Coverage for melissa/server/deep_learning/base_dl_server.py: 39%

315 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-11-03 09:52 +0100

1"""This script defines a base class for deep learning.""" 

2 

3import logging 

4import os 

5from threading import Thread, current_thread 

6import time 

7from abc import abstractmethod 

8from typing_extensions import override 

9from typing import Any, Dict, Optional, Tuple 

10 

11from mpi4py import MPI 

12 

13import cloudpickle 

14from melissa.server.base_server import BaseServer 

15from melissa.server.deep_learning import FrameworkType 

16from melissa.server.deep_learning.train_workflow import TrainingWorkflowMixin 

17from melissa.server.deep_learning.dataset import ( 

18 MelissaIterableDataset, 

19 make_dataset, 

20 make_dataloader, 

21) 

22from melissa.server.deep_learning.reservoir import BaseQueue, make_buffer, BufferType 

23from melissa.server.deep_learning.tensorboard import ( 

24 TensorboardLogger, 

25 convert_tb_logs_to_df, 

26 make_tb_logger 

27) 

28from melissa.server.simulation import PartialSimulationData, Simulation, SimulationData 

29from melissa.server.exceptions import ( 

30 ConfigurationFileError, 

31 TrainingError, 

32 ReceptionError, 

33 insert_exception, 

34 check_and_raise_exceptions_from_other_thread 

35) 

36from melissa.utility.metadata import Payload 

37 

38 

39logger = logging.getLogger(__name__) 

40 

41 

42class DeepMelissaServer( 

43 BaseServer, 

44 TrainingWorkflowMixin, 

45): 

46 """`DeepMelissaServer` is designed for studies involving deep learning workflows. 

47 It manages data buffering, logging, and configuration necessary for distributed training. 

48 

49 ### Parameters 

50 - **config_dict** (`Dict[str, Any]`): A dictionary containing configuration settings for 

51 initializing the server. 

52 

53 ### Attributes 

54 

55 - **_bind_simulation_to_server_rank** (`int`): Whether to bind sending all timesteps of a 

56 simulation to the same server rank. By default, timesteps are sent in a round-robin fashion. 

57 - **dl_config** (`Dict[str, Any]`): Dictionary containing deep learning-specific configurations. 

58 - **batch_size** (`int`): The size of batches used for training. 

59 - **per_server_watermark** (`int`): Watermark level to determine data buffering thresholds. 

60 - **buffer_size** (`int`): Total size of the buffer to store training data. 

61 - **_batches_watermark_set** (`bool`): Prevents recalculation of expected number of batches 

62 given that the time-steps are known. 

63 - **pseudo_epochs** (`int`): Number of pseudo-epochs to replicate epoch-based training 

64 in an online setting. 

65 - **current_sample_count** (`int`): Counter for the number of samples received. 

66 - **self.nb_local_batches** (`int`): Number of batches processed. Local to MPI rank. 

67 - **nb_expected_batches** (`int`): Number of expected batches; deterministic for FIFO and FIRO. 

68 - **nb_batches_update** (`int`): Number of batches after which updates are triggered during 

69 training (For example, for validation). 

70 - **checkpoint_interval** (`int`): Number of batches after which checkpointing is triggered. 

71 Defaults to `nb_batches_update`. 

72 - **_model** (`Any`): Model to be trained. 

73 - **_optimizer** (`Any`): Training optimizer. 

74 - **__buffer** (`BaseQueue`): Instance of a buffer object (FIFO, FIRO, or Reservoir) to 

75 manage training data. 

76 - **__dataset** (`Dataset`): Dataset interface to provide training data, integrates with the 

77 buffer and allows transformations via `process_simulation_data`. 

78 - **_framework_t** (`FrameworkType`): Framework type for the iterable dataset to be 

79 instantiated with `make_dataset`. It can either be `DEFAULT`, `TORCH`, or `TENSORFLOW`. 

80 - **_tb_logger** (`TensorboardLogger`): Logger to handle logging of metrics for 

81 visualization during training.""" 

82 

83 def __init__(self, config_dict: Dict[str, Any], **kwargs) -> None: 

84 

85 super().__init__(config_dict, **kwargs) 

86 

