Coverage for melissa/server/sensitivity_analysis/ 26%

383 statements  

« prev     ^ index     » next v7.2.7, created at 2023-09-22 10:36 +0000

1import logging 

2import os 

3import time 

4from typing import Any, Dict, List, Optional, Tuple, Union 

5from pathlib import Path 


7import numpy as np 

8import numpy.typing as npt 

9from mpi4py import MPI 


11from melissa.launcher import message 

12from melissa.server.base_server import BaseServer 

13from melissa.server.simulation import (Group, PartialSimulationData, Simulation, 

14 SimulationData, SimulationDataStatus) 

15from iterative_stats.sensitivity.sensitivity_martinez import IterativeSensitivityMartinez 

16from iterative_stats.iterative_moments import IterativeMoments 

17import cloudpickle 

18import rapidjson 


20logger = logging.getLogger(__name__) 



23class SensitivityAnalysisServer(BaseServer): 

24 """ 

25 Server to be used for Sensitivity Analysis studies. 

26 """ 


28 def __init__(self, config: Dict[str, Any]): 

29 super().__init__(config) 

30 if self.num_samples == 0: 

31 raise Exception("Error, in case of an SA study num_samples must be set by the user") 


33 self.learning: int = 0 

34 self.sa_config: Dict[str, Any] = config["sa_config"] 


36 self.sobol_op = 1 if self.sa_config.get("sobol_indices", False) else 0 

37 self.check_group_size() 


39 self.mean = self.sa_config.get("mean", False) 

40 self.variance = self.sa_config.get("variance", False) 

41 self.skewness = self.sa_config.get("skewness", False) 

42 self.kurtosis = self.sa_config.get("kurtosis", False) 

43 self.min = self.sa_config.get("min", False) 

44 self.max = self.sa_config.get("max", False) 

45 self.threshold_exceedance = self.sa_config.get("threshold_exceedance", False) 

46 self.threshold_values = self.sa_config.get("threshold_values", [0.7, 0.8]) 

47 self.quantiles = self.sa_config.get("quantiles", False) 

48 self.quantile_values = self.sa_config.get( 

49 "quantile_values", [0.05, 0.25, 0.5, 0.75, 0.95] 

50 ) 

51 Path('./results/').mkdir(parents=True, exist_ok=True) 


53 # Instantiate the melissa statistical data structures 

54 self.max_order: int = 0 

55 self.melissa_moments: Dict[str, Dict] = {} # {field, {clt_rank, {t, StatisticalMoments}}} 

56 if self.sobol_op: 

57 self.pick_freeze_matrix: List[List[Union[int, float]]] = [] 

58 self.melissa_sobol: Dict[str, Dict] = {} # {field, {clt_rank, {t, IterSobolMartinez}}} 


60 if self.kurtosis: 

61 self.max_order = 4 

62 elif self.skewness: 

63 self.max_order = 3 

64 elif self.variance: 

65 self.max_order = 2 

66 elif self.mean: 

67 self.max_order = 1 

68 else: 

69 self.max_order = 0 


71 if self.min or self.max: 

72 logging.warning("min max not implemented") 

73 if self.threshold_exceedance: 

74 logging.warning("threshold not implemented") 

75 if self.quantiles: 

76 logging.warning("quantiles not implemented") 


78 self.first_stat_computation: bool = True 

79 self.seen_ranks: List[int] = [] # list of seen client ranks 

80 self.checkpoint_count: int = 0 

81 self.checkpoint_interval: int = self.sa_config["checkpoint_interval"] 


83 def check_group_size(self): 

84 if self.sobol_op: 

85 self.group_size = self.nb_parameters + 2 

86 self.number_of_groups = self.parameter_sweep_size 

87 self.num_clients = self.group_size * self.parameter_sweep_size 

88 elif not self.sobol_op and self.group_size > 1 and self.num_clients % self.group_size != 0: 

89 logger.error("Incorrect group_size, please remove or adjust this option") 

90 self.catch_error = True 

91 else: 

92 pass 


94 def generate_client_scripts(self, first_id, number_of_scripts, default_parameters=None): 

95 """ 

96 Creates all required scripts and set up dict 

97 for fault tolerance 

98 """ 

99 for sim_id in range(first_id, first_id + number_of_scripts): 

100 # if the number of scripts to be generated becomes too significant 

101 # the server may spend too much time in this loop hence causing 

102 # the launcher to believe that the server timed out if no PING is 

