Coverage for melissa/server/sensitivity_analysis/sensitivity_analysis_server.py: 26%

383 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-09-22 10:36 +0000

1import logging 

2import os 

3import time 

4from typing import Any, Dict, List, Optional, Tuple, Union 

5from pathlib import Path 

6 

7import numpy as np 

8import numpy.typing as npt 

9from mpi4py import MPI 

10 

11from melissa.launcher import message 

12from melissa.server.base_server import BaseServer 

13from melissa.server.simulation import (Group, PartialSimulationData, Simulation, 

14 SimulationData, SimulationDataStatus) 

15from iterative_stats.sensitivity.sensitivity_martinez import IterativeSensitivityMartinez 

16from iterative_stats.iterative_moments import IterativeMoments 

17import cloudpickle 

18import rapidjson 

19 

20logger = logging.getLogger(__name__) 

21 

22 

23class SensitivityAnalysisServer(BaseServer): 

24 """ 

25 Server to be used for Sensitivity Analysis studies. 

26 """ 

27 

28 def __init__(self, config: Dict[str, Any]): 

29 super().__init__(config) 

30 if self.num_samples == 0: 

31 raise Exception("Error, in case of an SA study num_samples must be set by the user") 

32 

33 self.learning: int = 0 

34 self.sa_config: Dict[str, Any] = config["sa_config"] 

35 

36 self.sobol_op = 1 if self.sa_config.get("sobol_indices", False) else 0 

37 self.check_group_size() 

38 

39 self.mean = self.sa_config.get("mean", False) 

40 self.variance = self.sa_config.get("variance", False) 

41 self.skewness = self.sa_config.get("skewness", False) 

42 self.kurtosis = self.sa_config.get("kurtosis", False) 

43 self.min = self.sa_config.get("min", False) 

44 self.max = self.sa_config.get("max", False) 

45 self.threshold_exceedance = self.sa_config.get("threshold_exceedance", False) 

46 self.threshold_values = self.sa_config.get("threshold_values", [0.7, 0.8]) 

47 self.quantiles = self.sa_config.get("quantiles", False) 

48 self.quantile_values = self.sa_config.get( 

49 "quantile_values", [0.05, 0.25, 0.5, 0.75, 0.95] 

50 ) 

51 Path('./results/').mkdir(parents=True, exist_ok=True) 

52 

53 # Instantiate the melissa statistical data structures 

54 self.max_order: int = 0 

55 self.melissa_moments: Dict[str, Dict] = {} # {field, {clt_rank, {t, StatisticalMoments}}} 

56 if self.sobol_op: 

57 self.pick_freeze_matrix: List[List[Union[int, float]]] = [] 

58 self.melissa_sobol: Dict[str, Dict] = {} # {field, {clt_rank, {t, IterSobolMartinez}}} 

59 

60 if self.kurtosis: 

61 self.max_order = 4 

62 elif self.skewness: 

63 self.max_order = 3 

64 elif self.variance: 

65 self.max_order = 2 

66 elif self.mean: 

67 self.max_order = 1 

68 else: 

69 self.max_order = 0 

70 

71 if self.min or self.max: 

72 logging.warning("min max not implemented") 

73 if self.threshold_exceedance: 

74 logging.warning("threshold not implemented") 

75 if self.quantiles: 

76 logging.warning("quantiles not implemented") 

77 

78 self.first_stat_computation: bool = True 

79 self.seen_ranks: List[int] = [] # list of seen client ranks 

80 self.checkpoint_count: int = 0 

81 self.checkpoint_interval: int = self.sa_config["checkpoint_interval"] 

82 

83 def check_group_size(self): 

84 if self.sobol_op: 

85 self.group_size = self.nb_parameters + 2 

86 self.number_of_groups = self.parameter_sweep_size 

87 self.num_clients = self.group_size * self.parameter_sweep_size 

88 elif not self.sobol_op and self.group_size > 1 and self.num_clients % self.group_size != 0: 

