Coverage for melissa/server/deep_learning/torch_server.py: 37%

141 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-07-09 14:19 +0200

1"""This script extends `DeepMelissaServer` and implements `TorchServer`, as well as 

2`ExperimentalDeepMelissaActiveSamplingServer`.""" 

3 

4import os 

5import logging 

6from datetime import timedelta 

7from typing import Any, Dict, Union, Optional, List 

8from typing_extensions import override 

9 

10import torch 

11from mpi4py import MPI 

12import torch.distributed as dist 

13from torch.nn.parallel import DistributedDataParallel as DDP 

14import torch.utils.data 

15import cloudpickle 

16 

17from melissa.server.deep_learning.frameworks import FrameworkType 

18from melissa.server.deep_learning.base_dl_server import DeepMelissaServer 

19from melissa.utility.networking import is_port_in_use 

20from melissa.utility.rank_helper import ClusterEnvironment 

21from melissa.server.exceptions import FatalError, TrainingError 

22 

23 

24logger = logging.getLogger(__name__) 

25 

26 

27class TorchServer(DeepMelissaServer): 

28 """`TorchServer` for managing and training a PyTorch model in a distributed setting. 

29 It performs the following tasks: 

30 

31 - Sets up the distributed environment for multi-GPU or CPU training using PyTorch's DDP. 

32 - Wraps the model in `torch.nn.parallel.DistributedDataParallel` (DDP) for distributed training. 

33 - Manages model and optimizer state, including loading and saving checkpoints. 

34 - Synchronizes data availability across processes to avoid deadlocks during 

35 distributed training.""" 

36 

37 def __init__(self, config_dict: Dict[str, Any]) -> None: 

38 

39 super().__init__(config_dict) 

40 self.device: str = "cpu" 

41 self._framework_t = FrameworkType.TORCH 

42 self.dist_group: Optional[dist.ProcessGroup] = dist.group.WORLD 

43 self.ckpt_net_arch_path: str = "checkpoints/net_arch.pkl" 

44 

45 @property 

46 def dist_group_initialized(self) -> bool: 

47 if self.dist_group is not dist.GroupMember.NON_GROUP_MEMBER: 

48 return dist.get_world_size(group=self.dist_group) > 0 

49 return False 

50 

51 @property 

52 def unwrapped_model(self) -> torch.nn.Module: 

53 """Returns `torch.nn.Module` object of the original model given by the user 

54 before DDP wrapping. Useful for checkpointing custom state as well as calling user-defined 

55 methods belonging to `self.model`. As the server wraps the given model with 

56 `torch.nn.parallel.DistributedDataParallel` which will set `module` attribute of its own. 

57 Thus, returning `self.model.module` or `self.model`""" 

58 return self.model.module if hasattr(self.model, "module") else self.model 

59 

60 @override 

61 def _on_train_start(self) -> None: 

62 self.model.train() 

63 super()._on_train_start() 

64 

65 @override 

66 def _on_validation_start(self, batch_idx: int) -> None: 

67 self.model.eval() 

68 super()._on_validation_start(batch_idx) 

69 

70 @override 

71 def _on_validation_end(self, batch_idx: int) -> None: 

72 self.model.train() 

73 super()._on_validation_end(batch_idx) 

74 

75 @override 

76 @torch.no_grad() 

77 def validation(self, batch_idx: int) -> None: 

78 super().validation(batch_idx) 

79 

80 def _set_master_addr_port( 

81 self, 

82 master_addr: str = "127.0.0.1", 

83 master_port: int = 29500 

84 ) -> None: 

85 

86 if self.rank == 0: 

87 attempts = 10 

88 i = 0 

89 while is_port_in_use(master_port) and i < attempts: 

90 logger.warning( 

91 f"Rank {self.rank}>> MASTER_PORT={master_port} " 

92 "for torch.distributed is already being used. Trying another..." 

93 ) 

94 master_port += 1 

95 i += 1 

96 

97 if i == attempts: 

98 logger.error( 

99 f"{self.rank}>> Could not find an available MASTER_PORT after " 

100 f"{attempts} attempts." 

101 ) 

