Coverage for melissa/server/deep_learning/reservoir.py: 89%

325 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-09-22 10:36 +0000

1from typing import Optional, Tuple, Union, List, Deque 

2import logging 

3from time import time 

4import random 

5from dataclasses import dataclass 

6from collections import deque, Counter 

7import threading 

8from queue import Full, Empty 

9 

10import numpy as np 

11 

12from melissa.server.simulation import SimulationData 

13from melissa.types import GetProtocol, Threshold 

14 

15logger = logging.getLogger(__name__) 

16 

17 

18@dataclass 

19class Sample: 

20 """A Sample represents an item contained in the reservoir. 

21 It is associated to simulation data and seen counter. 

22 

23 """ 

24 

25 data: SimulationData 

26 seen: int = 0 

27 

28 

29class PutGetMetric: 

30 """A class to monitor the balance between puts and gets in the reservoir. 

31 

32 """ 

33 

34 def __init__(self, val: int): 

35 self.val: int = val 

36 self.increment_lock: threading.Lock = threading.Lock() 

37 

38 def inc(self, val): 

39 with self.increment_lock: 

40 self.val = self.val + val 

41 

42 def dec(self, val): 

43 with self.increment_lock: 

44 self.val = self.val - val 

45 

46 

47class NotEnoughData(Empty): 

48 "Not enough data in the queue" 

49 

50 

51class BaseQueue: 

52 """A queue inspired by the Lib/queue.py, 

53 but with only the functionality needed, hence no task_done feature. 

54 See https://github.com/python/cpython/blob/main/Lib/queue.py#L28 

55 

56 """ 

57 

58 def __init__(self, maxsize: int = 0): 

59 self.maxsize = maxsize 

60 self.mutex = threading.RLock() 

61 self.not_empty = threading.Condition(self.mutex) 

62 self.not_full = threading.Condition(self.mutex) 

63 self._init_queue() 

64 self.put_get_metric: PutGetMetric = PutGetMetric(0) 

65 

66 def _init_queue(self): 

67 self.queue: Deque[SimulationData] = deque() 

68 

69 def save_state(self): 

70 return {"queue": self.queue} 

71 

72 def load_from_state(self, state: dict): 

73 self.queue = state["queue"] 

74 

75 def _size(self) -> int: 

76 return len(self.queue) 

77 

78 def _is_sampling_ready(self) -> bool: 

79 return self._size() > 0 

80 

81 def __len__(self): 

82 with self.mutex: 

83 return self._size() 

84 

85 def empty(self): 

86 with self.mutex: 

87 return self._size() == 0 

88 

89 def _is_full(self) -> bool: 

90 return self._size() >= self.maxsize 

91 

92 def _on_full(self, block, timeout) -> bool: 

93 if not block: 

94 raise Full 

95 else: 

96 self.not_full.wait(timeout) 

97 return True 

98 

99 def _get(self) -> SimulationData: 

100 return self.queue.popleft() 

101 

102 def _put(self, item: SimulationData): 

103 return self.queue.append(item) 

104 

105 def get(self, block: bool = True, timeout: Optional[float] = None): 

106 with self.not_empty: 

107 if not block: 

108 if not self._is_sampling_ready(): 

109 raise NotEnoughData 

110 elif timeout is None: 

111 while not self._is_sampling_ready(): 

112 self.not_empty.wait() 

113 elif timeout < 0: 

114 raise ValueError("'timeout' must be a non-negative number") 

115 else: 

116 endtime = time() + timeout 

117 while not self._is_sampling_ready(): 

118 remaining = endtime - time() 

119 if remaining <= 0.0: 

120 raise NotEnoughData 

121 self.not_empty.wait(remaining) 

122 item = self._get() 

123 self.not_full.notify() 

124 self.put_get_metric.dec(1) 

125 return item 

126 

127 def put(self, item, block: bool = True, timeout: Optional[float] = None): 

128 add_item = True # Do we actually add the item to the queue 

129 with self.not_full: 

130 if self.maxsize > 0: 

131 if not block: 

132 if self._is_full(): 

133 add_item = self._on_full(block=block, timeout=timeout) 

