Coverage for melissa/server/deep_learning/base_dl_server.py: 43%

338 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-10 22:25 +0100

1"""This script defines a base class for deep learning.""" 

2 

3import sys 

4import logging 

5import os 

6from threading import Thread 

7import time 

8from abc import abstractmethod 

9from functools import wraps 

10from typing_extensions import override 

11from typing import Any, Callable, Dict, Optional, Tuple 

12 

13from mpi4py import MPI 

14 

15import cloudpickle 

16from melissa.launcher import message 

17from melissa.server.main import ServerError 

18from melissa.server.base_server import ( 

19 ReceptionError, 

20 BaseServer, 

21 UnsupportedConfiguration, 

22) 

23from melissa.server.deep_learning import FrameworkType 

24from melissa.server.deep_learning.train_workflow import TrainingWorkflowMixin 

25from melissa.server.deep_learning.experimental_monitoring import ( # type: ignore 

26 ExperimentalMonitoringMixin 

27) 

28from melissa.server.deep_learning.dataset import ( 

29 MelissaIterableDataset, 

30 make_dataset, 

31 make_dataloader, 

32) 

33from melissa.server.deep_learning.reservoir import BaseQueue, make_buffer, BufferType 

34from melissa.server.deep_learning.tensorboard import ( 

35 TensorboardLogger, 

36 convert_tb_logs_to_df, 

37 make_tb_logger 

38) 

39from melissa.server.simulation import PartialSimulationData, Simulation, SimulationData 

40from melissa.utility.networking import get_rank_and_num_server_proc 

41from melissa.utility.idr_torch import SlurmEnvironment 

42 

43 

44logger = logging.getLogger(__name__) 

45 

46 

47class BaseDLServerError(Exception): 

48 """Any base dl server non-training error.""" 

49 

50 def __init__(self, msg) -> None: 

51 self.msg = msg 

52 

53 def __str__(self) -> str: 

54 return f"OtherError: {self.msg}" 

55 

56 

57class TrainingError(Exception): 

58 """Errors from the training loop.""" 

59 

60 def __init__(self, msg) -> None: 

61 self.msg = msg 

62 

63 def __str__(self) -> str: 

64 return f"Training Error: {self.msg}" 

65 

66 

67# TODO: Remove ExperimentalMonitoringMixIn the future. 

68class DeepMelissaServer( 

69 BaseServer, 

70 ExperimentalMonitoringMixin, 

71 TrainingWorkflowMixin, 

72): 

73 """`DeepMelissaServer` is designed for studies involving deep learning workflows. 

74 It manages data buffering, logging, and configuration necessary for distributed training. 

75 

76 ### Parameters 

77 - **config_dict** (`Dict[str, Any]`): A dictionary containing configuration settings for 

78 initializing the server. 

79 

80 ### Attributes 

81 - **dl_config** (`Dict[str, Any]`): Dictionary containing deep learning-specific configurations. 

82 - **batch_size** (`int`): The size of batches used for training. 

83 - **per_server_watermark** (`int`): Watermark level to determine data buffering thresholds. 

84 - **buffer_size** (`int`): Total size of the buffer to store training data. 

85 - **pseudo_epochs** (`int`): Number of pseudo-epochs to replicate epoch-based training 

86 in an online setting. 

87 - **current_sample_count** (`int`): Counter for the number of samples received. 

88 - **nb_expected_batches** (`int`): Number of expected batches; deterministic for FIFO and FIRO. 

89 - **nb_batches_update** (`int`): Number of batches after which updates are triggered during 

90 training (For example, for validation). 

91 - **checkpoint_interval** (`int`): Number of batches after which checkpointing is triggered. 

92 Defaults to `nb_batches_update`. 

93 - **setup_slurm_ddp** (`bool`): Boolean flag to configure SLURM-based 

94 distributed data parallel training. 

95 - **_model** (`Any`): Model to be trained. 

96 - **_optimizer** (`Any`): Training optimizer. 

97 - **__buffer** (`BaseQueue`): Instance of a buffer object (FIFO, FIRO, or Reservoir) to 

98 manage training data. 

99 - **__dataset** (`Dataset`): Dataset interface to provide training data, integrates with the 

100 buffer and allows transformations via `process_simulation_data`. 

101 - **_framework_t** (`FrameworkType`): Framework type for the iterable dataset to be 

102 instantiated with `make_dataset`. It can either be `DEFAULT`, `TORCH`, or `TENSORFLOW`. 

103 - **_tb_logger** (`TensorboardLogger`): Logger to handle logging of metrics for 

104 visualization during training.""" 

