Coverage for melissa/server/deep_learning/reservoir.py: 63%

519 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-11-03 09:52 +0100

1"""This script defines all the buffer classes.""" 

2 

3import logging 

4import random 

5import threading 

6import numpy as np 

7from scipy.stats import norm 

8from collections import Counter, deque 

9from dataclasses import dataclass 

10from enum import Enum 

11from queue import Empty, Full 

12from time import time 

13from typing import Any, Dict, Deque, List, Optional, Tuple, Union 

14from numpy.typing import NDArray 

15 

16from melissa.server.simulation import SimulationData 

17from melissa.types import QueueProtocol 

18 

19logger = logging.getLogger(__name__) 

20 

21 

22@dataclass 

23class Sample: 

24 """A Sample represents an item contained in the reservoir. 

25 It is associated to simulation data and seen counter.""" 

26 

27 data: SimulationData 

28 seen: int = 0 

29 

30 def __repr__(self) -> str: 

31 """Returns a string representation of the `Sample` object.""" 

32 s = ( 

33 f"<{self.__class__.__name__} " 

34 f"sim-id={self.data.simulation_id} " 

35 f"time-step={self.data.time_step} " 

36 f"fields={[f for f in self.data.payload.keys()]} " 

37 f"seen={self.seen}>" 

38 ) 

39 return s 

40 

41 

42class PutGetMetric: 

43 """A class to monitor the balance between puts and gets in the reservoir.""" 

44 

45 def __init__(self, val: int) -> None: 

46 

47 self.val: int = val 

48 self.increment_lock: threading.Lock = threading.Lock() 

49 

50 def inc(self, val: int) -> None: 

51 """Increments the counter.""" 

52 with self.increment_lock: 

53 self.val = self.val + val 

54 

55 def dec(self, val: int) -> None: 

56 """Decrements the counter.""" 

57 with self.increment_lock: 

58 self.val = self.val - val 

59 

60 

61class NotEnoughData(Empty): 

62 "Not enough data in the queue" 

63 

64 

65class BaseQueue: 

66 """A queue inspired by the Lib/queue.py, 

67 but with only the functionality needed, hence no task_done feature. 

68 See https://github.com/python/cpython/blob/main/Lib/queue.py#L28""" 

69 

70 def __init__(self, maxsize: int = 0) -> None: 

71 

72 self.maxsize: int = maxsize 

73 self.mutex: threading.RLock = threading.RLock() 

74 self.not_empty: threading.Condition = threading.Condition(self.mutex) 

75 self.not_full: threading.Condition = threading.Condition(self.mutex) 

76 self._init_queue() 

77 self.put_get_metric: PutGetMetric = PutGetMetric(0) 

78 

79 def _init_queue(self) -> None: 

80 """Initialize the queue with a specific data-structure. 

81 Unique to subclasses.""" 

82 self.queue: Deque = deque() 

83 

84 def save_state(self) -> Dict[str, Any]: 

85 """Returns the current state of the queue.""" 

86 return {"queue": self.queue} 

87 

88 def load_from_state(self, state: Dict[str, Any]) -> None: 

89 """Loads the queue from the previous state.""" 

90 self.queue = state["queue"] 

91 

92 def _size(self) -> int: 

93 """Returns the current size of the queue.""" 

94 return len(self.queue) 

95 

96 def _is_sampling_ready(self) -> bool: 

97 return self._size() > 0 

98 

99 def __len__(self) -> int: 

100 with self.mutex: 

101 return self._size() 

102 

103 def empty(self) -> bool: 

104 with self.mutex: 

105 return self._size() == 0 

106 

107 def _is_full(self) -> bool: 

108 return self._size() >= self.maxsize 

109 

110 def _on_full(self, block, timeout) -> bool: 

111 """Handles the behavior when the queue is full, 

112 blocking the operation or raising an exception. 

113 

114 ### Parameters 

115 - **block** (`bool`): Indicates whether the operation should block 

116 until the queue is not full. 

117 - **timeout** (`float`): The maximum time (in seconds) to wait if blocking is enabled. 

118 

119 ### Returns 

120 - `bool`: Always returns `True` if the operation blocks successfully and 

121 the queue is no longer full.""" 

122 

123 if not block: 

124 raise Full 

125 self.not_full.wait(timeout) 

126 return True 

127 

128 def _get(self) -> Union[List[SimulationData], SimulationData]: 

129 """Returns the `SimulationData` object from the queue.""" 

130 return self.queue.popleft() 

131 

132 def _put(self, item: SimulationData) -> None: 

133 """Inserts `SimulationData` object in the queue.""" 

134 self.queue.append(item) 

135 

136 def get( 

137 self, block: bool = True, timeout: Optional[float] = None 

138 ) -> Union[List[SimulationData], SimulationData]: 

139 """Retrieves an item from the queue, blocking if necessary until data is available. 

140 

141 ### Parameters 

142 - **block** (`bool`, default=`True`): 

143 - If `True`, the method will block until data is available or the timeout expires. 

144 - If `False`, raises `NotEnoughData` if data is not immediately available. 

145 - **timeout** (`Optional[float]`): 

146 - Maximum time (in seconds) to wait for data if `block` is `True`. 

147 - If `None`, waits indefinitely. 

148 - Must be a non-negative number if provided. 

149 

150 ### Returns 

151 - `Union[List[SimulationData], SimulationData]`: The retrieved item from the queue. 

152 

153 ### Raises 

154 - `NotEnoughData`: If the queue does not have enough data to retrieve: 

155 - When `block` is `False` and data is unavailable. 

156 - When `block` is `True` and the timeout expires. 

157 - `ValueError`: If `timeout` is a negative number.""" 

158 

159 with self.not_empty: 

160 if not block: 

161 if not self._is_sampling_ready(): 

162 raise NotEnoughData 

163 elif timeout is None: 

164 while not self._is_sampling_ready(): 

165 self.not_empty.wait() 

166 elif timeout < 0: 

167 raise ValueError("'timeout' must be a non-negative number") 

168 else: 

169 endtime = time() + timeout 

