Coverage for melissa/server/sensitivity_analysis/sensitivity_analysis_server.py: 44%
260 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
1"""This script defines `SensitivityAnalysisServer` class."""
3import logging
4from pathlib import Path
5from dataclasses import dataclass
6from typing_extensions import override
7from typing import Any, Dict, List, Optional, Set, Tuple, Union, Callable
9import cloudpickle
10import numpy as np
11from numpy.typing import NDArray
12from mpi4py import MPI
13import rapidjson
15from iterative_stats.iterative_moments import IterativeMoments
16from iterative_stats.sensitivity.sensitivity_martinez import IterativeSensitivityMartinez
17from melissa.launcher import message
18from melissa.server.base_server import BaseServer
19from melissa.server.simulation import PartialSimulationData, Simulation
20from melissa.server.exceptions import ReceptionError
21from melissa.utility.rank_helper import MPI2NP_DT
24logger = logging.getLogger(__name__)
27@dataclass
28class FieldMetadata:
29 """A class to store and manage metadata for a field.
31 ### Parameters
32 - **size (`int`)**: The number of local vectors i.e client ranks.
34 ### Attributes
35 - **local_vect_sizes (`NDArray`)**: An array containing the local vector
36 sizes for each process.
37 - **global_vect_size** (`int`): The total global vector size, calculated as the sum."""
39 local_vect_sizes: NDArray
40 global_vect_size: int
42 def __init__(self, size: int):
43 self.local_vect_sizes = np.zeros(size, dtype=MPI2NP_DT["int"])
45 def compute_global_size(self):
46 """Computes the global vector size by summing."""
47 self.global_vect_size = int(np.sum(self.local_vect_sizes))
50class SensitivityAnalysisServer(BaseServer):
51 """`SensitivityAnalysisServer` class extends the `BaseServer` class and provides specialized
52 functionalities for sensitivity analysis. The primary tasks of this class include:
54 - Generating parameters and scripts using pick-freeze sampling.
55 - Calculating statistical moments with the `IterativeSensitivityMartinez` method.
56 - Overriding or redefining abstract methods.
58 ### Parameters
59 - **config_dict** (`Dict[str, Any]`):
60 A dictionary containing configuration settings for initializing
61 the sensitivity analysis server.
63 ### Attributes
64 - **sobol_op** (`bool`): Indicates if Sobol sensitivity analysis is enabled.
65 - **second_order** (`bool`): Flag to activate second order for parameter sampling
66 during pick-freeze.
67 - **__mean** (`bool`): Flag for computing the mean as part of the statistical analysis.
68 - **__variance** (`bool`): Flag for computing the variance as part of the statistical analysis.
69 - **__skewness** (`bool`): Flag for computing the skewness as part of the statistical analysis.
70 - **__kurtosis** (`bool`): Flag for computing the kurtosis as part of the statistical analysis.
71 - **__seen_ranks** (`Set[int]`): Set of ranks corresponding to clients that have been processed.
72 - **__checkpoint_count** (`int`): Counter for the number of checkpoints performed.
73 - **__checkpoint_interval** (`int`): Interval at which checkpoints are taken,
74 specified in the configuration.
75 - **__max_order** (`int`): The maximum statistical moment order to compute
76 (For example, mean = 1, variance = 2, etc.).
77 - **__melissa_moments** (`Dict[tuple, IterativeMoments]`): Dictionary to store statistical
78 moments for each field, rank, and time step.
79 - **__pick_freeze_matrix** (`List[List[Union[int, float]]]`): Matrix to
80 freeze parameters for Sobol computations(used if Sobol analysis is enabled).
81 - **__melissa_sobol** (`Dict[tuple, IterativeSensitivityMartinez]`): Dictionary to store
82 Sobol sensitivity indices for each field, rank, and time step, if enabled."""
83 def __init__(self, config_dict: Dict[str, Any]) -> None:
85 super().__init__(config_dict)
87 Path('./results/').mkdir(parents=True, exist_ok=True)
88 sa_config: Dict[str, Any] = config_dict["sa_config"]
90 self.sobol_op = sa_config.get("sobol_indices", False)
91 self.second_order = self.sobol_op and sa_config.get("second_order", False)
92 self._check_group_size()
94 self.__mean = sa_config.get("mean", False)
95 self.__variance = sa_config.get("variance", False)
96 self.__skewness = sa_config.get("skewness", False)
97 self.__kurtosis = sa_config.get("kurtosis", False)
99 self.__seen_ranks: Set[int] = set() # list of seen client ranks
100 self.__checkpoint_count: int = 0
101 self.__checkpoint_interval: int = sa_config["checkpoint_interval"]
103 # Instantiate the melissa statistical data structures
104 self.__max_order: int = 0
105 # {(field, clt_rank, t): StatisticalMoments}}
106 self.__melissa_moments: Dict[Tuple[str, int, int], IterativeMoments] = {}
107 if self.sobol_op:
108 self.__pick_freeze_matrix: List[List[Union[int, float]]] = []
109 # {(field, clt_rank, t): IterSobolMartinez}}
110 self.__melissa_sobol: Dict[Tuple[str, int, int], IterativeSensitivityMartinez] = {}
112 if self.__kurtosis:
113 self.__max_order = 4
114 elif self.__variance:
115 self.__max_order = 3
116 elif self.__variance:
117 self.__max_order = 2
118 elif self.__mean:
119 self.__max_order = 1
120 else:
121 self.__max_order = 0
123 # only calling it to handle the situation.
124 self.__unimplemented_stats(sa_config)
126 @property
127 def melissa_moments(self) -> Dict[Tuple[str, int, int], IterativeMoments]:
128 return self.__melissa_moments
130 @property
131 def melissa_sobol(self) -> Dict[Tuple[str, int, int], IterativeSensitivityMartinez]:
132 return self.__melissa_sobol
134 # keeping it modularized for better code management.
135 def __unimplemented_stats(self, sa_config):
136 """No implementation available for the following yet."""
138 self.__min = sa_config.get("min", False)
139 self.__max = sa_config.get("max", False)
140 self.__threshold_exceedance = sa_config.get("threshold_exceedance", False)
141 self.__threshold_values = sa_config.get("threshold_values", [0.7, 0.8])
142 self.__quantiles = sa_config.get("quantiles", False)
143 self.__quantile_values = sa_config.get(
144 "quantile_values", [0.05, 0.25, 0.5, 0.75, 0.95]
145 )
147 if self.__min or self.__max:
148 logger.warning("min max not implemented")
149 if self.__threshold_exceedance:
150 logger.warning("threshold not implemented")
151 if self.__quantiles:
152 logger.warning("quantiles not implemented")
154 @override
155 def _check_group_size(self) -> None:
156 """Based on sobol, validates the given group size,
157 and updates the number of clients."""
159 if self.sobol_op:
160 self.group_size = self.nb_parameters + 2
161 self.nb_groups = self.nb_clients
162 self.nb_clients = self.group_size * self.nb_groups
164 @override
165 def _verify_and_update_sampler_kwargs(self, sampler_t, **kwargs) -> Dict[str, Any]:
166 if not self.offline_mode:
167 kwargs.update({
168 "apply_pick_freeze": self.sobol_op,
169 "second_order": self.sobol_op and self.second_order
170 })
172 return super()._verify_and_update_sampler_kwargs(sampler_t, **kwargs)
174 @override
175 def _receive(self):
176 """Handles data from the server."""
178 try:
179 self._is_receiving = True
180 received_samples = 0
181 while not self._all_done():
182 status = self.poll_sockets()
183 if status is not None:
184 if isinstance(status, PartialSimulationData):
185 logger.debug(
186 f"Rank {self.rank}>> "
187 f"sim-id={status.simulation_id}, "
188 f"time-step={status.time_step} received."
189 )
190 received_samples += 1
192 # compute the statistics on the received data
193 self._compute_stats(status)
194 self._checkpoint()
196 self._is_receiving = False
197 except Exception as e:
198 raise ReceptionError() from e
200 @override
201 def _server_online(self):
202 """Steps to perform while the server is online."""
203 self._receive()
205 @override
206 def _server_offline(self):
207 """Optional. Post processing steps."""
209 self._start_pinger_thread()
210 self._melissa_write_stats()
211 self._stop_pinger_thread()
213 @override
214 def start(self):
215 """The main entrypoint for the server events."""
216 if self._restart:
217 self._restart_from_checkpoint()
218 self._restart_groups()
219 else:
220 self._launch_groups(list(range(0, self._job_limit)))
222 self.setup_environment()
223 self._server_online()
224 self._server_offline()
225 self._server_finalize()
227 @override
228 def _process_partial_data_reception(self,
229 _: Simulation,
230 simulation_data: PartialSimulationData
231 ) -> PartialSimulationData:
232 return simulation_data
234 @override
235 def _process_complete_data_reception(self,
236 simulation: Simulation,
237 simulation_data: PartialSimulationData
238 ) -> PartialSimulationData:
240 simulation.clear_data(simulation_data.client_rank, simulation_data.time_step)
241 return simulation_data
243 def __get_cached_sobol_data(self,
244 pdata: PartialSimulationData) -> Union[bool, Optional[NDArray]]:
245 """Caches time steps received from each simulation in a group for Sobol sampling.
247 ### Parameters
248 - **pdata** (`PartialSimulationData`): The data message received from the simulation.
250 ### Returns
251 - **Union[bool, Optional[NDArray]]**:
252 - `NDArray` if all timesteps for a specific group are available.
253 - `False` otherwise."""
255 group_id = self._get_group_id_by_simulation(pdata.simulation_id)
256 current_group = self._groups[group_id]
257 current_group.cache(pdata)
258 np_data = current_group.get_cached(
259 pdata.field, pdata.client_rank, pdata.time_step
260 )
262 return np_data if len(np_data) == self.group_size else False
264 def _compute_stats(self, pdata: PartialSimulationData) -> None:
265 """Computes statistics iteratively and Sobol sensitivity indices, if `sobol_op` is set.
266 - Initializes `IterativeMoments` and `IterativeSensitivityMartinez` objects
267 per new combination of field, client rank, and time step.
268 - Handles Sobol calculations by caching received data for a specific group.
270 ### Parameters
271 - **pdata** (`PartialSimulationData`): The data message received from the simulation."""
273 np_data: Union[bool, Optional[NDArray]]
275 self.__seen_ranks.add(pdata.client_rank)
276 current_key = (pdata.field, pdata.client_rank, pdata.time_step)
277 if current_key not in self.__melissa_moments:
278 self.__melissa_moments[current_key] = IterativeMoments(
279 self.__max_order,
280 dim=pdata.data_size
281 )
282 if self.sobol_op:
283 self.__melissa_sobol[current_key] = IterativeSensitivityMartinez(
284 nb_parms=self.nb_parameters,
285 dim=pdata.data_size
286 )
288 if not self.sobol_op:
289 np_data = pdata.payload.data.reshape(-1, pdata.data_size)
290 self.__melissa_moments[current_key].increment(np_data[0])
291 else:
292 np_data = self.__get_cached_sobol_data(pdata)
293 if isinstance(np_data, np.ndarray):
294 self.__melissa_moments[current_key].increment(np_data[0])
295 # increment the sobol data structure
296 self.__melissa_sobol[current_key].increment(np_data)
297 # increment the moments with the second solution
298 self.__melissa_moments[current_key].increment(np_data[1])
300 del np_data
302 def __gather_data(self,
303 local_vect_sizes: NDArray[np.int32],
304 d_buffer: NDArray[np.float64]
305 ) -> NDArray[np.float64]:
306 """Gathers data from all ranks to rank 0 using MPI's Gatherv function.
308 ### Parameters:
309 - **local_vect_sizes (`NDArray[np.int32]`)**: An array containing the size
310 of the data vector per rank.
311 - **d_buffer (`NDArray[np.float64]`)**: An array containing the local data to be gathered.
313 ### Returns:
314 - `NDArray[np.float64]`: An array with the gathered data at rank 0."""
316 offsets = [0] + list(np.cumsum(local_vect_sizes))[:-1]
317 # TODO: mpi4py version 4+ has Gatherv_init which removes initialization overhead
318 # for persistent gather calls.
319 self.comm.Gatherv(
320 d_buffer[:local_vect_sizes[self.rank]],
321 [d_buffer, local_vect_sizes, offsets, MPI.DOUBLE],
322 root=0
323 )
324 return d_buffer
326 def __gather_and_write_moments(self,
327 field: str,
328 global_vect_size: int,
329 local_vect_sizes: NDArray[np.int32],
330 stat_type: str,
331 values_fn: Callable) -> None:
332 """Gathers data from all ranks based on the specified statistical type,
333 and writes the results.
335 ### Parameters:
336 - **field** (`str`): The field for which the data is to be gathered.
337 - **global_vect_size (`int`)**: The size of the global vector.
338 - **local_vect_sizes (`NDArray[np.int32]`)**: An array containing the size
339 of the data vector per rank.
340 - **stat_type (`str`)**: A string specifying the type of statistics to gather
341 (called per moment. For example, `mean`).
342 - **values_fn (`Callable`)**: A function that takes `__melissa_moments` object
343 which calls `get_stat_type()` already defined.
344 (called per moment. For example, `lambda m: m.get_mean()`)."""
346 d_buffer = np.zeros(global_vect_size)
347 self.comm.Barrier()
349 assert self.time_steps_known, "melissa_finalize() must be called on the client-side."
351 for t in range(self.nb_time_steps):
352 file_name = f"./results/results.{field}_{stat_type}." \
353 f"{str(t + 1).zfill(len(str(self.nb_time_steps)))}"
355 temp_offset = 0
356 for rank in range(self.client_comm_size):
357 key = (field, rank, t)
358 values = values_fn(self.__melissa_moments[key])
360 if np.size(values) > 0:
361 last_offset = temp_offset + np.size(values)
362 d_buffer[temp_offset:last_offset] = values
363 temp_offset = last_offset
365 d_buffer = self.__gather_data(local_vect_sizes, d_buffer)
366 if self.rank == 0:
367 np.savetxt(file_name, d_buffer)
368 logger.info(f"file name: {file_name}")
370 def __gather_and_write_sobol(self,
371 field: str,
372 global_vect_size: int,
373 local_vect_sizes: NDArray[np.int32]) -> None:
374 """Gathers sobol data from all ranks, and writes the results.
376 ### Parameters:
377 - **field** (`str`): The field for which the data is to be gathered.
378 - **global_vect_size (`int`)**: The size of the global vector.
379 - **local_vect_sizes (`NDArray[np.int32]`)**: An array containing the size
380 of the data vector per rank."""
382 d_buffer_a = np.zeros(global_vect_size)
383 d_buffer_b = np.zeros(global_vect_size)
384 self.comm.Barrier()
386 for param in range(self.nb_parameters):
387 for t in range(self.nb_time_steps):
388 file_name_b = f"./results/results.{field}_sobol{str(param)}." \
389 f"{str(t + 1).zfill(len(str(self.nb_time_steps)))}"
391 file_name_a = f"./results/results.{field}_sobol_tot{str(param)}." \
392 f"{str(t + 1).zfill(len(str(self.nb_time_steps)))}"
394 temp_offset = 0
395 for rank in range(self.client_comm_size):
396 key = (field, rank, t)
397 pearson_b = self.__melissa_sobol[key].pearson_B[:, param]
398 pearson_a = self.__melissa_sobol[key].pearson_A[:, param]
400 if (
401 np.size(pearson_b) > 0
402 and np.size(pearson_a) == np.size(pearson_b)
403 ):
404 last_offset = temp_offset + np.size(pearson_b)
405 d_buffer_b[temp_offset:last_offset] = pearson_b
406 d_buffer_a[temp_offset:last_offset] = pearson_a
407 temp_offset = last_offset
409 d_buffer_b = self.__gather_data(local_vect_sizes, d_buffer_b)
410 d_buffer_a = self.__gather_data(local_vect_sizes, d_buffer_a)
412 if self.rank == 0:
413 np.savetxt(file_name_b, d_buffer_b)
414 logger.info(f"file name: {file_name_b}")
415 np.savetxt(file_name_a, d_buffer_a)
416 logger.info(f"file name: {file_name_a}")
418 def _melissa_write_stats(self):
419 """Gathers and writes all results to `results/` folder."""
421 # turn server monitoring off
422 if self.rank == 0:
423 snd_msg = self._encode_msg(message.StopTimeoutMonitoring())
424 self._launcherfd.send(snd_msg)
426 # brodcast client_comm_size to all server ranks
427 client_comm_size: int = self.client_comm_size
428 if self.rank == 0:
429 self.comm.bcast(client_comm_size, root=0)
430 else:
431 client_comm_size = self.comm.bcast(client_comm_size, root=0)
432 self.client_comm_size = client_comm_size
433 logger.info(f"Rank {self.rank}>> gathered client-comm-size={self.client_comm_size}")
435 # update __melissa_moments and __melissa_sobol with missing client ranks
436 # these are just placeholders and do not contribute to the results
437 unseen_ranks = set(range(self.client_comm_size)) - self.__seen_ranks
438 for field in self.fields:
439 for client_rank in unseen_ranks:
440 for t in range(self.nb_time_steps):
441 key = (field, client_rank, t)
442 self.__melissa_moments[key] = IterativeMoments(
443 self.__max_order,
444 dim=0
445 )
446 if self.sobol_op:
447 self.__melissa_sobol[key] = IterativeSensitivityMartinez(
448 nb_parms=self.nb_parameters,
449 dim=0
450 )
451 # avoiding code repetitions
452 stat2values_fn: Dict[str, Tuple[bool, Callable]] = {
453 "mean": (self.__mean, lambda m: m.get_mean()),
454 "variance": (self.__variance, lambda m: m.get_variance()),
455 "skewness": (self.__skewness, lambda m: m.get_skewness()),
456 "kurtosis": (self.__kurtosis, lambda m: m.get_kurtosis())
457 }
459 # compute the vector size across all client ranks
460 # and gather them to calculate the global size for every server rank.
461 # this is done across all fields
462 # and finally gather results for all moments and sobol
463 field_metadata: Dict[str, FieldMetadata] = {}
464 for field in self.fields:
465 field_metadata[field] = FieldMetadata(self.comm_size)
466 vect_size = np.zeros(1, dtype=MPI2NP_DT["int"])
467 for client_rank in range(self.client_comm_size):
468 key = (field, client_rank, 0)
469 vect_size += np.size(self.__melissa_moments[key].get_mean())
471 self.comm.Allgather(
472 [vect_size, MPI.INT],
473 [field_metadata[field].local_vect_sizes, MPI.INT]
474 )
476 field_metadata[field].compute_global_size()
477 local_vect_sizes = field_metadata[field].local_vect_sizes
478 global_vect_size = field_metadata[field].global_vect_size
480 logger.info(
481 f"Rank {self.rank}>> field=\"{field}\", "
482 f"local-vect-size={vect_size[-1]}, "
483 f"global-vect-size={global_vect_size}"
484 )
485 for stat_type, (condition, values_fn) in stat2values_fn.items():
486 if condition:
487 self.__gather_and_write_moments(
488 field,
489 global_vect_size,
490 local_vect_sizes,
491 stat_type,
492 values_fn
493 )
495 if self.sobol_op:
496 self.__gather_and_write_sobol(field, global_vect_size, local_vect_sizes)
498 @override
499 def _checkpoint(self, **kwargs) -> None:
500 """Checkpoint moments and sobol information."""
502 if self.no_fault_tolerance or not self.__checkpoint_interval:
503 return
505 self.__checkpoint_count += 1
506 if self.__checkpoint_count % self.__checkpoint_interval != 0:
507 return
509 self._save_base_state()
511 stats_metadata = {
512 "seen_ranks": list(self.__seen_ranks),
513 "nb_time_steps": self.nb_time_steps
514 }
516 # logger.info(f"Checkpointing moments {self.__melissa_moments}")
517 with open(f"checkpoints/{self.rank}/melissa_moments.pkl", 'wb') as f:
518 cloudpickle.dump(self.__melissa_moments, f)
520 if self.sobol_op:
521 with open(f"checkpoints/{self.rank}/melissa_sobol.pkl", 'wb') as f:
522 cloudpickle.dump(self.__melissa_sobol, f)
524 with open(f"checkpoints/{self.rank}/stats_metadata.json", 'w') as f:
525 rapidjson.dump(stats_metadata, f)
527 @override
528 def _restart_from_checkpoint(self, **kwargs) -> None:
529 """Loads from the last checkpoint, in case of a restart."""
531 self._load_base_state()
533 with open(f"checkpoints/{self.rank}/melissa_moments.pkl", 'rb') as f:
534 self.__melissa_moments = cloudpickle.load(f)
536 if self.sobol_op:
537 with open(f"checkpoints/{self.rank}/melissa_sobol.pkl", 'rb') as f:
538 self.__melissa_sobol = cloudpickle.load(f)
540 with open(f"checkpoints/{self.rank}/stats_metadata.json", 'r') as f:
541 stats_metadata = rapidjson.load(f)
543 self.__seen_ranks = set(stats_metadata["seen_ranks"])
544 self.nb_time_steps = stats_metadata["nb_time_steps"]