Coverage for melissa/server/sensitivity_analysis/sensitivity_analysis_server.py: 44%

260 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-11-03 09:52 +0100

1"""This script defines `SensitivityAnalysisServer` class.""" 

2 

3import logging 

4from pathlib import Path 

5from dataclasses import dataclass 

6from typing_extensions import override 

7from typing import Any, Dict, List, Optional, Set, Tuple, Union, Callable 

8 

9import cloudpickle 

10import numpy as np 

11from numpy.typing import NDArray 

12from mpi4py import MPI 

13import rapidjson 

14 

15from iterative_stats.iterative_moments import IterativeMoments 

16from iterative_stats.sensitivity.sensitivity_martinez import IterativeSensitivityMartinez 

17from melissa.launcher import message 

18from melissa.server.base_server import BaseServer 

19from melissa.server.simulation import PartialSimulationData, Simulation 

20from melissa.server.exceptions import ReceptionError 

21from melissa.utility.rank_helper import MPI2NP_DT 

22 

23 

24logger = logging.getLogger(__name__) 

25 

26 

27@dataclass 

28class FieldMetadata: 

29 """A class to store and manage metadata for a field. 

30 

31 ### Parameters 

32 - **size (`int`)**: The number of local vectors i.e client ranks. 

33 

34 ### Attributes 

35 - **local_vect_sizes (`NDArray`)**: An array containing the local vector 

36 sizes for each process. 

37 - **global_vect_size** (`int`): The total global vector size, calculated as the sum.""" 

38 

39 local_vect_sizes: NDArray 

40 global_vect_size: int 

41 

42 def __init__(self, size: int): 

43 self.local_vect_sizes = np.zeros(size, dtype=MPI2NP_DT["int"]) 

44 

45 def compute_global_size(self): 

46 """Computes the global vector size by summing.""" 

47 self.global_vect_size = int(np.sum(self.local_vect_sizes)) 

48 

49 

50class SensitivityAnalysisServer(BaseServer): 

51 """`SensitivityAnalysisServer` class extends the `BaseServer` class and provides specialized 

52 functionalities for sensitivity analysis. The primary tasks of this class include: 

53 

54 - Generating parameters and scripts using pick-freeze sampling. 

55 - Calculating statistical moments with the `IterativeSensitivityMartinez` method. 

56 - Overriding or redefining abstract methods. 

57 

58 ### Parameters 

59 - **config_dict** (`Dict[str, Any]`): 

60 A dictionary containing configuration settings for initializing 

61 the sensitivity analysis server. 

62 

63 ### Attributes 

64 - **sobol_op** (`bool`): Indicates if Sobol sensitivity analysis is enabled. 

65 - **second_order** (`bool`): Flag to activate second order for parameter sampling 

66 during pick-freeze. 

67 - **__mean** (`bool`): Flag for computing the mean as part of the statistical analysis. 

68 - **__variance** (`bool`): Flag for computing the variance as part of the statistical analysis. 

69 - **__skewness** (`bool`): Flag for computing the skewness as part of the statistical analysis. 

70 - **__kurtosis** (`bool`): Flag for computing the kurtosis as part of the statistical analysis. 

71 - **__seen_ranks** (`Set[int]`): Set of ranks corresponding to clients that have been processed. 

72 - **__checkpoint_count** (`int`): Counter for the number of checkpoints performed. 

73 - **__checkpoint_interval** (`int`): Interval at which checkpoints are taken, 

74 specified in the configuration. 

75 - **__max_order** (`int`): The maximum statistical moment order to compute 

76 (For example, mean = 1, variance = 2, etc.). 

77 - **__melissa_moments** (`Dict[tuple, IterativeMoments]`): Dictionary to store statistical 

78 moments for each field, rank, and time step. 

79 - **__pick_freeze_matrix** (`List[List[Union[int, float]]]`): Matrix to 

80 freeze parameters for Sobol computations(used if Sobol analysis is enabled). 

81 - **__melissa_sobol** (`Dict[tuple, IterativeSensitivityMartinez]`): Dictionary to store 

82 Sobol sensitivity indices for each field, rank, and time step, if enabled.""" 