105 

106 def __init__(self, config_dict: Dict[str, Any], **kwargs) -> None: 

107 

108 super().__init__(config_dict, **kwargs) 

109 

110 # connection request requires this to be set 

111 self._learning = 2 

112 self._check_group_size() 

113 

114 # this lets us ping the launcher periodically 

115 # ensuring the launcher does not assume the server is dead. 

116 self.__run_handler_thread: bool = False 

117 self.__pinger_thread: Thread = Thread(name="pinger", target=self.__loop_pings) 

118 self.__receiver_thread: Thread = Thread(target=self._receive) 

119 self.__first_completion: bool = False 

120 

121 self.dl_config: Dict[str, Any] = config_dict["dl_config"] 

122 self.batch_size: int = self.dl_config["batch_size"] 

123 self.per_server_watermark: int = self.dl_config["per_server_watermark"] 

124 self.buffer_size: int = self.dl_config["buffer_size"] 

125 self.pseudo_epochs: int = self.dl_config.get("pseudo_epochs", 1) 

126 self.current_sample_count: int = 0 

127 self.nb_expected_batches: int = 1 

128 self.nb_expected_time_steps: int = 1 

129 self.nb_batches_update: int = self.dl_config["nb_batches_update"] 

130 self.checkpoint_interval: int = self.dl_config.get( 

131 "checkpoint_interval", 

132 self.nb_batches_update 

133 ) 

134 

135 slurm_env = SlurmEnvironment() 

136 self.setup_slurm_ddp: bool = ( 

137 slurm_env.nnodes > 1 

138 and len(slurm_env.gpu_ids) >= 1 

139 ) 

140 

141 self.batch_offset: int = 0 

142 self._model: Any = None 

143 self._optimizer: Any = None 

144 

145 # str -> enum 

146 self._tb_logger: Optional[TensorboardLogger] = None 

147 self._framework_t: FrameworkType = FrameworkType.DEFAULT 

148 

149 self._valid_dataloader: Any = None 

150 

151 def __configure_data_collection(self) -> None: 

152 """Instantiates the data collection i.e buffer, dataset, and dataloader. 

153 Users must implement `process_simulation_data`.""" 

154 

155 buffer_t: BufferType = BufferType[self.dl_config.get("buffer", "FIRO")] 

156 # initialize tensorboardLogger 

157 self._tb_logger = make_tb_logger( 

158 framework_t=self._framework_t, 

159 rank=self.rank, 

160 disable=not self.dl_config["tensorboard"], 

161 debug=self.verbose_level >= 3 

162 ) 

163 

164 self._buffer: BaseQueue = make_buffer( 

165 buffer_size=self.buffer_size, 

166 buffer_t=buffer_t, 

167 per_server_watermark=self.per_server_watermark, 

168 pseudo_epochs=self.pseudo_epochs, 

169 ) 

170 

171 self.__dataset: MelissaIterableDataset = make_dataset( 

172 framework_t=self._framework_t, 

173 buffer=self._buffer, 

174 tb_logger=self._tb_logger, 

175 config_dict=self.config_dict, 

176 transform=self.process_simulation_data, 

177 ) 

178 

179 self._train_dataloader = make_dataloader( 

180 framework_t=self._framework_t, 

181 iter_dataset=self.dataset, 

182 batch_size=self.batch_size, 

183 num_workers=0, 

184 drop_last=True, 

185 ) 

186 

187 @property 

188 def time_steps_known(self) -> bool: 

189 return self.nb_expected_batches != 0 

190 

191 @property 

192 def tb_logger(self) -> TensorboardLogger: 

193 assert self._tb_logger is not None 

194 return self._tb_logger 

195 

196 @property 

197 def buffer(self) -> BaseQueue: 

198 return self._buffer 

199 

200 @property 

