Coverage for melissa/server/base_server.py: 47%

675 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-11-19 09:33 +0100

1"""This script defines the parent class for melissa server.""" 

2 

3from datetime import timedelta 

4import logging 

5import os 

6import socket 

7import threading 

8import datetime 

9import time 

10from abc import ABC, abstractmethod 

11from enum import Enum 

12from typing import ( 

13 Any, Dict, Optional, 

14 Tuple, Union, List, Type 

15) 

16 

17import psutil 

18import numpy as np 

19import zmq 

20from mpi4py import MPI 

21import cloudpickle 

22import rapidjson 

23try: 

24 import debugpy 

25except ModuleNotFoundError: 

26 pass 

27 

28 

29from iterative_stats.iterative_moments import IterativeMoments 

30from melissa.launcher import config, message 

31from melissa.scheduler import job 

32from melissa.server.fault_tolerance import FaultTolerance 

33from melissa.server.parameters import ( 

34 BaseExperiment, ParameterSamplerType, 

35 ParameterSamplerClass, make_parameter_sampler 

36) 

37from melissa.utility.message import Message 

38from melissa.server.message import ConnectionRequest, ConnectionResponse 

39from melissa.server.simulation import ( 

40 Group, PartialSimulationData, 

41 Simulation, SimulationData, 

42 SimulationDataStatus 

43) 

44from melissa.utility.networking import ( 

45 LengthPrefixFramingDecoder, 

46 LengthPrefixFramingEncoder, 

47 connect_to_launcher, 

48 select_protocol, 

49 is_port_in_use, 

50 is_launcher_socket_alive_and_ready 

51) 

52from melissa.server.exceptions import ( 

53 FatalError, 

54 InitialConnectionError, 

55 UnsupportedProtocol, 

56 ConfigurationFileError, 

57 FaultToleranceError 

58) 

59from melissa.utility.rank_helper import ClusterEnvironment, initialize_sampling_rank 

60from melissa.utility.timer import Timer 

61from melissa.utility.client_scripts import get_client_script_path 

62from melissa.utility.logger import configure_logger, get_log_level_from_verbosity 

63 

64 

65logger = logging.getLogger(__name__) 

66 

67 

68class ServerStatus(Enum): 

69 """Server status enum.""" 

70 CHECKPOINT = 1 

71 TIMEOUT = 2 

72 

73 

74def bytes_to_readable(total_bytes: int) -> str: 

75 """Returns human-readable byte representation.""" 

76 final_value = float(total_bytes) 

77 unit = "bytes" 

78 if total_bytes >= 1024: 

79 final_value /= 1024 

80 unit = "KB" 

81 if total_bytes >= pow(1024, 2): 

82 final_value /= 1024 

83 unit = "MB" 

84 if total_bytes >= pow(1024, 3): 

85 final_value /= 1024 

86 unit = "GB" 

87 if unit == "bytes": 

88 return f"{int(final_value)} {unit}" 

89 return f"{final_value:.3f} {unit}" 

90 

91 

92class BaseServer(ABC): 

93 """`BaseServer` class that handles the following tasks: 

94 

95 - Manages connections with the launcher and clients. 

96 - Generates client scripts for simulations. 

97 - Encodes and decodes messages between the server and clients. 

98 - Provides basic checkpointing functionality to save and restore states. 

99 

100 ### Parameters 

101 - **config_dict** (`Dict[str, Any]`): A dictionary containing configuration settings for 

102 initializing the server. 

103 - **checkpoint_file** (`str`, optional): The filename for the checkpoint file 

104 (default is `"checkpoint.pkl"`). This file is used for saving and restoring the server's state. 

105 

106 ### Attributes 

107 - **comm** (`MPI.Intracomm`): The MPI communicator for inter-process communication. 

108 - **rank** (`int`): The rank of the current process in the MPI communicator. 

109 - **comm_size** (`int`): The total number of server processes in the MPI communicator. 

110 - **client_comm_size** (`int`): The total number of client processes. 

111 - **server_processes** (`int`): Synonym for `comm_size`. 

112 - **connection_port** (`int`): The server port to establish request-response connection with 

113 the clients. 

114 - **data_puller_port** (`int`): The server port to establish data pulling with the clients. 

115 

116 - **_offline_mode** (`bool`): Internal flag indicating offline mode where no sending operation 

117 takes place. Useful when running multiple clients to produce datasets. 

118 Call `self.make_study_offline` to enable. 

119 - **_learning** (`int`): Internal flag indicating the learning state (initially 0). 

120 - **__t0** (`float`): The timestamp marking the initialization time of the object. 

121 - **_job_limit** (`int`): Maximum number of jobs the launcher can manage concurrently. 

122 - **__is_direct_scheduler** (`bool`): Flag indicating whether the study is using a direct 

123 scheduler. 

124 

125 - **_restart** (`int`): Flag indicating if the system is in a restart state; initialized 

126 from the `MELISSA_RESTART` environment variable. 

127 - **_consistency_lock** (`threading.RLock`): Reentrant lock to ensure thread-safe operations 

128 on shared resources. 

129 - **_is_receiving** (`bool`): Flag indicating whether data reception is ongoing. 

130 - **_is_online** (`bool`): Flag indicating if the system is in an online operational mode. 

131 - **_sobol_op** (`bool`): Flag indicating whether Sobol operations are being performed. 

132 - **_total_bytes_recv** (`int`): Tracks the total number of bytes received over the network. 

133 - **_active_sim_ids** (`set`): Set of active simulation ids currently being managed. 

134 

135 - **_groups** (`Dict[int, Group]`): Dictionary mapping group ids to `Group` objects. 

136 - **_parameter_sampler** (`Optional[BaseExperiment]`): Sampler for generating parameter values 

137 for simulations. 

138 - **__parameter_generator** (`Any`): Internal generator object for producing parameters. 

139 

140 - **verbose_level** (`int`): Determines the verbosity level for logging and debugging output. 

141 - **config_dict** (`Dict[str, Any`]): Configuration dictionary provided during initialization. 

142 - **checkpoint_file** (`str`): File name used for storing checkpoint data. 

143 

144 - **crashes_before_redraw** (`int`): Number of simulation crashes allowed before 

145 redrawing parameters. 

146 - **max_delay** (`Union[int, float]`): Maximum allowed delay for simulations, in seconds. 

147 - **rm_script** (`bool`): Indicates whether client scripts should be removed after execution. 

148 - **group_size** (`int`): Number of simulations grouped together for batch processing. 

149 - **zmq_hwm** (`int`): High-water mark for ZeroMQ communication. 

150 

151 - **fields** (`List[str]`): List of field names used in the study. 

152 - **nb_parameters** (`int`): Number of parameters in the parameter sweep study. 

153 - **nb_time_steps** (`int`): Number of time steps in each simulation. 

154 - **nb_clients** (`int`): Total number of clients participating in the parameter sweep study. 

155 

156 - **nb_groups** (`int`): Total number of groups, derived from the number of clients 

157 and group size. 

158 - **nb_submitted_groups** (`int`): Tracks the number of groups submitted so far. 

159 **finished_groups** (`set`): Tracks the finished set of groups. 

160 - **mtt_simulation_completion** (`float`): Iteratively keeps track of mean of simulation 

161 durations. 

162 

163 - **no_fault_tolerance** (`bool`): Indicates whether fault tolerance is disabled, 

164 based on the `MELISSA_FAULT_TOLERANCE` environment variable. 

165 - **__ft** (`FaultTolerance`): Fault tolerance object managing simulation 

166 crashes and retries.""" 

167 def __init__(self, 

168 config_dict: Dict[str, Any], 

169 checkpoint_file: str = "checkpoint.pkl") -> None: 

170 

171 # MPI initialization 

172 cluster = ClusterEnvironment() 

173 self.comm: MPI.Comm = cluster.comm_world 

174 self.rank: int = cluster.comm_world_rank 

175 self.comm_size: int = cluster.comm_world_size 

176 self.server_processes: int = self.comm_size 

177 self.client_comm_size: int = 0 

178 self.__connection_port: int = 2003 

179 self.__data_puller_port: int = 5000 

180 self.__connected_with_launcher: bool = False 