83 def __init__(self, config_dict: Dict[str, Any]) -> None: 

84 

85 super().__init__(config_dict) 

86 

87 Path('./results/').mkdir(parents=True, exist_ok=True) 

88 sa_config: Dict[str, Any] = config_dict["sa_config"] 

89 

90 self.sobol_op = sa_config.get("sobol_indices", False) 

91 self.second_order = self.sobol_op and sa_config.get("second_order", False) 

92 self._check_group_size() 

93 

94 self.__mean = sa_config.get("mean", False) 

95 self.__variance = sa_config.get("variance", False) 

96 self.__skewness = sa_config.get("skewness", False) 

97 self.__kurtosis = sa_config.get("kurtosis", False) 

98 

99 self.__seen_ranks: Set[int] = set() # list of seen client ranks 

100 self.__checkpoint_count: int = 0 

101 self.__checkpoint_interval: int = sa_config["checkpoint_interval"] 

102 

103 # Instantiate the melissa statistical data structures 

104 self.__max_order: int = 0 

105 # {(field, clt_rank, t): StatisticalMoments}} 

106 self.__melissa_moments: Dict[Tuple[str, int, int], IterativeMoments] = {} 

107 if self.sobol_op: 

108 self.__pick_freeze_matrix: List[List[Union[int, float]]] = [] 

109 # {(field, clt_rank, t): IterSobolMartinez}} 

110 self.__melissa_sobol: Dict[Tuple[str, int, int], IterativeSensitivityMartinez] = {} 

111 

112 if self.__kurtosis: 

113 self.__max_order = 4 

114 elif self.__variance: 

115 self.__max_order = 3 

116 elif self.__variance: 

117 self.__max_order = 2 

118 elif self.__mean: 

119 self.__max_order = 1 

120 else: 

121 self.__max_order = 0 

122 

123 # only calling it to handle the situation. 

124 self.__unimplemented_stats(sa_config) 

125 

126 @property 

127 def melissa_moments(self) -> Dict[Tuple[str, int, int], IterativeMoments]: 

128 return self.__melissa_moments 

129 

130 @property 

131 def melissa_sobol(self) -> Dict[Tuple[str, int, int], IterativeSensitivityMartinez]: 

132 return self.__melissa_sobol 

133 

134 # keeping it modularized for better code management. 

135 def __unimplemented_stats(self, sa_config): 

136 """No implementation available for the following yet.""" 

137 

138 self.__min = sa_config.get("min", False) 

139 self.__max = sa_config.get("max", False) 

140 self.__threshold_exceedance = sa_config.get("threshold_exceedance", False) 

141 self.__threshold_values = sa_config.get("threshold_values", [0.7, 0.8]) 

142 self.__quantiles = sa_config.get("quantiles", False) 

143 self.__quantile_values = sa_config.get( 

144 "quantile_values", [0.05, 0.25, 0.5, 0.75, 0.95] 

145 ) 

146 

147 if self.__min or self.__max: 

148 logger.warning("min max not implemented") 

149 if self.__threshold_exceedance: 

150 logger.warning("threshold not implemented") 

151 if self.__quantiles: 

152 logger.warning("quantiles not implemented") 

153 

154 @override 

155 def _check_group_size(self) -> None: 

156 """Based on sobol, validates the given group size, 

157 and updates the number of clients.""" 

158 

159 if self.sobol_op: 

160 self.group_size = self.nb_parameters + 2 

161 self.nb_groups = self.nb_clients 

162 self.nb_clients = self.group_size * self.nb_groups 

163 

164 @override 

165 def _verify_and_update_sampler_kwargs(self, sampler_t, **kwargs) -> Dict[str, Any]: 

166 if not self.offline_mode: 

167 kwargs.update({ 

168 "apply_pick_freeze": self.sobol_op, 

169 "second_order": self.sobol_op and self.second_order 

170 }) 