170 while not self._is_sampling_ready(): 

171 remaining = endtime - time() 

172 if remaining <= 0.0: 

173 raise NotEnoughData 

174 self.not_empty.wait(remaining) 

175 item = self._get() 

176 self.not_full.notify() 

177 self.put_get_metric.dec(1) 

178 return item 

179 

180 def put( 

181 self, item: SimulationData, block: bool = True, timeout: Optional[float] = None 

182 ) -> None: 

183 """Adds an item to the queue, blocking if necessary until space is available. 

184 

185 - If the queue has a finite `maxsize` and is full: 

186 - If `block` is `False`, checks whether the item can be 

187 conditionally added based on the `_on_full` method. 

188 - If `block` is `True`, waits until space is available or the timeout expires. 

189 - Adds the item to the queue if space is available or the `_on_full` method allows it. 

190 - Notifies waiting threads if the queue is no longer empty after the addition. 

191 

192 ### Parameters 

193 - **item** (`SimulationData`): The item to be added to the queue. 

194 - **block** (`bool`, default=`True`): 

195 - If `True`, the method will block until space becomes available or the timeout expires. 

196 - If `False`, raises `Full` if the queue is full. 

197 - **timeout** (`Optional[float]`): 

198 - Maximum time (in seconds) to wait for space in the queue if `block` is `True`. 

199 - If `None`, waits indefinitely. 

200 - Must be a non-negative number if provided. 

201 

202 ### Raises 

203 - `Full`: 

204 - If the queue is full, `block` is `False`, and the item cannot be conditionally added. 

205 - If the queue is full, `block` is `True`, and the timeout expires. 

206 - `ValueError`: If `timeout` is a negative number.""" 

207 

208 add_item = True # Do we actually add the item to the queue 

209 with self.not_full: 

210 if self.maxsize > 0: 

211 if not block: 

212 if self._is_full(): 

213 add_item = self._on_full(block=block, timeout=timeout) 

214 elif timeout is None: 

215 while self._is_full(): 

216 add_item = self._on_full(block=block, timeout=timeout) 

217 elif timeout < 0: 

218 raise ValueError("'timeout' must be a non-negative number") 

219 else: 

220 endtime = time() + timeout 

221 while self._is_full(): 

222 remaining = endtime - time() 

223 if remaining <= 0.0: 

224 raise Full 

225 add_item = self._on_full(block=block, timeout=remaining) 

226 if add_item: 

227 self._put(item) 

228 self.put_get_metric.inc(1) 

229 if self._is_sampling_ready(): 

230 self.not_empty.notify() 

231 

232 def compute_buffer_statistics(self) -> Tuple[NDArray, NDArray]: 

233 """Not needed.""" 

234 raise NotImplementedError 

235 

236 

237class CounterMixin: 

238 """A Mixin to track the number of times each sample has been seen in the reservoir. 

239 

240 This class adds counting functionality to the base class, using a `Counter` to track 

241 how many times each sample has been added to or retrieved from the reservoir.""" 

242 

243 def __init__(self): 

244 

245 self.seen_ctr = Counter() 

246 

247 def put( 

248 self, item: SimulationData, block: bool = True, timeout: Optional[float] = None 

249 ) -> None: 

250 """Adds an item to the reservoir and tracks how many times it has been seen.""" 

251 

252 watched_item = Sample(item) 

253 super().put(watched_item, block, timeout) # type: ignore 

254 

255 def _get_with_eviction( 

256 self, index: Union[List, int] 

257 ) -> Union[List[SimulationData], SimulationData]: 

258 """Retrieves items from the reservoir, evicting them if necessary, 

259 and updates their seen counts.""" 

260 

261 items: Union[List[Sample], Sample] = super()._get_with_eviction( # type: ignore 

262 index 

263 ) 

264 if isinstance(items, list): 

265 item_data = [] 

266 item_seen = [] 

267 for item in items: 

268 item.seen += 1 

269 item_seen.append(item.seen) 

270 item_data.append(item.data) 

271 self.seen_ctr += Counter(item_seen) 

272 return item_data 

273 self.seen_ctr += Counter([items.seen + 1]) 

274 

275 return items.data 

276 

277 def _get_without_eviction( 

278 self, index: Union[List, int] 

279 ) -> Union[List[SimulationData], SimulationData]: 

280 """Retrieves items from the reservoir without eviction, updating their seen counts.""" 

281 

282 items: Union[List[Sample], Sample] = super()._get_without_eviction( # type: ignore 

283 index 

284 ) 

285 if isinstance(items, list): 

286 for item in items: 

287 item.seen += 1 

288 return [item.data for item in items] 

289 items.seen += 1 

290 return items.data 

291 

292 def _get(self) -> Union[List[SimulationData], SimulationData]: 

293 """Retrieves a single item from the reservoir and updates its seen count.""" 

294 

295 if hasattr(super(), "_get_without_eviction"): 

296 return super()._get() # type: ignore 

297 

298 item: Sample = super()._get() # type: ignore 

299 self.seen_ctr += Counter([item.seen + 1]) 

300 return item.data 

301 

302 def __repr__(self) -> str: 

303 s = super().__repr__() 

304 return f"{s}: {self.seen_ctr}" 

305 

306 

307class ReceptionDependant: 

308 """A Mixin to signal when the reception of data is over. 

309 

310 It is particularly useful for cases where actions depend on the completion 

311 of data reception, such as emptying or processing the contents of a reservoir. It 

312 provides a method to indicate that the reception process is complete, allowing other 

313 parts of the system to act accordingly. 

314 

315 ### Attributes: 

316 - **_is_reception_over** (`bool`): A flag indicating whether data reception is over. 

317 """ 

318 

319 def __init__(self) -> None: 

320 

321 self._is_reception_over = False 

322 

323 def signal_reception_over(self) -> None: 

324 self._is_reception_over = True 

325 

326 

327class SamplingDependant: 

328 """A Mixin to receive the signal that the reservoir is ready for sampling. 

329 Espcieally useful with thresholds.""" 

330 