201 def optimizer(self) -> Any: 

202 if self._optimizer is None: 

203 raise AttributeError( 

204 "Parent classes rely on `self.optimizer`. It must be set by the user." 

205 ) 

206 return self._optimizer 

207 

208 @optimizer.setter 

209 def optimizer(self, optimizer: Any) -> None: 

210 self._optimizer = optimizer 

211 

212 @property 

213 def model(self) -> Any: 

214 if self._model is None: 

215 raise AttributeError( 

216 "Parent classes rely on `self.model`. It must be set by the user." 

217 ) 

218 return self._model 

219 

220 @model.setter 

221 def model(self, model: Any) -> None: 

222 self._model = model 

223 

224 @property 

225 def dataset(self) -> MelissaIterableDataset: 

226 return self.__dataset 

227 

228 @dataset.setter 

229 def dataset(self, dataset: MelissaIterableDataset) -> None: 

230 self.__dataset = dataset 

231 

232 @property 

233 def valid_dataloader(self) -> Any: 

234 return self._valid_dataloader 

235 

236 @valid_dataloader.setter 

237 def valid_dataloader(self, dataloader: Any) -> None: 

238 if self.rank == 0: 

239 logger.info(f"Rank {self.rank}>> Setting valid_dataloader.") 

240 self._valid_dataloader = dataloader 

241 

242 @override 

243 def _check_group_size(self) -> None: 

244 """Checks if the group size was correctly set.""" 

245 

246 if self.group_size > 1 and self.nb_clients % self.group_size != 0: 

247 m = "Incorrect group_size, please remove or adjust this option" 

248 logger.error(m) 

249 self._catch_error = True 

250 raise UnsupportedConfiguration(m) 

251 

252 def __loop_pings(self) -> None: 

253 """Maintains communication with the launcher to ensure it 

254 does not assume the server has become unresponsive.""" 

255 

256 while self.__run_handler_thread: 

257 self._launcherfd.send(self._encode_msg(message.Ping())) 

258 logger.debug(f"Rank {self.rank}>> pinging launcher.") 

259 time.sleep(5) 

260 

261 def __start_pinger_thread(self) -> None: 

262 """Starts the pinger thread and set the flag.""" 

263 

264 self.__run_handler_thread = True 

265 self.__pinger_thread.start() 

266 

267 def __stop_pinger_thread(self) -> None: 

268 """Stops the pinger thread and unsets the flag.""" 

269 

270 self.__run_handler_thread = False 

271 if self.__pinger_thread.is_alive(): 

272 self.__pinger_thread.join(timeout=1.0) 

273 

274 @override 

275 def start(self) -> None: 

276 """The main entrypoint for the server events.""" 

277 

278 try: 

279 self.__configure_data_collection() 

280 if not self._restart: 

281 self._launch_first_groups() 

282 if not self.setup_slurm_ddp: 

283 self.setup_environment() 

284 else: 

285 self._setup_environment_slurm() 

286 if not self._restart: 

287 self.model, self.optimizer = self.prepare_training_attributes() 

288 else: 

289 # the reinitialization from checkpoint occurs here 

290 logger.info( 

291 f"Rank {self.rank}>> Continuing from checkpoint restart-count={self._restart}" 

292 ) 

293 self._restart_from_checkpoint() 

294 if self.rank == 0: 

295 self._kill_and_restart_simulations() 

296 self.__set_expected_batches_samples_watermark() 

297 

298 self._server_online() 

299 

300 if self._tb_logger is not None: 

301 self._tb_logger.close() 

302 if self.dl_config.get("convert_log_to_df", False): 

303 convert_tb_logs_to_df(self.rank) 

304 

305 self._server_finalize() 

306 

307 except ServerError as e: 

308 logger.exception(e) 

309 raise e 

310 

311 @override 

312 def _server_online(self) -> None: 

313 """Initiates data collection, and 

314 directs the custom methods for acting on collected data.""" 

315 

316 # put server receive on a separate thread. 

317 # should not be accesse by user 

318 self.__receiver_thread.start() 

319 try: 

320 self.train() 

321 except TrainingError as exc: 

322 self._catch_error = True 

323 logger.exception(f"Exception was raised in the training thread: \n {exc}") 

