Coverage for melissa/server/deep_learning/torch_server.py: 37%
141 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
1"""This script extends `DeepMelissaServer` and implements `TorchServer`, as well as
2`ExperimentalDeepMelissaActiveSamplingServer`."""
4import os
5import logging
6from datetime import timedelta
7from typing import Any, Dict, Union, Optional, List
8from typing_extensions import override
10import torch
11from mpi4py import MPI
12import torch.distributed as dist
13from torch.nn.parallel import DistributedDataParallel as DDP
14import torch.utils.data
15import cloudpickle
17from melissa.server.deep_learning.frameworks import FrameworkType
18from melissa.server.deep_learning.base_dl_server import DeepMelissaServer
19from melissa.utility.networking import is_port_in_use
20from melissa.utility.rank_helper import ClusterEnvironment
21from melissa.server.exceptions import FatalError, TrainingError
24logger = logging.getLogger(__name__)
27class TorchServer(DeepMelissaServer):
28 """`TorchServer` for managing and training a PyTorch model in a distributed setting.
29 It performs the following tasks:
31 - Sets up the distributed environment for multi-GPU or CPU training using PyTorch's DDP.
32 - Wraps the model in `torch.nn.parallel.DistributedDataParallel` (DDP) for distributed training.
33 - Manages model and optimizer state, including loading and saving checkpoints.
34 - Synchronizes data availability across processes to avoid deadlocks during
35 distributed training."""
37 def __init__(self, config_dict: Dict[str, Any]) -> None:
39 super().__init__(config_dict)
40 self.device: str = "cpu"
41 self._framework_t = FrameworkType.TORCH
42 self.dist_group: Optional[dist.ProcessGroup] = dist.group.WORLD
43 self.ckpt_net_arch_path: str = "checkpoints/net_arch.pkl"
45 @property
46 def dist_group_initialized(self) -> bool:
47 if self.dist_group is not dist.GroupMember.NON_GROUP_MEMBER:
48 return dist.get_world_size(group=self.dist_group) > 0
49 return False
51 @property
52 def unwrapped_model(self) -> torch.nn.Module:
53 """Returns `torch.nn.Module` object of the original model given by the user
54 before DDP wrapping. Useful for checkpointing custom state as well as calling user-defined
55 methods belonging to `self.model`. As the server wraps the given model with
56 `torch.nn.parallel.DistributedDataParallel` which will set `module` attribute of its own.
57 Thus, returning `self.model.module` or `self.model`"""
58 return self.model.module if hasattr(self.model, "module") else self.model
60 @override
61 def _on_train_start(self) -> None:
62 self.model.train()
63 super()._on_train_start()
65 @override
66 def _on_validation_start(self, batch_idx: int) -> None:
67 self.model.eval()
68 super()._on_validation_start(batch_idx)
70 @override
71 def _on_validation_end(self, batch_idx: int) -> None:
72 self.model.train()
73 super()._on_validation_end(batch_idx)
75 @override
76 @torch.no_grad()
77 def validation(self, batch_idx: int) -> None:
78 super().validation(batch_idx)
80 def _set_master_addr_port(
81 self,
82 master_addr: str = "127.0.0.1",
83 master_port: int = 29500
84 ) -> None:
86 if self.rank == 0:
87 attempts = 10
88 i = 0
89 while is_port_in_use(master_port) and i < attempts:
90 logger.warning(
91 f"Rank {self.rank}>> MASTER_PORT={master_port} "
92 "for torch.distributed is already being used. Trying another..."
93 )
94 master_port += 1
95 i += 1
97 if i == attempts:
98 logger.error(
99 f"{self.rank}>> Could not find an available MASTER_PORT after "
100 f"{attempts} attempts."
101 )
102 raise RuntimeError
104 cluster = ClusterEnvironment()
105 master_addr, master_port = cluster.comm_world.bcast([master_addr, master_port], root=0)
106 # either use already set environment variables or use the broadcasted ones
107 os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", str(master_port))
108 os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", master_addr)
110 logger.info(
111 f"Rank {self.rank}>> torch.distributed will use "
112 f"env://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
113 )
115 def __initialize_distributed_backend(self) -> None:
116 """Initializes the distributed backend for both single-node and multi-node
117 server configurations. This method sets up the device and backend for
118 distributed training using either CPU or GPU, depending on availability.
120 It first obtains the local rank, global rank, and world size from the
121 Open MPI environment. Based on the availability of GPUs, it configures
122 the device and selects the appropriate backend ('gloo' for CPU and 'nccl'
123 for GPU).
125 Finally, it initializes the process group for distributed training and
126 creates a new group for active ranks that participate in training.
128 ### Raises
129 - `RuntimeError`: If no GPUs are available or the SLURM environment is not
130 set up correctly for Distributed Data Parallel (DDP).
131 """
133 cluster = ClusterEnvironment()
134 comm_world = cluster.comm_world
135 local_rank = cluster.comm_world_local_rank
136 local_world_size = cluster.comm_world_local_size
137 rank = cluster.comm_world_rank
138 world_size = cluster.comm_world_size
140 self.device = f"cpu:{local_rank}"
141 backend = 'gloo'
142 if torch.cuda.is_available():
143 local_gpus_count = torch.cuda.device_count()
144 assert local_world_size >= local_gpus_count
145 # if multi-node then gather local_gpus_count
146 # from local_rank 0 of every node
147 total_gpus_count = comm_world.allreduce( # type: ignore
148 local_gpus_count if local_rank == 0 else 0,
149 op=MPI.SUM
150 )
151 if world_size >= total_gpus_count:
152 self.device = f"cuda:{local_rank % local_gpus_count}"
153 backend = 'nccl'
154 else:
155 msg = "The number of MPI ranks must be at least equal to "
156 "the number of available GPUs."
157 logger.error(msg)
158 raise RuntimeError(msg)
160 logger.info(
161 f"Rank {self.rank}>> backend=\"{backend}\", device=\"{self.device}\", "
162 f"world-size={world_size}"
163 )
165 # timeout should scale according to the number of nodes
166 timeout = timedelta(
167 minutes=min(
168 # cluster.universe_size,
169 cluster.comm_world_size,
170 10,
171 )
172 )
173 dist.init_process_group(
174 init_method="env://",
175 backend=backend,
176 rank=rank,
177 world_size=world_size,
178 timeout=timeout
179 )
180 # the creation of a new group helps to diverge some of the ranks away from training
181 # NOTE: always use this group when doing collective calls on `torch.distributed`
182 self.dist_group = dist.new_group(
183 ranks=self._get_active_ranks(),
184 timeout=timeout
185 )
187 def _get_active_ranks(self) -> Optional[List[int]]:
188 """Returns a list of ranks that will be used for creating a group
189 that participates in training and collective calls.
191 Override this method to adjust the ranks associated with training.
192 """
193 return list(range(self.comm_size))
195 @override
196 def setup_environment(self) -> None:
197 """Configures the environment for distributed GPU or CPU training using
198 PyTorch's `torch.distributed` package."""
200 self._set_master_addr_port(master_addr=self.node_name)
201 self.__initialize_distributed_backend()
203 def wrap_model_ddp(self, model: Union[torch.nn.Module, DDP]) -> DDP:
204 """Wraps the model in DistributedDataParallel (DDP) for multi-GPU training.
206 Depending on the setup (SLURM or local CUDA), this method wraps the model
207 in DDP using the appropriate device(s).
209 ### Parameters
210 - **model** (`Union[torch.nn.Module, DDP]`): Instantiated torch model.
211 ### Returns
212 - `torch.nn.parallel.DistributedDataParallel`:
213 The model wrapped in DDP for distributed training."""
215 if isinstance(model, DDP):
216 return model
218 try:
219 model = model.to(torch.device(self.device))
220 device_ids = None if "cpu" in self.device else [self.device]
221 return DDP(
222 module=model,
223 device_ids=device_ids,
224 process_group=self.dist_group
225 )
226 except dist.DistBackendError as e:
227 logger.exception(
228 f"Rank {self.rank}>> MPI ranks and the number of GPUs must maintain "
229 "a 1-to-1 mapping. "
230 'For CPU training, export CUDA_VISIBLE_DEVICES=""',
231 str(e)
232 )
233 self._destroy_distributed_backend()
234 raise FatalError
236 def _destroy_distributed_backend(self) -> None:
237 # destroy the default (WORLD) group
238 if dist.is_initialized():
239 dist.destroy_process_group()
240 logger.info(f"Rank {self.rank}>> NCCL Group destroyed.")
242 @override
243 def _server_finalize(self, exit_: int = 0) -> None:
244 """
245 Finalizes the server operations by calling
246 `torch.distributed.destroy_process_group`.
248 ### Parameters
249 - **exit_ (int, optional)**: The exit status code indicating
250 the outcome of the server's operations.
251 Defaults to 0, which signifies a successful termination."""
252 self._destroy_distributed_backend()
253 super()._server_finalize(exit_)
255 @override
256 def _synchronize_data_availability(self) -> bool:
257 """Coordinates dataset to be sure there are data available to be processed.
258 This is to avoid any deadlock in `torch.distributed.all_reduce` of the gradients."""
259 assert self.dist_group_initialized
260 _status = torch.tensor(
261 int(self.dataset.has_data),
262 dtype=int, # type: ignore
263 device=self.device
264 )
265 try:
266 dist.all_reduce(_status, op=dist.ReduceOp.SUM, group=self.dist_group)
267 return _status.item() == self.comm_size
268 except RuntimeError as e:
269 self._destroy_distributed_backend()
270 raise TrainingError(str(e))
272 @override
273 def _load_model_from_checkpoint(self) -> None:
275 with open(self.ckpt_net_arch_path, 'rb') as f:
276 self.model = cloudpickle.load(f)
277 self.model = self.wrap_model_ddp(self.model)
279 map_location = {"cuda:0": self.device} if "cuda" in self.device else "cpu"
280 checkpoint = torch.load(
281 self.ckpt_model_path,
282 map_location=map_location, # type: ignore
283 weights_only=False
284 )
286 self.model.load_state_dict(checkpoint["model_state_dict"])
288 self.optimizer = checkpoint["optimizer"]
289 self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
291 self.batch_offset = checkpoint["batch_idx"]
293 dist.barrier(group=self.dist_group)
295 @override
296 def checkpoint(self, batch_idx: int = 0) -> None:
297 """The method called to initiate full tree checkpointing.
298 Saves `self.model` and `self.optimizer` states."""
300 if self.rank == 0:
301 with open(self.ckpt_net_arch_path, 'wb') as f:
302 cloudpickle.dump(self.unwrapped_model, f)
304 torch.save(
305 {
306 "optimizer": self.optimizer,
307 "batch_idx": batch_idx,
308 "model_state_dict": self.model.state_dict(),
309 "optimizer_state_dict": self.optimizer.state_dict(),
310 },
311 self.ckpt_model_path,
312 )