Coverage for melissa/server/sensitivity_analysis/sensitivity_analysis_server.py: 26%
383 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
1import logging
2import os
3import time
4from typing import Any, Dict, List, Optional, Tuple, Union
5from pathlib import Path
7import numpy as np
8import numpy.typing as npt
9from mpi4py import MPI
11from melissa.launcher import message
12from melissa.server.base_server import BaseServer
13from melissa.server.simulation import (Group, PartialSimulationData, Simulation,
14 SimulationData, SimulationDataStatus)
15from iterative_stats.sensitivity.sensitivity_martinez import IterativeSensitivityMartinez
16from iterative_stats.iterative_moments import IterativeMoments
17import cloudpickle
18import rapidjson
20logger = logging.getLogger(__name__)
23class SensitivityAnalysisServer(BaseServer):
24 """
25 Server to be used for Sensitivity Analysis studies.
26 """
28 def __init__(self, config: Dict[str, Any]):
29 super().__init__(config)
30 if self.num_samples == 0:
31 raise Exception("Error, in case of an SA study num_samples must be set by the user")
33 self.learning: int = 0
34 self.sa_config: Dict[str, Any] = config["sa_config"]
36 self.sobol_op = 1 if self.sa_config.get("sobol_indices", False) else 0
37 self.check_group_size()
39 self.mean = self.sa_config.get("mean", False)
40 self.variance = self.sa_config.get("variance", False)
41 self.skewness = self.sa_config.get("skewness", False)
42 self.kurtosis = self.sa_config.get("kurtosis", False)
43 self.min = self.sa_config.get("min", False)
44 self.max = self.sa_config.get("max", False)
45 self.threshold_exceedance = self.sa_config.get("threshold_exceedance", False)
46 self.threshold_values = self.sa_config.get("threshold_values", [0.7, 0.8])
47 self.quantiles = self.sa_config.get("quantiles", False)
48 self.quantile_values = self.sa_config.get(
49 "quantile_values", [0.05, 0.25, 0.5, 0.75, 0.95]
50 )
51 Path('./results/').mkdir(parents=True, exist_ok=True)
53 # Instantiate the melissa statistical data structures
54 self.max_order: int = 0
55 self.melissa_moments: Dict[str, Dict] = {} # {field, {clt_rank, {t, StatisticalMoments}}}
56 if self.sobol_op:
57 self.pick_freeze_matrix: List[List[Union[int, float]]] = []
58 self.melissa_sobol: Dict[str, Dict] = {} # {field, {clt_rank, {t, IterSobolMartinez}}}
60 if self.kurtosis:
61 self.max_order = 4
62 elif self.skewness:
63 self.max_order = 3
64 elif self.variance:
65 self.max_order = 2
66 elif self.mean:
67 self.max_order = 1
68 else:
69 self.max_order = 0
71 if self.min or self.max:
72 logging.warning("min max not implemented")
73 if self.threshold_exceedance:
74 logging.warning("threshold not implemented")
75 if self.quantiles:
76 logging.warning("quantiles not implemented")
78 self.first_stat_computation: bool = True
79 self.seen_ranks: List[int] = [] # list of seen client ranks
80 self.checkpoint_count: int = 0
81 self.checkpoint_interval: int = self.sa_config["checkpoint_interval"]
83 def check_group_size(self):
84 if self.sobol_op:
85 self.group_size = self.nb_parameters + 2
86 self.number_of_groups = self.parameter_sweep_size
87 self.num_clients = self.group_size * self.parameter_sweep_size
88 elif not self.sobol_op and self.group_size > 1 and self.num_clients % self.group_size != 0:
89 logger.error("Incorrect group_size, please remove or adjust this option")
90 self.catch_error = True
91 else:
92 pass
94 def generate_client_scripts(self, first_id, number_of_scripts, default_parameters=None):
95 """
96 Creates all required client.X.sh scripts and set up dict
97 for fault tolerance
98 """
99 for sim_id in range(first_id, first_id + number_of_scripts):
100 # if the number of scripts to be generated becomes too significant
101 # the server may spend too much time in this loop hence causing
102 # the launcher to believe that the server timed out if no PING is
103 # received
104 if number_of_scripts > 10000 and (sim_id - first_id) % 10000 == 0:
105 self.time_monitor.check_clock(time.monotonic(), self)
106 if default_parameters is not None:
107 parameters = default_parameters
108 else:
109 if not self.sobol_op:
110 parameters = list(next(self.parameter_generator))
111 else:
112 parameters = self.draw_from_pick_freeze()
113 if self.rank == 0:
114 client_script_i = os.path.abspath(f"./client_scripts/client.{str(sim_id)}.sh")
115 Path('./client_scripts').mkdir(parents=True, exist_ok=True)
116 with open(client_script_i, "w") as f:
117 print("#!/bin/sh", file=f)
118 print("exec env \\", file=f)
119 print(f" MELISSA_SIMU_ID={sim_id} \\", file=f)
120 print(f" MELISSA_SERVER_NODE_NAME={self.node_name} \\", file=f)
121 # str() conversion causes problems with scientific notations
122 # and should not be used
123 if self.rm_script:
124 print(
125 " "
126 + " ".join(
127 [os.path.join(os.getcwd(), "client.sh")]
128 + [
129 np.format_float_positional(x) if type(x) is not str
130 else x for x in parameters
131 ]
132 )
133 + " &",
134 file=f,
135 )
136 print(" wait", file=f)
137 print(' rm "$0"', file=f)
138 else:
139 print(
140 " "
141 + " ".join(
142 [os.path.join(os.getcwd(), "client.sh")]
143 + [
144 np.format_float_positional(x) if type(x) is not str
145 else x for x in parameters
146 ]
147 ),
148 file=f,
149 )
151 os.chmod(client_script_i, 0o744)
153 logger.info(
154 f"Rank {self.rank}: created client.{sim_id}.sh with parameters {parameters}"
155 )
157 # Fault-tolerance dictionary creation and update
158 group_id = sim_id // self.group_size
159 if group_id not in self.groups:
160 group = Group(group_id)
161 self.groups[group_id] = group
162 if sim_id not in self.groups[group_id].simulations:
163 self.n_submitted_simulations += 1
164 simulation = Simulation(
165 sim_id, self.num_samples, self.fields, parameters
166 )
167 self.groups[group_id].simulations[sim_id] = simulation
169 def draw_from_pick_freeze(self) -> List:
170 """
171 Returns a row from the pick-freeze matrix
172 """
173 if len(self.pick_freeze_matrix) > 0:
174 return self.pick_freeze_matrix.pop(0)
175 else:
176 logging.debug("Build pick-freeze matrix")
177 self.build_pick_freeze_matrix()
178 return self.pick_freeze_matrix.pop(0)
180 def build_pick_freeze_matrix(self):
181 """
182 Builds the pick-freeze matrix for one group
183 """
184 self.pick_freeze_matrix = list(next(self.parameter_generator))
186 def start(self):
187 """
188 The main execution method
189 """
190 if not self.restart:
191 self.launch_first_groups()
193 if self.restart:
194 # the reinitialization from checkpoint occurs here
195 logger.info(f"Continuing from checkpoint {self.restart}")
196 self.restart_from_checkpoint()
197 if self.rank == 0:
198 self.kill_and_restart_simulations()
200 self.setup_environment()
202 self.server_online()
204 self.server_offline()
206 self.server_finalize()
208 def server_online(self):
209 """
210 Method where user controls the data handling while
211 server is online.
212 """
213 self.receive()
214 return
216 def receive(self):
217 """Handle data from the server."""
218 self._is_receiving = True
219 received_samples = 0
220 while not self.all_done():
221 status = self.run()
222 if status is not None:
223 if isinstance(status, PartialSimulationData):
224 logger.debug(
225 f"receive message: sim_id {status.simulation_id}, "
226 f"timestep {status.time_step}",
227 )
228 received_samples += 1
230 # compute the statistics on the received data
231 self.compute_stats(status)
232 self.checkpoint_state()
234 self._is_receiving = False
236 def handle_simulation_data(self, msg):
237 """
238 Parses and validates the incoming data messages from simulations
239 """
240 # 1. Deserialize message
241 msg_data: PartialSimulationData = PartialSimulationData.from_msg(msg, self.learning)
242 logger.debug(
243 f"Rank {self.rank} received {msg_data} from rank {msg_data.client_rank} "
244 f"(vect_size: {len(msg_data.data)})"
245 )
246 # 2. Apply filters
247 if not (0 <= msg_data.time_step < self.num_samples):
248 logger.warning(
249 f"Rank {self.rank}: bad timestep {msg_data.time_step}"
250 )
251 return None
252 if msg_data.field not in self.fields:
253 logger.warning(f"Rank {self.rank}: bad field {msg_data.field}")
254 return None
256 # when sobol_op=1 the results of each simulation in the group are gathered on
257 # the ranks of its first simulation and are sent at once by each rank
258 # which means that len(msg_data.data) = group_size * data_size
259 # in addition msg_data.simulation_id is actually the group_id
260 for sim in range(len(msg_data.data) // msg_data.data_size):
261 if not self.sobol_op:
262 group_id = msg_data.simulation_id // self.group_size
263 sim_id = msg_data.simulation_id
264 else:
265 group_id = msg_data.simulation_id
266 sim_id = group_id * self.group_size + sim
267 simulation = self.groups[group_id].simulations[sim_id]
269 simulation_status, simulation_data = self.check_simulation_data(
270 simulation, msg_data
271 )
272 if simulation_status == SimulationDataStatus.COMPLETE:
273 logger.debug(
274 f"Rank {self.rank}: assembled time-step {simulation_data.time_step} "
275 f"- simulationID {simulation_data.simulation_id}"
276 )
277 elif simulation_status == SimulationDataStatus.ALREADY_RECEIVED:
278 logger.warning(f"Rank {self.rank}: duplicate simulation data {msg_data}")
280 # Check if simulation has finished
281 if (
282 simulation_status == SimulationDataStatus.COMPLETE
283 or simulation_status == SimulationDataStatus.EMPTY
284 ) and simulation.finished():
285 logger.info(f"Rank {self.rank}: simulation {simulation.id} finished")
286 self.n_finished_simulations += 1
288 return simulation_data
290 def check_simulation_data(
291 self, simulation: Simulation, simulation_data: PartialSimulationData
292 ) -> Tuple[
293 SimulationDataStatus,
294 Union[Optional[SimulationData], Optional[PartialSimulationData]],
295 ]:
296 """
297 Look for duplicated messages,
298 update received_simulation_data and the simulation_data status.
299 """
300 if simulation_data.client_rank not in simulation.received_simulation_data:
301 simulation.received_simulation_data[simulation_data.client_rank] = {}
302 simulation.received_time_steps[simulation_data.client_rank] = (
303 np.zeros((len(self.fields), self.num_samples), dtype=bool)
304 )
305 # Data have already been received
306 if simulation.has_already_received(
307 simulation_data.client_rank, simulation_data.time_step, simulation_data.field
308 ):
309 return SimulationDataStatus.ALREADY_RECEIVED, None
310 # Time step has never been seen
311 if simulation_data.time_step not in simulation.received_simulation_data[
312 simulation_data.client_rank
313 ]:
314 simulation.received_simulation_data[simulation_data.client_rank][
315 simulation_data.time_step
316 ] = {field: None for field in simulation.fields}
317 # Update the entry
318 # for SA it is more memory efficient not to keep track of the whole simulation_data
319 simulation.received_simulation_data[simulation_data.client_rank][
320 simulation_data.time_step
321 ][simulation_data.field] = 1
322 simulation._mark_as_received(
323 simulation_data.client_rank, simulation_data.time_step, simulation_data.field
324 )
325 if simulation.is_complete(simulation_data.time_step):
326 # All fields have been received for the time step
327 simulation.n_received_time_steps += 1
328 # Check there is actual data
329 is_empty = simulation_data.data_size == 0
330 if is_empty:
331 # Data have been set to another device, fields are empty
332 del simulation.received_simulation_data[simulation_data.client_rank][
333 simulation_data.time_step
334 ]
335 return SimulationDataStatus.EMPTY, None
337 del simulation.received_simulation_data[simulation_data.client_rank][
338 simulation_data.time_step
339 ]
340 return SimulationDataStatus.COMPLETE, simulation_data
341 else:
342 # Not all fields have been received yet
343 return SimulationDataStatus.PARTIAL, simulation_data
345 def setup_environment(self):
346 return super().setup_environment()
348 def server_offline(self):
349 """
350 Post processing goes here. Not required.
351 """
352 self.melissa_write_stats()
353 return
355 def server_finalize(self):
356 """
357 All finalization methods go here.
358 """
359 logger.info("stop server")
360 self.write_final_report()
361 self.close_connection()
362 return
364 def process_simulation_data(cls, msg: SimulationData, config: dict):
365 """
366 method used to custom process sa-data
367 """
368 return
370 def compute_stats(self, pdata: PartialSimulationData) -> None:
371 """
372 Link into stats lib for computing online statistics.
373 """
374 if self.first_stat_computation:
375 self.first_stat_computation = False
376 for field in self.fields:
377 self.melissa_moments[field] = {}
378 if self.sobol_op:
379 self.melissa_sobol[field] = {}
381 # Progressive initialization of the client_rank entry
382 # so that we do not iterate over client_comm_size which
383 # has not been broadcasted yet
384 if pdata.client_rank not in self.seen_ranks:
385 self.seen_ranks.append(pdata.client_rank)
386 for field in self.fields:
387 self.melissa_moments[field][pdata.client_rank] = {}
388 if self.sobol_op:
389 self.melissa_sobol[field][pdata.client_rank] = {}
390 for t in range(self.num_samples):
391 self.melissa_moments[field][pdata.client_rank][t] = (
392 IterativeMoments(self.max_order, dim=pdata.data_size)
393 )
394 if self.sobol_op:
395 self.melissa_sobol[field][pdata.client_rank][t] = (
396 IterativeSensitivityMartinez(nb_parms=self.nb_parameters,
397 dim=pdata.data_size)
398 )
400 # when sobol_op=1, results are grouped thus pdata.data contains the solution
401 # vectors of each simulation in the group and must be reshaped
402 # since only the first two solutions are used to compute the moments
403 np_data = pdata.data.reshape(-1, pdata.data_size)
404 self.melissa_moments[
405 pdata.field
406 ][pdata.client_rank][pdata.time_step].increment(np_data[0])
408 if self.sobol_op:
409 # increment the sobol data structure
410 self.melissa_sobol[
411 pdata.field
412 ][pdata.client_rank][pdata.time_step]._increment(np_data)
413 # increment the moments with the second solution
414 self.melissa_moments[
415 pdata.field
416 ][pdata.client_rank][pdata.time_step].increment(np_data[1])
418 def melissa_write_stats(self):
419 """
420 Write the computed statistics on file.
421 """
422 # Turn server monitoring off
423 if self.rank == 0:
424 snd_msg = self.encode_msg(message.StopTimeoutMonitoring())
425 self.launcherfd.send(snd_msg)
427 # Brodcast client_comm_size to all server ranks
428 client_comm_size: int = self.client_comm_size
429 if self.rank == 0:
430 self.comm.bcast(client_comm_size, root=0)
431 else:
432 client_comm_size = self.comm.bcast(client_comm_size, root=0)
433 self.client_comm_size = client_comm_size
434 logger.info(f"Rank: {self.rank}, gathered client comm size: {self.client_comm_size}")
436 # Update melissa_moments with missing client ranks
437 for field in self.fields:
438 for client_rank in range(self.client_comm_size):
439 if client_rank not in self.melissa_moments[field]:
440 self.melissa_moments[field][client_rank] = {}
441 for t in range(self.num_samples):
442 self.melissa_moments[field][client_rank][t] = (
443 IterativeMoments(self.max_order, dim=0)
444 )
446 temp_offset: int = 0
447 local_vect_sizes: npt.ArrayLike = np.zeros(self.comm_size, dtype=int)
448 vect_size: npt.ArrayLike = np.zeros(1, dtype=int)
449 global_vect_size: int = 0
451 # Compute the global vect size
452 field = self.fields[0]
453 for client_rank in self.melissa_moments[field].keys():
454 vect_size += np.size(self.melissa_moments[field][client_rank][0].m1)
456 self.comm.Allgather([vect_size, MPI.INT], [local_vect_sizes, MPI.INT])
457 global_vect_size = np.sum(local_vect_sizes)
458 logger.info(f"global_vect: {global_vect_size}")
460 d_buffer = np.zeros(global_vect_size)
462 if self.mean:
463 self.comm.Barrier()
464 for field in self.fields:
465 for t in range(self.num_samples):
466 file_name = "./results/results.{}_{}.{}".format(
467 field,
468 "mean",
469 str(t + 1).zfill(len(str(self.num_samples)))
470 )
471 if self.rank == 0:
472 logger.info(f"file name: {file_name}")
473 for rank in range(self.client_comm_size):
474 mean = self.melissa_moments[field][rank][t].m1
475 if np.size(mean) > 0:
476 d_buffer[
477 temp_offset:temp_offset + np.size(mean)
478 ] = mean
479 temp_offset += np.size(mean)
480 temp_offset = 0
481 d_buffer = self.gather_data(local_vect_sizes, d_buffer)
482 if self.rank == 0:
483 np.savetxt(file_name, d_buffer)
485 if self.variance:
486 self.comm.Barrier()
487 for field in self.fields:
488 for t in range(self.num_samples):
489 file_name = "./results/results.{}_{}.{}".format(
490 field,
491 "variance",
492 str(t + 1).zfill(len(str(self.num_samples)))
493 )
494 if self.rank == 0:
495 logger.info(f"file name: {file_name}")
496 for rank in range(self.client_comm_size):
497 mean = self.melissa_moments[field][rank][t].get_mean()
498 if np.size(mean) > 0:
499 var = self.melissa_moments[field][rank][t].get_variance()
500 d_buffer[
501 temp_offset:temp_offset + np.size(var)
502 ] = var
503 temp_offset += np.size(var)
504 temp_offset = 0
505 d_buffer = self.gather_data(local_vect_sizes, d_buffer)
506 if self.rank == 0:
507 np.savetxt(file_name, d_buffer)
509 if self.skewness:
510 self.comm.Barrier()
511 for field in self.fields:
512 for t in range(self.num_samples):
513 file_name = "./results/results.{}_{}.{}".format(
514 field,
515 "skewness",
516 str(t + 1).zfill(len(str(self.num_samples)))
517 )
518 if self.rank == 0:
519 logger.info(f"file name: {file_name}")
520 for rank in range(self.client_comm_size):
521 mean = self.melissa_moments[field][rank][t].get_mean()
522 if np.size(mean) > 0:
523 crank_skewness = self.melissa_moments[field][rank][t].get_skewness()
524 d_buffer[
525 temp_offset:temp_offset + np.size(crank_skewness)
526 ] = crank_skewness
527 temp_offset += np.size(crank_skewness)
528 temp_offset = 0
529 d_buffer = self.gather_data(local_vect_sizes, d_buffer)
530 if self.rank == 0:
531 np.savetxt(file_name, d_buffer)
532 # free memory
533 del crank_skewness
535 if self.kurtosis:
536 self.comm.Barrier()
537 for field in self.fields:
538 for t in range(self.num_samples):
539 file_name = "./results/results.{}_{}.{}".format(
540 field,
541 "kurtosis",
542 str(t + 1).zfill(len(str(self.num_samples)))
543 )
544 if self.rank == 0:
545 logger.info(f"file name: {file_name}")
546 for rank in range(self.client_comm_size):
547 mean = self.melissa_moments[field][rank][t].get_mean()
548 if np.size(mean) > 0:
549 crank_kurtosis = self.melissa_moments[field][rank][t].get_kurtosis()
550 d_buffer[
551 temp_offset:temp_offset + np.size(crank_kurtosis)
552 ] = crank_kurtosis
553 temp_offset += np.size(crank_kurtosis)
554 temp_offset = 0
555 d_buffer = self.gather_data(local_vect_sizes, d_buffer)
556 if self.rank == 0:
557 np.savetxt(file_name, d_buffer)
558 # free memory
559 del crank_kurtosis
561 if self.sobol_op:
562 self.comm.Barrier()
563 for field in self.fields:
564 for param in range(self.nb_parameters):
565 for t in range(self.num_samples):
566 file_name = "./results/results.{}_{}{}.{}".format(
567 field,
568 "sobol",
569 str(param),
570 str(t + 1).zfill(len(str(self.num_samples)))
571 )
572 if self.rank == 0:
573 logger.info(f"file name: {file_name}")
574 for rank in range(self.client_comm_size):
575 pearson_b = self.melissa_sobol[field][rank][t].pearson_B[param]
576 if np.size(pearson_b) > 0:
577 d_buffer[
578 temp_offset:temp_offset + np.size(pearson_b)
579 ] = pearson_b
580 temp_offset += np.size(pearson_b)
581 temp_offset = 0
582 d_buffer = self.gather_data(local_vect_sizes, d_buffer)
583 if self.rank == 0:
584 np.savetxt(file_name, d_buffer)
586 for field in self.fields:
587 for param in range(self.nb_parameters):
588 for t in range(self.num_samples):
589 file_name = "./results/results.{}_{}{}.{}".format(
590 field,
591 "sobol_tot",
592 str(param),
593 str(t + 1).zfill(len(str(self.num_samples)))
594 )
595 if self.rank == 0:
596 logger.info(f"file name: {file_name}")
597 for rank in range(self.client_comm_size):
598 pearson_a = self.melissa_sobol[field][rank][t].pearson_A[param]
599 if np.size(pearson_a) > 0:
600 d_buffer[
601 temp_offset:temp_offset + np.size(pearson_a)
602 ] = pearson_a
603 temp_offset += np.size(pearson_a)
604 temp_offset = 0
605 d_buffer = self.gather_data(local_vect_sizes, d_buffer)
606 if self.rank == 0:
607 np.savetxt(file_name, d_buffer)
609 def gather_data(
610 self,
611 local_vect_sizes: npt.NDArray[np.int_],
612 d_buffer: npt.NDArray[np.float_]
613 ) -> npt.NDArray[np.float_]:
614 """
615 Gather data on rank 0.
616 """
617 temp_offset: int = 0
618 if self.rank == 0:
619 for rank in range(1, self.comm_size):
620 temp_offset += local_vect_sizes[rank - 1]
621 if local_vect_sizes[rank] > 0:
622 d_buffer[
623 temp_offset:temp_offset
624 + local_vect_sizes[rank]
625 ] = self.comm.recv(source=rank)
626 temp_offset = 0
627 else:
628 if local_vect_sizes[self.rank] > 0:
629 self.comm.send(d_buffer[:local_vect_sizes[self.rank]], dest=0)
631 return d_buffer
633 def checkpoint_state(self):
634 if not self.checkpoint_interval:
635 return
637 self.checkpoint_count += 1
638 if self.checkpoint_count % self.checkpoint_interval != 0:
639 return
641 logger.info("Checkpointing state")
642 self.save_base_state()
644 stats_metadata = {"seen_ranks": self.seen_ranks, "num_samples": self.num_samples}
646 # logger.info(f"Checkpointing moments {self.melissa_moments}")
647 with open("checkpoints/melissa_moments.pkl", 'wb') as f:
648 cloudpickle.dump(self.melissa_moments, f)
650 if self.sobol_op:
651 with open("checkpoints/melissa_sobol.pkl", 'wb') as f:
652 cloudpickle.dump(self.melissa_sobol, f)
654 with open("checkpoints/stats_metadata.json", 'w') as f:
655 rapidjson.dump(stats_metadata, f)
656 return
658 def restart_from_checkpoint(self):
659 """
660 Invert checkpoint_state
661 """
662 self.load_base_state()
664 with open("checkpoints/melissa_moments.pkl", 'rb') as f:
665 self.melissa_moments = cloudpickle.load(f)
667 if self.sobol_op:
668 with open("checkpoints/melissa_sobol.pkl", 'rb') as f:
669 self.melissa_sobol = cloudpickle.load(f)
671 with open("checkpoints/stats_metadata.json", 'r') as f:
672 stats_metadata = rapidjson.load(f)
674 self.seen_ranks = stats_metadata["seen_ranks"]
675 self.num_samples = stats_metadata["num_samples"]
676 self.first_stat_computation = False
678 return