Coverage for melissa/server/deep_learning/dataset/base_dataset.py: 45%
103 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 22:25 +0100
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 22:25 +0100
1"""This script defines iterable datasets for pytorch and tensorflow servers."""
3# pylint: disable=W0223,E0601
5import time
6import queue
7import threading
8import logging
10from typing import Union, List, Dict, Any, Callable, Optional, Iterable
12from melissa.server.simulation import SimulationData
13from melissa.server.deep_learning.reservoir import BaseQueue, Empty
14from melissa.server.deep_learning.tensorboard import TensorboardLogger
17logger = logging.getLogger(__name__)
20class MelissaIterableDataset:
21 """A dataset class designed to handle streaming simulation data through a buffer,
22 with optional data transformations and logging capabilities.
24 ### Parameters
25 - **buffer** (`BaseQueue`):
26 The buffer used for storing and retrieving streaming data.
27 - **config_dict** (`dict`, optional):
28 Configuration dictionary for initializing dataset-specific parameters.
29 Defaults to an empty dictionary.
30 - **transform** (`Callable`, optional):
31 A callable transformation function to apply to the data samples. Defaults to `None`.
32 - **tb_logger** (`TensorboardLogger`, optional):
33 A logger for tracking dataset operations via TensorBoard. Defaults to `None`.
35 ### Attributes
36 - **buffer** (`BaseQueue`):
37 Holds the data samples in a queue for processing.
38 - **__tb_logger** (`Optional[TensorboardLogger]`):
39 Logs dataset-related events or metrics for TensorBoard visualization.
40 - **_is_receiving** (`bool`):
41 Indicates whether the dataset is currently receiving data from the buffer.
42 - **sample_number** (`int`):
43 Tracks the number of samples processed.
44 - **config_dict** (`Dict[str, Any]`):
45 Stores configuration settings for the dataset.
46 - **__transform** (`Callable`, optional):
47 Holds the transformation function, if provided.
48 - **__transform_lock** (`threading.Lock`):
49 Ensures thread-safe application of the transformation function."""
50 def __init__(self,
51 buffer: BaseQueue,
52 config_dict: Optional[Dict[str, Any]] = None,
53 transform: Optional[Callable] = None,
54 tb_logger: Optional[TensorboardLogger] = None) -> None:
56 self.buffer: BaseQueue = buffer
57 self.__tb_logger: Optional[TensorboardLogger] = tb_logger
58 self._is_receiving: bool = True
59 self.sample_number = 0
60 self.config_dict: Optional[Dict[str, Any]] = config_dict
61 self.__transform: Optional[Callable] = transform
62 self.__transform_lock: threading.Lock = threading.Lock()
64 @property
65 def transform_lock(self):
66 return self.__transform_lock
68 @property
69 def has_data(self) -> bool:
70 """Returns if the server is still receiving and the buffer is not empty."""
71 return self._is_receiving or not self.buffer.empty()
73 def get_sample_number(self) -> int:
74 """Returns the total sample count
75 that were pulled from the buffer and processed. Useful for logging."""
76 return self.sample_number
78 def signal_reception_over(self):
79 """Called after reception is done to flush the remaining
80 elements from the buffer."""
82 self._is_receiving = False
83 self.buffer.signal_reception_over()
85 def wrapped_transform(self, *args, **kwargs):
86 with self.__transform_lock:
87 return self.__transform(*args, **kwargs)
89 def __iter__(self):
90 """Infinite iterator which will always try to pull from the
91 buffer as long as the buffer is not empty or the server
92 is still receiving data."""
94 while self.has_data:
95 try:
96 start_get = time.time()
97 items: Union[SimulationData, List[SimulationData]] = self.buffer.get(timeout=1)
98 end_get = time.time()
99 if self.__transform and items:
100 data = self.wrapped_transform(items, self.config_dict)
101 elif isinstance(items, list):
102 data = [item.data for item in items]
103 else:
104 data = items.data
105 # compute samples per second
106 elapsed_time = time.time() - start_get
107 samples_per_second = 1.0 / elapsed_time if elapsed_time > 0 else 0.0
108 self.sample_number += 1
110 if self.__tb_logger:
111 self.__tb_logger.log_scalar(
112 "samples_per_second", samples_per_second, self.sample_number
113 )
114 self.__tb_logger.log_scalar("get_time", end_get - start_get, self.sample_number)
115 self.__tb_logger.log_scalar("buffer_size", len(self.buffer), self.sample_number)
116 yield data
117 except Empty:
118 logger.warning("Buffer empty but still receiving.")
119 time.sleep(1) # necessary in case multiple rounds
120 continue
123class GeneralDataLoader:
124 """A general-purpose data loader designed to handle streaming datasets
125 with optional multi-threaded loading and batch collation.
127 This class supports datasets like `MelissaIterableDataset` that provide
128 infinite or streaming data. It enables efficient batching and parallel
129 data loading while ensuring compatibility with custom collation functions.
131 ### Parameters
132 - **dataset** (`MelissaIterableDataset`):
133 An iterable dataset that streams data via its `__iter__` method.
134 - **batch_size** (`int`):
135 Number of samples per batch.
136 - **collate_fn** (`Callable`, optional):
137 A function to combine multiple samples into a batch. Defaults to `None`,
138 which creates batches as lists of samples.
139 - **num_workers** (`int`, optional):
140 Number of worker threads for parallel data loading. Defaults to `0` (no threading).
141 - **drop_last** (`bool`, optional):
142 Whether to drop the last incomplete batch. Defaults to `True`.
144 ### Attributes
145 - **dataset** (`MelissaIterableDataset`):
146 The dataset being wrapped for batching and loading.
147 - **batch_size** (`int`):
148 Size of each batch produced by the data loader.
149 - **collate_fn** (`Optional[Callable]`):
150 The function used to collate samples into batches.
151 - **num_workers** (`int`):
152 Number of worker threads for parallel data loading.
153 - **drop_last** (`bool`):
154 Indicates if incomplete batches are dropped.
155 - **_queue** (`queue.Queue`):
156 An internal buffer to hold preloaded samples during multi-threaded loading.
157 - **_stop_event** (`threading.Event`):
158 A flag to signal worker threads to stop loading data.
159 - **_threads** (`List[threading.Thread]`):
160 List of worker threads for parallel data loading."""
161 def __init__(self,
162 dataset: MelissaIterableDataset,
163 batch_size: int,
164 collate_fn: Optional[Callable] = None,
165 num_workers: int = 0,
166 drop_last: bool = True) -> None:
168 self.dataset = dataset
169 self.batch_size = batch_size
170 self.collate_fn = collate_fn
171 self.num_workers = num_workers
172 self.drop_last = drop_last
174 # for multi-threaded loading (if num_workers > 0)
175 self._queue: queue.Queue = queue.Queue(maxsize=10 * batch_size)
176 self._stop_event: threading.Event = threading.Event()
177 self._threads: List[threading.Thread] = []
179 def _worker_loop(self) -> None:
180 """Worker thread loop to fetch data from the dataset and enqueue it."""
182 for item in self.dataset:
183 if self._stop_event.is_set():
184 break
185 self._queue.put(item)
187 def __iter__(self) -> Iterable:
188 """Iterate over the dataset and yield batches."""
190 # start worker threads for parallel loading
191 if self.num_workers > 0:
192 self._stop_event.clear()
193 self._threads = [
194 threading.Thread(target=self._worker_loop, daemon=True)
195 for _ in range(self.num_workers)
196 ]
197 for thread in self._threads:
198 thread.start()
200 batch = None
201 try:
202 while True:
203 # fetch data from the dataset or the worker queue
204 if self.num_workers > 0:
205 item = self._queue.get()
206 else:
207 item = next(iter(self.dataset))
209 if batch is None:
210 batch = [[] for _ in item]
211 # making the first dimension equal to the number of returns
212 # from `process_simulation_data` defined by the user
213 for i, it in enumerate(item):
214 batch[i].append(it) # type: ignore
216 if len(batch[0]) == self.batch_size: # type: ignore
217 if self.collate_fn:
218 yield self.collate_fn(batch)
219 else:
220 yield batch
221 batch = None
223 except StopIteration:
224 pass
225 finally:
226 # handle any leftover batch
227 if batch and not self.drop_last:
228 if self.collate_fn:
229 yield self.collate_fn(batch)
230 else:
231 yield batch
233 # stop worker threads if running
234 if self.num_workers > 0:
235 self._stop_event.set()
236 for thread in self._threads:
237 thread.join()
239 def __len__(self) -> None:
240 """`__len__` is not supported for infinite datasets."""
241 raise TypeError("`__len__` is not supported for infinite datasets.")