87 # connection request requires this to be set 

88 self._learning = 2 

89 self._check_group_size() 

90 self._bind_simulation_to_server_rank = int( 

91 config_dict["study_options"].get( 

92 "bind_simulation_to_server_rank", 

93 0 

94 ) 

95 ) 

96 self.__receiver_thread: Thread = Thread(target=self._receive, name="receive_thread") 

97 self.__first_completion: bool = False 

98 

99 self.dl_config: Dict[str, Any] = config_dict["dl_config"] 

100 self.batch_size: int = self.dl_config["batch_size"] 

101 self.per_server_watermark: int = self.dl_config["per_server_watermark"] 

102 self.buffer_size: int = self.dl_config["buffer_size"] 

103 self._batches_watermark_set: bool = False 

104 self.pseudo_epochs: int = self.dl_config.get("pseudo_epochs", 1) 

105 self.current_sample_count: int = 0 

106 self.nb_local_batches: int = 0 

107 self.nb_expected_batches: int = 1 

108 self.nb_expected_time_steps: int = 1 

109 # adjust the value based on the total server ranks launched 

110 self.nb_batches_update: int = self.dl_config["nb_batches_update"] // self.comm_size 

111 self.checkpoint_interval: int = self.dl_config.get( 

112 "checkpoint_interval", 

113 self.nb_batches_update 

114 ) 

115 

116 self.batch_offset: int = 0 

117 self._model: Any = None 

118 self._optimizer: Any = None 

119 

120 # str -> enum 

121 self._tb_logger: Optional[TensorboardLogger] = None 

122 self._framework_t: FrameworkType = FrameworkType.DEFAULT 

123 

124 self._valid_dataloader: Any = None 

125 

126 self.ckpt_buffer_state_path: str = f"checkpoints/{self.rank}/buffer_state.pkl" 

127 self.ckpt_model_path: str = "checkpoints/model.pt" 

128 

129 def __adjust_buffer_size_for_pseudo_epochs(self) -> None: 

130 """Adjusts buffer size before buffer instantiation for pseudo-offline training.""" 

131 

132 if not self.time_steps_known: 

133 return 

134 

135 if self.pseudo_epochs > 1: 