89 logger.error("Incorrect group_size, please remove or adjust this option") 

90 self.catch_error = True 

91 else: 

92 pass 

93 

94 def generate_client_scripts(self, first_id, number_of_scripts, default_parameters=None): 

95 """ 

96 Creates all required client.X.sh scripts and set up dict 

97 for fault tolerance 

98 """ 

99 for sim_id in range(first_id, first_id + number_of_scripts): 

100 # if the number of scripts to be generated becomes too significant 

101 # the server may spend too much time in this loop hence causing 

102 # the launcher to believe that the server timed out if no PING is 

103 # received 

104 if number_of_scripts > 10000 and (sim_id - first_id) % 10000 == 0: 

105 self.time_monitor.check_clock(time.monotonic(), self) 

106 if default_parameters is not None: 

107 parameters = default_parameters 

108 else: 

109 if not self.sobol_op: 

110 parameters = list(next(self.parameter_generator)) 

111 else: 

112 parameters = self.draw_from_pick_freeze() 

113 if self.rank == 0: 

114 client_script_i = os.path.abspath(f"./client_scripts/client.{str(sim_id)}.sh") 

115 Path('./client_scripts').mkdir(parents=True, exist_ok=True) 

116 with open(client_script_i, "w") as f: 

117 print("#!/bin/sh", file=f) 

118 print("exec env \\", file=f) 

119 print(f" MELISSA_SIMU_ID={sim_id} \\", file=f) 

120 print(f" MELISSA_SERVER_NODE_NAME={self.node_name} \\", file=f) 

121 # str() conversion causes problems with scientific notations 

122 # and should not be used 

123 if self.rm_script: 

124 print( 

125 " " 

126 + " ".join( 

127 [os.path.join(os.getcwd(), "client.sh")] 

128 + [ 

129 np.format_float_positional(x) if type(x) is not str 

130 else x for x in parameters 

131 ] 

132 ) 

133 + " &", 

134 file=f, 

135 ) 

136 print(" wait", file=f) 

137 print(' rm "$0"', file=f) 

138 else: 

139 print( 

140 " " 

141 + " ".join( 

142 [os.path.join(os.getcwd(), "client.sh")] 

143 + [ 

144 np.format_float_positional(x) if type(x) is not str 

145 else x for x in parameters 

146 ] 

147 ), 

148 file=f, 

149 ) 

150 

151 os.chmod(client_script_i, 0o744) 

152 

153 logger.info( 

154 f"Rank {self.rank}: created client.{sim_id}.sh with parameters {parameters}" 

155 ) 

156 

157 # Fault-tolerance dictionary creation and update 

158 group_id = sim_id // self.group_size 

159 if group_id not in self.groups: 

160 group = Group(group_id) 

161 self.groups[group_id] = group 

162 if sim_id not in self.groups[group_id].simulations: 

163 self.n_submitted_simulations += 1 

164 simulation = Simulation( 

165 sim_id, self.num_samples, self.fields, parameters 

166 ) 

167 self.groups[group_id].simulations[sim_id] = simulation 

168 

169 def draw_from_pick_freeze(self) -> List: 

170 """ 

171 Returns a row from the pick-freeze matrix 

172 """ 

173 if len(self.pick_freeze_matrix) > 0: 

174 return self.pick_freeze_matrix.pop(0) 

175 else: 

176 logging.debug("Build pick-freeze matrix") 

177 self.build_pick_freeze_matrix() 

178 return self.pick_freeze_matrix.pop(0) 

179 

180 def build_pick_freeze_matrix(self): 

181 """ 

182 Builds the pick-freeze matrix for one group 

183 """ 

184 self.pick_freeze_matrix = list(next(self.parameter_generator)) 

185 

186 def start(self): 

187 """ 

188 The main execution method 

189 """ 

190 if not self.restart: 

191 self.launch_first_groups() 

192 

193 if self.restart: 

