Coverage for melissa/server/base_server.py: 47%
675 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-19 09:33 +0100
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-19 09:33 +0100
1"""This script defines the parent class for melissa server."""
3from datetime import timedelta
4import logging
5import os
6import socket
7import threading
8import datetime
9import time
10from abc import ABC, abstractmethod
11from enum import Enum
12from typing import (
13 Any, Dict, Optional,
14 Tuple, Union, List, Type
15)
17import psutil
18import numpy as np
19import zmq
20from mpi4py import MPI
21import cloudpickle
22import rapidjson
23try:
24 import debugpy
25except ModuleNotFoundError:
26 pass
29from iterative_stats.iterative_moments import IterativeMoments
30from melissa.launcher import config, message
31from melissa.scheduler import job
32from melissa.server.fault_tolerance import FaultTolerance
33from melissa.server.parameters import (
34 BaseExperiment, ParameterSamplerType,
35 ParameterSamplerClass, make_parameter_sampler
36)
37from melissa.utility.message import Message
38from melissa.server.message import ConnectionRequest, ConnectionResponse
39from melissa.server.simulation import (
40 Group, PartialSimulationData,
41 Simulation, SimulationData,
42 SimulationDataStatus
43)
44from melissa.utility.networking import (
45 LengthPrefixFramingDecoder,
46 LengthPrefixFramingEncoder,
47 connect_to_launcher,
48 select_protocol,
49 is_port_in_use,
50 is_launcher_socket_alive_and_ready
51)
52from melissa.server.exceptions import (
53 FatalError,
54 InitialConnectionError,
55 UnsupportedProtocol,
56 ConfigurationFileError,
57 FaultToleranceError
58)
59from melissa.utility.rank_helper import ClusterEnvironment, initialize_sampling_rank
60from melissa.utility.timer import Timer
61from melissa.utility.client_scripts import get_client_script_path
62from melissa.utility.logger import configure_logger, get_log_level_from_verbosity
65logger = logging.getLogger(__name__)
68class ServerStatus(Enum):
69 """Server status enum."""
70 CHECKPOINT = 1
71 TIMEOUT = 2
74def bytes_to_readable(total_bytes: int) -> str:
75 """Returns human-readable byte representation."""
76 final_value = float(total_bytes)
77 unit = "bytes"
78 if total_bytes >= 1024:
79 final_value /= 1024
80 unit = "KB"
81 if total_bytes >= pow(1024, 2):
82 final_value /= 1024
83 unit = "MB"
84 if total_bytes >= pow(1024, 3):
85 final_value /= 1024
86 unit = "GB"
87 if unit == "bytes":
88 return f"{int(final_value)} {unit}"
89 return f"{final_value:.3f} {unit}"
92class BaseServer(ABC):
93 """`BaseServer` class that handles the following tasks:
95 - Manages connections with the launcher and clients.
96 - Generates client scripts for simulations.
97 - Encodes and decodes messages between the server and clients.
98 - Provides basic checkpointing functionality to save and restore states.
100 ### Parameters
101 - **config_dict** (`Dict[str, Any]`): A dictionary containing configuration settings for
102 initializing the server.
103 - **checkpoint_file** (`str`, optional): The filename for the checkpoint file
104 (default is `"checkpoint.pkl"`). This file is used for saving and restoring the server's state.
106 ### Attributes
107 - **comm** (`MPI.Intracomm`): The MPI communicator for inter-process communication.
108 - **rank** (`int`): The rank of the current process in the MPI communicator.
109 - **comm_size** (`int`): The total number of server processes in the MPI communicator.
110 - **client_comm_size** (`int`): The total number of client processes.
111 - **server_processes** (`int`): Synonym for `comm_size`.
112 - **connection_port** (`int`): The server port to establish request-response connection with
113 the clients.
114 - **data_puller_port** (`int`): The server port to establish data pulling with the clients.
116 - **_offline_mode** (`bool`): Internal flag indicating offline mode where no sending operation
117 takes place. Useful when running multiple clients to produce datasets.
118 Call `self.make_study_offline` to enable.
119 - **_learning** (`int`): Internal flag indicating the learning state (initially 0).
120 - **__t0** (`float`): The timestamp marking the initialization time of the object.
121 - **_job_limit** (`int`): Maximum number of jobs the launcher can manage concurrently.
122 - **__is_direct_scheduler** (`bool`): Flag indicating whether the study is using a direct
123 scheduler.
125 - **_restart** (`int`): Flag indicating if the system is in a restart state; initialized
126 from the `MELISSA_RESTART` environment variable.
127 - **_consistency_lock** (`threading.RLock`): Reentrant lock to ensure thread-safe operations
128 on shared resources.
129 - **_is_receiving** (`bool`): Flag indicating whether data reception is ongoing.
130 - **_is_online** (`bool`): Flag indicating if the system is in an online operational mode.
131 - **_sobol_op** (`bool`): Flag indicating whether Sobol operations are being performed.
132 - **_total_bytes_recv** (`int`): Tracks the total number of bytes received over the network.
133 - **_active_sim_ids** (`set`): Set of active simulation ids currently being managed.
135 - **_groups** (`Dict[int, Group]`): Dictionary mapping group ids to `Group` objects.
136 - **_parameter_sampler** (`Optional[BaseExperiment]`): Sampler for generating parameter values
137 for simulations.
138 - **__parameter_generator** (`Any`): Internal generator object for producing parameters.
140 - **verbose_level** (`int`): Determines the verbosity level for logging and debugging output.
141 - **config_dict** (`Dict[str, Any`]): Configuration dictionary provided during initialization.
142 - **checkpoint_file** (`str`): File name used for storing checkpoint data.
144 - **crashes_before_redraw** (`int`): Number of simulation crashes allowed before
145 redrawing parameters.
146 - **max_delay** (`Union[int, float]`): Maximum allowed delay for simulations, in seconds.
147 - **rm_script** (`bool`): Indicates whether client scripts should be removed after execution.
148 - **group_size** (`int`): Number of simulations grouped together for batch processing.
149 - **zmq_hwm** (`int`): High-water mark for ZeroMQ communication.
151 - **fields** (`List[str]`): List of field names used in the study.
152 - **nb_parameters** (`int`): Number of parameters in the parameter sweep study.
153 - **nb_time_steps** (`int`): Number of time steps in each simulation.
154 - **nb_clients** (`int`): Total number of clients participating in the parameter sweep study.
156 - **nb_groups** (`int`): Total number of groups, derived from the number of clients
157 and group size.
158 - **nb_submitted_groups** (`int`): Tracks the number of groups submitted so far.
159 **finished_groups** (`set`): Tracks the finished set of groups.
160 - **mtt_simulation_completion** (`float`): Iteratively keeps track of mean of simulation
161 durations.
163 - **no_fault_tolerance** (`bool`): Indicates whether fault tolerance is disabled,
164 based on the `MELISSA_FAULT_TOLERANCE` environment variable.
165 - **__ft** (`FaultTolerance`): Fault tolerance object managing simulation
166 crashes and retries."""
167 def __init__(self,
168 config_dict: Dict[str, Any],
169 checkpoint_file: str = "checkpoint.pkl") -> None:
171 # MPI initialization
172 cluster = ClusterEnvironment()
173 self.comm: MPI.Comm = cluster.comm_world
174 self.rank: int = cluster.comm_world_rank
175 self.comm_size: int = cluster.comm_world_size
176 self.server_processes: int = self.comm_size
177 self.client_comm_size: int = 0
178 self.__connection_port: int = 2003
179 self.__data_puller_port: int = 5000
180 self.__connected_with_launcher: bool = False
182 self._offline_mode: bool = False
183 self._learning: int = 0
184 self._bind_simulation_to_server_rank: int = 0
185 self.__t0: float = time.time()
186 self._job_limit: int = config_dict["launcher_config"]["job_limit"] - 1
187 self.__is_direct_scheduler: bool = config_dict["launcher_config"]["scheduler"] == "openmpi"
189 self._restart: int = int(os.environ["MELISSA_RESTART"])
190 self._consistency_lock: threading.RLock = threading.RLock()
191 self._is_receiving: bool = False
192 self._is_online: bool = False
193 self._sobol_op: bool = False
194 self._total_bytes_recv: int = 0
195 self._active_sim_ids: set = set()
196 self._groups: Dict[int, Group] = {}
197 self._parameter_sampler: Optional[BaseExperiment] = None
199 self.config_dict: Dict[str, Any] = config_dict
200 self.checkpoint_file: str = checkpoint_file
201 study_options: Dict[str, Any] = self.config_dict["study_options"]
202 self.verbose_level: int = study_options["verbosity"]
204 # Scan study options dictionary
205 self.crashes_before_redraw: int = study_options.get("crashes_before_redraw", 1)
206 self.max_delay: Union[int, float] = study_options.get("simulation_timeout", 60)
207 self.rm_script: bool = study_options.get("remove_client_scripts", False)
208 self.group_size: int = 1
209 self.zmq_hwm: int = study_options.get("zmq_hwm", 0)
210 try:
211 self.fields: List[str] = study_options["field_names"]
212 self.nb_parameters: int = study_options["nb_parameters"]
213 self.nb_clients: int = study_options["parameter_sweep_size"]
214 except KeyError as e:
215 logger.error(f"[INCORRECT] Key not found in the configuration: {e}")
216 raise ConfigurationFileError
218 self.nb_time_steps: int = max(study_options.get("nb_time_steps", 0), 0)
219 if self.nb_time_steps == 0:
220 logger.warning(f"Rank {self.rank}>> Number of timesteps to be received not provided.")
222 self.nb_groups: int = self.nb_clients // self.group_size
223 self.nb_submitted_groups: int = 0
224 self.finished_groups: set = set()
225 self.mtt_simulation_completion: IterativeMoments = IterativeMoments(max_order=1, dim=1)
226 self._job_limit = min(self._job_limit, self.nb_groups)
228 # Fault-Tolerance initialization
229 self.no_fault_tolerance: bool = os.environ["MELISSA_FAULT_TOLERANCE"] == "OFF"
231 self.ignore_client_death: bool = config_dict.get("ignore_client_death", False)
232 if self.ignore_client_death and self.no_fault_tolerance is False:
233 raise ConfigurationFileError("Client deaths cannot be ignored if Fault-tolerance is ON")
235 logger.info("fault-tolerance " + ("OFF" if self.no_fault_tolerance else "ON"))
236 self.__ft: FaultTolerance = FaultTolerance(
237 self.no_fault_tolerance,
238 self.max_delay,
239 self.crashes_before_redraw,
240 self.nb_groups,
241 )
243 # this lets us ping the launcher periodically
244 # ensuring the launcher does not assume the server is dead.
245 self.__ping_interval: int = min(
246 10,
247 config_dict["launcher_config"].get("server_timeout", 10)
248 )
249 self.__run_handler_thread: bool = False
250 self.__pinger_thread: threading.Thread = threading.Thread(
251 name="pinger", target=self.__loop_pings
252 )
254 if self.rank == 0:
255 # make a directory for the checkpoint files if one does not exist
256 os.makedirs("checkpoints", exist_ok=True)
257 metadata = {"MELISSA_RESTART": int(os.environ["MELISSA_RESTART"])}
258 with open("checkpoints/restart_metadata.json", 'wb') as f:
259 # write the metadata to json
260 rapidjson.dump(metadata, f)
261 self.comm.Barrier()
262 self.ckpt_metadata_path: str = f"checkpoints/{self.rank}/metadata.pkl"
263 os.makedirs(f"checkpoints/{self.rank}", exist_ok=True)
265 # a runtime decision on which rank becomes the sampling rank
266 initialize_sampling_rank()
268 @property
269 def offline_mode(self) -> bool:
270 return self._offline_mode
272 def make_study_offline(self) -> None:
273 self._offline_mode = True
274 logger.warning(
275 f"Rank {self.rank}>> Currently running with offline mode. "
276 "No reception will take place."
277 )
279 @property
280 def time_steps_known(self) -> bool:
281 """Time steps are known prior study or not."""
282 return self.nb_time_steps > 0
284 @property
285 def nb_finished_groups(self) -> int:
286 return len(self.finished_groups)
288 @property
289 def is_direct_scheduler(self) -> bool:
290 """Study is using a direct scheduler or not."""
291 return self.__is_direct_scheduler
293 @property
294 def learning(self) -> int:
295 """Deep learning activated?
296 Required when establishing a connection with clients."""
297 return self._learning
299 @property
300 def consistency_lock(self) -> threading.RLock:
301 """Useful for active sampling."""
302 return self._consistency_lock
304 @property
305 def is_receiving(self) -> bool:
306 return self._is_receiving
308 @is_receiving.setter
309 def is_receiving(self, value: bool):
310 self._is_receiving = value
312 @property
313 def is_online(self) -> bool:
314 return self._is_online
316 @is_online.setter
317 def is_online(self, value: bool):
318 self._is_online = value
320 @property
321 def sobol_op(self) -> bool:
322 return self._sobol_op
324 @sobol_op.setter
325 def sobol_op(self, value: bool):
326 self._sobol_op = value
328 @property
329 def parameter_sampler(self) -> Optional[BaseExperiment]:
330 return self._parameter_sampler
332 @parameter_sampler.setter
333 def parameter_sampler(self, value: Optional[BaseExperiment]):
334 self._parameter_sampler = value
336 def __loop_pings(self) -> None:
337 """Maintains communication with the launcher to ensure it
338 does not assume the server has become unresponsive."""
340 while self.__run_handler_thread:
341 self._launcherfd.send(self._encode_msg(message.Ping()))
342 logger.debug(f"Rank {self.rank}>> pinging launcher.")
343 time.sleep(self.__ping_interval)
345 def _start_pinger_thread(self) -> None:
346 """Starts the pinger thread and set the flag."""
348 if self.rank == 0:
349 assert threading.current_thread() != self.__pinger_thread
350 self.__run_handler_thread = True
351 if not self.__pinger_thread.is_alive():
352 self.__pinger_thread.start()
354 def _stop_pinger_thread(self) -> None:
355 """Stops the pinger thread and unsets the flag."""
357 if self.rank == 0:
358 assert threading.current_thread() != self.__pinger_thread
359 self.__run_handler_thread = False
360 if self.__pinger_thread.is_alive():
361 self.__pinger_thread.join(timeout=1.0)
363 def _save_base_state(self) -> None:
364 """Checkpoints all common attributes in the server class to preserve the current state."""
366 if self.no_fault_tolerance:
367 return
369 self.comm.Barrier()
371 # save some state metadata to be reloaded later
372 with self.consistency_lock:
373 metadata = {
374 "nb_groups": self.nb_groups,
375 "nb_submitted_groups": self.nb_submitted_groups,
376 "finished_groups": self.finished_groups,
377 "groups": self._groups,
378 "t0": self.__t0,
379 "total_bytes_recv": self._total_bytes_recv
380 }
382 with open(self.ckpt_metadata_path, 'wb') as f:
383 cloudpickle.dump(metadata, f)
385 if self.parameter_sampler is not None:
386 self.parameter_sampler.checkpoint_state()
388 def _load_base_state(self) -> None:
389 """Loads all common attributes in the server class from a checkpoint or saved state."""
391 if self.no_fault_tolerance:
392 return
394 try:
395 # load the metadata
396 with open(self.ckpt_metadata_path, 'rb') as f:
397 metadata = cloudpickle.load(f)
398 except FileNotFoundError as e:
399 raise FatalError(
400 f"Fault-tolerance must be set for checkpointing.\n{e}"
401 )
403 self.nb_groups = metadata["nb_groups"]
404 self.nb_submitted_groups = metadata["nb_submitted_groups"]
405 self.finished_groups = metadata["finished_groups"]
406 self._groups = metadata["groups"]
407 self.__t0 = metadata["t0"]
408 self._total_bytes_recv = metadata["total_bytes_recv"]
410 if self.parameter_sampler is not None:
411 self.parameter_sampler.restart_from_checkpoint()
413 def __initialize_ports(self,
414 connection_port: int = 2003,
415 data_puller_port: int = 5000) -> None:
416 """Assigns port numbers for connection and data pulling as class attributes.
417 If the specified ports are already in use, likely due to multiple servers running
418 on the same node, the function attempts to find available ports by incrementing the
419 base port values and rechecking their availability.
421 _Note: When multiple independent `melissa-server` jobs are running simultaneously
422 on the same node, there is a chance that a port may incorrectly appear as available,
423 leading to potential conflicts._
425 ### Parameters
426 - **connection_port** (`int`, optional): The port number used for establishing the main
427 connection (default is `2003`).
428 - **data_puller_port** (`int`, optional): The port number used for pulling data
429 (default is `5000`).
431 ### Raises
432 - `FatalError`: If no ports were found after given number of attempts."""
434 # Ports initialization
435 logger.info(f"Rank {self.rank}>> Initializing server...")
436 self.node_name = socket.gethostname()
438 attempts = max(10, self.comm_size)
439 if self.rank == 0:
440 self.__connection_port = connection_port
441 i = 0
442 while is_port_in_use(self.__connection_port) and i < attempts:
443 logger.warning(
444 f"Rank {self.rank}>> Connection port {self.__connection_port} in use. "
445 "Trying another..."
446 )
447 self.__connection_port += 1
448 i += 1
450 if i == attempts:
451 logger.error(
452 f"{self.rank}>> Could not find an available connection port after "
453 f"{attempts} attempts."
454 )
455 raise InitialConnectionError
457 # Set data puller port
458 self.__data_puller_port = data_puller_port + (self.rank * (attempts + 1))
459 i = 0
460 while is_port_in_use(self.__data_puller_port) and i < attempts:
461 logger.warning(f"Rank {self.rank}>> Data puller port {self.__data_puller_port} in use. "
462 "Trying another...")
463 self.__data_puller_port += 1
464 i += 1
466 if i == attempts:
467 logger.error(
468 f"{self.rank}>> Could not find an available data puller port after "
469 f"{attempts} attempts."
470 )
471 raise InitialConnectionError
473 self.__data_puller_port_name = f"tcp://{self.node_name}:{self.__data_puller_port}"
474 self._port_names = self.comm.allgather(self.__data_puller_port_name)
475 logger.debug(f"port_names {self._port_names}")
477 def __connect_to_launcher(self) -> None:
478 """Establishes a connection with the launcher and sends metadata about the study."""
480 self._launcherfd: socket.socket
481 # Setup communication instances
482 self.__protocol, prot_name = select_protocol()
483 logger.info(f"Server/launcher communication protocol: {prot_name}")
484 if self.rank == 0:
485 self._launcherfd = connect_to_launcher()
486 self._launcherfd.send(self._encode_msg(message.CommSize(self.comm_size)))
487 logger.debug(f"Rank {self.rank}>> Comm size {self.comm_size} sent to launcher")
488 self._launcherfd.send(self._encode_msg(message.GroupSize(self.group_size)))
489 logger.debug(f"Rank {self.rank}>> Group size {self.group_size} sent to launcher")
490 # synchronize non-zero ranks after rank 0 connection to make sure
491 # these ranks only connect after comm_size is known to the launcher
492 self.comm.Barrier()
494 if self.rank > 0:
495 # avoiding simultaneous connection that could cause
496 # race conditions for FDs on supercomputers
497 time.sleep(0.001 * self.rank)
498 self._launcherfd = connect_to_launcher()
499 # for i in range(1, self.comm_size):
500 # if self.rank == i:
501 # self._launcherfd = connect_to_launcher()
502 # self.comm.Barrier()
504 all_fds_ready = self.comm.allreduce(
505 is_launcher_socket_alive_and_ready(self._launcherfd),
506 op=MPI.LAND
507 )
508 if not all_fds_ready:
509 raise InitialConnectionError(
510 "Some ranks failed to connect with the launcher."
511 )
513 logger.debug(f"Rank {self.rank}>> Launcher fd set up: {self._launcherfd.fileno()}")
514 self.__connected_with_launcher = True
516 def __setup_sockets(self) -> None:
517 """Sets up ZeroMQ (ZMQ) sockets over a given TCP connection port for communication."""
519 self.__zmq_context = zmq.Context()
520 # Simulations (REQ) <-> Server (REP)
521 self._connection_responder = self.__zmq_context.socket(zmq.REP)
522 if self.rank == 0:
523 addr1 = f"tcp://*:{self.__connection_port}"
524 try:
525 self._connection_responder.bind(addr1)
526 except InitialConnectionError as e:
527 raise e
528 logger.info(
529 f"Rank {self.rank}>> Binding to {addr1} successful."
530 )
532 # Simulations (PUSH) -> Server (PULL)
533 self.__data_puller = self.__zmq_context.socket(zmq.PULL)
534 self.__data_puller.setsockopt(zmq.RCVHWM, self.zmq_hwm)
535 self.__data_puller.setsockopt(zmq.RCVBUF, 4 * 1024 ** 2)
536 self.__data_puller.setsockopt(zmq.LINGER, -1)
537 addr2 = f"tcp://*:{self.__data_puller_port}"
538 try:
539 self.__data_puller.bind(addr2)
540 except InitialConnectionError as e:
541 raise e
542 logger.info(
543 f"Rank {self.rank}>> Data puller to {addr2} successful."
544 )
546 # Time-out checker (creates thread)
547 self.__timerfd_0, self.__timerfd_1 = socket.socketpair(
548 socket.AF_UNIX, socket.SOCK_STREAM
549 )
550 timer = Timer(self.__timerfd_1, timedelta(seconds=self.max_delay))
551 self.__t_timer = threading.Thread(target=timer.run, daemon=True)
552 self.__t_timer.start()
554 def __setup_poller(self) -> None:
555 """This method sets up the polling mechanism by registering three important sockets:
556 - **Data Socket**: Handles data communication.
557 - **Timer Socket**: Manages timing events.
558 - **Launcher Socket**: Facilitates communication with the launcher."""
560 self.__zmq_poller = zmq.Poller()
561 self.__zmq_poller.register(self.__data_puller, zmq.POLLIN)
562 self.__zmq_poller.register(self.__timerfd_0, zmq.POLLIN)
563 self.__zmq_poller.register(self._launcherfd, zmq.POLLIN)
564 if self.rank == 0:
565 self.__zmq_poller.register(self._connection_responder, zmq.POLLIN)
567 def __start_debugger(self) -> None:
568 """Launches the Visual Studio Code (VSCode) debugger for debugging purposes."""
570 # 5678 is the default attach port that we recommend users to set
571 # in the documentation.
572 debugpy.listen(5678)
573 logger.warning("Waiting for debugger attach, please start the "
574 "debugger by navigating to debugger pane ctrl+shift+d "
575 "and selecting\n"
576 "Python: Remote Attach")
577 debugpy.wait_for_client()
578 logger.info("Debugger successfully attached.")
579 # send message to launcher to ensure debugger doesnt timeout
580 snd_msg = self._encode_msg(message.StopTimeoutMonitoring())
581 self._launcherfd.send(snd_msg)
583 def configure_logger(self) -> None:
584 """Configures server loggers for each MPI rank."""
585 log_level = get_log_level_from_verbosity(self.verbose_level)
586 app_str = f"_restart_{self._restart}" if self._restart else ""
587 configure_logger(f"melissa_server_{self.rank}{app_str}.log", log_level)
589 def initialize_connections(self) -> None:
590 """Initializes socket connections for communication."""
592 self.__initialize_ports()
593 self.__connect_to_launcher()
594 self.__setup_sockets()
595 self.__setup_poller()
597 if self.config_dict.get("vscode_debugging", False):
598 self.__start_debugger()
600 def _get_group_id_by_simulation(self, sim_id: int) -> int:
601 """Returns group id of the given simulation id."""
602 return sim_id // self.group_size
604 def _get_sim_id_list_by_group(self, group_id: int) -> List[int]:
605 """Returns a list of all simulation ids for a given group id."""
606 first_id = group_id * self.group_size
607 last_id = first_id + self.group_size
608 return list(range(first_id, last_id))
610 def _get_all_sim_ids(self):
611 """Yields all simulation ids across all groups."""
612 for group_id in self._groups.keys():
613 yield from self._get_sim_id_list_by_group(group_id)
615 def _verify_and_update_sampler_kwargs(self, sampler_t, **kwargs) -> Dict[str, Any]:
616 """Updates the parameters that were not provided by the user when
617 creating a sampler using `set_parameter_sampler` method. It also ensures whether
618 the seed is given or not for a parallel server."""
620 # if not provided by the users
621 if "nb_params" not in kwargs:
622 kwargs["nb_params"] = self.nb_parameters
623 if "nb_sims" not in kwargs:
624 kwargs["nb_sims"] = self.nb_clients
625 if "seed" not in kwargs:
626 if hasattr(self, "seed"):
627 kwargs["seed"] = self.seed
629 return kwargs
631 def set_parameter_sampler(self,
632 sampler_t: Union[ParameterSamplerType, Type[ParameterSamplerClass]],
633 **kwargs) -> None:
634 """Sets the defined parameter sampler type. This dictates how parameters are sampled
635 for experiments. This sampler type can either be pre-defined or customized
636 by inheriting a pre-defined sampling class.
638 ### Parameters
639 - **sampler_t** (`Union[ParameterSamplerType, Type[ParameterSamplerClass]]`):
640 - `ParameterSamplerType`: Enum specifying pre-defined samplers.
641 - `Type[ParameterSamplerClass]`: A class type to instantiate.
642 - **kwargs** (`Dict[str, Any]`): Dictionary of keyword arguments.
643 Useful to pass custom parameter as well as strict parameter such as
644 `l_bounds`, `u_bounds`, `apply_pick_freeze`, `second_order`, `seed=0`, etc."""
645 kwargs = self._verify_and_update_sampler_kwargs(sampler_t, **kwargs)
646 self._parameter_sampler = make_parameter_sampler(sampler_t, **kwargs)
648 def _update_parameter_sampler(self) -> None:
649 """Updates the existing parameter sampler."""
651 if self._parameter_sampler:
652 self._parameter_sampler.flush_to_disk()
654 def _launch_groups(self, group_ids: List[int]) -> None:
655 """Launches the study groups for the very first run.
656 This process involves generating the client scripts and ensures that
657 no restart has occurred in the case of fault tolerance.
659 ### Parameters
660 - **group_ids** (`List[int]`): A list of group identifiers to launch.
661 """
663 if self.nb_submitted_groups >= self.nb_groups:
664 return
666 # Get current working directory containing the client script template
667 client_script = os.path.abspath("client.sh")
669 if not os.path.isfile(client_script):
670 raise FileNotFoundError("error client script not found")
672 # Generate all client scripts
673 self._generate_client_scripts(group_ids)
675 # Launch every group
676 # note that the launcher message stacking feature does not work
677 for group_id in group_ids:
678 self._launch_group(group_id)
680 def _get_client_script_path(self, sim_id: int) -> str:
681 return get_client_script_path(sim_id)
683 def _generate_client_scripts(self,
684 group_ids: List[int],
685 create_new_group: bool = False) -> None:
686 """Creates all required client scripts (e.g., `client.X.sh`),
687 and sets up a dictionary for fault tolerance.
689 ### Parameters
690 - **group_ids** (`List[int]`): A list of group identifiers.
691 - **create_new_group** (`bool`, optional): Flag indicating whether to
692 create a new group of clients (default is `False`)."""
694 for group_id in group_ids:
695 for sim_id in self._get_sim_id_list_by_group(group_id):
696 assert self._parameter_sampler is not None
697 parameters = list(self._parameter_sampler.draw(sim_id))
699 script_path = self._get_client_script_path(sim_id)
700 fname = os.path.basename(script_path)
701 if self.rank == 0:
702 self.__generate_client_script(sim_id, parameters, script_path)
703 if sim_id > 0 and sim_id % 1000 == 0:
704 snd_msg = self._encode_msg(message.Ping())
705 self._launcherfd.send(snd_msg)
707 logger.info(
708 f"Rank {self.rank}>> Created {fname} with parameters {parameters}"
709 )
711 # Fault-tolerance dictionary creation and update
712 # create_new_group is specifically for active sampling
713 if create_new_group or group_id not in self._groups:
714 self._groups[group_id] = Group(group_id, self.sobol_op)
715 if sim_id not in self._groups[group_id].simulations:
716 self._groups[group_id].simulations[sim_id] = Simulation(
717 sim_id,
718 script_path,
719 self.nb_time_steps,
720 self.fields,
721 parameters
722 )
723 self._groups[group_id].simulations[sim_id].last_message = None
725 def __generate_client_script(self,
726 sim_id: int,
727 parameters: List[Any],
728 script_path: str) -> None:
729 """Generates a single client script for a given simulation id and parameters.
731 ### Parameters
732 - **sim_id** (`int`): The simulation id associated with the client script.
733 - **parameters** (`list`): The list of parameters.
734 - **script_path** (`str`): The absolute path of the client script to create.
735 """
737 if self.rank == 0:
738 with open(script_path, "w") as f:
739 print("#!/bin/sh", file=f)
740 self._write_environment_variables(f, sim_id)
741 self._write_execution_command(f, parameters)
743 os.chmod(script_path, 0o744)
745 def _write_environment_variables(self, f: Any, sim_id: int) -> None:
746 """Writes environment variables to the client script.
748 ### Parameters
749 - **f** (`Any`): The file object to write to.
750 - **sim_id** (`int`): The simulation id associated with the client script.
751 """
752 print("exec env \\", file=f)
753 print(f" MELISSA_VERBOSE={self.verbose_level} \\", file=f)
754 print(f" MELISSA_SIMU_ID={sim_id} \\", file=f)
755 print(f" MELISSA_SERVER_NODE_NAME={self.node_name} \\", file=f)
756 print(f" MELISSA_SERVER_PORT={self.__connection_port} \\", file=f)
758 def _write_execution_command(self, f: Any, parameters: List[Any]) -> None:
759 """Writes the execution command to the client script.
761 ### Parameters
762 - **f** (`Any`): The file object to write to.
763 - **parameters** (`list`): The list of parameters.
764 """
765 if self.rm_script:
766 print(
767 " "
768 + " ".join(
769 [os.path.join(os.getcwd(), "client.sh")]
770 + [
771 np.format_float_positional(x) if not isinstance(x, str)
772 else x for x in parameters
773 ]
774 )
775 + " &",
776 file=f,
777 )
778 print(" wait", file=f)
779 print(' rm "$0"', file=f)
780 else:
781 print(
782 " "
783 + " ".join(
784 [os.path.join(os.getcwd(), "client.sh")]
785 + [
786 np.format_float_positional(x) if not isinstance(x, str)
787 else x for x in parameters
788 ]
789 ),
790 file=f,
791 )
793 def _launch_group(self, group_id: int) -> None:
794 """Submits a request to the launcher to run a given group id.
795 For non-Sobol studies, the group id and simulation id are the same.
797 ### Parameters
798 - **group_id** (`int`): The unique identifier of the group to be launched.
799 """
800 if self.rank == 0:
801 # Job submission message to launcher (initial_id,num_jobs)
802 for sim_id in self._get_sim_id_list_by_group(group_id):
803 snd_msg = self._encode_msg(message.JobSubmission(sim_id, 1))
804 self._launcherfd.send(snd_msg)
805 # snd_msg = self._encode_msg(message.JobSubmission(group_id, 1))
806 # self._launcherfd.send(snd_msg)
807 logger.debug(
808 f"Rank {self.rank}>> group "
809 f"{group_id + 1}/{self.nb_groups} "
810 "submitted to launcher"
811 )
813 self._groups[group_id].submitted = True
814 for _, simulation in self._groups[group_id].simulations.items():
815 simulation.connected = False
817 self.nb_submitted_groups += 1
819 def _kill_group(self, group_id: int) -> None:
820 """Submits a request to the launcher to terminate a given group id.
822 ### Parameters
823 - **group_id** (`int`): The unique identifier of the group to be terminated."""
825 group = self._groups[group_id]
826 if self.rank == 0:
827 logger.warning(
828 f"[RESTART] Resubmitting incomplete group-id={group_id} to the launcher."
829 )
830 snd_msg = self._encode_msg(message.JobCancellation(group_id))
831 self._launcherfd.send(snd_msg)
833 group.submitted = False
834 for sim_id in group.simulations:
835 group.simulations[sim_id].connected = False
836 self.nb_submitted_groups -= 1
838 def _relaunch_group(self, group_id: int, create_new_group: bool) -> None:
839 """Relaunches a failed group with or without new parameters,
840 depending on the fault tolerance configuration.
842 ### Parameters
843 - **group_id** (`int`): The unique identifier of the group to be relaunched.
844 - **create_new_group** (`bool`): A flag indicating whether to create a new group
845 with new parameters."""
847 self._generate_client_scripts([group_id], create_new_group)
848 assert not self._groups[group_id].has_finished()
849 self._kill_group(group_id)
850 self._launch_group(group_id)
852 def _handle_simulation_connection(self, msg: bytes) -> int:
853 """Handles an incoming connection request from a submitted simulation.
854 This method is executed by rank 0 only.
856 ### Parameters
857 - **msg** (`bytes`): The message received from the simulation requesting a connection.
859 ### Returns
860 - `int`: The simulation id of the connected simulation, or `-1` if the connection
861 could not be established."""
863 request = ConnectionRequest.recv(msg)
864 self.client_comm_size = request.comm_size
865 sim_id = request.simulation_id
867 # a corner case which may not happen
868 # at this point, it is expected that
869 # the group/simulation is already running
870 group_id = self._get_group_id_by_simulation(sim_id)
871 if group_id not in self._groups or not self._groups[group_id].submitted:
872 logger.warning(f"Rank {self.rank}>> group-id={group_id} does not exist.")
873 return -1
875 logger.debug(
876 f"Rank {self.rank}>> [Connection] received connection message "
877 f"from sim-id={sim_id} with client-comm-size={self.client_comm_size}."
878 )
879 logger.debug(
880 f"Rank {self.rank}>> [Connection] sending response to sim-id={sim_id}"
881 f" with learning={self._learning}"
882 )
883 response = ConnectionResponse(
884 self.comm_size,
885 self._learning,
886 int(self._bind_simulation_to_server_rank),
887 self.nb_parameters,
888 self._port_names,
889 )
890 self._connection_responder.send(response.encode())
891 logger.info(
892 f"Rank {self.rank}>> [Connection] sim-id={sim_id} established."
893 )
894 self._groups[group_id].simulations[sim_id].connected = True
896 return sim_id
898 def _restart_groups(self) -> None:
899 """Kills and restarts simulations that were running when the server crashed."""
901 resubmitted_cnt: int = 0
903 for group_id, group in self._groups.items():
904 if group.submitted and not group.has_finished():
905 self.nb_submitted_groups -= 1
906 for sim_id, sim in group.simulations.items():
907 # regenerate scripts as some things might change
908 # such as environment variables pointing to the host:port
909 # of the newly launched server.
910 self.__generate_client_script(
911 sim_id,
912 sim.parameters,
913 sim.script_path
914 )
915 self._launch_group(group_id)
916 resubmitted_cnt += 1
918 # possible corner case where all were presumed to be submitted
919 # then we need to initiate at least one new submission
920 if resubmitted_cnt == 0:
921 self._launch_groups([self.nb_submitted_groups])
923 def poll_sockets(self,
924 timeout: int = 10,
925 ) -> Optional[Union[ServerStatus, SimulationData, PartialSimulationData]]:
926 """Performs polling over the registered socket descriptors to monitor various events,
927 including timer, launcher messages, new client connections, and data readiness.
929 ### Parameters
930 - **timeout** (`int`, optional): The maximum time (in seconds) to wait for a socket
931 event before returning. Default is `10` seconds.
933 ### Returns
934 - `Optional[Union[ServerStatus, SimulationData, PartialSimulationData]]`:
935 - `ServerStatus` if the event is related to server status.
936 - `SimulationData` if new simulation data is received.
937 - `PartialSimulationData` if partial data from a simulation is received."""
939 # 1. Poll sockets
940 # ZMQ sockets
941 sockets = dict(self.__zmq_poller.poll(timeout))
942 if not sockets:
943 return ServerStatus.TIMEOUT
945 if self.rank == 0:
946 # 2. Look for connections from new simulations (just at rank 0)
947 if (
948 self._connection_responder in sockets
949 and sockets[self._connection_responder] == zmq.POLLIN
950 ):
951 msg = self._connection_responder.recv()
952 logger.debug(f"Rank {self.rank}>> Handle client connection request")
953 self._handle_simulation_connection(msg)
955 # 3. Handle launcher message
956 # this is a TCP/SCTP socket not a zmq one, so handled differently
957 if self._launcherfd.fileno() in sockets:
958 logger.debug(f"Rank {self.rank}>> Handle launcher message")
959 self._handle_fd()
961 # 4. Handle simulation data message
962 if (
963 not self.offline_mode
964 and self.__data_puller in sockets
965 and sockets[self.__data_puller] == zmq.POLLIN
966 ):
967 logger.debug(f"Rank {self.rank}>> Handle simulation data")
968 msg = self.__data_puller.recv()
969 self._total_bytes_recv += len(msg)
970 return self._handle_simulation_data(msg)
972 # 5. Handle timer message
973 if self.__timerfd_0.fileno() in sockets:
974 logger.debug(f"Rank {self.rank}>> Handle timer message")
975 self.__handle_timerfd()
977 return None
979 def __forceful_group_termination(self, group_id: int) -> None:
980 """Forcefully terminates all clients in a group."""
982 for sim_id in self._get_sim_id_list_by_group(group_id):
983 sim = self._groups[group_id].simulations[sim_id]
984 self.__process_simulation_completion(sim, force=True)
986 def __handle_timerfd(self) -> None:
987 """Handles timer messages."""
989 self.__timerfd_0.recv(1)
990 try:
991 self.__ft.check_time_out(self._groups)
992 except FaultToleranceError as e:
993 if not self.ignore_client_death:
994 raise e
996 for group_id, create_new_group in self.__ft.restart_group.items():
997 if not self.ignore_client_death:
998 self._relaunch_group(group_id, create_new_group)
999 else:
1000 self.__forceful_group_termination(group_id)
1002 self.__ft.restart_group = {}
1004 def __handle_failed_group(self, group_id: int) -> None:
1005 """Handles failed group by using fault-tolerance to decide resubmission."""
1007 try:
1008 group = self._groups[group_id]
1009 create_new_group = self.__ft.handle_failed_group(group_id, group)
1010 except FaultToleranceError as e:
1011 if not self.ignore_client_death:
1012 raise e
1014 if not self.ignore_client_death:
1015 self._relaunch_group(group_id, create_new_group)
1016 else:
1017 self.__forceful_group_termination(group_id)
1019 def _handle_fd(self) -> None:
1020 """Handles the launcher's messages."""
1022 bs = self._launcherfd.recv(256)
1023 rcvd_msg = self._decode_msg(bs)
1025 for msg in rcvd_msg:
1026 # 1. Launcher sent JOB_UPDATE message (msg.job_id <=> group_id)
1027 if isinstance(msg, message.JobUpdate):
1028 if msg.job_id not in self._groups:
1029 continue
1030 group = self._groups[msg.job_id]
1031 group_id = group.group_id
1032 # React to simulation status
1033 if msg.job_state in [job.State.ERROR, job.State.FAILED]:
1034 logger.debug(f"Launcher indicates failure of group-id/sim-id={group_id}")
1035 self.__handle_failed_group(group_id)
1036 elif msg.job_state is job.State.TERMINATED:
1037 logger.info(
1038 f"Rank {self.rank}>> [Termination] group-id/sim-id={group_id}"
1039 )
1040 if not group.has_finished() and not self.ignore_client_death:
1041 logger.warning(
1042 f"[Inconsistent State] Launcher reports group-id={group_id} as "
1043 "terminated, but the server has not marked the group as "
1044 "finished. The server expects a \"termination\" message from the "
1045 "group before considering it complete. This may occur if message "
1046 "delivery is delayed and the launcher detects group exit prematurely."
1047 )
1049 if self.nb_submitted_groups < self.nb_groups:
1050 # keep submitting new clients one by one
1051 current_max_group_id = self.nb_submitted_groups
1052 self._launch_groups([current_max_group_id])
1054 # 2. Server sends PING
1055 if self.rank == 0:
1056 logger.debug(
1057 "Server got message from launcher and sends PING back"
1058 )
1059 snd_msg = self._encode_msg(message.Ping())
1060 self._launcherfd.send(snd_msg)
1062 def _decode_msg(self, byte_stream: bytes) -> List[Message]:
1063 """Deserializes a message based on the specified protocol.
1065 ### Parameters
1066 - **byte_stream** (`bytes`): The byte stream to be deserialized,
1067 representing the encoded message.
1069 ### Returns
1070 - `List[Message]`: A list of byte sequences representing
1071 the deserialized message components."""
1073 msg_list = []
1074 if self.__protocol == socket.IPPROTO_TCP:
1075 packets = LengthPrefixFramingDecoder(
1076 config.TCP_MESSAGE_PREFIX_LENGTH
1077 ).execute(byte_stream)
1078 for p in packets:
1079 msg_list.append(message.deserialize(p))
1080 logger.debug(f"Decoded launcher messages {msg_list}")
1081 return msg_list
1082 if self.__protocol == socket.IPPROTO_SCTP:
1083 msg_list.append(message.deserialize(byte_stream))
1084 logger.debug(f"Decoded launcher messages {msg_list}")
1085 return msg_list
1086 raise UnsupportedProtocol(f"{self.__protocol} not supported for decoding.")
1088 def _encode_msg(self, msg: Message) -> bytes:
1089 """Serializes message based on the specified protocol.
1091 ### Parameters
1092 - **msg** (`Message`): The message to be serialized,
1093 typically a byte sequence that needs encoding.
1095 ### Returns
1096 - `bytes`: The serialized byte stream representing the encoded message."""
1098 if self.__protocol == socket.IPPROTO_TCP:
1099 encoded_packet = LengthPrefixFramingEncoder(
1100 config.TCP_MESSAGE_PREFIX_LENGTH
1101 ).execute(msg.serialize())
1102 return encoded_packet
1103 if self.__protocol == socket.IPPROTO_SCTP:
1104 return msg.serialize()
1105 raise UnsupportedProtocol(f"{self.__protocol} not supported for encoding.")
1107 def _all_done(self) -> bool:
1108 """Checks whether all clients' data has been received and
1109 unregisters the timer socket if completed.
1111 ### Returns
1112 - `bool`: if all clients' data has been successfully received."""
1114 if self.nb_finished_groups == self.nb_groups:
1115 # join thread and close timer sockets
1116 logger.info(f"Rank {self.rank}>> closes timer sockets.")
1117 self.__timerfd_0.close()
1118 self.__zmq_poller.unregister(self.__timerfd_0)
1119 self.__t_timer.join(timeout=1)
1120 if self.__t_timer.is_alive():
1121 logger.warning("timer thread did not terminate")
1122 else:
1123 self.__timerfd_1.close()
1124 return True
1126 return False
1128 def close_connection(self, exit_: int = 0) -> None:
1129 """Signals to the launcher that the study has ended with a specified exit status.
1131 ### Parameters
1132 - `exit_` (`int`, optional): The exit status code to be sent to the launcher.
1133 Defaults to `0`, indicating successful completion."""
1135 if self.rank == 0 and self.__connected_with_launcher:
1136 self._launcherfd.send(
1137 self._encode_msg(message.Exit(exit_))
1138 )
1140 self.mpi_abort(exit_)
1142 def mpi_abort(self, exit_: int = 0) -> None:
1143 if exit_ > 0:
1144 logger.error(
1145 f"Rank {self.rank}>> An error occured on one of the MPI ranks. Aborting the study."
1146 )
1147 self.comm.Abort(exit_)
1149 def get_memory_info_in_gb(self) -> Tuple[float, float]:
1150 """Returns a `Tuple[float, float]` containing memory consumed and
1151 the total main memory in GB."""
1153 memory = psutil.virtual_memory()
1154 consumed = memory.used / (1024 ** 3)
1155 total = memory.total / (1024 ** 3)
1157 return consumed, total
1159 def _show_insights(self) -> None:
1160 """Logs information gathered from clients, and server processing."""
1162 seconds = self.mtt_simulation_completion.get_mean()[0]
1163 mean_hms = str(
1164 datetime.timedelta(
1165 seconds=int(seconds)
1166 )
1167 )
1168 t = mean_hms if seconds > 1 else f"{seconds:.2} sec."
1169 logger.info(
1170 f"Rank {self.rank}>> [Insight] "
1171 f"Average simulation completion={t}. "
1172 "Computed based on the reception of samples "
1173 "and may vary if time-steps are received out-of-order."
1174 )
1176 consumed, total = self.get_memory_info_in_gb()
1177 logger.info(
1178 f"Rank {self.rank}>> [Insight] Memory consumption={consumed:.2f}/{total:.2f} GB."
1179 )
1181 def _write_final_report(self) -> None:
1182 """Write miscellaneous information about the analysis."""
1184 # Total time
1185 total_time = time.time() - self.__t0
1186 total_time = self.comm.allreduce(total_time, op=MPI.SUM)
1187 # Total MB received
1188 total_b = self.comm.allreduce(self._total_bytes_recv, op=MPI.SUM)
1189 if self.rank == 0:
1190 msg_bytes = bytes_to_readable(total_b)
1191 total_hms = datetime.timedelta(seconds=total_time // self.comm_size)
1192 logger.info(
1193 " - Number of Finished Groups (Simulations): "
1194 f"{self.nb_finished_groups}/{self.nb_submitted_groups}"
1195 )
1196 if self.ignore_client_death and len(self.__ft.failed_ids) > 0:
1197 logger.info(
1198 " - Number of Failed Groups (Simulations): "
1199 f"{len(self.__ft.failed_ids)}/{self.nb_submitted_groups}"
1200 )
1201 logger.info(f" - Number of Server Ranks: {self.comm_size}")
1202 logger.info(f" - Total time: {str(total_hms)}")
1203 logger.info(f" - Total data received: {msg_bytes}")
1205 def _server_finalize(self, exit_: int = 0) -> None:
1206 """Finalizes the server operations.
1208 ### Parameters
1209 - `exit_` (`int`, optional): The exit status code indicating the outcome
1210 of the server's operations. Defaults to `0`, which signifies a successful termination."""
1211 if exit_ == 0:
1212 self._write_final_report()
1214 self._stop_pinger_thread()
1215 self.parameter_sampler.finalize(exit_) # type:ignore
1217 logger.info(f"Server finalizing with status {exit_}.")
1218 self.close_connection(exit_)
1220 def setup_environment(self) -> None:
1221 """Optional. A method that sets up the environment or initialization.
1222 Any necessary setup methods go here.
1223 For example, Melissa DL study needs `dist.init_process_group` to be called."""
1224 return None
1226 def _check_simulation_data(self,
1227 simulation: Simulation,
1228 simulation_data: PartialSimulationData
1229 ) -> Tuple[
1230 SimulationDataStatus, Union[
1231 Optional[SimulationData],
1232 Optional[PartialSimulationData]]]:
1233 """Tracks and validates incoming simulation data.
1235 1. **Client Rank Initialization**:
1236 Ensures `simulation_data` structures are initialized per client rank.
1237 2. **Dynamic Matrix Expansion**:
1238 Handles unknown sizes dynamically as new time steps are encountered.
1239 3. **Duplicate Data Detection**:
1240 Discards messages if the data for the specified field and time step has
1241 already been received.
1242 4. **Time Step Completion**:
1243 - Checks if all fields for a specific time step have been received and processes them
1244 into a `SimulationData` object.
1245 - Handles cases where the data is empty.
1246 5. **Partial Data Handling**: Tracks fields received so far and waits for completion.
1248 ### Parameters
1249 - **simulation** (`Simulation`): Tracks the state and received data of the simulation.
1250 - **simulation_data** (`PartialSimulationData`): The incoming data message from
1251 the simulation.
1253 ### Returns
1254 - `SimulationDataStatus`: Status of the simulation data
1255 (`COMPLETE`, `PARTIAL`, `ALREADY_RECEIVED`, `EMPTY`).
1256 - `Union[Optional[SimulationData], Optional[PartialSimulationData]]`:
1257 - Sensitivity Analysis:
1258 - A `PartialSimulationData` object regardless of it being incomplete
1259 as SA can be computed independently.
1260 - Deep Learning:
1261 - A `SimulationData` object if all fields for the time step are complete.
1262 - `None` if the data is incomplete or invalid."""
1264 client_rank, time_step, field = (
1265 simulation_data.client_rank,
1266 simulation_data.time_step,
1267 simulation_data.field,
1268 )
1270 # lock needed when training thread might checkpoint on the same data
1271 # that is being updated below
1272 with self.consistency_lock:
1273 simulation.init_structures(client_rank)
1274 # following expansion is conditional based on
1275 # the current shape and the given time step
1276 simulation.time_step_expansion(client_rank, time_step)
1278 # check for duplicate data
1279 if simulation.has_already_received(client_rank, time_step, field):
1280 return SimulationDataStatus.ALREADY_RECEIVED, None
1282 # initialize storage for all fields
1283 simulation.init_data_storage(client_rank, time_step)
1284 # update received data
1285 simulation.update(client_rank, time_step, field, simulation_data)
1287 simulation.mark_as_received(client_rank, time_step, field)
1289 # check if the time step is complete
1290 if simulation.is_complete(time_step):
1291 simulation.nb_received_time_steps += 1
1293 # handle empty data scenario
1294 if simulation_data.data_size == 0:
1295 return SimulationDataStatus.EMPTY, None
1297 return (
1298 SimulationDataStatus.COMPLETE,
1299 self._process_complete_data_reception(
1300 simulation,
1301 simulation_data
1302 )
1303 )
1305 # Partial data received
1306 return (
1307 SimulationDataStatus.PARTIAL,
1308 self._process_partial_data_reception(
1309 simulation,
1310 simulation_data
1311 )
1312 )
1314 def __deserialize_message(self, msg: bytes) -> PartialSimulationData:
1315 """Deserializes a byte stream into a `PartialSimulationData` object.
1317 ### Parameters
1318 - **msg** (`bytes`): Serialized message containing simulation data.
1320 ### Returns
1321 - `PartialSimulationData`: Data objet."""
1323 data = PartialSimulationData.from_msg(msg, self.learning)
1324 logger.debug(
1325 f"Rank {self.rank}>> received message "
1326 f"from sim-id={data.simulation_id}, "
1327 f"time-step={data.time_step}, "
1328 f"client-rank={data.client_rank}, "
1329 f"vect-size={len(data.data)}"
1330 )
1331 return data
1333 def __process_simulation_completion(self, simulation: Simulation, force: bool = False) -> None:
1334 """Finalizes simulation completion and adjusts metadata associated with it.
1336 ### Parameters
1337 - **simulation** (`Simulation`): Instance of the simulation to finalize.
1338 - **force** (`bool`): Set to enforce termination, regardless. Default is `False`."""
1340 sim_id = simulation.id
1341 group_id = self._get_group_id_by_simulation(sim_id)
1342 group = self._groups[group_id]
1344 if not simulation.connected:
1345 return
1347 if simulation.has_finished(force):
1348 simulation.connected = False
1349 logger.info(
1350 f"Rank {self.rank}>> sim-id={sim_id} has finished sending time-steps. "
1351 f"received={simulation.nb_received_time_steps}, "
1352 f"expected={self.nb_time_steps}"
1353 )
1354 self.mtt_simulation_completion.increment(simulation.duration)
1355 if self.nb_finished_groups % 100 == 0:
1356 self._show_insights()
1358 if group.has_finished():
1359 self.finished_groups.add(group_id)
1361 def _validate_data(self, simulation_data: PartialSimulationData) -> bool:
1362 """Validates the time step and field of the received simulation data.
1364 ### Parameters
1365 - **simulation_data** (`PartialSimulationData`): The data to validate.
1367 ### Returns
1368 - `bool`: if the data is valid."""
1370 sim_id = simulation_data.simulation_id
1371 time_step = simulation_data.time_step
1372 field = simulation_data.field
1373 group_id = self._get_group_id_by_simulation(sim_id)
1374 simulation = self._groups[group_id].simulations[sim_id]
1376 # apply validation checks
1377 if group_id not in self._groups:
1378 return False
1380 if (
1381 time_step < 0
1382 or (self.time_steps_known and time_step > self.nb_time_steps)
1383 ):
1384 logger.warning(f"Rank {self.rank}>> [BAD] sim-id={sim_id}, time-step={time_step}")
1385 return False
1387 if field != "termination" and field not in self.fields:
1388 logger.warning(f"Rank {self.rank}>> [BAD] sim-id={sim_id}, field=\"{field}\"")
1389 return False
1391 # handle termination messages
1392 if field == "termination":
1394 # modify the time steps received accordingly
1395 if not self.time_steps_known:
1396 self.nb_time_steps = time_step
1398 # termination message sends total time steps as its `time_step`
1399 # value. so make a check on how many time steps are received
1400 # termination could be received prematurely in high-traffic situations
1401 if simulation.nb_received_time_steps != self.nb_time_steps:
1402 logger.warning(
1403 f"Received termination from sim-id={sim_id} prematurely."
1404 )
1405 else:
1406 self.__process_simulation_completion(simulation)
1407 return False
1409 self.__process_simulation_completion(simulation)
1411 return True
1413 def __determine_and_process_simulation_data(self, simulation_data: PartialSimulationData
1414 ) -> Optional[Union[SimulationData,
1415 PartialSimulationData]]:
1416 """Determines the status of the simulation data and handles actions accordingly.
1418 ### Parameters
1419 - **simulation_data** (`PartialSimulationData`): The incoming simulation data to process.
1421 ### Returns
1422 - `Optional[Union[SimulationData, PartialSimulationData]]`:
1423 return of the `_check_simulation_data` method."""
1425 sim_id = simulation_data.simulation_id
1426 time_step = simulation_data.time_step
1427 group_id = self._get_group_id_by_simulation(sim_id)
1428 group = self._groups[group_id]
1429 simulation = group.simulations[sim_id]
1430 simulation.connected = True
1432 # check simulation data status
1433 sim_status, sim_data = self._check_simulation_data(simulation, simulation_data)
1434 if sim_status is SimulationDataStatus.COMPLETE:
1435 logger.debug(
1436 f"Rank {self.rank}>> sim-id={sim_id}, time-step={time_step} assembled."
1437 )
1438 elif sim_status is SimulationDataStatus.ALREADY_RECEIVED:
1439 logger.warning(
1440 f"Rank {self.rank}>> [Duplicated] sim-id={sim_id}, time-step={time_step}"
1441 )
1442 if sim_status in [SimulationDataStatus.COMPLETE, SimulationDataStatus.EMPTY]:
1443 self.__process_simulation_completion(simulation)
1445 if group.has_finished():
1446 self.finished_groups.add(group_id)
1448 return sim_data
1450 def _handle_simulation_data(self,
1451 msg: bytes) -> Optional[
1452 Union[SimulationData,
1453 PartialSimulationData]]:
1454 """This method handles the following tasks:
1455 1. **Deserialization**: Converts the incoming byte stream
1456 into a `PartialSimulationData` object.
1457 2. **Validation**: Ensures the data is valid based on:
1458 - Time step being within the allowed range.
1459 - Field name being recognized.
1460 3. **Simulation Data Handling**:
1461 - Updates the status of the simulation based on the received data.
1462 - Detects and logs duplicate messages.
1463 4. **Completion Check**:
1464 - Marks the simulation as finished if all data is received.
1465 - Updates the count of finished simulations.
1467 ### Parameters
1468 - **msg** (`bytes`): A serialized message containing simulation data.
1470 ### Returns
1471 - `Optional[PartialSimulationData]`:
1472 - `PartialSimulationData`, if successful.
1473 - `None`, if the message fails validation."""
1475 data = self.__deserialize_message(msg)
1476 if self._validate_data(data):
1477 return self.__determine_and_process_simulation_data(data)
1478 return None
1480# =====================================Abstract Methods=====================================
1482 @abstractmethod
1483 def _server_online(self) -> None:
1484 """An abstract method where user controls the data handling while server is online.
1485 Unique to melissa flavors."""
1486 raise NotImplementedError("Subclasses must override this method.")
1488 @abstractmethod
1489 def _server_offline(self) -> None:
1490 """An abstract method where user controls the data handling while server is offline.
1491 Unique to melissa flavors."""
1492 raise NotImplementedError("Subclasses must override this method.")
1494 @abstractmethod
1495 def _check_group_size(self) -> None:
1496 """An abstract method that checks if the group size was correctly set.
1497 Unique to melissa flavors."""
1498 raise NotImplementedError("Subclasses must override this method.")
1500 @abstractmethod
1501 def _process_partial_data_reception(self,
1502 simulation: Simulation,
1503 simulation_data: PartialSimulationData
1504 ) -> Optional[PartialSimulationData]:
1505 """Returns a value when data has been partially received.
1506 Unique to melissa flavors."""
1507 raise NotImplementedError("Subclass must override this method.")
1509 @abstractmethod
1510 def _process_complete_data_reception(self,
1511 simulation: Simulation,
1512 simulation_data: PartialSimulationData
1513 ) -> Union[PartialSimulationData,
1514 SimulationData]:
1515 """Returns a value when data has been completely received.
1516 Unique to melissa flavors."""
1517 raise NotImplementedError("Subclass must override this method.")
1519 @abstractmethod
1520 def _receive(self) -> None:
1521 """Handles data coming from the server object.
1522 Unique to melissa flavors."""
1523 raise NotImplementedError("Subclasses must override this method.")
1525 @abstractmethod
1526 def start(self) -> None:
1527 """The high level organization of server events.
1528 Unique to melissa flavors."""
1529 raise NotImplementedError("Subclasses must override this method.")
1531 @abstractmethod
1532 def _restart_from_checkpoint(self, **kwargs) -> None:
1533 """Restarts the server object from a checkpoint.
1534 Unique to melissa flavors."""
1535 raise NotImplementedError("Subclasses must override this method.")
1537 @abstractmethod
1538 def _checkpoint(self, **kwargs) -> None:
1539 """Checkpoint the server object.
1540 Unique to melissa flavors."""
1541 raise NotImplementedError("Subclasses must override this method.")