136 expected_samples = (self.nb_clients // self.comm_size) * self.nb_time_steps 

137 

138 if self.buffer_size < expected_samples: 

139 logger.warning( 

140 f"Rank {self.rank}>> Adjusting buffer_size from {self.buffer_size} " 

141 f"to {expected_samples} for pseudo_epochs={self.pseudo_epochs}" 

142 ) 

143 self.buffer_size = expected_samples 

144 self.dl_config["buffer_size"] = expected_samples 

145 

146 def __configure_data_collection(self) -> None: 

147 """Instantiates the data collection i.e buffer, dataset, and dataloader. 

148 Users must implement `process_simulation_data`.""" 

149 

150 self.__adjust_buffer_size_for_pseudo_epochs() 

151 

152 buffer_str = self.dl_config.get("buffer", "FIRO") 

153 buffer_t: BufferType = BufferType[buffer_str] 

154 if self._bind_simulation_to_server_rank: 

155 if "reservoir" not in buffer_str.lower(): 

156 raise ConfigurationFileError( 

157 "Binding a simulation to a specific server rank is supported only with " 

158 "the `Reservoir` type buffers." 

159 ) 

160 if self._job_limit < self.comm_size or self.nb_clients < self.comm_size: 

161 raise ConfigurationFileError( 

162 "Binding a simulation to a specific server rank is supported only when the " 

163 "`job_limit` and `parameter_sweep_size` is greater than or equal to " 

164 "the total server ranks launched." 

165 ) 

166 

167 if self.pseudo_epochs > 1 and buffer_t is not BufferType.FIRO: 

168 logger.warning(f"Rank {self.rank}>> `pseudo_epochs > 1` is only supported " 

169 "for `FIRO` buffer.") 

170 self.pseudo_epochs = 1 

171 

172 rr_type = "Simulations" if self._bind_simulation_to_server_rank else "Timesteps" 

173 logger.info( 

174 f"Rank {self.rank}>> Round-robin strategy has been set to \"{rr_type}\" " 

175 "across the server ranks." 

176 ) 

177 # initialize tensorboardLogger 

178 self._tb_logger = make_tb_logger( 

179 framework_t=self._framework_t, 

180 rank=self.rank, 

181 disable=not self.dl_config["tensorboard"], 

182 debug=self.verbose_level >= 3 

183 ) 

184 

185 self._buffer: BaseQueue = make_buffer( 

186 buffer_size=self.buffer_size, 

187 buffer_t=buffer_t, 

188 per_server_watermark=self.per_server_watermark, 

189 pseudo_epochs=self.pseudo_epochs, 

190 sweep_size=self.nb_clients, 

191 comm_size=self.comm_size, 

192 nb_time_steps=self.nb_time_steps, 

193 ) 

194 

195 self.__dataset: MelissaIterableDataset = make_dataset( 

196 framework_t=self._framework_t, 

197 buffer=self._buffer, 

198 tb_logger=self._tb_logger, 

199 config_dict=self.config_dict, 

200 transform=self.process_simulation_data, 

201 ) 

202 

203 self._train_dataloader = make_dataloader( 

204 framework_t=self._framework_t, 

205 iter_dataset=self.dataset, 

206 batch_size=self.batch_size, 

207 num_workers=0, 

208 drop_last=True, 

209 ) 

210 

211 @property 

212 def tb_logger(self) -> TensorboardLogger: 

213 assert self._tb_logger is not None 

214 return self._tb_logger 

215 

216 @property 

217 def buffer(self) -> BaseQueue: 

218 return self._buffer 

219 

220 @property 

221 def optimizer(self) -> Any: 

222 if self._optimizer is None: 

223 raise AttributeError( 

224 "Parent classes rely on `self.optimizer`. It must be set by the user." 

225 ) 

226 return self._optimizer 

227 

228 @optimizer.setter 

229 def optimizer(self, optimizer: Any) -> None: 

230 self._optimizer = optimizer 

231 

232 @property 

233 def model(self) -> Any: 

234 if self._model is None: 

235 raise AttributeError( 

236 "Parent classes rely on `self.model`. It must be set by the user." 

237 ) 

238 return self._model 

239 

240 @model.setter 

241 def model(self, model: Any) -> None: 

242 self._model = model 

243 

244 @property 

245 def dataset(self) -> MelissaIterableDataset: 

246 return self.__dataset 

247 

248 @dataset.setter 

249 def dataset(self, dataset: MelissaIterableDataset) -> None: 

250 self.__dataset = dataset 

251 

252 @property 

253 def valid_dataloader(self) -> Any: 

254 return self._valid_dataloader 

255 

256 @valid_dataloader.setter 

257 def valid_dataloader(self, dataloader: Any) -> None: 

258 logger.info(f"Rank {self.rank}>> Setting valid_dataloader.") 

259 self._valid_dataloader = dataloader 

260 if self.comm_size > 1: 

261 logger.warning( 

262 f"Rank {self.rank}>> `valid_dataloader` must load different data " 

263 "for each server rank." 

264 ) 

265 logger.warning( 

266 f"Rank {self.rank}>> Users must call `get_reduced_validation_loss` " 

267 "to obtain the mean validation loss across all server ranks." 

268 ) 

269 

270 def _restart_simulations(self): 

271 """Alias for `_restart_groups` method.""" 

272 self._restart_groups() 

273 

274 @override 

275 def _check_group_size(self) -> None: 

276 """Checks if the group size was correctly set.""" 

277 

278 if self.group_size > 1 and self.nb_clients % self.group_size != 0: 

279 raise ConfigurationFileError( 

280 "Incorrect group_size, please remove or adjust this option" 

281 ) 

282 

283 @override 

284 def start(self) -> None: 

285 """The main entrypoint for the server events.""" 

286 

287 self.__configure_data_collection() 

288 if self._restart == 0: 

289 self._launch_groups(list(range(0, self._job_limit))) 

290 

291 self.setup_environment() 

292 

293 if self._restart == 0: 

294 self.model, self.optimizer = self.prepare_training_attributes() 

295 else: 

296 self._restart_from_checkpoint() 

297 self._restart_simulations() 

298 self.__set_expected_batches_samples_watermark() 

299 

300 self._server_online() 

301 self._server_offline() 

302 self._server_finalize() 

303 

304 @override 

305 def _server_online(self) -> None: 

306 """Initiates data collection, and 

307 directs the custom methods for acting on collected data.""" 

308 

309 # put server receive on a separate thread. 

310 # should not be accesse by user 

311 self.__receiver_thread.start() 

312 try: 

313 self.train() 

314 except ReceptionError: 

315 logger.error( 

316 f"Rank {self.rank}>> An error occured on the receiving thread." 

317 ) 

318 except Exception as e: 

319 insert_exception(e) 

320 raise e 

321 

322 @override 

323 def _server_offline(self): 

324 """Optional. Post processing steps.""" 

325 if self._tb_logger is not None: 

326 self._tb_logger.close() 

327 if self.dl_config.get("convert_log_to_df", False): 

328 convert_tb_logs_to_df(self.rank) 

329 

330 @override 

331 def _server_finalize(self, exit_: int = 0): 

332 """Finalizes the server operations. 

333 

334 ### Parameters 

335 - **exit_** (`int`, optional): The exit status code indicating 

336 the outcome of the server's operations. 

337 Defaults to 0, which signifies a successful termination.""" 

338 

339 if ( 

340 current_thread() != self.__receiver_thread 

341 and self.__receiver_thread.is_alive() 

342 ): 

343 self.__receiver_thread.join(timeout=1.0) 

344 

345 self._stop_pinger_thread() 

346 

347 self.__dataset.signal_reception_over() 

348 

349 super()._server_finalize(exit_) 

350 

351 def __signal_end_of_reception(self) -> None: 

352 """Unsets the reception when all data has been received, 

353 and notifies `MelissaIterableDataset` to stop the batch formation.""" 

354 

355 with self._consistency_lock: 

356 self._is_receiving = False 

357 self.__dataset.signal_reception_over() 

358 logger.debug("Signal end of reception.") 

359 

360 @override 

361 def _receive(self) -> None: 

362 """ "Handles data coming from the server object.""" 

363 

364 try: 

365 self._is_receiving = True 

366 while not self._all_done(): 

367 check_and_raise_exceptions_from_other_thread() 

368 start = time.time() 

369 data = self.poll_sockets() 

370 

371 if data is not None and isinstance(data, SimulationData): 

372 logger.debug( 

373 f"Rank {self.rank}>> " 

374 f"sim-id={data.simulation_id}, " 

375 f"time-step={data.time_step} received." 

376 ) 

377 self._buffer.put(data) 

378 self.current_sample_count += 1 

379 self._tb_logger.log_scalar( # type: ignore 

380 "put_time", time.time() - start, self.current_sample_count 

381 ) 

382 

383 if self.current_sample_count % 10000 == 0: 

384 consumed, _ = self.get_memory_info_in_gb() 

385 self._tb_logger.log_scalar( # type: ignore 

386 "memory_consumed", consumed, self.current_sample_count 

387 ) 

388 # endwhile 

389 self.comm.Barrier() 

390 self.__signal_end_of_reception() 

391 # ping until the training loop is done. 

392 self._start_pinger_thread() 

393 

394 except TrainingError: 

395 logger.error( 

396 f"Rank {self.rank}>> An error occured on the training thread." 

397 ) 

398 except Exception as e: 

399 insert_exception(e) 

400 self.__signal_end_of_reception() 

401 raise e 

402 finally: 

403 self._is_receiving = False 

404 

405 @override 

406 def _process_partial_data_reception( 

407 self, simulation: Simulation, simulation_data: PartialSimulationData 

408 ) -> None: 

409 """Partial data has to be assembled for `DeepMelissaServer`. 

410 Do not perform anything.""" 

411 return None 

412 

413 @override 

414 def _process_complete_data_reception( 

415 self, simulation: Simulation, simulation_data: PartialSimulationData 

416 ) -> SimulationData: 

417 

418 # extract actual data from `PartialSimulationData` object. 

419 all_fields_payload: Dict[str, Payload] = { 

420 key: val.payload 

421 for key, val in simulation.get_data( 

422 simulation_data.client_rank, 

423 simulation_data.time_step 

424 ).items() 

425 if isinstance(val, PartialSimulationData) 

426 } 

427 

428 # dereference `received_simulation_data` as we will put the returned data in the buffer. 

429 simulation.clear_data(simulation_data.client_rank, simulation_data.time_step) 

430 

431 if not self.__first_completion: 

432 self.__first_completion = True 

433 _, total = self.get_memory_info_in_gb() 

434 expected_buffer_consumption = (32 / 8) * self.buffer_size 

435 expected_buffer_consumption = expected_buffer_consumption * sum( 

436 v.data.size for v in all_fields_payload.values() 

437 ) 

438 expected_buffer_consumption /= 1024**3 

439 if expected_buffer_consumption / total < 0.2: 

440 logger.warning( 

441 f"Rank {self.rank}>> [Suggestion] Buffer size can be increased. " 

442 f"Buffer/Main memory={expected_buffer_consumption:.2f}/{total:.2f} GB" 

443 ) 

444 

445 return SimulationData( 

446 simulation_data.simulation_id, 

447 simulation_data.time_step, 

448 all_fields_payload, 

449 simulation.parameters, 

450 ) 

451 

452 @override 

453 def _write_final_report(self) -> None: 

454 

455 global_batches = self.comm.allreduce( 

456 self.nb_local_batches, 

457 op=MPI.SUM 

458 ) 

459 if self.rank == 0: 

460 logger.info(f" - Number of Global Batches: {global_batches}") 

461 super()._write_final_report() 

462 

463 def __set_expected_batches_samples_watermark(self) -> None: 

464 """Computes and sets the expected samples and batches per server process.""" 

465 

466 if self._batches_watermark_set: 

467 return 

468 

469 if not self.time_steps_known and self.pseudo_epochs > 1: 

470 raise ConfigurationFileError( 

471 "`nb_time_steps` must be provided to adjust the buffer " 

472 "size for pseudo-offline training." 

473 ) 

474 

475 # standard case where nb_time_steps is given in the config file 

476 if self.time_steps_known: 

477 self._batches_watermark_set = True 

478 # ensure watermark is sufficient 

479 self.__check_water_mark() 

480 

481 # unique samples 

482 self.nb_expected_time_steps = ( 

483 self.nb_clients // self.comm_size 

484 ) * self.nb_time_steps 

485 

486 # unique batches 

487 self.nb_expected_batches = self.nb_expected_time_steps // self.batch_size 

488 

489 # calculate expected batches considering pseudo epochs 

490 if self.pseudo_epochs > 1: 

491 # for pseudo-offline, samples are reused across epochs 

492 self.nb_expected_batches *= self.pseudo_epochs 

493 

494 # verify buffer size was properly adjusted 

495 # should have been done in __adjust_buffer_size_for_pseudo_epochs 

496 assert self.buffer_size == self.nb_expected_time_steps, ( 

497 f"Rank {self.rank}>> Buffer size mismatch detected! " 

498 f"buffer_size={self.buffer_size}, expected={self.nb_expected_time_steps}. " 

499 "This should have been adjusted before buffer instantiation." 

500 ) 

501 

502 logger.info( 

503 f"Rank {self.rank}>> [Pseudo-offline Training] Expecting " 

504 f"{self.nb_expected_time_steps * self.pseudo_epochs} samples " 

505 f"across {self.nb_expected_batches} batches." 

506 ) 

507 else: 

508 logger.info( 

509 f"Rank {self.rank}>> [Online Training] Expecting " 

510 f"{self.nb_expected_time_steps} samples " 

511 f"across {self.nb_expected_batches} batches." 

512 ) 

513 

514 # when `nb_time_steps` is not known a priori 

515 else: 

516 logger.info(f"Rank {self.rank}>> Number of expected samples a priori unknown.") 

517 self.nb_expected_batches = 0 

518 

519 def __check_water_mark(self) -> None: 

520 """Ensures there are sufficient samples to reach the `per_server_watermark`.""" 

521 

522 total_time_steps = self.nb_time_steps * self.nb_clients 

523 samples_per_server = total_time_steps // self.comm_size 

524 if not self.dl_config["per_server_watermark"] <= samples_per_server: 

525 raise ConfigurationFileError( 

526 "Insufficient samples to reach `per_server_watermark`. " 

527 "please increase `nb_time_steps`, " 

528 "or decrease `per_server_watermark`." 

529 ) 

530 

531 def other_processes_finished(self, batch_idx: int) -> bool: 

532 """Checks if other server processes have finished emptying their buffers. 

533 

534 ### Parameters 

535 - **batch_idx** (`int`): The current batch number being processed. 

536 

537 ### Returns 

538 - **`bool`**: if all other server processes have emptied their buffers.""" 

539 

540 logger.debug( 

541 f"Rank {self.rank}>> batch-idx={batch_idx}, " 

542 f"expected-batches >= {self.nb_expected_batches}" 

543 ) 

544 with self._consistency_lock: 

545 # ensure self.__dataset._is_receiving is in sync across all server 

546 # processes. 

547 data_available: bool = self._synchronize_data_availability() 

548 

549 if not data_available: 

550 # at this point the total number of expected samples should be known 

551 # and used to update the value of self.nb_expected_batches 

552 self.__set_expected_batches_samples_watermark() 

553 logger.debug( 

554 f"Rank {self.rank}>> One of the server processes finished receiving at " 

555 f"batch_idx={batch_idx}, " 

556 f"expected-batches >= {self.nb_expected_batches}" 

557 ) 

558 

559 return not data_available 

560 

561 def _synchronize_data_availability(self) -> bool: 

562 """Coordinates the dataset data availability status across all 

563 server processes. This usually requires a library 

564 specific all_reduce function (e.g. `dist.all_reduce` in pytorch). 

565 

566 Default behaviour is to check whether the buffer is empty across all MPI ranks. 

567 If at least one rank, finishes, then stop the training for all ranks.""" 

568 

569 local_status = int(self.dataset.has_data) 

570 global_status = self.comm.allreduce(local_status, op=MPI.SUM) 

571 

572 return global_status == self.comm_size 

573 

574 @override 

575 def _save_base_state(self) -> None: 

576 """Checkpoint the current state of the server.""" 

577 

578 if self.no_fault_tolerance: 

579 return 

580 

581 super()._save_base_state() 

582 with self.buffer.mutex: 

583 # serialize the self._buffer.queue and then pickle it 

584 with open(self.ckpt_buffer_state_path, "wb") as file_: 

585 cloudpickle.dump(self._buffer.save_state(), file_) 

586 

587 @override 

588 def _restart_from_checkpoint(self, **kwargs) -> None: 

589 """Restarts the server object from a checkpoint.""" 

590 logger.info( 

591 f"Rank {self.rank}>> Continuing from checkpoint restart-count={self._restart}" 

592 ) 

593 

594 self._load_base_state() 

595 

596 if ( 

597 not os.path.exists(self.ckpt_model_path) 

598 or not os.path.exists(self.ckpt_buffer_state_path) 

599 ): 

600 raise ConfigurationFileError( 

601 f"Rank {self.rank}>> No checkpoints and/or buffer state were found." 

602 "Please make sure that fault-tolerance is set." 

603 ) 

604 

605 with open(self.ckpt_buffer_state_path, "rb") as file_: 

606 state = cloudpickle.load(file_) 

607 self._buffer.load_from_state(state) 

608 

609 # lib specific loading method (torch vs tf) 

610 self._load_model_from_checkpoint() 

611 

612 @abstractmethod 

613 def process_simulation_data(self, data: SimulationData, config_dict: dict) -> Any: 

614 """Transforms data while creating batches with `MelissaIterableDataset`. 

615 See `SimulationData` for usage of attributes associated with the received data. 

616 

617 ### Parameters 

618 - **data** (`SimulationData`): The data message received from the simulation 

619 (pulled from the buffer). 

620 - **config_dict** (`Dict[str, Any]`): A dictionary containing configuration settings. 

621 

622 ### Returns 

623 - **`Any`**: Transformed data before creating a batch from it.""" 

624 raise NotImplementedError("This method must be implemented in a subclass.") 

625 

626 def get_reduced_validation_loss(self, valid_loss: float) -> float: 

627 """Returns reduced validation loss across all server ranks.""" 

628 return self.comm.allreduce(valid_loss, op=MPI.SUM) / self.comm.size 

629 

630 def validation(self, batch_idx: int): 

631 """Predefined validation loop agnostic of frameworks. 

632 This can be overridable by the user.""" 

633 

634 if ( 

635 self._valid_dataloader is not None 

636 and batch_idx > 0 

637 and (batch_idx + 1) % self.nb_batches_update == 0 

638 ): 

639 logger.info(f"Rank {self.rank}>> Running validation at batch_idx={batch_idx}") 

640 self._on_validation_start(batch_idx) 

641 for v_batch_idx, v_batch in enumerate(self._valid_dataloader): 

642 self.validation_step(v_batch, v_batch_idx, batch_idx) 

643 self._on_validation_end(batch_idx) 

644 

645 def train(self) -> None: 

646 """Predefined training loop agnostic of frameworks.""" 

647 

648 if self._valid_dataloader is None: 

649 logger.warning( 

650 f"Rank {self.rank}>> [Ignorable] Attribute `valid_dataloader` must be set " 

651 "to perform validation." 

652 ) 

653 

654 self._on_train_start() 

655 

656 batch_idx = -1 

657 last_log_time = time.time() 

658 last_log_batch_idx = self.batch_offset 

659 

660 for batch_idx, batch in enumerate(self._train_dataloader, start=self.batch_offset): 

661 logger.debug(f"Rank {self.rank}>> Training on batch-idx={batch_idx}") 

662 if self.other_processes_finished(batch_idx): 

663 logger.info( 

664 f"Rank {self.rank}>> At least one other process has finished. " 

665 "Break training loop." 

666 ) 

667 break 

668 

669 self._on_batch_start(batch_idx) 

670 self.training_step(batch, batch_idx) 

671 self._on_batch_end(batch_idx) 

672 

673 self.validation(batch_idx) 

674 

675 self._checkpoint(batch_idx) 

676 if (batch_idx + 1) % self.nb_batches_update == 0: 

677 current_time = time.time() 

678 batches_processed = batch_idx - last_log_batch_idx 

679 time_elapsed = current_time - last_log_time 

680 

681 if time_elapsed > 0: 

682 batches_per_sec = batches_processed / time_elapsed 

683 logger.info( 

684 f"Rank {self.rank}>> [Insight] " 

685 f"Training throughput={batches_per_sec:.2f} batches/sec." 

686 ) 

687 

688 last_log_time = current_time 

689 last_log_batch_idx = batch_idx 

690 

691 # end training loop 

692 self._on_train_end() 

693 self.nb_local_batches = batch_idx + 1 

694 

695 @abstractmethod 

696 def prepare_training_attributes(self) -> Tuple[Any, Any]: 

697 """Required to configure server's `self.model` and `self.optimizer` attributes, 

698 preparing them for initialization. 

699 

700 ### Returns 

701 - `Tuple[Any, Any]`: 

702 - **model** (`Any`): Instantiated model object. 

703 - **optimizer** (`Any`): Instantiated optimizer object.""" 

704 

705 raise NotImplementedError( 

706 "This method must be implemented in a subclass. " 

707 "Parent classes rely on `self.model`, and `self.optimizer` " 

708 "which are set from the return values of this method." 

709 ) 

710 

711 @abstractmethod 

712 def checkpoint(self, batch_idx: int): 

713 """The method called to initiate full tree checkpointing. This is 

714 specific to `torch` or `tensorflow` server.""" 

715 raise NotImplementedError("This method must be implemented in a subclass.") 

716 

717 @override 

718 def _checkpoint(self, batch_idx: int): # type: ignore 

719 """Checkpointing at specific interval. The interval defaults `nb_batches_update`.""" 

720 if batch_idx > 0 and (batch_idx + 1) % self.checkpoint_interval == 0: 

721 self._save_base_state() 

722 self.checkpoint(batch_idx) 

723 

724 @abstractmethod 

725 def _load_model_from_checkpoint(self): 

726 """Library specific model loading function. This is 

727 specific to `torch` or `tensorflow` server.""" 

728 raise NotImplementedError("This method must be implemented in a subclass.")