Coverage for melissa/server/deep_learning/torch_server.py: 37%

141 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-11-03 09:52 +0100

1"""This script extends `DeepMelissaServer` and implements `TorchServer`, as well as 

2`ExperimentalDeepMelissaActiveSamplingServer`.""" 

3 

4import os 

5import logging 

6from datetime import timedelta 

7from typing import Any, Dict, Union, Optional, List 

8from typing_extensions import override 

9 

10import torch 

11from mpi4py import MPI 

12import torch.distributed as dist 

13from torch.nn.parallel import DistributedDataParallel as DDP 

14import torch.utils.data 

15import cloudpickle 

16 

17from melissa.server.deep_learning.frameworks import FrameworkType 

18from melissa.server.deep_learning.base_dl_server import DeepMelissaServer 

19from melissa.utility.networking import is_port_in_use 

20from melissa.utility.rank_helper import ClusterEnvironment 

21from melissa.server.exceptions import FatalError, TrainingError 

22 

23 

24logger = logging.getLogger(__name__) 

25 

26 

27class TorchServer(DeepMelissaServer): 

28 """`TorchServer` for managing and training a PyTorch model in a distributed setting. 

29 It performs the following tasks: 

30 

31 - Sets up the distributed environment for multi-GPU or CPU training using PyTorch's DDP. 

32 - Wraps the model in `torch.nn.parallel.DistributedDataParallel` (DDP) for distributed training. 

33 - Manages model and optimizer state, including loading and saving checkpoints. 

34 - Synchronizes data availability across processes to avoid deadlocks during 

35 distributed training.""" 

36 

37 def __init__(self, config_dict: Dict[str, Any]) -> None: 

38 

39 super().__init__(config_dict) 

40 self.device: str = "cpu" 

41 self._framework_t = FrameworkType.TORCH 

42 self.dist_group: Optional[dist.ProcessGroup] = dist.group.WORLD 

43 self.ckpt_net_arch_path: str = "checkpoints/net_arch.pkl" 

44 

45 @property 

46 def dist_group_initialized(self) -> bool: 

47 if self.dist_group is not dist.GroupMember.NON_GROUP_MEMBER: 

48 return dist.get_world_size(group=self.dist_group) > 0 

49 return False 

50 

51 @property 

52 def unwrapped_model(self) -> torch.nn.Module: 

53 """Returns `torch.nn.Module` object of the original model given by the user 

54 before DDP wrapping. Useful for checkpointing custom state as well as calling user-defined 

55 methods belonging to `self.model`. As the server wraps the given model with 

56 `torch.nn.parallel.DistributedDataParallel` which will set `module` attribute of its own. 

57 Thus, returning `self.model.module` or `self.model`""" 

58 return self.model.module if hasattr(self.model, "module") else self.model 

59 

60 @override 

61 def _on_train_start(self) -> None: 

62 self.model.train() 

63 super()._on_train_start() 

64 

65 @override 

66 def _on_validation_start(self, batch_idx: int) -> None: 

67 self.model.eval() 

68 super()._on_validation_start(batch_idx) 

69 

70 @override 

71 def _on_validation_end(self, batch_idx: int) -> None: 

72 self.model.train() 

73 super()._on_validation_end(batch_idx) 

74 

75 @override 

76 @torch.no_grad() 

77 def validation(self, batch_idx: int) -> None: 

78 super().validation(batch_idx) 

79 

80 def _set_master_addr_port( 

81 self, 

82 master_addr: str = "127.0.0.1", 

83 master_port: int = 29500 

84 ) -> None: 

85 

86 if self.rank == 0: 

87 attempts = 10 

88 i = 0 

89 while is_port_in_use(master_port) and i < attempts: 

90 logger.warning( 

91 f"Rank {self.rank}>> MASTER_PORT={master_port} " 

92 "for torch.distributed is already being used. Trying another..." 

93 ) 

94 master_port += 1 

95 i += 1 

96 

97 if i == attempts: 

98 logger.error( 

99 f"{self.rank}>> Could not find an available MASTER_PORT after " 

100 f"{attempts} attempts." 

101 ) 