181 

182 self._offline_mode: bool = False 

183 self._learning: int = 0 

184 self._bind_simulation_to_server_rank: int = 0 

185 self.__t0: float = time.time() 

186 self._job_limit: int = config_dict["launcher_config"]["job_limit"] - 1 

187 self.__is_direct_scheduler: bool = config_dict["launcher_config"]["scheduler"] == "openmpi" 

188 

189 self._restart: int = int(os.environ["MELISSA_RESTART"]) 

190 self._consistency_lock: threading.RLock = threading.RLock() 

191 self._is_receiving: bool = False 

192 self._is_online: bool = False 

193 self._sobol_op: bool = False 

194 self._total_bytes_recv: int = 0 

195 self._active_sim_ids: set = set() 

196 self._groups: Dict[int, Group] = {} 

197 self._parameter_sampler: Optional[BaseExperiment] = None 

198 

199 self.config_dict: Dict[str, Any] = config_dict 

200 self.checkpoint_file: str = checkpoint_file 

201 study_options: Dict[str, Any] = self.config_dict["study_options"] 

202 self.verbose_level: int = study_options["verbosity"] 

203 

204 # Scan study options dictionary 

205 self.crashes_before_redraw: int = study_options.get("crashes_before_redraw", 1) 

206 self.max_delay: Union[int, float] = study_options.get("simulation_timeout", 60) 

207 self.rm_script: bool = study_options.get("remove_client_scripts", False) 

208 self.group_size: int = 1 

209 self.zmq_hwm: int = study_options.get("zmq_hwm", 0) 

210 try: 

211 self.fields: List[str] = study_options["field_names"] 

212 self.nb_parameters: int = study_options["nb_parameters"] 

213 self.nb_clients: int = study_options["parameter_sweep_size"] 

214 except KeyError as e: 

215 logger.error(f"[INCORRECT] Key not found in the configuration: {e}") 

216 raise ConfigurationFileError 

217 

218 self.nb_time_steps: int = max(study_options.get("nb_time_steps", 0), 0) 

219 if self.nb_time_steps == 0: 

220 logger.warning(f"Rank {self.rank}>> Number of timesteps to be received not provided.") 

221 

222 self.nb_groups: int = self.nb_clients // self.group_size 

223 self.nb_submitted_groups: int = 0 

224 self.finished_groups: set = set() 

225 self.mtt_simulation_completion: IterativeMoments = IterativeMoments(max_order=1, dim=1) 

226 self._job_limit = min(self._job_limit, self.nb_groups) 

227 

228 # Fault-Tolerance initialization 

229 self.no_fault_tolerance: bool = os.environ["MELISSA_FAULT_TOLERANCE"] == "OFF" 

230 

231 self.ignore_client_death: bool = config_dict.get("ignore_client_death", False) 

232 if self.ignore_client_death and self.no_fault_tolerance is False: 

233 raise ConfigurationFileError("Client deaths cannot be ignored if Fault-tolerance is ON") 

234 

235 logger.info("fault-tolerance " + ("OFF" if self.no_fault_tolerance else "ON")) 

236 self.__ft: FaultTolerance = FaultTolerance( 

237 self.no_fault_tolerance, 

238 self.max_delay, 

239 self.crashes_before_redraw, 

240 self.nb_groups, 

241 ) 

242 

243 # this lets us ping the launcher periodically 

244 # ensuring the launcher does not assume the server is dead. 

245 self.__ping_interval: int = min( 

246 10, 

247 config_dict["launcher_config"].get("server_timeout", 10) 

248 ) 

249 self.__run_handler_thread: bool = False 

250 self.__pinger_thread: threading.Thread = threading.Thread( 

251 name="pinger", target=self.__loop_pings 

252 ) 

253 

254 if self.rank == 0: 

255 # make a directory for the checkpoint files if one does not exist 

256 os.makedirs("checkpoints", exist_ok=True) 

257 metadata = {"MELISSA_RESTART": int(os.environ["MELISSA_RESTART"])} 

258 with open("checkpoints/restart_metadata.json", 'wb') as f: 

259 # write the metadata to json 

260 rapidjson.dump(metadata, f) 

261 self.comm.Barrier() 

262 self.ckpt_metadata_path: str = f"checkpoints/{self.rank}/metadata.pkl" 

263 os.makedirs(f"checkpoints/{self.rank}", exist_ok=True) 

264 

265 # a runtime decision on which rank becomes the sampling rank 

266 initialize_sampling_rank() 

267 

268 @property 

269 def offline_mode(self) -> bool: 

270 return self._offline_mode 

271 

272 def make_study_offline(self) -> None: 

273 self._offline_mode = True 

274 logger.warning( 

275 f"Rank {self.rank}>> Currently running with offline mode. " 

276 "No reception will take place." 

277 ) 

278 

279 @property 

280 def time_steps_known(self) -> bool: 

281 """Time steps are known prior study or not.""" 

282 return self.nb_time_steps > 0 

283 

284 @property 

285 def nb_finished_groups(self) -> int: 

286 return len(self.finished_groups) 

287 

288 @property 

289 def is_direct_scheduler(self) -> bool: 

290 """Study is using a direct scheduler or not.""" 

291 return self.__is_direct_scheduler 

292 

293 @property 

294 def learning(self) -> int: 

295 """Deep learning activated? 

296 Required when establishing a connection with clients.""" 

297 return self._learning 

298 

299 @property 

300 def consistency_lock(self) -> threading.RLock: 

301 """Useful for active sampling.""" 

302 return self._consistency_lock 

303 

304 @property 

305 def is_receiving(self) -> bool: 

306 return self._is_receiving 

307 

308 @is_receiving.setter 

309 def is_receiving(self, value: bool): 

310 self._is_receiving = value 

311 

312 @property 

313 def is_online(self) -> bool: 

314 return self._is_online 

315 

316 @is_online.setter 

317 def is_online(self, value: bool): 

318 self._is_online = value 

319 

320 @property 

321 def sobol_op(self) -> bool: 

322 return self._sobol_op 

323 

324 @sobol_op.setter 

325 def sobol_op(self, value: bool): 

326 self._sobol_op = value 

327 

328 @property 

329 def parameter_sampler(self) -> Optional[BaseExperiment]: 

330 return self._parameter_sampler 

331 

332 @parameter_sampler.setter 

333 def parameter_sampler(self, value: Optional[BaseExperiment]): 

334 self._parameter_sampler = value 

335 

336 def __loop_pings(self) -> None: 

337 """Maintains communication with the launcher to ensure it 

338 does not assume the server has become unresponsive.""" 

339 

340 while self.__run_handler_thread: 

341 self._launcherfd.send(self._encode_msg(message.Ping())) 

342 logger.debug(f"Rank {self.rank}>> pinging launcher.") 

343 time.sleep(self.__ping_interval) 

344 

345 def _start_pinger_thread(self) -> None: 

346 """Starts the pinger thread and set the flag.""" 

347 

348 if self.rank == 0: 

349 assert threading.current_thread() != self.__pinger_thread 

350 self.__run_handler_thread = True 

351 if not self.__pinger_thread.is_alive(): 

352 self.__pinger_thread.start() 

353 

354 def _stop_pinger_thread(self) -> None: 

355 """Stops the pinger thread and unsets the flag.""" 

356 

357 if self.rank == 0: 

358 assert threading.current_thread() != self.__pinger_thread 

359 self.__run_handler_thread = False 

360 if self.__pinger_thread.is_alive(): 

361 self.__pinger_thread.join(timeout=1.0) 

362 

363 def _save_base_state(self) -> None: 

364 """Checkpoints all common attributes in the server class to preserve the current state.""" 

365 

366 if self.no_fault_tolerance: 

367 return 

368 

369 self.comm.Barrier() 

370 

371 # save some state metadata to be reloaded later 

372 with self.consistency_lock: 

373 metadata = { 

374 "nb_groups": self.nb_groups, 

375 "nb_submitted_groups": self.nb_submitted_groups, 

376 "finished_groups": self.finished_groups, 

377 "groups": self._groups, 

378 "t0": self.__t0, 

379 "total_bytes_recv": self._total_bytes_recv 

380 } 

