Torch Server

melissa.server.deep_learning.torch_server.TorchServer

Bases: DeepMelissaServer

TorchServer for managing and training a PyTorch model in a distributed setting. It performs the following tasks:

  • Sets up the distributed environment for multi-GPU or CPU training using PyTorch's DDP.
  • Wraps the model in torch.nn.parallel.DistributedDataParallel (DDP) for distributed training.
  • Manages model and optimizer state, including loading and saving checkpoints.
  • Synchronizes data availability across processes to avoid deadlocks during distributed training.

unwrapped_model property

Returns torch.nn.Module object of the original model given by the user before DDP wrapping. Useful for checkpointing custom state as well as calling user-defined methods belonging to self.model. As the server wraps the given model with torch.nn.parallel.DistributedDataParallel which will set module attribute of its own. Thus, returning self.model.module or self.model

__initialize_distributed_backend()

Initializes the distributed backend for both single-node and multi-node server configurations. This method sets up the device and backend for distributed training using either CPU or GPU, depending on availability.

It first obtains the local rank, global rank, and world size from the Open MPI environment. Based on the availability of GPUs, it configures the device and selects the appropriate backend ('gloo' for CPU and 'nccl' for GPU).

Finally, it initializes the process group for distributed training and creates a new group for active ranks that participate in training.

Raises
  • RuntimeError: If no GPUs are available or the SLURM environment is not set up correctly for Distributed Data Parallel (DDP).

setup_environment()

Configures the environment for distributed GPU or CPU training using PyTorch's torch.distributed package.

wrap_model_ddp(model)

Wraps the model in DistributedDataParallel (DDP) for multi-GPU training.

Depending on the setup (SLURM or local CUDA), this method wraps the model in DDP using the appropriate device(s).

Parameters
  • model (Union[torch.nn.Module, DDP]): Instantiated torch model.
Returns
  • torch.nn.parallel.DistributedDataParallel: The model wrapped in DDP for distributed training.

checkpoint(batch_idx=0)

The method called to initiate full tree checkpointing. Saves self.model and self.optimizer states.