Coverage for melissa/server/base_server.py: 40%

382 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-09-22 10:36 +0000

1from datetime import timedelta 

2import logging 

3import os 

4import socket 

5import threading 

6import time 

7from abc import ABC, abstractmethod 

8from enum import Enum 

9from pathlib import Path 

10from typing import Any, Dict, Optional, Tuple, Union 

11 

12import numpy as np 

13import zmq 

14from mpi4py import MPI 

15import cloudpickle 

16 

17from melissa.launcher import config, message 

18from melissa.scheduler import job 

19from melissa.server.fault_tolerance import FaultTolerance 

20from melissa.server.parameters import AbstractExperiment 

21from melissa.server.message import ConnectionRequest, ConnectionResponse 

22from melissa.server.simulation import (Group, PartialSimulationData, 

23 Simulation, SimulationData, 

24 SimulationDataStatus) 

25from melissa.utility.networking import (LengthPrefixFramingDecoder, 

26 LengthPrefixFramingEncoder, 

27 connect_to_launcher, 

28 get_rank_and_num_server_proc) 

29from melissa.utility.timer import Timer 

30import rapidjson 

31 

32logger = logging.getLogger(__name__) 

33 

34 

35class ServerStatus(Enum): 

36 CHECKPOINT = 1 

37 TIMEOUT = 2 

38 

39 

40def select_protocol(): 

41 try: 

42 protocol = os.environ["MELISSA_LAUNCHER_PROTOCOL"] 

43 except KeyError: 

44 raise Exception("Undefined protocol") 

45 

46 if protocol == "SCTP": 

47 return socket.IPPROTO_SCTP, "SCTP" 

48 elif protocol == "TCP": 

49 return socket.IPPROTO_TCP, "TCP" 

50 else: 

51 raise Exception(f"Unsupported server/launcher communication protocol {protocol}") 

52 

53 

54class BaseServer(ABC): 

55 def __init__( 

56 self, 

57 config: Dict[str, Any], 

58 checkpoint_file: str = "checkpoint.pkl", 

59 data_hwm: int = 4096, 

60 restart: bool = False, 

61 ): 

62 """ 

63 Initializes the server, opens all connections and 

64 launches the study simulations 

65 """ 

66 self.restart = int(os.environ["MELISSA_RESTART"]) 

67 rank, nb_proc_server = get_rank_and_num_server_proc() 

68 self.config: Dict[str, Any] = config 

69 self.sweep_params: Dict[str, Any] = config.get("sweep_params", {}) 

70 self.checkpoint_file: str = checkpoint_file 

71 self.num_server_proc: int = nb_proc_server 

72 self.study_options: Dict[str, Any] = self.config["study_options"] 

73 self.user_data_path = Path("user_data") 

74 self.rank: int = rank 

75 self._is_receiving: bool = False 

76 self._is_online: bool = False 

77 self.catch_error: bool = False 

78 

79 # Variables for final report 

80 self.total_bytes_recv: int = 0 

81 self.t0: float = time.time() 

82 

83 # MPI initialization 

84 self.sobol_op: int = 0 

85 self.comm = MPI.COMM_WORLD 

86 self.rank = self.comm.Get_rank() 

87 self.comm_size = self.comm.Get_size() 

88 self.client_comm_size: int = 0 

89 

90 # Scan study options dictionary 

91 self.crashes_before_redraw = self.study_options.get("crashes_before_redraw", 1) 

92 self.max_delay = self.study_options.get("simulation_timeout", 60) 

93 self.rm_script = self.study_options.get("remove_client_scripts", False) 

94 try: 

95 self.fields = self.study_options["field_names"] 

96 self.nb_parameters = self.study_options["nb_parameters"] 

97 self.num_samples = self.study_options["num_samples"] 

98 self.parameter_sweep_size = self.study_options["parameter_sweep_size"] 

99 self.num_clients: int = self.parameter_sweep_size # introduced for semantic clarity 

100 except Exception as e: 

101 logger.error(f"Incorrect study conf file: {e}") 

102 self.catch_error = True 

103 self.group_size = self.study_options.get("group_size", 1) 

104 self.zmq_hwm = self.study_options.get("zmq_hwm", 0) 

105 

