Torch Server

melissa.server.deep_learning.torch_server.TorchServer

Bases: DeepMelissaServer

TorchServer for managing and training a PyTorch model in a distributed setting. It performs the following tasks:

  • Sets up the distributed environment for multi-GPU or CPU training using PyTorch's DDP.
  • Wraps the model in torch.nn.parallel.DistributedDataParallel (DDP) for distributed training.
  • Manages model and optimizer state, including loading and saving checkpoints.
  • Synchronizes data availability across processes to avoid deadlocks during distributed training.

unwrapped_model property

Returns torch.nn.Module object of the original model given by the user before DDP wrapping. Useful for checkpointing custom state as well as calling user-defined methods belonging to self.model. As the server wraps the given model with torch.nn.parallel.DistributedDataParallel which will set module attribute of its own. Thus, returning self.model.module or self.model

setup_environment()

Configures the environment for distributed GPU or CPU training using PyTorch's torch.distributed package.

This method sets up the master address and port for inter-process communication, determines the appropriate device (GPU or CPU), selects the communication backend, and initializes the distributed process group. It also initializes a TensorBoard logger for tracking training metrics.

Behavior
  • If multiple GPUs are available and match the communication size (self.comm_size), the nccl backend is used with GPU devices.
  • If GPUs are unavailable or insufficient, the gloo backend is used with CPU devices.
Raises
  • RuntimeError: If torch.distributed.init_process_group fails during initialization.

_setup_environment_slurm()

Configures the multi-node distributed data parallel (DDP) environment using SLURM for GPU-based training (using Jean-zay recommendations).

It retrieves information from the SLURM environment and sets up the local rank, world size, and device for each process. It also initializes the distributed process group with the appropriate backend.

Behavior
  • If GPUs are available, sets the local GPU device and configures the DDP environment using the nccl backend.
  • If no GPUs are found, raises a RuntimeError with a descriptive error message.
Raises
  • RuntimeError: If no GPUs are available or the SLURM environment is not set up correctly for DDP.

wrap_model_ddp(model)

Wraps the model in DistributedDataParallel (DDP) for multi-GPU training.

Depending on the setup (SLURM or local CUDA), this method wraps the model in DDP using the appropriate device(s).

Parameters
  • model (torch.nn.Module): Instantiated torch model.
Returns
  • torch.nn.parallel.DistributedDataParallel: The model wrapped in DDP for distributed training.

_server_finalize(exit_=0)

Finalizes the server operations by calling torch.distributed.destroy_process_group.

Parameters
  • exit_ (int, optional): The exit status code indicating the outcome of the server's operations. Defaults to 0, which signifies a successful termination.

_synchronize_data_availability()

Coordinates dataset to be sure there are data available to be processed. This is to avoid any deadlock in torch.distributed.all_reduce of the gradients.

_load_model_from_checkpoint()

Loads torch self.model and self.optimizer attributes from the last checkpoint.

Returns
  • bool: if data is available on all processes.

checkpoint(batch_idx=0, path='checkpoints')

The method called to initiate full tree checkpointing. Saves self.model and self.optimizer states.