Coverage for melissa/server/deep_learning/dataset/base_dataset.py: 45%

103 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-10 22:25 +0100

1"""This script defines iterable datasets for pytorch and tensorflow servers.""" 

2 

3# pylint: disable=W0223,E0601 

4 

5import time 

6import queue 

7import threading 

8import logging 

9 

10from typing import Union, List, Dict, Any, Callable, Optional, Iterable 

11 

12from melissa.server.simulation import SimulationData 

13from melissa.server.deep_learning.reservoir import BaseQueue, Empty 

14from melissa.server.deep_learning.tensorboard import TensorboardLogger 

15 

16 

17logger = logging.getLogger(__name__) 

18 

19 

20class MelissaIterableDataset: 

21 """A dataset class designed to handle streaming simulation data through a buffer, 

22 with optional data transformations and logging capabilities. 

23 

24 ### Parameters 

25 - **buffer** (`BaseQueue`): 

26 The buffer used for storing and retrieving streaming data. 

27 - **config_dict** (`dict`, optional): 

28 Configuration dictionary for initializing dataset-specific parameters. 

29 Defaults to an empty dictionary. 

30 - **transform** (`Callable`, optional): 

31 A callable transformation function to apply to the data samples. Defaults to `None`. 

32 - **tb_logger** (`TensorboardLogger`, optional): 

33 A logger for tracking dataset operations via TensorBoard. Defaults to `None`. 

34 

35 ### Attributes 

36 - **buffer** (`BaseQueue`): 

37 Holds the data samples in a queue for processing. 

38 - **__tb_logger** (`Optional[TensorboardLogger]`): 

39 Logs dataset-related events or metrics for TensorBoard visualization. 

40 - **_is_receiving** (`bool`): 

41 Indicates whether the dataset is currently receiving data from the buffer. 

42 - **sample_number** (`int`): 

43 Tracks the number of samples processed. 

44 - **config_dict** (`Dict[str, Any]`): 

45 Stores configuration settings for the dataset. 

46 - **__transform** (`Callable`, optional): 

47 Holds the transformation function, if provided. 

48 - **__transform_lock** (`threading.Lock`): 

49 Ensures thread-safe application of the transformation function.""" 

50 def __init__(self, 

51 buffer: BaseQueue, 

52 config_dict: Optional[Dict[str, Any]] = None, 

53 transform: Optional[Callable] = None, 

54 tb_logger: Optional[TensorboardLogger] = None) -> None: 

55 

56 self.buffer: BaseQueue = buffer 

57 self.__tb_logger: Optional[TensorboardLogger] = tb_logger 

58 self._is_receiving: bool = True 

59 self.sample_number = 0 

60 self.config_dict: Optional[Dict[str, Any]] = config_dict 

61 self.__transform: Optional[Callable] = transform 

62 self.__transform_lock: threading.Lock = threading.Lock() 

63 

64 @property 

65 def transform_lock(self): 

66 return self.__transform_lock 

67 

68 @property 

69 def has_data(self) -> bool: 

70 """Returns if the server is still receiving and the buffer is not empty.""" 

71 return self._is_receiving or not self.buffer.empty() 

72 

73 def get_sample_number(self) -> int: 

74 """Returns the total sample count 

75 that were pulled from the buffer and processed. Useful for logging.""" 

76 return self.sample_number 

77 

78 def signal_reception_over(self): 

79 """Called after reception is done to flush the remaining 

80 elements from the buffer.""" 

81 

82 self._is_receiving = False 

83 self.buffer.signal_reception_over() 

84 

85 def wrapped_transform(self, *args, **kwargs): 

86 with self.__transform_lock: 

87 return self.__transform(*args, **kwargs) 

88 

89 def __iter__(self): 

90 """Infinite iterator which will always try to pull from the 

91 buffer as long as the buffer is not empty or the server 

92 is still receiving data.""" 

93 

94 while self.has_data: 

95 try: 

96 start_get = time.time() 

97 items: Union[SimulationData, List[SimulationData]] = self.buffer.get(timeout=1) 

98 end_get = time.time() 

99 if self.__transform and items: 

100 data = self.wrapped_transform(items, self.config_dict) 

101 elif isinstance(items, list): 

102 data = [item.data for item in items] 

103 else: 

104 data = items.data 

105 # compute samples per second 

106 elapsed_time = time.time() - start_get 

107 samples_per_second = 1.0 / elapsed_time if elapsed_time > 0 else 0.0 

108 self.sample_number += 1 

109 

110 if self.__tb_logger: 

111 self.__tb_logger.log_scalar( 

112 "samples_per_second", samples_per_second, self.sample_number 

113 ) 

114 self.__tb_logger.log_scalar("get_time", end_get - start_get, self.sample_number) 

