Source code for pytagi.nn.data_struct
from typing import List, Optional
import cutagi
[docs]
class BaseHiddenStates:
"""
Represents the base hidden states, acting as a Python wrapper for the C++ backend.
This class manages the mean (mu_a), variance (var_a), and Jacobian (jcb) of hidden states.
"""
def __init__(
self, size: Optional[int] = None, block_size: Optional[int] = None
):
"""
Initializes the BaseHiddenStates.
Args:
size (Optional[int]): The size of the hidden states.
block_size (Optional[int]): The block size for the hidden states.
"""
if size is not None and block_size is not None:
self._cpp_backend = cutagi.BaseHiddenStates(size, block_size)
else:
self._cpp_backend = cutagi.BaseHiddenStates()
@property
[docs]
def mu_a(self) -> List[float]:
"""
Gets or sets the mean of the hidden states (mu_a).
"""
return self._cpp_backend.mu_a
@mu_a.setter
def mu_a(self, value: List[float]):
self._cpp_backend.mu_a = value
@property
[docs]
def var_a(self) -> List[float]:
"""
Gets or sets the variance of the hidden states (var_a).
"""
return self._cpp_backend.var_a
@var_a.setter
def var_a(self, value: List[float]):
self._cpp_backend.var_a = value
@property
[docs]
def jcb(self) -> List[float]:
"""
Gets or sets the Jacobian of the hidden states (jcb).
"""
return self._cpp_backend.jcb
@jcb.setter
def jcb(self, value: List[float]):
self._cpp_backend.jcb = value
@property
[docs]
def size(self) -> int:
"""
Gets the size of the hidden states.
"""
return self._cpp_backend.size
@property
[docs]
def block_size(self) -> int:
"""
Gets the block size of the hidden states.
"""
return self._cpp_backend.block_size
@property
[docs]
def actual_size(self) -> int:
"""
Gets the actual size of the hidden states.
"""
return self._cpp_backend.actual_size
[docs]
def get_name(self) -> str:
"""
Gets the name of the hidden states type.
Returns:
str: The name of the hidden states type.
"""
return self._cpp_backend.get_name()
[docs]
def set_size(self, new_size: int, new_block_size: int) -> str:
"""
Sets a new size and block size for the hidden states.
Args:
new_size (int): The new size.
new_block_size (int): The new block size.
Returns:
str: A message indicating the success or failure of the operation.
"""
self._cpp_backend.set_size(new_size, new_block_size)
[docs]
class BaseDeltaStates:
"""
Represents the base delta states, acting as a Python wrapper for the C++ backend.
This class manages the change in mean (delta_mu) and change in variance (delta_var)
induced by the update step.
"""
def __init__(
self, size: Optional[int] = None, block_size: Optional[int] = None
):
"""
Initializes the BaseDeltaStates.
Args:
size (Optional[int]): The size of the delta states.
block_size (Optional[int]): The block size for the delta states.
"""
if size is not None and block_size is not None:
self._cpp_backend = cutagi.BaseDeltaStates(size, block_size)
else:
self._cpp_backend = cutagi.BaseDeltaStates()
@property
[docs]
def delta_mu(self) -> List[float]:
"""
Gets or sets the change in mean of the delta states (delta_mu).
"""
return self._cpp_backend.delta_mu
@delta_mu.setter
def delta_mu(self, value: List[float]):
self._cpp_backend.delta_mu = value
@property
[docs]
def delta_var(self) -> List[float]:
"""
Gets or sets the change in variance of the delta states (delta_var).
"""
return self._cpp_backend.delta_var
@delta_var.setter
def delta_var(self, value: List[float]):
self._cpp_backend.delta_var = value
@property
[docs]
def size(self) -> int:
"""
Gets the size of the delta states.
"""
return self._cpp_backend.size
@property
[docs]
def block_size(self) -> int:
"""
Gets the block size of the delta states.
"""
return self._cpp_backend.block_size
@property
[docs]
def actual_size(self) -> int:
"""
Gets the actual size of the delta states.
"""
return self._cpp_backend.actual_size
[docs]
def get_name(self) -> str:
"""
Gets the name of the delta states type.
Returns:
str: The name of the delta states type.
"""
return self._cpp_backend.get_name()
[docs]
def reset_zeros(self) -> None:
"""Reset all delta_mu and delta_var to zeros."""
self._cpp_backend.reset_zeros()
[docs]
def copy_from(self, source: "BaseDeltaStates", num_data: int = -1) -> None:
"""
Copy values of delta_mu and delta_var from another delta states object.
Args:
source (BaseDeltaStates): The source delta states object to copy from.
num_data (int): The number of data points to copy. Defaults to -1 (all).
"""
self._cpp_backend.copy_from(source, num_data)
[docs]
def set_size(self, new_size: int, new_block_size: int) -> str:
"""
Sets a new size and block size for the delta states.
Args:
new_size (int): The new size.
new_block_size (int): The new block size.
Returns:
str: A message indicating the success or failure of the operation.
"""
self._cpp_backend.set_size(new_size, new_block_size)
[docs]
class HRCSoftmax:
"""
Hierarchical softmax wrapper from the CPP backend.
"""
def __init__(self) -> None:
"""Initializes the HRCSoftmax object."""
self._cpp_backend = cutagi.HRCSoftmax()
@property
[docs]
def obs(self) -> List[float]:
"""
Gets or sets the fictive observation \in [-1, 1].
"""
return self._cpp_backend.obs
@obs.setter
def obs(self, value: List[float]):
self._cpp_backend.obs = value
@property
[docs]
def idx(self) -> List[int]:
"""
Gets or sets the indices assigned to each label.
"""
return self._cpp_backend.idx
@idx.setter
def idx(self, value: List[int]):
self._cpp_backend.idx = value
@property
[docs]
def num_obs(self) -> int:
"""
Gets or sets the number of indices for each label.
"""
return self._cpp_backend.num_obs
@num_obs.setter
def num_obs(self, value: int):
self._cpp_backend.num_obs = value
@property
[docs]
def len(self) -> int:
"""
Gets or sets the length of an observation (e.g., 10 labels -> len(obs) = 11).
"""
return self._cpp_backend.len
@len.setter
def len(self, value: int):
self._cpp_backend.len = value