106 # Termination condition arguments 

107 self.number_of_groups: int = self.num_clients // self.group_size 

108 self.n_submitted_simulations: int = 0 

109 self.n_finished_simulations: int = 0 

110 self.n_connected_simulations: int = 0 

111 

112 # Fault-Tolerance initialization 

113 self.no_fault_tolerance = ( 

114 True if os.environ["MELISSA_FAULT_TOLERANCE"] == "OFF" else False 

115 ) 

116 logger.info("fault-tolerance " + ("OFF" if self.no_fault_tolerance else "ON")) 

117 self.ft = FaultTolerance( 

118 self.no_fault_tolerance, 

119 self.max_delay, 

120 self.crashes_before_redraw, 

121 self.number_of_groups, 

122 ) 

123 self.groups: Dict[int, Group] = {} 

124 self.parameter_generator: Optional[AbstractExperiment] = None 

125 

126 # TODO: maybe this should be an MPI gather on rank 0. 

127 self.verbose_level = 0 

128 

129 if restart: 

130 # TODO implement generic restart/checkpoint 

131 raise Exception("Error, this needs to be implemented") 

132 

133 def initialize_connections(self): 

134 self.initialize_ports() 

135 # 0. Connect to launcher 

136 self.connect_to_launcher() 

137 # 1. Set up sockets 

138 self.setup_sockets() 

139 # 2. Setup poller 

140 self.setup_poller() 

141 # 3. Setup TimeMonitor object 

142 self.time_monitor = TimeMonitor(time.monotonic(), 30.) 

143 

144 if self.config.get("vscode_debugging", False): 

145 self.start_debugger() 

146 

147 def start_debugger(self): 

148 import debugpy 

149 # 5678 is the default attach port that we recommend users to set 

150 # in the documentation. 

151 debugpy.listen(5678) 

152 logger.warning("Waiting for debugger attach, please start the " 

153 "debugger by navigating to debugger pane ctrl+shift+d " 

154 "and selecting\n" 

155 "Python: Remote Attach") 

156 debugpy.wait_for_client() 

157 logger.info("Debugger successfully attached.") 

158 # send message to launcher to ensure debugger doesnt timeout 

159 snd_msg = self.encode_msg(message.StopTimeoutMonitoring()) 

160 self.launcherfd.send(snd_msg) 

161 

162 def initialize_ports(self, connection_port: int = 2003, data_puller_port: int = 2006): 

163 # Ports initialization 

164 logger.info("Rank {} | Initializing server...".format(self.rank)) 

165 self.node_name = socket.gethostname() 

166 self.connection_port = connection_port 

167 self.data_puller_port = str(data_puller_port) + str(self.rank) 

168 self.data_puller_port_name = "tcp://{}:{}".format( 

169 self.node_name, 

170 self.data_puller_port, 

171 ) 

172 self.port_names = self.comm.allgather(self.data_puller_port_name) 

173 logger.debug("port_names {}".format(self.port_names)) 

174 

175 def connect_to_launcher(self): 

176 # Setup communication instances 

177 self.protocol, prot_name = select_protocol() 

178 logger.info(f"Server/launcher communication protocol: {prot_name}") 

179 if self.rank == 0: 

180 self.launcherfd: socket.socket = connect_to_launcher() 

181 logger.debug(f"Launcher fd set up: { self.launcherfd.fileno()}") 

182 self.launcherfd.send(self.encode_msg(message.CommSize(self.comm_size))) 

183 logger.debug(f"Comm size {self.comm_size} sent to launcher") 

184 self.launcherfd.send(self.encode_msg(message.GroupSize(self.group_size))) 

185 logger.debug(f"Group size {self.group_size} sent to launcher") 

186 # synchronize non-zero ranks after rank 0 connection to make sure 

187 # these ranks only connect after comm_size is known to the launcher 

188 self.comm.Barrier() 

189 if self.rank > 0: 

190 self.launcherfd: socket.socket = connect_to_launcher() 

191 logger.debug(f"Launcher fd set up: {self.launcherfd.fileno()}") 

192 

193 # if an error was caught earlier the connection is immediately closed 

194 # this will prevent the server to be relaunched infinitely 

195 if self.catch_error: 