324 if self.no_fault_tolerance: 

325 self._server_finalize(exit_=1) 

326 sys.exit(1) 

327 

328 @override 

329 def _server_finalize(self, exit_: int = 0): 

330 """Finalizes the server operations. 

331 

332 ### Parameters 

333 - **exit_** (`int`, optional): The exit status code indicating 

334 the outcome of the server's operations. 

335 Defaults to 0, which signifies a successful termination.""" 

336 

337 if self.__receiver_thread.is_alive(): 

338 self.__receiver_thread.join(timeout=1.0) 

339 

340 self.comm.Barrier() 

341 self.__stop_pinger_thread() 

342 

343 self.__dataset.signal_reception_over() 

344 

345 super()._server_finalize(exit_) 

346 

347 def __signal_end_of_reception(self) -> None: 

348 """Unsets the reception when all data has been received, 

349 and notifies `MelissaIterableDataset` to stop the batch formation.""" 

350 

351 self._is_receiving = False 

352 self.__dataset.signal_reception_over() 

353 logger.debug("Signal end of reception.") 

354 

355 @override 

356 def _receive(self) -> None: 

357 """ "Handles data coming from the server object.""" 

358 

359 try: 

360 self._is_receiving = True 

361 while not self._all_done(): 

362 start = time.time() 

363 data = self.poll_sockets() 

364 

365 if data is not None and isinstance(data, SimulationData): 

366 logger.debug( 

367 f"Rank {self.rank}>> " 

368 f"sim-id={data.simulation_id}, " 

369 f"time-step={data.time_step} received." 

370 ) 

371 self._buffer.put(data) 

372 self.current_sample_count += 1 

373 self._tb_logger.log_scalar( # type: ignore 

374 "put_time", time.time() - start, self.current_sample_count 

375 ) 

376 

377 if self.current_sample_count % 10000 == 0: 

378 consumed, _ = self.get_memory_info_in_gb() 

379 self._tb_logger.log_scalar( # type: ignore 

380 "memory_consumed", consumed, self.current_sample_count 

381 ) 

382 # endwhile 

383 self.comm.Barrier() 

384 self.__signal_end_of_reception() 

385 # ping until the training loop is done. 

386 self.__start_pinger_thread() 

387 

388 except ReceptionError as exc: 

389 self._catch_error = True 

390 logger.exception(f"Exception was raised in the receiving thread: \n {exc}") 

391 if self.no_fault_tolerance: 

392 self.__signal_end_of_reception() 

393 logger.warning( 

394 f"Rank {self.rank}>> Training will stop once the buffers are empty." 

395 ) 

396 self._server_finalize(exit_=1) 

397 sys.exit(1) 

398 

399 @override 

400 def _process_partial_data_reception( 

401 self, simulation: Simulation, simulation_data: PartialSimulationData 

402 ) -> None: 

403 """Partial data has to be assembled for `DeepMelissaServer`. 

404 Do not perform anything.""" 

405 return None 

406 

407 @override 

408 def _process_complete_data_reception( 

409 self, simulation: Simulation, simulation_data: PartialSimulationData 

410 ) -> SimulationData: 

411 

412 # extract actual data from `PartialSimulationData` object. 

413 all_fields_data: Dict[str, Any] = { 

414 key: val.data 

415 for key, val in simulation.get_data( 

416 simulation_data.client_rank, simulation_data.time_step 

417 ).items() 

418 if isinstance(val, PartialSimulationData) 

419 } 

420 

421 # dereference `received_simulation_data` as we will put the returned data in the buffer. 

422 simulation.clear_data(simulation_data.client_rank, simulation_data.time_step) 

423 

424 if not self.__first_completion: 

425 self.__first_completion = True 

426 _, total = self.get_memory_info_in_gb() 

427 expected_buffer_consumption = (32 / 8) * self.buffer_size 

428 expected_buffer_consumption = expected_buffer_consumption * sum( 

429 v.size for v in all_fields_data.values() 

430 ) 

431 expected_buffer_consumption /= 1024**3 

432 if expected_buffer_consumption / total < 0.2: 