134 elif timeout is None: 

135 while self._is_full(): 

136 add_item = self._on_full(block=block, timeout=timeout) 

137 elif timeout < 0: 

138 raise ValueError("'timeout' must be a non-negative number") 

139 else: 

140 endtime = time() + timeout 

141 while self._is_full(): 

142 remaining = endtime - time() 

143 if remaining <= 0.0: 

144 raise Full 

145 add_item = self._on_full(block=block, timeout=remaining) 

146 if add_item: 

147 self._put(item) 

148 self.put_get_metric.inc(1) 

149 if self._is_sampling_ready(): 

150 self.not_empty.notify() 

151 

152 def compute_buffer_statistics(self) -> Tuple[np.ndarray, np.ndarray]: 

153 raise NotImplementedError 

154 

155 

156class CounterMixin: 

157 """A Mixin to count the number of times each sample has been seen in the reservoir. 

158 

159 """ 

160 

161 def __init__(self): 

162 self.seen_ctr = Counter() 

163 

164 def put(self, item: SimulationData, block: bool = True, timeout: Optional[float] = None): 

165 watched_item = Sample(item) 

166 super().put(watched_item, block, timeout) # type: ignore 

167 

168 def _get_with_eviction(self, index: Union[List, int]): 

169 items: Union[List[Sample], Sample] = super()._get_with_eviction( # type: ignore 

170 index 

171 ) 

172 if isinstance(items, list): 

173 item_data = [] 

174 item_seen = [] 

175 for item in items: 

176 item.seen += 1 

177 item_seen.append(item.seen) 

178 item_data.append(item.data) 

179 self.seen_ctr += Counter(item_seen) 

180 return item_data 

181 self.seen_ctr += Counter([items.seen + 1]) 

182 return items.data 

183 

184 def _get_without_eviction(self, index: Union[List, int]): 

185 items: Union[List[Sample], Sample] = super()._get_without_eviction( # type: ignore 

186 index 

187 ) 

188 if isinstance(items, list): 

189 for item in items: 

190 item.seen += 1 

191 return [item.data for item in items] 

192 items.seen += 1 

193 return items.data 

194 

195 def _get(self): 

196 if hasattr(super(), "_get_without_eviction"): 

197 return super()._get() # type: ignore 

198 else: 

199 item: Sample = super()._get() # type: ignore 

200 self.seen_ctr += Counter([item.seen + 1]) 

201 return item.data 

202 

203 def __repr__(self) -> str: 

204 s = super().__repr__() 

205 return f"{s}: {self.seen_ctr}" 

206 

207 

208class ReceptionDependant: 

209 """A Mixin to receive the signal that the reception of data is over. 

210 Especially useful to empty the reservoir. 

211 

212 """ 

213 

214 def __init__(self): 

215 self._is_reception_over = False 

216 

217 def signal_reception_over(self): 

218 self._is_reception_over = True 

219 

220 

221class SamplingDependant: 

222 """A Mixin to receive the signal that the reservoir is ready for sampling. 

223 Espcieally useful with thresholds. 

224 

225 """ 

226 

227 def _is_sampling_ready(self) -> bool: 

228 return True 

229 

230 

231class ThresholdMixin(SamplingDependant, ReceptionDependant): 

232 """A Mixin that blocks when not enough data are available in the container. 

233 

234 """ 

235 

236 def __init__(self, threshold: int): 

237 ReceptionDependant.__init__(self) 

238 self.threshold = threshold 

239 

240 def _is_sampling_ready(self: Threshold) -> bool: 

241 is_ready = super()._is_sampling_ready() and (self._size() > self.threshold) # type: ignore 

242 return is_ready 

243 

244 def signal_reception_over(self): 

245 with self.mutex: 

246 super().signal_reception_over() 

247 self.threshold = 0 

248 self.not_empty.notify() 

249 

250 

251class ReadingWithoutEvictionMixin(ReceptionDependant): 

252 """A Mixin that gets data without eviction until the reception is over. 

253 

254 """ 

255 

256 def _get_from_index(self: GetProtocol, index: Union[list, int]): 

257 if self._is_reception_over: 