196 raise Exception("Incorrect configuration file") 

197 

198 def setup_sockets(self): 

199 self.context = zmq.Context() 

200 # Simulations (REQ) <-> Server (REP) 

201 self.connection_responder = self.context.socket(zmq.REP) 

202 if self.rank == 0: 

203 logger.info("Rank {}: binding..".format(self.rank)) 

204 addr = "tcp://*:{port}".format(port=self.connection_port) 

205 try: 

206 self.connection_responder.bind(addr) 

207 except Exception as e: 

208 logger.error("Rank {}: could not bind to {}".format(self.rank, addr)) 

209 raise e 

210 # Simulations (PUSH) -> Server (PULL) 

211 self.data_puller = self.context.socket(zmq.PULL) 

212 self.data_puller.setsockopt(zmq.RCVHWM, self.zmq_hwm) 

213 self.data_puller.setsockopt(zmq.LINGER, -1) 

214 addr = "tcp://*:{}".format(self.data_puller_port) 

215 logger.info("Rank {}: connecting puller to: {}".format(self.rank, addr)) 

216 self.data_puller.bind(addr) 

217 # Time-out checker (creates thread) 

218 self.timerfd_0, self.timerfd_1 = socket.socketpair( 

219 socket.AF_UNIX, socket.SOCK_STREAM 

220 ) 

221 self.timer = Timer(self.timerfd_1, timedelta(seconds=self.max_delay)) 

222 self.t_timer = threading.Thread(target=lambda: self.timer.run(), daemon=True) 

223 self.t_timer.start() 

224 

225 def setup_poller(self): 

226 self.poller = zmq.Poller() 

227 self.poller.register(self.data_puller, zmq.POLLIN) 

228 self.poller.register(self.timerfd_0, zmq.POLLIN) 

229 self.poller.register(self.launcherfd, zmq.POLLIN) 

230 if self.rank == 0: 

231 self.poller.register(self.connection_responder, zmq.POLLIN) 

232 

233 def launch_first_groups(self): 

234 """ 

235 Launches the study groups 

236 """ 

237 # Get current working directory containing the client script template 

238 client_script = os.path.abspath("client.sh") 

239 

240 if not os.path.isfile(client_script): 

241 raise Exception("error client script not found") 

242 

243 # Generate all client scripts 

244 self.generate_client_scripts(0, self.num_clients) 

245 

246 # Launch every group 

247 # note that the launcher message stacking feature does not work 

248 for grp_id in range(self.number_of_groups): 

249 self.launch_group(grp_id) 

250 

251 def generate_client_scripts(self, first_id, number_of_scripts, default_parameters=None): 

252 """ 

253 Creates all required client.X.sh scripts and set up dict 

254 for fault tolerance 

255 """ 

256 for sim_id in range(first_id, first_id + number_of_scripts): 

257 # if the number of scripts to be generated becomes too significant 

258 # the server may spend too much time in this loop hence causing 

259 # the launcher to believe that the server timed out if no PING is 

260 # received 

261 if number_of_scripts > 10000 and (sim_id - first_id) % 10000 == 0: 

262 self.time_monitor.check_clock(time.monotonic(), self) 

263 if default_parameters is not None: 

264 parameters = default_parameters 

265 else: 

266 parameters = list(next(self.parameter_generator)) 

267 if self.rank == 0: 

268 self.generate_client_script(sim_id, parameters) 

269 

270 logger.info( 

271 f"Rank {self.rank}: created client.{sim_id}.sh with parameters {parameters}" 

272 ) 

273 

274 # Fault-tolerance dictionary creation and update 

275 group_id = sim_id // self.group_size 

276 if group_id not in self.groups: 

277 group = Group(group_id) 

278 self.groups[group_id] = group 

279 if sim_id not in self.groups[group_id].simulations: 

280 self.n_submitted_simulations += 1 

281 simulation = Simulation( 

282 sim_id, self.num_samples, self.fields, parameters 

283 ) 

284 self.groups[group_id].simulations[sim_id] = simulation 

285 

286 def generate_client_script(self, sim_id, parameters): 

287 """ 

288 Generate a single client script 

289 """ 

290 client_script_i = os.path.abspath(f"./client_scripts/client.{str(sim_id)}.sh") 