103 # received 

104 if number_of_scripts > 10000 and (sim_id - first_id) % 10000 == 0: 

105 self.time_monitor.check_clock(time.monotonic(), self) 

106 if default_parameters is not None: 

107 parameters = default_parameters 

108 else: 

109 if not self.sobol_op: 

110 parameters = list(next(self.parameter_generator)) 

111 else: 

112 parameters = self.draw_from_pick_freeze() 

113 if self.rank == 0: 

114 client_script_i = os.path.abspath(f"./client_scripts/client.{str(sim_id)}.sh") 

115 Path('./client_scripts').mkdir(parents=True, exist_ok=True) 

116 with open(client_script_i, "w") as f: 

117 print("#!/bin/sh", file=f) 

118 print("exec env \\", file=f) 

119 print(f" MELISSA_SIMU_ID={sim_id} \\", file=f) 

120 print(f" MELISSA_SERVER_NODE_NAME={self.node_name} \\", file=f) 

121 # str() conversion causes problems with scientific notations 

122 # and should not be used 

123 if self.rm_script: 

124 print( 

125 " " 

126 + " ".join( 

127 [os.path.join(os.getcwd(), "")] 

128 + [ 

129 np.format_float_positional(x) if type(x) is not str 

130 else x for x in parameters 

131 ] 

132 ) 

133 + " &", 

134 file=f, 

135 ) 

136 print(" wait", file=f) 

137 print(' rm "$0"', file=f) 

138 else: 

139 print( 

140 " " 

141 + " ".join( 

142 [os.path.join(os.getcwd(), "")] 

143 + [ 

144 np.format_float_positional(x) if type(x) is not str 

145 else x for x in parameters 

146 ] 

147 ), 

148 file=f, 

149 ) 


151 os.chmod(client_script_i, 0o744) 



154 f"Rank {self.rank}: created client.{sim_id}.sh with parameters {parameters}" 

155 ) 


157 # Fault-tolerance dictionary creation and update 

158 group_id = sim_id // self.group_size 

159 if group_id not in self.groups: 

160 group = Group(group_id) 

161 self.groups[group_id] = group 

162 if sim_id not in self.groups[group_id].simulations: 

163 self.n_submitted_simulations += 1 

164 simulation = Simulation( 

165 sim_id, self.num_samples, self.fields, parameters 

166 ) 

167 self.groups[group_id].simulations[sim_id] = simulation 


169 def draw_from_pick_freeze(self) -> List: 

170 """ 

171 Returns a row from the pick-freeze matrix 

172 """ 

173 if len(self.pick_freeze_matrix) > 0: 

174 return self.pick_freeze_matrix.pop(0) 

175 else: 

176 logging.debug("Build pick-freeze matrix") 

177 self.build_pick_freeze_matrix() 

178 return self.pick_freeze_matrix.pop(0) 


180 def build_pick_freeze_matrix(self): 

181 """ 

182 Builds the pick-freeze matrix for one group 

183 """ 

184 self.pick_freeze_matrix = list(next(self.parameter_generator)) 


186 def start(self): 

187 """ 

188 The main execution method 

189 """ 

190 if not self.restart: 

191 self.launch_first_groups() 


193 if self.restart: 

194 # the reinitialization from checkpoint occurs here 

195"Continuing from checkpoint {self.restart}") 

196 self.restart_from_checkpoint() 

197 if self.rank == 0: 

198 self.kill_and_restart_simulations() 


200 self.setup_environment() 


202 self.server_online() 


204 self.server_offline() 


206 self.server_finalize() 


208 def server_online(self): 

209 """ 

210 Method where user controls the data handling while 

211 server is online. 

212 """ 

213 self.receive() 

214 return 


216 def receive(self): 

217 """Handle data from the server.""" 

218 self._is_receiving = True 

219 received_samples = 0 

220 while not self.all_done(): 

221 status = 

222 if status is not None: 

223 if isinstance(status, PartialSimulationData): 

224 logger.debug( 

225 f"receive message: sim_id {status.simulation_id}, " 

226 f"timestep {status.time_step}", 

227 ) 

228 received_samples += 1 


230 # compute the statistics on the received data 

231 self.compute_stats(status) 

232 self.checkpoint_state() 


234 self._is_receiving = False 


236 def handle_simulation_data(self, msg): 

237 """ 

238 Parses and validates the incoming data messages from simulations 

239 """ 

