Source code for pytagi.nn.lstm
import cutagi
from pytagi.nn.base_layer import BaseLayer
[docs]
class LSTM(BaseLayer):
"""A temporal LSTM layer that processes sequences with an explicit
time loop, correctly propagating hidden states between timesteps."""
def __init__(
self,
input_size: int,
output_size: int,
last_timestep: bool = False,
seq_len: int = 1,
bias: bool = True,
gain_weight: float = 1.0,
gain_bias: float = 1.0,
init_method: str = "He",
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.last_timestep = last_timestep
self.seq_len = seq_len
self.bias = bias
self.gain_weight = gain_weight
self.gain_bias = gain_bias
self.init_method = init_method
self._cpp_backend = cutagi.LSTM(
input_size,
output_size,
last_timestep,
seq_len,
bias,
gain_weight,
gain_bias,
init_method,
)
[docs]
def get_layer_info(self) -> str:
return self._cpp_backend.get_layer_info()
[docs]
def get_layer_name(self) -> str:
return self._cpp_backend.get_layer_name()
[docs]
def init_weight_bias(self):
self._cpp_backend.init_weight_bias()