331 def _is_sampling_ready(self) -> bool: 

332 return True 

333 

334 

335class ThresholdMixin(SamplingDependant, ReceptionDependant): 

336 """A Mixin that blocks operations when not enough data is available in the container, 

337 based on a threshold. 

338 

339 This Mixin is used to manage situations where certain operations should be blocked or delayed 

340 until the data in the container reaches a specified threshold. It extends the functionality 

341 of both `SamplingDependant` and `ReceptionDependant` mixins, and adds a threshold-based 

342 control mechanism for readiness. 

343 

344 ### Attributes 

345 - **threshold** (`int`): The minimum number of items required in the container 

346 for sampling to proceed.""" 

347 

348 def __init__(self, threshold: int) -> None: 

349 

350 ReceptionDependant.__init__(self) 

351 self.threshold = threshold 

352 

353 def _is_sampling_ready(self: QueueProtocol) -> bool: 

354 """Determines if sampling is ready based on the threshold 

355 and available data in the container.""" 

356 

357 is_ready = super()._is_sampling_ready() and (self._size() > self.threshold) # type: ignore 

358 return is_ready 

359 

360 def set_threshold(self, t: int) -> None: 

361 """Sets a new threshold for when sampling is considered ready.""" 

362 self.threshold = t 

363 

364 def signal_reception_over(self: QueueProtocol) -> None: 

365 """Signals that the reception of data is complete, 

366 sets the threshold to 0, and notifies any waiting processes.""" 

367 with self.mutex: 

368 super().signal_reception_over() # type: ignore 

369 self.set_threshold(0) 

370 self.not_empty.notify() 

371 

372 

373class ReadingWithoutEvictionMixin(ReceptionDependant): 

374 """A Mixin that retrieves data without eviction until data reception is over. 

375 

376 This Mixin allows data to be retrieved from the container without eviction, 

377 but once data reception is marked as complete, it switches to eviction-based retrieval. 

378 It extends the `ReceptionDependant` Mixin to track the reception status.""" 

379 

380 def _get_from_index( 

381 self: QueueProtocol, index: Union[list, int] 

382 ) -> Union[List[SimulationData], SimulationData]: 

383 """Retrieves data from the container either with or without eviction, 

384 depending on whether the reception process has been completed.""" 

385 

386 if self._is_reception_over: 

387 return self._get_with_eviction(index) 

388 return self._get_without_eviction(index) 

389 

390 

391class RandomQueue(BaseQueue): 

392 """A queue that randomly selects items from the container and evicts them upon retrieval. 

393 

394 This queue operates by randomly selecting an item from the list container. Upon reading, 

395 the selected item is removed (evicted) from the queue. The queue supports a configurable 

396 number of pseudo-epochs to control the number of times the process of eviction and item 

397 retrieval is repeated.""" 

398 

399 def __init__(self, maxsize: int, pseudo_epochs: int = 1) -> None: 

400 

401 super().__init__(maxsize) 

402 self.pseudo_epochs = pseudo_epochs 

403 self.current_epoch: int = 0 

404 self.epoch_access_count: int = 0 

405 self.total_samples_in_epoch: int = 0 

406 

407 def _init_queue(self) -> None: 

408 """Initializes the internal list (queue).""" 

409 self.queue: List[SimulationData] = [] # type: ignore 

410 

411 def _get_index(self) -> int: 

412 """Randomly selects an index from the queue.""" 

413 index = random.randrange(self._size()) 

414 return index 

415 

416 def _get_without_eviction( 

417 self, index: Union[list, int] 

418 ) -> Union[List[SimulationData], SimulationData]: 

419 """Retrieves items without eviction.""" 

420 

421 if isinstance(index, int): 

422 return self.queue[index] 

423 elif isinstance(index, list): 

424 return [self.queue[i] for i in index] 

425 return SimulationData(0, 0, {}, []) 

426 

427 def _get_with_eviction( 

428 self, index: Union[list, int] 

429 ) -> Union[List[SimulationData], SimulationData]: 

430 """Retrieves items and evicts them.""" 

431 

432 if isinstance(index, int): 

433 item = self.queue[index] 

434 del self.queue[index] 

435 return item 

436 if isinstance(index, list): 

437 sorted_indices = sorted(index, reverse=True) 

438 items = [self.queue[i] for i in index] 

439 for i in sorted_indices: 

440 del self.queue[i] 

441 return items 

442 

443 def _get_from_index( 

444 self, index: Union[List, int] 

445 ) -> Union[List[SimulationData], SimulationData]: 

446 """Retrieves data based on the current pseudo-epoch setting.""" 

447 

448 if ( 

449 self.pseudo_epochs <= 1 

450 or self.current_epoch >= self.pseudo_epochs 

451 ): 

452 return self._get_with_eviction(index) 

453 

454 # still in pseudo-epoch mode 

455 result = self._get_without_eviction(index) 

456 

457 if isinstance(index, int): 

458 self.epoch_access_count += 1 

459 elif isinstance(index, list): 

460 self.epoch_access_count += len(index) 

461 

462 self._check_epoch_completion() 

463 

464 return result 

465 

466 def _get(self) -> Union[List[SimulationData], SimulationData]: 

467 """Retrieves a random item from the queue and evicts it, if necessary.""" 

468 

469 index = self._get_index() 

470 return self._get_from_index(index) 

471 

472 def _put(self, item: SimulationData) -> None: 

473 """Inserts SimulationData object in the queue.""" 

474 self.queue.append(item) 

475 

476 # track total samples during first epoch 

477 if self.pseudo_epochs > 1 and self.current_epoch == 0: 

478 self.total_samples_in_epoch = len(self.queue) 

479 

480 def _check_epoch_completion(self) -> None: 

481 """Check if we've completed an epoch and should start the next one.""" 

482 if self.pseudo_epochs <= 1: 

483 return 

484 

485 # if we've seen all samples in this epoch 

486 if self.epoch_access_count >= self.total_samples_in_epoch: 

487 self.current_epoch += 1 

488 self.epoch_access_count = 0 

