Coverage for melissa/server/deep_learning/dataset/tf_dataset.py: 0%
13 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 22:25 +0100
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 22:25 +0100
1import tensorflow as tf
2from typing import Dict, Any, Callable, Optional
4from melissa.server.deep_learning.reservoir import BaseQueue
5from melissa.server.deep_learning.tensorboard import TensorboardLogger
6from melissa.server.deep_learning.dataset import MelissaIterableDataset
9class TfMelissaIterableDataset(MelissaIterableDataset):
10 """A TensorFlow-compatible extension of the MelissaIterableDataset.
12 This class adapts the MelissaIterableDataset to work seamlessly with
13 TensorFlow pipelines. It serves as a bridge between the Melissa
14 distributed data system and TensorFlow, ensuring compatibility
15 and ease of use."""
16 def __init__(
17 self,
18 buffer: BaseQueue,
19 config_dict: Optional[Dict[str, Any]] = None,
20 transform: Optional[Callable] = None,
21 tb_logger: Optional[TensorboardLogger] = None,
22 ) -> None:
23 MelissaIterableDataset.__init__(self, buffer, config_dict, transform, tb_logger)
26def as_tensorflow_dataset(iter_dataset: TfMelissaIterableDataset,
27 batch_size: int,
28 collate_fn: Optional[Callable] = None) -> tf.data.Dataset:
29 """Converts the iterable dataset into a TensorFlow `tf.data.Dataset`.
31 This method utilizes TensorFlow's `from_generator` functionality to
32 wrap the current iterable dataset into a `tf.data.Dataset`, allowing
33 integration with TensorFlow's data processing pipelines.
35 ### Parameters
36 - **iter_dataset** (`TfMelissaIterableDataset`): An iterable dataset instance
37 defining `__iter__` method.
38 - **batch_size** (`int`): Batch size for the iterable.
39 - **collate_fn** (`Callable`, optional):
40 A function to combine multiple samples into a batch.
42 ### Returns
43 - `tf.data.Dataset`: A TensorFlow dataset with elements
44 structured as `(features, labels)`. Both features and labels are of type
45 `tf.float32` with dynamic shapes (`None`)."""
47 dataset = tf.data.Dataset.from_generator(
48 iter_dataset.__iter__,
49 output_types=(tf.float32, tf.float32),
50 output_shapes=((None,), (None,)),
51 ).batch(batch_size)
53 if collate_fn:
54 dataset = dataset.map(collate_fn)
56 return dataset