381 

382 with open(self.ckpt_metadata_path, 'wb') as f: 

383 cloudpickle.dump(metadata, f) 

384 

385 if self.parameter_sampler is not None: 

386 self.parameter_sampler.checkpoint_state() 

387 

388 def _load_base_state(self) -> None: 

389 """Loads all common attributes in the server class from a checkpoint or saved state.""" 

390 

391 if self.no_fault_tolerance: 

392 return 

393 

394 try: 

395 # load the metadata 

396 with open(self.ckpt_metadata_path, 'rb') as f: 

397 metadata = cloudpickle.load(f) 

398 except FileNotFoundError as e: 

399 raise FatalError( 

400 f"Fault-tolerance must be set for checkpointing.\n{e}" 

401 ) 

402 

403 self.nb_groups = metadata["nb_groups"] 

404 self.nb_submitted_groups = metadata["nb_submitted_groups"] 

405 self.finished_groups = metadata["finished_groups"] 

406 self._groups = metadata["groups"] 

407 self.__t0 = metadata["t0"] 

408 self._total_bytes_recv = metadata["total_bytes_recv"] 

409 

410 if self.parameter_sampler is not None: 

411 self.parameter_sampler.restart_from_checkpoint() 

412 

413 def __initialize_ports(self, 

414 connection_port: int = 2003, 

415 data_puller_port: int = 5000) -> None: 

416 """Assigns port numbers for connection and data pulling as class attributes. 

417 If the specified ports are already in use, likely due to multiple servers running 

418 on the same node, the function attempts to find available ports by incrementing the 

419 base port values and rechecking their availability. 

420 

421 _Note: When multiple independent `melissa-server` jobs are running simultaneously 

422 on the same node, there is a chance that a port may incorrectly appear as available, 

423 leading to potential conflicts._ 

424 

425 ### Parameters 

426 - **connection_port** (`int`, optional): The port number used for establishing the main 

427 connection (default is `2003`). 

428 - **data_puller_port** (`int`, optional): The port number used for pulling data 

429 (default is `5000`). 

430 

431 ### Raises 

432 - `FatalError`: If no ports were found after given number of attempts.""" 

433 

434 # Ports initialization 

435 logger.info(f"Rank {self.rank}>> Initializing server...") 

436 self.node_name = socket.gethostname() 

437 

438 attempts = max(10, self.comm_size) 

439 if self.rank == 0: 

440 self.__connection_port = connection_port 

441 i = 0 

442 while is_port_in_use(self.__connection_port) and i < attempts: 

443 logger.warning( 

444 f"Rank {self.rank}>> Connection port {self.__connection_port} in use. " 

445 "Trying another..." 

446 ) 

447 self.__connection_port += 1 

448 i += 1 

449 

450 if i == attempts: 

451 logger.error( 

452 f"{self.rank}>> Could not find an available connection port after " 

453 f"{attempts} attempts." 

454 ) 

455 raise InitialConnectionError 

456 

457 # Set data puller port 

458 self.__data_puller_port = data_puller_port + (self.rank * (attempts + 1)) 

459 i = 0 

460 while is_port_in_use(self.__data_puller_port) and i < attempts: 

461 logger.warning(f"Rank {self.rank}>> Data puller port {self.__data_puller_port} in use. " 

462 "Trying another...") 

463 self.__data_puller_port += 1 

464 i += 1 

465 

466 if i == attempts: 

467 logger.error( 

468 f"{self.rank}>> Could not find an available data puller port after " 

469 f"{attempts} attempts." 

470 ) 

471 raise InitialConnectionError 

472 

473 self.__data_puller_port_name = f"tcp://{self.node_name}:{self.__data_puller_port}" 

474 self._port_names = self.comm.allgather(self.__data_puller_port_name) 

475 logger.debug(f"port_names {self._port_names}") 

476 

477 def __connect_to_launcher(self) -> None: 

478 """Establishes a connection with the launcher and sends metadata about the study.""" 

479 

480 self._launcherfd: socket.socket 

481 # Setup communication instances 

482 self.__protocol, prot_name = select_protocol() 

483 logger.info(f"Server/launcher communication protocol: {prot_name}") 

484 if self.rank == 0: 

485 self._launcherfd = connect_to_launcher() 

486 self._launcherfd.send(self._encode_msg(message.CommSize(self.comm_size))) 

487 logger.debug(f"Rank {self.rank}>> Comm size {self.comm_size} sent to launcher") 

488 self._launcherfd.send(self._encode_msg(message.GroupSize(self.group_size))) 

489 logger.debug(f"Rank {self.rank}>> Group size {self.group_size} sent to launcher") 

490 # synchronize non-zero ranks after rank 0 connection to make sure 

491 # these ranks only connect after comm_size is known to the launcher 

492 self.comm.Barrier() 

493 

494 if self.rank > 0: 

495 # avoiding simultaneous connection that could cause 

496 # race conditions for FDs on supercomputers 

497 time.sleep(0.001 * self.rank) 

498 self._launcherfd = connect_to_launcher() 

499 # for i in range(1, self.comm_size): 

500 # if self.rank == i: 

501 # self._launcherfd = connect_to_launcher() 

502 # self.comm.Barrier() 

503 

504 all_fds_ready = self.comm.allreduce( 

505 is_launcher_socket_alive_and_ready(self._launcherfd), 

506 op=MPI.LAND 

507 ) 

508 if not all_fds_ready: 

509 raise InitialConnectionError( 

510 "Some ranks failed to connect with the launcher." 

511 ) 

512 

513 logger.debug(f"Rank {self.rank}>> Launcher fd set up: {self._launcherfd.fileno()}") 

514 self.__connected_with_launcher = True 

515 

516 def __setup_sockets(self) -> None: 

517 """Sets up ZeroMQ (ZMQ) sockets over a given TCP connection port for communication.""" 

518 

519 self.__zmq_context = zmq.Context() 

520 # Simulations (REQ) <-> Server (REP) 

521 self._connection_responder = self.__zmq_context.socket(zmq.REP) 

522 if self.rank == 0: 

523 addr1 = f"tcp://*:{self.__connection_port}" 

524 try: 

525 self._connection_responder.bind(addr1) 

526 except InitialConnectionError as e: 

527 raise e 

528 logger.info( 

529 f"Rank {self.rank}>> Binding to {addr1} successful." 

530 ) 

531 

532 # Simulations (PUSH) -> Server (PULL) 

533 self.__data_puller = self.__zmq_context.socket(zmq.PULL) 

534 self.__data_puller.setsockopt(zmq.RCVHWM, self.zmq_hwm) 

535 self.__data_puller.setsockopt(zmq.RCVBUF, 4 * 1024 ** 2) 

536 self.__data_puller.setsockopt(zmq.LINGER, -1) 

537 addr2 = f"tcp://*:{self.__data_puller_port}" 

538 try: 

539 self.__data_puller.bind(addr2) 

540 except InitialConnectionError as e: 

541 raise e 

542 logger.info( 

543 f"Rank {self.rank}>> Data puller to {addr2} successful." 

544 ) 

545 

546 # Time-out checker (creates thread) 

547 self.__timerfd_0, self.__timerfd_1 = socket.socketpair( 

548 socket.AF_UNIX, socket.SOCK_STREAM 

549 ) 

550 timer = Timer(self.__timerfd_1, timedelta(seconds=self.max_delay)) 

551 self.__t_timer = threading.Thread(target=timer.run, daemon=True) 

552 self.__t_timer.start() 

553 

554 def __setup_poller(self) -> None: 

555 """This method sets up the polling mechanism by registering three important sockets: 

556 - **Data Socket**: Handles data communication. 

557 - **Timer Socket**: Manages timing events. 

558 - **Launcher Socket**: Facilitates communication with the launcher.""" 

559 

560 self.__zmq_poller = zmq.Poller() 

561 self.__zmq_poller.register(self.__data_puller, zmq.POLLIN) 

562 self.__zmq_poller.register(self.__timerfd_0, zmq.POLLIN) 

563 self.__zmq_poller.register(self._launcherfd, zmq.POLLIN) 

564 if self.rank == 0: 

