pytagi.nn.data_struct#

Classes#

BaseHiddenStates

Represents the base hidden states, acting as a Python wrapper for the C++ backend.

BaseDeltaStates

Represents the base delta states, acting as a Python wrapper for the C++ backend.

HRCSoftmax

Hierarchical softmax wrapper from the CPP backend.

Module Contents#

class pytagi.nn.data_struct.BaseHiddenStates(size: int | None = None, block_size: int | None = None)[source]#

Represents the base hidden states, acting as a Python wrapper for the C++ backend. This class manages the mean (mu_a), variance (var_a), and Jacobian (jcb) of hidden states.

Initializes the BaseHiddenStates.

Parameters:
  • size (Optional[int]) – The size of the hidden states.

  • block_size (Optional[int]) – The block size for the hidden states.

property mu_a: List[float][source]#

Gets or sets the mean of the hidden states (mu_a).

property var_a: List[float][source]#

Gets or sets the variance of the hidden states (var_a).

property jcb: List[float][source]#

Gets or sets the Jacobian of the hidden states (jcb).

property size: int[source]#

Gets the size of the hidden states.

property block_size: int[source]#

Gets the block size of the hidden states.

property actual_size: int[source]#

Gets the actual size of the hidden states.

set_input_x(mu_x: List[float], var_x: List[float], block_size: int)[source]#

Sets the input for the hidden states.

Parameters:
  • mu_x (List[float]) – The mean of the input x.

  • var_x (List[float]) – The variance of the input x.

  • block_size (int) – The block size for the input.

get_name() str[source]#

Gets the name of the hidden states type.

Returns:

The name of the hidden states type.

Return type:

str

set_size(new_size: int, new_block_size: int) str[source]#

Sets a new size and block size for the hidden states.

Parameters:
  • new_size (int) – The new size.

  • new_block_size (int) – The new block size.

Returns:

A message indicating the success or failure of the operation.

Return type:

str

class pytagi.nn.data_struct.BaseDeltaStates(size: int | None = None, block_size: int | None = None)[source]#

Represents the base delta states, acting as a Python wrapper for the C++ backend. This class manages the change in mean (delta_mu) and change in variance (delta_var) induced by the update step.

Initializes the BaseDeltaStates.

Parameters:
  • size (Optional[int]) – The size of the delta states.

  • block_size (Optional[int]) – The block size for the delta states.

property delta_mu: List[float][source]#

Gets or sets the change in mean of the delta states (delta_mu).

property delta_var: List[float][source]#

Gets or sets the change in variance of the delta states (delta_var).

property size: int[source]#

Gets the size of the delta states.

property block_size: int[source]#

Gets the block size of the delta states.

property actual_size: int[source]#

Gets the actual size of the delta states.

get_name() str[source]#

Gets the name of the delta states type.

Returns:

The name of the delta states type.

Return type:

str

reset_zeros() None[source]#

Reset all delta_mu and delta_var to zeros.

copy_from(source: BaseDeltaStates, num_data: int = -1) None[source]#

Copy values of delta_mu and delta_var from another delta states object.

Parameters:
  • source (BaseDeltaStates) – The source delta states object to copy from.

  • num_data (int) – The number of data points to copy. Defaults to -1 (all).

set_size(new_size: int, new_block_size: int) str[source]#

Sets a new size and block size for the delta states.

Parameters:
  • new_size (int) – The new size.

  • new_block_size (int) – The new block size.

Returns:

A message indicating the success or failure of the operation.

Return type:

str

class pytagi.nn.data_struct.HRCSoftmax[source]#

Hierarchical softmax wrapper from the CPP backend.

Initializes the HRCSoftmax object.

property obs: List[float][source]#

Gets or sets the fictive observation in [-1, 1].

property idx: List[int][source]#

Gets or sets the indices assigned to each label.

property num_obs: int[source]#

Gets or sets the number of indices for each label.

property len: int[source]#

Gets or sets the length of an observation (e.g., 10 labels -> len(obs) = 11).