171 

172 return super()._verify_and_update_sampler_kwargs(sampler_t, **kwargs) 

173 

174 @override 

175 def _receive(self): 

176 """Handles data from the server.""" 

177 

178 try: 

179 self._is_receiving = True 

180 received_samples = 0 

181 while not self._all_done(): 

182 status = self.poll_sockets() 

183 if status is not None: 

184 if isinstance(status, PartialSimulationData): 

185 logger.debug( 

186 f"Rank {self.rank}>> " 

187 f"sim-id={status.simulation_id}, " 

188 f"time-step={status.time_step} received." 

189 ) 

190 received_samples += 1 

191 

192 # compute the statistics on the received data 

193 self._compute_stats(status) 

194 self._checkpoint() 

195 

196 self._is_receiving = False 

197 except Exception as e: 

198 raise ReceptionError() from e 

199 

200 @override 

201 def _server_online(self): 

202 """Steps to perform while the server is online.""" 

203 self._receive() 

204 

205 @override 

206 def _server_offline(self): 

207 """Optional. Post processing steps.""" 

208 

209 self._start_pinger_thread() 

210 self._melissa_write_stats() 

211 self._stop_pinger_thread() 

212 

213 @override 

214 def start(self): 

215 """The main entrypoint for the server events.""" 

216 if self._restart: 

217 self._restart_from_checkpoint() 

218 self._restart_groups() 

219 else: 

220 self._launch_groups(list(range(0, self._job_limit))) 

221 

222 self.setup_environment() 

223 self._server_online() 

224 self._server_offline() 

225 self._server_finalize() 

226 

227 @override 

228 def _process_partial_data_reception(self, 

229 _: Simulation, 

230 simulation_data: PartialSimulationData 

231 ) -> PartialSimulationData: 

232 return simulation_data 

233 

234 @override 

235 def _process_complete_data_reception(self, 

236 simulation: Simulation, 

237 simulation_data: PartialSimulationData 

238 ) -> PartialSimulationData: 

239 

240 simulation.clear_data(simulation_data.client_rank, simulation_data.time_step) 

241 return simulation_data 

242 

243 def __get_cached_sobol_data(self, 

244 pdata: PartialSimulationData) -> Union[bool, Optional[NDArray]]: 

245 """Caches time steps received from each simulation in a group for Sobol sampling. 

246 

247 ### Parameters 

248 - **pdata** (`PartialSimulationData`): The data message received from the simulation. 

249 

250 ### Returns 

251 - **Union[bool, Optional[NDArray]]**: 

252 - `NDArray` if all timesteps for a specific group are available. 

253 - `False` otherwise.""" 

254 

255 group_id = self._get_group_id_by_simulation(pdata.simulation_id) 

256 current_group = self._groups[group_id] 

257 current_group.cache(pdata) 

258 np_data = current_group.get_cached( 

259 pdata.field, pdata.client_rank, pdata.time_step 

260 ) 

261 

262 return np_data if len(np_data) == self.group_size else False 

263 

264 def _compute_stats(self, pdata: PartialSimulationData) -> None: 

265 """Computes statistics iteratively and Sobol sensitivity indices, if `sobol_op` is set. 

266 - Initializes `IterativeMoments` and `IterativeSensitivityMartinez` objects 

267 per new combination of field, client rank, and time step. 

268 - Handles Sobol calculations by caching received data for a specific group. 

269 

270 ### Parameters 

271 - **pdata** (`PartialSimulationData`): The data message received from the simulation.""" 

272 

273 np_data: Union[bool, Optional[NDArray]] 

274 

275 self.__seen_ranks.add(pdata.client_rank) 

276 current_key = (pdata.field, pdata.client_rank, pdata.time_step) 

277 if current_key not in self.__melissa_moments: 

278 self.__melissa_moments[current_key] = IterativeMoments( 

279 self.__max_order, 

280 dim=pdata.data_size 

281 ) 

