Coverage for melissa/server/deep_learning/active_sampling/breed_utils.py: 41%

73 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-10 22:25 +0100

1"""`melissa.server.deep_learning.active_sampling.breed_utils` defines all the helper functions 

2required for active sampling breed algorithm.""" 

3 

4# pylint: disable=W0603 

5 

6import threading 

7from collections import deque 

8from typing import ( 

9 Union, Dict, List, Tuple, Set, 

10 Deque, Protocol, Callable, Iterable 

11) 

12 

13import numpy as np 

14from numpy.typing import NDArray 

15from iterative_stats.iterative_moments import IterativeMoments 

16 

17 

18class TensorType(Protocol): 

19 """A static type protocol for `record_increments`.""" 

20 item: Callable 

21 

22 

23class LastN: 

24 """Maintains the last `n` unique values in the order of insertion. 

25 

26 ### Attributes 

27 - **n** (`int`): Maximum number of unique values to keep. 

28 - **queue** (`deque`): Stores the ordered unique values. 

29 - **set** (`set`): Tracks the unique values for fast membership checks.""" 

30 

31 def __init__(self, n: int = 1) -> None: 

32 

33 self.n: int = n 

34 self.queue: Deque[int] = deque() 

35 self.set_: Set[int] = set() 

36 

37 def insert(self, value: int) -> None: 

38 """Inserts a value, maintaining the order and uniqueness.""" 

39 

40 if value in self.set_: 

41 self.queue.remove(value) 

42 elif len(self.queue) == self.n: 

43 oldest = self.queue.popleft() 

44 self.set_.remove(oldest) 

45 self.queue.append(value) 

46 self.set_.add(value) 

47 

48 def get_values(self) -> list: 

49 """Returns the current list of values in the queue.""" 

50 

51 return list(self.queue) 

52 

53 

54__N: int = -1 

55__LASTN: LastN = LastN() 

56__LOCK: threading.Lock = threading.Lock() 

57# WARNING: do not combine the (sim_id, tstep) as a tuple key 

58MomentsType = Dict[int, Dict[int, IterativeMoments]] 

59__MOMENT_1: MomentsType = {} 

60RecordType = Union[List, NDArray, Iterable[TensorType]] 

61 

62 

63def get_moment() -> MomentsType: 

64 return __MOMENT_1 

65 

66 

67def get_sliding_window() -> LastN: 

68 return __LASTN 

69 

70 

71def get_sliding_window_size(): 

72 return __N 

73 

74 

75def initialize_container_with_sliding_window(n: int) -> None: 

76 """Initializes the `LastN` container.""" 

77 

78 global __N, __LASTN 

79 __N = n 

80 __LASTN = LastN(__N) 

81 

82 

83def weighted_average(values: Union[List, NDArray]) -> NDArray: 

84 """Computes the weighted average of the input values using a logarithmic weighting scheme. 

85 

86 ### Parameters 

87 - **values** (`Union[List, NDArray]`): 

88 A list or array of numerical values to compute the weighted average. 

89 

90 ### Returns 

91 - `NDArray`: An array containing the weighted average values.""" 

92 

93 values = np.array(values, dtype=np.float32) 

94 n = len(values) 

95 idx = np.arange(n) 

96 alpha = np.log(n - idx + 2) 

97 return (values * alpha) / n 

98 

99 

100def calculate_delta_loss(loss_per_sample_in_batch: NDArray) -> NDArray: 

101 """Calculates delta loss values for every sample in batch 

102 - **loss_per_sample_in_batch** (`NDArray`): an array of losses per sample in a batch. 

103 

104 ### Returns 

105 - `NDArray`: An array containing delta loss values calculated per sample in a batch. 

106 """ 

107 

108 avg = loss_per_sample_in_batch.mean() 

109 std = loss_per_sample_in_batch.std() 

110 return np.where( 

111 loss_per_sample_in_batch > avg, (loss_per_sample_in_batch - avg) / std, 0.0 

112 ) 

113 

114 

115def _record_increment(sim_id: int, t_step: int, value: float) -> None: 

116 """Records the delta loss for a given simulation ID and time step, 

117 updating its mean iteratively based on occurance. 

118 

119 ### Parameters 

120 - **sim_id** (`int`): The ID of the simulation. 

121 - **t_step** (`int`): The current time step in the simulation.""" 

122 

123 with __LOCK: 

124 if sim_id not in __MOMENT_1: 

125 __MOMENT_1[sim_id] = {} 

126 if t_step not in __MOMENT_1[sim_id]: 

127 __MOMENT_1[sim_id][t_step] = IterativeMoments(max_order=1, dim=1) 

128 __MOMENT_1[sim_id][t_step].increment(value) 

129 

130 __LASTN.insert(sim_id) 

131 

132 

133def record_increments(sim_ids: RecordType, time_ids: RecordType, values: RecordType): 

134 """Records delta losses for multiple simulation IDs and time steps, 

135 iteratively updating their means based on occurrence. 

136 

137 This function processes multiple simulation records in batch, 

138 converting inputs to raw numerical values before passing them 

139 to `_record_increment`. 

140 

141 ### Parameters 

142 - **sim_ids** (`RecordType`): A collection of simulation IDs. 

143 - **time_ids** (`RecordType`): A collection of time steps corresponding to `sim_ids`. 

144 - **values** (`RecordType`): A collection of loss values to record. 

145 

146 Each element in `sim_ids`, `time_ids`, and `values` must be convertible 

147 to an `int` or `float`.""" 

148 

149 def raw_val(x): 

150 x = x.item() if hasattr(x, "item") else x 

151 assert isinstance(x, (int, float)) 

152 return x 

153 

154 for j, t, v in zip(sim_ids, time_ids, values): 

155 _record_increment(raw_val(j), raw_val(t), raw_val(v)) 

156 

157 

158def get_fitnesses( 

159 weighted: bool = False, lastN: bool = True 

160) -> Tuple[List[float], List[int]]: 

161 """Computes the fitness values for recent simulations based on their mean delta loss. 

162 

163 ### Parameters 

164 - **weighted** (`bool`, optional): computes a weighted average of delta losses; 

165 otherwise, computes a simple mean. (Default is `False`). 

166 

167 ### Returns 

168 - `Tuple[List[float], List[int]]`: 

169 - `List[float]`: A list of fitness values (averaged or weighted delta losses). 

170 - `List[int]`: A list of simulation IDs corresponding to the fitness values.""" 

171 

172 last_sim_ids: list = __LASTN.get_values() if lastN else list(__MOMENT_1.keys()) 

173 last_averages: list = [] 

174 

175 with __LOCK: 

176 for sim_id in last_sim_ids: 

177 temp = [] 

178 for t_step in __MOMENT_1[sim_id].keys(): 

179 temp.append( 

180 # average by batches 

181 __MOMENT_1[sim_id][t_step].get_mean()[0] 

182 ) 

183 # average by timestep 

184 last_averages.append(weighted_average(temp) if weighted else np.mean(temp)) 

185 

186 return last_averages, last_sim_ids