Coverage for melissa/server/deep_learning/base_dl_server.py: 43%
338 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 22:25 +0100
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 22:25 +0100
1"""This script defines a base class for deep learning."""
3import sys
4import logging
5import os
6from threading import Thread
7import time
8from abc import abstractmethod
9from functools import wraps
10from typing_extensions import override
11from typing import Any, Callable, Dict, Optional, Tuple
13from mpi4py import MPI
15import cloudpickle
16from melissa.launcher import message
17from melissa.server.main import ServerError
18from melissa.server.base_server import (
19 ReceptionError,
20 BaseServer,
21 UnsupportedConfiguration,
22)
23from melissa.server.deep_learning import FrameworkType
24from melissa.server.deep_learning.train_workflow import TrainingWorkflowMixin
25from melissa.server.deep_learning.experimental_monitoring import ( # type: ignore
26 ExperimentalMonitoringMixin
27)
28from melissa.server.deep_learning.dataset import (
29 MelissaIterableDataset,
30 make_dataset,
31 make_dataloader,
32)
33from melissa.server.deep_learning.reservoir import BaseQueue, make_buffer, BufferType
34from melissa.server.deep_learning.tensorboard import (
35 TensorboardLogger,
36 convert_tb_logs_to_df,
37 make_tb_logger
38)
39from melissa.server.simulation import PartialSimulationData, Simulation, SimulationData
40from melissa.utility.networking import get_rank_and_num_server_proc
41from melissa.utility.idr_torch import SlurmEnvironment
44logger = logging.getLogger(__name__)
47class BaseDLServerError(Exception):
48 """Any base dl server non-training error."""
50 def __init__(self, msg) -> None:
51 self.msg = msg
53 def __str__(self) -> str:
54 return f"OtherError: {self.msg}"
57class TrainingError(Exception):
58 """Errors from the training loop."""
60 def __init__(self, msg) -> None:
61 self.msg = msg
63 def __str__(self) -> str:
64 return f"Training Error: {self.msg}"
67# TODO: Remove ExperimentalMonitoringMixIn the future.
68class DeepMelissaServer(
69 BaseServer,
70 ExperimentalMonitoringMixin,
71 TrainingWorkflowMixin,
72):
73 """`DeepMelissaServer` is designed for studies involving deep learning workflows.
74 It manages data buffering, logging, and configuration necessary for distributed training.
76 ### Parameters
77 - **config_dict** (`Dict[str, Any]`): A dictionary containing configuration settings for
78 initializing the server.
80 ### Attributes
81 - **dl_config** (`Dict[str, Any]`): Dictionary containing deep learning-specific configurations.
82 - **batch_size** (`int`): The size of batches used for training.
83 - **per_server_watermark** (`int`): Watermark level to determine data buffering thresholds.
84 - **buffer_size** (`int`): Total size of the buffer to store training data.
85 - **pseudo_epochs** (`int`): Number of pseudo-epochs to replicate epoch-based training
86 in an online setting.
87 - **current_sample_count** (`int`): Counter for the number of samples received.
88 - **nb_expected_batches** (`int`): Number of expected batches; deterministic for FIFO and FIRO.
89 - **nb_batches_update** (`int`): Number of batches after which updates are triggered during
90 training (For example, for validation).
91 - **checkpoint_interval** (`int`): Number of batches after which checkpointing is triggered.
92 Defaults to `nb_batches_update`.
93 - **setup_slurm_ddp** (`bool`): Boolean flag to configure SLURM-based
94 distributed data parallel training.
95 - **_model** (`Any`): Model to be trained.
96 - **_optimizer** (`Any`): Training optimizer.
97 - **__buffer** (`BaseQueue`): Instance of a buffer object (FIFO, FIRO, or Reservoir) to
98 manage training data.
99 - **__dataset** (`Dataset`): Dataset interface to provide training data, integrates with the
100 buffer and allows transformations via `process_simulation_data`.
101 - **_framework_t** (`FrameworkType`): Framework type for the iterable dataset to be
102 instantiated with `make_dataset`. It can either be `DEFAULT`, `TORCH`, or `TENSORFLOW`.
103 - **_tb_logger** (`TensorboardLogger`): Logger to handle logging of metrics for
104 visualization during training."""
106 def __init__(self, config_dict: Dict[str, Any], **kwargs) -> None:
108 super().__init__(config_dict, **kwargs)
110 # connection request requires this to be set
111 self._learning = 2
112 self._check_group_size()
114 # this lets us ping the launcher periodically
115 # ensuring the launcher does not assume the server is dead.
116 self.__run_handler_thread: bool = False
117 self.__pinger_thread: Thread = Thread(name="pinger", target=self.__loop_pings)
118 self.__receiver_thread: Thread = Thread(target=self._receive)
119 self.__first_completion: bool = False
121 self.dl_config: Dict[str, Any] = config_dict["dl_config"]
122 self.batch_size: int = self.dl_config["batch_size"]
123 self.per_server_watermark: int = self.dl_config["per_server_watermark"]
124 self.buffer_size: int = self.dl_config["buffer_size"]
125 self.pseudo_epochs: int = self.dl_config.get("pseudo_epochs", 1)
126 self.current_sample_count: int = 0
127 self.nb_expected_batches: int = 1
128 self.nb_expected_time_steps: int = 1
129 self.nb_batches_update: int = self.dl_config["nb_batches_update"]
130 self.checkpoint_interval: int = self.dl_config.get(
131 "checkpoint_interval",
132 self.nb_batches_update
133 )
135 slurm_env = SlurmEnvironment()
136 self.setup_slurm_ddp: bool = (
137 slurm_env.nnodes > 1
138 and len(slurm_env.gpu_ids) >= 1
139 )
141 self.batch_offset: int = 0
142 self._model: Any = None
143 self._optimizer: Any = None
145 # str -> enum
146 self._tb_logger: Optional[TensorboardLogger] = None
147 self._framework_t: FrameworkType = FrameworkType.DEFAULT
149 self._valid_dataloader: Any = None
151 def __configure_data_collection(self) -> None:
152 """Instantiates the data collection i.e buffer, dataset, and dataloader.
153 Users must implement `process_simulation_data`."""
155 buffer_t: BufferType = BufferType[self.dl_config.get("buffer", "FIRO")]
156 # initialize tensorboardLogger
157 self._tb_logger = make_tb_logger(
158 framework_t=self._framework_t,
159 rank=self.rank,
160 disable=not self.dl_config["tensorboard"],
161 debug=self.verbose_level >= 3
162 )
164 self._buffer: BaseQueue = make_buffer(
165 buffer_size=self.buffer_size,
166 buffer_t=buffer_t,
167 per_server_watermark=self.per_server_watermark,
168 pseudo_epochs=self.pseudo_epochs,
169 )
171 self.__dataset: MelissaIterableDataset = make_dataset(
172 framework_t=self._framework_t,
173 buffer=self._buffer,
174 tb_logger=self._tb_logger,
175 config_dict=self.config_dict,
176 transform=self.process_simulation_data,
177 )
179 self._train_dataloader = make_dataloader(
180 framework_t=self._framework_t,
181 iter_dataset=self.dataset,
182 batch_size=self.batch_size,
183 num_workers=0,
184 drop_last=True,
185 )
187 @property
188 def time_steps_known(self) -> bool:
189 return self.nb_expected_batches != 0
191 @property
192 def tb_logger(self) -> TensorboardLogger:
193 assert self._tb_logger is not None
194 return self._tb_logger
196 @property
197 def buffer(self) -> BaseQueue:
198 return self._buffer
200 @property
201 def optimizer(self) -> Any:
202 if self._optimizer is None:
203 raise AttributeError(
204 "Parent classes rely on `self.optimizer`. It must be set by the user."
205 )
206 return self._optimizer
208 @optimizer.setter
209 def optimizer(self, optimizer: Any) -> None:
210 self._optimizer = optimizer
212 @property
213 def model(self) -> Any:
214 if self._model is None:
215 raise AttributeError(
216 "Parent classes rely on `self.model`. It must be set by the user."
217 )
218 return self._model
220 @model.setter
221 def model(self, model: Any) -> None:
222 self._model = model
224 @property
225 def dataset(self) -> MelissaIterableDataset:
226 return self.__dataset
228 @dataset.setter
229 def dataset(self, dataset: MelissaIterableDataset) -> None:
230 self.__dataset = dataset
232 @property
233 def valid_dataloader(self) -> Any:
234 return self._valid_dataloader
236 @valid_dataloader.setter
237 def valid_dataloader(self, dataloader: Any) -> None:
238 if self.rank == 0:
239 logger.info(f"Rank {self.rank}>> Setting valid_dataloader.")
240 self._valid_dataloader = dataloader
242 @override
243 def _check_group_size(self) -> None:
244 """Checks if the group size was correctly set."""
246 if self.group_size > 1 and self.nb_clients % self.group_size != 0:
247 m = "Incorrect group_size, please remove or adjust this option"
248 logger.error(m)
249 self._catch_error = True
250 raise UnsupportedConfiguration(m)
252 def __loop_pings(self) -> None:
253 """Maintains communication with the launcher to ensure it
254 does not assume the server has become unresponsive."""
256 while self.__run_handler_thread:
257 self._launcherfd.send(self._encode_msg(message.Ping()))
258 logger.debug(f"Rank {self.rank}>> pinging launcher.")
259 time.sleep(5)
261 def __start_pinger_thread(self) -> None:
262 """Starts the pinger thread and set the flag."""
264 self.__run_handler_thread = True
265 self.__pinger_thread.start()
267 def __stop_pinger_thread(self) -> None:
268 """Stops the pinger thread and unsets the flag."""
270 self.__run_handler_thread = False
271 if self.__pinger_thread.is_alive():
272 self.__pinger_thread.join(timeout=1.0)
274 @override
275 def start(self) -> None:
276 """The main entrypoint for the server events."""
278 try:
279 self.__configure_data_collection()
280 if not self._restart:
281 self._launch_first_groups()
282 if not self.setup_slurm_ddp:
283 self.setup_environment()
284 else:
285 self._setup_environment_slurm()
286 if not self._restart:
287 self.model, self.optimizer = self.prepare_training_attributes()
288 else:
289 # the reinitialization from checkpoint occurs here
290 logger.info(
291 f"Rank {self.rank}>> Continuing from checkpoint restart-count={self._restart}"
292 )
293 self._restart_from_checkpoint()
294 if self.rank == 0:
295 self._kill_and_restart_simulations()
296 self.__set_expected_batches_samples_watermark()
298 self._server_online()
300 if self._tb_logger is not None:
301 self._tb_logger.close()
302 if self.dl_config.get("convert_log_to_df", False):
303 convert_tb_logs_to_df(self.rank)
305 self._server_finalize()
307 except ServerError as e:
308 logger.exception(e)
309 raise e
311 @override
312 def _server_online(self) -> None:
313 """Initiates data collection, and
314 directs the custom methods for acting on collected data."""
316 # put server receive on a separate thread.
317 # should not be accesse by user
318 self.__receiver_thread.start()
319 try:
320 self.train()
321 except TrainingError as exc:
322 self._catch_error = True
323 logger.exception(f"Exception was raised in the training thread: \n {exc}")
324 if self.no_fault_tolerance:
325 self._server_finalize(exit_=1)
326 sys.exit(1)
328 @override
329 def _server_finalize(self, exit_: int = 0):
330 """Finalizes the server operations.
332 ### Parameters
333 - **exit_** (`int`, optional): The exit status code indicating
334 the outcome of the server's operations.
335 Defaults to 0, which signifies a successful termination."""
337 if self.__receiver_thread.is_alive():
338 self.__receiver_thread.join(timeout=1.0)
340 self.comm.Barrier()
341 self.__stop_pinger_thread()
343 self.__dataset.signal_reception_over()
345 super()._server_finalize(exit_)
347 def __signal_end_of_reception(self) -> None:
348 """Unsets the reception when all data has been received,
349 and notifies `MelissaIterableDataset` to stop the batch formation."""
351 self._is_receiving = False
352 self.__dataset.signal_reception_over()
353 logger.debug("Signal end of reception.")
355 @override
356 def _receive(self) -> None:
357 """ "Handles data coming from the server object."""
359 try:
360 self._is_receiving = True
361 while not self._all_done():
362 start = time.time()
363 data = self.poll_sockets()
365 if data is not None and isinstance(data, SimulationData):
366 logger.debug(
367 f"Rank {self.rank}>> "
368 f"sim-id={data.simulation_id}, "
369 f"time-step={data.time_step} received."
370 )
371 self._buffer.put(data)
372 self.current_sample_count += 1
373 self._tb_logger.log_scalar( # type: ignore
374 "put_time", time.time() - start, self.current_sample_count
375 )
377 if self.current_sample_count % 10000 == 0:
378 consumed, _ = self.get_memory_info_in_gb()
379 self._tb_logger.log_scalar( # type: ignore
380 "memory_consumed", consumed, self.current_sample_count
381 )
382 # endwhile
383 self.comm.Barrier()
384 self.__signal_end_of_reception()
385 # ping until the training loop is done.
386 self.__start_pinger_thread()
388 except ReceptionError as exc:
389 self._catch_error = True
390 logger.exception(f"Exception was raised in the receiving thread: \n {exc}")
391 if self.no_fault_tolerance:
392 self.__signal_end_of_reception()
393 logger.warning(
394 f"Rank {self.rank}>> Training will stop once the buffers are empty."
395 )
396 self._server_finalize(exit_=1)
397 sys.exit(1)
399 @override
400 def _process_partial_data_reception(
401 self, simulation: Simulation, simulation_data: PartialSimulationData
402 ) -> None:
403 """Partial data has to be assembled for `DeepMelissaServer`.
404 Do not perform anything."""
405 return None
407 @override
408 def _process_complete_data_reception(
409 self, simulation: Simulation, simulation_data: PartialSimulationData
410 ) -> SimulationData:
412 # extract actual data from `PartialSimulationData` object.
413 all_fields_data: Dict[str, Any] = {
414 key: val.data
415 for key, val in simulation.get_data(
416 simulation_data.client_rank, simulation_data.time_step
417 ).items()
418 if isinstance(val, PartialSimulationData)
419 }
421 # dereference `received_simulation_data` as we will put the returned data in the buffer.
422 simulation.clear_data(simulation_data.client_rank, simulation_data.time_step)
424 if not self.__first_completion:
425 self.__first_completion = True
426 _, total = self.get_memory_info_in_gb()
427 expected_buffer_consumption = (32 / 8) * self.buffer_size
428 expected_buffer_consumption = expected_buffer_consumption * sum(
429 v.size for v in all_fields_data.values()
430 )
431 expected_buffer_consumption /= 1024**3
432 if expected_buffer_consumption / total < 0.2:
433 logger.warning(
434 f"Rank {self.rank}>> [Suggestion] Buffer size can be increased. "
435 f"Buffer/Main memory={expected_buffer_consumption:.2f}/{total:.2f} GB"
436 )
438 return SimulationData(
439 simulation_data.simulation_id,
440 simulation_data.time_step,
441 all_fields_data,
442 simulation.parameters,
443 )
445 @override
446 def _validate_data(self, simulation_data: PartialSimulationData) -> bool:
447 """Validates the simulation data."""
449 sim_id, field, time_step = (
450 simulation_data.simulation_id,
451 simulation_data.field,
452 simulation_data.time_step,
453 )
454 group_id = self._get_group_id_by_simulation(sim_id)
456 # handle termination messages
457 if field == "termination":
458 logger.info(
459 f"Rank {self.rank}>> [Termination] sim-id={sim_id}, "
460 f"total time-steps={time_step} received as expected."
461 )
462 # modify the time steps received accordingly
463 if self.nb_expected_batches == 0:
464 self.nb_time_steps += time_step
465 self._groups[group_id].simulations[sim_id].nb_time_steps = time_step
466 logger.info(f"Rank {self.rank}>> sim-id={sim_id} finished sending.")
467 self.nb_finished_simulations += 1
468 return False
470 # apply validation checks
471 if field not in self.fields:
472 if field != "termination":
473 logger.warning(
474 f'Rank {self.rank}>> [Bad] sim-id={sim_id}, field="{field}"'
475 )
476 return False
478 return super()._validate_data(simulation_data)
480 @override
481 def _write_final_report(self) -> None:
483 total_batches = self.comm.allreduce(
484 self.dataset.get_sample_number() // self.batch_size,
485 op=MPI.SUM
486 )
487 if self.rank == 0:
488 logger.info(f" - Number of global batches: {total_batches}")
489 super()._write_final_report()
491 def __set_expected_batches_samples_watermark(self) -> None:
492 """Computes and sets the expected samples and batches per server process."""
494 # standard case where nb_time_steps is given in the config file
495 if self.nb_time_steps > 0:
496 # ensure watermark is sufficient
497 self.__check_water_mark()
499 # Account for possible accumulated shift
500 self.nb_expected_time_steps = (
501 self.nb_clients // self.comm_size
502 ) * self.nb_time_steps
503 self.nb_expected_batches = (
504 self.nb_expected_time_steps // self.batch_size * self.pseudo_epochs
505 )
507 if (
508 self.pseudo_epochs > 1
509 and self.buffer_size != self.nb_expected_time_steps
510 ):
511 logger.warning(
512 "User tried using `pseudo_epochs` with `buffer_size` smaller than expected "
513 "samples. Setting `buffer_size` to number of expected time steps."
514 f"({self.nb_expected_time_steps})."
515 )
516 self.buffer_size = self.nb_expected_time_steps
518 logger.info(
519 f"Expecting {self.nb_expected_time_steps} "
520 f"samples across {self.nb_expected_batches} batches."
521 )
522 # when `nb_time_steps` is not known a priori
523 else:
524 logger.info("Number of expected samples a priori unknown.")
525 self.nb_expected_batches = 0
527 def __check_water_mark(self) -> None:
528 """Ensures there are sufficient samples to reach the `per_server_watermark`."""
530 total_time_steps = self.nb_time_steps * self.nb_clients
531 samples_per_server = total_time_steps // self.comm_size
532 if not self.dl_config["per_server_watermark"] <= samples_per_server:
533 raise UnsupportedConfiguration(
534 "Insufficient samples to reach `per_server_watermark`. "
535 "please increase `nb_time_steps`, "
536 "or decrease `per_server_watermark`."
537 )
539 def other_processes_finished(self, batch_idx: int) -> bool:
540 """Checks if other server processes have finished emptying their buffers.
542 ### Parameters
543 - **batch_idx** (`int`): The current batch number being processed.
545 ### Returns
546 - **`bool`**: if all other server processes have emptied their buffers."""
548 logger.debug(
549 f"Rank {self.rank}>> batch-idx={batch_idx}, "
550 f"expected-batches >= {self.nb_expected_batches}"
551 )
553 # ensure self.__dataset._is_receiving is in sync across all server
554 # processes.
555 data_available: bool = self._synchronize_data_availability()
557 # in case of pseudo_offline training, we want to avoid a
558 # server timeout so we ping the launcher with time_monitor
559 if not data_available:
560 # at this point the total number of expected samples should be known
561 # and used to update the value of self.nb_expected_batches
562 if self.nb_expected_batches == 0:
563 # per client number of expected time-steps
564 self.nb_time_steps //= self.nb_clients
565 self.__set_expected_batches_samples_watermark()
566 logger.debug(
567 f"Rank {self.rank}>> One of the server processes finished receiving at "
568 f"batch_idx={batch_idx}, "
569 f"expected-batches >= {self.nb_expected_batches}"
570 )
572 return not data_available
574 def _synchronize_data_availability(self) -> bool:
575 """Coordinates the dataset data availability status across all
576 server processes. This usually requires a library
577 specific all_reduce function (e.g. `dist.all_reduce` in pytorch).
579 Default behaviour is to check whether the buffer is empty across all MPI ranks.
580 If at least one rank, finishes, then stop the training for all ranks."""
582 local_status = int(self.dataset.has_data)
583 global_status = self.comm.allreduce(local_status, op=MPI.SUM)
585 return global_status == self.comm_size
587 @override
588 def checkpoint_state(self) -> None:
589 """Checkpoint the current state of the server."""
591 if self.no_fault_tolerance:
592 return
594 self._save_base_state()
595 # serialize the self._buffer.queue and then pickle it
596 with open(f"checkpoints/{self.rank}/buffer_state.pkl", "wb") as file_:
597 cloudpickle.dump(self._buffer.save_state(), file_)
599 @override
600 def _restart_from_checkpoint(self) -> None:
601 """Restarts the server object from a checkpoint."""
603 self._load_base_state()
605 if (
606 not os.path.exists("checkpoints/model.pt")
607 or not os.path.exists(f"checkpoints/{self.rank}/buffer_state.pkl")
608 or not os.path.exists(f"checkpoints/{self.rank}/buffer_state.pkl")
609 ):
610 raise UnsupportedConfiguration(
611 f"Rank {self.rank}>> No checkpoints and/or buffer state were found."
612 "Please make sure that fault-tolerance is set."
613 )
615 logger.info(f"Rank {self.rank}>> Restarting from checkpoint.")
616 with open(f"checkpoints/{self.rank}/buffer_state.pkl", "rb") as file_:
617 state = cloudpickle.load(file_)
618 self._buffer.load_from_state(state)
620 # lib specific loading method (torch vs tf)
621 self._load_model_from_checkpoint()
623 @abstractmethod
624 def process_simulation_data(self, data: SimulationData, config_dict: dict) -> Any:
625 """Transforms data while creating batches with `MelissaIterableDataset`.
626 See `SimulationData` for usage of attributes associated with the received data.
628 ### Parameters
629 - **data** (`SimulationData`): The data message received from the simulation
630 (pulled from the buffer).
631 - **config_dict** (`Dict[str, Any]`): A dictionary containing configuration settings.
633 ### Returns
634 - **`Any`**: Transformed data before creating a batch from it."""
635 raise NotImplementedError("This method must be implemented in a subclass.")
637 def validation(self, batch_idx: int):
638 """Predefined validation loop agnostic of frameworks."""
640 logger.info(f"Rank {self.rank} Running validation at batch_idx={batch_idx}")
641 self._on_validation_start(batch_idx)
642 for v_batch_idx, v_batch in enumerate(self._valid_dataloader):
643 self.validation_step(v_batch, v_batch_idx, batch_idx)
644 self._on_validation_end(batch_idx)
646 def train(self) -> None:
647 """Predefined training loop agnostic of frameworks."""
649 try:
650 valid_found = self._valid_dataloader is not None
651 if not valid_found:
652 logger.warning(
653 f"Rank {self.rank}>> [Ignorable] Attribute `valid_dataloader` must be set "
654 "to perform validation."
655 )
657 self._on_train_start()
658 for batch_idx, batch in enumerate(self._train_dataloader):
659 batch_idx += self.batch_offset
660 logger.debug(f"Rank {self.rank}>> Training on batch-idx={batch_idx}")
661 if self.other_processes_finished(batch_idx):
662 logger.info(
663 f"Rank {self.rank}>> At least one other process has finished. "
664 "Break training loop."
665 )
666 break
668 self._on_batch_start(batch_idx)
669 self.training_step(batch, batch_idx)
670 self._on_batch_end(batch_idx)
671 # TODO: Multi-GPU validation (not a priority)
672 # it should be framework specific thing. and may require overriding
673 # this entire if block
674 if (
675 self.rank == 0
676 and valid_found
677 and batch_idx > 0
678 and (batch_idx + 1) % self.nb_batches_update == 0
679 ):
680 self.validation(batch_idx)
682 self._checkpoint(batch_idx)
683 # end training loop
684 self._on_train_end()
685 except TrainingError as exc:
686 raise exc
688 @abstractmethod
689 def _setup_environment_slurm(self) -> None:
690 """Sets up the unique Distributed Data Parallel (DDP) environment using SLURM
691 as per the recommendations from: [Jean-Zay Documentation]
692 (http://www.idris.fr/eng/jean-zay/gpu/jean-zay-gpu-torch-multi-eng.html)"""
693 raise NotImplementedError("This method must be implemented in a subclass.")
695 @abstractmethod
696 def prepare_training_attributes(self) -> Tuple[Any, Any]:
697 """Required to configure server's `self.model` and `self.optimizer` attributes,
698 preparing them for initialization.
700 ### Returns
701 - `Tuple[Any, Any]`:
702 - **model** (`Any`): Instantiated model object.
703 - **optimizer** (`Any`): Instantiated optimizer object."""
705 raise NotImplementedError(
706 "This method must be implemented in a subclass. "
707 "Parent classes rely on `self.model`, and `self.optimizer` "
708 "which are set from the return values of this method."
709 )
711 @abstractmethod
712 def checkpoint(self, batch_idx: int, path: str = "checkpoints"):
713 """The method called to initiate full tree checkpointing. This is
714 specific to `torch` or `tensorflow` server."""
715 raise NotImplementedError("This method must be implemented in a subclass.")
717 def _checkpoint(self, batch_idx: int, path: str = "checkpoints"):
718 """Checkpointing at specific interval. The interval defaults `nb_batches_update`."""
719 if batch_idx > 0 and (batch_idx + 1) % self.checkpoint_interval == 0:
720 self.checkpoint(batch_idx, path)
722 @abstractmethod
723 def _load_model_from_checkpoint(self):
724 """Library specific model loading function. This is
725 specific to `torch` or `tensorflow` server."""
726 raise NotImplementedError("This method must be implemented in a subclass.")
729def rank_zero_only(fn_: Callable) -> Callable:
730 """Function that can be used as a decorator to enable a function/method
731 being called only on rank 0. Inspired by pytorch_lightning"""
733 rank, _ = get_rank_and_num_server_proc()
735 @wraps(fn_)
736 def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]:
737 if rank == 0:
738 return fn_(*args, **kwargs)
739 return None
741 return wrapped_fn