240 # 1. Deserialize message 

241 msg_data: PartialSimulationData = PartialSimulationData.from_msg(msg, self.learning) 

242 logger.debug( 

243 f"Rank {self.rank} received {msg_data} from rank {msg_data.client_rank} " 

244 f"(vect_size: {len(})" 

245 ) 

246 # 2. Apply filters 

247 if not (0 <= msg_data.time_step < self.num_samples): 

248 logger.warning( 

249 f"Rank {self.rank}: bad timestep {msg_data.time_step}" 

250 ) 

251 return None 

252 if msg_data.field not in self.fields: 

253 logger.warning(f"Rank {self.rank}: bad field {msg_data.field}") 

254 return None 


256 # when sobol_op=1 the results of each simulation in the group are gathered on 

257 # the ranks of its first simulation and are sent at once by each rank 

258 # which means that len( = group_size * data_size 

259 # in addition msg_data.simulation_id is actually the group_id 

260 for sim in range(len( // msg_data.data_size): 

261 if not self.sobol_op: 

262 group_id = msg_data.simulation_id // self.group_size 

263 sim_id = msg_data.simulation_id 

264 else: 

265 group_id = msg_data.simulation_id 

266 sim_id = group_id * self.group_size + sim 

267 simulation = self.groups[group_id].simulations[sim_id] 


269 simulation_status, simulation_data = self.check_simulation_data( 

270 simulation, msg_data 

271 ) 

272 if simulation_status == SimulationDataStatus.COMPLETE: 

273 logger.debug( 

274 f"Rank {self.rank}: assembled time-step {simulation_data.time_step} " 

275 f"- simulationID {simulation_data.simulation_id}" 

276 ) 

277 elif simulation_status == SimulationDataStatus.ALREADY_RECEIVED: 

278 logger.warning(f"Rank {self.rank}: duplicate simulation data {msg_data}") 


280 # Check if simulation has finished 

281 if ( 

282 simulation_status == SimulationDataStatus.COMPLETE 

283 or simulation_status == SimulationDataStatus.EMPTY 

284 ) and simulation.finished(): 

285"Rank {self.rank}: simulation {} finished") 

286 self.n_finished_simulations += 1 


288 return simulation_data 


290 def check_simulation_data( 

291 self, simulation: Simulation, simulation_data: PartialSimulationData 

292 ) -> Tuple[ 

293 SimulationDataStatus, 

294 Union[Optional[SimulationData], Optional[PartialSimulationData]], 

295 ]: 

296 """ 

297 Look for duplicated messages, 

298 update received_simulation_data and the simulation_data status. 

299 """ 

300 if simulation_data.client_rank not in simulation.received_simulation_data: 

301 simulation.received_simulation_data[simulation_data.client_rank] = {} 

302 simulation.received_time_steps[simulation_data.client_rank] = ( 

303 np.zeros((len(self.fields), self.num_samples), dtype=bool) 

304 ) 

305 # Data have already been received 

306 if simulation.has_already_received( 

307 simulation_data.client_rank, simulation_data.time_step, simulation_data.field 

308 ): 

309 return SimulationDataStatus.ALREADY_RECEIVED, None 

310 # Time step has never been seen 

311 if simulation_data.time_step not in simulation.received_simulation_data[ 

312 simulation_data.client_rank 

313 ]: 

314 simulation.received_simulation_data[simulation_data.client_rank][ 

315 simulation_data.time_step 

316 ] = {field: None for field in simulation.fields} 

317 # Update the entry 

318 # for SA it is more memory efficient not to keep track of the whole simulation_data 

319 simulation.received_simulation_data[simulation_data.client_rank][ 

320 simulation_data.time_step 

321 ][simulation_data.field] = 1 

322 simulation._mark_as_received( 

323 simulation_data.client_rank, simulation_data.time_step, simulation_data.field 

324 ) 

325 if simulation.is_complete(simulation_data.time_step): 

326 # All fields have been received for the time step 

327 simulation.n_received_time_steps += 1 

328 # Check there is actual data 

329 is_empty = simulation_data.data_size == 0 

330 if is_empty: 

331 # Data have been set to another device, fields are empty 

332 del simulation.received_simulation_data[simulation_data.client_rank][ 

333 simulation_data.time_step 

334 ] 

335 return SimulationDataStatus.EMPTY, None 


337 del simulation.received_simulation_data[simulation_data.client_rank][ 

338 simulation_data.time_step 

339 ] 

340 return SimulationDataStatus.COMPLETE, simulation_data 

341 else: 

342 # Not all fields have been received yet 

343 return SimulationDataStatus.PARTIAL, simulation_data 


345 def setup_environment(self): 

346 return super().setup_environment() 


348 def server_offline(self): 

349 """ 

350 Post processing goes here. Not required. 

351 """ 

352 self.melissa_write_stats() 

353 return 


355 def server_finalize(self): 

356 """ 

357 All finalization methods go here. 

358 """ 

359"stop server") 