291 Path('./client_scripts').mkdir(parents=True, exist_ok=True) 

292 with open(client_script_i, "w") as f: 

293 print("#!/bin/sh", file=f) 

294 print("exec env \\", file=f) 

295 print(f" MELISSA_SIMU_ID={sim_id} \\", file=f) 

296 print(f" MELISSA_SERVER_NODE_NAME={self.node_name} \\", file=f) 

297 # str() conversion causes problems with scientific notations 

298 # and should not be used 

299 if self.rm_script: 

300 print( 

301 " " 

302 + " ".join( 

303 [os.path.join(os.getcwd(), "client.sh")] 

304 + [ 

305 np.format_float_positional(x) if type(x) is not str 

306 else x for x in parameters 

307 ] 

308 ) 

309 + " &", 

310 file=f, 

311 ) 

312 print(" wait", file=f) 

313 print(' rm "$0"', file=f) 

314 else: 

315 print( 

316 " " 

317 + " ".join( 

318 [os.path.join(os.getcwd(), "client.sh")] 

319 + [ 

320 np.format_float_positional(x) if type(x) is not str 

321 else x for x in parameters 

322 ] 

323 ), 

324 file=f, 

325 ) 

326 

327 os.chmod(client_script_i, 0o744) 

328 

329 def launch_group(self, group_id): 

330 """ 

331 Launches group group_id 

332 """ 

333 if self.rank == 0: 

334 # Job submission message to launcher (initial_id,num_jobs) 

335 snd_msg = self.encode_msg(message.JobSubmission(group_id, 1)) 

336 self.launcherfd.send(snd_msg) 

337 logger.info( 

338 f"Rank {self.rank}: group " 

339 f"{group_id+1}/{self.number_of_groups} " 

340 "submitted to launcher" 

341 ) 

342 

343 def kill_group(self, group_id): 

344 """ 

345 Kills the specified group 

346 """ 

347 if self.rank == 0: 

348 logger.warning(f"Server crashed, restarting incomplete job {group_id}") 

349 snd_msg = self.encode_msg(message.JobCancellation(group_id)) 

350 self.launcherfd.send(snd_msg) 

351 

352 def relaunch_group(self, group_id: int, new_grp: bool) -> None: 

353 """ 

354 Relaunch a failed group with or without new parameters 

355 """ 

356 if new_grp: 

357 del self.groups[group_id] 

358 # n_submitted_simulations needs to be decremented after each group deletion 

359 # to compensate for generate_client_scripts which will increment it group_size times 

360 self.n_submitted_simulations -= self.group_size 

361 for sim in range(self.group_size): 

362 self.generate_client_scripts( 

363 group_id * self.group_size + sim, 1, None 

364 ) 

365 else: 

366 for sim in range(self.group_size): 

367 self.generate_client_scripts( 

368 group_id * self.group_size + sim, 

369 1, 

370 self.groups[group_id].simulations[group_id * self.group_size + sim].parameters, 

371 ) 

372 

373 self.launch_group(group_id) 

374 

375 def handle_simulation_connection(self, msg): 

376 """ 

377 New simulation connection. Executed by rank 0 only 

378 """ 

379 request = ConnectionRequest.recv(msg) 

380 self.client_comm_size = request.comm_size 

381 logger.debug( 

382 f"Rank {self.rank}: [Connection] received connection message " 

383 f"from simulation {request.simulation_id} (client comm size {self.client_comm_size})" 

384 ) 

385 logger.debug( 

386 f"Rank {self.rank}: [Connection] sending response to simulation {request.simulation_id}" 

387 f" with learning set to {self.learning}" 

388 ) 

389 response = ConnectionResponse( 

390 self.comm_size, 

391 self.sobol_op, 

392 self.learning, 

393 self.nb_parameters, 

394 self.verbose_level, 

395 self.port_names, 

396 ) 

397 self.connection_responder.send(response.encode()) 

398 self.n_connected_simulations += 1 

399 logger.info( 

400 f"Rank {self.rank}: [Connection] connection established " 

401 f"with simulation {request.simulation_id}" 

402 ) 

403 return "Connection" 

404 

405 def run(self, timeout=10): 

