Coverage for melissa/server/deep_learning/reservoir.py: 89%
325 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
1from typing import Optional, Tuple, Union, List, Deque
2import logging
3from time import time
4import random
5from dataclasses import dataclass
6from collections import deque, Counter
7import threading
8from queue import Full, Empty
10import numpy as np
12from melissa.server.simulation import SimulationData
13from melissa.types import GetProtocol, Threshold
15logger = logging.getLogger(__name__)
18@dataclass
19class Sample:
20 """A Sample represents an item contained in the reservoir.
21 It is associated to simulation data and seen counter.
23 """
25 data: SimulationData
26 seen: int = 0
29class PutGetMetric:
30 """A class to monitor the balance between puts and gets in the reservoir.
32 """
34 def __init__(self, val: int):
35 self.val: int = val
36 self.increment_lock: threading.Lock = threading.Lock()
38 def inc(self, val):
39 with self.increment_lock:
40 self.val = self.val + val
42 def dec(self, val):
43 with self.increment_lock:
44 self.val = self.val - val
47class NotEnoughData(Empty):
48 "Not enough data in the queue"
51class BaseQueue:
52 """A queue inspired by the Lib/queue.py,
53 but with only the functionality needed, hence no task_done feature.
54 See https://github.com/python/cpython/blob/main/Lib/queue.py#L28
56 """
58 def __init__(self, maxsize: int = 0):
59 self.maxsize = maxsize
60 self.mutex = threading.RLock()
61 self.not_empty = threading.Condition(self.mutex)
62 self.not_full = threading.Condition(self.mutex)
63 self._init_queue()
64 self.put_get_metric: PutGetMetric = PutGetMetric(0)
66 def _init_queue(self):
67 self.queue: Deque[SimulationData] = deque()
69 def save_state(self):
70 return {"queue": self.queue}
72 def load_from_state(self, state: dict):
73 self.queue = state["queue"]
75 def _size(self) -> int:
76 return len(self.queue)
78 def _is_sampling_ready(self) -> bool:
79 return self._size() > 0
81 def __len__(self):
82 with self.mutex:
83 return self._size()
85 def empty(self):
86 with self.mutex:
87 return self._size() == 0
89 def _is_full(self) -> bool:
90 return self._size() >= self.maxsize
92 def _on_full(self, block, timeout) -> bool:
93 if not block:
94 raise Full
95 else:
96 self.not_full.wait(timeout)
97 return True
99 def _get(self) -> SimulationData:
100 return self.queue.popleft()
102 def _put(self, item: SimulationData):
103 return self.queue.append(item)
105 def get(self, block: bool = True, timeout: Optional[float] = None):
106 with self.not_empty:
107 if not block:
108 if not self._is_sampling_ready():
109 raise NotEnoughData
110 elif timeout is None:
111 while not self._is_sampling_ready():
112 self.not_empty.wait()
113 elif timeout < 0:
114 raise ValueError("'timeout' must be a non-negative number")
115 else:
116 endtime = time() + timeout
117 while not self._is_sampling_ready():
118 remaining = endtime - time()
119 if remaining <= 0.0:
120 raise NotEnoughData
121 self.not_empty.wait(remaining)
122 item = self._get()
123 self.not_full.notify()
124 self.put_get_metric.dec(1)
125 return item
127 def put(self, item, block: bool = True, timeout: Optional[float] = None):
128 add_item = True # Do we actually add the item to the queue
129 with self.not_full:
130 if self.maxsize > 0:
131 if not block:
132 if self._is_full():
133 add_item = self._on_full(block=block, timeout=timeout)
134 elif timeout is None:
135 while self._is_full():
136 add_item = self._on_full(block=block, timeout=timeout)
137 elif timeout < 0:
138 raise ValueError("'timeout' must be a non-negative number")
139 else:
140 endtime = time() + timeout
141 while self._is_full():
142 remaining = endtime - time()
143 if remaining <= 0.0:
144 raise Full
145 add_item = self._on_full(block=block, timeout=remaining)
146 if add_item:
147 self._put(item)
148 self.put_get_metric.inc(1)
149 if self._is_sampling_ready():
150 self.not_empty.notify()
152 def compute_buffer_statistics(self) -> Tuple[np.ndarray, np.ndarray]:
153 raise NotImplementedError
156class CounterMixin:
157 """A Mixin to count the number of times each sample has been seen in the reservoir.
159 """
161 def __init__(self):
162 self.seen_ctr = Counter()
164 def put(self, item: SimulationData, block: bool = True, timeout: Optional[float] = None):
165 watched_item = Sample(item)
166 super().put(watched_item, block, timeout) # type: ignore
168 def _get_with_eviction(self, index: Union[List, int]):
169 items: Union[List[Sample], Sample] = super()._get_with_eviction( # type: ignore
170 index
171 )
172 if isinstance(items, list):
173 item_data = []
174 item_seen = []
175 for item in items:
176 item.seen += 1
177 item_seen.append(item.seen)
178 item_data.append(item.data)
179 self.seen_ctr += Counter(item_seen)
180 return item_data
181 self.seen_ctr += Counter([items.seen + 1])
182 return items.data
184 def _get_without_eviction(self, index: Union[List, int]):
185 items: Union[List[Sample], Sample] = super()._get_without_eviction( # type: ignore
186 index
187 )
188 if isinstance(items, list):
189 for item in items:
190 item.seen += 1
191 return [item.data for item in items]
192 items.seen += 1
193 return items.data
195 def _get(self):
196 if hasattr(super(), "_get_without_eviction"):
197 return super()._get() # type: ignore
198 else:
199 item: Sample = super()._get() # type: ignore
200 self.seen_ctr += Counter([item.seen + 1])
201 return item.data
203 def __repr__(self) -> str:
204 s = super().__repr__()
205 return f"{s}: {self.seen_ctr}"
208class ReceptionDependant:
209 """A Mixin to receive the signal that the reception of data is over.
210 Especially useful to empty the reservoir.
212 """
214 def __init__(self):
215 self._is_reception_over = False
217 def signal_reception_over(self):
218 self._is_reception_over = True
221class SamplingDependant:
222 """A Mixin to receive the signal that the reservoir is ready for sampling.
223 Espcieally useful with thresholds.
225 """
227 def _is_sampling_ready(self) -> bool:
228 return True
231class ThresholdMixin(SamplingDependant, ReceptionDependant):
232 """A Mixin that blocks when not enough data are available in the container.
234 """
236 def __init__(self, threshold: int):
237 ReceptionDependant.__init__(self)
238 self.threshold = threshold
240 def _is_sampling_ready(self: Threshold) -> bool:
241 is_ready = super()._is_sampling_ready() and (self._size() > self.threshold) # type: ignore
242 return is_ready
244 def signal_reception_over(self):
245 with self.mutex:
246 super().signal_reception_over()
247 self.threshold = 0
248 self.not_empty.notify()
251class ReadingWithoutEvictionMixin(ReceptionDependant):
252 """A Mixin that gets data without eviction until the reception is over.
254 """
256 def _get_from_index(self: GetProtocol, index: Union[list, int]):
257 if self._is_reception_over:
258 return self._get_with_eviction(index)
259 return self._get_without_eviction(index)
262class RandomQueue(BaseQueue):
263 """Queue that randomly reads items from the list container.
264 It evicts on reading.
266 """
268 def __init__(self, maxsize: int, pseudo_epochs: int = 1):
269 super().__init__(maxsize)
270 self.pseudo_epochs = pseudo_epochs
272 def _init_queue(self):
273 self.queue: List[SimulationData] = []
275 def _get_index(self) -> int:
276 index = random.randrange(self._size())
277 return index
279 def _get_without_eviction(self, index: Union[list, int]):
280 if isinstance(index, int):
281 return self.queue[index]
282 elif isinstance(index, list):
283 return [self.queue[i] for i in index]
285 def _get_with_eviction(self, index: Union[list, int]):
286 if isinstance(index, int):
287 item = self.queue[index]
288 del self.queue[index]
289 return item
290 elif isinstance(index, list):
291 items = [self.queue[i] for i in index]
292 for i in index:
293 del self.queue[i]
294 return items
296 def _get_from_index(self, index: int):
297 if self.pseudo_epochs <= 1:
298 return self._get_with_eviction(index)
299 return self._get_without_eviction(index)
301 def _get(self):
302 index = self._get_index()
303 return self._get_from_index(index)
306class ReservoirQueue(ReadingWithoutEvictionMixin, RandomQueue):
307 """Queue implementing the reservoir sampling algorithm.
309 """
311 def __init__(self, maxsize: int, queue: Optional[Deque] = None):
312 ReadingWithoutEvictionMixin.__init__(self)
313 RandomQueue.__init__(self, maxsize)
314 self.put_ctr: int = 0
316 def _evict(self, index: int):
317 del self.queue[index]
319 def _on_full(self, block, timeout) -> bool:
320 index = random.randrange(self.put_ctr)
321 if index < self.maxsize:
322 self._evict(index)
323 return True
324 return False
326 def put(self, item, block=True, timeout=None):
327 with self.mutex:
328 RandomQueue.put(self, item, block, timeout)
329 self.put_ctr += 1
332class RandomEvictOnWriteQueue(ReadingWithoutEvictionMixin, RandomQueue):
333 """A queue that overwrites random samples that have been already seen.
335 """
337 def __init__(self, maxsize: int, queue: Optional[Deque] = None):
338 ReadingWithoutEvictionMixin.__init__(self)
339 RandomQueue.__init__(self, maxsize)
341 def _init_queue(self):
342 self.not_seen: list = []
343 self.seen: list = []
345 def save_state(self):
346 return {"not_seen": self.not_seen, "seen": self.seen}
348 def load_from_state(self, state: dict):
349 self.not_seen = state["not_seen"]
350 self.seen = state["seen"]
352 def __repr__(self) -> str:
353 return (
354 f"{self.__class__.__name__}:"
355 f"not yet seen samples {len(self.not_seen)}, already seen {len(self.seen)}"
356 )
358 def _size(self) -> int:
359 return len(self.not_seen) + len(self.seen)
361 def _evict(self, index):
362 del self.seen[index]
364 def _on_full(self, block, timeout) -> bool:
365 if not block:
366 if len(self.seen) == 0:
367 raise Full
368 else:
369 if len(self.seen) == 0:
370 timed_in = self.not_full.wait(timeout)
371 if not timed_in:
372 return True
373 index = random.randrange(len(self.seen))
374 self._evict(index)
375 return True
377 def _get_with_eviction(self, index: Union[list, int]):
378 if isinstance(index, int):
379 if index < len(self.not_seen):
380 item = self.not_seen[index]
381 del self.not_seen[index]
382 else:
383 index = index - len(self.not_seen)
384 item = self.seen[index]
385 del self.seen[index]
386 return item
387 elif isinstance(index, list):
388 not_seen_index = sorted([i for i in index if i < len(self.not_seen)], reverse=True)
389 seen_index = sorted(
390 [i - len(self.not_seen) for i in index if i >= len(self.not_seen)], reverse=True
391 )
392 items = [self.not_seen[i] for i in not_seen_index]
393 items += [self.seen[i] for i in seen_index]
394 for i in not_seen_index:
395 del self.not_seen[i]
396 for i in seen_index:
397 del self.seen[i]
398 return items
400 def _get_without_eviction(self, index: Union[list, int]):
401 if isinstance(index, int):
402 if index < len(self.not_seen):
403 item = self.not_seen[index]
404 del self.not_seen[index]
405 self.seen.append(item)
406 else:
407 index = index - len(self.not_seen)
408 item = self.seen[index]
409 return item
411 elif isinstance(index, list):
412 not_seen_index = [i for i in index if i < len(self.not_seen)]
413 seen_index = [i - len(self.not_seen) for i in index if i >= len(self.not_seen)]
414 items = []
415 for i in not_seen_index:
416 item = self.not_seen[i]
417 items.append(item)
418 del self.not_seen[i]
419 self.seen.append(item)
420 items += [self.seen[i] for i in seen_index]
421 return items
423 def _put(self, item):
424 self.not_seen.append(item)
427class BatchGetMixin(SamplingDependant, ReceptionDependant):
428 """A Mixin that gets data from the reservoir as batches instead of indivudal sampling.
430 """
432 def __init__(self, batch_size: int):
433 ReceptionDependant.__init__(self)
434 self.batch_size = batch_size
436 def _is_sampling_ready(self):
437 is_ready = super()._is_sampling_ready()
438 if self._is_reception_over:
439 # No more data will arrive, we may not be able to serve batch_size data
440 return is_ready
441 return is_ready and (self._size() >= self.batch_size)
443 def _get(self):
444 if not self._is_reception_over:
445 population = self.batch_size
446 else:
447 population = min(self.batch_size, self._size())
448 indices = sorted(random.sample(range(self._size()), k=population), reverse=True)
449 items = self._get_from_index(indices)
450 return items
453class FIFO(CounterMixin, ReceptionDependant, BaseQueue):
454 """First In First Out. """
456 def __init__(self, maxsize: int = 0):
457 CounterMixin.__init__(self)
458 ReceptionDependant.__init__(self)
459 BaseQueue.__init__(self, maxsize)
462class FIRO(CounterMixin, ThresholdMixin, RandomQueue):
463 """First In Random Out. """
465 def __init__(self, maxsize: int, threshold: int, pseudo_epochs: int = 1):
466 assert threshold <= maxsize
467 CounterMixin.__init__(self)
468 RandomQueue.__init__(self, maxsize, pseudo_epochs)
469 ThresholdMixin.__init__(self, threshold)
472class Reservoir(CounterMixin, ThresholdMixin, RandomEvictOnWriteQueue):
473 def __init__(self, maxsize: int, threshold: int):
474 assert threshold <= maxsize
475 CounterMixin.__init__(self)
476 ThresholdMixin.__init__(self, threshold)
477 RandomEvictOnWriteQueue.__init__(self, maxsize)
480class BatchReservoir(BatchGetMixin, Reservoir):
481 def __init__(self, maxsize: int, threshold: int, batch_size: int):
482 assert threshold <= maxsize
483 assert batch_size <= maxsize
484 Reservoir.__init__(self, maxsize, threshold)
485 BatchGetMixin.__init__(self, batch_size)