433 logger.warning( 

434 f"Rank {self.rank}>> [Suggestion] Buffer size can be increased. " 

435 f"Buffer/Main memory={expected_buffer_consumption:.2f}/{total:.2f} GB" 

436 ) 

437 

438 return SimulationData( 

439 simulation_data.simulation_id, 

440 simulation_data.time_step, 

441 all_fields_data, 

442 simulation.parameters, 

443 ) 

444 

445 @override 

446 def _validate_data(self, simulation_data: PartialSimulationData) -> bool: 

447 """Validates the simulation data.""" 

448 

449 sim_id, field, time_step = ( 

450 simulation_data.simulation_id, 

451 simulation_data.field, 

452 simulation_data.time_step, 

453 ) 

454 group_id = self._get_group_id_by_simulation(sim_id) 

455 

456 # handle termination messages 

457 if field == "termination": 

458 logger.info( 

459 f"Rank {self.rank}>> [Termination] sim-id={sim_id}, " 

460 f"total time-steps={time_step} received as expected." 

461 ) 

462 # modify the time steps received accordingly 

463 if self.nb_expected_batches == 0: 

464 self.nb_time_steps += time_step 

465 self._groups[group_id].simulations[sim_id].nb_time_steps = time_step 

466 logger.info(f"Rank {self.rank}>> sim-id={sim_id} finished sending.") 

467 self.nb_finished_simulations += 1 

468 return False 

469 

470 # apply validation checks 

471 if field not in self.fields: 

472 if field != "termination": 

473 logger.warning( 

474 f'Rank {self.rank}>> [Bad] sim-id={sim_id}, field="{field}"' 

475 ) 

476 return False 

477 

478 return super()._validate_data(simulation_data) 

479 

480 @override 

481 def _write_final_report(self) -> None: 

482 

483 total_batches = self.comm.allreduce( 

484 self.dataset.get_sample_number() // self.batch_size, 

485 op=MPI.SUM 

486 ) 

487 if self.rank == 0: 

488 logger.info(f" - Number of global batches: {total_batches}") 

489 super()._write_final_report() 

490 

491 def __set_expected_batches_samples_watermark(self) -> None: 

492 """Computes and sets the expected samples and batches per server process.""" 

493 

494 # standard case where nb_time_steps is given in the config file 

495 if self.nb_time_steps > 0: 

496 # ensure watermark is sufficient 

497 self.__check_water_mark() 

498 

499 # Account for possible accumulated shift 

500 self.nb_expected_time_steps = ( 

501 self.nb_clients // self.comm_size 

502 ) * self.nb_time_steps 

503 self.nb_expected_batches = ( 

504 self.nb_expected_time_steps // self.batch_size * self.pseudo_epochs 

505 ) 

506 

507 if ( 

508 self.pseudo_epochs > 1 

509 and self.buffer_size != self.nb_expected_time_steps 

510 ): 

511 logger.warning( 

512 "User tried using `pseudo_epochs` with `buffer_size` smaller than expected " 

513 "samples. Setting `buffer_size` to number of expected time steps." 

514 f"({self.nb_expected_time_steps})." 

515 ) 

516 self.buffer_size = self.nb_expected_time_steps 

517 

518 logger.info( 

519 f"Expecting {self.nb_expected_time_steps} " 

520 f"samples across {self.nb_expected_batches} batches." 

521 ) 

522 # when `nb_time_steps` is not known a priori 

523 else: 

524 logger.info("Number of expected samples a priori unknown.") 

525 self.nb_expected_batches = 0 

526 

527 def __check_water_mark(self) -> None: 

528 """Ensures there are sufficient samples to reach the `per_server_watermark`.""" 

529 

530 total_time_steps = self.nb_time_steps * self.nb_clients 

531 samples_per_server = total_time_steps // self.comm_size 

532 if not self.dl_config["per_server_watermark"] <= samples_per_server: 

533 raise UnsupportedConfiguration( 

534 "Insufficient samples to reach `per_server_watermark`. " 

535 "please increase `nb_time_steps`, " 

536 "or decrease `per_server_watermark`." 

537 ) 

538 

539 def other_processes_finished(self, batch_idx: int) -> bool: 