282 if self.sobol_op: 

283 self.__melissa_sobol[current_key] = IterativeSensitivityMartinez( 

284 nb_parms=self.nb_parameters, 

285 dim=pdata.data_size 

286 ) 

287 

288 if not self.sobol_op: 

289 np_data = pdata.payload.data.reshape(-1, pdata.data_size) 

290 self.__melissa_moments[current_key].increment(np_data[0]) 

291 else: 

292 np_data = self.__get_cached_sobol_data(pdata) 

293 if isinstance(np_data, np.ndarray): 

294 self.__melissa_moments[current_key].increment(np_data[0]) 

295 # increment the sobol data structure 

296 self.__melissa_sobol[current_key].increment(np_data) 

297 # increment the moments with the second solution 

298 self.__melissa_moments[current_key].increment(np_data[1]) 

299 

300 del np_data 

301 

302 def __gather_data(self, 

303 local_vect_sizes: NDArray[np.int32], 

304 d_buffer: NDArray[np.float64] 

305 ) -> NDArray[np.float64]: 

306 """Gathers data from all ranks to rank 0 using MPI's Gatherv function. 

307 

308 ### Parameters: 

309 - **local_vect_sizes (`NDArray[np.int32]`)**: An array containing the size 

310 of the data vector per rank. 

311 - **d_buffer (`NDArray[np.float64]`)**: An array containing the local data to be gathered. 

312 

313 ### Returns: 

314 - `NDArray[np.float64]`: An array with the gathered data at rank 0.""" 

315 

316 offsets = [0] + list(np.cumsum(local_vect_sizes))[:-1] 

317 # TODO: mpi4py version 4+ has Gatherv_init which removes initialization overhead 

318 # for persistent gather calls. 

319 self.comm.Gatherv( 

320 d_buffer[:local_vect_sizes[self.rank]], 

321 [d_buffer, local_vect_sizes, offsets, MPI.DOUBLE], 

322 root=0 

323 ) 

324 return d_buffer 

325 

326 def __gather_and_write_moments(self, 

327 field: str, 

328 global_vect_size: int, 

329 local_vect_sizes: NDArray[np.int32], 

330 stat_type: str, 

331 values_fn: Callable) -> None: 

332 """Gathers data from all ranks based on the specified statistical type, 

333 and writes the results. 

334 

335 ### Parameters: 

336 - **field** (`str`): The field for which the data is to be gathered. 

337 - **global_vect_size (`int`)**: The size of the global vector. 

338 - **local_vect_sizes (`NDArray[np.int32]`)**: An array containing the size 

339 of the data vector per rank. 

340 - **stat_type (`str`)**: A string specifying the type of statistics to gather 

341 (called per moment. For example, `mean`). 

342 - **values_fn (`Callable`)**: A function that takes `__melissa_moments` object 

343 which calls `get_stat_type()` already defined. 

344 (called per moment. For example, `lambda m: m.get_mean()`).""" 

345 

346 d_buffer = np.zeros(global_vect_size) 

347 self.comm.Barrier() 

348 

349 assert self.time_steps_known, "melissa_finalize() must be called on the client-side." 

350 

351 for t in range(self.nb_time_steps): 

352 file_name = f"./results/results.{field}_{stat_type}." \ 

353 f"{str(t + 1).zfill(len(str(self.nb_time_steps)))}" 

354 

355 temp_offset = 0 

356 for rank in range(self.client_comm_size): 

357 key = (field, rank, t) 

358 values = values_fn(self.__melissa_moments[key]) 

359 

360 if np.size(values) > 0: 

361 last_offset = temp_offset + np.size(values) 

362 d_buffer[temp_offset:last_offset] = values 

363 temp_offset = last_offset 

364 

365 d_buffer = self.__gather_data(local_vect_sizes, d_buffer) 

366 if self.rank == 0: 

367 np.savetxt(file_name, d_buffer) 

368 logger.info(f"file name: {file_name}") 

369 