565 self.__zmq_poller.register(self._connection_responder, zmq.POLLIN) 

566 

567 def __start_debugger(self) -> None: 

568 """Launches the Visual Studio Code (VSCode) debugger for debugging purposes.""" 

569 

570 # 5678 is the default attach port that we recommend users to set 

571 # in the documentation. 

572 debugpy.listen(5678) 

573 logger.warning("Waiting for debugger attach, please start the " 

574 "debugger by navigating to debugger pane ctrl+shift+d " 

575 "and selecting\n" 

576 "Python: Remote Attach") 

577 debugpy.wait_for_client() 

578 logger.info("Debugger successfully attached.") 

579 # send message to launcher to ensure debugger doesnt timeout 

580 snd_msg = self._encode_msg(message.StopTimeoutMonitoring()) 

581 self._launcherfd.send(snd_msg) 

582 

583 def configure_logger(self) -> None: 

584 """Configures server loggers for each MPI rank.""" 

585 log_level = get_log_level_from_verbosity(self.verbose_level) 

586 app_str = f"_restart_{self._restart}" if self._restart else "" 

587 configure_logger(f"melissa_server_{self.rank}{app_str}.log", log_level) 

588 

589 def initialize_connections(self) -> None: 

590 """Initializes socket connections for communication.""" 

591 

592 self.__initialize_ports() 

593 self.__connect_to_launcher() 

594 self.__setup_sockets() 

595 self.__setup_poller() 

596 

597 if self.config_dict.get("vscode_debugging", False): 

598 self.__start_debugger() 

599 

600 def _get_group_id_by_simulation(self, sim_id: int) -> int: 

601 """Returns group id of the given simulation id.""" 

602 return sim_id // self.group_size 

603 

604 def _get_sim_id_list_by_group(self, group_id: int) -> List[int]: 

605 """Returns a list of all simulation ids for a given group id.""" 

606 first_id = group_id * self.group_size 

607 last_id = first_id + self.group_size 

608 return list(range(first_id, last_id)) 

609 

610 def _get_all_sim_ids(self): 

611 """Yields all simulation ids across all groups.""" 

612 for group_id in self._groups.keys(): 

613 yield from self._get_sim_id_list_by_group(group_id) 

614 

615 def _verify_and_update_sampler_kwargs(self, sampler_t, **kwargs) -> Dict[str, Any]: 

616 """Updates the parameters that were not provided by the user when 

617 creating a sampler using `set_parameter_sampler` method. It also ensures whether 

618 the seed is given or not for a parallel server.""" 

619 

620 # if not provided by the users 

621 if "nb_params" not in kwargs: 

622 kwargs["nb_params"] = self.nb_parameters 

623 if "nb_sims" not in kwargs: 

624 kwargs["nb_sims"] = self.nb_clients 

625 if "seed" not in kwargs: 

626 if hasattr(self, "seed"): 

627 kwargs["seed"] = self.seed 

628 

629 return kwargs 

630 

631 def set_parameter_sampler(self, 

632 sampler_t: Union[ParameterSamplerType, Type[ParameterSamplerClass]], 

633 **kwargs) -> None: 

634 """Sets the defined parameter sampler type. This dictates how parameters are sampled 

635 for experiments. This sampler type can either be pre-defined or customized 

636 by inheriting a pre-defined sampling class. 

637 

638 ### Parameters 

639 - **sampler_t** (`Union[ParameterSamplerType, Type[ParameterSamplerClass]]`): 

640 - `ParameterSamplerType`: Enum specifying pre-defined samplers. 

641 - `Type[ParameterSamplerClass]`: A class type to instantiate. 

642 - **kwargs** (`Dict[str, Any]`): Dictionary of keyword arguments. 

643 Useful to pass custom parameter as well as strict parameter such as 

644 `l_bounds`, `u_bounds`, `apply_pick_freeze`, `second_order`, `seed=0`, etc.""" 

645 kwargs = self._verify_and_update_sampler_kwargs(sampler_t, **kwargs) 

646 self._parameter_sampler = make_parameter_sampler(sampler_t, **kwargs) 

647 

648 def _update_parameter_sampler(self) -> None: 

649 """Updates the existing parameter sampler.""" 

650 

651 if self._parameter_sampler: 

652 self._parameter_sampler.flush_to_disk() 

653 

654 def _launch_groups(self, group_ids: List[int]) -> None: 

655 """Launches the study groups for the very first run. 

656 This process involves generating the client scripts and ensures that 

657 no restart has occurred in the case of fault tolerance. 

658 

659 ### Parameters 

660 - **group_ids** (`List[int]`): A list of group identifiers to launch. 

661 """ 

662 

663 if self.nb_submitted_groups >= self.nb_groups: 

664 return 

665 

666 # Get current working directory containing the client script template 

667 client_script = os.path.abspath("client.sh") 

668 

669 if not os.path.isfile(client_script): 

670 raise FileNotFoundError("error client script not found") 

671 

672 # Generate all client scripts 

673 self._generate_client_scripts(group_ids) 

674 

675 # Launch every group 

676 # note that the launcher message stacking feature does not work 

677 for group_id in group_ids: 

678 self._launch_group(group_id) 

679 

680 def _get_client_script_path(self, sim_id: int) -> str: 

681 return get_client_script_path(sim_id) 

682 

683 def _generate_client_scripts(self, 

684 group_ids: List[int], 

685 create_new_group: bool = False) -> None: 

686 """Creates all required client scripts (e.g., `client.X.sh`), 

687 and sets up a dictionary for fault tolerance. 

688 

689 ### Parameters 

690 - **group_ids** (`List[int]`): A list of group identifiers. 

691 - **create_new_group** (`bool`, optional): Flag indicating whether to 

692 create a new group of clients (default is `False`).""" 

693 

694 for group_id in group_ids: 

695 for sim_id in self._get_sim_id_list_by_group(group_id): 

696 assert self._parameter_sampler is not None 

697 parameters = list(self._parameter_sampler.draw(sim_id)) 

698 

699 script_path = self._get_client_script_path(sim_id) 

700 fname = os.path.basename(script_path) 

701 if self.rank == 0: 

702 self.__generate_client_script(sim_id, parameters, script_path) 

703 if sim_id > 0 and sim_id % 1000 == 0: 

704 snd_msg = self._encode_msg(message.Ping()) 

705 self._launcherfd.send(snd_msg) 

706 

707 logger.info( 

708 f"Rank {self.rank}>> Created {fname} with parameters {parameters}" 

709 ) 

710 

711 # Fault-tolerance dictionary creation and update 

712 # create_new_group is specifically for active sampling 

713 if create_new_group or group_id not in self._groups: 

714 self._groups[group_id] = Group(group_id, self.sobol_op) 

715 if sim_id not in self._groups[group_id].simulations: 

716 self._groups[group_id].simulations[sim_id] = Simulation( 

717 sim_id, 

718 script_path, 

719 self.nb_time_steps, 

720 self.fields, 

721 parameters 

722 ) 

723 self._groups[group_id].simulations[sim_id].last_message = None 

724 

725 def __generate_client_script(self, 

726 sim_id: int, 

727 parameters: List[Any], 

728 script_path: str) -> None: 

729 """Generates a single client script for a given simulation id and parameters. 

730 

731 ### Parameters 

732 - **sim_id** (`int`): The simulation id associated with the client script. 

733 - **parameters** (`list`): The list of parameters. 

734 - **script_path** (`str`): The absolute path of the client script to create. 

735 """ 

736 

737 if self.rank == 0: 

738 with open(script_path, "w") as f: 

739 print("#!/bin/sh", file=f) 

740 self._write_environment_variables(f, sim_id) 

741 self._write_execution_command(f, parameters) 

742 

743 os.chmod(script_path, 0o744) 

744 

745 def _write_environment_variables(self, f: Any, sim_id: int) -> None: 

746 """Writes environment variables to the client script. 

747 

748 ### Parameters 

749 - **f** (`Any`): The file object to write to. 

750 - **sim_id** (`int`): The simulation id associated with the client script. 

751 """ 

752 print("exec env \\", file=f) 

753 print(f" MELISSA_VERBOSE={self.verbose_level} \\", file=f) 

