Source code for pytagi.nn.slstm
import cutagi
from pytagi.nn.base_layer import BaseLayer
[docs]
class SLSTM(BaseLayer):
"""Smoothing Long Short-Term Memory (LSTM) layer.
This layer is a variation of the standard LSTM, incorporating a mechanism
for **smoothing** the hidden- and cell-states. It wraps the C++/CUDA backend
`cutagi.SLSTM`.
"""
def __init__(
self,
input_size: int,
output_size: int,
seq_len: int,
bias: bool = True,
gain_weight: float = 1.0,
gain_bias: float = 1.0,
init_method: str = "He",
):
"""Initializes the SLSTM layer.
:param input_size: The number of expected features in the input $x$.
:type input_size: int
:param output_size: The number of features in the hidden state $h$ (and the output).
:type output_size: int
:param seq_len: The maximum sequence length this layer is configured to handle.
:type seq_len: int
:param bias: If ``True``, use bias weights in the internal linear transformations.
:type bias: bool
:param gain_weight: A scaling factor applied to the initialized weights.
:type gain_weight: float
:param gain_bias: A scaling factor applied to the initialized bias terms.
:type gain_bias: float
:param init_method: The method used for initializing weights and biases (e.g., 'He', 'Xavier').
:type init_method: str
"""
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.seq_len = seq_len
self.bias = bias
self.gain_weight = gain_weight
self.gain_bias = gain_bias
self.init_method = init_method
self._cpp_backend = cutagi.SLSTM(
input_size,
output_size,
seq_len,
bias,
gain_weight,
gain_bias,
init_method,
)
[docs]
def get_layer_info(self) -> str:
"""Returns a string containing detailed information about the layer's configuration."""
return self._cpp_backend.get_layer_info()
[docs]
def get_layer_name(self) -> str:
"""Returns the name of the layer (e.g., 'SLSTM')."""
return self._cpp_backend.get_layer_name()
[docs]
def init_weight_bias(self):
"""Initializes all the layer's internal weight matrices and bias vectors (for gates and cell) based on the configured method."""
self._cpp_backend.init_weight_bias()