115 self.__tb_logger.log_scalar("buffer_size", len(self.buffer), self.sample_number) 

116 yield data 

117 except Empty: 

118 logger.warning("Buffer empty but still receiving.") 

119 time.sleep(1) # necessary in case multiple rounds 

120 continue 

121 

122 

123class GeneralDataLoader: 

124 """A general-purpose data loader designed to handle streaming datasets 

125 with optional multi-threaded loading and batch collation. 

126 

127 This class supports datasets like `MelissaIterableDataset` that provide 

128 infinite or streaming data. It enables efficient batching and parallel 

129 data loading while ensuring compatibility with custom collation functions. 

130 

131 ### Parameters 

132 - **dataset** (`MelissaIterableDataset`): 

133 An iterable dataset that streams data via its `__iter__` method. 

134 - **batch_size** (`int`): 

135 Number of samples per batch. 

136 - **collate_fn** (`Callable`, optional): 

137 A function to combine multiple samples into a batch. Defaults to `None`, 

138 which creates batches as lists of samples. 

139 - **num_workers** (`int`, optional): 

140 Number of worker threads for parallel data loading. Defaults to `0` (no threading). 

141 - **drop_last** (`bool`, optional): 

142 Whether to drop the last incomplete batch. Defaults to `True`. 

143 

144 ### Attributes 

145 - **dataset** (`MelissaIterableDataset`): 

146 The dataset being wrapped for batching and loading. 

147 - **batch_size** (`int`): 

148 Size of each batch produced by the data loader. 

149 - **collate_fn** (`Optional[Callable]`): 

150 The function used to collate samples into batches. 

151 - **num_workers** (`int`): 

152 Number of worker threads for parallel data loading. 

153 - **drop_last** (`bool`): 

154 Indicates if incomplete batches are dropped. 

155 - **_queue** (`queue.Queue`): 

156 An internal buffer to hold preloaded samples during multi-threaded loading. 

157 - **_stop_event** (`threading.Event`): 

158 A flag to signal worker threads to stop loading data. 

159 - **_threads** (`List[threading.Thread]`): 

160 List of worker threads for parallel data loading.""" 

161 def __init__(self, 

162 dataset: MelissaIterableDataset, 

163 batch_size: int, 

164 collate_fn: Optional[Callable] = None, 

165 num_workers: int = 0, 

166 drop_last: bool = True) -> None: 

167 

168 self.dataset = dataset 

169 self.batch_size = batch_size 

170 self.collate_fn = collate_fn 

171 self.num_workers = num_workers 

172 self.drop_last = drop_last 

173 

174 # for multi-threaded loading (if num_workers > 0) 

175 self._queue: queue.Queue = queue.Queue(maxsize=10 * batch_size) 

176 self._stop_event: threading.Event = threading.Event() 

177 self._threads: List[threading.Thread] = [] 

178 

179 def _worker_loop(self) -> None: 

180 """Worker thread loop to fetch data from the dataset and enqueue it.""" 

181 

182 for item in self.dataset: 

183 if self._stop_event.is_set(): 

184 break 

185 self._queue.put(item) 

186 

187 def __iter__(self) -> Iterable: 

188 """Iterate over the dataset and yield batches.""" 

189 

190 # start worker threads for parallel loading 

191 if self.num_workers > 0: 

192 self._stop_event.clear() 

193 self._threads = [ 

194 threading.Thread(target=self._worker_loop, daemon=True) 

195 for _ in range(self.num_workers) 

196 ] 

197 for thread in self._threads: 

198 thread.start() 

199 

200 batch = None 

201 try: 

202 while True: 

203 # fetch data from the dataset or the worker queue 

204 if self.num_workers > 0: 

205 item = self._queue.get() 

206 else: 

207 item = next(iter(self.dataset)) 

208 

209 if batch is None: 

210 batch = [[] for _ in item] 

211 # making the first dimension equal to the number of returns 

212 # from `process_simulation_data` defined by the user 

213 for i, it in enumerate(item): 

214 batch[i].append(it) # type: ignore 

215 

216 if len(batch[0]) == self.batch_size: # type: ignore 

217 if self.collate_fn: 

218 yield self.collate_fn(batch) 

219 else: 

220 yield batch 

221 batch = None 

222 

223 except StopIteration: 

224 pass 

225 finally: 

226 # handle any leftover batch 

227 if batch and not self.drop_last: 

228 if self.collate_fn: 

229 yield self.collate_fn(batch) 

230 else: 

231 yield batch 

232 

233 # stop worker threads if running 

234 if self.num_workers > 0: 

235 self._stop_event.set() 

236 for thread in self._threads: 

237 thread.join() 

238 

239 def __len__(self) -> None: 

240 """`__len__` is not supported for infinite datasets.""" 

241 raise TypeError("`__len__` is not supported for infinite datasets.")