Server Class¶
melissa.server.deep_learning.base_dl_server.DeepMelissaServer¶
Bases: BaseServer
, ExperimentalMonitoringMixin
, TrainingWorkflowMixin
DeepMelissaServer
is designed for studies involving deep learning workflows.
It manages data buffering, logging, and configuration necessary for distributed training.
Parameters¶
- config_dict (
Dict[str, Any]
): A dictionary containing configuration settings for initializing the server.
Attributes¶
- dl_config (
Dict[str, Any]
): Dictionary containing deep learning-specific configurations. - batch_size (
int
): The size of batches used for training. - per_server_watermark (
int
): Watermark level to determine data buffering thresholds. - buffer_size (
int
): Total size of the buffer to store training data. - pseudo_epochs (
int
): Number of pseudo-epochs to replicate epoch-based training in an online setting. - current_sample_count (
int
): Counter for the number of samples received. - nb_expected_batches (
int
): Number of expected batches; deterministic for FIFO and FIRO. - nb_batches_update (
int
): Number of batches after which updates are triggered during training (For example, for validation). - checkpoint_interval (
int
): Number of batches after which checkpointing is triggered. Defaults tonb_batches_update
. - setup_slurm_ddp (
bool
): Boolean flag to configure SLURM-based distributed data parallel training. - _model (
Any
): Model to be trained. - _optimizer (
Any
): Training optimizer. - __buffer (
BaseQueue
): Instance of a buffer object (FIFO, FIRO, or Reservoir) to manage training data. - __dataset (
Dataset
): Dataset interface to provide training data, integrates with the buffer and allows transformations viaprocess_simulation_data
. - _framework_t (
FrameworkType
): Framework type for the iterable dataset to be instantiated withmake_dataset
. It can either beDEFAULT
,TORCH
, orTENSORFLOW
. - _tb_logger (
TensorboardLogger
): Logger to handle logging of metrics for visualization during training.
__configure_data_collection()
¶
Instantiates the data collection i.e buffer, dataset, and dataloader.
Users must implement process_simulation_data
.
_check_group_size()
¶
Checks if the group size was correctly set.
__loop_pings()
¶
Maintains communication with the launcher to ensure it does not assume the server has become unresponsive.
__start_pinger_thread()
¶
Starts the pinger thread and set the flag.
__stop_pinger_thread()
¶
Stops the pinger thread and unsets the flag.
start()
¶
The main entrypoint for the server events.
_server_online()
¶
Initiates data collection, and directs the custom methods for acting on collected data.
_server_finalize(exit_=0)
¶
Finalizes the server operations.
Parameters¶
- exit_ (
int
, optional): The exit status code indicating the outcome of the server's operations. Defaults to 0, which signifies a successful termination.
__signal_end_of_reception()
¶
Unsets the reception when all data has been received,
and notifies MelissaIterableDataset
to stop the batch formation.
_receive()
¶
"Handles data coming from the server object.
_process_partial_data_reception(simulation, simulation_data)
¶
Partial data has to be assembled for DeepMelissaServer
.
Do not perform anything.
_validate_data(simulation_data)
¶
Validates the simulation data.
__set_expected_batches_samples_watermark()
¶
Computes and sets the expected samples and batches per server process.
__check_water_mark()
¶
Ensures there are sufficient samples to reach the per_server_watermark
.
other_processes_finished(batch_idx)
¶
_synchronize_data_availability()
¶
Coordinates the dataset data availability status across all
server processes. This usually requires a library
specific all_reduce function (e.g. dist.all_reduce
in pytorch).
Default behaviour is to check whether the buffer is empty across all MPI ranks. If at least one rank, finishes, then stop the training for all ranks.
checkpoint_state()
¶
Checkpoint the current state of the server.
_restart_from_checkpoint()
¶
Restarts the server object from a checkpoint.
process_simulation_data(data, config_dict)
abstractmethod
¶
Transforms data while creating batches with MelissaIterableDataset
.
See SimulationData
for usage of attributes associated with the received data.
Parameters¶
- data (
SimulationData
): The data message received from the simulation (pulled from the buffer). - config_dict (
Dict[str, Any]
): A dictionary containing configuration settings.
Returns¶
Any
: Transformed data before creating a batch from it.
validation(batch_idx)
¶
Predefined validation loop agnostic of frameworks.
train()
¶
Predefined training loop agnostic of frameworks.
_setup_environment_slurm()
abstractmethod
¶
Sets up the unique Distributed Data Parallel (DDP) environment using SLURM as per the recommendations from: [Jean-Zay Documentation] (http://www.idris.fr/eng/jean-zay/gpu/jean-zay-gpu-torch-multi-eng.html)
prepare_training_attributes()
abstractmethod
¶
Required to configure server's self.model
and self.optimizer
attributes,
preparing them for initialization.
Returns¶
Tuple[Any, Any]
:- model (
Any
): Instantiated model object. - optimizer (
Any
): Instantiated optimizer object.
- model (
checkpoint(batch_idx, path='checkpoints')
abstractmethod
¶
The method called to initiate full tree checkpointing. This is
specific to torch
or tensorflow
server.
_checkpoint(batch_idx, path='checkpoints')
¶
Checkpointing at specific interval. The interval defaults nb_batches_update
.
_load_model_from_checkpoint()
abstractmethod
¶
Library specific model loading function. This is
specific to torch
or tensorflow
server.
melissa.server.deep_learning.train_workflow.TrainingWorkflowMixin¶
Bases: ABC
Provides a structure for overriding training, validation, and hook methods.
training_step(batch, batch_idx, **kwargs)
abstractmethod
¶
Defines the logic for a single training step.
Parameters¶
- batch (
Any
): A single batch of data. - batch_idx (
int
): The index of the batch.
validation_step(batch, valid_batch_idx, batch_idx, **kwargs)
¶
on_train_start()
¶
Hook called at the start of training.
on_train_end()
¶
Hook called at the end of training.
on_batch_start(batch_idx)
¶
Hook called at the start of batch iteration.
on_batch_end(batch_idx)
¶
Hook called at the end of batch iteration.
on_validation_start(batch_idx)
¶
Hook called at the start of validation.
on_validation_end(batch_idx)
¶
Hook called at the end of validation.