Coverage for melissa/server/deep_learning/dataset/__init__.py: 26%

27 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-10 22:25 +0100

1from typing import Optional, Dict, Any, Callable 

2 

3from melissa.server.deep_learning import FrameworkType 

4from melissa.server.deep_learning.reservoir import BaseQueue 

5from melissa.server.deep_learning.dataset.base_dataset import ( 

6 MelissaIterableDataset, GeneralDataLoader 

7) 

8 

9__all__ = [ 

10 "MelissaIterableDataset", 

11 "FrameworkType", 

12 "make_dataset", 

13 "make_dataloader" 

14] 

15 

16 

17def make_dataset(framework_t: FrameworkType, 

18 buffer: BaseQueue, 

19 tb_logger: Any, 

20 config_dict: Dict[str, Any], 

21 transform: Callable) -> MelissaIterableDataset: 

22 

23 """Factory function to create datasets based on the specified deep learning framework. 

24 

25 This function initializes and returns a dataset object for either PyTorch or TensorFlow 

26 based on the provided framework type. 

27 

28 ### Parameters 

29 - **framework_t** (`FrameworkType`): The type of framework (`DEFAULT`, `TORCH` or `TENSORFLOW`). 

30 - **buffer** (`BaseQueue`): The data buffer to be used by the dataset. 

31 - **tb_logger** (`Any`): A logger for TensorBoard metrics. 

32 - **config_dict** (`Dict[str, Any]`): Configuration dictionary for the dataset. 

33 - **transform** (`Callable`): A transformation function to process data before yielding it. 

34 

35 ### Returns 

36 - `MelissaIterableDataset`: A dataset object compatible with the specified framework.""" 

37 

38 if framework_t is FrameworkType.DEFAULT: 

39 dataset_class = MelissaIterableDataset 

40 if framework_t is FrameworkType.TORCH: 

41 from melissa.server.deep_learning.dataset.torch_dataset import TorchMelissaIterableDataset 

42 dataset_class = TorchMelissaIterableDataset 

43 if framework_t is FrameworkType.TENSORFLOW: 

44 from melissa.server.deep_learning.dataset.tf_dataset import TfMelissaIterableDataset 

45 dataset_class = TfMelissaIterableDataset 

46 

47 return dataset_class(buffer, config_dict, transform, tb_logger) 

48 

49 

50def make_dataloader(framework_t: FrameworkType, 

51 iter_dataset: MelissaIterableDataset, 

52 batch_size: int, 

53 collate_fn: Optional[Callable] = None, 

54 num_workers: int = 0, 

55 **extra_torch_dl_args) -> Any: 

56 """Factory function to create dataloader based on the specified deep learning framework. 

57 

58 ### Parameters 

59 - **framework_t** (`FrameworkType`): The type of framework (`DEFAULT`, `TORCH` or `TENSORFLOW`). 

60 - **iter_dataset** (`MelissaIterableDataset`): An iterable dataset that streams data via its 

61 `__iter__` method. 

62 - **batch_size** (`int`): Number of samples per batch. 

63 - **collate_fn** (`Callable`, optional): A function to combine multiple samples into a batch. 

64 Defaults to `None`, which creates batches as lists of samples. 

65 - **num_workers** (`int`, optional): Number of worker threads for parallel data loading. 

66 Defaults to `0` (no threading). 

67 - **extra_torch_dl_args** (`Dict[str, Any]`, optional): Extra `kwargs` for 

68 `torch.utils.data.DataLoader`. 

69 

70 ### Returns 

71 - `Union[GeneralDataLoader, torch.utils.data.DataLoader, tensorflow.data.Dataset]`: 

72 An iterable for training over batches. 

73 

74 ### Raises 

75 - `RuntimeError`if the specified framework is not found.""" 

76 

77 if framework_t is FrameworkType.DEFAULT: 

78 return GeneralDataLoader(iter_dataset, batch_size, collate_fn, num_workers) 

79 if framework_t is FrameworkType.TORCH: 

80 from melissa.server.deep_learning.dataset.torch_dataset import ( 

81 TorchMelissaIterableDataset, 

82 as_torch_dataloader 

83 ) 

84 assert isinstance(iter_dataset, TorchMelissaIterableDataset) 

85 return as_torch_dataloader( 

86 iter_dataset, batch_size, collate_fn, num_workers, **extra_torch_dl_args 

87 ) 

88 if framework_t is FrameworkType.TENSORFLOW: 

89 from melissa.server.deep_learning.dataset.tf_dataset import ( 

90 TfMelissaIterableDataset, 

91 as_tensorflow_dataset 

92 ) 

93 assert isinstance(iter_dataset, TfMelissaIterableDataset) 

94 return as_tensorflow_dataset(iter_dataset, batch_size, collate_fn) 

95 

96 raise RuntimeError