194 # the reinitialization from checkpoint occurs here 

195 logger.info(f"Continuing from checkpoint {self.restart}") 

196 self.restart_from_checkpoint() 

197 if self.rank == 0: 

198 self.kill_and_restart_simulations() 

199 

200 self.setup_environment() 

201 

202 self.server_online() 

203 

204 self.server_offline() 

205 

206 self.server_finalize() 

207 

208 def server_online(self): 

209 """ 

210 Method where user controls the data handling while 

211 server is online. 

212 """ 

213 self.receive() 

214 return 

215 

216 def receive(self): 

217 """Handle data from the server.""" 

218 self._is_receiving = True 

219 received_samples = 0 

220 while not self.all_done(): 

221 status = self.run() 

222 if status is not None: 

223 if isinstance(status, PartialSimulationData): 

224 logger.debug( 

225 f"receive message: sim_id {status.simulation_id}, " 

226 f"timestep {status.time_step}", 

227 ) 

228 received_samples += 1 

229 

230 # compute the statistics on the received data 

231 self.compute_stats(status) 

232 self.checkpoint_state() 

233 

234 self._is_receiving = False 

235 

236 def handle_simulation_data(self, msg): 

237 """ 

238 Parses and validates the incoming data messages from simulations 

239 """ 

240 # 1. Deserialize message 

241 msg_data: PartialSimulationData = PartialSimulationData.from_msg(msg, self.learning) 

242 logger.debug( 

243 f"Rank {self.rank} received {msg_data} from rank {msg_data.client_rank} " 

244 f"(vect_size: {len(msg_data.data)})" 

245 ) 

246 # 2. Apply filters 

247 if not (0 <= msg_data.time_step < self.num_samples): 

248 logger.warning( 

249 f"Rank {self.rank}: bad timestep {msg_data.time_step}" 

250 ) 

251 return None 

252 if msg_data.field not in self.fields: 

253 logger.warning(f"Rank {self.rank}: bad field {msg_data.field}") 

254 return None 

255 

256 # when sobol_op=1 the results of each simulation in the group are gathered on 

257 # the ranks of its first simulation and are sent at once by each rank 

258 # which means that len(msg_data.data) = group_size * data_size 

259 # in addition msg_data.simulation_id is actually the group_id 