489 logger.debug( 

490 f"Running epoch={self.current_epoch}/{self.pseudo_epochs}" 

491 ) 

492 

493 # if we've completed all pseudo epochs, now we can evict 

494 if self.current_epoch >= self.pseudo_epochs: 

495 logger.info("All pseudo epochs completed, switching to eviction mode.") 

496 

497 

498class ReservoirQueue(ReadingWithoutEvictionMixin, RandomQueue): 

499 """The `ReservoirQueue` supports adding items to the queue, evicting items when the queue 

500 reaches its maximum size, and retrieving items without eviction during sampling.""" 

501 

502 def __init__(self, maxsize: int) -> None: 

503 

504 ReadingWithoutEvictionMixin.__init__(self) 

505 RandomQueue.__init__(self, maxsize) 

506 self.put_ctr: int = 0 

507 

508 def _evict(self, index: int) -> None: 

509 """Removes an item from the queue at the specified index.""" 

510 del self.queue[index] 

511 

512 def _on_full(self, block, timeout) -> bool: 

513 """Handles eviction when the queue is full.""" 

514 

515 index = random.randrange(self.put_ctr) 

516 if index < self.maxsize: 

517 self._evict(index) 

518 return True 

519 return False 

520 

521 def put(self, item, block=True, timeout=None) -> None: 

522 """Adds an item to the queue, evicting items as needed.""" 

523 

524 with self.mutex: 

525 RandomQueue.put(self, item, block, timeout) 

526 self.put_ctr += 1 

527 

528 

529class RandomEvictOnWriteQueue(ReadingWithoutEvictionMixin, RandomQueue): 

530 """A queue that overwrites random samples that have already been seen. 

531 

532 This queue implements a variant of reservoir sampling where items are randomly evicted 

533 when the queue reaches its maximum size. The key difference is that once an item has been 

534 "seen", it can be overwritten randomly, ensuring that older items are eventually evicted 

535 while still maintaining randomness in the selection. 

536 

537 The queue maintains two lists: `not_seen` for items that have not been processed yet, 

538 and `seen` for items that have already been processed. Items are added to the `not_seen` 

539 list and moved to the `seen` list once they are accessed. 

540 

541 ### Attributes 

542 - **not_seen** (`list`): A list holding items that have not yet been accessed. 

543 - **seen** (`list`): A list holding items that have already been accessed.""" 

544 

545 def __init__(self, maxsize: int) -> None: 

546 

547 ReadingWithoutEvictionMixin.__init__(self) 

548 RandomQueue.__init__(self, maxsize) 

549 

550 def _init_queue(self) -> None: 

551 """Initializes `seen` and `not_seen` lists.""" 

552 self.not_seen: list = [] 

553 self.seen: list = [] 

554 

555 def save_state(self) -> Dict[str, List]: 

556 """Saves the current state of the queue, including the `not_seen` and `seen` lists.""" 

557 return {"not_seen": self.not_seen, "seen": self.seen} 

558 

559 def load_from_state(self, state: Dict[str, List]) -> None: 

560 """Loads the queue state from a saved state dictionary.""" 

561 

562 self.not_seen = state["not_seen"] 

563 self.seen = state["seen"] 

564 

565 def __repr__(self) -> str: 

566 """Returns a string representation of the queue, showing the size of the 

567 `not_seen` and `seen` lists.""" 

568 

569 return ( 

570 f"{self.__class__.__name__}:" 

571 f"not yet seen samples {len(self.not_seen)}, already seen {len(self.seen)}" 

572 ) 

573 

574 def _size(self) -> int: 

575 """Returns the total size of the queue (sum of `not_seen` and `seen`).""" 

576 return len(self.not_seen) + len(self.seen) 

577 

578 def _evict(self, index) -> None: 

579 """Evicts the item at the given index from the `seen` list.""" 

580 del self.seen[index] 

581 

582 def _on_full(self, block, timeout) -> bool: 

583 """Handles eviction when the queue is full, randomly evicting 

584 an item from the `seen` list.""" 

585 

586 if not block: 

587 if len(self.seen) == 0: 

588 raise Full 

589 else: 

590 if len(self.seen) == 0: 

591 timed_in = self.not_full.wait(timeout) 

592 if not timed_in: 

593 return True 

594 index = random.randrange(len(self.seen)) 

595 self._evict(index) 

596 return True 

597 

598 def _get_with_eviction( 

599 self, index: Union[list, int] 

600 ) -> Union[List[SimulationData], SimulationData]: 

601 """Retrieves an item by index, evicting it if necessary.""" 

602 

603 if isinstance(index, int): 

604 if index < len(self.not_seen): 

605 item = self.not_seen[index] 

606 del self.not_seen[index] 

607 else: 

608 index = index - len(self.not_seen) 

609 item = self.seen[index] 

610 del self.seen[index] 

611 return item 

612 

613 if isinstance(index, list): 

614 not_seen_index = sorted( 

615 [i for i in index if i < len(self.not_seen)], reverse=True 

616 ) 

617 seen_index = sorted( 

618 [i - len(self.not_seen) for i in index if i >= len(self.not_seen)], 

619 reverse=True, 

620 ) 

621 items = [self.not_seen[i] for i in not_seen_index] 

622 items += [self.seen[i] for i in seen_index] 

623 for i in not_seen_index: 

624 del self.not_seen[i] 

625 for i in seen_index: 

626 del self.seen[i] 

627 return items 

628 

629 def _get_without_eviction( 

630 self, index: Union[list, int] 

631 ) -> Union[List[SimulationData], SimulationData]: 

632 """Retrieves an item without eviction, moving it from 

633 `not_seen` to `seen` if accessed.""" 

634 

635 if isinstance(index, int): 

636 if index < len(self.not_seen): 

637 item = self.not_seen[index] 

638 del self.not_seen[index] 

639 self.seen.append(item) 

640 else: 

641 index = index - len(self.not_seen) 

642 item = self.seen[index] 

643 return item 

644 

645 if isinstance(index, list): 

646 not_seen_index = sorted( 

647 [i for i in index if i < len(self.not_seen)], reverse=True 

648 ) 