406 """ 

407 main function which handles the incoming messages from 

408 the launcher, the timer and the clients 

409 """ 

410 # Checkpoint: this needs to be implemented specifically for melissa-dl 

411 

412 # 1. Poll sockets 

413 # ZMQ sockets 

414 sockets = dict(self.poller.poll(timeout)) 

415 if not sockets: 

416 return ServerStatus.TIMEOUT 

417 

418 if self.rank == 0: 

419 # 2. Look for connections from new simulations (just at rank 0) 

420 if ( 

421 self.connection_responder in sockets 

422 and sockets[self.connection_responder] == zmq.POLLIN 

423 ): 

424 msg = self.connection_responder.recv() 

425 logger.debug(f"Rank {self.rank}: Handle client connection request") 

426 self.handle_simulation_connection(msg) 

427 

428 # 3. Handle launcher message 

429 # this is a TCP/SCTP socket not a zmq one, so handled differently 

430 if self.launcherfd.fileno() in sockets: 

431 logger.debug(f"Rank {self.rank}: Handle launcher message") 

432 self.handle_fd() 

433 

434 # 4. Handle simulation data message 

435 if self.data_puller in sockets and sockets[self.data_puller] == zmq.POLLIN: 

436 logger.debug(f"Rank {self.rank}: Handle simulation data") 

437 msg = self.data_puller.recv() 

438 self.total_bytes_recv += len(msg) 

439 return self.handle_simulation_data(msg) 

440 

441 # 5. Handle timer message 

442 if self.timerfd_0.fileno() in sockets: 

443 logger.debug(f"Rank {self.rank}: Handle timer message") 

444 self.handle_timerfd() 

445 

446 def handle_timerfd(self): 

447 """ 

448 Handles timer messages 

449 """ 

450 self.timerfd_0.recv(1) 

451 self.ft.check_time_out(self.groups) 

452 

453 [ 

454 self.relaunch_group(grp_id, new_grp) 

455 for grp_id, new_grp in self.ft.restart_grp.items() 

456 ] 

457 self.ft.restart_grp = {} 

458 

459 def handle_fd(self): 

460 """ 

461 Handles the launcher's messages through the filedescriptor 

462 """ 

463 bs = self.launcherfd.recv(256) 

464 rcvd_msg = self.decode_msg(bs) 

465 

466 for msg in rcvd_msg: 

467 # 1. Launcher sent JOB_UPDATE message (msg.job_id <=> group_id) 

468 if isinstance(msg, message.JobUpdate): 

469 group = self.groups[msg.job_id] 

470 group_id = group.group_id 

471 

472 # React to simulation status 

473 if msg.job_state in [job.State.ERROR, job.State.FAILED]: 

474 logger.debug("Launcher indicates job failure") 

475 new_grp = self.ft.handle_failed_group(group_id, group) 

476 self.relaunch_group(group_id, new_grp) 

477 

478 # 2. Server sends PING 

479 if self.rank == 0: 

480 logger.debug( 

481 "Server got message from launcher and sends PING back" 

482 ) 

483 snd_msg = self.encode_msg(message.Ping()) 

484 self.launcherfd.send(snd_msg) 

485 

486 def decode_msg(self, byte_stream): 

487 msg_list = [] 

488 if self.protocol == socket.IPPROTO_TCP: 

489 packets = LengthPrefixFramingDecoder( 

490 config.TCP_MESSAGE_PREFIX_LENGTH 

491 ).execute(byte_stream) 

492 for p in packets: 

493 msg_list.append(message.deserialize(p)) 

494 logger.debug(f"Decoded launcher messages {msg_list}") 

495 return msg_list 

496 elif self.protocol == socket.IPPROTO_SCTP: 

497 msg_list.append(message.deserialize(byte_stream)) 

498 logger.debug(f"Decoded launcher messages {msg_list}") 

499 return msg_list 

500 else: # impossible case of figure 

501 raise Exception("Unsupported protocol") 

502 

503 def encode_msg(self, msg): 

504 if self.protocol == socket.IPPROTO_TCP: 

505 encoded_packet = LengthPrefixFramingEncoder( 

506 config.TCP_MESSAGE_PREFIX_LENGTH 

507 ).execute(msg.serialize()) 

