Coverage for melissa/server/deep_learning/dataset/__init__.py: 26%
27 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 22:25 +0100
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 22:25 +0100
1from typing import Optional, Dict, Any, Callable
3from melissa.server.deep_learning import FrameworkType
4from melissa.server.deep_learning.reservoir import BaseQueue
5from melissa.server.deep_learning.dataset.base_dataset import (
6 MelissaIterableDataset, GeneralDataLoader
7)
9__all__ = [
10 "MelissaIterableDataset",
11 "FrameworkType",
12 "make_dataset",
13 "make_dataloader"
14]
17def make_dataset(framework_t: FrameworkType,
18 buffer: BaseQueue,
19 tb_logger: Any,
20 config_dict: Dict[str, Any],
21 transform: Callable) -> MelissaIterableDataset:
23 """Factory function to create datasets based on the specified deep learning framework.
25 This function initializes and returns a dataset object for either PyTorch or TensorFlow
26 based on the provided framework type.
28 ### Parameters
29 - **framework_t** (`FrameworkType`): The type of framework (`DEFAULT`, `TORCH` or `TENSORFLOW`).
30 - **buffer** (`BaseQueue`): The data buffer to be used by the dataset.
31 - **tb_logger** (`Any`): A logger for TensorBoard metrics.
32 - **config_dict** (`Dict[str, Any]`): Configuration dictionary for the dataset.
33 - **transform** (`Callable`): A transformation function to process data before yielding it.
35 ### Returns
36 - `MelissaIterableDataset`: A dataset object compatible with the specified framework."""
38 if framework_t is FrameworkType.DEFAULT:
39 dataset_class = MelissaIterableDataset
40 if framework_t is FrameworkType.TORCH:
41 from melissa.server.deep_learning.dataset.torch_dataset import TorchMelissaIterableDataset
42 dataset_class = TorchMelissaIterableDataset
43 if framework_t is FrameworkType.TENSORFLOW:
44 from melissa.server.deep_learning.dataset.tf_dataset import TfMelissaIterableDataset
45 dataset_class = TfMelissaIterableDataset
47 return dataset_class(buffer, config_dict, transform, tb_logger)
50def make_dataloader(framework_t: FrameworkType,
51 iter_dataset: MelissaIterableDataset,
52 batch_size: int,
53 collate_fn: Optional[Callable] = None,
54 num_workers: int = 0,
55 **extra_torch_dl_args) -> Any:
56 """Factory function to create dataloader based on the specified deep learning framework.
58 ### Parameters
59 - **framework_t** (`FrameworkType`): The type of framework (`DEFAULT`, `TORCH` or `TENSORFLOW`).
60 - **iter_dataset** (`MelissaIterableDataset`): An iterable dataset that streams data via its
61 `__iter__` method.
62 - **batch_size** (`int`): Number of samples per batch.
63 - **collate_fn** (`Callable`, optional): A function to combine multiple samples into a batch.
64 Defaults to `None`, which creates batches as lists of samples.
65 - **num_workers** (`int`, optional): Number of worker threads for parallel data loading.
66 Defaults to `0` (no threading).
67 - **extra_torch_dl_args** (`Dict[str, Any]`, optional): Extra `kwargs` for
68 `torch.utils.data.DataLoader`.
70 ### Returns
71 - `Union[GeneralDataLoader, torch.utils.data.DataLoader, tensorflow.data.Dataset]`:
72 An iterable for training over batches.
74 ### Raises
75 - `RuntimeError`if the specified framework is not found."""
77 if framework_t is FrameworkType.DEFAULT:
78 return GeneralDataLoader(iter_dataset, batch_size, collate_fn, num_workers)
79 if framework_t is FrameworkType.TORCH:
80 from melissa.server.deep_learning.dataset.torch_dataset import (
81 TorchMelissaIterableDataset,
82 as_torch_dataloader
83 )
84 assert isinstance(iter_dataset, TorchMelissaIterableDataset)
85 return as_torch_dataloader(
86 iter_dataset, batch_size, collate_fn, num_workers, **extra_torch_dl_args
87 )
88 if framework_t is FrameworkType.TENSORFLOW:
89 from melissa.server.deep_learning.dataset.tf_dataset import (
90 TfMelissaIterableDataset,
91 as_tensorflow_dataset
92 )
93 assert isinstance(iter_dataset, TfMelissaIterableDataset)
94 return as_tensorflow_dataset(iter_dataset, batch_size, collate_fn)
96 raise RuntimeError