Torch Server¶
melissa.server.deep_learning.tf_server.TFServer¶
Bases: DeepMelissaServer
TFServer
manages and trains a TensorFlow model in a distributed setting.
This class handles the following tasks:
- Sets up the distributed environment for multi-GPU or CPU training using TensorFlow's
MultiWorkerMirroredStrategy
. - Configures the environment based on available GPUs or falls back to CPU-only training.
- Initializes the distributed training strategy using cluster resolvers for SLURM, OAR, or local clusters.
- Synchronizes data availability to ensure smooth gradient aggregation during distributed training.
strategy
property
¶
Returns initialized tensorflow.distribute.MultiWorkerMirroredStrategy
instance.
__configure_visible_devices(device_type='GPU')
¶
Sets the visible device(s) for the current process based on the rank.
This method ensures that each process sees only the GPU corresponding to its local rank, facilitating distributed GPU training.
Parameters¶
- device_type (
str
, optional): The type of device to configure (default is"GPU"
).
Returns¶
List
: A list of physical devices of the specified type found on the machine.
__initialize_slurm_cluster()
¶
Initializes the cluster using Slurm environment variables for multi-node training.
It uses the tensorflow.distribute.cluster_resolver.SlurmClusterResolver
to set up a distributed cluster environment, configuring GPU usage and
networking based on Slurm settings.
Returns¶
tensorflow.distribute.cluster_resolver.SlurmClusterResolver
: A cluster resolver configured for the current Slurm setup.
__initialize_oar_cluster()
¶
Initializes the cluster configuration using OAR environment variables.
This method sets up the distributed cluster for environments managed by OAR,
extracting the node list from the OAR_NODEFILE
and configuring the
Slurm-compatible cluster resolver.
Returns¶
tensorflow.distribute.cluster_resolver.SlurmClusterResolver
: A cluster resolver configured for OAR-based multi-node training.
__initialize_local_cluster()
¶
Initializes the cluster for a local, single-node distributed training setup.
This method configures the tensorflow.distribute.cluster_resolver.SlurmClusterResolver
for local environments without external cluster managers like Slurm or OAR,
using the local hostname as the node.
Returns¶
tensorflow.distribute.cluster_resolver.SlurmClusterResolver
: A cluster resolver configured for local distributed training.
__initialize_strategy(cluster_resolver)
¶
Initializes the distributed training strategy using NCCL for multi-GPU communication.
This method configures the MultiWorkerMirroredStrategy
with the given cluster resolver
and sets the communication options to use the NCCL backend for efficient GPU-based
inter-process communication.
Parameters¶
- cluster_resolver (
tensorflow.distribute.cluster_resolver.SlurmClusterResolver
): The cluster resolver that defines the distributed cluster configuration.
__fallback_to_cpu_strategy()
¶
Provides a fallback strategy for CPU-only distributed training.
This method ensures that distributed training can proceed when no GPUs are available. If multiple CPUs are detected, it raises an error since TensorFlow cannot distribute workloads across multiple CPUs without GPUs.
Raises¶
TFServerDeviceError
: If multiple CPUs are detected.
setup_environment()
¶
Configures the environment for distributed training using GPUs or CPUs.
This method sets up GPU visibility, initializes the appropriate cluster (SLURM, OAR, or local), and configures the distribution strategy for multi-worker training. If no GPUs are available, it falls back to CPU-based distributed training.
Behavior¶
- Detects and configures GPU devices for the current process.
- Initializes the cluster using SLURM, OAR, or a local setup.
- Sets up a
MultiWorkerMirroredStrategy
using NCCL for GPU communication. - Falls back to CPU strategy if no GPUs are found.
Raises¶
- TFServerDeviceError: Raised when no GPUs are available for distributed training.
_load_model_from_checkpoint()
¶
Loads tensorflow self.model
and self.optimizer
attributes
from the last checkpoint.
checkpoint(batch_idx=0)
¶
The method called to initiate full tree checkpointing.
Saves self.model
and self.optimizer
states.