649 seen_index = sorted( 

650 [i - len(self.not_seen) for i in index if i >= len(self.not_seen)], 

651 reverse=True, 

652 ) 

653 items = [] 

654 for i in not_seen_index: 

655 item = self.not_seen[i] 

656 items.append(item) 

657 del self.not_seen[i] 

658 self.seen.append(item) 

659 items += [self.seen[i] for i in seen_index] 

660 return items 

661 

662 def _put(self, item) -> None: 

663 """Adds an item to the `not_seen` list.""" 

664 self.not_seen.append(item) 

665 

666 def all_seen(self) -> bool: 

667 """Returns `True` if there are no more items in the `not_seen` list, 

668 indicating that all items have been processed.""" 

669 return len(self.not_seen) == 0 

670 

671 

672class SimPairedQueue(ReadingWithoutEvictionMixin, BaseQueue): 

673 """A specialized queue that maintains a 2D matrix of simulation data samples, 

674 organized by simulation ID and time step. 

675 

676 This queue is designed to efficiently store and retrieve pairs of data points 

677 from the same simulation trajectory at different time steps. It uses a probability-based 

678 sampling mechanism that prioritizes time pairs with smaller time differences that 

679 have been sampled less frequently, implementing a form of importance sampling. 

680 

681 The queue maintains a count of how many times each time delta (difference between time steps) 

682 has been sampled, allowing it to adaptively focus on undersampled time differences. 

683 

684 ### Parameters: 

685 - **maxsize** (`int`): Maximum number of items to store in the queue 

686 - **sweep_size** (`int`): Number of distinct simulation trajectories 

687 - **nb_time_steps** (`int`): Number of time steps per simulation trajectory 

688 - **comm_size** (`int`, default=`1`): Number of parallel processes or threads 

689 """ 

690 

691 def __init__( 

692 self, maxsize: int, sweep_size: int, nb_time_steps: int, comm_size: int 

693 ) -> None: 

694 

695 self.nb_simulations = sweep_size // comm_size 

696 self.nb_time_steps = nb_time_steps 

697 self.comm_size = comm_size 

698 ReadingWithoutEvictionMixin.__init__(self) 

699 BaseQueue.__init__(self, maxsize) 

700 

701 def _init_queue(self) -> None: 

702 """Initializes the queue with a specific data-structure.""" 

703 self.queue: NDArray = np.empty( # type: ignore 

704 (self.nb_simulations, self.nb_time_steps), dtype=object 

705 ) 

706 self.queue_status: NDArray = np.zeros( 

707 (self.nb_simulations, self.nb_time_steps), dtype=np.int8 

708 ) 

709 self.delta_counts: NDArray = np.zeros((self.nb_time_steps - 1), dtype=int) 

710 self.sim_counts: NDArray = np.zeros((self.nb_simulations), dtype=int) 

711 self.sim_reception: NDArray = np.zeros((self.nb_simulations), dtype=int) 

712 self.avail_ts: NDArray = np.zeros((self.nb_simulations), dtype=int) 

713 self.cumul_seen: NDArray = np.zeros((self.nb_simulations), dtype=int) 

714 

715 self.seen_ctr: Counter = Counter() 

716 

717 def save_state(self) -> Dict: 

718 """Saves the current state of the queue and already seen time deltas.""" 

719 return { 

720 "queue": self.queue, 

721 "queue_status": self.queue_status, 

722 "delta_counts": self.delta_counts, 

723 "sim_counts": self.sim_counts, 

724 "sim_reception": self.sim_reception, 

725 } 

726 

727 def load_from_state(self, state: Dict) -> None: 

728 """Loads the queue state from a saved state dictionary.""" 

729 

730 self.queue = state["queue"] 

731 self.queue_status = state["queue_status"] 

732 self.delta_counts = state["delta_counts"] 

733 self.sim_counts = state["sim_counts"] 

734 self.sim_reception = state["sim_reception"] 

735 self.avail_ts = np.array( 

736 [(sim_queue is not None).sum() for sim_queue in self.queue], dtype=int # type: ignore 

737 ) 

738 self.cumul_seen = ( 

739 np.vectorize(lambda x: x.seen if hasattr(x, "seen") else 0)(self.queue) 

740 .sum(axis=1) 

741 .astype(np.int32) 

742 ) 

743 

744 def is_queue_empty(self, sim_id: int) -> bool: 

745 if ( 

746 self.avail_ts[sim_id] == 1 

747 and self.sim_reception[sim_id] == self.nb_time_steps 

748 ): 

749 time_step = np.where(self.queue_status[sim_id] == 1)[0].item() 

750 index = sim_id, time_step 

751 self._evict(index) 

752 return True 

753 elif self.avail_ts[sim_id] == 0: 

754 return True 

755 else: 

756 return False 

757 

758 def get_evicted_index(self, sim_id: int, std_factor: int = 4) -> Tuple[int, int]: 

759 """Computes the probabilities for evicting items from the queue.""" 

760 counts = np.vectorize(lambda x: x.seen if hasattr(x, "seen") else 0)( 

761 self.queue[sim_id] 

762 ).astype(np.float32) 

763 if std_factor > 0: 

764 mean = self.nb_time_steps / 2 

765 normal_pdf_values = norm.pdf( 

766 np.arange(self.nb_time_steps), 

767 loc=mean, 

768 scale=mean / std_factor, 

769 ) 

770 logits = counts * normal_pdf_values 

771 else: 

772 logits = counts 

773 time_step = np.random.choice(a=len(counts), p=logits / np.sum(logits)) 

774 return (sim_id, time_step) 

775 

776 def get_sampled_indices(self, sim_id: int) -> List[Tuple[int, int]]: 

777 """ 

778 Sample a pair of indices from a queue based on their time differences, with a bias towards 

779 smaller delta values after applying a transformation with a delta counter. 

780 """ 

781 # Extract indices of non-None elements in the queue 

782 (valid_ts,) = np.nonzero(self.queue[sim_id] != None) 

