Torch Server¶
melissa.server.deep_learning.torch_server.TorchServer¶
Bases: DeepMelissaServer
TorchServer
for managing and training a PyTorch model in a distributed setting.
It performs the following tasks:
- Sets up the distributed environment for multi-GPU or CPU training using PyTorch's DDP.
- Wraps the model in
torch.nn.parallel.DistributedDataParallel
(DDP) for distributed training. - Manages model and optimizer state, including loading and saving checkpoints.
- Synchronizes data availability across processes to avoid deadlocks during distributed training.
unwrapped_model
property
¶
Returns torch.nn.Module
object of the original model given by the user
before DDP wrapping. Useful for checkpointing custom state as well as calling user-defined
methods belonging to self.model
. As the server wraps the given model with
torch.nn.parallel.DistributedDataParallel
which will set module
attribute of its own.
Thus, returning self.model.module
or self.model
setup_environment()
¶
Configures the environment for distributed GPU or CPU training using
PyTorch's torch.distributed
package.
This method sets up the master address and port for inter-process communication, determines the appropriate device (GPU or CPU), selects the communication backend, and initializes the distributed process group. It also initializes a TensorBoard logger for tracking training metrics.
Behavior¶
- If multiple GPUs are available and match the communication size (
self.comm_size
), thenccl
backend is used with GPU devices. - If GPUs are unavailable or insufficient, the
gloo
backend is used with CPU devices.
Raises¶
RuntimeError
: Iftorch.distributed.init_process_group
fails during initialization.
_setup_environment_slurm()
¶
Configures the multi-node distributed data parallel (DDP) environment using SLURM for GPU-based training (using Jean-zay recommendations).
It retrieves information from the SLURM environment and sets up the local rank, world size, and device for each process. It also initializes the distributed process group with the appropriate backend.
Behavior¶
- If GPUs are available, sets the local GPU device and configures the DDP environment
using the
nccl
backend. - If no GPUs are found, raises a
RuntimeError
with a descriptive error message.
Raises¶
- RuntimeError: If no GPUs are available or the SLURM environment is not set up correctly for DDP.
wrap_model_ddp(model)
¶
Wraps the model in DistributedDataParallel (DDP) for multi-GPU training.
Depending on the setup (SLURM or local CUDA), this method wraps the model in DDP using the appropriate device(s).
Parameters¶
- model (
torch.nn.Module
): Instantiated torch model.
Returns¶
torch.nn.parallel.DistributedDataParallel
: The model wrapped in DDP for distributed training.
_server_finalize(exit_=0)
¶
Finalizes the server operations by calling
torch.distributed.destroy_process_group
.
Parameters¶
- exit_ (int, optional): The exit status code indicating the outcome of the server's operations. Defaults to 0, which signifies a successful termination.
_synchronize_data_availability()
¶
Coordinates dataset to be sure there are data available to be processed.
This is to avoid any deadlock in torch.distributed.all_reduce
of the gradients.
_load_model_from_checkpoint()
¶
Loads torch self.model
and self.optimizer
attributes
from the last checkpoint.
Returns¶
bool
: if data is available on all processes.
checkpoint(batch_idx=0, path='checkpoints')
¶
The method called to initiate full tree checkpointing.
Saves self.model
and self.optimizer
states.