754 print(f" MELISSA_SIMU_ID={sim_id} \\", file=f) 

755 print(f" MELISSA_SERVER_NODE_NAME={self.node_name} \\", file=f) 

756 print(f" MELISSA_SERVER_PORT={self.__connection_port} \\", file=f) 

757 

758 def _write_execution_command(self, f: Any, parameters: List[Any]) -> None: 

759 """Writes the execution command to the client script. 

760 

761 ### Parameters 

762 - **f** (`Any`): The file object to write to. 

763 - **parameters** (`list`): The list of parameters. 

764 """ 

765 if self.rm_script: 

766 print( 

767 " " 

768 + " ".join( 

769 [os.path.join(os.getcwd(), "client.sh")] 

770 + [ 

771 np.format_float_positional(x) if not isinstance(x, str) 

772 else x for x in parameters 

773 ] 

774 ) 

775 + " &", 

776 file=f, 

777 ) 

778 print(" wait", file=f) 

779 print(' rm "$0"', file=f) 

780 else: 

781 print( 

782 " " 

783 + " ".join( 

784 [os.path.join(os.getcwd(), "client.sh")] 

785 + [ 

786 np.format_float_positional(x) if not isinstance(x, str) 

787 else x for x in parameters 

788 ] 

789 ), 

790 file=f, 

791 ) 

792 

793 def _launch_group(self, group_id: int) -> None: 

794 """Submits a request to the launcher to run a given group id. 

795 For non-Sobol studies, the group id and simulation id are the same. 

796 

797 ### Parameters 

798 - **group_id** (`int`): The unique identifier of the group to be launched. 

799 """ 

800 if self.rank == 0: 

801 # Job submission message to launcher (initial_id,num_jobs) 

802 for sim_id in self._get_sim_id_list_by_group(group_id): 

803 snd_msg = self._encode_msg(message.JobSubmission(sim_id, 1)) 

804 self._launcherfd.send(snd_msg) 

805 # snd_msg = self._encode_msg(message.JobSubmission(group_id, 1)) 

806 # self._launcherfd.send(snd_msg) 

807 logger.debug( 

808 f"Rank {self.rank}>> group " 

809 f"{group_id + 1}/{self.nb_groups} " 

810 "submitted to launcher" 

811 ) 

812 

813 self._groups[group_id].submitted = True 

814 for _, simulation in self._groups[group_id].simulations.items(): 

815 simulation.connected = False 

816 

817 self.nb_submitted_groups += 1 

818 

819 def _kill_group(self, group_id: int) -> None: 

820 """Submits a request to the launcher to terminate a given group id. 

821 

822 ### Parameters 

823 - **group_id** (`int`): The unique identifier of the group to be terminated.""" 

824 

825 group = self._groups[group_id] 

826 if self.rank == 0: 

827 logger.warning( 

828 f"[RESTART] Resubmitting incomplete group-id={group_id} to the launcher." 

829 ) 

830 snd_msg = self._encode_msg(message.JobCancellation(group_id)) 

831 self._launcherfd.send(snd_msg) 

832 

833 group.submitted = False 

834 for sim_id in group.simulations: 

835 group.simulations[sim_id].connected = False 

836 self.nb_submitted_groups -= 1 

837 

838 def _relaunch_group(self, group_id: int, create_new_group: bool) -> None: 

839 """Relaunches a failed group with or without new parameters, 

840 depending on the fault tolerance configuration. 

841 

842 ### Parameters 

843 - **group_id** (`int`): The unique identifier of the group to be relaunched. 

844 - **create_new_group** (`bool`): A flag indicating whether to create a new group 

845 with new parameters.""" 

846 

847 self._generate_client_scripts([group_id], create_new_group) 

848 assert not self._groups[group_id].has_finished() 

849 self._kill_group(group_id) 

850 self._launch_group(group_id) 

851 

852 def _handle_simulation_connection(self, msg: bytes) -> int: 

853 """Handles an incoming connection request from a submitted simulation. 

854 This method is executed by rank 0 only. 

855 

856 ### Parameters 

857 - **msg** (`bytes`): The message received from the simulation requesting a connection. 

858 

859 ### Returns 

860 - `int`: The simulation id of the connected simulation, or `-1` if the connection 

861 could not be established.""" 

862 

863 request = ConnectionRequest.recv(msg) 

864 self.client_comm_size = request.comm_size 

865 sim_id = request.simulation_id 

866 

867 # a corner case which may not happen 

868 # at this point, it is expected that 

869 # the group/simulation is already running 

870 group_id = self._get_group_id_by_simulation(sim_id) 

871 if group_id not in self._groups or not self._groups[group_id].submitted: 

872 logger.warning(f"Rank {self.rank}>> group-id={group_id} does not exist.") 

873 return -1 

874 

875 logger.debug( 

876 f"Rank {self.rank}>> [Connection] received connection message " 

877 f"from sim-id={sim_id} with client-comm-size={self.client_comm_size}." 

878 ) 

879 logger.debug( 

880 f"Rank {self.rank}>> [Connection] sending response to sim-id={sim_id}" 

881 f" with learning={self._learning}" 

882 ) 

883 response = ConnectionResponse( 

884 self.comm_size, 

885 self._learning, 

886 int(self._bind_simulation_to_server_rank), 

887 self.nb_parameters, 

888 self._port_names, 

889 ) 

890 self._connection_responder.send(response.encode()) 

891 logger.info( 

892 f"Rank {self.rank}>> [Connection] sim-id={sim_id} established." 

893 ) 

894 self._groups[group_id].simulations[sim_id].connected = True 

895 

896 return sim_id 

897 

898 def _restart_groups(self) -> None: 

899 """Kills and restarts simulations that were running when the server crashed.""" 

900 

901 resubmitted_cnt: int = 0 

902 

903 for group_id, group in self._groups.items(): 

904 if group.submitted and not group.has_finished(): 

905 self.nb_submitted_groups -= 1 

906 for sim_id, sim in group.simulations.items(): 

907 # regenerate scripts as some things might change 

908 # such as environment variables pointing to the host:port 

909 # of the newly launched server. 

910 self.__generate_client_script( 

911 sim_id, 

912 sim.parameters, 

913 sim.script_path 

914 ) 

915 self._launch_group(group_id) 

916 resubmitted_cnt += 1 

917 

918 # possible corner case where all were presumed to be submitted 

919 # then we need to initiate at least one new submission 

920 if resubmitted_cnt == 0: 

921 self._launch_groups([self.nb_submitted_groups]) 

922 

923 def poll_sockets(self, 

924 timeout: int = 10, 

925 ) -> Optional[Union[ServerStatus, SimulationData, PartialSimulationData]]: 

926 """Performs polling over the registered socket descriptors to monitor various events, 

927 including timer, launcher messages, new client connections, and data readiness. 

928 

929 ### Parameters 

930 - **timeout** (`int`, optional): The maximum time (in seconds) to wait for a socket 

931 event before returning. Default is `10` seconds. 

932 

933 ### Returns 

934 - `Optional[Union[ServerStatus, SimulationData, PartialSimulationData]]`: 

935 - `ServerStatus` if the event is related to server status. 

936 - `SimulationData` if new simulation data is received. 

937 - `PartialSimulationData` if partial data from a simulation is received.""" 

938 

939 # 1. Poll sockets 

940 # ZMQ sockets 

941 sockets = dict(self.__zmq_poller.poll(timeout)) 

942 if not sockets: 

943 return ServerStatus.TIMEOUT 

944 

945 if self.rank == 0: 

946 # 2. Look for connections from new simulations (just at rank 0) 

947 if ( 

948 self._connection_responder in sockets 

949 and sockets[self._connection_responder] == zmq.POLLIN 

950 ): 

951 msg = self._connection_responder.recv() 

952 logger.debug(f"Rank {self.rank}>> Handle client connection request") 

953 self._handle_simulation_connection(msg) 

954 

955 # 3. Handle launcher message 

956 # this is a TCP/SCTP socket not a zmq one, so handled differently 

957 if self._launcherfd.fileno() in sockets: 

958 logger.debug(f"Rank {self.rank}>> Handle launcher message") 