783 assert len(valid_ts) == self.avail_ts[sim_id], ( 

784 f"[Inconsistent state] Queue for sim_id={sim_id}: " 

785 f"avail_ts={self.avail_ts[sim_id]}, valid_ts={len(valid_ts)}" 

786 ) 

787 

788 # Get all possible pairs of indices (i,j) where i < j using upper triangular indices 

789 i, j = np.triu_indices(len(valid_ts), k=1) 

790 # Calculate the absolute time differences between all valid timestamps 

791 deltas = np.abs(valid_ts[i] - valid_ts[j]) 

792 # Define a mapping function that transforms raw deltas using delta_counts lookup 

793 mapped_deltas = np.take(self.delta_counts, deltas - 1) 

794 # Sample an index based on the calculated weights 

795 logits = -mapped_deltas 

796 weights = np.exp(logits - np.max(logits)) 

797 sampled_idx = np.random.choice(len(mapped_deltas), p=weights / np.sum(weights)) 

798 # Return the sampled indices 

799 t0, t1 = valid_ts[i[sampled_idx]], valid_ts[j[sampled_idx]] 

800 return [(sim_id, t0), (sim_id, t1)] 

801 

802 def get_sim_logits(self, mode: str) -> Tuple[NDArray, NDArray]: 

803 assert mode in ["evict", "sample"], "mode must be either 'evict' or 'sample'" 

804 avail_sim = ( 

805 np.where(self.cumul_seen != 0)[0] 

806 if mode == "evict" 

807 else np.where(self.avail_ts > 1)[0] 

808 ) 

809 recep_over_sim = np.where(self.sim_reception == self.nb_time_steps)[0] 

810 intersec = np.intersect1d(avail_sim, recep_over_sim) 

811 avail_sim = intersec if intersec.size > 0 else avail_sim 

812 

813 assert avail_sim.size > 0, " ".join( # type: ignore 

814 f"[Inconsistent state] mode={mode}", 

815 f"avail_ts={self.avail_ts}", 

816 f"sim_counts={self.sim_counts}", 

817 f"avail_sim={avail_sim}", 

818 f"recep_over_sim={recep_over_sim}" 

819 f"is_reception_over={self._is_reception_over}" 

820 ) 

821 

822 if mode == "evict": 

823 logits = self.sim_counts[avail_sim] 

824 elif mode == "sample": 

825 logits = -(self.sim_counts[avail_sim] / self.avail_ts[avail_sim]) 

826 logits = np.exp(logits - np.max(logits)) 

827 return avail_sim, logits 

828 

829 def _evict_most_seen(self) -> None: 

830 avail_sim, logits = self.get_sim_logits(mode="evict") 

831 sim_id = np.random.choice(avail_sim, p=logits / np.sum(logits)) 

832 index = self.get_evicted_index(sim_id) 

833 self._evict(index) 

834 self.is_queue_empty(sim_id) 

835 

836 def put( 

837 self, item: SimulationData, block: bool = True, timeout: Optional[float] = None 

838 ) -> None: 

839 """Adds an item to the reservoir and tracks how many times it has been seen.""" 

840 

841 watched_item = Sample(item) 

842 super().put(watched_item, block, timeout) # type: ignore 

843 

844 def _put(self, item: Sample) -> None: # type: ignore 

845 """Adds an item to the queue acording to its sim_id and time_step.""" 

846 sim_id, time_step = item.data.simulation_id, item.data.time_step 

847 rank_sim_id = sim_id // self.comm_size 

848 self.queue[rank_sim_id][time_step] = item 

849 self.queue_status[rank_sim_id][time_step] = 1 

850 self.sim_reception[rank_sim_id] += 1 

851 self.avail_ts[rank_sim_id] += 1 

852 

853 def _size(self) -> int: 

854 """Returns the total size of the queue.""" 

855 return np.sum(self.avail_ts) 

856 

857 def _evict(self, index) -> None: 

858 """Evicts the item at the given index.""" 

859 sim_id, time_step = index 

860 assert self.queue[sim_id][time_step] is not None, ( 

861 f"Evicting empty slot {sim_id},{time_step}" 

862 ) 

863 self.seen_ctr += Counter([self.queue[sim_id][time_step].seen]) 

864 self.cumul_seen[sim_id] -= self.queue[sim_id][time_step].seen 

865 self.queue[sim_id][time_step] = None 

866 self.avail_ts[sim_id] -= 1 

867 self.queue_status[sim_id][time_step] = -1 

868 

869 def _on_full(self, block, timeout) -> bool: 

870 """Handles eviction when the queue is full, randomly evicting 

871 an item from the buffer.""" 

872 

873 if not block: 

874 raise Full 

875 else: 

876 timed_in = self.not_full.wait(timeout) 

877 if not timed_in: 

878 return True 

879 self._evict_most_seen() 

880 return True 

881 

882 def _get_index(self) -> List[Tuple[int, int]]: 

883 """Selects a pair of indices from the queue for the same simulation ID.""" 

884 avail_sim, logits = self.get_sim_logits(mode="sample") 

885 sim_id = np.random.choice(avail_sim, p=logits / np.sum(logits)) 

886 indices = self.get_sampled_indices(sim_id) 

887 self.sim_counts[sim_id] += 1 

888 self.delta_counts[abs(indices[1][1] - indices[0][1]) - 1] += 1 

889 self.cumul_seen[sim_id] += 2 

890 return indices 

891 

892 def _get_with_eviction( 

893 self, indices: List[Tuple[int, int]] 

894 ) -> List[SimulationData]: 

895 """Retrieves and evicts items at the specified indices.""" 

896 items = [self.queue[sim_id][time_step] for sim_id, time_step in indices] 

897 for item in items: 

898 item.seen += 1 

899 self._evict_most_seen() 

900 return [item.data for item in items] 

901 

902 def _get_without_eviction( 

903 self, indices: List[Tuple[int, int]] 

904 ) -> List[SimulationData]: 

905 """Retrieves items at the specified indices.""" 

906 items = [self.queue[sim_id][time_step] for sim_id, time_step in indices] 

