Coverage for melissa/server/deep_learning/dataset.py: 69%
64 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
1from typing import Union, List, Dict, Any, Callable, Optional
2import logging
3import time
5try: # double import for unit tests and CI
6 from torch.utils.data.dataset import IterableDataset
7 import tensorflow as tf
8except ModuleNotFoundError: # only one framework available
9 try:
10 from torch.utils.data.dataset import IterableDataset
11 except ModuleNotFoundError:
12 import tensorflow as tf
14 class IterableDataset: # type: ignore
15 pass
17from melissa.server.simulation import SimulationData
18from melissa.server.deep_learning.reservoir import FIFO, Empty
19from melissa.server.deep_learning.tensorboard_logger import TensorboardLogger
22logger = logging.getLogger(__name__)
25class MelissaIterableDataset:
26 def __init__(
27 self,
28 buffer: FIFO,
29 config: dict = {},
30 transform: Optional[Callable] = None,
31 tb_logger: Optional[TensorboardLogger] = None,
32 ) -> None:
33 self.buffer: FIFO = buffer
34 self.tb_logger = tb_logger
35 self._is_receiving: bool = True
36 self.batch = 0
37 self.sample_number = 0
38 self.config: Dict[str, Any] = config
39 self.transform = transform
41 @property
42 def has_data(self) -> bool:
43 return self._is_receiving or not self.buffer.empty()
45 def signal_reception_over(self):
46 self._is_receiving = False
47 self.buffer.signal_reception_over()
49 def __iter__(self):
50 # Infinite iterator which will always try to pull from the
51 # buffer as long as the buffer is not empty or the server
52 # is still receiving data
53 while self.has_data:
54 try:
55 start_get = time.time()
56 items: Union[SimulationData, List[SimulationData]] = self.buffer.get(timeout=1)
57 end_get = time.time()
58 start_processing = time.time()
59 if self.transform:
60 data = self.transform(items, self.config)
61 elif isinstance(items, list):
62 data = [item.data for item in items]
63 else:
64 data = items.data
65 end_processing = time.time()
66 if self.tb_logger:
67 self.sample_number += 1
68 self.tb_logger.log_scalar("get_time", end_get - start_get, self.sample_number)
69 self.tb_logger.log_scalar(
70 "processing_time", end_processing - start_processing, self.sample_number
71 )
72 self.tb_logger.log_scalar("buffer_size", len(self.buffer), self.sample_number)
73 yield data
74 except Empty:
75 logger.warning("Buffer empty but still receiving.")
76 continue
79class TorchMelissaIterableDataset(MelissaIterableDataset, IterableDataset):
80 """
81 Object whose only job is to make an iterable dataset available to torch.
82 Only used by Melissa-DL children.
83 """
84 def __init__(
85 self,
86 buffer: FIFO,
87 config: dict = {},
88 transform: Optional[Callable] = None,
89 tb_logger: Optional[TensorboardLogger] = None,
90 ) -> None:
91 MelissaIterableDataset.__init__(self, buffer, config, transform, tb_logger)
92 IterableDataset.__init__(self)
95class TfMelissaIterableDataset(MelissaIterableDataset):
96 """
97 Object whose only job is to make an iterable dataset available to tensorflow.
98 Only used by Melissa-DL children.
99 """
100 def __init__(
101 self,
102 buffer: FIFO,
103 config: dict = {},
104 transform: Optional[Callable] = None,
105 tb_logger: Optional[TensorboardLogger] = None,
106 ) -> None:
107 MelissaIterableDataset.__init__(self, buffer, config, transform, tb_logger)
109 def as_tensorflow_dataset(self):
110 dataset = tf.data.Dataset.from_generator(
111 self.__iter__,
112 output_types=(tf.float32, tf.float32),
113 output_shapes=((None,), (None,)),
114 )
115 return dataset