508 return encoded_packet 

509 elif self.protocol == socket.IPPROTO_SCTP: 

510 return msg.serialize() 

511 else: # impossible case of figure 

512 raise Exception("Unsupported protocol") 

513 

514 def all_done(self) -> bool: 

515 """ 

516 Checks whether all clients data were received 

517 """ 

518 if self.n_finished_simulations == self.n_submitted_simulations: 

519 # join thread and close timer sockets 

520 logger.info(f"Rank {self.rank}: closes timer sockets") 

521 self.timerfd_0.close() 

522 self.poller.unregister(self.timerfd_0) 

523 self.t_timer.join(timeout=1) 

524 if self.t_timer.is_alive(): 

525 logger.warning("timer thread did not terminate") 

526 else: 

527 self.timerfd_1.close() 

528 return True 

529 else: 

530 return False 

531 

532 def close_connection(self, exit: int = 0): 

533 """ 

534 Signals to the launcher that the study has ended. 

535 """ 

536 snd_msg = self.encode_msg(message.Exit(exit)) 

537 # 1. Syn server processes 

538 self.comm.Barrier() 

539 

540 # 2. Send msg to the launcher 

541 if self.rank == 0: 

542 if exit > 0: 

543 logger.error("The server failed with an error") 

544 else: 

545 logger.info(f"Rank {self.rank}: turns the launcher off") 

546 self.launcherfd.send(snd_msg) 

547 return 

548 

549 def write_final_report(self) -> None: 

550 """ 

551 Write miscellaneous information about the analysis 

552 """ 

553 # Number of simulation 

554 if self.rank == 0: 

555 logger.info(f" - Number of simulations: {self.num_clients}") 

556 # Number of simulation processes 

557 if self.rank == 0: 

558 logger.info(f" - Number of simulation processes: {self.client_comm_size}") 

559 # Number of analysis processes 

560 if self.rank == 0: 

561 logger.info(f" - Number of server processes: {self.comm_size}") 

562 # Total time 

563 total_time = np.zeros(1, dtype=float) 

564 self.comm.Allreduce( 

565 [np.array([time.time() - self.t0], dtype=float), MPI.DOUBLE], 

566 [total_time, MPI.DOUBLE], 

567 op=MPI.SUM 

568 ) 

569 if self.rank == 0: 

570 logger.info(f" - Total time: {total_time[0] / self.comm_size} s") 

571 # Total MB received 

572 total_b = np.zeros(1, dtype=int) 

573 self.comm.Allreduce( 

574 [np.array([self.total_bytes_recv], dtype=int), MPI.INT], 

575 [total_b, MPI.INT], 

576 op=MPI.SUM 

577 ) 

578 if self.rank == 0: 

579 logger.info(f" - MB received: {total_b[0] / 1024**2} MB") 

580 

581 # abstract methods that children must override. 

582 

583 @abstractmethod 

584 def check_group_size(self): 

585 """ 

586 Checks if the group size was correctly set 

587 """ 

588 

589 @abstractmethod 

590 def handle_simulation_data(self, msg) -> Union[SimulationData, PartialSimulationData]: 

591 """ 

592 Parses and validates the incoming data messages from simulations. Unique to melissa flavors. 

593 """ 

594 

595 @abstractmethod 

596 def check_simulation_data( 

597 self, simulation: Simulation, simulation_data: PartialSimulationData 

598 ) -> Tuple[ 

599 SimulationDataStatus, 

600 Union[Optional[SimulationData], Optional[PartialSimulationData]], 

601 ]: 

602 """ 

603 Look for duplicated messages, 

604 update received_simulation_data and the simulation_data status. 

605 """ 

606 return SimulationDataStatus.EMPTY, None 

607 

608 def kill_and_restart_simulations(self): 

609 """ 

610 After a server crash and restart, cycle through 

611 loaded self.groups dictionary to find simulations which were unfinished 

612 kill them and relaunch them. 

613 """ 

614 for group in self.groups.keys(): 

615 for sim in self.groups[group].simulations.keys(): 

616 simulation = self.groups[group].simulations[sim] 

617 if simulation.n_received_time_steps < simulation.n_time_steps: 

618 params = self.groups[group].simulations[sim].parameters 