258 return self._get_with_eviction(index) 

259 return self._get_without_eviction(index) 

260 

261 

262class RandomQueue(BaseQueue): 

263 """Queue that randomly reads items from the list container. 

264 It evicts on reading. 

265 

266 """ 

267 

268 def __init__(self, maxsize: int, pseudo_epochs: int = 1): 

269 super().__init__(maxsize) 

270 self.pseudo_epochs = pseudo_epochs 

271 

272 def _init_queue(self): 

273 self.queue: List[SimulationData] = [] 

274 

275 def _get_index(self) -> int: 

276 index = random.randrange(self._size()) 

277 return index 

278 

279 def _get_without_eviction(self, index: Union[list, int]): 

280 if isinstance(index, int): 

281 return self.queue[index] 

282 elif isinstance(index, list): 

283 return [self.queue[i] for i in index] 

284 

285 def _get_with_eviction(self, index: Union[list, int]): 

286 if isinstance(index, int): 

287 item = self.queue[index] 

288 del self.queue[index] 

289 return item 

290 elif isinstance(index, list): 

291 items = [self.queue[i] for i in index] 

292 for i in index: 

293 del self.queue[i] 

294 return items 

295 

296 def _get_from_index(self, index: int): 

297 if self.pseudo_epochs <= 1: 

298 return self._get_with_eviction(index) 

299 return self._get_without_eviction(index) 

300 

301 def _get(self): 

302 index = self._get_index() 

303 return self._get_from_index(index) 

304 

305 

306class ReservoirQueue(ReadingWithoutEvictionMixin, RandomQueue): 

307 """Queue implementing the reservoir sampling algorithm. 

308 

309 """ 

310 

311 def __init__(self, maxsize: int, queue: Optional[Deque] = None): 

312 ReadingWithoutEvictionMixin.__init__(self) 

313 RandomQueue.__init__(self, maxsize) 

314 self.put_ctr: int = 0 

315 

316 def _evict(self, index: int): 

317 del self.queue[index] 

318 

319 def _on_full(self, block, timeout) -> bool: 

320 index = random.randrange(self.put_ctr) 

321 if index < self.maxsize: 

322 self._evict(index) 

323 return True 

324 return False 

325 

326 def put(self, item, block=True, timeout=None): 

327 with self.mutex: 

328 RandomQueue.put(self, item, block, timeout) 

329 self.put_ctr += 1 

330 

331 

332class RandomEvictOnWriteQueue(ReadingWithoutEvictionMixin, RandomQueue): 

333 """A queue that overwrites random samples that have been already seen. 

334 

335 """ 

336 

337 def __init__(self, maxsize: int, queue: Optional[Deque] = None): 

338 ReadingWithoutEvictionMixin.__init__(self) 

339 RandomQueue.__init__(self, maxsize) 

340 

341 def _init_queue(self): 

342 self.not_seen: list = [] 

343 self.seen: list = [] 

344 

345 def save_state(self): 

346 return {"not_seen": self.not_seen, "seen": self.seen} 

347 

348 def load_from_state(self, state: dict): 

349 self.not_seen = state["not_seen"] 

350 self.seen = state["seen"] 

351 

352 def __repr__(self) -> str: 

353 return ( 

354 f"{self.__class__.__name__}:" 

355 f"not yet seen samples {len(self.not_seen)}, already seen {len(self.seen)}" 

356 ) 

357 

358 def _size(self) -> int: 

359 return len(self.not_seen) + len(self.seen) 

360 

361 def _evict(self, index): 

362 del self.seen[index] 

363 

364 def _on_full(self, block, timeout) -> bool: 

365 if not block: 

366 if len(self.seen) == 0: 

367 raise Full 

368 else: 

369 if len(self.seen) == 0: 

370 timed_in = self.not_full.wait(timeout) 

371 if not timed_in: 

372 return True 

373 index = random.randrange(len(self.seen)) 

374 self._evict(index) 

375 return True 

376 

377 def _get_with_eviction(self, index: Union[list, int]): 

378 if isinstance(index, int): 

379 if index < len(self.not_seen): 

380 item = self.not_seen[index] 

