Coverage for melissa/server/deep_learning/train_workflow.py: 74%
27 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-07-09 14:19 +0200
« prev ^ index » next coverage.py v7.6.12, created at 2025-07-09 14:19 +0200
1import logging
2from typing import Any
3from abc import ABC, abstractmethod
6logger = logging.getLogger("melissa")
9class TrainingWorkflowMixin(ABC):
10 """Provides a structure for overriding training, validation, and hook methods."""
12 # ==================================================================================================
13 # USER-DEFINED METHODS
14 # ==================================================================================================
15 @abstractmethod
16 def training_step(self, batch: Any, batch_idx: int, **kwargs) -> None:
17 """Defines the logic for a single training step.
19 ### Parameters
20 - **batch** (`Any`): A single batch of data.
21 - **batch_idx** (`int`): The index of the batch."""
23 raise NotImplementedError
25 # @abstractmethod
26 def validation_step(self, batch: Any, valid_batch_idx: int, batch_idx: int, **kwargs) -> None:
27 """Defines the logic for a single validation step.
29 ### Parameters
30 - **batch** (`Any`): A single batch of validation data.
31 - **valid_batch_idx** (`int`): The index of the validation batch.
32 - **batch_idx** (`int`): The index of the batch.
34 ### Returns
35 - `Dict[str, Any]`: Output from the validation step."""
37 def on_train_start(self) -> None:
38 """Hook called at the start of training."""
40 def on_train_end(self) -> None:
41 """Hook called at the end of training."""
43 # note that `batch_idx` is always referring to the training batch index
44 def on_batch_start(self, batch_idx: int) -> None:
45 """Hook called at the start of batch iteration."""
47 def on_batch_end(self, batch_idx: int) -> None:
48 """Hook called at the end of batch iteration."""
50 def on_validation_start(self, batch_idx: int) -> None:
51 """Hook called at the start of validation."""
53 def on_validation_end(self, batch_idx: int) -> None:
54 """Hook called at the end of validation."""
56 # ==================================================================================================
57 # METHODS TO BE CALLED IN THE TRAINING LOOP THAT DO FIXED THINGS ON THE SERVER-SIDE
58 # BEFORE CALLING THE USER-DEFINED METHODS.
59 # ==================================================================================================
61 def _on_train_start(self) -> None:
62 self.on_train_start()
64 def _on_train_end(self) -> None:
65 self.on_train_end()
67 def _on_batch_start(self, batch_idx: int) -> None:
68 self.on_batch_start(batch_idx)
70 def _on_batch_end(self, batch_idx: int) -> None:
71 self.on_batch_end(batch_idx)
73 def _on_validation_start(self, batch_idx: int) -> None:
74 self.on_validation_start(batch_idx)
76 def _on_validation_end(self, batch_idx: int) -> None:
77 self.on_validation_end(batch_idx)