260 for sim in range(len(msg_data.data) // msg_data.data_size): 

261 if not self.sobol_op: 

262 group_id = msg_data.simulation_id // self.group_size 

263 sim_id = msg_data.simulation_id 

264 else: 

265 group_id = msg_data.simulation_id 

266 sim_id = group_id * self.group_size + sim 

267 simulation = self.groups[group_id].simulations[sim_id] 

268 

269 simulation_status, simulation_data = self.check_simulation_data( 

270 simulation, msg_data 

271 ) 

272 if simulation_status == SimulationDataStatus.COMPLETE: 

273 logger.debug( 

274 f"Rank {self.rank}: assembled time-step {simulation_data.time_step} " 

275 f"- simulationID {simulation_data.simulation_id}" 

276 ) 

277 elif simulation_status == SimulationDataStatus.ALREADY_RECEIVED: 

278 logger.warning(f"Rank {self.rank}: duplicate simulation data {msg_data}") 

279 

280 # Check if simulation has finished 

281 if ( 

282 simulation_status == SimulationDataStatus.COMPLETE 

283 or simulation_status == SimulationDataStatus.EMPTY 

284 ) and simulation.finished(): 

285 logger.info(f"Rank {self.rank}: simulation {simulation.id} finished") 

286 self.n_finished_simulations += 1 

287 

288 return simulation_data 

289 

290 def check_simulation_data( 

291 self, simulation: Simulation, simulation_data: PartialSimulationData 

292 ) -> Tuple[ 

293 SimulationDataStatus, 

294 Union[Optional[SimulationData], Optional[PartialSimulationData]], 

295 ]: 

296 """ 

297 Look for duplicated messages, 

298 update received_simulation_data and the simulation_data status. 

299 """ 

300 if simulation_data.client_rank not in simulation.received_simulation_data: 

301 simulation.received_simulation_data[simulation_data.client_rank] = {} 

302 simulation.received_time_steps[simulation_data.client_rank] = ( 

303 np.zeros((len(self.fields), self.num_samples), dtype=bool) 

304 ) 

305 # Data have already been received 

306 if simulation.has_already_received( 

307 simulation_data.client_rank, simulation_data.time_step, simulation_data.field 

308 ): 

309 return SimulationDataStatus.ALREADY_RECEIVED, None 

310 # Time step has never been seen 

311 if simulation_data.time_step not in simulation.received_simulation_data[ 

312 simulation_data.client_rank 

313 ]: 

314 simulation.received_simulation_data[simulation_data.client_rank][ 

315 simulation_data.time_step 

316 ] = {field: None for field in simulation.fields} 

317 # Update the entry 

318 # for SA it is more memory efficient not to keep track of the whole simulation_data 

319 simulation.received_simulation_data[simulation_data.client_rank][ 

320 simulation_data.time_step 

321 ][simulation_data.field] = 1 

322 simulation._mark_as_received( 

323 simulation_data.client_rank, simulation_data.time_step, simulation_data.field 

324 ) 

325 if simulation.is_complete(simulation_data.time_step): 

326 # All fields have been received for the time step 

327 simulation.n_received_time_steps += 1 

328 # Check there is actual data 

329 is_empty = simulation_data.data_size == 0 

330 if is_empty: 

331 # Data have been set to another device, fields are empty 

332 del simulation.received_simulation_data[simulation_data.client_rank][ 

333 simulation_data.time_step 

334 ] 

335 return SimulationDataStatus.EMPTY, None 

336 

337 del simulation.received_simulation_data[simulation_data.client_rank][ 

338 simulation_data.time_step 

339 ] 

340 return SimulationDataStatus.COMPLETE, simulation_data 

341 else: 

342 # Not all fields have been received yet 

343 return SimulationDataStatus.PARTIAL, simulation_data 

344 

345 def setup_environment(self): 

346 return super().setup_environment() 

347 

348 def server_offline(self): 

349 """ 

350 Post processing goes here. Not required. 

351 """ 

352 self.melissa_write_stats() 

353 return 

354 

355 def server_finalize(self): 

356 """ 

357 All finalization methods go here. 

358 """ 

359 logger.info("stop server") 

360 self.write_final_report() 

361 self.close_connection() 

362 return 

363 

364 def process_simulation_data(cls, msg: SimulationData, config: dict): 

365 """ 

366 method used to custom process sa-data 

367 """ 

368 return 

369 

370 def compute_stats(self, pdata: PartialSimulationData) -> None: 

371 """ 

372 Link into stats lib for computing online statistics. 

373 """ 

374 if self.first_stat_computation: 

375 self.first_stat_computation = False 

376 for field in self.fields: 

377 self.melissa_moments[field] = {} 

378 if self.sobol_op: 

379 self.melissa_sobol[field] = {} 

380 

381 # Progressive initialization of the client_rank entry 

382 # so that we do not iterate over client_comm_size which 

383 # has not been broadcasted yet 

384 if pdata.client_rank not in self.seen_ranks: 

385 self.seen_ranks.append(pdata.client_rank) 

386 for field in self.fields: 

387 self.melissa_moments[field][pdata.client_rank] = {} 

388 if self.sobol_op: 

389 self.melissa_sobol[field][pdata.client_rank] = {} 

390 for t in range(self.num_samples): 

391 self.melissa_moments[field][pdata.client_rank][t] = ( 

392 IterativeMoments(self.max_order, dim=pdata.data_size) 

393 ) 

394 if self.sobol_op: 

395 self.melissa_sobol[field][pdata.client_rank][t] = ( 

396 IterativeSensitivityMartinez(nb_parms=self.nb_parameters, 

397 dim=pdata.data_size) 

398 ) 

399 

400 # when sobol_op=1, results are grouped thus pdata.data contains the solution 

401 # vectors of each simulation in the group and must be reshaped 

402 # since only the first two solutions are used to compute the moments 

403 np_data = pdata.data.reshape(-1, pdata.data_size) 

404 self.melissa_moments[ 

405 pdata.field 

406 ][pdata.client_rank][pdata.time_step].increment(np_data[0]) 

407 

408 if self.sobol_op: 

409 # increment the sobol data structure 

410 self.melissa_sobol[ 

411 pdata.field 

412 ][pdata.client_rank][pdata.time_step]._increment(np_data) 

413 # increment the moments with the second solution 

414 self.melissa_moments[ 

415 pdata.field 

416 ][pdata.client_rank][pdata.time_step].increment(np_data[1]) 

417 

418 def melissa_write_stats(self): 

419 """ 

420 Write the computed statistics on file. 

421 """ 

422 # Turn server monitoring off 

423 if self.rank == 0: 

424 snd_msg = self.encode_msg(message.StopTimeoutMonitoring()) 

425 self.launcherfd.send(snd_msg) 

426 

427 # Brodcast client_comm_size to all server ranks 

428 client_comm_size: int = self.client_comm_size 

429 if self.rank == 0: 

430 self.comm.bcast(client_comm_size, root=0) 

431 else: 

432 client_comm_size = self.comm.bcast(client_comm_size, root=0) 

433 self.client_comm_size = client_comm_size 

434 logger.info(f"Rank: {self.rank}, gathered client comm size: {self.client_comm_size}") 

435 

436 # Update melissa_moments with missing client ranks 

437 for field in self.fields: 

438 for client_rank in range(self.client_comm_size): 

439 if client_rank not in self.melissa_moments[field]: 

440 self.melissa_moments[field][client_rank] = {} 

441 for t in range(self.num_samples): 

442 self.melissa_moments[field][client_rank][t] = ( 

443 IterativeMoments(self.max_order, dim=0) 

444 ) 

445 

446 temp_offset: int = 0 

447 local_vect_sizes: npt.ArrayLike = np.zeros(self.comm_size, dtype=int) 

448 vect_size: npt.ArrayLike = np.zeros(1, dtype=int) 

449 global_vect_size: int = 0 

450 

451 # Compute the global vect size 

452 field = self.fields[0] 

453 for client_rank in self.melissa_moments[field].keys(): 

454 vect_size += np.size(self.melissa_moments[field][client_rank][0].m1) 

455 

456 self.comm.Allgather([vect_size, MPI.INT], [local_vect_sizes, MPI.INT]) 

457 global_vect_size = np.sum(local_vect_sizes) 

458 logger.info(f"global_vect: {global_vect_size}") 

459 

460 d_buffer = np.zeros(global_vect_size) 

461 

462 if self.mean: 

463 self.comm.Barrier() 

464 for field in self.fields: 

465 for t in range(self.num_samples): 

466 file_name = "./results/results.{}_{}.{}".format( 

467 field, 

468 "mean", 

469 str(t + 1).zfill(len(str(self.num_samples))) 

470 ) 

471 if self.rank == 0: 

472 logger.info(f"file name: {file_name}") 

473 for rank in range(self.client_comm_size): 

474 mean = self.melissa_moments[field][rank][t].m1 

475 if np.size(mean) > 0: 

476 d_buffer[ 

477 temp_offset:temp_offset + np.size(mean) 

478 ] = mean 

479 temp_offset += np.size(mean) 

480 temp_offset = 0 

481 d_buffer = self.gather_data(local_vect_sizes, d_buffer) 

482 if self.rank == 0: 

483 np.savetxt(file_name, d_buffer) 

484 

485 if self.variance: 

486 self.comm.Barrier() 

487 for field in self.fields: 

488 for t in range(self.num_samples): 

489 file_name = "./results/results.{}_{}.{}".format( 

490 field, 

491 "variance", 

492 str(t + 1).zfill(len(str(self.num_samples))) 

493 ) 

494 if self.rank == 0: 

495 logger.info(f"file name: {file_name}") 

496 for rank in range(self.client_comm_size): 

497 mean = self.melissa_moments[field][rank][t].get_mean() 

498 if np.size(mean) > 0: 

499 var = self.melissa_moments[field][rank][t].get_variance() 

500 d_buffer[ 

501 temp_offset:temp_offset + np.size(var) 

502 ] = var 

503 temp_offset += np.size(var) 

504 temp_offset = 0 

505 d_buffer = self.gather_data(local_vect_sizes, d_buffer) 

506 if self.rank == 0: 

507 np.savetxt(file_name, d_buffer) 

508 

509 if self.skewness: 

510 self.comm.Barrier() 

511 for field in self.fields: 

512 for t in range(self.num_samples): 

513 file_name = "./results/results.{}_{}.{}".format( 

514 field, 

515 "skewness", 

516 str(t + 1).zfill(len(str(self.num_samples))) 

517 ) 

518 if self.rank == 0: 

519 logger.info(f"file name: {file_name}") 

520 for rank in range(self.client_comm_size): 

521 mean = self.melissa_moments[field][rank][t].get_mean() 

522 if np.size(mean) > 0: 

523 crank_skewness = self.melissa_moments[field][rank][t].get_skewness() 

524 d_buffer[ 

525 temp_offset:temp_offset + np.size(crank_skewness) 

526 ] = crank_skewness 

527 temp_offset += np.size(crank_skewness) 

528 temp_offset = 0 

529 d_buffer = self.gather_data(local_vect_sizes, d_buffer) 

530 if self.rank == 0: 

531 np.savetxt(file_name, d_buffer) 

532 # free memory 

533 del crank_skewness 

534 

535 if self.kurtosis: 

536 self.comm.Barrier() 

537 for field in self.fields: 

538 for t in range(self.num_samples): 

539 file_name = "./results/results.{}_{}.{}".format( 

540 field, 

541 "kurtosis", 

542 str(t + 1).zfill(len(str(self.num_samples))) 

543 ) 

544 if self.rank == 0: 

545 logger.info(f"file name: {file_name}") 

546 for rank in range(self.client_comm_size): 

547 mean = self.melissa_moments[field][rank][t].get_mean() 

548 if np.size(mean) > 0: 

549 crank_kurtosis = self.melissa_moments[field][rank][t].get_kurtosis() 

550 d_buffer[ 

551 temp_offset:temp_offset + np.size(crank_kurtosis) 

552 ] = crank_kurtosis 

553 temp_offset += np.size(crank_kurtosis) 

554 temp_offset = 0 

555 d_buffer = self.gather_data(local_vect_sizes, d_buffer) 

556 if self.rank == 0: 

557 np.savetxt(file_name, d_buffer) 

558 # free memory 

559 del crank_kurtosis 

560 

561 if self.sobol_op: 

562 self.comm.Barrier() 

563 for field in self.fields: 

564 for param in range(self.nb_parameters): 

565 for t in range(self.num_samples): 

566 file_name = "./results/results.{}_{}{}.{}".format( 

567 field, 

568 "sobol", 

569 str(param), 

570 str(t + 1).zfill(len(str(self.num_samples))) 

571 ) 

572 if self.rank == 0: 

573 logger.info(f"file name: {file_name}") 

574 for rank in range(self.client_comm_size): 

575 pearson_b = self.melissa_sobol[field][rank][t].pearson_B[param] 

576 if np.size(pearson_b) > 0: 

577 d_buffer[ 

578 temp_offset:temp_offset + np.size(pearson_b) 

579 ] = pearson_b 

580 temp_offset += np.size(pearson_b) 

581 temp_offset = 0 

582 d_buffer = self.gather_data(local_vect_sizes, d_buffer) 

583 if self.rank == 0: 

584 np.savetxt(file_name, d_buffer) 

585 

586 for field in self.fields: 

587 for param in range(self.nb_parameters): 

588 for t in range(self.num_samples): 

589 file_name = "./results/results.{}_{}{}.{}".format( 

590 field, 

591 "sobol_tot", 

592 str(param), 

593 str(t + 1).zfill(len(str(self.num_samples))) 

594 ) 

595 if self.rank == 0: 

596 logger.info(f"file name: {file_name}") 

597 for rank in range(self.client_comm_size): 

598 pearson_a = self.melissa_sobol[field][rank][t].pearson_A[param] 

599 if np.size(pearson_a) > 0: 

600 d_buffer[ 

601 temp_offset:temp_offset + np.size(pearson_a) 

602 ] = pearson_a 

603 temp_offset += np.size(pearson_a) 

604 temp_offset = 0 

605 d_buffer = self.gather_data(local_vect_sizes, d_buffer) 

606 if self.rank == 0: 

607 np.savetxt(file_name, d_buffer) 

608 

609 def gather_data( 

610 self, 

611 local_vect_sizes: npt.NDArray[np.int_], 

612 d_buffer: npt.NDArray[np.float_] 

613 ) -> npt.NDArray[np.float_]: 

614 """ 

615 Gather data on rank 0. 

616 """ 

617 temp_offset: int = 0 

618 if self.rank == 0: 

619 for rank in range(1, self.comm_size): 

620 temp_offset += local_vect_sizes[rank - 1] 

621 if local_vect_sizes[rank] > 0: 

622 d_buffer[ 

623 temp_offset:temp_offset 

624 + local_vect_sizes[rank] 

625 ] = self.comm.recv(source=rank) 

626 temp_offset = 0 

627 else: 

628 if local_vect_sizes[self.rank] > 0: 

629 self.comm.send(d_buffer[:local_vect_sizes[self.rank]], dest=0) 

630 

631 return d_buffer 

632 

633 def checkpoint_state(self): 

634 if not self.checkpoint_interval: 

635 return 

636 

637 self.checkpoint_count += 1 

638 if self.checkpoint_count % self.checkpoint_interval != 0: 

639 return 

640 

641 logger.info("Checkpointing state") 

642 self.save_base_state() 

643 

644 stats_metadata = {"seen_ranks": self.seen_ranks, "num_samples": self.num_samples} 

645 

646 # logger.info(f"Checkpointing moments {self.melissa_moments}") 

647 with open("checkpoints/melissa_moments.pkl", 'wb') as f: 

648 cloudpickle.dump(self.melissa_moments, f) 

649 

650 if self.sobol_op: 

651 with open("checkpoints/melissa_sobol.pkl", 'wb') as f: 

652 cloudpickle.dump(self.melissa_sobol, f) 

653 

654 with open("checkpoints/stats_metadata.json", 'w') as f: 

655 rapidjson.dump(stats_metadata, f) 

656 return 

657 

658 def restart_from_checkpoint(self): 

659 """ 

660 Invert checkpoint_state 

661 """ 

662 self.load_base_state() 

663 

664 with open("checkpoints/melissa_moments.pkl", 'rb') as f: 

665 self.melissa_moments = cloudpickle.load(f) 

666 

667 if self.sobol_op: 

668 with open("checkpoints/melissa_sobol.pkl", 'rb') as f: 

669 self.melissa_sobol = cloudpickle.load(f) 

670 

671 with open("checkpoints/stats_metadata.json", 'r') as f: 

672 stats_metadata = rapidjson.load(f) 

673 

674 self.seen_ranks = stats_metadata["seen_ranks"] 

675 self.num_samples = stats_metadata["num_samples"] 

676 self.first_stat_computation = False 

677 

678 return