Server Class¶
melissa.server.deep_learning.base_dl_server.DeepMelissaServer¶
Bases: BaseServer, TrainingWorkflowMixin
DeepMelissaServer is designed for studies involving deep learning workflows.
It manages data buffering, logging, and configuration necessary for distributed training.
Parameters¶
- config_dict (
Dict[str, Any]): A dictionary containing configuration settings for initializing the server.
Attributes¶
- _bind_simulation_to_server_rank (
int): Whether to bind sending all timesteps of a simulation to the same server rank. By default, timesteps are sent in a round-robin fashion. - dl_config (
Dict[str, Any]): Dictionary containing deep learning-specific configurations. - batch_size (
int): The size of batches used for training. - per_server_watermark (
int): Watermark level to determine data buffering thresholds. - buffer_size (
int): Total size of the buffer to store training data. - _batches_watermark_set (
bool): Prevents recalculation of expected number of batches given that the time-steps are known. - pseudo_epochs (
int): Number of pseudo-epochs to replicate epoch-based training in an online setting. - current_sample_count (
int): Counter for the number of samples received. - self.nb_local_batches (
int): Number of batches processed. Local to MPI rank. - nb_expected_batches (
int): Number of expected batches; deterministic for FIFO and FIRO. - nb_batches_update (
int): Number of batches after which updates are triggered during training (For example, for validation). - checkpoint_interval (
int): Number of batches after which checkpointing is triggered. Defaults tonb_batches_update. - _model (
Any): Model to be trained. - _optimizer (
Any): Training optimizer. - __buffer (
BaseQueue): Instance of a buffer object (FIFO, FIRO, or Reservoir) to manage training data. - __dataset (
Dataset): Dataset interface to provide training data, integrates with the buffer and allows transformations viaprocess_simulation_data. - _framework_t (
FrameworkType): Framework type for the iterable dataset to be instantiated withmake_dataset. It can either beDEFAULT,TORCH, orTENSORFLOW. - _tb_logger (
TensorboardLogger): Logger to handle logging of metrics for visualization during training.
__adjust_buffer_size_for_pseudo_epochs()
¶
Adjusts buffer size before buffer instantiation for pseudo-offline training.
__configure_data_collection()
¶
Instantiates the data collection i.e buffer, dataset, and dataloader.
Users must implement process_simulation_data.
start()
¶
The main entrypoint for the server events.
__signal_end_of_reception()
¶
Unsets the reception when all data has been received,
and notifies MelissaIterableDataset to stop the batch formation.
__set_expected_batches_samples_watermark()
¶
Computes and sets the expected samples and batches per server process.
__check_water_mark()
¶
Ensures there are sufficient samples to reach the per_server_watermark.
other_processes_finished(batch_idx)
¶
process_simulation_data(data, config_dict)
abstractmethod
¶
Transforms data while creating batches with MelissaIterableDataset.
See SimulationData for usage of attributes associated with the received data.
Parameters¶
- data (
Union[List[SimulationData], SimulationData]): The data message received from the simulation (pulled from the buffer). - config_dict (
Dict[str, Any]): A dictionary containing configuration settings.
Returns¶
Any: Transformed data before creating a batch from it.
get_reduced_validation_loss(valid_loss)
¶
Returns reduced validation loss across all server ranks.
validation(batch_idx)
¶
Predefined validation loop agnostic of frameworks. This can be overridable by the user.
train()
¶
Predefined training loop agnostic of frameworks.
prepare_training_attributes()
abstractmethod
¶
Required to configure server's self.model and self.optimizer attributes,
preparing them for initialization.
Returns¶
Tuple[Any, Any]:- model (
Any): Instantiated model object. - optimizer (
Any): Instantiated optimizer object.
- model (
checkpoint(batch_idx)
abstractmethod
¶
The method called to initiate full tree checkpointing. This is
specific to torch or tensorflow server.
melissa.server.deep_learning.train_workflow.TrainingWorkflowMixin¶
Bases: ABC
Provides a structure for overriding training, validation, and hook methods.
training_step(batch, batch_idx, **kwargs)
abstractmethod
¶
Defines the logic for a single training step.
Parameters¶
- batch (
Any): A single batch of data. - batch_idx (
int): The index of the batch.
validation_step(batch, valid_batch_idx, batch_idx, **kwargs)
¶
on_train_start()
¶
Hook called at the start of training.
on_train_end()
¶
Hook called at the end of training.
on_batch_start(batch_idx)
¶
Hook called at the start of batch iteration.
on_batch_end(batch_idx)
¶
Hook called at the end of batch iteration.
on_validation_start(batch_idx)
¶
Hook called at the start of validation.
on_validation_end(batch_idx)
¶
Hook called at the end of validation.