Server Class

melissa.server.deep_learning.base_dl_server.DeepMelissaServer

Bases: BaseServer, TrainingWorkflowMixin

DeepMelissaServer is designed for studies involving deep learning workflows. It manages data buffering, logging, and configuration necessary for distributed training.

Parameters
  • config_dict (Dict[str, Any]): A dictionary containing configuration settings for initializing the server.
Attributes
  • _bind_simulation_to_server_rank (int): Whether to bind sending all timesteps of a simulation to the same server rank. By default, timesteps are sent in a round-robin fashion.
  • dl_config (Dict[str, Any]): Dictionary containing deep learning-specific configurations.
  • batch_size (int): The size of batches used for training.
  • per_server_watermark (int): Watermark level to determine data buffering thresholds.
  • buffer_size (int): Total size of the buffer to store training data.
  • _batches_watermark_set (bool): Prevents recalculation of expected number of batches given that the time-steps are known.
  • pseudo_epochs (int): Number of pseudo-epochs to replicate epoch-based training in an online setting.
  • current_sample_count (int): Counter for the number of samples received.
  • self.nb_local_batches (int): Number of batches processed. Local to MPI rank.
  • nb_expected_batches (int): Number of expected batches; deterministic for FIFO and FIRO.
  • nb_batches_update (int): Number of batches after which updates are triggered during training (For example, for validation).
  • checkpoint_interval (int): Number of batches after which checkpointing is triggered. Defaults to nb_batches_update.
  • _model (Any): Model to be trained.
  • _optimizer (Any): Training optimizer.
  • __buffer (BaseQueue): Instance of a buffer object (FIFO, FIRO, or Reservoir) to manage training data.
  • __dataset (Dataset): Dataset interface to provide training data, integrates with the buffer and allows transformations via process_simulation_data.
  • _framework_t (FrameworkType): Framework type for the iterable dataset to be instantiated with make_dataset. It can either be DEFAULT, TORCH, or TENSORFLOW.
  • _tb_logger (TensorboardLogger): Logger to handle logging of metrics for visualization during training.

__adjust_buffer_size_for_pseudo_epochs()

Adjusts buffer size before buffer instantiation for pseudo-offline training.

__configure_data_collection()

Instantiates the data collection i.e buffer, dataset, and dataloader. Users must implement process_simulation_data.

start()

The main entrypoint for the server events.

__signal_end_of_reception()

Unsets the reception when all data has been received, and notifies MelissaIterableDataset to stop the batch formation.

__set_expected_batches_samples_watermark()

Computes and sets the expected samples and batches per server process.

__check_water_mark()

Ensures there are sufficient samples to reach the per_server_watermark.

other_processes_finished(batch_idx)

Checks if other server processes have finished emptying their buffers.

Parameters
  • batch_idx (int): The current batch number being processed.
Returns
  • bool: if all other server processes have emptied their buffers.

process_simulation_data(data, config_dict) abstractmethod

Transforms data while creating batches with MelissaIterableDataset. See SimulationData for usage of attributes associated with the received data.

Parameters
  • data (Union[List[SimulationData], SimulationData]): The data message received from the simulation (pulled from the buffer).
  • config_dict (Dict[str, Any]): A dictionary containing configuration settings.
Returns
  • Any: Transformed data before creating a batch from it.

get_reduced_validation_loss(valid_loss)

Returns reduced validation loss across all server ranks.

validation(batch_idx)

Predefined validation loop agnostic of frameworks. This can be overridable by the user.

train()

Predefined training loop agnostic of frameworks.

prepare_training_attributes() abstractmethod

Required to configure server's self.model and self.optimizer attributes, preparing them for initialization.

Returns
  • Tuple[Any, Any]:
    • model (Any): Instantiated model object.
    • optimizer (Any): Instantiated optimizer object.

checkpoint(batch_idx) abstractmethod

The method called to initiate full tree checkpointing. This is specific to torch or tensorflow server.

melissa.server.deep_learning.train_workflow.TrainingWorkflowMixin

Bases: ABC

Provides a structure for overriding training, validation, and hook methods.

training_step(batch, batch_idx, **kwargs) abstractmethod

Defines the logic for a single training step.

Parameters
  • batch (Any): A single batch of data.
  • batch_idx (int): The index of the batch.

validation_step(batch, valid_batch_idx, batch_idx, **kwargs)

Defines the logic for a single validation step.

Parameters
  • batch (Any): A single batch of validation data.
  • valid_batch_idx (int): The index of the validation batch.
  • batch_idx (int): The index of the batch.
Returns
  • Dict[str, Any]: Output from the validation step.

on_train_start()

Hook called at the start of training.

on_train_end()

Hook called at the end of training.

on_batch_start(batch_idx)

Hook called at the start of batch iteration.

on_batch_end(batch_idx)

Hook called at the end of batch iteration.

on_validation_start(batch_idx)

Hook called at the start of validation.

on_validation_end(batch_idx)

Hook called at the end of validation.