pytagi.nn.data_struct#
Classes#
Represents the base hidden states, acting as a Python wrapper for the C++ backend. |
|
Represents the base delta states, acting as a Python wrapper for the C++ backend. |
|
Hierarchical softmax wrapper from the CPP backend. |
Module Contents#
- class pytagi.nn.data_struct.BaseHiddenStates(size: int | None = None, block_size: int | None = None)[source]#
Represents the base hidden states, acting as a Python wrapper for the C++ backend. This class manages the mean (mu_a), variance (var_a), and Jacobian (jcb) of hidden states.
Initializes the BaseHiddenStates.
- Parameters:
size (Optional[int]) – The size of the hidden states.
block_size (Optional[int]) – The block size for the hidden states.
- set_input_x(mu_x: List[float], var_x: List[float], block_size: int)[source]#
Sets the input for the hidden states.
- Parameters:
mu_x (List[float]) – The mean of the input x.
var_x (List[float]) – The variance of the input x.
block_size (int) – The block size for the input.
- class pytagi.nn.data_struct.BaseDeltaStates(size: int | None = None, block_size: int | None = None)[source]#
Represents the base delta states, acting as a Python wrapper for the C++ backend. This class manages the change in mean (delta_mu) and change in variance (delta_var) induced by the update step.
Initializes the BaseDeltaStates.
- Parameters:
size (Optional[int]) – The size of the delta states.
block_size (Optional[int]) – The block size for the delta states.
- property delta_mu: List[float][source]#
Gets or sets the change in mean of the delta states (delta_mu).
- property delta_var: List[float][source]#
Gets or sets the change in variance of the delta states (delta_var).
- get_name() str [source]#
Gets the name of the delta states type.
- Returns:
The name of the delta states type.
- Return type:
str
- copy_from(source: BaseDeltaStates, num_data: int = -1) None [source]#
Copy values of delta_mu and delta_var from another delta states object.
- Parameters:
source (BaseDeltaStates) – The source delta states object to copy from.
num_data (int) – The number of data points to copy. Defaults to -1 (all).