540 """Checks if other server processes have finished emptying their buffers. 

541 

542 ### Parameters 

543 - **batch_idx** (`int`): The current batch number being processed. 

544 

545 ### Returns 

546 - **`bool`**: if all other server processes have emptied their buffers.""" 

547 

548 logger.debug( 

549 f"Rank {self.rank}>> batch-idx={batch_idx}, " 

550 f"expected-batches >= {self.nb_expected_batches}" 

551 ) 

552 

553 # ensure self.__dataset._is_receiving is in sync across all server 

554 # processes. 

555 data_available: bool = self._synchronize_data_availability() 

556 

557 # in case of pseudo_offline training, we want to avoid a 

558 # server timeout so we ping the launcher with time_monitor 

559 if not data_available: 

560 # at this point the total number of expected samples should be known 

561 # and used to update the value of self.nb_expected_batches 

562 if self.nb_expected_batches == 0: 

563 # per client number of expected time-steps 

564 self.nb_time_steps //= self.nb_clients 

565 self.__set_expected_batches_samples_watermark() 

566 logger.debug( 

567 f"Rank {self.rank}>> One of the server processes finished receiving at " 

568 f"batch_idx={batch_idx}, " 

569 f"expected-batches >= {self.nb_expected_batches}" 

570 ) 

571 

572 return not data_available 

573 

574 def _synchronize_data_availability(self) -> bool: 

575 """Coordinates the dataset data availability status across all 

576 server processes. This usually requires a library 

577 specific all_reduce function (e.g. `dist.all_reduce` in pytorch). 

578 

579 Default behaviour is to check whether the buffer is empty across all MPI ranks. 

580 If at least one rank, finishes, then stop the training for all ranks.""" 

581 

582 local_status = int(self.dataset.has_data) 

583 global_status = self.comm.allreduce(local_status, op=MPI.SUM) 

584 

585 return global_status == self.comm_size 

586 

587 @override 

588 def checkpoint_state(self) -> None: 

589 """Checkpoint the current state of the server.""" 

590 

591 if self.no_fault_tolerance: 

592 return 

593 

594 self._save_base_state() 

595 # serialize the self._buffer.queue and then pickle it 

596 with open(f"checkpoints/{self.rank}/buffer_state.pkl", "wb") as file_: 

597 cloudpickle.dump(self._buffer.save_state(), file_) 

598 

599 @override 

600 def _restart_from_checkpoint(self) -> None: 

601 """Restarts the server object from a checkpoint.""" 

602 

603 self._load_base_state() 

604 

605 if ( 

606 not os.path.exists("checkpoints/model.pt") 

607 or not os.path.exists(f"checkpoints/{self.rank}/buffer_state.pkl") 

608 or not os.path.exists(f"checkpoints/{self.rank}/buffer_state.pkl") 

609 ): 

610 raise UnsupportedConfiguration( 

611 f"Rank {self.rank}>> No checkpoints and/or buffer state were found." 

612 "Please make sure that fault-tolerance is set." 

613 ) 

614 

615 logger.info(f"Rank {self.rank}>> Restarting from checkpoint.") 

616 with open(f"checkpoints/{self.rank}/buffer_state.pkl", "rb") as file_: 

617 state = cloudpickle.load(file_) 

618 self._buffer.load_from_state(state) 

619 

620 # lib specific loading method (torch vs tf) 

621 self._load_model_from_checkpoint() 

622 

623 @abstractmethod 

624 def process_simulation_data(self, data: SimulationData, config_dict: dict) -> Any: 

625 """Transforms data while creating batches with `MelissaIterableDataset`. 

626 See `SimulationData` for usage of attributes associated with the received data. 

627 

628 ### Parameters 

629 - **data** (`SimulationData`): The data message received from the simulation 

630 (pulled from the buffer). 

631 - **config_dict** (`Dict[str, Any]`): A dictionary containing configuration settings. 

632 

633 ### Returns 

634 - **`Any`**: Transformed data before creating a batch from it.""" 

635 raise NotImplementedError("This method must be implemented in a subclass.") 

636 

637 def validation(self, batch_idx: int): 

638 """Predefined validation loop agnostic of frameworks.""" 

639 

640 logger.info(f"Rank {self.rank} Running validation at batch_idx={batch_idx}") 

