Coverage for melissa/server/deep_learning/dataset/tf_dataset.py: 0%

13 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-10 22:25 +0100

1import tensorflow as tf 

2from typing import Dict, Any, Callable, Optional 

3 

4from melissa.server.deep_learning.reservoir import BaseQueue 

5from melissa.server.deep_learning.tensorboard import TensorboardLogger 

6from melissa.server.deep_learning.dataset import MelissaIterableDataset 

7 

8 

9class TfMelissaIterableDataset(MelissaIterableDataset): 

10 """A TensorFlow-compatible extension of the MelissaIterableDataset. 

11 

12 This class adapts the MelissaIterableDataset to work seamlessly with 

13 TensorFlow pipelines. It serves as a bridge between the Melissa 

14 distributed data system and TensorFlow, ensuring compatibility 

15 and ease of use.""" 

16 def __init__( 

17 self, 

18 buffer: BaseQueue, 

19 config_dict: Optional[Dict[str, Any]] = None, 

20 transform: Optional[Callable] = None, 

21 tb_logger: Optional[TensorboardLogger] = None, 

22 ) -> None: 

23 MelissaIterableDataset.__init__(self, buffer, config_dict, transform, tb_logger) 

24 

25 

26def as_tensorflow_dataset(iter_dataset: TfMelissaIterableDataset, 

27 batch_size: int, 

28 collate_fn: Optional[Callable] = None) -> tf.data.Dataset: 

29 """Converts the iterable dataset into a TensorFlow `tf.data.Dataset`. 

30 

31 This method utilizes TensorFlow's `from_generator` functionality to 

32 wrap the current iterable dataset into a `tf.data.Dataset`, allowing 

33 integration with TensorFlow's data processing pipelines. 

34 

35 ### Parameters 

36 - **iter_dataset** (`TfMelissaIterableDataset`): An iterable dataset instance 

37 defining `__iter__` method. 

38 - **batch_size** (`int`): Batch size for the iterable. 

39 - **collate_fn** (`Callable`, optional): 

40 A function to combine multiple samples into a batch. 

41 

42 ### Returns 

43 - `tf.data.Dataset`: A TensorFlow dataset with elements 

44 structured as `(features, labels)`. Both features and labels are of type 

45 `tf.float32` with dynamic shapes (`None`).""" 

46 

47 dataset = tf.data.Dataset.from_generator( 

48 iter_dataset.__iter__, 

49 output_types=(tf.float32, tf.float32), 

50 output_shapes=((None,), (None,)), 

51 ).batch(batch_size) 

52 

53 if collate_fn: 

54 dataset = dataset.map(collate_fn) 

55 

56 return dataset