959 self._handle_fd() 

960 

961 # 4. Handle simulation data message 

962 if ( 

963 not self.offline_mode 

964 and self.__data_puller in sockets 

965 and sockets[self.__data_puller] == zmq.POLLIN 

966 ): 

967 logger.debug(f"Rank {self.rank}>> Handle simulation data") 

968 msg = self.__data_puller.recv() 

969 self._total_bytes_recv += len(msg) 

970 return self._handle_simulation_data(msg) 

971 

972 # 5. Handle timer message 

973 if self.__timerfd_0.fileno() in sockets: 

974 logger.debug(f"Rank {self.rank}>> Handle timer message") 

975 self.__handle_timerfd() 

976 

977 return None 

978 

979 def __forceful_group_termination(self, group_id: int) -> None: 

980 """Forcefully terminates all clients in a group.""" 

981 

982 for sim_id in self._get_sim_id_list_by_group(group_id): 

983 sim = self._groups[group_id].simulations[sim_id] 

984 self.__process_simulation_completion(sim, force=True) 

985 

986 def __handle_timerfd(self) -> None: 

987 """Handles timer messages.""" 

988 

989 self.__timerfd_0.recv(1) 

990 try: 

991 self.__ft.check_time_out(self._groups) 

992 except FaultToleranceError as e: 

993 if not self.ignore_client_death: 

994 raise e 

995 

996 for group_id, create_new_group in self.__ft.restart_group.items(): 

997 if not self.ignore_client_death: 

998 self._relaunch_group(group_id, create_new_group) 

999 else: 

1000 self.__forceful_group_termination(group_id) 

1001 

1002 self.__ft.restart_group = {} 

1003 

1004 def __handle_failed_group(self, group_id: int) -> None: 

1005 """Handles failed group by using fault-tolerance to decide resubmission.""" 

1006 

1007 try: 

1008 group = self._groups[group_id] 

1009 create_new_group = self.__ft.handle_failed_group(group_id, group) 

1010 except FaultToleranceError as e: 

1011 if not self.ignore_client_death: 

1012 raise e 

1013 

1014 if not self.ignore_client_death: 

1015 self._relaunch_group(group_id, create_new_group) 

1016 else: 

1017 self.__forceful_group_termination(group_id) 

1018 

1019 def _handle_fd(self) -> None: 

1020 """Handles the launcher's messages.""" 

1021 

1022 bs = self._launcherfd.recv(256) 

1023 rcvd_msg = self._decode_msg(bs) 

1024 

1025 for msg in rcvd_msg: 

1026 # 1. Launcher sent JOB_UPDATE message (msg.job_id <=> group_id) 

1027 if isinstance(msg, message.JobUpdate): 

1028 if msg.job_id not in self._groups: 

1029 continue 

1030 group = self._groups[msg.job_id] 

1031 group_id = group.group_id 

1032 # React to simulation status 

1033 if msg.job_state in [job.State.ERROR, job.State.FAILED]: 

1034 logger.debug(f"Launcher indicates failure of group-id/sim-id={group_id}") 

1035 self.__handle_failed_group(group_id) 

1036 elif msg.job_state is job.State.TERMINATED: 

1037 logger.info( 

1038 f"Rank {self.rank}>> [Termination] group-id/sim-id={group_id}" 

1039 ) 

1040 if not group.has_finished() and not self.ignore_client_death: 

1041 logger.warning( 

1042 f"[Inconsistent State] Launcher reports group-id={group_id} as " 

1043 "terminated, but the server has not marked the group as " 

1044 "finished. The server expects a \"termination\" message from the " 

1045 "group before considering it complete. This may occur if message " 

1046 "delivery is delayed and the launcher detects group exit prematurely." 

1047 ) 

1048 

1049 if self.nb_submitted_groups < self.nb_groups: 

1050 # keep submitting new clients one by one 

1051 current_max_group_id = self.nb_submitted_groups 

1052 self._launch_groups([current_max_group_id]) 

1053 

1054 # 2. Server sends PING 

1055 if self.rank == 0: 

1056 logger.debug( 

1057 "Server got message from launcher and sends PING back" 

1058 ) 

1059 snd_msg = self._encode_msg(message.Ping()) 

1060 self._launcherfd.send(snd_msg) 

1061 

1062 def _decode_msg(self, byte_stream: bytes) -> List[Message]: 

1063 """Deserializes a message based on the specified protocol. 

1064 

1065 ### Parameters 

1066 - **byte_stream** (`bytes`): The byte stream to be deserialized, 

1067 representing the encoded message. 

1068 

1069 ### Returns 

1070 - `List[Message]`: A list of byte sequences representing 

1071 the deserialized message components.""" 

1072 

1073 msg_list = [] 

1074 if self.__protocol == socket.IPPROTO_TCP: 

1075 packets = LengthPrefixFramingDecoder( 

1076 config.TCP_MESSAGE_PREFIX_LENGTH 

1077 ).execute(byte_stream) 

1078 for p in packets: 

1079 msg_list.append(message.deserialize(p)) 

1080 logger.debug(f"Decoded launcher messages {msg_list}") 

1081 return msg_list 

1082 if self.__protocol == socket.IPPROTO_SCTP: 

1083 msg_list.append(message.deserialize(byte_stream)) 

1084 logger.debug(f"Decoded launcher messages {msg_list}") 

1085 return msg_list 

1086 raise UnsupportedProtocol(f"{self.__protocol} not supported for decoding.") 

1087 

1088 def _encode_msg(self, msg: Message) -> bytes: 

1089 """Serializes message based on the specified protocol. 

1090 

1091 ### Parameters 

1092 - **msg** (`Message`): The message to be serialized, 

1093 typically a byte sequence that needs encoding. 

1094 

1095 ### Returns 

1096 - `bytes`: The serialized byte stream representing the encoded message.""" 

1097 

1098 if self.__protocol == socket.IPPROTO_TCP: 

1099 encoded_packet = LengthPrefixFramingEncoder( 

1100 config.TCP_MESSAGE_PREFIX_LENGTH 

1101 ).execute(msg.serialize()) 

1102 return encoded_packet 

1103 if self.__protocol == socket.IPPROTO_SCTP: 

1104 return msg.serialize() 

1105 raise UnsupportedProtocol(f"{self.__protocol} not supported for encoding.") 

1106 

1107 def _all_done(self) -> bool: 

1108 """Checks whether all clients' data has been received and 

1109 unregisters the timer socket if completed. 

1110 

1111 ### Returns 

1112 - `bool`: if all clients' data has been successfully received.""" 

1113 

1114 if self.nb_finished_groups == self.nb_groups: 

1115 # join thread and close timer sockets 

1116 logger.info(f"Rank {self.rank}>> closes timer sockets.") 

1117 self.__timerfd_0.close() 

1118 self.__zmq_poller.unregister(self.__timerfd_0) 

1119 self.__t_timer.join(timeout=1) 

1120 if self.__t_timer.is_alive(): 

1121 logger.warning("timer thread did not terminate") 

1122 else: 

1123 self.__timerfd_1.close() 

1124 return True 

1125 

1126 return False 

1127 

1128 def close_connection(self, exit_: int = 0) -> None: 

1129 """Signals to the launcher that the study has ended with a specified exit status. 

1130 

1131 ### Parameters 

1132 - `exit_` (`int`, optional): The exit status code to be sent to the launcher. 

1133 Defaults to `0`, indicating successful completion.""" 

1134 

1135 if self.rank == 0 and self.__connected_with_launcher: 

1136 self._launcherfd.send( 

1137 self._encode_msg(message.Exit(exit_)) 

1138 ) 

1139 

1140 self.mpi_abort(exit_) 

1141 

1142 def mpi_abort(self, exit_: int = 0) -> None: 

1143 if exit_ > 0: 

1144 logger.error( 

1145 f"Rank {self.rank}>> An error occured on one of the MPI ranks. Aborting the study." 

1146 ) 

1147 self.comm.Abort(exit_) 

1148 

1149 def get_memory_info_in_gb(self) -> Tuple[float, float]: 

1150 """Returns a `Tuple[float, float]` containing memory consumed and 

1151 the total main memory in GB.""" 

