Source code for pytagi.nn.lstm

import cutagi

from pytagi.nn.base_layer import BaseLayer


[docs] class LSTM(BaseLayer): """A temporal LSTM layer that processes sequences with an explicit time loop, correctly propagating hidden states between timesteps.""" def __init__( self, input_size: int, output_size: int, last_timestep: bool = False, seq_len: int = 1, bias: bool = True, gain_weight: float = 1.0, gain_bias: float = 1.0, init_method: str = "He", ): super().__init__() self.input_size = input_size self.output_size = output_size self.last_timestep = last_timestep self.seq_len = seq_len self.bias = bias self.gain_weight = gain_weight self.gain_bias = gain_bias self.init_method = init_method self._cpp_backend = cutagi.LSTM( input_size, output_size, last_timestep, seq_len, bias, gain_weight, gain_bias, init_method, )
[docs] def get_layer_info(self) -> str: return self._cpp_backend.get_layer_info()
[docs] def get_layer_name(self) -> str: return self._cpp_backend.get_layer_name()
[docs] def init_weight_bias(self): self._cpp_backend.init_weight_bias()