pytagi.nn.output_updater#
Classes#
A utility to compute the error signal (delta states) for the output layer. |
Module Contents#
- class pytagi.nn.output_updater.OutputUpdater(model_device: str)[source]#
A utility to compute the error signal (delta states) for the output layer.
This class calculates the difference between the model’s predictions and the observations, which is essential for performing the backward pass to update the model’s parameters. It wraps the C++/CUDA backend cutagi.OutputUpdater.
Initializes the OutputUpdater.
- Parameters:
model_device (str) – The computational device the model is on (e.g., ‘cpu’ or ‘cuda:0’).
- update(output_states: pytagi.nn.data_struct.BaseHiddenStates, mu_obs: numpy.ndarray, var_obs: numpy.ndarray, delta_states: pytagi.nn.data_struct.BaseDeltaStates)[source]#
Computes the delta states based on observations.
This method is used for homoscedastic regression where the observation variance is known and provided.
- Parameters:
output_states (pytagi.nn.data_struct.BaseHiddenStates) – The hidden states (mean and variance) of the model’s output layer.
mu_obs (np.ndarray) – The mean of the ground truth observations.
var_obs (np.ndarray) – The variance of the ground truth observations.
delta_states (pytagi.nn.data_struct.BaseDeltaStates) – The delta states object to be updated with the computed error signal.
- update_using_indices(output_states: pytagi.nn.data_struct.BaseHiddenStates, mu_obs: numpy.ndarray, var_obs: numpy.ndarray, selected_idx: numpy.ndarray, delta_states: pytagi.nn.data_struct.BaseDeltaStates)[source]#
Computes the delta states for a selected subset of outputs.
This is useful in scenarios like hierarchical softmax or when only a sparse set of outputs needs to be updated.
- Parameters:
output_states (pytagi.nn.data_struct.BaseHiddenStates) – The hidden states of the model’s output layer.
mu_obs (np.ndarray) – The mean of the ground truth observations.
var_obs (np.ndarray) – The variance of the ground truth observations.
selected_idx (np.ndarray) – An array of indices specifying which output neurons to update.
delta_states (pytagi.nn.data_struct.BaseDeltaStates) – The delta states object to be updated with the computed error signal.
- update_heteros(output_states: pytagi.nn.data_struct.BaseHiddenStates, mu_obs: numpy.ndarray, delta_states: pytagi.nn.data_struct.BaseDeltaStates)[source]#
Computes delta states for heteroscedastic regression.
In this case, the model is expected to predict both the mean and the variance of the output. The predicted variance is taken from the output_states.
- Parameters:
output_states (pytagi.nn.data_struct.BaseHiddenStates) – The hidden states of the model’s output layer. The model’s predicted variance is sourced from here.
mu_obs (np.ndarray) – The mean of the ground truth observations.
delta_states (pytagi.nn.data_struct.BaseDeltaStates) – The delta states object to be updated with the computed error signal.