Coverage for melissa/server/deep_learning/base_dl_server.py: 39%
315 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
1"""This script defines a base class for deep learning."""
3import logging
4import os
5from threading import Thread, current_thread
6import time
7from abc import abstractmethod
8from typing_extensions import override
9from typing import Any, Dict, Optional, Tuple
11from mpi4py import MPI
13import cloudpickle
14from melissa.server.base_server import BaseServer
15from melissa.server.deep_learning import FrameworkType
16from melissa.server.deep_learning.train_workflow import TrainingWorkflowMixin
17from melissa.server.deep_learning.dataset import (
18 MelissaIterableDataset,
19 make_dataset,
20 make_dataloader,
21)
22from melissa.server.deep_learning.reservoir import BaseQueue, make_buffer, BufferType
23from melissa.server.deep_learning.tensorboard import (
24 TensorboardLogger,
25 convert_tb_logs_to_df,
26 make_tb_logger
27)
28from melissa.server.simulation import PartialSimulationData, Simulation, SimulationData
29from melissa.server.exceptions import (
30 ConfigurationFileError,
31 TrainingError,
32 ReceptionError,
33 insert_exception,
34 check_and_raise_exceptions_from_other_thread
35)
36from melissa.utility.metadata import Payload
39logger = logging.getLogger(__name__)
42class DeepMelissaServer(
43 BaseServer,
44 TrainingWorkflowMixin,
45):
46 """`DeepMelissaServer` is designed for studies involving deep learning workflows.
47 It manages data buffering, logging, and configuration necessary for distributed training.
49 ### Parameters
50 - **config_dict** (`Dict[str, Any]`): A dictionary containing configuration settings for
51 initializing the server.
53 ### Attributes
55 - **_bind_simulation_to_server_rank** (`int`): Whether to bind sending all timesteps of a
56 simulation to the same server rank. By default, timesteps are sent in a round-robin fashion.
57 - **dl_config** (`Dict[str, Any]`): Dictionary containing deep learning-specific configurations.
58 - **batch_size** (`int`): The size of batches used for training.
59 - **per_server_watermark** (`int`): Watermark level to determine data buffering thresholds.
60 - **buffer_size** (`int`): Total size of the buffer to store training data.
61 - **_batches_watermark_set** (`bool`): Prevents recalculation of expected number of batches
62 given that the time-steps are known.
63 - **pseudo_epochs** (`int`): Number of pseudo-epochs to replicate epoch-based training
64 in an online setting.
65 - **current_sample_count** (`int`): Counter for the number of samples received.
66 - **self.nb_local_batches** (`int`): Number of batches processed. Local to MPI rank.
67 - **nb_expected_batches** (`int`): Number of expected batches; deterministic for FIFO and FIRO.
68 - **nb_batches_update** (`int`): Number of batches after which updates are triggered during
69 training (For example, for validation).
70 - **checkpoint_interval** (`int`): Number of batches after which checkpointing is triggered.
71 Defaults to `nb_batches_update`.
72 - **_model** (`Any`): Model to be trained.
73 - **_optimizer** (`Any`): Training optimizer.
74 - **__buffer** (`BaseQueue`): Instance of a buffer object (FIFO, FIRO, or Reservoir) to
75 manage training data.
76 - **__dataset** (`Dataset`): Dataset interface to provide training data, integrates with the
77 buffer and allows transformations via `process_simulation_data`.
78 - **_framework_t** (`FrameworkType`): Framework type for the iterable dataset to be
79 instantiated with `make_dataset`. It can either be `DEFAULT`, `TORCH`, or `TENSORFLOW`.
80 - **_tb_logger** (`TensorboardLogger`): Logger to handle logging of metrics for
81 visualization during training."""
83 def __init__(self, config_dict: Dict[str, Any], **kwargs) -> None:
85 super().__init__(config_dict, **kwargs)
87 # connection request requires this to be set
88 self._learning = 2
89 self._check_group_size()
90 self._bind_simulation_to_server_rank = int(
91 config_dict["study_options"].get(
92 "bind_simulation_to_server_rank",
93 0
94 )
95 )
96 self.__receiver_thread: Thread = Thread(target=self._receive, name="receive_thread")
97 self.__first_completion: bool = False
99 self.dl_config: Dict[str, Any] = config_dict["dl_config"]
100 self.batch_size: int = self.dl_config["batch_size"]
101 self.per_server_watermark: int = self.dl_config["per_server_watermark"]
102 self.buffer_size: int = self.dl_config["buffer_size"]
103 self._batches_watermark_set: bool = False
104 self.pseudo_epochs: int = self.dl_config.get("pseudo_epochs", 1)
105 self.current_sample_count: int = 0
106 self.nb_local_batches: int = 0
107 self.nb_expected_batches: int = 1
108 self.nb_expected_time_steps: int = 1
109 # adjust the value based on the total server ranks launched
110 self.nb_batches_update: int = self.dl_config["nb_batches_update"] // self.comm_size
111 self.checkpoint_interval: int = self.dl_config.get(
112 "checkpoint_interval",
113 self.nb_batches_update
114 )
116 self.batch_offset: int = 0
117 self._model: Any = None
118 self._optimizer: Any = None
120 # str -> enum
121 self._tb_logger: Optional[TensorboardLogger] = None
122 self._framework_t: FrameworkType = FrameworkType.DEFAULT
124 self._valid_dataloader: Any = None
126 self.ckpt_buffer_state_path: str = f"checkpoints/{self.rank}/buffer_state.pkl"
127 self.ckpt_model_path: str = "checkpoints/model.pt"
129 def __adjust_buffer_size_for_pseudo_epochs(self) -> None:
130 """Adjusts buffer size before buffer instantiation for pseudo-offline training."""
132 if not self.time_steps_known:
133 return
135 if self.pseudo_epochs > 1:
136 expected_samples = (self.nb_clients // self.comm_size) * self.nb_time_steps
138 if self.buffer_size < expected_samples:
139 logger.warning(
140 f"Rank {self.rank}>> Adjusting buffer_size from {self.buffer_size} "
141 f"to {expected_samples} for pseudo_epochs={self.pseudo_epochs}"
142 )
143 self.buffer_size = expected_samples
144 self.dl_config["buffer_size"] = expected_samples
146 def __configure_data_collection(self) -> None:
147 """Instantiates the data collection i.e buffer, dataset, and dataloader.
148 Users must implement `process_simulation_data`."""
150 self.__adjust_buffer_size_for_pseudo_epochs()
152 buffer_str = self.dl_config.get("buffer", "FIRO")
153 buffer_t: BufferType = BufferType[buffer_str]
154 if self._bind_simulation_to_server_rank:
155 if "reservoir" not in buffer_str.lower():
156 raise ConfigurationFileError(
157 "Binding a simulation to a specific server rank is supported only with "
158 "the `Reservoir` type buffers."
159 )
160 if self._job_limit < self.comm_size or self.nb_clients < self.comm_size:
161 raise ConfigurationFileError(
162 "Binding a simulation to a specific server rank is supported only when the "
163 "`job_limit` and `parameter_sweep_size` is greater than or equal to "
164 "the total server ranks launched."
165 )
167 if self.pseudo_epochs > 1 and buffer_t is not BufferType.FIRO:
168 logger.warning(f"Rank {self.rank}>> `pseudo_epochs > 1` is only supported "
169 "for `FIRO` buffer.")
170 self.pseudo_epochs = 1
172 rr_type = "Simulations" if self._bind_simulation_to_server_rank else "Timesteps"
173 logger.info(
174 f"Rank {self.rank}>> Round-robin strategy has been set to \"{rr_type}\" "
175 "across the server ranks."
176 )
177 # initialize tensorboardLogger
178 self._tb_logger = make_tb_logger(
179 framework_t=self._framework_t,
180 rank=self.rank,
181 disable=not self.dl_config["tensorboard"],
182 debug=self.verbose_level >= 3
183 )
185 self._buffer: BaseQueue = make_buffer(
186 buffer_size=self.buffer_size,
187 buffer_t=buffer_t,
188 per_server_watermark=self.per_server_watermark,
189 pseudo_epochs=self.pseudo_epochs,
190 sweep_size=self.nb_clients,
191 comm_size=self.comm_size,
192 nb_time_steps=self.nb_time_steps,
193 )
195 self.__dataset: MelissaIterableDataset = make_dataset(
196 framework_t=self._framework_t,
197 buffer=self._buffer,
198 tb_logger=self._tb_logger,
199 config_dict=self.config_dict,
200 transform=self.process_simulation_data,
201 )
203 self._train_dataloader = make_dataloader(
204 framework_t=self._framework_t,
205 iter_dataset=self.dataset,
206 batch_size=self.batch_size,
207 num_workers=0,
208 drop_last=True,
209 )
211 @property
212 def tb_logger(self) -> TensorboardLogger:
213 assert self._tb_logger is not None
214 return self._tb_logger
216 @property
217 def buffer(self) -> BaseQueue:
218 return self._buffer
220 @property
221 def optimizer(self) -> Any:
222 if self._optimizer is None:
223 raise AttributeError(
224 "Parent classes rely on `self.optimizer`. It must be set by the user."
225 )
226 return self._optimizer
228 @optimizer.setter
229 def optimizer(self, optimizer: Any) -> None:
230 self._optimizer = optimizer
232 @property
233 def model(self) -> Any:
234 if self._model is None:
235 raise AttributeError(
236 "Parent classes rely on `self.model`. It must be set by the user."
237 )
238 return self._model
240 @model.setter
241 def model(self, model: Any) -> None:
242 self._model = model
244 @property
245 def dataset(self) -> MelissaIterableDataset:
246 return self.__dataset
248 @dataset.setter
249 def dataset(self, dataset: MelissaIterableDataset) -> None:
250 self.__dataset = dataset
252 @property
253 def valid_dataloader(self) -> Any:
254 return self._valid_dataloader
256 @valid_dataloader.setter
257 def valid_dataloader(self, dataloader: Any) -> None:
258 logger.info(f"Rank {self.rank}>> Setting valid_dataloader.")
259 self._valid_dataloader = dataloader
260 if self.comm_size > 1:
261 logger.warning(
262 f"Rank {self.rank}>> `valid_dataloader` must load different data "
263 "for each server rank."
264 )
265 logger.warning(
266 f"Rank {self.rank}>> Users must call `get_reduced_validation_loss` "
267 "to obtain the mean validation loss across all server ranks."
268 )
270 def _restart_simulations(self):
271 """Alias for `_restart_groups` method."""
272 self._restart_groups()
274 @override
275 def _check_group_size(self) -> None:
276 """Checks if the group size was correctly set."""
278 if self.group_size > 1 and self.nb_clients % self.group_size != 0:
279 raise ConfigurationFileError(
280 "Incorrect group_size, please remove or adjust this option"
281 )
283 @override
284 def start(self) -> None:
285 """The main entrypoint for the server events."""
287 self.__configure_data_collection()
288 if self._restart == 0:
289 self._launch_groups(list(range(0, self._job_limit)))
291 self.setup_environment()
293 if self._restart == 0:
294 self.model, self.optimizer = self.prepare_training_attributes()
295 else:
296 self._restart_from_checkpoint()
297 self._restart_simulations()
298 self.__set_expected_batches_samples_watermark()
300 self._server_online()
301 self._server_offline()
302 self._server_finalize()
304 @override
305 def _server_online(self) -> None:
306 """Initiates data collection, and
307 directs the custom methods for acting on collected data."""
309 # put server receive on a separate thread.
310 # should not be accesse by user
311 self.__receiver_thread.start()
312 try:
313 self.train()
314 except ReceptionError:
315 logger.error(
316 f"Rank {self.rank}>> An error occured on the receiving thread."
317 )
318 except Exception as e:
319 insert_exception(e)
320 raise e
322 @override
323 def _server_offline(self):
324 """Optional. Post processing steps."""
325 if self._tb_logger is not None:
326 self._tb_logger.close()
327 if self.dl_config.get("convert_log_to_df", False):
328 convert_tb_logs_to_df(self.rank)
330 @override
331 def _server_finalize(self, exit_: int = 0):
332 """Finalizes the server operations.
334 ### Parameters
335 - **exit_** (`int`, optional): The exit status code indicating
336 the outcome of the server's operations.
337 Defaults to 0, which signifies a successful termination."""
339 if (
340 current_thread() != self.__receiver_thread
341 and self.__receiver_thread.is_alive()
342 ):
343 self.__receiver_thread.join(timeout=1.0)
345 self._stop_pinger_thread()
347 self.__dataset.signal_reception_over()
349 super()._server_finalize(exit_)
351 def __signal_end_of_reception(self) -> None:
352 """Unsets the reception when all data has been received,
353 and notifies `MelissaIterableDataset` to stop the batch formation."""
355 with self._consistency_lock:
356 self._is_receiving = False
357 self.__dataset.signal_reception_over()
358 logger.debug("Signal end of reception.")
360 @override
361 def _receive(self) -> None:
362 """ "Handles data coming from the server object."""
364 try:
365 self._is_receiving = True
366 while not self._all_done():
367 check_and_raise_exceptions_from_other_thread()
368 start = time.time()
369 data = self.poll_sockets()
371 if data is not None and isinstance(data, SimulationData):
372 logger.debug(
373 f"Rank {self.rank}>> "
374 f"sim-id={data.simulation_id}, "
375 f"time-step={data.time_step} received."
376 )
377 self._buffer.put(data)
378 self.current_sample_count += 1
379 self._tb_logger.log_scalar( # type: ignore
380 "put_time", time.time() - start, self.current_sample_count
381 )
383 if self.current_sample_count % 10000 == 0:
384 consumed, _ = self.get_memory_info_in_gb()
385 self._tb_logger.log_scalar( # type: ignore
386 "memory_consumed", consumed, self.current_sample_count
387 )
388 # endwhile
389 self.comm.Barrier()
390 self.__signal_end_of_reception()
391 # ping until the training loop is done.
392 self._start_pinger_thread()
394 except TrainingError:
395 logger.error(
396 f"Rank {self.rank}>> An error occured on the training thread."
397 )
398 except Exception as e:
399 insert_exception(e)
400 self.__signal_end_of_reception()
401 raise e
402 finally:
403 self._is_receiving = False
405 @override
406 def _process_partial_data_reception(
407 self, simulation: Simulation, simulation_data: PartialSimulationData
408 ) -> None:
409 """Partial data has to be assembled for `DeepMelissaServer`.
410 Do not perform anything."""
411 return None
413 @override
414 def _process_complete_data_reception(
415 self, simulation: Simulation, simulation_data: PartialSimulationData
416 ) -> SimulationData:
418 # extract actual data from `PartialSimulationData` object.
419 all_fields_payload: Dict[str, Payload] = {
420 key: val.payload
421 for key, val in simulation.get_data(
422 simulation_data.client_rank,
423 simulation_data.time_step
424 ).items()
425 if isinstance(val, PartialSimulationData)
426 }
428 # dereference `received_simulation_data` as we will put the returned data in the buffer.
429 simulation.clear_data(simulation_data.client_rank, simulation_data.time_step)
431 if not self.__first_completion:
432 self.__first_completion = True
433 _, total = self.get_memory_info_in_gb()
434 expected_buffer_consumption = (32 / 8) * self.buffer_size
435 expected_buffer_consumption = expected_buffer_consumption * sum(
436 v.data.size for v in all_fields_payload.values()
437 )
438 expected_buffer_consumption /= 1024**3
439 if expected_buffer_consumption / total < 0.2:
440 logger.warning(
441 f"Rank {self.rank}>> [Suggestion] Buffer size can be increased. "
442 f"Buffer/Main memory={expected_buffer_consumption:.2f}/{total:.2f} GB"
443 )
445 return SimulationData(
446 simulation_data.simulation_id,
447 simulation_data.time_step,
448 all_fields_payload,
449 simulation.parameters,
450 )
452 @override
453 def _write_final_report(self) -> None:
455 global_batches = self.comm.allreduce(
456 self.nb_local_batches,
457 op=MPI.SUM
458 )
459 if self.rank == 0:
460 logger.info(f" - Number of Global Batches: {global_batches}")
461 super()._write_final_report()
463 def __set_expected_batches_samples_watermark(self) -> None:
464 """Computes and sets the expected samples and batches per server process."""
466 if self._batches_watermark_set:
467 return
469 if not self.time_steps_known and self.pseudo_epochs > 1:
470 raise ConfigurationFileError(
471 "`nb_time_steps` must be provided to adjust the buffer "
472 "size for pseudo-offline training."
473 )
475 # standard case where nb_time_steps is given in the config file
476 if self.time_steps_known:
477 self._batches_watermark_set = True
478 # ensure watermark is sufficient
479 self.__check_water_mark()
481 # unique samples
482 self.nb_expected_time_steps = (
483 self.nb_clients // self.comm_size
484 ) * self.nb_time_steps
486 # unique batches
487 self.nb_expected_batches = self.nb_expected_time_steps // self.batch_size
489 # calculate expected batches considering pseudo epochs
490 if self.pseudo_epochs > 1:
491 # for pseudo-offline, samples are reused across epochs
492 self.nb_expected_batches *= self.pseudo_epochs
494 # verify buffer size was properly adjusted
495 # should have been done in __adjust_buffer_size_for_pseudo_epochs
496 assert self.buffer_size == self.nb_expected_time_steps, (
497 f"Rank {self.rank}>> Buffer size mismatch detected! "
498 f"buffer_size={self.buffer_size}, expected={self.nb_expected_time_steps}. "
499 "This should have been adjusted before buffer instantiation."
500 )
502 logger.info(
503 f"Rank {self.rank}>> [Pseudo-offline Training] Expecting "
504 f"{self.nb_expected_time_steps * self.pseudo_epochs} samples "
505 f"across {self.nb_expected_batches} batches."
506 )
507 else:
508 logger.info(
509 f"Rank {self.rank}>> [Online Training] Expecting "
510 f"{self.nb_expected_time_steps} samples "
511 f"across {self.nb_expected_batches} batches."
512 )
514 # when `nb_time_steps` is not known a priori
515 else:
516 logger.info(f"Rank {self.rank}>> Number of expected samples a priori unknown.")
517 self.nb_expected_batches = 0
519 def __check_water_mark(self) -> None:
520 """Ensures there are sufficient samples to reach the `per_server_watermark`."""
522 total_time_steps = self.nb_time_steps * self.nb_clients
523 samples_per_server = total_time_steps // self.comm_size
524 if not self.dl_config["per_server_watermark"] <= samples_per_server:
525 raise ConfigurationFileError(
526 "Insufficient samples to reach `per_server_watermark`. "
527 "please increase `nb_time_steps`, "
528 "or decrease `per_server_watermark`."
529 )
531 def other_processes_finished(self, batch_idx: int) -> bool:
532 """Checks if other server processes have finished emptying their buffers.
534 ### Parameters
535 - **batch_idx** (`int`): The current batch number being processed.
537 ### Returns
538 - **`bool`**: if all other server processes have emptied their buffers."""
540 logger.debug(
541 f"Rank {self.rank}>> batch-idx={batch_idx}, "
542 f"expected-batches >= {self.nb_expected_batches}"
543 )
544 with self._consistency_lock:
545 # ensure self.__dataset._is_receiving is in sync across all server
546 # processes.
547 data_available: bool = self._synchronize_data_availability()
549 if not data_available:
550 # at this point the total number of expected samples should be known
551 # and used to update the value of self.nb_expected_batches
552 self.__set_expected_batches_samples_watermark()
553 logger.debug(
554 f"Rank {self.rank}>> One of the server processes finished receiving at "
555 f"batch_idx={batch_idx}, "
556 f"expected-batches >= {self.nb_expected_batches}"
557 )
559 return not data_available
561 def _synchronize_data_availability(self) -> bool:
562 """Coordinates the dataset data availability status across all
563 server processes. This usually requires a library
564 specific all_reduce function (e.g. `dist.all_reduce` in pytorch).
566 Default behaviour is to check whether the buffer is empty across all MPI ranks.
567 If at least one rank, finishes, then stop the training for all ranks."""
569 local_status = int(self.dataset.has_data)
570 global_status = self.comm.allreduce(local_status, op=MPI.SUM)
572 return global_status == self.comm_size
574 @override
575 def _save_base_state(self) -> None:
576 """Checkpoint the current state of the server."""
578 if self.no_fault_tolerance:
579 return
581 super()._save_base_state()
582 with self.buffer.mutex:
583 # serialize the self._buffer.queue and then pickle it
584 with open(self.ckpt_buffer_state_path, "wb") as file_:
585 cloudpickle.dump(self._buffer.save_state(), file_)
587 @override
588 def _restart_from_checkpoint(self, **kwargs) -> None:
589 """Restarts the server object from a checkpoint."""
590 logger.info(
591 f"Rank {self.rank}>> Continuing from checkpoint restart-count={self._restart}"
592 )
594 self._load_base_state()
596 if (
597 not os.path.exists(self.ckpt_model_path)
598 or not os.path.exists(self.ckpt_buffer_state_path)
599 ):
600 raise ConfigurationFileError(
601 f"Rank {self.rank}>> No checkpoints and/or buffer state were found."
602 "Please make sure that fault-tolerance is set."
603 )
605 with open(self.ckpt_buffer_state_path, "rb") as file_:
606 state = cloudpickle.load(file_)
607 self._buffer.load_from_state(state)
609 # lib specific loading method (torch vs tf)
610 self._load_model_from_checkpoint()
612 @abstractmethod
613 def process_simulation_data(self, data: SimulationData, config_dict: dict) -> Any:
614 """Transforms data while creating batches with `MelissaIterableDataset`.
615 See `SimulationData` for usage of attributes associated with the received data.
617 ### Parameters
618 - **data** (`SimulationData`): The data message received from the simulation
619 (pulled from the buffer).
620 - **config_dict** (`Dict[str, Any]`): A dictionary containing configuration settings.
622 ### Returns
623 - **`Any`**: Transformed data before creating a batch from it."""
624 raise NotImplementedError("This method must be implemented in a subclass.")
626 def get_reduced_validation_loss(self, valid_loss: float) -> float:
627 """Returns reduced validation loss across all server ranks."""
628 return self.comm.allreduce(valid_loss, op=MPI.SUM) / self.comm.size
630 def validation(self, batch_idx: int):
631 """Predefined validation loop agnostic of frameworks.
632 This can be overridable by the user."""
634 if (
635 self._valid_dataloader is not None
636 and batch_idx > 0
637 and (batch_idx + 1) % self.nb_batches_update == 0
638 ):
639 logger.info(f"Rank {self.rank}>> Running validation at batch_idx={batch_idx}")
640 self._on_validation_start(batch_idx)
641 for v_batch_idx, v_batch in enumerate(self._valid_dataloader):
642 self.validation_step(v_batch, v_batch_idx, batch_idx)
643 self._on_validation_end(batch_idx)
645 def train(self) -> None:
646 """Predefined training loop agnostic of frameworks."""
648 if self._valid_dataloader is None:
649 logger.warning(
650 f"Rank {self.rank}>> [Ignorable] Attribute `valid_dataloader` must be set "
651 "to perform validation."
652 )
654 self._on_train_start()
656 batch_idx = -1
657 last_log_time = time.time()
658 last_log_batch_idx = self.batch_offset
660 for batch_idx, batch in enumerate(self._train_dataloader, start=self.batch_offset):
661 logger.debug(f"Rank {self.rank}>> Training on batch-idx={batch_idx}")
662 if self.other_processes_finished(batch_idx):
663 logger.info(
664 f"Rank {self.rank}>> At least one other process has finished. "
665 "Break training loop."
666 )
667 break
669 self._on_batch_start(batch_idx)
670 self.training_step(batch, batch_idx)
671 self._on_batch_end(batch_idx)
673 self.validation(batch_idx)
675 self._checkpoint(batch_idx)
676 if (batch_idx + 1) % self.nb_batches_update == 0:
677 current_time = time.time()
678 batches_processed = batch_idx - last_log_batch_idx
679 time_elapsed = current_time - last_log_time
681 if time_elapsed > 0:
682 batches_per_sec = batches_processed / time_elapsed
683 logger.info(
684 f"Rank {self.rank}>> [Insight] "
685 f"Training throughput={batches_per_sec:.2f} batches/sec."
686 )
688 last_log_time = current_time
689 last_log_batch_idx = batch_idx
691 # end training loop
692 self._on_train_end()
693 self.nb_local_batches = batch_idx + 1
695 @abstractmethod
696 def prepare_training_attributes(self) -> Tuple[Any, Any]:
697 """Required to configure server's `self.model` and `self.optimizer` attributes,
698 preparing them for initialization.
700 ### Returns
701 - `Tuple[Any, Any]`:
702 - **model** (`Any`): Instantiated model object.
703 - **optimizer** (`Any`): Instantiated optimizer object."""
705 raise NotImplementedError(
706 "This method must be implemented in a subclass. "
707 "Parent classes rely on `self.model`, and `self.optimizer` "
708 "which are set from the return values of this method."
709 )
711 @abstractmethod
712 def checkpoint(self, batch_idx: int):
713 """The method called to initiate full tree checkpointing. This is
714 specific to `torch` or `tensorflow` server."""
715 raise NotImplementedError("This method must be implemented in a subclass.")
717 @override
718 def _checkpoint(self, batch_idx: int): # type: ignore
719 """Checkpointing at specific interval. The interval defaults `nb_batches_update`."""
720 if batch_idx > 0 and (batch_idx + 1) % self.checkpoint_interval == 0:
721 self._save_base_state()
722 self.checkpoint(batch_idx)
724 @abstractmethod
725 def _load_model_from_checkpoint(self):
726 """Library specific model loading function. This is
727 specific to `torch` or `tensorflow` server."""
728 raise NotImplementedError("This method must be implemented in a subclass.")