Server Class

melissa.server.deep_learning.base_dl_server.DeepMelissaServer

Bases: BaseServer, ExperimentalMonitoringMixin, TrainingWorkflowMixin

DeepMelissaServer is designed for studies involving deep learning workflows. It manages data buffering, logging, and configuration necessary for distributed training.

Parameters
  • config_dict (Dict[str, Any]): A dictionary containing configuration settings for initializing the server.
Attributes
  • dl_config (Dict[str, Any]): Dictionary containing deep learning-specific configurations.
  • batch_size (int): The size of batches used for training.
  • per_server_watermark (int): Watermark level to determine data buffering thresholds.
  • buffer_size (int): Total size of the buffer to store training data.
  • pseudo_epochs (int): Number of pseudo-epochs to replicate epoch-based training in an online setting.
  • current_sample_count (int): Counter for the number of samples received.
  • nb_expected_batches (int): Number of expected batches; deterministic for FIFO and FIRO.
  • nb_batches_update (int): Number of batches after which updates are triggered during training (For example, for validation).
  • checkpoint_interval (int): Number of batches after which checkpointing is triggered. Defaults to nb_batches_update.
  • setup_slurm_ddp (bool): Boolean flag to configure SLURM-based distributed data parallel training.
  • _model (Any): Model to be trained.
  • _optimizer (Any): Training optimizer.
  • __buffer (BaseQueue): Instance of a buffer object (FIFO, FIRO, or Reservoir) to manage training data.
  • __dataset (Dataset): Dataset interface to provide training data, integrates with the buffer and allows transformations via process_simulation_data.
  • _framework_t (FrameworkType): Framework type for the iterable dataset to be instantiated with make_dataset. It can either be DEFAULT, TORCH, or TENSORFLOW.
  • _tb_logger (TensorboardLogger): Logger to handle logging of metrics for visualization during training.

__configure_data_collection()

Instantiates the data collection i.e buffer, dataset, and dataloader. Users must implement process_simulation_data.

_check_group_size()

Checks if the group size was correctly set.

__loop_pings()

Maintains communication with the launcher to ensure it does not assume the server has become unresponsive.

__start_pinger_thread()

Starts the pinger thread and set the flag.

__stop_pinger_thread()

Stops the pinger thread and unsets the flag.

start()

The main entrypoint for the server events.

_server_online()

Initiates data collection, and directs the custom methods for acting on collected data.

_server_finalize(exit_=0)

Finalizes the server operations.

Parameters
  • exit_ (int, optional): The exit status code indicating the outcome of the server's operations. Defaults to 0, which signifies a successful termination.

__signal_end_of_reception()

Unsets the reception when all data has been received, and notifies MelissaIterableDataset to stop the batch formation.

_receive()

"Handles data coming from the server object.

_process_partial_data_reception(simulation, simulation_data)

Partial data has to be assembled for DeepMelissaServer. Do not perform anything.

_validate_data(simulation_data)

Validates the simulation data.

__set_expected_batches_samples_watermark()

Computes and sets the expected samples and batches per server process.

__check_water_mark()

Ensures there are sufficient samples to reach the per_server_watermark.

other_processes_finished(batch_idx)

Checks if other server processes have finished emptying their buffers.

Parameters
  • batch_idx (int): The current batch number being processed.
Returns
  • bool: if all other server processes have emptied their buffers.

_synchronize_data_availability()

Coordinates the dataset data availability status across all server processes. This usually requires a library specific all_reduce function (e.g. dist.all_reduce in pytorch).

Default behaviour is to check whether the buffer is empty across all MPI ranks. If at least one rank, finishes, then stop the training for all ranks.

checkpoint_state()

Checkpoint the current state of the server.

_restart_from_checkpoint()

Restarts the server object from a checkpoint.

process_simulation_data(data, config_dict) abstractmethod

Transforms data while creating batches with MelissaIterableDataset. See SimulationData for usage of attributes associated with the received data.

Parameters
  • data (SimulationData): The data message received from the simulation (pulled from the buffer).
  • config_dict (Dict[str, Any]): A dictionary containing configuration settings.
Returns
  • Any: Transformed data before creating a batch from it.

validation(batch_idx)

Predefined validation loop agnostic of frameworks.

train()

Predefined training loop agnostic of frameworks.

_setup_environment_slurm() abstractmethod

Sets up the unique Distributed Data Parallel (DDP) environment using SLURM as per the recommendations from: [Jean-Zay Documentation] (http://www.idris.fr/eng/jean-zay/gpu/jean-zay-gpu-torch-multi-eng.html)

prepare_training_attributes() abstractmethod

Required to configure server's self.model and self.optimizer attributes, preparing them for initialization.

Returns
  • Tuple[Any, Any]:
    • model (Any): Instantiated model object.
    • optimizer (Any): Instantiated optimizer object.

checkpoint(batch_idx, path='checkpoints') abstractmethod

The method called to initiate full tree checkpointing. This is specific to torch or tensorflow server.

_checkpoint(batch_idx, path='checkpoints')

Checkpointing at specific interval. The interval defaults nb_batches_update.

_load_model_from_checkpoint() abstractmethod

Library specific model loading function. This is specific to torch or tensorflow server.

melissa.server.deep_learning.train_workflow.TrainingWorkflowMixin

Bases: ABC

Provides a structure for overriding training, validation, and hook methods.

training_step(batch, batch_idx, **kwargs) abstractmethod

Defines the logic for a single training step.

Parameters
  • batch (Any): A single batch of data.
  • batch_idx (int): The index of the batch.

validation_step(batch, valid_batch_idx, batch_idx, **kwargs)

Defines the logic for a single validation step.

Parameters
  • batch (Any): A single batch of validation data.
  • valid_batch_idx (int): The index of the validation batch.
  • batch_idx (int): The index of the batch.
Returns
  • Dict[str, Any]: Output from the validation step.

on_train_start()

Hook called at the start of training.

on_train_end()

Hook called at the end of training.

on_batch_start(batch_idx)

Hook called at the start of batch iteration.

on_batch_end(batch_idx)

Hook called at the end of batch iteration.

on_validation_start(batch_idx)

Hook called at the start of validation.

on_validation_end(batch_idx)

Hook called at the end of validation.

Training-specific Exceptions

melissa.server.deep_learning.base_dl_server.TrainingError

Bases: Exception

Errors from the training loop.

melissa.server.deep_learning.base_dl_server.UnsupportedConfiguration

Bases: Exception

Errors from the configuration file.

Others

melissa.server.deep_learning.base_dl_server.rank_zero_only()

Function that can be used as a decorator to enable a function/method being called only on rank 0. Inspired by pytorch_lightning