102 raise RuntimeError 

103 

104 cluster = ClusterEnvironment() 

105 master_addr, master_port = cluster.comm_world.bcast([master_addr, master_port], root=0) 

106 # either use already set environment variables or use the broadcasted ones 

107 os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", str(master_port)) 

108 os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", master_addr) 

109 

110 logger.info( 

111 f"Rank {self.rank}>> torch.distributed will use " 

112 f"env://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}" 

113 ) 

114 

115 def __initialize_distributed_backend(self) -> None: 

116 """Initializes the distributed backend for both single-node and multi-node 

117 server configurations. This method sets up the device and backend for 

118 distributed training using either CPU or GPU, depending on availability. 

119 

120 It first obtains the local rank, global rank, and world size from the 

121 Open MPI environment. Based on the availability of GPUs, it configures 

122 the device and selects the appropriate backend ('gloo' for CPU and 'nccl' 

123 for GPU). 

124 

125 Finally, it initializes the process group for distributed training and 

126 creates a new group for active ranks that participate in training. 

127 

128 ### Raises 

129 - `RuntimeError`: If no GPUs are available or the SLURM environment is not 

130 set up correctly for Distributed Data Parallel (DDP). 

131 """ 

132 

133 cluster = ClusterEnvironment() 

134 comm_world = cluster.comm_world 

135 local_rank = cluster.comm_world_local_rank 

136 local_world_size = cluster.comm_world_local_size 

137 rank = cluster.comm_world_rank 

138 world_size = cluster.comm_world_size 

139 

140 self.device = f"cpu:{local_rank}" 

141 backend = 'gloo' 

142 if torch.cuda.is_available(): 

143 local_gpus_count = torch.cuda.device_count() 

144 assert local_world_size >= local_gpus_count 

145 # if multi-node then gather local_gpus_count 

146 # from local_rank 0 of every node 

147 total_gpus_count = comm_world.allreduce( # type: ignore 

148 local_gpus_count if local_rank == 0 else 0, 

149 op=MPI.SUM 

150 ) 

151 if world_size >= total_gpus_count: 

152 self.device = f"cuda:{local_rank % local_gpus_count}" 

153 backend = 'nccl' 

154 else: 

155 msg = "The number of MPI ranks must be at least equal to " 

156 "the number of available GPUs." 

157 logger.error(msg) 

158 raise RuntimeError(msg) 

159 

160 logger.info( 

161 f"Rank {self.rank}>> backend=\"{backend}\", device=\"{self.device}\", " 

162 f"world-size={world_size}" 

163 ) 

164 

165 # timeout should scale according to the number of nodes 

166 timeout = timedelta( 

167 minutes=min(cluster.universe_size, 10) 

168 ) 

169 dist.init_process_group( 

170 init_method="env://", 

171 backend=backend, 

172 rank=rank, 

173 world_size=world_size, 

174 timeout=timeout 

175 ) 

176 # the creation of a new group helps to diverge some of the ranks away from training 

177 # NOTE: always use this group when doing collective calls on `torch.distributed` 

178 self.dist_group = dist.new_group( 

179 ranks=self._get_active_ranks(), 

180 timeout=timeout 

181 ) 

182 

183 def _get_active_ranks(self) -> Optional[List[int]]: 

184 """Returns a list of ranks that will be used for creating a group 

185 that participates in training and collective calls. 

186 

187 Override this method to adjust the ranks associated with training. 

188 """ 

189 return list(range(self.comm_size)) 

190 

191 @override 

192 def setup_environment(self) -> None: 

193 """Configures the environment for distributed GPU or CPU training using 

194 PyTorch's `torch.distributed` package.""" 

195 

196 self._set_master_addr_port(master_addr=self.node_name) 

197 self.__initialize_distributed_backend() 

198 

199 def wrap_model_ddp(self, model: Union[torch.nn.Module, DDP]) -> DDP: 