360 self.write_final_report() 

361 self.close_connection() 

362 return 


364 def process_simulation_data(cls, msg: SimulationData, config: dict): 

365 """ 

366 method used to custom process sa-data 

367 """ 

368 return 


370 def compute_stats(self, pdata: PartialSimulationData) -> None: 

371 """ 

372 Link into stats lib for computing online statistics. 

373 """ 

374 if self.first_stat_computation: 

375 self.first_stat_computation = False 

376 for field in self.fields: 

377 self.melissa_moments[field] = {} 

378 if self.sobol_op: 

379 self.melissa_sobol[field] = {} 


381 # Progressive initialization of the client_rank entry 

382 # so that we do not iterate over client_comm_size which 

383 # has not been broadcasted yet 

384 if pdata.client_rank not in self.seen_ranks: 

385 self.seen_ranks.append(pdata.client_rank) 

386 for field in self.fields: 

387 self.melissa_moments[field][pdata.client_rank] = {} 

388 if self.sobol_op: 

389 self.melissa_sobol[field][pdata.client_rank] = {} 

390 for t in range(self.num_samples): 

391 self.melissa_moments[field][pdata.client_rank][t] = ( 

392 IterativeMoments(self.max_order, dim=pdata.data_size) 

393 ) 

394 if self.sobol_op: 

395 self.melissa_sobol[field][pdata.client_rank][t] = ( 

396 IterativeSensitivityMartinez(nb_parms=self.nb_parameters, 

397 dim=pdata.data_size) 

398 ) 


400 # when sobol_op=1, results are grouped thus contains the solution 

401 # vectors of each simulation in the group and must be reshaped 

402 # since only the first two solutions are used to compute the moments 

403 np_data =, pdata.data_size) 

404 self.melissa_moments[ 

405 pdata.field 

406 ][pdata.client_rank][pdata.time_step].increment(np_data[0]) 


408 if self.sobol_op: 

409 # increment the sobol data structure 

410 self.melissa_sobol[ 

411 pdata.field 

412 ][pdata.client_rank][pdata.time_step]._increment(np_data) 

413 # increment the moments with the second solution 

414 self.melissa_moments[ 

415 pdata.field 

416 ][pdata.client_rank][pdata.time_step].increment(np_data[1]) 


418 def melissa_write_stats(self): 

419 """ 

420 Write the computed statistics on file. 

421 """ 

422 # Turn server monitoring off 

423 if self.rank == 0: 

424 snd_msg = self.encode_msg(message.StopTimeoutMonitoring()) 

425 self.launcherfd.send(snd_msg) 


427 # Brodcast client_comm_size to all server ranks 

428 client_comm_size: int = self.client_comm_size 

429 if self.rank == 0: 

430 self.comm.bcast(client_comm_size, root=0) 

431 else: 

432 client_comm_size = self.comm.bcast(client_comm_size, root=0) 

433 self.client_comm_size = client_comm_size 

434"Rank: {self.rank}, gathered client comm size: {self.client_comm_size}") 


436 # Update melissa_moments with missing client ranks 

437 for field in self.fields: 

438 for client_rank in range(self.client_comm_size): 

439 if client_rank not in self.melissa_moments[field]: 

440 self.melissa_moments[field][client_rank] = {} 

441 for t in range(self.num_samples): 

442 self.melissa_moments[field][client_rank][t] = ( 

443 IterativeMoments(self.max_order, dim=0) 

444 ) 


446 temp_offset: int = 0 

447 local_vect_sizes: npt.ArrayLike = np.zeros(self.comm_size, dtype=int) 

448 vect_size: npt.ArrayLike = np.zeros(1, dtype=int) 

449 global_vect_size: int = 0 


451 # Compute the global vect size 

452 field = self.fields[0] 

453 for client_rank in self.melissa_moments[field].keys(): 