1152 

1153 memory = psutil.virtual_memory() 

1154 consumed = memory.used / (1024 ** 3) 

1155 total = memory.total / (1024 ** 3) 

1156 

1157 return consumed, total 

1158 

1159 def _show_insights(self) -> None: 

1160 """Logs information gathered from clients, and server processing.""" 

1161 

1162 seconds = self.mtt_simulation_completion.get_mean()[0] 

1163 mean_hms = str( 

1164 datetime.timedelta( 

1165 seconds=int(seconds) 

1166 ) 

1167 ) 

1168 t = mean_hms if seconds > 1 else f"{seconds:.2} sec." 

1169 logger.info( 

1170 f"Rank {self.rank}>> [Insight] " 

1171 f"Average simulation completion={t}. " 

1172 "Computed based on the reception of samples " 

1173 "and may vary if time-steps are received out-of-order." 

1174 ) 

1175 

1176 consumed, total = self.get_memory_info_in_gb() 

1177 logger.info( 

1178 f"Rank {self.rank}>> [Insight] Memory consumption={consumed:.2f}/{total:.2f} GB." 

1179 ) 

1180 

1181 def _write_final_report(self) -> None: 

1182 """Write miscellaneous information about the analysis.""" 

1183 

1184 # Total time 

1185 total_time = time.time() - self.__t0 

1186 total_time = self.comm.allreduce(total_time, op=MPI.SUM) 

1187 # Total MB received 

1188 total_b = self.comm.allreduce(self._total_bytes_recv, op=MPI.SUM) 

1189 if self.rank == 0: 

1190 msg_bytes = bytes_to_readable(total_b) 

