pytagi.nn.ddp#
Classes#
Configuration for Distributed Data Parallel (DDP) training. |
|
A wrapper for Sequential models to enable Distributed Data Parallel (DDP) training. |
Module Contents#
- class pytagi.nn.ddp.DDPConfig(device_ids: List[int], backend: str = 'nccl', rank: int = 0, world_size: int = 1)[source]#
Configuration for Distributed Data Parallel (DDP) training.
This class holds all the necessary settings for initializing a distributed process group.
Initializes the DDP configuration.
- Parameters:
device_ids (List[int]) – A list of GPU device IDs to be used for training.
backend (str, optional) – The distributed backend to use. ‘nccl’ is recommended for GPUs. Defaults to “nccl”.
rank (int, optional) – The unique rank of the current process. Defaults to 0.
world_size (int, optional) – The total number of processes participating in the training. Defaults to 1.
- class pytagi.nn.ddp.DDPSequential(model: pytagi.nn.sequential.Sequential, config: DDPConfig, average: bool = True)[source]#
A wrapper for Sequential models to enable Distributed Data Parallel (DDP) training.
This class handles gradient synchronization and parameter updates across multiple processes, allowing for scalable training on multiple GPUs.
Initializes the DDPSequential wrapper.
- Parameters:
model (Sequential) – The Sequential model to be parallelized.
config (DDPConfig) – The DDP configuration object.
average (bool, optional) – If True, gradients are averaged across processes. If False, they are summed. Defaults to True.
- property output_z_buffer: pytagi.nn.data_struct.BaseHiddenStates[source]#
The output hidden states buffer from the forward pass of the underlying model.
- property input_delta_z_buffer: pytagi.nn.data_struct.BaseDeltaStates[source]#
The input delta states buffer for the backward pass of the underlying model.
- __call__(mu_x: numpy.ndarray, var_x: numpy.ndarray = None) Tuple[numpy.ndarray, numpy.ndarray] [source]#
A convenient alias for the forward pass.
- Parameters:
mu_x (np.ndarray) – The mean of the input data for the current process.
var_x (np.ndarray, optional) – The variance of the input data for the current process. Defaults to None.
- Returns:
A tuple containing the mean and variance of the model’s output.
- Return type:
Tuple[np.ndarray, np.ndarray]
- forward(mu_x: numpy.ndarray, var_x: numpy.ndarray = None) Tuple[numpy.ndarray, numpy.ndarray] [source]#
Performs a forward pass on the local model replica.
- Parameters:
mu_x (np.ndarray) – The mean of the input data.
var_x (np.ndarray, optional) – The variance of the input data. Defaults to None.
- Returns:
A tuple containing the mean and variance of the output.
- Return type:
Tuple[np.ndarray, np.ndarray]
- barrier()[source]#
Synchronizes all processes.
Blocks until all processes in the distributed group have reached this point.