454 vect_size += np.size(self.melissa_moments[field][client_rank][0].m1) 


456 self.comm.Allgather([vect_size, MPI.INT], [local_vect_sizes, MPI.INT]) 

457 global_vect_size = np.sum(local_vect_sizes) 

458"global_vect: {global_vect_size}") 


460 d_buffer = np.zeros(global_vect_size) 


462 if self.mean: 

463 self.comm.Barrier() 

464 for field in self.fields: 

465 for t in range(self.num_samples): 

466 file_name = "./results/results.{}_{}.{}".format( 

467 field, 

468 "mean", 

469 str(t + 1).zfill(len(str(self.num_samples))) 

470 ) 

471 if self.rank == 0: 

472"file name: {file_name}") 

473 for rank in range(self.client_comm_size): 

474 mean = self.melissa_moments[field][rank][t].m1 

475 if np.size(mean) > 0: 

476 d_buffer[ 

477 temp_offset:temp_offset + np.size(mean) 

478 ] = mean 

479 temp_offset += np.size(mean) 

480 temp_offset = 0 

481 d_buffer = self.gather_data(local_vect_sizes, d_buffer) 

482 if self.rank == 0: 

483 np.savetxt(file_name, d_buffer) 


485 if self.variance: 

486 self.comm.Barrier() 

487 for field in self.fields: 

488 for t in range(self.num_samples): 

489 file_name = "./results/results.{}_{}.{}".format( 

490 field, 

491 "variance", 

492 str(t + 1).zfill(len(str(self.num_samples))) 

493 ) 

494 if self.rank == 0: 

495"file name: {file_name}") 

496 for rank in range(self.client_comm_size): 

497 mean = self.melissa_moments[field][rank][t].get_mean() 

498 if np.size(mean) > 0: 

499 var = self.melissa_moments[field][rank][t].get_variance() 

500 d_buffer[ 

501 temp_offset:temp_offset + np.size(var) 

502 ] = var 

503 temp_offset += np.size(var) 

504 temp_offset = 0 

505 d_buffer = self.gather_data(local_vect_sizes, d_buffer) 

506 if self.rank == 0: 

507 np.savetxt(file_name, d_buffer) 


509 if self.skewness: 

510 self.comm.Barrier() 

511 for field in self.fields: 

512 for t in range(self.num_samples): 

513 file_name = "./results/results.{}_{}.{}".format( 

514 field, 

515 "skewness", 

516 str(t + 1).zfill(len(str(self.num_samples))) 

517 ) 

518 if self.rank == 0: 

519"file name: {file_name}") 

520 for rank in range(self.client_comm_size): 

521 mean = self.melissa_moments[field][rank][t].get_mean() 

522 if np.size(mean) > 0: 

523 crank_skewness = self.melissa_moments[field][rank][t].get_skewness() 

524 d_buffer[ 

525 temp_offset:temp_offset + np.size(crank_skewness) 

526 ] = crank_skewness 

527 temp_offset += np.size(crank_skewness) 

528 temp_offset = 0 

529 d_buffer = self.gather_data(local_vect_sizes, d_buffer) 

530 if self.rank == 0: 

531 np.savetxt(file_name, d_buffer) 

532 # free memory 

533 del crank_skewness 


535 if self.kurtosis: 

536 self.comm.Barrier() 

537 for field in self.fields: 

538 for t in range(self.num_samples): 

539 file_name = "./results/results.{}_{}.{}".format( 

540 field, 

541 "kurtosis", 

542 str(t + 1).zfill(len(str(self.num_samples))) 

543 ) 

544 if self.rank == 0: 

545"file name: {file_name}") 

546 for rank in range(self.client_comm_size): 

547 mean = self.melissa_moments[field][rank][t].get_mean() 

548 if np.size(mean) > 0: 

549 crank_kurtosis = self.melissa_moments[field][rank][t].get_kurtosis() 

550 d_buffer[ 

551 temp_offset:temp_offset + np.size(crank_kurtosis) 

552 ] = crank_kurtosis 

553 temp_offset += np.size(crank_kurtosis) 

554 temp_offset = 0 

555 d_buffer = self.gather_data(local_vect_sizes, d_buffer) 

556 if self.rank == 0: 

557 np.savetxt(file_name, d_buffer) 

558 # free memory 

559 del crank_kurtosis 


561 if self.sobol_op: 

562 self.comm.Barrier() 

563 for field in self.fields: 

564 for param in range(self.nb_parameters): 

