Torch Server

melissa.server.deep_learning.tf_server.TFServer

Bases: DeepMelissaServer

TFServer manages and trains a TensorFlow model in a distributed setting.

This class handles the following tasks:

  • Sets up the distributed environment for multi-GPU or CPU training using TensorFlow's MultiWorkerMirroredStrategy.
  • Configures the environment based on available GPUs or falls back to CPU-only training.
  • Initializes the distributed training strategy using cluster resolvers for SLURM, OAR, or local clusters.
  • Synchronizes data availability to ensure smooth gradient aggregation during distributed training.

strategy property

Returns initialized tensorflow.distribute.MultiWorkerMirroredStrategy instance.

__configure_visible_devices(device_type='GPU')

Sets the visible device(s) for the current process based on the rank.

This method ensures that each process sees only the GPU corresponding to its local rank, facilitating distributed GPU training.

Parameters
  • device_type (str, optional): The type of device to configure (default is "GPU").
Returns
  • List: A list of physical devices of the specified type found on the machine.

__initialize_slurm_cluster()

Initializes the cluster using Slurm environment variables for multi-node training.

It uses the tensorflow.distribute.cluster_resolver.SlurmClusterResolver to set up a distributed cluster environment, configuring GPU usage and networking based on Slurm settings.

Returns
  • tensorflow.distribute.cluster_resolver.SlurmClusterResolver: A cluster resolver configured for the current Slurm setup.

__initialize_oar_cluster()

Initializes the cluster configuration using OAR environment variables.

This method sets up the distributed cluster for environments managed by OAR, extracting the node list from the OAR_NODEFILE and configuring the Slurm-compatible cluster resolver.

Returns
  • tensorflow.distribute.cluster_resolver.SlurmClusterResolver: A cluster resolver configured for OAR-based multi-node training.

__initialize_local_cluster()

Initializes the cluster for a local, single-node distributed training setup.

This method configures the tensorflow.distribute.cluster_resolver.SlurmClusterResolver for local environments without external cluster managers like Slurm or OAR, using the local hostname as the node.

Returns
  • tensorflow.distribute.cluster_resolver.SlurmClusterResolver: A cluster resolver configured for local distributed training.

__initialize_strategy(cluster_resolver)

Initializes the distributed training strategy using NCCL for multi-GPU communication.

This method configures the MultiWorkerMirroredStrategy with the given cluster resolver and sets the communication options to use the NCCL backend for efficient GPU-based inter-process communication.

Parameters
  • cluster_resolver (tensorflow.distribute.cluster_resolver.SlurmClusterResolver): The cluster resolver that defines the distributed cluster configuration.

__fallback_to_cpu_strategy()

Provides a fallback strategy for CPU-only distributed training.

This method ensures that distributed training can proceed when no GPUs are available. If multiple CPUs are detected, it raises an error since TensorFlow cannot distribute workloads across multiple CPUs without GPUs.

Raises
  • TFServerDeviceError: If multiple CPUs are detected.

setup_environment()

Configures the environment for distributed training using GPUs or CPUs.

This method sets up GPU visibility, initializes the appropriate cluster (SLURM, OAR, or local), and configures the distribution strategy for multi-worker training. If no GPUs are available, it falls back to CPU-based distributed training.

Behavior
  • Detects and configures GPU devices for the current process.
  • Initializes the cluster using SLURM, OAR, or a local setup.
  • Sets up a MultiWorkerMirroredStrategy using NCCL for GPU communication.
  • Falls back to CPU strategy if no GPUs are found.
Raises
  • TFServerDeviceError: Raised when no GPUs are available for distributed training.

_load_model_from_checkpoint()

Loads tensorflow self.model and self.optimizer attributes from the last checkpoint.

checkpoint(batch_idx=0)

The method called to initiate full tree checkpointing. Saves self.model and self.optimizer states.

Exceptions

melissa.server.deep_learning.tf_server.TFServerDeviceError

Bases: Exception

Exception for distributed tensorflow server errors.