200 """Wraps the model in DistributedDataParallel (DDP) for multi-GPU training. 

201 

202 Depending on the setup (SLURM or local CUDA), this method wraps the model 

203 in DDP using the appropriate device(s). 

204 

205 ### Parameters 

206 - **model** (`Union[torch.nn.Module, DDP]`): Instantiated torch model. 

207 ### Returns 

208 - `torch.nn.parallel.DistributedDataParallel`: 

209 The model wrapped in DDP for distributed training.""" 

210 

211 if isinstance(model, DDP): 

212 return model 

213 

214 try: 

215 model = model.to(torch.device(self.device)) 

216 device_ids = None if "cpu" in self.device else [self.device] 

217 return DDP( 

218 module=model, 

219 device_ids=device_ids, 

220 process_group=self.dist_group 

221 ) 

222 except dist.DistBackendError as e: 

223 logger.exception( 

224 f"Rank {self.rank}>> MPI ranks and the number of GPUs must maintain " 

225 "a 1-to-1 mapping. " 

226 'For CPU training, export CUDA_VISIBLE_DEVICES=""', 

227 str(e) 

228 ) 

229 self._destroy_distributed_backend() 

230 raise FatalError 

231 

232 def _destroy_distributed_backend(self) -> None: 

233 # destroy the default (WORLD) group 

234 if dist.is_initialized(): 

235 dist.destroy_process_group() 

236 logger.info(f"Rank {self.rank}>> NCCL Group destroyed.") 

237 

238 @override 

239 def _server_finalize(self, exit_: int = 0) -> None: 

240 """ 

241 Finalizes the server operations by calling 

242 `torch.distributed.destroy_process_group`. 

243 

244 ### Parameters 

245 - **exit_ (int, optional)**: The exit status code indicating 

246 the outcome of the server's operations. 

247 Defaults to 0, which signifies a successful termination.""" 

248 self._destroy_distributed_backend() 

249 super()._server_finalize(exit_) 

250 

251 @override 

252 def _synchronize_data_availability(self) -> bool: 

253 """Coordinates dataset to be sure there are data available to be processed. 

254 This is to avoid any deadlock in `torch.distributed.all_reduce` of the gradients.""" 

255 assert self.dist_group_initialized 

256 _status = torch.tensor( 

257 int(self.dataset.has_data), 

258 dtype=int, # type: ignore 

259 device=self.device 

260 ) 

261 try: 

262 dist.all_reduce(_status, op=dist.ReduceOp.SUM, group=self.dist_group) 

263 return _status.item() == self.comm_size 

264 except RuntimeError as e: 

265 self._destroy_distributed_backend() 

266 raise TrainingError(str(e)) 

267 

268 @override 

269 def _load_model_from_checkpoint(self) -> None: 

270 

271 with open(self.ckpt_net_arch_path, 'rb') as f: 

272 self.model = cloudpickle.load(f) 

273 self.model = self.wrap_model_ddp(self.model) 

274 

275 map_location = {"cuda:0": self.device} if "cuda" in self.device else "cpu" 

276 checkpoint = torch.load( 

277 self.ckpt_model_path, 

278 map_location=map_location, # type: ignore 

279 weights_only=False 

280 ) 

281 

282 self.model.load_state_dict(checkpoint["model_state_dict"]) 

283 

284 self.optimizer = checkpoint["optimizer"] 

285 self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 

286 

287 self.batch_offset = checkpoint["batch_idx"] 

288 

289 dist.barrier(group=self.dist_group) 

290 

291 @override 

292 def checkpoint(self, batch_idx: int = 0) -> None: 

293 """The method called to initiate full tree checkpointing. 

294 Saves `self.model` and `self.optimizer` states.""" 

295 

296 if self.rank == 0: 

297 with open(self.ckpt_net_arch_path, 'wb') as f: 

298 cloudpickle.dump(self.unwrapped_model, f) 

299 

300 torch.save( 

301 { 

302 "optimizer": self.optimizer, 

303 "batch_idx": batch_idx, 

304 "model_state_dict": self.model.state_dict(), 

305 "optimizer_state_dict": self.optimizer.state_dict(), 

306 }, 

307 self.ckpt_model_path, 

308 )