370 def __gather_and_write_sobol(self, 

371 field: str, 

372 global_vect_size: int, 

373 local_vect_sizes: NDArray[np.int32]) -> None: 

374 """Gathers sobol data from all ranks, and writes the results. 

375 

376 ### Parameters: 

377 - **field** (`str`): The field for which the data is to be gathered. 

378 - **global_vect_size (`int`)**: The size of the global vector. 

379 - **local_vect_sizes (`NDArray[np.int32]`)**: An array containing the size 

380 of the data vector per rank.""" 

381 

382 d_buffer_a = np.zeros(global_vect_size) 

383 d_buffer_b = np.zeros(global_vect_size) 

384 self.comm.Barrier() 

385 

386 for param in range(self.nb_parameters): 

387 for t in range(self.nb_time_steps): 

388 file_name_b = f"./results/results.{field}_sobol{str(param)}." \ 

389 f"{str(t + 1).zfill(len(str(self.nb_time_steps)))}" 

390 

391 file_name_a = f"./results/results.{field}_sobol_tot{str(param)}." \ 

392 f"{str(t + 1).zfill(len(str(self.nb_time_steps)))}" 

393 

394 temp_offset = 0 

395 for rank in range(self.client_comm_size): 

396 key = (field, rank, t) 

397 pearson_b = self.__melissa_sobol[key].pearson_B[:, param] 

398 pearson_a = self.__melissa_sobol[key].pearson_A[:, param] 

399 

400 if ( 

401 np.size(pearson_b) > 0 

402 and np.size(pearson_a) == np.size(pearson_b) 

403 ): 

404 last_offset = temp_offset + np.size(pearson_b) 

405 d_buffer_b[temp_offset:last_offset] = pearson_b 

406 d_buffer_a[temp_offset:last_offset] = pearson_a 

407 temp_offset = last_offset 

408 

409 d_buffer_b = self.__gather_data(local_vect_sizes, d_buffer_b) 

410 d_buffer_a = self.__gather_data(local_vect_sizes, d_buffer_a) 

411 

412 if self.rank == 0: 

413 np.savetxt(file_name_b, d_buffer_b) 

414 logger.info(f"file name: {file_name_b}") 

415 np.savetxt(file_name_a, d_buffer_a) 

416 logger.info(f"file name: {file_name_a}") 

417 

418 def _melissa_write_stats(self): 

419 """Gathers and writes all results to `results/` folder.""" 

420 

421 # turn server monitoring off 

422 if self.rank == 0: 

423 snd_msg = self._encode_msg(message.StopTimeoutMonitoring()) 

424 self._launcherfd.send(snd_msg) 

425 

426 # brodcast client_comm_size to all server ranks 

427 client_comm_size: int = self.client_comm_size 

428 if self.rank == 0: 

429 self.comm.bcast(client_comm_size, root=0) 

430 else: 

431 client_comm_size = self.comm.bcast(client_comm_size, root=0) 

432 self.client_comm_size = client_comm_size 

433 logger.info(f"Rank {self.rank}>> gathered client-comm-size={self.client_comm_size}") 

434 

435 # update __melissa_moments and __melissa_sobol with missing client ranks 

436 # these are just placeholders and do not contribute to the results 

437 unseen_ranks = set(range(self.client_comm_size)) - self.__seen_ranks 

438 for field in self.fields: 

439 for client_rank in unseen_ranks: 

440 for t in range(self.nb_time_steps): 

441 key = (field, client_rank, t) 

442 self.__melissa_moments[key] = IterativeMoments( 

443 self.__max_order, 

444 dim=0 

445 ) 

446 if self.sobol_op: 

447 self.__melissa_sobol[key] = IterativeSensitivityMartinez( 

448 nb_parms=self.nb_parameters, 

449 dim=0 

450 ) 

451 # avoiding code repetitions 

