Coverage for melissa/server/deep_learning/train_workflow.py: 74%

27 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-07-09 14:19 +0200

1import logging 

2from typing import Any 

3from abc import ABC, abstractmethod 

4 

5 

6logger = logging.getLogger("melissa") 

7 

8 

9class TrainingWorkflowMixin(ABC): 

10 """Provides a structure for overriding training, validation, and hook methods.""" 

11 

12 # ================================================================================================== 

13 # USER-DEFINED METHODS 

14 # ================================================================================================== 

15 @abstractmethod 

16 def training_step(self, batch: Any, batch_idx: int, **kwargs) -> None: 

17 """Defines the logic for a single training step. 

18 

19 ### Parameters 

20 - **batch** (`Any`): A single batch of data. 

21 - **batch_idx** (`int`): The index of the batch.""" 

22 

23 raise NotImplementedError 

24 

25 # @abstractmethod 

26 def validation_step(self, batch: Any, valid_batch_idx: int, batch_idx: int, **kwargs) -> None: 

27 """Defines the logic for a single validation step. 

28 

29 ### Parameters 

30 - **batch** (`Any`): A single batch of validation data. 

31 - **valid_batch_idx** (`int`): The index of the validation batch. 

32 - **batch_idx** (`int`): The index of the batch. 

33 

34 ### Returns 

35 - `Dict[str, Any]`: Output from the validation step.""" 

36 

37 def on_train_start(self) -> None: 

38 """Hook called at the start of training.""" 

39 

40 def on_train_end(self) -> None: 

41 """Hook called at the end of training.""" 

42 

43 # note that `batch_idx` is always referring to the training batch index 

44 def on_batch_start(self, batch_idx: int) -> None: 

45 """Hook called at the start of batch iteration.""" 

46 

47 def on_batch_end(self, batch_idx: int) -> None: 

48 """Hook called at the end of batch iteration.""" 

49 

50 def on_validation_start(self, batch_idx: int) -> None: 

51 """Hook called at the start of validation.""" 

52 

53 def on_validation_end(self, batch_idx: int) -> None: 

54 """Hook called at the end of validation.""" 

55 

56 # ================================================================================================== 

57 # METHODS TO BE CALLED IN THE TRAINING LOOP THAT DO FIXED THINGS ON THE SERVER-SIDE 

58 # BEFORE CALLING THE USER-DEFINED METHODS. 

59 # ================================================================================================== 

60 

61 def _on_train_start(self) -> None: 

62 self.on_train_start() 

63 

64 def _on_train_end(self) -> None: 

65 self.on_train_end() 

66 

67 def _on_batch_start(self, batch_idx: int) -> None: 

68 self.on_batch_start(batch_idx) 

69 

70 def _on_batch_end(self, batch_idx: int) -> None: 

71 self.on_batch_end(batch_idx) 

72 

73 def _on_validation_start(self, batch_idx: int) -> None: 

74 self.on_validation_start(batch_idx) 

75 

76 def _on_validation_end(self, batch_idx: int) -> None: 

77 self.on_validation_end(batch_idx)