102 raise RuntimeError 

103 

104 cluster = ClusterEnvironment() 

105 master_addr, master_port = cluster.comm_world.bcast([master_addr, master_port], root=0) 

106 # either use already set environment variables or use the broadcasted ones 

107 os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", str(master_port)) 

108 os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", master_addr) 

109 

110 logger.info( 

111 f"Rank {self.rank}>> torch.distributed will use " 

112 f"env://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}" 

113 ) 

114 

115 def __initialize_distributed_backend(self) -> None: 

116 """Initializes the distributed backend for both single-node and multi-node 

117 server configurations. This method sets up the device and backend for 

118 distributed training using either CPU or GPU, depending on availability. 

119 

120 It first obtains the local rank, global rank, and world size from the 

121 Open MPI environment. Based on the availability of GPUs, it configures 

122 the device and selects the appropriate backend ('gloo' for CPU and 'nccl' 

123 for GPU). 

124 

125 Finally, it initializes the process group for distributed training and 

126 creates a new group for active ranks that participate in training. 

127 

128 ### Raises 

129 - `RuntimeError`: If no GPUs are available or the SLURM environment is not 

130 set up correctly for Distributed Data Parallel (DDP). 

131 """ 

132 

133 cluster = ClusterEnvironment() 

134 comm_world = cluster.comm_world 

135 local_rank = cluster.comm_world_local_rank 

136 local_world_size = cluster.comm_world_local_size 

137 rank = cluster.comm_world_rank 

138 world_size = cluster.comm_world_size 

139 

140 self.device = f"cpu:{local_rank}" 

141 backend = 'gloo' 

142 if torch.cuda.is_available(): 

143 local_gpus_count = torch.cuda.device_count() 

144 assert local_world_size >= local_gpus_count 

145 # if multi-node then gather local_gpus_count 

146 # from local_rank 0 of every node 

147 total_gpus_count = comm_world.allreduce( # type: ignore 

148 local_gpus_count if local_rank == 0 else 0, 

149 op=MPI.SUM 

150 ) 

151 if world_size >= total_gpus_count: 

152 self.device = f"cuda:{local_rank % local_gpus_count}" 

153 backend = 'nccl' 

154 else: 

155 msg = "The number of MPI ranks must be at least equal to " 

156 "the number of available GPUs." 

157 logger.error(msg) 

158 raise RuntimeError(msg) 

159 

160 logger.info( 

161 f"Rank {self.rank}>> backend=\"{backend}\", device=\"{self.device}\", " 

162 f"world-size={world_size}" 

163 ) 

164 

165 # timeout should scale according to the number of nodes 

166 timeout = timedelta( 

167 minutes=min( 

168 # cluster.universe_size, 

169 cluster.comm_world_size, 

170 10, 

171 ) 

172 ) 

173 dist.init_process_group( 

174 init_method="env://", 

175 backend=backend, 

176 rank=rank, 

177 world_size=world_size, 

178 timeout=timeout 

179 ) 

180 # the creation of a new group helps to diverge some of the ranks away from training 

181 # NOTE: always use this group when doing collective calls on `torch.distributed` 

182 self.dist_group = dist.new_group( 

183 ranks=self._get_active_ranks(), 

184 timeout=timeout 

185 ) 

186 

187 def _get_active_ranks(self) -> Optional[List[int]]: 

188 """Returns a list of ranks that will be used for creating a group 

189 that participates in training and collective calls. 

190 

191 Override this method to adjust the ranks associated with training. 

192 """ 

193 return list(range(self.comm_size)) 

194 

195 @override 

196 def setup_environment(self) -> None: 

197 """Configures the environment for distributed GPU or CPU training using 

198 PyTorch's `torch.distributed` package.""" 

199 

200 self._set_master_addr_port(master_addr=self.node_name) 

201 self.__initialize_distributed_backend() 

202 

203 def wrap_model_ddp(self, model: Union[torch.nn.Module, DDP]) -> DDP: 

