Coverage for melissa/server/deep_learning/tf_server.py: 0%
79 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
1"""This script extends `DeepMelissaServer` and implements `TFServer`"""
3# pylint: disable=E1101,E0401
5import os
6import logging
7from typing_extensions import override
8from typing import Any, Dict, List
10import tensorflow as tf
11import tensorflow.distribute as tfd
12try:
13 from tensorflow.distribute.cluster_resolver import SlurmClusterResolver # type: ignore
14except ModuleNotFoundError:
15 pass
17from melissa.server.deep_learning.frameworks import FrameworkType
18from melissa.server.deep_learning.base_dl_server import DeepMelissaServer
20logger = logging.getLogger(__name__)
23class TFServerDeviceError(Exception):
24 """Exception for distributed tensorflow server errors."""
27class TFServer(DeepMelissaServer):
28 """`TFServer` manages and trains a TensorFlow model in a distributed setting.
30 This class handles the following tasks:
32 - Sets up the distributed environment for multi-GPU or CPU training using TensorFlow's
33 `MultiWorkerMirroredStrategy`.
34 - Configures the environment based on available GPUs or falls back to CPU-only training.
35 - Initializes the distributed training strategy using cluster resolvers for SLURM, OAR,
36 or local clusters.
37 - Synchronizes data availability to ensure smooth gradient aggregation during
38 distributed training."""
39 def __init__(self, config_dict: Dict[str, Any]) -> None:
40 super().__init__(config_dict)
41 self._framework_t: FrameworkType = FrameworkType.TENSORFLOW
42 self.__strategy: tfd.MultiWorkerMirroredStrategy = tfd.MultiWorkerMirroredStrategy()
44 @property
45 def strategy(self) -> tfd.MultiWorkerMirroredStrategy:
46 """Returns initialized `tensorflow.distribute.MultiWorkerMirroredStrategy` instance."""
47 return self.__strategy
49 def __configure_visible_devices(self, device_type: str = "GPU") -> List:
50 """Sets the visible device(s) for the current process based on the rank.
52 This method ensures that each process sees only the GPU corresponding
53 to its local rank, facilitating distributed GPU training.
55 ### Parameters
56 - **device_type** (`str`, optional): The type of device to configure (default is `"GPU"`).
58 ### Returns
59 - `List`: A list of physical devices of the specified type found on the machine."""
61 physical_devices = tf.config.list_physical_devices(device_type)
62 if physical_devices:
63 local_rank = self.rank % len(physical_devices)
64 tf.config.set_visible_devices(physical_devices[local_rank], device_type)
66 return physical_devices
68 def __initialize_slurm_cluster(self) -> "SlurmClusterResolver":
69 """Initializes the cluster using Slurm environment variables for multi-node training.
71 It uses the `tensorflow.distribute.cluster_resolver.SlurmClusterResolver`
72 to set up a distributed cluster environment, configuring GPU usage and
73 networking based on Slurm settings.
75 ### Returns
76 - `tensorflow.distribute.cluster_resolver.SlurmClusterResolver`: A cluster resolver
77 configured for the current Slurm setup."""
79 return SlurmClusterResolver(
80 port_base=12345,
81 gpus_per_node=self.comm_size // int(os.environ['SLURM_NNODES']),
82 auto_set_gpu=False
83 )
85 def __initialize_oar_cluster(self) -> "SlurmClusterResolver":
86 """Initializes the cluster configuration using OAR environment variables.
88 This method sets up the distributed cluster for environments managed by OAR,
89 extracting the node list from the `OAR_NODEFILE` and configuring the
90 Slurm-compatible cluster resolver.
92 ### Returns
93 - `tensorflow.distribute.cluster_resolver.SlurmClusterResolver`:
94 A cluster resolver configured for OAR-based multi-node training."""
96 os.environ['SLURM_PROCID'] = str(self.rank)
97 with open(os.environ['OAR_NODEFILE']) as my_file:
98 host_list = list(set(my_file.read().splitlines()))
99 return SlurmClusterResolver(
100 jobs={'worker': self.comm_size},
101 port_base=12345,
102 gpus_per_node=self.comm_size // len(host_list),
103 gpus_per_task=1,
104 tasks_per_node={host: self.comm_size // len(host_list) for host in host_list},
105 auto_set_gpu=False
106 )
108 def __initialize_local_cluster(self) -> "SlurmClusterResolver":
109 """Initializes the cluster for a local, single-node distributed training setup.
111 This method configures the `tensorflow.distribute.cluster_resolver.SlurmClusterResolver`
112 for local environments without external cluster managers like Slurm or OAR,
113 using the local hostname as the node.
115 ### Returns
116 - `tensorflow.distribute.cluster_resolver.SlurmClusterResolver`:
117 A cluster resolver configured for local distributed training."""
119 os.environ['SLURM_PROCID'] = str(self.rank)
120 return SlurmClusterResolver(
121 jobs={'worker': self.comm_size},
122 port_base=12345,
123 gpus_per_node=None,
124 gpus_per_task=1,
125 tasks_per_node={os.uname()[1]: self.comm_size},
126 auto_set_gpu=False
127 )
129 def __initialize_strategy(self, cluster_resolver: "SlurmClusterResolver") -> None:
130 """Initializes the distributed training strategy using NCCL for multi-GPU communication.
132 This method configures the `MultiWorkerMirroredStrategy` with the given cluster resolver
133 and sets the communication options to use the NCCL backend for efficient GPU-based
134 inter-process communication.
136 ### Parameters
137 - **cluster_resolver** (`tensorflow.distribute.cluster_resolver.SlurmClusterResolver`):
138 The cluster resolver that defines the distributed cluster configuration."""
140 communication_options = tfd.experimental.CommunicationOptions(
141 implementation=tfd.experimental.CommunicationImplementation.NCCL
142 )
143 self.__strategy = tfd.MultiWorkerMirroredStrategy(
144 cluster_resolver=cluster_resolver,
145 communication_options=communication_options
146 )
148 def __fallback_to_cpu_strategy(self) -> None:
149 """Provides a fallback strategy for CPU-only distributed training.
151 This method ensures that distributed training can proceed when no GPUs are available.
152 If multiple CPUs are detected, it raises an error since TensorFlow cannot distribute
153 workloads across multiple CPUs without GPUs.
155 ### Raises
156 - `TFServerDeviceError`: If multiple CPUs are detected."""
158 if len(tf.config.list_physical_devices('CPU')) > 1:
159 raise TFServerDeviceError(
160 "TensorFlow cannot be distributed on multiple non-GPU devices."
161 )
162 if len(tf.config.list_physical_devices('CPU')) == 1:
163 logger.info("Default `MultiWorkerMirroredStrategy` with CPU.")
164 self.__strategy = tfd.MultiWorkerMirroredStrategy()
166 @override
167 def setup_environment(self) -> None:
168 """Configures the environment for distributed training using GPUs or CPUs.
170 This method sets up GPU visibility, initializes the appropriate cluster
171 (SLURM, OAR, or local), and configures the distribution strategy for
172 multi-worker training. If no GPUs are available, it falls back to CPU-based
173 distributed training.
175 ### Behavior
176 - Detects and configures GPU devices for the current process.
177 - Initializes the cluster using SLURM, OAR, or a local setup.
178 - Sets up a `MultiWorkerMirroredStrategy` using NCCL for GPU communication.
179 - Falls back to CPU strategy if no GPUs are found.
181 ### Raises
182 - **TFServerDeviceError**: Raised when no GPUs are available for distributed training."""
184 try:
185 physical_devices = self.__configure_visible_devices('GPU')
187 if not physical_devices:
188 raise TFServerDeviceError("No GPU found")
190 if 'SLURM_NODELIST' in os.environ:
191 logger.info("Slurm cluster initialization")
192 cluster_resolver = self.__initialize_slurm_cluster()
193 elif 'OAR_NODEFILE' in os.environ:
194 logger.info("OAR cluster initialization")
195 cluster_resolver = self.__initialize_oar_cluster()
196 else:
197 logger.info("Local cluster initialization")
198 cluster_resolver = self.__initialize_local_cluster()
200 logger.info(f"Rank {self.rank}>> physical-devices={physical_devices}")
201 logger.info(
202 f"Rank {self.rank}>> visible-devices={tf.config.get_visible_devices('GPU')}"
203 )
205 self.__initialize_strategy(cluster_resolver)
206 except TFServerDeviceError as e:
207 logger.info(f"SLURM, OAR, and local cluster initialization failed with exception: {e}")
208 self.__fallback_to_cpu_strategy()
210 @override
211 def _load_model_from_checkpoint(self) -> None:
212 """Loads tensorflow `self.model` and `self.optimizer` attributes
213 from the last checkpoint."""
215 step = tf.Variable(0, trainable=False)
216 with self.__strategy.scope():
217 restore = tf.train.Checkpoint(
218 step=step,
219 optimizer=self.optimizer,
220 model=self.model
221 )
222 restore.read(self.ckpt_model_path)
224 self.batch_offset = step.numpy()
226 @override
227 def checkpoint(self,
228 batch_idx: int = 0) -> None:
229 """The method called to initiate full tree checkpointing.
230 Saves `self.model` and `self.optimizer` states."""
232 if self.rank == 0:
233 # tensorflow checkpoint
234 ckpt = tf.train.Checkpoint(
235 step=tf.Variable(batch_idx, trainable=False),
236 optimizer=self.optimizer,
237 model=self.model
238 )
240 ckpt.write(self.ckpt_model_path)