1191 total_hms = datetime.timedelta(seconds=total_time // self.comm_size) 

1192 logger.info( 

1193 " - Number of Finished Groups (Simulations): " 

1194 f"{self.nb_finished_groups}/{self.nb_submitted_groups}" 

1195 ) 

1196 if self.ignore_client_death and len(self.__ft.failed_ids) > 0: 

1197 logger.info( 

1198 " - Number of Failed Groups (Simulations): " 

1199 f"{len(self.__ft.failed_ids)}/{self.nb_submitted_groups}" 

1200 ) 

1201 logger.info(f" - Number of Server Ranks: {self.comm_size}") 

1202 logger.info(f" - Total time: {str(total_hms)}") 

1203 logger.info(f" - Total data received: {msg_bytes}") 

1204 

1205 def _server_finalize(self, exit_: int = 0) -> None: 

1206 """Finalizes the server operations. 

1207 

1208 ### Parameters 

1209 - `exit_` (`int`, optional): The exit status code indicating the outcome 

1210 of the server's operations. Defaults to `0`, which signifies a successful termination.""" 

1211 if exit_ == 0: 

1212 self._write_final_report() 

1213 

1214 self._stop_pinger_thread() 

1215 self.parameter_sampler.finalize(exit_) # type:ignore 

1216 

1217 logger.info(f"Server finalizing with status {exit_}.") 

1218 self.close_connection(exit_) 

1219 

1220 def setup_environment(self) -> None: 

1221 """Optional. A method that sets up the environment or initialization. 

1222 Any necessary setup methods go here. 

1223 For example, Melissa DL study needs `dist.init_process_group` to be called.""" 

1224 return None 

1225 

1226 def _check_simulation_data(self, 

1227 simulation: Simulation, 

1228 simulation_data: PartialSimulationData 

1229 ) -> Tuple[ 

1230 SimulationDataStatus, Union[ 

1231 Optional[SimulationData], 

1232 Optional[PartialSimulationData]]]: 

1233 """Tracks and validates incoming simulation data. 

1234 

1235 1. **Client Rank Initialization**: 

1236 Ensures `simulation_data` structures are initialized per client rank. 

1237 2. **Dynamic Matrix Expansion**: 

1238 Handles unknown sizes dynamically as new time steps are encountered. 

1239 3. **Duplicate Data Detection**: 

1240 Discards messages if the data for the specified field and time step has 

1241 already been received. 

1242 4. **Time Step Completion**: 

1243 - Checks if all fields for a specific time step have been received and processes them 

1244 into a `SimulationData` object. 

1245 - Handles cases where the data is empty. 

1246 5. **Partial Data Handling**: Tracks fields received so far and waits for completion. 

1247 

1248 ### Parameters 

1249 - **simulation** (`Simulation`): Tracks the state and received data of the simulation. 

1250 - **simulation_data** (`PartialSimulationData`): The incoming data message from 

1251 the simulation. 

1252 

1253 ### Returns 

1254 - `SimulationDataStatus`: Status of the simulation data 

1255 (`COMPLETE`, `PARTIAL`, `ALREADY_RECEIVED`, `EMPTY`). 

1256 - `Union[Optional[SimulationData], Optional[PartialSimulationData]]`: 

1257 - Sensitivity Analysis: 

1258 - A `PartialSimulationData` object regardless of it being incomplete 

1259 as SA can be computed independently. 

1260 - Deep Learning: 

1261 - A `SimulationData` object if all fields for the time step are complete. 

1262 - `None` if the data is incomplete or invalid.""" 

1263 

1264 client_rank, time_step, field = ( 

1265 simulation_data.client_rank, 

1266 simulation_data.time_step, 

1267 simulation_data.field, 

1268 ) 

1269 

1270 # lock needed when training thread might checkpoint on the same data 

1271 # that is being updated below 

1272 with self.consistency_lock: 

1273 simulation.init_structures(client_rank) 

1274 # following expansion is conditional based on 

1275 # the current shape and the given time step 

1276 simulation.time_step_expansion(client_rank, time_step) 

1277 

1278 # check for duplicate data 

1279 if simulation.has_already_received(client_rank, time_step, field): 

1280 return SimulationDataStatus.ALREADY_RECEIVED, None 

1281 

1282 # initialize storage for all fields 

1283 simulation.init_data_storage(client_rank, time_step) 

1284 # update received data 

1285 simulation.update(client_rank, time_step, field, simulation_data) 

1286 

1287 simulation.mark_as_received(client_rank, time_step, field) 

1288 

1289 # check if the time step is complete 

1290 if simulation.is_complete(time_step): 

1291 simulation.nb_received_time_steps += 1 

1292 

1293 # handle empty data scenario 

1294 if simulation_data.data_size == 0: 

1295 return SimulationDataStatus.EMPTY, None 

1296 

1297 return ( 

1298 SimulationDataStatus.COMPLETE, 

1299 self._process_complete_data_reception( 

1300 simulation, 

1301 simulation_data 

1302 ) 

1303 ) 

1304 

1305 # Partial data received 

1306 return ( 

1307 SimulationDataStatus.PARTIAL, 

1308 self._process_partial_data_reception( 

1309 simulation, 

1310 simulation_data 

1311 ) 

1312 ) 

1313 

1314 def __deserialize_message(self, msg: bytes) -> PartialSimulationData: 

1315 """Deserializes a byte stream into a `PartialSimulationData` object. 

1316 

1317 ### Parameters 

1318 - **msg** (`bytes`): Serialized message containing simulation data. 

1319 

1320 ### Returns 

1321 - `PartialSimulationData`: Data objet.""" 

1322 

1323 data = PartialSimulationData.from_msg(msg, self.learning) 

1324 logger.debug( 

1325 f"Rank {self.rank}>> received message " 

1326 f"from sim-id={data.simulation_id}, " 

1327 f"time-step={data.time_step}, " 

1328 f"client-rank={data.client_rank}, " 

1329 f"vect-size={len(data.data)}" 

1330 ) 

1331 return data 

1332 

1333 def __process_simulation_completion(self, simulation: Simulation, force: bool = False) -> None: 

1334 """Finalizes simulation completion and adjusts metadata associated with it. 

1335 

1336 ### Parameters 

1337 - **simulation** (`Simulation`): Instance of the simulation to finalize. 

1338 - **force** (`bool`): Set to enforce termination, regardless. Default is `False`.""" 

1339 

1340 sim_id = simulation.id 

1341 group_id = self._get_group_id_by_simulation(sim_id) 

1342 group = self._groups[group_id] 

1343 

1344 if not simulation.connected: 

1345 return 

1346 

1347 if simulation.has_finished(force): 

1348 simulation.connected = False 

1349 logger.info( 

1350 f"Rank {self.rank}>> sim-id={sim_id} has finished sending time-steps. " 

1351 f"received={simulation.nb_received_time_steps}, " 

1352 f"expected={self.nb_time_steps}" 

1353 ) 

1354 self.mtt_simulation_completion.increment(simulation.duration) 

1355 if self.nb_finished_groups % 100 == 0: 

1356 self._show_insights() 

1357 

1358 if group.has_finished(): 

1359 self.finished_groups.add(group_id) 

1360 

1361 def _validate_data(self, simulation_data: PartialSimulationData) -> bool: 

1362 """Validates the time step and field of the received simulation data. 

1363 

1364 ### Parameters 

1365 - **simulation_data** (`PartialSimulationData`): The data to validate. 

1366 

1367 ### Returns 

1368 - `bool`: if the data is valid.""" 

1369 

1370 sim_id = simulation_data.simulation_id 

1371 time_step = simulation_data.time_step 

1372 field = simulation_data.field 

1373 group_id = self._get_group_id_by_simulation(sim_id) 

1374 simulation = self._groups[group_id].simulations[sim_id] 

1375 

1376 # apply validation checks 

1377 if group_id not in self._groups: 

1378 return False 

1379 

1380 if ( 

1381 time_step < 0 

1382 or (self.time_steps_known and time_step > self.nb_time_steps) 

1383 ): 

1384 logger.warning(f"Rank {self.rank}>> [BAD] sim-id={sim_id}, time-step={time_step}") 

1385 return False 

1386 

1387 if field != "termination" and field not in self.fields: 

1388 logger.warning(f"Rank {self.rank}>> [BAD] sim-id={sim_id}, field=\"{field}\"") 

1389 return False 

1390 

1391 # handle termination messages 

1392 if field == "termination": 

1393 

1394 # modify the time steps received accordingly 

1395 if not self.time_steps_known: 

1396 self.nb_time_steps = time_step 

1397 

1398 # termination message sends total time steps as its `time_step` 

1399 # value. so make a check on how many time steps are received 

1400 # termination could be received prematurely in high-traffic situations 

1401 if simulation.nb_received_time_steps != self.nb_time_steps: 

1402 logger.warning( 

1403 f"Received termination from sim-id={sim_id} prematurely." 

1404 ) 

1405 else: 

1406 self.__process_simulation_completion(simulation) 

1407 return False 

1408 

1409 self.__process_simulation_completion(simulation) 

1410 

1411 return True 

1412 

1413 def __determine_and_process_simulation_data(self, simulation_data: PartialSimulationData 

1414 ) -> Optional[Union[SimulationData, 

1415 PartialSimulationData]]: 

1416 """Determines the status of the simulation data and handles actions accordingly. 

1417 

1418 ### Parameters 

1419 - **simulation_data** (`PartialSimulationData`): The incoming simulation data to process. 

1420 

1421 ### Returns 

1422 - `Optional[Union[SimulationData, PartialSimulationData]]`: 

1423 return of the `_check_simulation_data` method.""" 

1424 

1425 sim_id = simulation_data.simulation_id 

1426 time_step = simulation_data.time_step 

1427 group_id = self._get_group_id_by_simulation(sim_id) 

1428 group = self._groups[group_id] 

1429 simulation = group.simulations[sim_id] 

1430 simulation.connected = True 

1431 

1432 # check simulation data status 

1433 sim_status, sim_data = self._check_simulation_data(simulation, simulation_data) 

1434 if sim_status is SimulationDataStatus.COMPLETE: 

1435 logger.debug( 

1436 f"Rank {self.rank}>> sim-id={sim_id}, time-step={time_step} assembled." 

1437 ) 

1438 elif sim_status is SimulationDataStatus.ALREADY_RECEIVED: 

1439 logger.warning( 

1440 f"Rank {self.rank}>> [Duplicated] sim-id={sim_id}, time-step={time_step}" 

1441 ) 

1442 if sim_status in [SimulationDataStatus.COMPLETE, SimulationDataStatus.EMPTY]: 

1443 self.__process_simulation_completion(simulation) 

1444 

1445 if group.has_finished(): 

1446 self.finished_groups.add(group_id) 

1447 

1448 return sim_data 

1449 

1450 def _handle_simulation_data(self, 

1451 msg: bytes) -> Optional[ 

1452 Union[SimulationData, 

1453 PartialSimulationData]]: 

1454 """This method handles the following tasks: 

1455 1. **Deserialization**: Converts the incoming byte stream 

1456 into a `PartialSimulationData` object. 

1457 2. **Validation**: Ensures the data is valid based on: 

1458 - Time step being within the allowed range. 

1459 - Field name being recognized. 

1460 3. **Simulation Data Handling**: 

1461 - Updates the status of the simulation based on the received data. 

1462 - Detects and logs duplicate messages. 

1463 4. **Completion Check**: 

1464 - Marks the simulation as finished if all data is received. 

1465 - Updates the count of finished simulations. 

1466 

1467 ### Parameters 

1468 - **msg** (`bytes`): A serialized message containing simulation data. 

1469 

1470 ### Returns 

1471 - `Optional[PartialSimulationData]`: 

1472 - `PartialSimulationData`, if successful. 

1473 - `None`, if the message fails validation.""" 

1474 

1475 data = self.__deserialize_message(msg) 

1476 if self._validate_data(data): 

1477 return self.__determine_and_process_simulation_data(data) 

1478 return None 

1479 

1480# =====================================Abstract Methods===================================== 

1481 

1482 @abstractmethod 

1483 def _server_online(self) -> None: 

1484 """An abstract method where user controls the data handling while server is online. 

1485 Unique to melissa flavors.""" 

1486 raise NotImplementedError("Subclasses must override this method.") 

1487 

1488 @abstractmethod 

1489 def _server_offline(self) -> None: 

1490 """An abstract method where user controls the data handling while server is offline. 

1491 Unique to melissa flavors.""" 

1492 raise NotImplementedError("Subclasses must override this method.") 

1493 

1494 @abstractmethod 

1495 def _check_group_size(self) -> None: 

1496 """An abstract method that checks if the group size was correctly set. 

1497 Unique to melissa flavors.""" 

1498 raise NotImplementedError("Subclasses must override this method.") 

1499 

1500 @abstractmethod 

1501 def _process_partial_data_reception(self, 

1502 simulation: Simulation, 

1503 simulation_data: PartialSimulationData 

1504 ) -> Optional[PartialSimulationData]: 

1505 """Returns a value when data has been partially received. 

1506 Unique to melissa flavors.""" 

1507 raise NotImplementedError("Subclass must override this method.") 

1508 

1509 @abstractmethod 

1510 def _process_complete_data_reception(self, 

1511 simulation: Simulation, 

1512 simulation_data: PartialSimulationData 

1513 ) -> Union[PartialSimulationData, 

1514 SimulationData]: 

1515 """Returns a value when data has been completely received. 

1516 Unique to melissa flavors.""" 

1517 raise NotImplementedError("Subclass must override this method.") 

1518 

1519 @abstractmethod 

1520 def _receive(self) -> None: 

1521 """Handles data coming from the server object. 

1522 Unique to melissa flavors.""" 

1523 raise NotImplementedError("Subclasses must override this method.") 

1524 

1525 @abstractmethod 

1526 def start(self) -> None: 

1527 """The high level organization of server events. 

1528 Unique to melissa flavors.""" 

1529 raise NotImplementedError("Subclasses must override this method.") 

1530 

1531 @abstractmethod 

1532 def _restart_from_checkpoint(self, **kwargs) -> None: 

1533 """Restarts the server object from a checkpoint. 

1534 Unique to melissa flavors.""" 

1535 raise NotImplementedError("Subclasses must override this method.") 

1536 

1537 @abstractmethod 

1538 def _checkpoint(self, **kwargs) -> None: 

1539 """Checkpoint the server object. 

1540 Unique to melissa flavors.""" 

1541 raise NotImplementedError("Subclasses must override this method.")