pytagi.nn.resnet_block#

Classes#

ResNetBlock

A Residual Network (ResNet) block structure.

Module Contents#

class pytagi.nn.resnet_block.ResNetBlock(main_block: pytagi.nn.base_layer.BaseLayer | pytagi.nn.layer_block.LayerBlock, shortcut: pytagi.nn.base_layer.BaseLayer | pytagi.nn.layer_block.LayerBlock = None)[source]#

Bases: pytagi.nn.base_layer.BaseLayer

A Residual Network (ResNet) block structure.

This class implements the core structure of a ResNet block, consisting of a main block (which performs the main transformations) and an optional shortcut connection (which adds the input to the main block’s output). It wraps the C++/CUDA backend cutagi.ResNetBlock.

Initializes the ResNetBlock.

Parameters:
  • main_block (Union[BaseLayer, LayerBlock]) – The primary set of layers in the block (e.g., convolutional layers).

  • shortcut (Union[BaseLayer, LayerBlock], optional) – The optional shortcut connection, often an identity mapping or a projection. If None, an identity shortcut is implicitly assumed by the C++ backend.

init_shortcut_state() None[source]#

Initializes the hidden state buffers for the shortcut layer.

init_shortcut_delta_state() None[source]#

Initializes the delta state buffers (error signals) for the shortcut layer.

init_input_buffer() None[source]#

Initializes the input state buffer used to hold the input for both the main block and the shortcut.

property main_block: pytagi.nn.layer_block.LayerBlock[source]#

Gets the main block component of the ResNet block.

property shortcut: pytagi.nn.base_layer.BaseLayer[source]#

Gets the shortcut component of the ResNet block.

property input_z: pytagi.nn.data_struct.BaseHiddenStates[source]#

Gets the buffered input hidden states (mean and variance) for the block.

property input_delta_z: pytagi.nn.data_struct.BaseDeltaStates[source]#

Gets the delta states (error signals) associated with the block’s input.

property shortcut_output_z: pytagi.nn.data_struct.BaseHiddenStates[source]#

Gets the output hidden states (mean and variance) from the shortcut layer.

property shortcut_output_delta_z: pytagi.nn.data_struct.BaseDeltaStates[source]#

Gets the delta states (error signals) associated with the shortcut layer’s output.