381 del self.not_seen[index] 

382 else: 

383 index = index - len(self.not_seen) 

384 item = self.seen[index] 

385 del self.seen[index] 

386 return item 

387 elif isinstance(index, list): 

388 not_seen_index = sorted([i for i in index if i < len(self.not_seen)], reverse=True) 

389 seen_index = sorted( 

390 [i - len(self.not_seen) for i in index if i >= len(self.not_seen)], reverse=True 

391 ) 

392 items = [self.not_seen[i] for i in not_seen_index] 

393 items += [self.seen[i] for i in seen_index] 

394 for i in not_seen_index: 

395 del self.not_seen[i] 

396 for i in seen_index: 

397 del self.seen[i] 

398 return items 

399 

400 def _get_without_eviction(self, index: Union[list, int]): 

401 if isinstance(index, int): 

402 if index < len(self.not_seen): 

403 item = self.not_seen[index] 

404 del self.not_seen[index] 

405 self.seen.append(item) 

406 else: 

407 index = index - len(self.not_seen) 

408 item = self.seen[index] 

409 return item 

410 

411 elif isinstance(index, list): 

412 not_seen_index = [i for i in index if i < len(self.not_seen)] 

413 seen_index = [i - len(self.not_seen) for i in index if i >= len(self.not_seen)] 

414 items = [] 

415 for i in not_seen_index: 

416 item = self.not_seen[i] 

417 items.append(item) 

418 del self.not_seen[i] 

419 self.seen.append(item) 

420 items += [self.seen[i] for i in seen_index] 

421 return items 

422 

423 def _put(self, item): 

424 self.not_seen.append(item) 

425 

426 

427class BatchGetMixin(SamplingDependant, ReceptionDependant): 

428 """A Mixin that gets data from the reservoir as batches instead of indivudal sampling. 

429 

430 """ 

431 

432 def __init__(self, batch_size: int): 

433 ReceptionDependant.__init__(self) 

434 self.batch_size = batch_size 

435 

436 def _is_sampling_ready(self): 

437 is_ready = super()._is_sampling_ready() 

438 if self._is_reception_over: 

439 # No more data will arrive, we may not be able to serve batch_size data 

440 return is_ready 

441 return is_ready and (self._size() >= self.batch_size) 

442 

443 def _get(self): 

444 if not self._is_reception_over: 

445 population = self.batch_size 

446 else: 

447 population = min(self.batch_size, self._size()) 

448 indices = sorted(random.sample(range(self._size()), k=population), reverse=True) 

449 items = self._get_from_index(indices) 

450 return items 

451 

452 

453class FIFO(CounterMixin, ReceptionDependant, BaseQueue): 

454 """First In First Out. """ 

455 

456 def __init__(self, maxsize: int = 0): 

457 CounterMixin.__init__(self) 

458 ReceptionDependant.__init__(self) 

459 BaseQueue.__init__(self, maxsize) 

460 

461 

462class FIRO(CounterMixin, ThresholdMixin, RandomQueue): 

463 """First In Random Out. """ 

464 

465 def __init__(self, maxsize: int, threshold: int, pseudo_epochs: int = 1): 

466 assert threshold <= maxsize 

467 CounterMixin.__init__(self) 

468 RandomQueue.__init__(self, maxsize, pseudo_epochs) 

469 ThresholdMixin.__init__(self, threshold) 

470 

471 

472class Reservoir(CounterMixin, ThresholdMixin, RandomEvictOnWriteQueue): 

473 def __init__(self, maxsize: int, threshold: int): 

474 assert threshold <= maxsize 

475 CounterMixin.__init__(self) 

476 ThresholdMixin.__init__(self, threshold) 

477 RandomEvictOnWriteQueue.__init__(self, maxsize) 

478 

479 

480class BatchReservoir(BatchGetMixin, Reservoir): 

481 def __init__(self, maxsize: int, threshold: int, batch_size: int): 

482 assert threshold <= maxsize 

483 assert batch_size <= maxsize 

484 Reservoir.__init__(self, maxsize, threshold) 

485 BatchGetMixin.__init__(self, batch_size)