Coverage for melissa/server/deep_learning/torch_server.py: 37%
141 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-07-09 14:19 +0200
« prev ^ index » next coverage.py v7.6.12, created at 2025-07-09 14:19 +0200
1"""This script extends `DeepMelissaServer` and implements `TorchServer`, as well as
2`ExperimentalDeepMelissaActiveSamplingServer`."""
4import os
5import logging
6from datetime import timedelta
7from typing import Any, Dict, Union, Optional, List
8from typing_extensions import override
10import torch
11from mpi4py import MPI
12import torch.distributed as dist
13from torch.nn.parallel import DistributedDataParallel as DDP
14import torch.utils.data
15import cloudpickle
17from melissa.server.deep_learning.frameworks import FrameworkType
18from melissa.server.deep_learning.base_dl_server import DeepMelissaServer
19from melissa.utility.networking import is_port_in_use
20from melissa.utility.rank_helper import ClusterEnvironment
21from melissa.server.exceptions import FatalError, TrainingError
24logger = logging.getLogger(__name__)
27class TorchServer(DeepMelissaServer):
28 """`TorchServer` for managing and training a PyTorch model in a distributed setting.
29 It performs the following tasks:
31 - Sets up the distributed environment for multi-GPU or CPU training using PyTorch's DDP.
32 - Wraps the model in `torch.nn.parallel.DistributedDataParallel` (DDP) for distributed training.
33 - Manages model and optimizer state, including loading and saving checkpoints.
34 - Synchronizes data availability across processes to avoid deadlocks during
35 distributed training."""
37 def __init__(self, config_dict: Dict[str, Any]) -> None:
39 super().__init__(config_dict)
40 self.device: str = "cpu"
41 self._framework_t = FrameworkType.TORCH
42 self.dist_group: Optional[dist.ProcessGroup] = dist.group.WORLD
43 self.ckpt_net_arch_path: str = "checkpoints/net_arch.pkl"
45 @property
46 def dist_group_initialized(self) -> bool:
47 if self.dist_group is not dist.GroupMember.NON_GROUP_MEMBER:
48 return dist.get_world_size(group=self.dist_group) > 0
49 return False
51 @property
52 def unwrapped_model(self) -> torch.nn.Module:
53 """Returns `torch.nn.Module` object of the original model given by the user
54 before DDP wrapping. Useful for checkpointing custom state as well as calling user-defined
55 methods belonging to `self.model`. As the server wraps the given model with
56 `torch.nn.parallel.DistributedDataParallel` which will set `module` attribute of its own.
57 Thus, returning `self.model.module` or `self.model`"""
58 return self.model.module if hasattr(self.model, "module") else self.model
60 @override
61 def _on_train_start(self) -> None:
62 self.model.train()
63 super()._on_train_start()
65 @override
66 def _on_validation_start(self, batch_idx: int) -> None:
67 self.model.eval()
68 super()._on_validation_start(batch_idx)
70 @override
71 def _on_validation_end(self, batch_idx: int) -> None:
72 self.model.train()
73 super()._on_validation_end(batch_idx)
75 @override
76 @torch.no_grad()
77 def validation(self, batch_idx: int) -> None:
78 super().validation(batch_idx)
80 def _set_master_addr_port(
81 self,
82 master_addr: str = "127.0.0.1",
83 master_port: int = 29500
84 ) -> None:
86 if self.rank == 0:
87 attempts = 10
88 i = 0
89 while is_port_in_use(master_port) and i < attempts:
90 logger.warning(
91 f"Rank {self.rank}>> MASTER_PORT={master_port} "
92 "for torch.distributed is already being used. Trying another..."
93 )
94 master_port += 1
95 i += 1
97 if i == attempts:
98 logger.error(
99 f"{self.rank}>> Could not find an available MASTER_PORT after "
100 f"{attempts} attempts."
101 )
102 raise RuntimeError
104 cluster = ClusterEnvironment()
105 master_addr, master_port = cluster.comm_world.bcast([master_addr, master_port], root=0)
106 # either use already set environment variables or use the broadcasted ones
107 os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", str(master_port))
108 os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", master_addr)
110 logger.info(
111 f"Rank {self.rank}>> torch.distributed will use "
112 f"env://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
113 )
115 def __initialize_distributed_backend(self) -> None:
116 """Initializes the distributed backend for both single-node and multi-node
117 server configurations. This method sets up the device and backend for
118 distributed training using either CPU or GPU, depending on availability.
120 It first obtains the local rank, global rank, and world size from the
121 Open MPI environment. Based on the availability of GPUs, it configures
122 the device and selects the appropriate backend ('gloo' for CPU and 'nccl'
123 for GPU).
125 Finally, it initializes the process group for distributed training and
126 creates a new group for active ranks that participate in training.
128 ### Raises
129 - `RuntimeError`: If no GPUs are available or the SLURM environment is not
130 set up correctly for Distributed Data Parallel (DDP).
131 """
133 cluster = ClusterEnvironment()
134 comm_world = cluster.comm_world
135 local_rank = cluster.comm_world_local_rank
136 local_world_size = cluster.comm_world_local_size
137 rank = cluster.comm_world_rank
138 world_size = cluster.comm_world_size
140 self.device = f"cpu:{local_rank}"
141 backend = 'gloo'
142 if torch.cuda.is_available():
143 local_gpus_count = torch.cuda.device_count()
144 assert local_world_size >= local_gpus_count
145 # if multi-node then gather local_gpus_count
146 # from local_rank 0 of every node
147 total_gpus_count = comm_world.allreduce( # type: ignore
148 local_gpus_count if local_rank == 0 else 0,
149 op=MPI.SUM
150 )
151 if world_size >= total_gpus_count:
152 self.device = f"cuda:{local_rank % local_gpus_count}"
153 backend = 'nccl'
154 else:
155 msg = "The number of MPI ranks must be at least equal to "
156 "the number of available GPUs."
157 logger.error(msg)
158 raise RuntimeError(msg)
160 logger.info(
161 f"Rank {self.rank}>> backend=\"{backend}\", device=\"{self.device}\", "
162 f"world-size={world_size}"
163 )
165 # timeout should scale according to the number of nodes
166 timeout = timedelta(
167 minutes=min(cluster.universe_size, 10)
168 )
169 dist.init_process_group(
170 init_method="env://",
171 backend=backend,
172 rank=rank,
173 world_size=world_size,
174 timeout=timeout
175 )
176 # the creation of a new group helps to diverge some of the ranks away from training
177 # NOTE: always use this group when doing collective calls on `torch.distributed`
178 self.dist_group = dist.new_group(
179 ranks=self._get_active_ranks(),
180 timeout=timeout
181 )
183 def _get_active_ranks(self) -> Optional[List[int]]:
184 """Returns a list of ranks that will be used for creating a group
185 that participates in training and collective calls.
187 Override this method to adjust the ranks associated with training.
188 """
189 return list(range(self.comm_size))
191 @override
192 def setup_environment(self) -> None:
193 """Configures the environment for distributed GPU or CPU training using
194 PyTorch's `torch.distributed` package."""
196 self._set_master_addr_port(master_addr=self.node_name)
197 self.__initialize_distributed_backend()
199 def wrap_model_ddp(self, model: Union[torch.nn.Module, DDP]) -> DDP:
200 """Wraps the model in DistributedDataParallel (DDP) for multi-GPU training.
202 Depending on the setup (SLURM or local CUDA), this method wraps the model
203 in DDP using the appropriate device(s).
205 ### Parameters
206 - **model** (`Union[torch.nn.Module, DDP]`): Instantiated torch model.
207 ### Returns
208 - `torch.nn.parallel.DistributedDataParallel`:
209 The model wrapped in DDP for distributed training."""
211 if isinstance(model, DDP):
212 return model
214 try:
215 model = model.to(torch.device(self.device))
216 device_ids = None if "cpu" in self.device else [self.device]
217 return DDP(
218 module=model,
219 device_ids=device_ids,
220 process_group=self.dist_group
221 )
222 except dist.DistBackendError as e:
223 logger.exception(
224 f"Rank {self.rank}>> MPI ranks and the number of GPUs must maintain "
225 "a 1-to-1 mapping. "
226 'For CPU training, export CUDA_VISIBLE_DEVICES=""',
227 str(e)
228 )
229 self._destroy_distributed_backend()
230 raise FatalError
232 def _destroy_distributed_backend(self) -> None:
233 # destroy the default (WORLD) group
234 if dist.is_initialized():
235 dist.destroy_process_group()
236 logger.info(f"Rank {self.rank}>> NCCL Group destroyed.")
238 @override
239 def _server_finalize(self, exit_: int = 0) -> None:
240 """
241 Finalizes the server operations by calling
242 `torch.distributed.destroy_process_group`.
244 ### Parameters
245 - **exit_ (int, optional)**: The exit status code indicating
246 the outcome of the server's operations.
247 Defaults to 0, which signifies a successful termination."""
248 self._destroy_distributed_backend()
249 super()._server_finalize(exit_)
251 @override
252 def _synchronize_data_availability(self) -> bool:
253 """Coordinates dataset to be sure there are data available to be processed.
254 This is to avoid any deadlock in `torch.distributed.all_reduce` of the gradients."""
255 assert self.dist_group_initialized
256 _status = torch.tensor(
257 int(self.dataset.has_data),
258 dtype=int, # type: ignore
259 device=self.device
260 )
261 try:
262 dist.all_reduce(_status, op=dist.ReduceOp.SUM, group=self.dist_group)
263 return _status.item() == self.comm_size
264 except RuntimeError as e:
265 self._destroy_distributed_backend()
266 raise TrainingError(str(e))
268 @override
269 def _load_model_from_checkpoint(self) -> None:
271 with open(self.ckpt_net_arch_path, 'rb') as f:
272 self.model = cloudpickle.load(f)
273 self.model = self.wrap_model_ddp(self.model)
275 map_location = {"cuda:0": self.device} if "cuda" in self.device else "cpu"
276 checkpoint = torch.load(
277 self.ckpt_model_path,
278 map_location=map_location, # type: ignore
279 weights_only=False
280 )
282 self.model.load_state_dict(checkpoint["model_state_dict"])
284 self.optimizer = checkpoint["optimizer"]
285 self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
287 self.batch_offset = checkpoint["batch_idx"]
289 dist.barrier(group=self.dist_group)
291 @override
292 def checkpoint(self, batch_idx: int = 0) -> None:
293 """The method called to initiate full tree checkpointing.
294 Saves `self.model` and `self.optimizer` states."""
296 if self.rank == 0:
297 with open(self.ckpt_net_arch_path, 'wb') as f:
298 cloudpickle.dump(self.unwrapped_model, f)
300 torch.save(
301 {
302 "optimizer": self.optimizer,
303 "batch_idx": batch_idx,
304 "model_state_dict": self.model.state_dict(),
305 "optimizer_state_dict": self.optimizer.state_dict(),
306 },
307 self.ckpt_model_path,
308 )