641 self._on_validation_start(batch_idx) 

642 for v_batch_idx, v_batch in enumerate(self._valid_dataloader): 

643 self.validation_step(v_batch, v_batch_idx, batch_idx) 

644 self._on_validation_end(batch_idx) 

645 

646 def train(self) -> None: 

647 """Predefined training loop agnostic of frameworks.""" 

648 

649 try: 

650 valid_found = self._valid_dataloader is not None 

651 if not valid_found: 

652 logger.warning( 

653 f"Rank {self.rank}>> [Ignorable] Attribute `valid_dataloader` must be set " 

654 "to perform validation." 

655 ) 

656 

657 self._on_train_start() 

658 for batch_idx, batch in enumerate(self._train_dataloader): 

659 batch_idx += self.batch_offset 

660 logger.debug(f"Rank {self.rank}>> Training on batch-idx={batch_idx}") 

661 if self.other_processes_finished(batch_idx): 

662 logger.info( 

663 f"Rank {self.rank}>> At least one other process has finished. " 

664 "Break training loop." 

665 ) 

666 break 

667 

668 self._on_batch_start(batch_idx) 

669 self.training_step(batch, batch_idx) 

670 self._on_batch_end(batch_idx) 

671 # TODO: Multi-GPU validation (not a priority) 

672 # it should be framework specific thing. and may require overriding 

673 # this entire if block 

674 if ( 

675 self.rank == 0 

676 and valid_found 

677 and batch_idx > 0 

678 and (batch_idx + 1) % self.nb_batches_update == 0 

679 ): 

680 self.validation(batch_idx) 

681 

682 self._checkpoint(batch_idx) 

683 # end training loop 

684 self._on_train_end() 

685 except TrainingError as exc: 

686 raise exc 

687 

688 @abstractmethod 

689 def _setup_environment_slurm(self) -> None: 

690 """Sets up the unique Distributed Data Parallel (DDP) environment using SLURM 

691 as per the recommendations from: [Jean-Zay Documentation] 

692 (http://www.idris.fr/eng/jean-zay/gpu/jean-zay-gpu-torch-multi-eng.html)""" 

693 raise NotImplementedError("This method must be implemented in a subclass.") 

694 

695 @abstractmethod 

696 def prepare_training_attributes(self) -> Tuple[Any, Any]: 

697 """Required to configure server's `self.model` and `self.optimizer` attributes, 

698 preparing them for initialization. 

699 

700 ### Returns 

701 - `Tuple[Any, Any]`: 

702 - **model** (`Any`): Instantiated model object. 

703 - **optimizer** (`Any`): Instantiated optimizer object.""" 

704 

705 raise NotImplementedError( 

706 "This method must be implemented in a subclass. " 

707 "Parent classes rely on `self.model`, and `self.optimizer` " 

708 "which are set from the return values of this method." 

709 ) 

710 

711 @abstractmethod 

712 def checkpoint(self, batch_idx: int, path: str = "checkpoints"): 

713 """The method called to initiate full tree checkpointing. This is 

714 specific to `torch` or `tensorflow` server.""" 

715 raise NotImplementedError("This method must be implemented in a subclass.") 

716 

717 def _checkpoint(self, batch_idx: int, path: str = "checkpoints"): 

718 """Checkpointing at specific interval. The interval defaults `nb_batches_update`.""" 

719 if batch_idx > 0 and (batch_idx + 1) % self.checkpoint_interval == 0: 

720 self.checkpoint(batch_idx, path) 

721 

722 @abstractmethod 

723 def _load_model_from_checkpoint(self): 

724 """Library specific model loading function. This is 

725 specific to `torch` or `tensorflow` server.""" 

726 raise NotImplementedError("This method must be implemented in a subclass.") 

727 

728 

729def rank_zero_only(fn_: Callable) -> Callable: 

730 """Function that can be used as a decorator to enable a function/method 

731 being called only on rank 0. Inspired by pytorch_lightning""" 

732 

733 rank, _ = get_rank_and_num_server_proc() 

734 

735 @wraps(fn_) 

736 def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]: 

737 if rank == 0: 

738 return fn_(*args, **kwargs) 

739 return None 

740 

741 return wrapped_fn