Coverage for melissa/server/deep_learning/active_sampling/parameters.py: 16%
195 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 22:25 +0100
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 22:25 +0100
1"""This script defines the `ExperimentBreeder` base class and `DefaultBreeder`
2which uses our breeding approach to sample the parameters."""
4# pylint: disable=C0103,R0901,R0902,W0221,W0613,W0238
6import logging
7from typing import (
8 Union, Any, List, Dict, Tuple,
9 Iterable, Optional, Callable
10)
12import numpy as np
13from numpy.typing import DTypeLike, NDArray
14from scipy.stats import multivariate_normal
15import matplotlib
16import matplotlib.pyplot as plt
18from melissa.server.deep_learning.active_sampling.exceptions import BreedError
19from melissa.server.parameters import (
20 StaticExperiment,
21 RandomUniformSamplerMixIn,
22 HaltonSamplerMixIn,
23 LatinHypercubeSamplerMixIn,
24 ParameterSamplerType
25)
26from melissa.server.deep_learning.active_sampling.breed_utils import get_fitnesses
27from melissa.server.deep_learning.tensorboard import TensorboardLogger
29logger = logging.getLogger(__name__)
31# avoids a warning from matplotlib
32# when we run it under a NOT main thread.
33matplotlib.use('agg')
36class ExperimentBreeder(RandomUniformSamplerMixIn,
37 HaltonSamplerMixIn,
38 LatinHypercubeSamplerMixIn,
39 StaticExperiment):
40 """A base class for breeding and sampling experimental parameters. It initializes the parameters
41 based on the chosen sampling method and logs relevant experiment data.
43 Inherits from multiple mix-ins for sampling and static experiment management.
45 **When inheriting this class, override `sample()` and optionally, `draw()`.**
47 ### Parameters
48 - **non_breed_sampler_t** (`ParameterSamplerType`, default=`RANDOM_UNIFORM`):
49 The type of parameter sampler to use. Can be `RANDOM_UNIFORM`, `HALTON`, or `LHS`.
50 **Note that the first generation will be populated from this sampler type.**
52 ### Attributes
53 - **checkpoint_data** (`Dict[str, Any]`): Stores experiment checkpoint data.
54 - **tb_logger** (`TensorboardLogger`): Logger for tracking experiment metrics in TensorBoard.
55 """
57 def __init__(
58 self,
59 nb_params: int,
60 nb_sims: int,
61 l_bounds: List[Union[float, int]],
62 u_bounds: List[Union[float, int]],
63 seed: Optional[int] = None,
64 dtype: DTypeLike = np.float32,
65 non_breed_sampler_t: ParameterSamplerType = ParameterSamplerType.RANDOM_UNIFORM,
66 ) -> None:
68 StaticExperiment.__init__(
69 self, nb_params, nb_sims, l_bounds, u_bounds, seed, dtype
70 )
72 if non_breed_sampler_t is ParameterSamplerType.HALTON:
73 HaltonSamplerMixIn.__init__(self)
74 elif non_breed_sampler_t is ParameterSamplerType.LHS:
75 LatinHypercubeSamplerMixIn.__init__(self)
76 else:
77 RandomUniformSamplerMixIn.__init__(self)
79 self.checkpoint_data: Dict[str, Any] = {}
80 self.tb_logger: Optional[TensorboardLogger] = None
82 def get_non_breed_samples(self, nb_samples: int = -1) -> NDArray:
83 """Returns the `nb_samples` using the default (non-breed) sampler."""
84 return self.sample(nb_samples)
86 def next_parameters(self, **kwargs) -> Any:
87 """This method is required for producing the next set of parameters i.e the next generation.
88 It must be called only through `melissa.server.deep_learing.active_sampling` module and
89 can be overridden.
91 ### Parameters
92 - **kwargs** (`Dict[str, Any]`): A keyword arguments for custom preprocessing.
94 ### Returns
95 - `Any`: A set of parameters for the next generation."""
96 return self.sample(self.nb_sims)
98 def set_tb_logger(self, tb_logger: TensorboardLogger) -> None:
99 """Sets the Tensorboard logger."""
100 self.tb_logger = tb_logger
103class DefaultBreeder(ExperimentBreeder):
104 """A class that extends `ExperimentBreeder` for managing breeding experiments with more advanced
105 control over sampling and breeding parameters.
107 ### Parameters
108 - **sigma** (`Union[float, Iterable]`, default=0.005): Covariance initialization for breeding.
109 - **start** (`float`, default=0.15): Starting breeding ratio.
110 - **end** (`float`, default=0.75): Ending breeding ratio.
111 - **breakpoint** (`int`, default=10): Number of steps in the breeding ratio transition.
112 - **use_true_mixing** (`bool`, default=False): Use true mixing for breeding.
113 - **log_extra** (`bool`, default=False): Log extra information for debugging.
114 - **scatter_function** (`Callable[[NDArray], Tuple[str, NDArray, str, NDArray]]`,
115 default=(lambda x: ("x0", x[:,0], "x1", x[:,1])): Scatter function that takes one NDarray
116 `self.parameters`
117 of shape `(nb_sims, nb_params)` and returns two Ndarays of shape (nb_sims,) for plotting
118 scatter on X-Y axis labeled with provided strings.
119 - **device** (`str`, default="cpu"): Device for computation, e.g., "cpu" or "cuda".
121 ### Attributes
122 - **sigma_opt** (`float`): Optimal minimum covariance value for breeding.
123 - **covs** (`NDArray[np.float32]`): Covariance matrix for each simulation.
124 - **Rs** (`NDArray`): Linearly spaced breeding ratios for each breeding step.
125 - **R_i** (`int`): Current breeding ratio index.
126 - **R** (`float`): Current breeding ratio.
127 - **oob_factor** (`float`): Factor by which the covariance decreases
128 when the child is out-of-bounds.
129 - **max_oob_count** (`int`): Maximum allowed attempts for
130 out-of-bounds children.
131 - **oob_count** (`List[List[int]]`): List of out-of-bounds count for each simulation.
132 - **parameters_is_bred** (`NDArray[np.bool_]`): Boolean array indicating whether a
133 simulation's parameters have been bred."""
135 def __init__(
136 self,
137 nb_params: int,
138 nb_sims: int,
139 l_bounds: List[Union[float, int]],
140 u_bounds: List[Union[float, int]],
141 seed: Optional[int] = None,
142 dtype: DTypeLike = np.float32,
143 non_breed_sampler_t: ParameterSamplerType = ParameterSamplerType.RANDOM_UNIFORM,
144 sigma: Union[float, Iterable] = 0.005,
145 start: float = 0.15,
146 end: float = 0.75,
147 breakpoint: int = 10, # breakpoint is a python keyword too.
148 use_true_mixing: bool = False,
149 log_extra: bool = False,
150 scatter_function: Callable[[NDArray], Tuple[str, NDArray, str, NDArray]] = (
151 lambda x: ("x0", x[:, 0], "x1", x[:, 1])
152 ),
153 ) -> None:
155 super().__init__(
156 nb_params, nb_sims, l_bounds, u_bounds, seed, dtype, non_breed_sampler_t
157 )
159 self.sigma_opt: float = 0.05
160 self._sigma_init(sigma)
161 self.covs: NDArray[np.float32] = \
162 np.full((self.nb_sims, self.nb_params), self.sigma).astype(np.float32)
163 self.Rs: NDArray = np.linspace(start, end, max(breakpoint, 2), endpoint=True)
164 self.R_i: int = 0
165 self.R: float = start
166 self.oob_factor: float = 0.3 # 3-sigma rule
167 self.max_oob_count: int = 5
168 self.oob_count: List[List[int]] = [[0]] * self.nb_sims
169 self.parameters_is_bred: NDArray[np.bool_] = np.full(self.nb_sims, False)
170 self.use_true_mixing: bool = use_true_mixing
171 self.log_extra: bool = log_extra
172 self.scatter_function: Callable[
173 [NDArray], Tuple[str, NDArray, str, NDArray]
174 ] = scatter_function
176 def set_tb_logger(self, tb_logger: TensorboardLogger) -> None:
177 """Sets the Tensorboard logger."""
178 super().set_tb_logger(tb_logger)
179 if self.log_extra:
180 fig, axes = plt.subplots(1, 1, figsize=(8, 8))
181 label_X, data_X, label_Y, data_Y = self.scatter_function(self.parameters)
182 axes.set_xlabel(label_X)
183 axes.set_ylabel(label_Y)
184 axes.scatter(data_X, data_Y, c="red", s=10)
185 axes.set_title("Initial parameters")
186 self.plot_limits_xy = (axes.get_xlim(), axes.get_ylim())
187 if self.tb_logger is not None:
188 self.tb_logger.log_figure(
189 "Parents_vs_children", fig, close=True
190 )
192 def checkpoint_state(self) -> None:
193 """Saves the current state for checkpointing."""
195 self.checkpoint_data.update({
196 "R_index": self.R_i,
197 "oob_count": self.oob_count,
198 "covariances": self.covs,
199 "parameters_is_bred": self.parameters_is_bred
200 })
202 super().checkpoint_state()
204 def restart_from_checkpoint(self) -> None:
205 """Restores the state from a checkpoint."""
207 super().restart_from_checkpoint()
209 self.R_i = self.checkpoint_data["R_index"]
210 self.oob_count = self.checkpoint_data["oob_count"]
211 self.covs = self.checkpoint_data["covariances"]
212 self.parameters_is_bred = self.checkpoint_data["parameters_is_bred"]
214 def get_breeding_status_per_parameter(self) -> NDArray[np.bool_]:
215 """Returns a boolean array stating which parameter indices were bred."""
216 return self.parameters_is_bred
218 def _sigma_init(self, sigma: Union[float, Iterable]) -> None:
219 """Initializes the `sigma` values for the sampling process,
220 ensuring they are within acceptable bounds.
222 The function checks the type and size of `sigma`:
224 - If `sigma` is a float, it initializes an array of the same size as
225 the number of parameters (`nb_params`).
226 - If `sigma` is an array-like object, it verifies that its size
227 matches the number of parameters (`nb_params`), otherwise raises an assertion error.
228 - The function then ensures that `sigma` values are within a valid range,
229 scaling them if necessary to be within bounds defined by the parameter's upper and lower
230 limits (`u_bounds` and `l_bounds`).
232 ### Parameters
233 - **sigma** (`Union[float, Iterable]`):
234 The value(s) to initialize the sigma. This can either be a single float value
235 (applied to all parameters) or an array-like object (list, tuple, or NDArray)
236 specifying the sigma for each parameter.
238 ### Raises
239 - `RuntimeError`:
240 If `sigma` is neither a float nor an iterable."""
242 if isinstance(sigma, float):
243 sigma = np.full((self.nb_params,), sigma)
244 elif isinstance(sigma, (np.ndarray, list, tuple)): # array_like but not string
245 sigma = np.array(sigma).ravel()
246 msg = (
247 f"The size of array-like `sigma` must be equal to number "
248 f"of dimensions of PDE domain: {sigma.size} != {self.nb_params}"
249 )
250 assert sigma.shape[0] == self.nb_params, msg
251 else:
252 raise BreedError(
253 "Argument `sigma` for `samplers.Breed` must be one of `float` or `iterable`. "
254 f"Given type was {type(sigma)}."
255 )
257 dim_scales = self.u_bounds - self.l_bounds
258 max_sigma = dim_scales / 8
259 if not np.logical_and(
260 max_sigma >= sigma, # sigma is not too big
261 sigma >= (dim_scales * self.sigma_opt) # sigma is not too small (or not set)
262 ).all():
263 self.sigma = np.where(sigma > max_sigma, max_sigma, sigma)
264 self.sigma = np.where(
265 sigma < dim_scales * self.sigma_opt, dim_scales * self.sigma_opt, sigma
266 )
267 logger.info(
268 f"WARNING (samplers.Breed): Given value of sigma ({sigma}) "
269 f"is updated to {self.sigma}"
270 )
271 else:
272 self.sigma = sigma
274 def next_parameters( # type: ignore
275 self, start_sim_id: int, max_breeding_count: int = -1
276 ) -> None:
277 """Override the parent class sampling method with custom breed-specific arguments.
279 This method calculates the fitness for each simulation, selects breeding candidates,
280 and returns a new set of parameters based on the custom breeding algorithm.
282 ### Parameters
283 - **start_sim_id** (`int`): The starting simulation id from which breeding should begin.
284 - **max_breeding_count** (`int`, optional):
285 The maximum number of breed iterations. (Default is -1 i.e all remaining parameters).
287 **Note**: Unlike the parent, this method does not return anything as `__breed_algorithm`
288 method directly modifies `parameters` attribute of the class."""
290 fitness_per_sim, sim_ids = get_fitnesses()
291 self.__breed_algorithm(
292 fitness_per_sim, sim_ids, start_sim_id, max_breeding_count
293 )
295 def reset_index(self) -> None:
296 """Placeholder method. `__breed_algorithm` updates `_current_idx` value directly."""
297 return
299 def __breed_algorithm(
300 self,
301 fitness_per_sim: Union[NDArray, list],
302 sim_ids: Union[NDArray, list],
303 start_sim_id: int,
304 max_breeding_count: int = -1,
305 ) -> None:
306 """Breeding algorithm to generate new parameters based on simulation fitness.
308 This method selects parent simulations based on their fitness scores, performs the breeding
309 process to generate new parameters, and updates the breeding statistics. It also logs
310 various statistics regarding the breeding process with `TensorboardLogger`.
312 ### Parameters
313 - **fitness_per_sim** (`Union[NDArray, list]`):
314 Fitness values corresponding to each simulation.
315 - **sim_ids** (`Union[NDArray, list]`): List or array of simulation ids.
316 - **start_sim_id** (`int`): The starting simulation id for breeding.
317 - **max_breeding_count** (`int`, optional): The maximum number of parameters to
318 consider after `start_sim_id`.
319 (Default is -1 i.e all remaining parameters)."""
321 logger.info(f"Current breeding ratio={self.R:.3f}")
322 fitness_per_sim = np.array(fitness_per_sim, dtype=np.float32)
323 fitness = fitness_per_sim - fitness_per_sim.min()
324 distribution = (fitness / fitness.sum()).ravel()
326 self._current_idx = start_sim_id
327 current_final_sim_id = (
328 self.nb_sims
329 if max_breeding_count < 0
330 else start_sim_id + max_breeding_count
331 )
332 nb_children = current_final_sim_id - start_sim_id
334 child_idx_list = np.arange(start_sim_id, current_final_sim_id)
336 try:
337 if self.use_true_mixing:
338 parent_idx_list = np.random.choice(
339 sim_ids,
340 size=round(nb_children * self.R),
341 p=distribution.ravel()
342 )
343 else:
344 parent_idx_list = np.random.choice(
345 sim_ids,
346 size=nb_children,
347 p=distribution.ravel()
348 )
349 except BreedError as e:
350 logger.exception(f"Active sampling breed sim_ids={sim_ids}")
351 logger.exception(f"Active sampling breed fitness={fitness_per_sim}")
352 logger.exception(f"Active sampling breed loss dist={distribution}")
353 raise BreedError(e) from e
355 self.__set_children(parent_idx_list, child_idx_list)
356 logger.info(
357 f"Bred samples={self.parameters_is_bred[child_idx_list].sum()}/{nb_children}"
358 )
360 if self.log_extra and self.tb_logger:
361 self.tb_logger.log_scalar(
362 "Ratio_sampler/Empirical",
363 self.parameters_is_bred[child_idx_list].sum() / nb_children,
364 self.R_i,
365 )
366 self.tb_logger.log_scalar("Ratio_sampler/Expected", self.R, self.R_i)
368 fig, axes = plt.subplots(1, 1, figsize=(9, 8))
369 label_X, data_X, label_Y, data_Y = self.scatter_function(self.parameters)
370 axes.set_xlabel(label_X)
371 axes.set_ylabel(label_Y)
372 axes.scatter(
373 data_X[child_idx_list],
374 data_Y[child_idx_list],
375 c=np.where(self.parameters_is_bred[child_idx_list], "b", "g"),
376 alpha=0.5,
377 s=15,
378 )
379 scat = axes.scatter(
380 data_X[sim_ids],
381 data_Y[sim_ids],
382 c=fitness_per_sim,
383 cmap="Reds",
384 marker="v",
385 s=30,
386 edgecolors="k",
387 linewidths=0.3,
388 )
389 axes.set_title("blue - proposal, green - uniform, red - parents")
390 fig.colorbar(scat, label="fitness (deltaloss)")
391 xlim, ylim = (axes.get_xlim(), axes.get_ylim())
392 self.plot_limits_xy = (
393 (
394 min(xlim[0], self.plot_limits_xy[0][0]),
395 max(xlim[1], self.plot_limits_xy[0][1]),
396 ),
397 (
398 min(ylim[0], self.plot_limits_xy[1][0]),
399 max(ylim[1], self.plot_limits_xy[1][1]),
400 ),
401 )
402 axes.set_xlim(self.plot_limits_xy[0])
403 axes.set_ylim(self.plot_limits_xy[1])
404 self.tb_logger.log_figure(
405 tag="Parents_vs_children", figure=fig, step=self.R_i, close=True
406 )
408 # _R_step
409 self.R_i += 1
410 if self.R_i < len(self.Rs):
411 self.R = self.Rs[self.R_i]
413 def __set_proposal_child(self, parent_id: int, child_id: int) -> None:
414 parent = self.parameters[parent_id]
415 parent_oob = 0
416 while parent_oob < self.max_oob_count:
417 mvn_sampler = multivariate_normal(mean=parent, cov=self.covs[parent_id])
418 child = mvn_sampler.rvs(size=1)
419 # TODO division by proposal
420 # self.proposal[parent_id] = mvn_sampler.pdf(parent)
421 # self.proposal[child_id] = mvn_sampler.pdf(child)
422 child_dim_inside = np.squeeze(
423 np.logical_and(child > self.l_bounds, child < self.u_bounds)
424 )
425 if child_dim_inside.all():
426 break
427 # if out of boundary - sample again with lower covariance
428 parent_oob += 1
429 self.covs[parent_id][~child_dim_inside] *= self.oob_factor
430 else:
431 # if too close to boundary, parent is the child
432 child = np.copy(parent)
434 self.covs[child_id] = self.covs[parent_id]
435 self.oob_count[parent_id].append(parent_oob)
436 self.parameters_is_bred[child_id] = True
437 self.parameters[child_id] = child
439 def __set_random_child(self, child_id: int) -> None:
440 child = self.get_non_breed_samples(nb_samples=1).astype(np.float32)
441 self.covs[child_id] = self.sigma
442 self.parameters_is_bred[child_id] = False
443 self.parameters[child_id] = child
445 def __set_children(self, parent_idx_list, child_idx_list) -> None:
446 """Unoptimized breeding process to generate child parameters by sampling from
447 the parent parameters and their covariance, with an additional check
448 for out-of-bounds conditions. If a child falls out of bounds, its covariance is reduced,
449 and the child is re-sampled.
451 This method uses `scipy.stats.multivariate_normal` to perform
452 sampling per simulation iteratively.
454 ### Parameters
455 - **parent_idx_list** (`NDArray`): List of indices for the parent simulations.
456 - **child_idx_list** (`NDArray`):
457 List of indices for the child simulations that are being bred."""
459 random_parent_count = 0
460 proposal_parent_idx_list = []
461 if self.use_true_mixing:
462 nb_total = len(child_idx_list)
463 nb_proposal = len(parent_idx_list)
464 nb_random = nb_total - nb_proposal
465 idx_proposal, idx_random = 0, 0
466 idx_total = 0
467 # debug_string = []
468 while idx_total < nb_total:
469 child_id = child_idx_list[idx_total]
470 if idx_proposal < nb_proposal and np.random.uniform(0, 1) < self.R:
471 # debug_string.append('1')
472 parent_id = parent_idx_list[idx_proposal]
473 self.__set_proposal_child(parent_id, child_id)
474 idx_proposal += 1
475 idx_total += 1
476 proposal_parent_idx_list.append(parent_id)
477 elif idx_random < nb_random:
478 # debug_string.append('0')
479 self.__set_random_child(child_id)
480 idx_random += 1
481 idx_total += 1
482 random_parent_count += 1
483 # debug_string = '\n'.join(''.join(debug_string[i * 50:(i + 1) * 50])
484 # for i in range((len(debug_string) // 50 ) + 1))
485 # print(debug_string)
486 # logger.info(f'checking shuffle: \n {debug_string}')
487 else:
488 for child_id, parent_id in zip(child_idx_list, parent_idx_list):
489 if np.random.uniform(0, 1) < self.R:
490 self.__set_proposal_child(parent_id, child_id)
491 proposal_parent_idx_list.append(parent_id)
492 else:
493 self.__set_random_child(child_id)
494 random_parent_count += 1
496 if self.log_extra:
497 fig, ax = plt.subplots(1, 2, figsize=(16, 5))
498 ax[0].set_title("Distribution of chosen parents")
499 left_id, right_id = \
500 min(proposal_parent_idx_list), max(proposal_parent_idx_list)
501 ax[0].hist(
502 proposal_parent_idx_list,
503 bins=int(right_id - left_id) + 1,
504 range=(left_id, right_id + 1),
505 align="left",
506 edgecolor="black",
507 )
508 ax[0].set_xlabel("Simulation id")
509 ax[0].set_ylabel("Number times chosen")
510 ax[0].set_axisbelow(True)
511 ax[0].grid()
513 ax[1].set_title('Distribution of "family" size')
514 counts = np.unique(proposal_parent_idx_list, return_counts=True)[1]
515 left_id, right_id = min(counts), max(counts)
516 ax[1].hist(
517 counts,
518 bins=int(right_id - left_id) + 1,
519 range=(left_id, right_id + 1),
520 align="left",
521 edgecolor="black",
522 )
523 ax[1].set_xlabel("...with that number of children")
524 ax[1].set_ylabel("Number of parents...")
525 ax[1].set_axisbelow(True)
526 ax[1].grid()
527 if self.tb_logger is not None:
528 self.tb_logger.log_figure(
529 "Parents_statistics_distribution", fig, self.R_i, close=True
530 )