452 stat2values_fn: Dict[str, Tuple[bool, Callable]] = { 

453 "mean": (self.__mean, lambda m: m.get_mean()), 

454 "variance": (self.__variance, lambda m: m.get_variance()), 

455 "skewness": (self.__skewness, lambda m: m.get_skewness()), 

456 "kurtosis": (self.__kurtosis, lambda m: m.get_kurtosis()) 

457 } 

458 

459 # compute the vector size across all client ranks 

460 # and gather them to calculate the global size for every server rank. 

461 # this is done across all fields 

462 # and finally gather results for all moments and sobol 

463 field_metadata: Dict[str, FieldMetadata] = {} 

464 for field in self.fields: 

465 field_metadata[field] = FieldMetadata(self.comm_size) 

466 vect_size = np.zeros(1, dtype=MPI2NP_DT["int"]) 

467 for client_rank in range(self.client_comm_size): 

468 key = (field, client_rank, 0) 

469 vect_size += np.size(self.__melissa_moments[key].get_mean()) 

470 

471 self.comm.Allgather( 

472 [vect_size, MPI.INT], 

473 [field_metadata[field].local_vect_sizes, MPI.INT] 

474 ) 

475 

476 field_metadata[field].compute_global_size() 

477 local_vect_sizes = field_metadata[field].local_vect_sizes 

478 global_vect_size = field_metadata[field].global_vect_size 

479 

480 logger.info( 

481 f"Rank {self.rank}>> field=\"{field}\", " 

482 f"local-vect-size={vect_size[-1]}, " 

483 f"global-vect-size={global_vect_size}" 

484 ) 

485 for stat_type, (condition, values_fn) in stat2values_fn.items(): 

486 if condition: 

487 self.__gather_and_write_moments( 

488 field, 

489 global_vect_size, 

490 local_vect_sizes, 

491 stat_type, 

492 values_fn 

493 ) 

494 

495 if self.sobol_op: 

496 self.__gather_and_write_sobol(field, global_vect_size, local_vect_sizes) 

497 

498 @override 

499 def _checkpoint(self, **kwargs) -> None: 

500 """Checkpoint moments and sobol information.""" 

501 

502 if self.no_fault_tolerance or not self.__checkpoint_interval: 

503 return 

504 

505 self.__checkpoint_count += 1 

506 if self.__checkpoint_count % self.__checkpoint_interval != 0: 

507 return 

508 

509 self._save_base_state() 

510 

511 stats_metadata = { 

512 "seen_ranks": list(self.__seen_ranks), 

513 "nb_time_steps": self.nb_time_steps 

514 } 

515 

516 # logger.info(f"Checkpointing moments {self.__melissa_moments}") 

517 with open(f"checkpoints/{self.rank}/melissa_moments.pkl", 'wb') as f: 

518 cloudpickle.dump(self.__melissa_moments, f) 

519 

520 if self.sobol_op: 

521 with open(f"checkpoints/{self.rank}/melissa_sobol.pkl", 'wb') as f: 

522 cloudpickle.dump(self.__melissa_sobol, f) 

523 

524 with open(f"checkpoints/{self.rank}/stats_metadata.json", 'w') as f: 

525 rapidjson.dump(stats_metadata, f) 

526 

527 @override 

528 def _restart_from_checkpoint(self, **kwargs) -> None: 

529 """Loads from the last checkpoint, in case of a restart.""" 

530 

531 self._load_base_state() 

532 

533 with open(f"checkpoints/{self.rank}/melissa_moments.pkl", 'rb') as f: 

534 self.__melissa_moments = cloudpickle.load(f) 

535 

536 if self.sobol_op: 

537 with open(f"checkpoints/{self.rank}/melissa_sobol.pkl", 'rb') as f: 

538 self.__melissa_sobol = cloudpickle.load(f) 

539 

540 with open(f"checkpoints/{self.rank}/stats_metadata.json", 'r') as f: 

541 stats_metadata = rapidjson.load(f) 

542 

543 self.__seen_ranks = set(stats_metadata["seen_ranks"]) 

544 self.nb_time_steps = stats_metadata["nb_time_steps"]