Coverage for melissa/server/deep_learning/reservoir.py: 63%
519 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
1"""This script defines all the buffer classes."""
3import logging
4import random
5import threading
6import numpy as np
7from scipy.stats import norm
8from collections import Counter, deque
9from dataclasses import dataclass
10from enum import Enum
11from queue import Empty, Full
12from time import time
13from typing import Any, Dict, Deque, List, Optional, Tuple, Union
14from numpy.typing import NDArray
16from melissa.server.simulation import SimulationData
17from melissa.types import QueueProtocol
19logger = logging.getLogger(__name__)
22@dataclass
23class Sample:
24 """A Sample represents an item contained in the reservoir.
25 It is associated to simulation data and seen counter."""
27 data: SimulationData
28 seen: int = 0
30 def __repr__(self) -> str:
31 """Returns a string representation of the `Sample` object."""
32 s = (
33 f"<{self.__class__.__name__} "
34 f"sim-id={self.data.simulation_id} "
35 f"time-step={self.data.time_step} "
36 f"fields={[f for f in self.data.payload.keys()]} "
37 f"seen={self.seen}>"
38 )
39 return s
42class PutGetMetric:
43 """A class to monitor the balance between puts and gets in the reservoir."""
45 def __init__(self, val: int) -> None:
47 self.val: int = val
48 self.increment_lock: threading.Lock = threading.Lock()
50 def inc(self, val: int) -> None:
51 """Increments the counter."""
52 with self.increment_lock:
53 self.val = self.val + val
55 def dec(self, val: int) -> None:
56 """Decrements the counter."""
57 with self.increment_lock:
58 self.val = self.val - val
61class NotEnoughData(Empty):
62 "Not enough data in the queue"
65class BaseQueue:
66 """A queue inspired by the Lib/queue.py,
67 but with only the functionality needed, hence no task_done feature.
68 See https://github.com/python/cpython/blob/main/Lib/queue.py#L28"""
70 def __init__(self, maxsize: int = 0) -> None:
72 self.maxsize: int = maxsize
73 self.mutex: threading.RLock = threading.RLock()
74 self.not_empty: threading.Condition = threading.Condition(self.mutex)
75 self.not_full: threading.Condition = threading.Condition(self.mutex)
76 self._init_queue()
77 self.put_get_metric: PutGetMetric = PutGetMetric(0)
79 def _init_queue(self) -> None:
80 """Initialize the queue with a specific data-structure.
81 Unique to subclasses."""
82 self.queue: Deque = deque()
84 def save_state(self) -> Dict[str, Any]:
85 """Returns the current state of the queue."""
86 return {"queue": self.queue}
88 def load_from_state(self, state: Dict[str, Any]) -> None:
89 """Loads the queue from the previous state."""
90 self.queue = state["queue"]
92 def _size(self) -> int:
93 """Returns the current size of the queue."""
94 return len(self.queue)
96 def _is_sampling_ready(self) -> bool:
97 return self._size() > 0
99 def __len__(self) -> int:
100 with self.mutex:
101 return self._size()
103 def empty(self) -> bool:
104 with self.mutex:
105 return self._size() == 0
107 def _is_full(self) -> bool:
108 return self._size() >= self.maxsize
110 def _on_full(self, block, timeout) -> bool:
111 """Handles the behavior when the queue is full,
112 blocking the operation or raising an exception.
114 ### Parameters
115 - **block** (`bool`): Indicates whether the operation should block
116 until the queue is not full.
117 - **timeout** (`float`): The maximum time (in seconds) to wait if blocking is enabled.
119 ### Returns
120 - `bool`: Always returns `True` if the operation blocks successfully and
121 the queue is no longer full."""
123 if not block:
124 raise Full
125 self.not_full.wait(timeout)
126 return True
128 def _get(self) -> Union[List[SimulationData], SimulationData]:
129 """Returns the `SimulationData` object from the queue."""
130 return self.queue.popleft()
132 def _put(self, item: SimulationData) -> None:
133 """Inserts `SimulationData` object in the queue."""
134 self.queue.append(item)
136 def get(
137 self, block: bool = True, timeout: Optional[float] = None
138 ) -> Union[List[SimulationData], SimulationData]:
139 """Retrieves an item from the queue, blocking if necessary until data is available.
141 ### Parameters
142 - **block** (`bool`, default=`True`):
143 - If `True`, the method will block until data is available or the timeout expires.
144 - If `False`, raises `NotEnoughData` if data is not immediately available.
145 - **timeout** (`Optional[float]`):
146 - Maximum time (in seconds) to wait for data if `block` is `True`.
147 - If `None`, waits indefinitely.
148 - Must be a non-negative number if provided.
150 ### Returns
151 - `Union[List[SimulationData], SimulationData]`: The retrieved item from the queue.
153 ### Raises
154 - `NotEnoughData`: If the queue does not have enough data to retrieve:
155 - When `block` is `False` and data is unavailable.
156 - When `block` is `True` and the timeout expires.
157 - `ValueError`: If `timeout` is a negative number."""
159 with self.not_empty:
160 if not block:
161 if not self._is_sampling_ready():
162 raise NotEnoughData
163 elif timeout is None:
164 while not self._is_sampling_ready():
165 self.not_empty.wait()
166 elif timeout < 0:
167 raise ValueError("'timeout' must be a non-negative number")
168 else:
169 endtime = time() + timeout
170 while not self._is_sampling_ready():
171 remaining = endtime - time()
172 if remaining <= 0.0:
173 raise NotEnoughData
174 self.not_empty.wait(remaining)
175 item = self._get()
176 self.not_full.notify()
177 self.put_get_metric.dec(1)
178 return item
180 def put(
181 self, item: SimulationData, block: bool = True, timeout: Optional[float] = None
182 ) -> None:
183 """Adds an item to the queue, blocking if necessary until space is available.
185 - If the queue has a finite `maxsize` and is full:
186 - If `block` is `False`, checks whether the item can be
187 conditionally added based on the `_on_full` method.
188 - If `block` is `True`, waits until space is available or the timeout expires.
189 - Adds the item to the queue if space is available or the `_on_full` method allows it.
190 - Notifies waiting threads if the queue is no longer empty after the addition.
192 ### Parameters
193 - **item** (`SimulationData`): The item to be added to the queue.
194 - **block** (`bool`, default=`True`):
195 - If `True`, the method will block until space becomes available or the timeout expires.
196 - If `False`, raises `Full` if the queue is full.
197 - **timeout** (`Optional[float]`):
198 - Maximum time (in seconds) to wait for space in the queue if `block` is `True`.
199 - If `None`, waits indefinitely.
200 - Must be a non-negative number if provided.
202 ### Raises
203 - `Full`:
204 - If the queue is full, `block` is `False`, and the item cannot be conditionally added.
205 - If the queue is full, `block` is `True`, and the timeout expires.
206 - `ValueError`: If `timeout` is a negative number."""
208 add_item = True # Do we actually add the item to the queue
209 with self.not_full:
210 if self.maxsize > 0:
211 if not block:
212 if self._is_full():
213 add_item = self._on_full(block=block, timeout=timeout)
214 elif timeout is None:
215 while self._is_full():
216 add_item = self._on_full(block=block, timeout=timeout)
217 elif timeout < 0:
218 raise ValueError("'timeout' must be a non-negative number")
219 else:
220 endtime = time() + timeout
221 while self._is_full():
222 remaining = endtime - time()
223 if remaining <= 0.0:
224 raise Full
225 add_item = self._on_full(block=block, timeout=remaining)
226 if add_item:
227 self._put(item)
228 self.put_get_metric.inc(1)
229 if self._is_sampling_ready():
230 self.not_empty.notify()
232 def compute_buffer_statistics(self) -> Tuple[NDArray, NDArray]:
233 """Not needed."""
234 raise NotImplementedError
237class CounterMixin:
238 """A Mixin to track the number of times each sample has been seen in the reservoir.
240 This class adds counting functionality to the base class, using a `Counter` to track
241 how many times each sample has been added to or retrieved from the reservoir."""
243 def __init__(self):
245 self.seen_ctr = Counter()
247 def put(
248 self, item: SimulationData, block: bool = True, timeout: Optional[float] = None
249 ) -> None:
250 """Adds an item to the reservoir and tracks how many times it has been seen."""
252 watched_item = Sample(item)
253 super().put(watched_item, block, timeout) # type: ignore
255 def _get_with_eviction(
256 self, index: Union[List, int]
257 ) -> Union[List[SimulationData], SimulationData]:
258 """Retrieves items from the reservoir, evicting them if necessary,
259 and updates their seen counts."""
261 items: Union[List[Sample], Sample] = super()._get_with_eviction( # type: ignore
262 index
263 )
264 if isinstance(items, list):
265 item_data = []
266 item_seen = []
267 for item in items:
268 item.seen += 1
269 item_seen.append(item.seen)
270 item_data.append(item.data)
271 self.seen_ctr += Counter(item_seen)
272 return item_data
273 self.seen_ctr += Counter([items.seen + 1])
275 return items.data
277 def _get_without_eviction(
278 self, index: Union[List, int]
279 ) -> Union[List[SimulationData], SimulationData]:
280 """Retrieves items from the reservoir without eviction, updating their seen counts."""
282 items: Union[List[Sample], Sample] = super()._get_without_eviction( # type: ignore
283 index
284 )
285 if isinstance(items, list):
286 for item in items:
287 item.seen += 1
288 return [item.data for item in items]
289 items.seen += 1
290 return items.data
292 def _get(self) -> Union[List[SimulationData], SimulationData]:
293 """Retrieves a single item from the reservoir and updates its seen count."""
295 if hasattr(super(), "_get_without_eviction"):
296 return super()._get() # type: ignore
298 item: Sample = super()._get() # type: ignore
299 self.seen_ctr += Counter([item.seen + 1])
300 return item.data
302 def __repr__(self) -> str:
303 s = super().__repr__()
304 return f"{s}: {self.seen_ctr}"
307class ReceptionDependant:
308 """A Mixin to signal when the reception of data is over.
310 It is particularly useful for cases where actions depend on the completion
311 of data reception, such as emptying or processing the contents of a reservoir. It
312 provides a method to indicate that the reception process is complete, allowing other
313 parts of the system to act accordingly.
315 ### Attributes:
316 - **_is_reception_over** (`bool`): A flag indicating whether data reception is over.
317 """
319 def __init__(self) -> None:
321 self._is_reception_over = False
323 def signal_reception_over(self) -> None:
324 self._is_reception_over = True
327class SamplingDependant:
328 """A Mixin to receive the signal that the reservoir is ready for sampling.
329 Espcieally useful with thresholds."""
331 def _is_sampling_ready(self) -> bool:
332 return True
335class ThresholdMixin(SamplingDependant, ReceptionDependant):
336 """A Mixin that blocks operations when not enough data is available in the container,
337 based on a threshold.
339 This Mixin is used to manage situations where certain operations should be blocked or delayed
340 until the data in the container reaches a specified threshold. It extends the functionality
341 of both `SamplingDependant` and `ReceptionDependant` mixins, and adds a threshold-based
342 control mechanism for readiness.
344 ### Attributes
345 - **threshold** (`int`): The minimum number of items required in the container
346 for sampling to proceed."""
348 def __init__(self, threshold: int) -> None:
350 ReceptionDependant.__init__(self)
351 self.threshold = threshold
353 def _is_sampling_ready(self: QueueProtocol) -> bool:
354 """Determines if sampling is ready based on the threshold
355 and available data in the container."""
357 is_ready = super()._is_sampling_ready() and (self._size() > self.threshold) # type: ignore
358 return is_ready
360 def set_threshold(self, t: int) -> None:
361 """Sets a new threshold for when sampling is considered ready."""
362 self.threshold = t
364 def signal_reception_over(self: QueueProtocol) -> None:
365 """Signals that the reception of data is complete,
366 sets the threshold to 0, and notifies any waiting processes."""
367 with self.mutex:
368 super().signal_reception_over() # type: ignore
369 self.set_threshold(0)
370 self.not_empty.notify()
373class ReadingWithoutEvictionMixin(ReceptionDependant):
374 """A Mixin that retrieves data without eviction until data reception is over.
376 This Mixin allows data to be retrieved from the container without eviction,
377 but once data reception is marked as complete, it switches to eviction-based retrieval.
378 It extends the `ReceptionDependant` Mixin to track the reception status."""
380 def _get_from_index(
381 self: QueueProtocol, index: Union[list, int]
382 ) -> Union[List[SimulationData], SimulationData]:
383 """Retrieves data from the container either with or without eviction,
384 depending on whether the reception process has been completed."""
386 if self._is_reception_over:
387 return self._get_with_eviction(index)
388 return self._get_without_eviction(index)
391class RandomQueue(BaseQueue):
392 """A queue that randomly selects items from the container and evicts them upon retrieval.
394 This queue operates by randomly selecting an item from the list container. Upon reading,
395 the selected item is removed (evicted) from the queue. The queue supports a configurable
396 number of pseudo-epochs to control the number of times the process of eviction and item
397 retrieval is repeated."""
399 def __init__(self, maxsize: int, pseudo_epochs: int = 1) -> None:
401 super().__init__(maxsize)
402 self.pseudo_epochs = pseudo_epochs
403 self.current_epoch: int = 0
404 self.epoch_access_count: int = 0
405 self.total_samples_in_epoch: int = 0
407 def _init_queue(self) -> None:
408 """Initializes the internal list (queue)."""
409 self.queue: List[SimulationData] = [] # type: ignore
411 def _get_index(self) -> int:
412 """Randomly selects an index from the queue."""
413 index = random.randrange(self._size())
414 return index
416 def _get_without_eviction(
417 self, index: Union[list, int]
418 ) -> Union[List[SimulationData], SimulationData]:
419 """Retrieves items without eviction."""
421 if isinstance(index, int):
422 return self.queue[index]
423 elif isinstance(index, list):
424 return [self.queue[i] for i in index]
425 return SimulationData(0, 0, {}, [])
427 def _get_with_eviction(
428 self, index: Union[list, int]
429 ) -> Union[List[SimulationData], SimulationData]:
430 """Retrieves items and evicts them."""
432 if isinstance(index, int):
433 item = self.queue[index]
434 del self.queue[index]
435 return item
436 if isinstance(index, list):
437 sorted_indices = sorted(index, reverse=True)
438 items = [self.queue[i] for i in index]
439 for i in sorted_indices:
440 del self.queue[i]
441 return items
443 def _get_from_index(
444 self, index: Union[List, int]
445 ) -> Union[List[SimulationData], SimulationData]:
446 """Retrieves data based on the current pseudo-epoch setting."""
448 if (
449 self.pseudo_epochs <= 1
450 or self.current_epoch >= self.pseudo_epochs
451 ):
452 return self._get_with_eviction(index)
454 # still in pseudo-epoch mode
455 result = self._get_without_eviction(index)
457 if isinstance(index, int):
458 self.epoch_access_count += 1
459 elif isinstance(index, list):
460 self.epoch_access_count += len(index)
462 self._check_epoch_completion()
464 return result
466 def _get(self) -> Union[List[SimulationData], SimulationData]:
467 """Retrieves a random item from the queue and evicts it, if necessary."""
469 index = self._get_index()
470 return self._get_from_index(index)
472 def _put(self, item: SimulationData) -> None:
473 """Inserts SimulationData object in the queue."""
474 self.queue.append(item)
476 # track total samples during first epoch
477 if self.pseudo_epochs > 1 and self.current_epoch == 0:
478 self.total_samples_in_epoch = len(self.queue)
480 def _check_epoch_completion(self) -> None:
481 """Check if we've completed an epoch and should start the next one."""
482 if self.pseudo_epochs <= 1:
483 return
485 # if we've seen all samples in this epoch
486 if self.epoch_access_count >= self.total_samples_in_epoch:
487 self.current_epoch += 1
488 self.epoch_access_count = 0
489 logger.debug(
490 f"Running epoch={self.current_epoch}/{self.pseudo_epochs}"
491 )
493 # if we've completed all pseudo epochs, now we can evict
494 if self.current_epoch >= self.pseudo_epochs:
495 logger.info("All pseudo epochs completed, switching to eviction mode.")
498class ReservoirQueue(ReadingWithoutEvictionMixin, RandomQueue):
499 """The `ReservoirQueue` supports adding items to the queue, evicting items when the queue
500 reaches its maximum size, and retrieving items without eviction during sampling."""
502 def __init__(self, maxsize: int) -> None:
504 ReadingWithoutEvictionMixin.__init__(self)
505 RandomQueue.__init__(self, maxsize)
506 self.put_ctr: int = 0
508 def _evict(self, index: int) -> None:
509 """Removes an item from the queue at the specified index."""
510 del self.queue[index]
512 def _on_full(self, block, timeout) -> bool:
513 """Handles eviction when the queue is full."""
515 index = random.randrange(self.put_ctr)
516 if index < self.maxsize:
517 self._evict(index)
518 return True
519 return False
521 def put(self, item, block=True, timeout=None) -> None:
522 """Adds an item to the queue, evicting items as needed."""
524 with self.mutex:
525 RandomQueue.put(self, item, block, timeout)
526 self.put_ctr += 1
529class RandomEvictOnWriteQueue(ReadingWithoutEvictionMixin, RandomQueue):
530 """A queue that overwrites random samples that have already been seen.
532 This queue implements a variant of reservoir sampling where items are randomly evicted
533 when the queue reaches its maximum size. The key difference is that once an item has been
534 "seen", it can be overwritten randomly, ensuring that older items are eventually evicted
535 while still maintaining randomness in the selection.
537 The queue maintains two lists: `not_seen` for items that have not been processed yet,
538 and `seen` for items that have already been processed. Items are added to the `not_seen`
539 list and moved to the `seen` list once they are accessed.
541 ### Attributes
542 - **not_seen** (`list`): A list holding items that have not yet been accessed.
543 - **seen** (`list`): A list holding items that have already been accessed."""
545 def __init__(self, maxsize: int) -> None:
547 ReadingWithoutEvictionMixin.__init__(self)
548 RandomQueue.__init__(self, maxsize)
550 def _init_queue(self) -> None:
551 """Initializes `seen` and `not_seen` lists."""
552 self.not_seen: list = []
553 self.seen: list = []
555 def save_state(self) -> Dict[str, List]:
556 """Saves the current state of the queue, including the `not_seen` and `seen` lists."""
557 return {"not_seen": self.not_seen, "seen": self.seen}
559 def load_from_state(self, state: Dict[str, List]) -> None:
560 """Loads the queue state from a saved state dictionary."""
562 self.not_seen = state["not_seen"]
563 self.seen = state["seen"]
565 def __repr__(self) -> str:
566 """Returns a string representation of the queue, showing the size of the
567 `not_seen` and `seen` lists."""
569 return (
570 f"{self.__class__.__name__}:"
571 f"not yet seen samples {len(self.not_seen)}, already seen {len(self.seen)}"
572 )
574 def _size(self) -> int:
575 """Returns the total size of the queue (sum of `not_seen` and `seen`)."""
576 return len(self.not_seen) + len(self.seen)
578 def _evict(self, index) -> None:
579 """Evicts the item at the given index from the `seen` list."""
580 del self.seen[index]
582 def _on_full(self, block, timeout) -> bool:
583 """Handles eviction when the queue is full, randomly evicting
584 an item from the `seen` list."""
586 if not block:
587 if len(self.seen) == 0:
588 raise Full
589 else:
590 if len(self.seen) == 0:
591 timed_in = self.not_full.wait(timeout)
592 if not timed_in:
593 return True
594 index = random.randrange(len(self.seen))
595 self._evict(index)
596 return True
598 def _get_with_eviction(
599 self, index: Union[list, int]
600 ) -> Union[List[SimulationData], SimulationData]:
601 """Retrieves an item by index, evicting it if necessary."""
603 if isinstance(index, int):
604 if index < len(self.not_seen):
605 item = self.not_seen[index]
606 del self.not_seen[index]
607 else:
608 index = index - len(self.not_seen)
609 item = self.seen[index]
610 del self.seen[index]
611 return item
613 if isinstance(index, list):
614 not_seen_index = sorted(
615 [i for i in index if i < len(self.not_seen)], reverse=True
616 )
617 seen_index = sorted(
618 [i - len(self.not_seen) for i in index if i >= len(self.not_seen)],
619 reverse=True,
620 )
621 items = [self.not_seen[i] for i in not_seen_index]
622 items += [self.seen[i] for i in seen_index]
623 for i in not_seen_index:
624 del self.not_seen[i]
625 for i in seen_index:
626 del self.seen[i]
627 return items
629 def _get_without_eviction(
630 self, index: Union[list, int]
631 ) -> Union[List[SimulationData], SimulationData]:
632 """Retrieves an item without eviction, moving it from
633 `not_seen` to `seen` if accessed."""
635 if isinstance(index, int):
636 if index < len(self.not_seen):
637 item = self.not_seen[index]
638 del self.not_seen[index]
639 self.seen.append(item)
640 else:
641 index = index - len(self.not_seen)
642 item = self.seen[index]
643 return item
645 if isinstance(index, list):
646 not_seen_index = sorted(
647 [i for i in index if i < len(self.not_seen)], reverse=True
648 )
649 seen_index = sorted(
650 [i - len(self.not_seen) for i in index if i >= len(self.not_seen)],
651 reverse=True,
652 )
653 items = []
654 for i in not_seen_index:
655 item = self.not_seen[i]
656 items.append(item)
657 del self.not_seen[i]
658 self.seen.append(item)
659 items += [self.seen[i] for i in seen_index]
660 return items
662 def _put(self, item) -> None:
663 """Adds an item to the `not_seen` list."""
664 self.not_seen.append(item)
666 def all_seen(self) -> bool:
667 """Returns `True` if there are no more items in the `not_seen` list,
668 indicating that all items have been processed."""
669 return len(self.not_seen) == 0
672class SimPairedQueue(ReadingWithoutEvictionMixin, BaseQueue):
673 """A specialized queue that maintains a 2D matrix of simulation data samples,
674 organized by simulation ID and time step.
676 This queue is designed to efficiently store and retrieve pairs of data points
677 from the same simulation trajectory at different time steps. It uses a probability-based
678 sampling mechanism that prioritizes time pairs with smaller time differences that
679 have been sampled less frequently, implementing a form of importance sampling.
681 The queue maintains a count of how many times each time delta (difference between time steps)
682 has been sampled, allowing it to adaptively focus on undersampled time differences.
684 ### Parameters:
685 - **maxsize** (`int`): Maximum number of items to store in the queue
686 - **sweep_size** (`int`): Number of distinct simulation trajectories
687 - **nb_time_steps** (`int`): Number of time steps per simulation trajectory
688 - **comm_size** (`int`, default=`1`): Number of parallel processes or threads
689 """
691 def __init__(
692 self, maxsize: int, sweep_size: int, nb_time_steps: int, comm_size: int
693 ) -> None:
695 self.nb_simulations = sweep_size // comm_size
696 self.nb_time_steps = nb_time_steps
697 self.comm_size = comm_size
698 ReadingWithoutEvictionMixin.__init__(self)
699 BaseQueue.__init__(self, maxsize)
701 def _init_queue(self) -> None:
702 """Initializes the queue with a specific data-structure."""
703 self.queue: NDArray = np.empty( # type: ignore
704 (self.nb_simulations, self.nb_time_steps), dtype=object
705 )
706 self.queue_status: NDArray = np.zeros(
707 (self.nb_simulations, self.nb_time_steps), dtype=np.int8
708 )
709 self.delta_counts: NDArray = np.zeros((self.nb_time_steps - 1), dtype=int)
710 self.sim_counts: NDArray = np.zeros((self.nb_simulations), dtype=int)
711 self.sim_reception: NDArray = np.zeros((self.nb_simulations), dtype=int)
712 self.avail_ts: NDArray = np.zeros((self.nb_simulations), dtype=int)
713 self.cumul_seen: NDArray = np.zeros((self.nb_simulations), dtype=int)
715 self.seen_ctr: Counter = Counter()
717 def save_state(self) -> Dict:
718 """Saves the current state of the queue and already seen time deltas."""
719 return {
720 "queue": self.queue,
721 "queue_status": self.queue_status,
722 "delta_counts": self.delta_counts,
723 "sim_counts": self.sim_counts,
724 "sim_reception": self.sim_reception,
725 }
727 def load_from_state(self, state: Dict) -> None:
728 """Loads the queue state from a saved state dictionary."""
730 self.queue = state["queue"]
731 self.queue_status = state["queue_status"]
732 self.delta_counts = state["delta_counts"]
733 self.sim_counts = state["sim_counts"]
734 self.sim_reception = state["sim_reception"]
735 self.avail_ts = np.array(
736 [(sim_queue is not None).sum() for sim_queue in self.queue], dtype=int # type: ignore
737 )
738 self.cumul_seen = (
739 np.vectorize(lambda x: x.seen if hasattr(x, "seen") else 0)(self.queue)
740 .sum(axis=1)
741 .astype(np.int32)
742 )
744 def is_queue_empty(self, sim_id: int) -> bool:
745 if (
746 self.avail_ts[sim_id] == 1
747 and self.sim_reception[sim_id] == self.nb_time_steps
748 ):
749 time_step = np.where(self.queue_status[sim_id] == 1)[0].item()
750 index = sim_id, time_step
751 self._evict(index)
752 return True
753 elif self.avail_ts[sim_id] == 0:
754 return True
755 else:
756 return False
758 def get_evicted_index(self, sim_id: int, std_factor: int = 4) -> Tuple[int, int]:
759 """Computes the probabilities for evicting items from the queue."""
760 counts = np.vectorize(lambda x: x.seen if hasattr(x, "seen") else 0)(
761 self.queue[sim_id]
762 ).astype(np.float32)
763 if std_factor > 0:
764 mean = self.nb_time_steps / 2
765 normal_pdf_values = norm.pdf(
766 np.arange(self.nb_time_steps),
767 loc=mean,
768 scale=mean / std_factor,
769 )
770 logits = counts * normal_pdf_values
771 else:
772 logits = counts
773 time_step = np.random.choice(a=len(counts), p=logits / np.sum(logits))
774 return (sim_id, time_step)
776 def get_sampled_indices(self, sim_id: int) -> List[Tuple[int, int]]:
777 """
778 Sample a pair of indices from a queue based on their time differences, with a bias towards
779 smaller delta values after applying a transformation with a delta counter.
780 """
781 # Extract indices of non-None elements in the queue
782 (valid_ts,) = np.nonzero(self.queue[sim_id] != None)
783 assert len(valid_ts) == self.avail_ts[sim_id], (
784 f"[Inconsistent state] Queue for sim_id={sim_id}: "
785 f"avail_ts={self.avail_ts[sim_id]}, valid_ts={len(valid_ts)}"
786 )
788 # Get all possible pairs of indices (i,j) where i < j using upper triangular indices
789 i, j = np.triu_indices(len(valid_ts), k=1)
790 # Calculate the absolute time differences between all valid timestamps
791 deltas = np.abs(valid_ts[i] - valid_ts[j])
792 # Define a mapping function that transforms raw deltas using delta_counts lookup
793 mapped_deltas = np.take(self.delta_counts, deltas - 1)
794 # Sample an index based on the calculated weights
795 logits = -mapped_deltas
796 weights = np.exp(logits - np.max(logits))
797 sampled_idx = np.random.choice(len(mapped_deltas), p=weights / np.sum(weights))
798 # Return the sampled indices
799 t0, t1 = valid_ts[i[sampled_idx]], valid_ts[j[sampled_idx]]
800 return [(sim_id, t0), (sim_id, t1)]
802 def get_sim_logits(self, mode: str) -> Tuple[NDArray, NDArray]:
803 assert mode in ["evict", "sample"], "mode must be either 'evict' or 'sample'"
804 avail_sim = (
805 np.where(self.cumul_seen != 0)[0]
806 if mode == "evict"
807 else np.where(self.avail_ts > 1)[0]
808 )
809 recep_over_sim = np.where(self.sim_reception == self.nb_time_steps)[0]
810 intersec = np.intersect1d(avail_sim, recep_over_sim)
811 avail_sim = intersec if intersec.size > 0 else avail_sim
813 assert avail_sim.size > 0, " ".join( # type: ignore
814 f"[Inconsistent state] mode={mode}",
815 f"avail_ts={self.avail_ts}",
816 f"sim_counts={self.sim_counts}",
817 f"avail_sim={avail_sim}",
818 f"recep_over_sim={recep_over_sim}"
819 f"is_reception_over={self._is_reception_over}"
820 )
822 if mode == "evict":
823 logits = self.sim_counts[avail_sim]
824 elif mode == "sample":
825 logits = -(self.sim_counts[avail_sim] / self.avail_ts[avail_sim])
826 logits = np.exp(logits - np.max(logits))
827 return avail_sim, logits
829 def _evict_most_seen(self) -> None:
830 avail_sim, logits = self.get_sim_logits(mode="evict")
831 sim_id = np.random.choice(avail_sim, p=logits / np.sum(logits))
832 index = self.get_evicted_index(sim_id)
833 self._evict(index)
834 self.is_queue_empty(sim_id)
836 def put(
837 self, item: SimulationData, block: bool = True, timeout: Optional[float] = None
838 ) -> None:
839 """Adds an item to the reservoir and tracks how many times it has been seen."""
841 watched_item = Sample(item)
842 super().put(watched_item, block, timeout) # type: ignore
844 def _put(self, item: Sample) -> None: # type: ignore
845 """Adds an item to the queue acording to its sim_id and time_step."""
846 sim_id, time_step = item.data.simulation_id, item.data.time_step
847 rank_sim_id = sim_id // self.comm_size
848 self.queue[rank_sim_id][time_step] = item
849 self.queue_status[rank_sim_id][time_step] = 1
850 self.sim_reception[rank_sim_id] += 1
851 self.avail_ts[rank_sim_id] += 1
853 def _size(self) -> int:
854 """Returns the total size of the queue."""
855 return np.sum(self.avail_ts)
857 def _evict(self, index) -> None:
858 """Evicts the item at the given index."""
859 sim_id, time_step = index
860 assert self.queue[sim_id][time_step] is not None, (
861 f"Evicting empty slot {sim_id},{time_step}"
862 )
863 self.seen_ctr += Counter([self.queue[sim_id][time_step].seen])
864 self.cumul_seen[sim_id] -= self.queue[sim_id][time_step].seen
865 self.queue[sim_id][time_step] = None
866 self.avail_ts[sim_id] -= 1
867 self.queue_status[sim_id][time_step] = -1
869 def _on_full(self, block, timeout) -> bool:
870 """Handles eviction when the queue is full, randomly evicting
871 an item from the buffer."""
873 if not block:
874 raise Full
875 else:
876 timed_in = self.not_full.wait(timeout)
877 if not timed_in:
878 return True
879 self._evict_most_seen()
880 return True
882 def _get_index(self) -> List[Tuple[int, int]]:
883 """Selects a pair of indices from the queue for the same simulation ID."""
884 avail_sim, logits = self.get_sim_logits(mode="sample")
885 sim_id = np.random.choice(avail_sim, p=logits / np.sum(logits))
886 indices = self.get_sampled_indices(sim_id)
887 self.sim_counts[sim_id] += 1
888 self.delta_counts[abs(indices[1][1] - indices[0][1]) - 1] += 1
889 self.cumul_seen[sim_id] += 2
890 return indices
892 def _get_with_eviction(
893 self, indices: List[Tuple[int, int]]
894 ) -> List[SimulationData]:
895 """Retrieves and evicts items at the specified indices."""
896 items = [self.queue[sim_id][time_step] for sim_id, time_step in indices]
897 for item in items:
898 item.seen += 1
899 self._evict_most_seen()
900 return [item.data for item in items]
902 def _get_without_eviction(
903 self, indices: List[Tuple[int, int]]
904 ) -> List[SimulationData]:
905 """Retrieves items at the specified indices."""
906 items = [self.queue[sim_id][time_step] for sim_id, time_step in indices]
907 for item in items:
908 item.seen += 1
909 return [item.data for item in items]
911 def _get(self) -> List[SimulationData]:
912 """Retrieves a random item from the queue and evicts it, if necessary."""
914 index = self._get_index()
915 return self._get_from_index(index) # type: ignore
918class BatchGetMixin(SamplingDependant, ReceptionDependant):
919 """A Mixin that retrieves data from the reservoir in batches instead of individual sampling.
921 This class provides functionality to retrieve multiple items (a batch) from the reservoir
922 instead of retrieving them one at a time. It ensures that the reservoir has enough items
923 to serve a batch, and if the reception of data is over, it limits the batch size to the
924 remaining number of items in the reservoir.
926 ### Attributes
927 - **batch_size** (`int`): The desired number of items to retrieve in each batch."""
929 def __init__(self, batch_size: int) -> None:
931 ReceptionDependant.__init__(self)
932 self.batch_size = batch_size
934 def _is_sampling_ready(self: QueueProtocol) -> bool:
935 """Checks if the reservoir has enough items to serve a full batch, considering
936 the batch size and whether data reception is over."""
938 is_ready = super()._is_sampling_ready() # type: ignore
939 if self._is_reception_over:
940 # No more data will arrive, we may not be able to serve batch_size data
941 return is_ready
942 return is_ready and (self._size() >= self.batch_size)
944 def _get(self: QueueProtocol) -> Union[List[SimulationData], SimulationData]:
945 """Retrieves a batch of items from the reservoir. The number of items returned
946 is limited by the batch size, and it ensures that the batch is randomly selected
947 from the available items in the reservoir."""
949 if not self._is_reception_over:
950 population = self.batch_size
951 else:
952 population = min(self.batch_size, self._size())
953 indices = sorted(random.sample(range(self._size()), k=population), reverse=True)
954 items = self._get_from_index(indices)
955 return items
958class FIFO(CounterMixin, ReceptionDependant, BaseQueue):
959 """First In First Out (FIFO) Queue.
961 A queue implementation that follows the FIFO principle, where the first item added
962 is the first one to be retrieved. This queue also supports counting the number of
963 times each item has been seen (via `CounterMixin`) and can signal when data reception
964 is over (via `ReceptionDependant`)."""
966 def __init__(self, maxsize: int = 0) -> None:
968 CounterMixin.__init__(self)
969 ReceptionDependant.__init__(self)
970 BaseQueue.__init__(self, maxsize)
973class FIRO(CounterMixin, ThresholdMixin, RandomQueue):
974 """First In Random Out (FIRO) Queue.
976 A queue implementation that combines elements of FIFO (First In) and random eviction
977 (Random Out). It follows a First In strategy to decide which item is added first,
978 but when retrieving, items are evicted randomly from the queue.
979 The queue also supports counting the number of times each item has been seen
980 (via `CounterMixin`) and can block operations when the threshold is not met
981 (via `ThresholdMixin`).
983 ### Attributes
984 - **threshold** (`int`): The minimum number of items required in the queue to
985 proceed with sampling.
986 - **pseudo_epochs** (`int`): The number of epochs. Defaults to 1."""
988 def __init__(self, maxsize: int, threshold: int, pseudo_epochs: int = 1) -> None:
990 assert threshold <= maxsize
991 CounterMixin.__init__(self)
992 RandomQueue.__init__(self, maxsize, pseudo_epochs)
993 ThresholdMixin.__init__(self, threshold)
996class Reservoir(CounterMixin, ThresholdMixin, RandomEvictOnWriteQueue):
997 """First In Random Out (FIRO) with eviction on write.
999 This queue implementation combines the First In strategy (FIFO) with
1000 random eviction upon writing. It behaves like a traditional FIFO queue for insertion,
1001 but when the queue reaches its capacity, it randomly evicts one of the previously added
1002 elements to make room for the new item. The class also tracks the number of times
1003 each item has been seen (via `CounterMixin`) and enforces a threshold for sampling
1004 (via `ThresholdMixin`).
1006 ### Attributes
1007 - **maxsize** (`int`): The maximum size of the queue. Once the size exceeds this limit,
1008 eviction occurs.
1009 - **threshold** (`int`): The minimum number of items required in the queue to allow sampling.
1010 """
1012 def __init__(self, maxsize: int, threshold: int) -> None:
1014 assert threshold <= maxsize
1015 CounterMixin.__init__(self)
1016 ThresholdMixin.__init__(self, threshold)
1017 RandomEvictOnWriteQueue.__init__(self, maxsize)
1020class BatchReservoir(BatchGetMixin, Reservoir):
1021 """A queue implementing a batch version of the Reservoir sampling algorithm.
1023 This class extends the `Reservoir` queue to support batch sampling.
1024 It combines the functionality of the Reservoir sampling algorithm
1025 (with random eviction on write) with the ability to retrieve
1026 data in batches rather than individual samples.
1027 It ensures that the queue never exceeds a specified maximum size and
1028 maintains a threshold for sampling readiness.
1030 ### Attributes
1031 - **threshold** (`int`): The minimum number of items required in the queue to allow sampling.
1032 - **batch_size** (`int`): The number of items to retrieve in each batch."""
1034 def __init__(self, maxsize: int, threshold: int, batch_size: int) -> None:
1036 assert threshold <= maxsize
1037 assert batch_size <= maxsize
1038 Reservoir.__init__(self, maxsize, threshold)
1039 BatchGetMixin.__init__(self, batch_size)
1042class SimPairedReservoir(ThresholdMixin, SimPairedQueue):
1043 """A reservoir sampling queue that retrieves paired simulation data.
1045 This class extends the SimPairedQueue to incorporate reservoir sampling techniques
1046 with a threshold-based sampling readiness condition. It's designed for scenarios where
1047 paired samples from the same simulation are needed for training, while also maintaining
1048 a diverse and representative set of samples through reservoir sampling.
1050 It maintains an efficient data structure for storing simulation trajectories and implements
1051 adaptive sampling strategies that prioritize undersampled time differences.
1053 ### Attributes:
1054 - **maxsize** (`int`): Maximum number of items to store in the reservoir
1055 - **threshold** (`int`): Minimum number of items required before sampling
1056 - **sweep_size** (`int`): Number of simulation trajectories to manage
1057 - **nb_time_steps** (`int`): Number of time steps in each simulation trajectory
1058 """
1060 def __init__(
1061 self,
1062 maxsize: int,
1063 threshold: int,
1064 sweep_size: int,
1065 nb_time_steps: int,
1066 comm_size: int = 1,
1067 ) -> None:
1069 assert threshold <= maxsize
1070 ThresholdMixin.__init__(self, threshold)
1071 SimPairedQueue.__init__(self, maxsize, sweep_size, nb_time_steps, comm_size)
1074class BufferType(Enum):
1075 """Enum representing the different types of buffer strategies available for sampling.
1077 - FIFO = 0
1078 - FIRO = 1
1079 - Reservoir = 2
1080 - BatchReservoir = 3
1081 - SimPairedReservoir = 4"""
1083 FIFO = 0
1084 FIRO = 1
1085 Reservoir = 2
1086 BatchReservoir = 3
1087 SimPairedReservoir = 4
1090def make_buffer(
1091 buffer_size: int,
1092 buffer_t: BufferType = BufferType.FIFO,
1093 per_server_watermark: Optional[int] = None,
1094 pseudo_epochs: Optional[int] = None,
1095 batch_size: Optional[int] = None,
1096 sweep_size: Optional[int] = None,
1097 nb_time_steps: Optional[int] = None,
1098 comm_size: Optional[int] = 1,
1099) -> BaseQueue:
1100 """Factory function to create different types of buffers based on the specified buffer type.
1102 This function initializes and returns a buffer (queue) object based on the given buffer size
1103 and type. The available buffer types are defined by the `BufferType` enum, and each type
1104 has specific requirements. The function raises a `ValueError` if the necessary parameters
1105 are not provided for the selected buffer type.
1107 ### Parameters
1108 - **buffer_size** (`int`): The maximum size of the buffer.
1109 - **buffer_t** (BufferType, optional): The type of buffer to create.
1110 Default is `BufferType.FIFO`.
1111 - **per_server_watermark** (`int`, optional): A threshold used in some buffer types
1112 (e.g., `Reservoir`, `FIRO`, and `BatchReservoir`).
1113 - **pseudo_epochs** (`int`, optional): A parameter for `FIRO` buffer type
1114 to define the number of epochs for the buffer.
1115 - **batch_size** (`int`, optional):A parameter for `BatchReservoir` buffer type
1116 to defien the size of a batch.
1117 - **sweep_size** (`int`, optional): The number of distinct simulation trajectories.
1118 - **nb_time_steps** (`int`, optional): The number of time steps per simulation trajectory.
1119 - **comm_size** (`int`, optional): The number of parallel processes or threads.
1121 ### Returns
1122 `BaseQueue`: A buffer object of the specified type."""
1124 if not buffer_t:
1125 logger.warning("Buffer type is not specified. Defaulting to `FIFO`.")
1126 buffer_t = BufferType.FIFO
1128 if buffer_t is BufferType.FIFO:
1129 return FIFO(buffer_size)
1131 if buffer_t is BufferType.FIRO:
1132 if per_server_watermark is None or pseudo_epochs is None:
1133 raise ValueError(
1134 "`per_server_watermark` and `pseudo_epochs`"
1135 " are required for `FIRO` buffers."
1136 )
1137 return FIRO(buffer_size, per_server_watermark, pseudo_epochs)
1139 if buffer_t is BufferType.Reservoir:
1140 if per_server_watermark is None:
1141 raise ValueError(
1142 "`per_server_watermark` is required for Reservoir buffers."
1143 )
1144 return Reservoir(buffer_size, per_server_watermark)
1146 if buffer_t is BufferType.BatchReservoir:
1147 if per_server_watermark is None or batch_size is None:
1148 raise ValueError(
1149 "`per_server_watermark` and `batch_size`"
1150 " are required for `BatchReservoir` buffers."
1151 )
1152 return BatchReservoir(buffer_size, per_server_watermark, batch_size)
1154 if buffer_t is BufferType.SimPairedReservoir:
1155 if per_server_watermark is None or sweep_size is None or nb_time_steps is None:
1156 raise ValueError(
1157 "`per_server_watermark`, `nb_simulations` and `nb_time_steps` are required for "
1158 "`SimPairedReservoir` buffers."
1159 )
1161 if comm_size is None:
1162 comm_size = 1
1164 return SimPairedReservoir(
1165 buffer_size, per_server_watermark, sweep_size, nb_time_steps, comm_size
1166 )
1168 raise ValueError(
1169 f"Unknown buffer type: '{buffer_t}'. "
1170 "Must be `FIFO`, `FIRO`, `Reservoir`, `BatchReservoir` or `SimPairedReservoir`."
1171 )