907 for item in items: 

908 item.seen += 1 

909 return [item.data for item in items] 

910 

911 def _get(self) -> List[SimulationData]: 

912 """Retrieves a random item from the queue and evicts it, if necessary.""" 

913 

914 index = self._get_index() 

915 return self._get_from_index(index) # type: ignore 

916 

917 

918class BatchGetMixin(SamplingDependant, ReceptionDependant): 

919 """A Mixin that retrieves data from the reservoir in batches instead of individual sampling. 

920 

921 This class provides functionality to retrieve multiple items (a batch) from the reservoir 

922 instead of retrieving them one at a time. It ensures that the reservoir has enough items 

923 to serve a batch, and if the reception of data is over, it limits the batch size to the 

924 remaining number of items in the reservoir. 

925 

926 ### Attributes 

927 - **batch_size** (`int`): The desired number of items to retrieve in each batch.""" 

928 

929 def __init__(self, batch_size: int) -> None: 

930 

931 ReceptionDependant.__init__(self) 

932 self.batch_size = batch_size 

933 

934 def _is_sampling_ready(self: QueueProtocol) -> bool: 

935 """Checks if the reservoir has enough items to serve a full batch, considering 

936 the batch size and whether data reception is over.""" 

937 

938 is_ready = super()._is_sampling_ready() # type: ignore 

939 if self._is_reception_over: 

940 # No more data will arrive, we may not be able to serve batch_size data 

941 return is_ready 

942 return is_ready and (self._size() >= self.batch_size) 

943 

944 def _get(self: QueueProtocol) -> Union[List[SimulationData], SimulationData]: 

945 """Retrieves a batch of items from the reservoir. The number of items returned 

946 is limited by the batch size, and it ensures that the batch is randomly selected 

947 from the available items in the reservoir.""" 

948 

949 if not self._is_reception_over: 

950 population = self.batch_size 

951 else: 

952 population = min(self.batch_size, self._size()) 

953 indices = sorted(random.sample(range(self._size()), k=population), reverse=True) 

954 items = self._get_from_index(indices) 

955 return items 

956 

957 

958class FIFO(CounterMixin, ReceptionDependant, BaseQueue): 

959 """First In First Out (FIFO) Queue. 

960 

961 A queue implementation that follows the FIFO principle, where the first item added 

962 is the first one to be retrieved. This queue also supports counting the number of 

963 times each item has been seen (via `CounterMixin`) and can signal when data reception 

964 is over (via `ReceptionDependant`).""" 

965 

966 def __init__(self, maxsize: int = 0) -> None: 

967 

968 CounterMixin.__init__(self) 

969 ReceptionDependant.__init__(self) 

970 BaseQueue.__init__(self, maxsize) 

971 

972 

973class FIRO(CounterMixin, ThresholdMixin, RandomQueue): 

974 """First In Random Out (FIRO) Queue. 

975 

976 A queue implementation that combines elements of FIFO (First In) and random eviction 

977 (Random Out). It follows a First In strategy to decide which item is added first, 

978 but when retrieving, items are evicted randomly from the queue. 

979 The queue also supports counting the number of times each item has been seen 

980 (via `CounterMixin`) and can block operations when the threshold is not met 

981 (via `ThresholdMixin`). 

982 

983 ### Attributes 

984 - **threshold** (`int`): The minimum number of items required in the queue to 

985 proceed with sampling. 

986 - **pseudo_epochs** (`int`): The number of epochs. Defaults to 1.""" 

987 

988 def __init__(self, maxsize: int, threshold: int, pseudo_epochs: int = 1) -> None: 

989 

990 assert threshold <= maxsize 

991 CounterMixin.__init__(self) 

992 RandomQueue.__init__(self, maxsize, pseudo_epochs) 

993 ThresholdMixin.__init__(self, threshold) 

994 

995 

996class Reservoir(CounterMixin, ThresholdMixin, RandomEvictOnWriteQueue): 

997 """First In Random Out (FIRO) with eviction on write. 

998 

999 This queue implementation combines the First In strategy (FIFO) with 

1000 random eviction upon writing. It behaves like a traditional FIFO queue for insertion, 

1001 but when the queue reaches its capacity, it randomly evicts one of the previously added 

1002 elements to make room for the new item. The class also tracks the number of times 

1003 each item has been seen (via `CounterMixin`) and enforces a threshold for sampling 

1004 (via `ThresholdMixin`). 

1005 

1006 ### Attributes 

1007 - **maxsize** (`int`): The maximum size of the queue. Once the size exceeds this limit, 

1008 eviction occurs. 

1009 - **threshold** (`int`): The minimum number of items required in the queue to allow sampling. 

1010 """ 

1011 

1012 def __init__(self, maxsize: int, threshold: int) -> None: 

1013 

1014 assert threshold <= maxsize 

1015 CounterMixin.__init__(self) 

1016 ThresholdMixin.__init__(self, threshold) 

1017 RandomEvictOnWriteQueue.__init__(self, maxsize) 

1018 

1019 

1020class BatchReservoir(BatchGetMixin, Reservoir): 

1021 """A queue implementing a batch version of the Reservoir sampling algorithm. 

1022 

1023 This class extends the `Reservoir` queue to support batch sampling. 

1024 It combines the functionality of the Reservoir sampling algorithm 

1025 (with random eviction on write) with the ability to retrieve 

1026 data in batches rather than individual samples. 

1027 It ensures that the queue never exceeds a specified maximum size and 

1028 maintains a threshold for sampling readiness. 

1029 

1030 ### Attributes 

1031 - **threshold** (`int`): The minimum number of items required in the queue to allow sampling. 

1032 - **batch_size** (`int`): The number of items to retrieve in each batch.""" 

1033 

1034 def __init__(self, maxsize: int, threshold: int, batch_size: int) -> None: 

1035 

1036 assert threshold <= maxsize 

1037 assert batch_size <= maxsize 

1038 Reservoir.__init__(self, maxsize, threshold) 

1039 BatchGetMixin.__init__(self, batch_size) 

