Delta updates#
We present here how TAGI efficiently leverage the Gaussian conditionnal update equations by relying on delta_mu[]
and delta_var[]
in order to only compute the change require to the hidden states without requiring to explicitely compute them. We cover the working principles through two exampple, a first one on the backward step of the linear layer, and a second one for the output updater.
Example through the linear layer hidden-states updates#
Let’s take the backward step used to update the expected for the hidden states from the linear layer,
i.e., linear_bwd_fc_delta_z()
from linear_layer.cpp
:
{
int ni = input_size;
int no = output_size;
for (int j = start_chunk; j < end_chunk; j++) {
int row = j / B;
int col = j % B;
float sum_mu_z = 0.0f;
float sum_var_z = 0.0f;
for (int i = 0; i < no; i++) {
sum_mu_z += mu_w[ni * i + row] * delta_mu[col * no + i];
sum_var_z += mu_w[ni * i + row] * delta_var[col * no + i] * mu_w[ni * i + row];
}
// NOTE: Compute directly innovation vector
delta_mu_z[col * ni + row] = sum_mu_z * jcb[col * ni + row];
delta_var_z[col * ni + row] = sum_var_z * jcb[col * ni + row] * jcb[col * ni + row];
}
}
From the original TAGI paper, we have for the mean RTS update equations:
where
Therefore, by omitting the multiplication by \(\sigma_{Z_i}^{2}\), \(\mathtt{delta\_mu\_z[.]}\) (which becomes \(\mathtt{delta\_mu[.]}\) for the subsequent layer during the backward pass) is already pre-divided by \((\sigma^{+}_{Z_i})^{2}\).
Example through the output hidden-state update#
Let’s now take the backward step used to update the expected for the hidden states from the output layer,
i.e., compute_delta_z_output()
from base_output_updater.cpp
:
{
float zero_pad = 0;
float tmp = 0;
// We compute directly the innovation vector for output layer
for (int col = start_chunk; col < end_chunk; col++) {
tmp = jcb[col] / (var_a[col] + var_obs[col]);
if (isinf(tmp) || isnan(tmp)) {
delta_mu[col] = zero_pad;
delta_var[col] = zero_pad;
} else {
delta_mu[col] = tmp * (obs[col] - mu_a[col]);
delta_var[col] = -tmp * jcb[col];
}
}
}
The corresponding update for the expected value reads: