Coverage for melissa/server/parameters.py: 86%
186 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
1"""This script consists of helper classes for implementing random sampling
2strategies for the simulation parameters."""
4import os
5from enum import Enum, EnumMeta, unique
6from abc import ABC, abstractmethod
7from typing import (
8 Any, List, Union, TypeVar,
9 Type, Optional, Dict
10)
11from typing_extensions import override
13import cloudpickle
14import numpy as np
15from numpy.typing import NDArray, DTypeLike
16from scipy.stats import qmc
17from mpi4py import MPI
19from melissa.utility.rank_helper import is_sampling_rank, sampling_rank_only
21ParameterSamplerClass = TypeVar("ParameterSamplerClass", bound="BaseExperiment")
24class MemmapWrapper:
25 """Wrapper class to avoid a direct access to the `numpy.memmap` instance which could get
26 overwritten by mistake."""
27 def __init__(self, filename: str, shape: tuple, mode: str, dtype=np.float32):
28 self.__memmap = np.memmap( # type: ignore
29 filename=filename, dtype=dtype, mode=mode, shape=shape
30 )
32 def __getitem__(self, index: Union[int, slice]):
33 return self.__memmap[index]
35 def __setitem__(self, index: Union[int, slice], value):
36 self.__memmap[index] = value
38 def flush(self):
39 self.__memmap.flush()
41 def __array__(self):
42 return self.__memmap
44 def __getattr__(self, attr):
45 return getattr(self.__memmap, attr)
47 def __repr__(self):
48 return repr(self.__memmap) # Return memmap representation
50 def __str__(self):
51 return str(self.__memmap)
54class BaseExperiment(ABC):
55 """A base class for parameter generation strategies for simulations.
57 - This class serves as the common parent for different methods of sampling simulation
58 parameters.
59 The `draw()` method, which is implemented by subclasses, is called from the server when creating
60 client scripts to generate parameter values.
61 - The `sample()` method samples parameters in their raw form that must be post-processed.
63 ### Attributes
64 - **nb_params** (`int`): The number of parameters to be generated.
65 - **nb_sims** (`int`): The number of simulations to be run.
66 - **l_bounds** (`List[Union[float, int]]`): The lower bounds for the parameter values.
67 - **u_bounds** (`List[Union[float, int]]`): The upper bounds for the parameter values.
68 - **seed** (`int`): The seed value for random number generation to
69 ensure reproducibility (Default is `None`).
70 - **dtype** (`DTypeLike`): The data type for the generated parameters.
71 (Default is `np.float64`)
72 - **rng** `Generator`: A NumPy random number generator initialized with either
73 the given seed or a random seed.
74 - **parameters** (`MemmapWrapper`): The pre-generated parameter values stored as a
75 NumPy array or memmap.
76 - **_current_idx** (`int`): The current index for selecting the next set of parameters.
77 """
79 def __init__(
80 self,
81 nb_params: int,
82 nb_sims: int,
83 l_bounds: List[Union[float, int]],
84 u_bounds: List[Union[float, int]],
85 seed: Optional[int] = None,
86 dtype: DTypeLike = np.float64
87 ) -> None:
89 self.comm: MPI.Intracomm = MPI.COMM_WORLD
90 self.rank: int = self.comm.Get_rank()
91 self.comm_size: int = self.comm.Get_size()
93 self.seed: Optional[int] = seed
94 self.nb_sims: int = nb_sims
95 self.nb_params: int = nb_params
96 self.dtype: DTypeLike = dtype
97 self.memmap_file: str = "checkpoints/sampled_parameters.npy"
98 self.already_initialized: bool = False
100 self.checkpoint_file: str = "checkpoints/sampler_metadata.pkl"
101 self.checkpoint_data: Dict[str, Any] = {}
103 self.set_seeds()
105 assert len(l_bounds) == len(
106 u_bounds
107 ), "Lengths of boundary lists must be the same."
108 if len(l_bounds) == 1:
109 self.l_bounds = np.full((self.nb_params,), l_bounds[0])
110 self.u_bounds = np.full((self.nb_params,), u_bounds[0])
111 else:
112 self.l_bounds = np.array(l_bounds)
113 self.u_bounds = np.array(u_bounds)
115 self.__initialize_memmap()
117 def get_current_parameters(self) -> NDArray:
118 return np.array(self.parameters, dtype=self.dtype) # type: ignore
120 @sampling_rank_only
121 def checkpoint_state(self) -> None:
122 """Saves the current state for checkpointing."""
123 self.checkpoint_data["already_initialized"] = self.already_initialized
124 with open(self.checkpoint_file, "wb") as f:
125 cloudpickle.dump(self.checkpoint_data, f)
127 def restart_from_checkpoint(self) -> None:
128 """Restores the state from a checkpoint."""
129 with open(self.checkpoint_file, "rb") as f:
130 self.checkpoint_data = cloudpickle.load(f)
131 self.already_initialized = self.checkpoint_data["already_initialized"]
133 def __initialize_memmap(self) -> None:
134 """Initializes a `numpy.memmap` to store the parameters to be sampled in a file such that
135 we share a common set of attributes across a parallel server.
137 The sampling rank performs the sampling and writes to a file.
138 Usually, the sampling rank is the last MPI rank but the decision is made at runtime for
139 multi-node studies.
140 """
141 if is_sampling_rank() and not os.path.exists(self.memmap_file):
142 with open(self.memmap_file, "wb") as f:
143 bytes_ = np.dtype(self.dtype).itemsize
144 f.write(bytearray(self.nb_sims * self.nb_params * bytes_))
145 self.comm.Barrier()
147 self.parameters = MemmapWrapper( # type: ignore
148 filename=self.memmap_file,
149 shape=(self.nb_sims, self.nb_params),
150 mode="r+" if is_sampling_rank() else "r",
151 dtype=self.dtype
152 )
154 # avoid overwriting the firstly sampled parameters
155 # when the server was restarted due to FT
156 if not self.already_initialized:
157 self.set_parameters(self.sample(self.nb_sims))
158 self.already_initialized = True
160 def set_seeds(self) -> None:
161 """Sets the random seeds for ensuring reproducibility across experiments.
163 This method initializes the random seed for,
165 -Python's `random` module
166 - NumPy's random number generator
167 - Default random number generator (RNG) using NumPy's `default_rng`.
169 ### Behavior
170 - If `self.seed` is specified, it sets the same seed across `random`,
171 NumPy's random module, and `np.random.default_rng` for consistent
172 results.
173 - If `self.seed` is not specified (`None` or `0`), it initializes
174 `self.rng` with a random seed provided by the system."""
176 if self.seed:
177 np.random.seed(self.seed)
178 self.rng = np.random.default_rng(self.seed)
179 else:
180 self.rng = np.random.default_rng()
182 def process_drawn(self, parameters: NDArray) -> Any:
183 """Processes the next set of parameters before generating client scripts.
185 ### Parameters:
186 - `parameters` (`NDArray`): A row of parameters obtained from a pre-generated
187 parameters matrix of shape `(nb_sims, nb_params)`.
189 ### Returns
190 - `Any`: Processed parameter compatible with the client scripts.
192 ---
193 For example, if your client script requires the inputs in a certain way
195 ```
196 python3 solver.py --arg1=<p0> --arg2=<p1>
197 ```
199 Then, `process_drawn` can be overriden such that it returns
201 ```python
202 def process_drawn(self, parameters: NDArray) -> List[str]:
203 return [f"--arg1={parameters[0]}", "--arg2=parameters[1]"]
204 ```
205 """
206 return parameters
208 def draw(self, sim_id: int) -> Any:
209 """Draws the next set of parameters.
211 ### Parameters
212 **sim_id** (`int`): The simulation id for which the parameter set returned.
214 ### Returns
215 - `Any`: The set of processed parameters from the pre-generated array.
216 """
218 params = self.parameters[sim_id]
219 return self.process_drawn(params)
221 @sampling_rank_only
222 def set_parameters(self, parameters: Union[List, NDArray]):
223 """Sets the static parameters.
225 ### Parameters
226 - **parameters** (`Union[List, NDArray]`): A list or array of parameters to set.
227 """
228 p = np.array(parameters, dtype=self.dtype)
229 assert p.shape == (self.nb_sims, self.nb_params), \
230 f"Shape mismatch, expected ({self.nb_sims}, {self.nb_params})"
231 self.parameters[:] = p # NOTE: keep this intact!
233 @sampling_rank_only
234 def flush_to_disk(self) -> None:
235 """Writes current parameter map to `memmap_file` on the disk."""
236 self.parameters.flush() # type: ignore
238 @abstractmethod
239 def base_sample(self, nb_samples: int) -> Any:
240 """Generates the specified number of parameter sets.
242 ### Parameters
243 - **nb_samples** (`int`): The number of parameter sets to generate.
245 ### Returns
246 - `Any`: The generated set of parameters."""
247 raise NotImplementedError("The `base_sample` method must be implemented by subclasses.")
249 def sample(self, nb_samples: int) -> Any:
250 """Generates parameter sets using the base sampling method.
252 ### Parameters
253 - **nb_samples** (`int`): The number of parameter sets to generate.
255 ### Returns
256 - `Any`: The generated set of parameters.
257 """
258 return self.base_sample(nb_samples)
260 @sampling_rank_only
261 def finalize(self, exit_: int = 0) -> None:
262 """Finalizes the experiment by cleaning up resources.
264 ### Parameters
265 - **exit_** (`int`): Exit code indicating the status of the experiment.
267 ### Behavior
268 - Removes the memory-mapped file (`memmap_file`) if it exists and the exit code is `0`.
269 """
271 if exit_ == 0:
272 if hasattr(self, "memmap_file") and os.path.exists(self.memmap_file):
273 os.remove(self.memmap_file)
276class SobolBaseExperiment(BaseExperiment):
277 """`SobolBaseExperiment` is an extension of the `BaseExperiment` class
278 designed to support Sobol sensitivity analysis. It introduces the concept
279 of the pick-freeze method for generating samples and supports both first-order
280 and second-order Sobol indices.
282 ### Attributes
283 **apply_pick_freeze** (`bool`): Indicates whether to use the pick-freeze method
284 for sample generation.
285 **second_order** (`bool`): Specifies whether to include second-order Sobol indices
286 in the pick-freeze matrix generation.
287 """
288 def __init__(
289 self,
290 nb_params: int,
291 nb_sims: int,
292 l_bounds: List[Union[float, int]],
293 u_bounds: List[Union[float, int]],
294 seed: Optional[int] = None,
295 dtype: DTypeLike = np.float64,
296 apply_pick_freeze: bool = False,
297 second_order: bool = False
298 ) -> None:
299 self.apply_pick_freeze: bool = apply_pick_freeze
300 self.second_order: bool = second_order
301 super().__init__(nb_params, nb_sims, l_bounds, u_bounds, seed, dtype)
303 def generate_pick_freeze_matrix(self, sample_a: NDArray, sample_b: NDArray) -> NDArray:
304 """Returns a pick-freeze matrix given the set of input parameters."""
305 matrix = np.vstack((sample_a, sample_b))
306 for k in range(self.nb_params):
307 sample_ek = sample_a.copy()
308 sample_ek[k] = sample_b[k]
309 matrix = np.vstack((matrix, sample_ek))
311 if self.second_order:
312 for k in range(self.nb_params):
313 sample_ck = sample_b.copy()
314 sample_ck[k] = sample_a[k]
315 matrix = np.vstack((matrix, sample_ck))
317 return matrix
319 def pick_freeze_sample(self, nb_samples: int) -> NDArray:
320 """Generates samples using the pick-freeze method."""
321 # generate one pick-freeze matrix to determine its size
322 def __pf_matrix():
323 sample_a = self.base_sample(1).squeeze()
324 sample_b = self.base_sample(1).squeeze()
325 return self.generate_pick_freeze_matrix(sample_a, sample_b)
327 pf_matrix = __pf_matrix()
329 # ensure the number of samples is a multiple of the pick-freeze matrix size
330 assert nb_samples % pf_matrix.shape[0] == 0, (
331 f"nb_samples ({nb_samples}) must be a multiple of pick-freeze matrix size "
332 f"({pf_matrix.shape[0]})."
333 )
334 nb_matrices = nb_samples // pf_matrix.shape[0]
336 # stack multiple pick-freeze matrices to generate the required number of samples
337 return np.vstack([pf_matrix, *[__pf_matrix() for _ in range(nb_matrices - 1)]])
339 @override
340 def sample(self, nb_samples: int) -> Any:
341 if self.apply_pick_freeze:
342 return self.pick_freeze_sample(nb_samples)
343 return self.base_sample(nb_samples)
346class RandomUniformSamplerMixIn:
347 """This class defines `sample()` method for random sampling with numpy."""
349 nb_params: int
350 l_bounds: NDArray
351 u_bounds: NDArray
353 def base_sample(self, nb_samples: int) -> NDArray:
354 """Generates random samples uniformly distributed between specified bounds.
356 ### Parameters
357 - **nb_samples** (`int`): The number of samples to generate.
359 ### Returns
360 - `NDArray`: A numpy array of shape `(nb_samples, nb_params)`
361 containing the sampled parameters."""
363 return np.random.uniform(
364 self.l_bounds, self.u_bounds, size=(nb_samples, self.nb_params)
365 ).squeeze()
368class RandomUniform(RandomUniformSamplerMixIn, SobolBaseExperiment):
369 """Random uniform sampling experiment that integrates Sobol and uniform sampling methods."""
371 def __init__(
372 self,
373 nb_params: int,
374 nb_sims: int,
375 l_bounds: List[Union[float, int]],
376 u_bounds: List[Union[float, int]],
377 seed: Optional[int] = None,
378 dtype=np.float64,
379 apply_pick_freeze: bool = False,
380 second_order: bool = False,
381 ) -> None:
383 SobolBaseExperiment.__init__(
384 self, nb_params, nb_sims, l_bounds, u_bounds, seed, dtype,
385 apply_pick_freeze, second_order
386 )
389class QMCSamplerMixIn:
390 """Mixin class for samplers based on `scipy.stats.qmc` methods,
391 such as Halton and Latin Hypercube.
393 This class defines the `sample()` method specifically for generating samples
394 using quasi-Monte Carlo methods, including Halton and Latin Hypercube samplers"""
396 l_bounds: NDArray
397 u_bounds: NDArray
398 _sampler: Union[qmc.Halton, qmc.LatinHypercube]
400 def base_sample(self, nb_samples: int) -> NDArray:
401 """Generates scaled samples using the chosen quasi-Monte Carlo sampler.
403 ### Returns
404 - `NDArray`: A numpy array of shape `(nb_samples, nb_params)`
405 containing the sampled parameters."""
406 return qmc.scale(
407 self._sampler.random(nb_samples), self.l_bounds, self.u_bounds
408 ).squeeze()
411class HaltonSamplerMixIn(QMCSamplerMixIn):
412 """Mixin class for `scipy.stats.qmc.Halton` sampler to be set."""
414 def __init__(self, nb_params: int, seed: Optional[int] = None) -> None:
415 self._sampler = qmc.Halton(d=nb_params, scramble=True, seed=seed)
416 QMCSamplerMixIn.__init__(self)
419class HaltonGenerator(HaltonSamplerMixIn, SobolBaseExperiment):
420 """Deterministic sample generator based on scipy Halton sequence
421 https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.qmc.Halton.html
422 """
423 def __init__(
424 self,
425 nb_params: int,
426 nb_sims: int,
427 l_bounds: List[Union[float, int]],
428 u_bounds: List[Union[float, int]],
429 seed: Optional[int] = None,
430 dtype=np.float64,
431 apply_pick_freeze: bool = False,
432 second_order: bool = False,
433 ) -> None:
435 HaltonSamplerMixIn.__init__(self, nb_params, seed)
436 SobolBaseExperiment.__init__(
437 self, nb_params, nb_sims, l_bounds, u_bounds, seed, dtype,
438 apply_pick_freeze, second_order
439 )
442class LatinHypercubeSamplerMixIn(QMCSamplerMixIn):
443 """Mixin class for `scipy.stats.qmc.LatinHypercube` sampler to be set."""
445 def __init__(self, nb_params: int, seed: Optional[int] = None) -> None:
446 self._sampler = qmc.LatinHypercube(d=nb_params, scramble=False, seed=seed)
447 QMCSamplerMixIn.__init__(self)
450class LHSGenerator(LatinHypercubeSamplerMixIn, SobolBaseExperiment):
451 """Non-deterministic sample generator based on scipy LHS sampling
452 https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.qmc.LatinHypercube.html
453 """
455 def __init__(
456 self,
457 nb_params: int,
458 nb_sims: int,
459 l_bounds: List[Union[float, int]],
460 u_bounds: List[Union[float, int]],
461 seed: Optional[int] = None,
462 dtype=np.float64,
463 apply_pick_freeze: bool = False,
464 second_order: bool = False,
465 ) -> None:
466 LatinHypercubeSamplerMixIn.__init__(self, nb_params, seed)
467 SobolBaseExperiment.__init__(
468 self, nb_params, nb_sims, l_bounds, u_bounds, seed, dtype,
469 apply_pick_freeze, second_order
470 )
473class MyGetItem(EnumMeta):
474 def __getitem__(self, name):
475 try:
476 return super().__getitem__(name.upper().strip())
477 except KeyError as e:
478 if "random" in name.lower().strip() or "uniform" in name.lower().strip():
479 return super().__getitem__("RANDOM_UNIFORM")
480 raise e
483@unique
484class ParameterSamplerType(Enum, metaclass=MyGetItem):
485 """Enum to choose the paramter sampler type.
487 - RANDOM_UNIFORM = 0
488 - HALTON = 1
489 - LHS = 2"""
491 RANDOM_UNIFORM = 0
492 HALTON = 1
493 LHS = 2
496def make_parameter_sampler(
497 sampler_t: Union[ParameterSamplerType, Type[ParameterSamplerClass]], **kwargs
498) -> BaseExperiment:
499 """Creates and returns an instance of a parameter sampler based on the specified sampler type.
501 This function supports both predefined sampler types from `ParameterSamplerType` and custom
502 sampler classes passed directly. It instantiates the appropriate sampler class with the
503 provided keyword arguments.
505 ### Parameters:
506 - **sampler_t** (`Union[ParameterSamplerType, Type[ParameterSamplerClass]]`):
507 - `ParameterSamplerType`: An enum value from `ParameterSamplerType`.
508 (`RANDOM_UNIFORM`, `HALTON`, `LHS`)
509 - `Type[ParameterSamplerClass]`: A predefined or a custom class type to be instantiated
510 (Not an object).
511 - **kwargs**: Additional keyword arguments passed to the sampler class constructor.
513 ### Returns:
514 - `BaseExperiment`: An instance of the chosen parameter sampler.
516 ### Raises:
517 - `ValueError`: If `sampler_t` is not one of the supported enum values or
518 a valid subclass of `BaseExperiement`."""
520 if isinstance(sampler_t, ParameterSamplerType):
521 if sampler_t is ParameterSamplerType.RANDOM_UNIFORM:
522 return RandomUniform(**kwargs)
523 elif sampler_t is ParameterSamplerType.HALTON:
524 return HaltonGenerator(**kwargs)
525 elif sampler_t is ParameterSamplerType.LHS:
526 return LHSGenerator(**kwargs)
527 else:
528 raise ValueError(
529 "ParameterSamplerType must be one of "
530 "`RANDOM_UNIFORM`, `HALTON`, or `LHS`."
531 )
532 elif isinstance(sampler_t, type) and issubclass(sampler_t, BaseExperiment):
533 return sampler_t(**kwargs)
534 else:
535 raise ValueError(
536 "Invalid sampler type. Must be a ParameterSamplerType enum "
537 "or a custom sampler class."
538 )