Coverage for melissa/server/deep_learning/active_sampling/breed_utils.py: 41%
73 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 22:25 +0100
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 22:25 +0100
1"""`melissa.server.deep_learning.active_sampling.breed_utils` defines all the helper functions
2required for active sampling breed algorithm."""
4# pylint: disable=W0603
6import threading
7from collections import deque
8from typing import (
9 Union, Dict, List, Tuple, Set,
10 Deque, Protocol, Callable, Iterable
11)
13import numpy as np
14from numpy.typing import NDArray
15from iterative_stats.iterative_moments import IterativeMoments
18class TensorType(Protocol):
19 """A static type protocol for `record_increments`."""
20 item: Callable
23class LastN:
24 """Maintains the last `n` unique values in the order of insertion.
26 ### Attributes
27 - **n** (`int`): Maximum number of unique values to keep.
28 - **queue** (`deque`): Stores the ordered unique values.
29 - **set** (`set`): Tracks the unique values for fast membership checks."""
31 def __init__(self, n: int = 1) -> None:
33 self.n: int = n
34 self.queue: Deque[int] = deque()
35 self.set_: Set[int] = set()
37 def insert(self, value: int) -> None:
38 """Inserts a value, maintaining the order and uniqueness."""
40 if value in self.set_:
41 self.queue.remove(value)
42 elif len(self.queue) == self.n:
43 oldest = self.queue.popleft()
44 self.set_.remove(oldest)
45 self.queue.append(value)
46 self.set_.add(value)
48 def get_values(self) -> list:
49 """Returns the current list of values in the queue."""
51 return list(self.queue)
54__N: int = -1
55__LASTN: LastN = LastN()
56__LOCK: threading.Lock = threading.Lock()
57# WARNING: do not combine the (sim_id, tstep) as a tuple key
58MomentsType = Dict[int, Dict[int, IterativeMoments]]
59__MOMENT_1: MomentsType = {}
60RecordType = Union[List, NDArray, Iterable[TensorType]]
63def get_moment() -> MomentsType:
64 return __MOMENT_1
67def get_sliding_window() -> LastN:
68 return __LASTN
71def get_sliding_window_size():
72 return __N
75def initialize_container_with_sliding_window(n: int) -> None:
76 """Initializes the `LastN` container."""
78 global __N, __LASTN
79 __N = n
80 __LASTN = LastN(__N)
83def weighted_average(values: Union[List, NDArray]) -> NDArray:
84 """Computes the weighted average of the input values using a logarithmic weighting scheme.
86 ### Parameters
87 - **values** (`Union[List, NDArray]`):
88 A list or array of numerical values to compute the weighted average.
90 ### Returns
91 - `NDArray`: An array containing the weighted average values."""
93 values = np.array(values, dtype=np.float32)
94 n = len(values)
95 idx = np.arange(n)
96 alpha = np.log(n - idx + 2)
97 return (values * alpha) / n
100def calculate_delta_loss(loss_per_sample_in_batch: NDArray) -> NDArray:
101 """Calculates delta loss values for every sample in batch
102 - **loss_per_sample_in_batch** (`NDArray`): an array of losses per sample in a batch.
104 ### Returns
105 - `NDArray`: An array containing delta loss values calculated per sample in a batch.
106 """
108 avg = loss_per_sample_in_batch.mean()
109 std = loss_per_sample_in_batch.std()
110 return np.where(
111 loss_per_sample_in_batch > avg, (loss_per_sample_in_batch - avg) / std, 0.0
112 )
115def _record_increment(sim_id: int, t_step: int, value: float) -> None:
116 """Records the delta loss for a given simulation ID and time step,
117 updating its mean iteratively based on occurance.
119 ### Parameters
120 - **sim_id** (`int`): The ID of the simulation.
121 - **t_step** (`int`): The current time step in the simulation."""
123 with __LOCK:
124 if sim_id not in __MOMENT_1:
125 __MOMENT_1[sim_id] = {}
126 if t_step not in __MOMENT_1[sim_id]:
127 __MOMENT_1[sim_id][t_step] = IterativeMoments(max_order=1, dim=1)
128 __MOMENT_1[sim_id][t_step].increment(value)
130 __LASTN.insert(sim_id)
133def record_increments(sim_ids: RecordType, time_ids: RecordType, values: RecordType):
134 """Records delta losses for multiple simulation IDs and time steps,
135 iteratively updating their means based on occurrence.
137 This function processes multiple simulation records in batch,
138 converting inputs to raw numerical values before passing them
139 to `_record_increment`.
141 ### Parameters
142 - **sim_ids** (`RecordType`): A collection of simulation IDs.
143 - **time_ids** (`RecordType`): A collection of time steps corresponding to `sim_ids`.
144 - **values** (`RecordType`): A collection of loss values to record.
146 Each element in `sim_ids`, `time_ids`, and `values` must be convertible
147 to an `int` or `float`."""
149 def raw_val(x):
150 x = x.item() if hasattr(x, "item") else x
151 assert isinstance(x, (int, float))
152 return x
154 for j, t, v in zip(sim_ids, time_ids, values):
155 _record_increment(raw_val(j), raw_val(t), raw_val(v))
158def get_fitnesses(
159 weighted: bool = False, lastN: bool = True
160) -> Tuple[List[float], List[int]]:
161 """Computes the fitness values for recent simulations based on their mean delta loss.
163 ### Parameters
164 - **weighted** (`bool`, optional): computes a weighted average of delta losses;
165 otherwise, computes a simple mean. (Default is `False`).
167 ### Returns
168 - `Tuple[List[float], List[int]]`:
169 - `List[float]`: A list of fitness values (averaged or weighted delta losses).
170 - `List[int]`: A list of simulation IDs corresponding to the fitness values."""
172 last_sim_ids: list = __LASTN.get_values() if lastN else list(__MOMENT_1.keys())
173 last_averages: list = []
175 with __LOCK:
176 for sim_id in last_sim_ids:
177 temp = []
178 for t_step in __MOMENT_1[sim_id].keys():
179 temp.append(
180 # average by batches
181 __MOMENT_1[sim_id][t_step].get_mean()[0]
182 )
183 # average by timestep
184 last_averages.append(weighted_average(temp) if weighted else np.mean(temp))
186 return last_averages, last_sim_ids