619 # generate the client script just in-case a client finished, removed 

620 # its own script, but before the server checkpointed itself. This 

621 # is a corner case. 

622 self.generate_client_script(sim, params) 

623 # relaunch the group 

624 self.launch_group(group) 

625 

626 @abstractmethod 

627 def setup_environment(self): 

628 """ 

629 Any necessary setup methods go here. DeepMelissa needs 

630 e.g. dist.init_process_group 

631 """ 

632 return 

633 

634 @abstractmethod 

635 def receive(self) -> None: 

636 """ 

637 Handle data coming from the server object 

638 """ 

639 

640 @abstractmethod 

641 def process_simulation_data(cls, msg: SimulationData, config: dict): 

642 """ 

643 method used to custom process data. See example in 

644 LorenzLearner 

645 """ 

646 return 

647 

648 @abstractmethod 

649 def start(self): 

650 """ 

651 The high level organization of events. 

652 """ 

653 

654 def other_processes_finished(self, batch_number: int) -> bool: 

655 """ 

656 Ensure distributed server processes are coordinated for 

657 _is_receiving 

658 """ 

659 return False 

660 

661 @abstractmethod 

662 def checkpoint_state(self): 

663 """ 

664 Checkpoint the server object at the current state. This depends 

665 on the type of server currently running (SA vs DL) 

666 """ 

667 return 

668 

669 def save_base_state(self): 

670 """ 

671 Checkpoint all common objects in the server class 

672 """ 

673 # make a directory for the checkpoint files if one does not exist 

674 if not os.path.exists("checkpoints"): 

675 os.makedirs("checkpoints", exist_ok=True) 

676 

677 # save some state metadata to be reloaded later 

678 

679 metadata = {"number_of_groups": self.number_of_groups, 

680 "n_submitted_simulations": self.n_submitted_simulations, 

681 "n_finished_simulations": self.n_finished_simulations, 

682 "n_connected_simulations": self.n_connected_simulations, 

683 "groups": self.groups, 

684 "t0": self.t0, 

685 "total_bytes_recv": self.total_bytes_recv 

686 } 

687 with open(f'checkpoints/metadata_{self.rank}.pkl', 'wb') as f: 

688 cloudpickle.dump(metadata, f) 

689 

690 if self.rank == 0: 

691 metadata = {"MELISSA_RESTART": int(os.environ["MELISSA_RESTART"])} 

692 with open('checkpoints/restart_metadata.json', 'wb') as f: 

693 # write the metadata to json 

694 rapidjson.dump(metadata, f) 

695 

696 return 

697 

698 @abstractmethod 

699 def restart_from_checkpoint(self): 

700 """ 

701 Restart the server object from a checkpoint. This depends 

702 on the type of server currently running (SA vs DL) 

703 """ 

704 return 

705 

706 def load_base_state(self): 

707 """ 

708 Load all common objects in the server class 

709 """ 

710 # load the metadata 

711 with open(f'checkpoints/metadata_{self.rank}.pkl', 'rb') as f: 

712 metadata = cloudpickle.load(f) 

713 

714 self.number_of_groups = metadata["number_of_groups"] 

715 self.n_submitted_simulations = metadata["n_submitted_simulations"] 

716 self.n_finished_simulations = metadata["n_finished_simulations"] 

717 self.n_connected_simulations = metadata["n_connected_simulations"] 

718 self.groups = metadata["groups"] 

719 self.t0 = metadata["t0"] 

720 self.total_bytes_recv = metadata["total_bytes_recv"] 

721 

722 return 

723 

724 @property 

725 def is_receiving(self): 

726 return self._is_receiving 

727 

728 

729class TimeMonitor(): 

730 """This class implements a time monitor object to make sure that 

731 during the offline phase of the training the server keeps sending 

732 PINGs to the launcher at least every time_delay (units in sec) 

733 """ 

734 

735 def __init__(self, time: float, time_delay: float) -> None: 

736 self.last_ping = time 

737 self.time_delay = time_delay 

738 

739 def check_clock(self, time: float, server: BaseServer) -> None: 

740 if time - self.last_ping > self.time_delay: 

741 server.run(10) 

742 self.last_ping = time 

743 else: 

744 pass