204 """Wraps the model in DistributedDataParallel (DDP) for multi-GPU training. 

205 

206 Depending on the setup (SLURM or local CUDA), this method wraps the model 

207 in DDP using the appropriate device(s). 

208 

209 ### Parameters 

210 - **model** (`Union[torch.nn.Module, DDP]`): Instantiated torch model. 

211 ### Returns 

212 - `torch.nn.parallel.DistributedDataParallel`: 

213 The model wrapped in DDP for distributed training.""" 

214 

215 if isinstance(model, DDP): 

216 return model 

217 

218 try: 

219 model = model.to(torch.device(self.device)) 

220 device_ids = None if "cpu" in self.device else [self.device] 

221 return DDP( 

222 module=model, 

223 device_ids=device_ids, 

224 process_group=self.dist_group 

225 ) 

226 except dist.DistBackendError as e: 

227 logger.exception( 

228 f"Rank {self.rank}>> MPI ranks and the number of GPUs must maintain " 

229 "a 1-to-1 mapping. " 

230 'For CPU training, export CUDA_VISIBLE_DEVICES=""', 

231 str(e) 

232 ) 

233 self._destroy_distributed_backend() 

234 raise FatalError 

235 

236 def _destroy_distributed_backend(self) -> None: 

237 # destroy the default (WORLD) group 

238 if dist.is_initialized(): 

239 dist.destroy_process_group() 

240 logger.info(f"Rank {self.rank}>> NCCL Group destroyed.") 

241 

242 @override 

243 def _server_finalize(self, exit_: int = 0) -> None: 

244 """ 

245 Finalizes the server operations by calling 

246 `torch.distributed.destroy_process_group`. 

247 

248 ### Parameters 

249 - **exit_ (int, optional)**: The exit status code indicating 

250 the outcome of the server's operations. 

251 Defaults to 0, which signifies a successful termination.""" 

252 self._destroy_distributed_backend() 

253 super()._server_finalize(exit_) 

254 

255 @override 

256 def _synchronize_data_availability(self) -> bool: 

257 """Coordinates dataset to be sure there are data available to be processed. 

258 This is to avoid any deadlock in `torch.distributed.all_reduce` of the gradients.""" 

259 assert self.dist_group_initialized 

260 _status = torch.tensor( 

261 int(self.dataset.has_data), 

262 dtype=int, # type: ignore 

263 device=self.device 

264 ) 

265 try: 

266 dist.all_reduce(_status, op=dist.ReduceOp.SUM, group=self.dist_group) 

267 return _status.item() == self.comm_size 

268 except RuntimeError as e: 

269 self._destroy_distributed_backend() 

270 raise TrainingError(str(e)) 

271 

272 @override 

273 def _load_model_from_checkpoint(self) -> None: 

274 

275 with open(self.ckpt_net_arch_path, 'rb') as f: 

276 self.model = cloudpickle.load(f) 

277 self.model = self.wrap_model_ddp(self.model) 

278 

279 map_location = {"cuda:0": self.device} if "cuda" in self.device else "cpu" 

280 checkpoint = torch.load( 

281 self.ckpt_model_path, 

282 map_location=map_location, # type: ignore 

283 weights_only=False 

284 ) 

285 

286 self.model.load_state_dict(checkpoint["model_state_dict"]) 

287 

288 self.optimizer = checkpoint["optimizer"] 

289 self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 

290 

291 self.batch_offset = checkpoint["batch_idx"] 

292 

293 dist.barrier(group=self.dist_group) 

294 

295 @override 

296 def checkpoint(self, batch_idx: int = 0) -> None: 

297 """The method called to initiate full tree checkpointing. 

298 Saves `self.model` and `self.optimizer` states.""" 

299 

300 if self.rank == 0: 

301 with open(self.ckpt_net_arch_path, 'wb') as f: 

302 cloudpickle.dump(self.unwrapped_model, f) 

303 

304 torch.save( 

305 { 

306 "optimizer": self.optimizer, 

307 "batch_idx": batch_idx, 

308 "model_state_dict": self.model.state_dict(), 

309 "optimizer_state_dict": self.optimizer.state_dict(), 

310 }, 

311 self.ckpt_model_path, 

312 )