1040 

1041 

1042class SimPairedReservoir(ThresholdMixin, SimPairedQueue): 

1043 """A reservoir sampling queue that retrieves paired simulation data. 

1044 

1045 This class extends the SimPairedQueue to incorporate reservoir sampling techniques 

1046 with a threshold-based sampling readiness condition. It's designed for scenarios where 

1047 paired samples from the same simulation are needed for training, while also maintaining 

1048 a diverse and representative set of samples through reservoir sampling. 

1049 

1050 It maintains an efficient data structure for storing simulation trajectories and implements 

1051 adaptive sampling strategies that prioritize undersampled time differences. 

1052 

1053 ### Attributes: 

1054 - **maxsize** (`int`): Maximum number of items to store in the reservoir 

1055 - **threshold** (`int`): Minimum number of items required before sampling 

1056 - **sweep_size** (`int`): Number of simulation trajectories to manage 

1057 - **nb_time_steps** (`int`): Number of time steps in each simulation trajectory 

1058 """ 

1059 

1060 def __init__( 

1061 self, 

1062 maxsize: int, 

1063 threshold: int, 

1064 sweep_size: int, 

1065 nb_time_steps: int, 

1066 comm_size: int = 1, 

1067 ) -> None: 

1068 

1069 assert threshold <= maxsize 

1070 ThresholdMixin.__init__(self, threshold) 

1071 SimPairedQueue.__init__(self, maxsize, sweep_size, nb_time_steps, comm_size) 

1072 

1073 

1074class BufferType(Enum): 

1075 """Enum representing the different types of buffer strategies available for sampling. 

1076 

1077 - FIFO = 0 

1078 - FIRO = 1 

1079 - Reservoir = 2 

1080 - BatchReservoir = 3 

1081 - SimPairedReservoir = 4""" 

1082 

1083 FIFO = 0 

1084 FIRO = 1 

1085 Reservoir = 2 

1086 BatchReservoir = 3 

1087 SimPairedReservoir = 4 

1088 

1089 

1090def make_buffer( 

1091 buffer_size: int, 

1092 buffer_t: BufferType = BufferType.FIFO, 

1093 per_server_watermark: Optional[int] = None, 

1094 pseudo_epochs: Optional[int] = None, 

1095 batch_size: Optional[int] = None, 

1096 sweep_size: Optional[int] = None, 

1097 nb_time_steps: Optional[int] = None, 

1098 comm_size: Optional[int] = 1, 

1099) -> BaseQueue: 

1100 """Factory function to create different types of buffers based on the specified buffer type. 

1101 

1102 This function initializes and returns a buffer (queue) object based on the given buffer size 

1103 and type. The available buffer types are defined by the `BufferType` enum, and each type 

1104 has specific requirements. The function raises a `ValueError` if the necessary parameters 

1105 are not provided for the selected buffer type. 

1106 

1107 ### Parameters 

1108 - **buffer_size** (`int`): The maximum size of the buffer. 

1109 - **buffer_t** (BufferType, optional): The type of buffer to create. 

1110 Default is `BufferType.FIFO`. 

1111 - **per_server_watermark** (`int`, optional): A threshold used in some buffer types 

1112 (e.g., `Reservoir`, `FIRO`, and `BatchReservoir`). 

1113 - **pseudo_epochs** (`int`, optional): A parameter for `FIRO` buffer type 

1114 to define the number of epochs for the buffer. 

1115 - **batch_size** (`int`, optional):A parameter for `BatchReservoir` buffer type 

1116 to defien the size of a batch. 

1117 - **sweep_size** (`int`, optional): The number of distinct simulation trajectories. 

1118 - **nb_time_steps** (`int`, optional): The number of time steps per simulation trajectory. 

1119 - **comm_size** (`int`, optional): The number of parallel processes or threads. 

1120 

1121 ### Returns 

1122 `BaseQueue`: A buffer object of the specified type.""" 

1123 

1124 if not buffer_t: 

1125 logger.warning("Buffer type is not specified. Defaulting to `FIFO`.") 

1126 buffer_t = BufferType.FIFO 

1127 

1128 if buffer_t is BufferType.FIFO: 

1129 return FIFO(buffer_size) 

1130 

1131 if buffer_t is BufferType.FIRO: 

1132 if per_server_watermark is None or pseudo_epochs is None: 

1133 raise ValueError( 

1134 "`per_server_watermark` and `pseudo_epochs`" 

1135 " are required for `FIRO` buffers." 

1136 ) 

1137 return FIRO(buffer_size, per_server_watermark, pseudo_epochs) 

1138 

1139 if buffer_t is BufferType.Reservoir: 

1140 if per_server_watermark is None: 

1141 raise ValueError( 

1142 "`per_server_watermark` is required for Reservoir buffers." 

1143 ) 

1144 return Reservoir(buffer_size, per_server_watermark) 

1145 

1146 if buffer_t is BufferType.BatchReservoir: 

1147 if per_server_watermark is None or batch_size is None: 

1148 raise ValueError( 

1149 "`per_server_watermark` and `batch_size`" 

1150 " are required for `BatchReservoir` buffers." 

1151 ) 

1152 return BatchReservoir(buffer_size, per_server_watermark, batch_size) 

1153 

1154 if buffer_t is BufferType.SimPairedReservoir: 

1155 if per_server_watermark is None or sweep_size is None or nb_time_steps is None: 

1156 raise ValueError( 

1157 "`per_server_watermark`, `nb_simulations` and `nb_time_steps` are required for " 

1158 "`SimPairedReservoir` buffers." 

1159 ) 

1160 

1161 if comm_size is None: 

1162 comm_size = 1 

1163 

1164 return SimPairedReservoir( 

1165 buffer_size, per_server_watermark, sweep_size, nb_time_steps, comm_size 

1166 ) 

1167 

1168 raise ValueError( 

1169 f"Unknown buffer type: '{buffer_t}'. " 

1170 "Must be `FIFO`, `FIRO`, `Reservoir`, `BatchReservoir` or `SimPairedReservoir`." 

1171 )