565 for t in range(self.num_samples): 

566 file_name = "./results/results.{}_{}{}.{}".format( 

567 field, 

568 "sobol", 

569 str(param), 

570 str(t + 1).zfill(len(str(self.num_samples))) 

571 ) 

572 if self.rank == 0: 

573"file name: {file_name}") 

574 for rank in range(self.client_comm_size): 

575 pearson_b = self.melissa_sobol[field][rank][t].pearson_B[param] 

576 if np.size(pearson_b) > 0: 

577 d_buffer[ 

578 temp_offset:temp_offset + np.size(pearson_b) 

579 ] = pearson_b 

580 temp_offset += np.size(pearson_b) 

581 temp_offset = 0 

582 d_buffer = self.gather_data(local_vect_sizes, d_buffer) 

583 if self.rank == 0: 

584 np.savetxt(file_name, d_buffer) 


586 for field in self.fields: 

587 for param in range(self.nb_parameters): 

588 for t in range(self.num_samples): 

589 file_name = "./results/results.{}_{}{}.{}".format( 

590 field, 

591 "sobol_tot", 

592 str(param), 

593 str(t + 1).zfill(len(str(self.num_samples))) 

594 ) 

595 if self.rank == 0: 

596"file name: {file_name}") 

597 for rank in range(self.client_comm_size): 

598 pearson_a = self.melissa_sobol[field][rank][t].pearson_A[param] 

599 if np.size(pearson_a) > 0: 

600 d_buffer[ 

601 temp_offset:temp_offset + np.size(pearson_a) 

602 ] = pearson_a 

603 temp_offset += np.size(pearson_a) 

604 temp_offset = 0 

605 d_buffer = self.gather_data(local_vect_sizes, d_buffer) 

606 if self.rank == 0: 

607 np.savetxt(file_name, d_buffer) 


609 def gather_data( 

610 self, 

611 local_vect_sizes: npt.NDArray[np.int_], 

612 d_buffer: npt.NDArray[np.float_] 

613 ) -> npt.NDArray[np.float_]: 

614 """ 

615 Gather data on rank 0. 

616 """ 

617 temp_offset: int = 0 

618 if self.rank == 0: 

619 for rank in range(1, self.comm_size): 

620 temp_offset += local_vect_sizes[rank - 1] 

621 if local_vect_sizes[rank] > 0: 

622 d_buffer[ 

623 temp_offset:temp_offset 

624 + local_vect_sizes[rank] 

625 ] = self.comm.recv(source=rank) 

626 temp_offset = 0 

627 else: 

628 if local_vect_sizes[self.rank] > 0: 

629 self.comm.send(d_buffer[:local_vect_sizes[self.rank]], dest=0) 


631 return d_buffer 


633 def checkpoint_state(self): 

634 if not self.checkpoint_interval: 

635 return 


637 self.checkpoint_count += 1 

638 if self.checkpoint_count % self.checkpoint_interval != 0: 

639 return 


641"Checkpointing state") 

642 self.save_base_state() 


644 stats_metadata = {"seen_ranks": self.seen_ranks, "num_samples": self.num_samples} 


646 #"Checkpointing moments {self.melissa_moments}") 

647 with open("checkpoints/melissa_moments.pkl", 'wb') as f: 

648 cloudpickle.dump(self.melissa_moments, f) 


650 if self.sobol_op: 

651 with open("checkpoints/melissa_sobol.pkl", 'wb') as f: 

652 cloudpickle.dump(self.melissa_sobol, f) 


654 with open("checkpoints/stats_metadata.json", 'w') as f: 

655 rapidjson.dump(stats_metadata, f) 

656 return 


658 def restart_from_checkpoint(self): 

659 """ 

660 Invert checkpoint_state 

661 """ 

662 self.load_base_state() 


664 with open("checkpoints/melissa_moments.pkl", 'rb') as f: 

665 self.melissa_moments = cloudpickle.load(f) 


667 if self.sobol_op: 

668 with open("checkpoints/melissa_sobol.pkl", 'rb') as f: 

669 self.melissa_sobol = cloudpickle.load(f) 


671 with open("checkpoints/stats_metadata.json", 'r') as f: 

672 stats_metadata = rapidjson.load(f) 


674 self.seen_ranks = stats_metadata["seen_ranks"] 

675 self.num_samples = stats_metadata["num_samples"] 

676 self.first_stat_computation = False 


678 return