Coverage for melissa/server/simulation.py: 88%
153 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
1"""This script defines classes associated with simulations, groups, and their metadata."""
3import logging
4import json
5import struct
6import time
7from dataclasses import dataclass
8from enum import Enum
9from typing import Any, Dict, List, Optional, Union, Tuple
11import numpy as np
12from numpy.typing import NDArray
14from melissa.utility.metadata import Payload
17logger = logging.getLogger(__name__)
20@dataclass
21class SimulationData:
22 """Stores data related to a specific simulation.
24 ### Attributes
25 - **simulation_id** (`int`): The id of the simulation.
26 - **time_step** (`int`): The time step for the simulation.
27 - **payload** (`Dict[str, Payload]`): The dictionary of payload (data + metadata) associated
28 with the simulation per field.
29 - **parameters** (`list`): The list of parameters for the simulation."""
31 simulation_id: int
32 time_step: int
33 payload: Dict[str, Payload]
34 parameters: List
36 def __repr__(self) -> str:
37 """Returns a string representation of the `SimulationData` object."""
38 s = (
39 f"<{self.__class__.__name__} "
40 f"sim-id={self.simulation_id} "
41 f"time-step={self.time_step} "
42 f"fields={[f for f in self.payload.keys()]}>"
43 )
44 return s
46 def __getitem__(self, field) -> Payload:
47 return self.payload[field]
50class PartialSimulationData:
51 """Stores partial data for a specific simulation,
52 including information on time step, client rank, and field.
54 ### Attributes
55 - **time_step** (`int`): The time step for the simulation.
56 - **simulation_id** (`int`): The id of the simulation.
57 - **client_rank** (`int`): The rank of the client submitting the data.
58 - **data_size** (`int`): The size of the data associated with the simulation.
59 - **external_metadata_size** (`int`): The size of the user-given metadata.
60 - **field** (`str`): The field to which the data belongs.
61 - **payload** (`Payload`): The actual data + extra metadata associated with the simulation."""
63 MAX_FIELD_NAME_SIZE = (
64 128 # comply with common/melissa_data.h::MELISSA_MAX_FIELD_NAME_LENGTH
65 )
66 # 'I' = uint32_t
67 # 'i' = int32
68 # 'Q' = size_t (unsigned long long, platform dependent)
69 metadata_struct_fmt = f"<2Ii2Q{MAX_FIELD_NAME_SIZE + 1}s"
71 def __init__(
72 self,
73 time_step: int,
74 simulation_id: int,
75 client_rank: int,
76 data_size: int,
77 external_metadata_size: int,
78 field: str,
79 payload: Payload
80 ) -> None:
82 self.time_step = time_step
83 self.simulation_id = simulation_id
84 self.client_rank = client_rank
85 self.data_size = data_size
86 self.external_metadata_size = external_metadata_size
87 self.field = field
88 self.payload = payload
90 @classmethod
91 def from_msg(cls, msg: bytes, learning: int) -> "PartialSimulationData":
92 """Class method to deserialize a message and create an instance of the class.
94 ### Parameters
95 - **msg** (`bytes`): The serialized message in bytes format to be deserialized.
96 - **learning** (`int`): A flag or parameter used during the deserialization process.
98 ### Returns
99 - `PartialSimulationData`: A new instance of the class created from the
100 deserialized message."""
102 time_step: int
103 simulation_id: int
104 client_rank: int
105 data_size: int
106 external_metadata_size: int
107 external_metadata: Dict[str, Any]
108 field: bytearray
110 # unpack data
111 internal_metadata_size = struct.calcsize(cls.metadata_struct_fmt)
112 (
113 time_step,
114 simulation_id,
115 client_rank,
116 data_size,
117 external_metadata_size,
118 field,
119 ) = struct.unpack(cls.metadata_struct_fmt, msg[:internal_metadata_size])
120 field_name: str = field.split(b"\x00")[0].decode("utf-8")
122 # unpack external metadata
123 if external_metadata_size > 0:
124 json_str = struct.unpack_from(
125 f"{external_metadata_size}s", msg, offset=internal_metadata_size
126 )[0]
127 external_metadata = json.loads(json_str)
128 else:
129 external_metadata = {}
131 # unpack data array (FP32 for DL FP64 for SA)
132 # though already converted on the client side before sending
133 dtype = np.dtype(np.float32 if learning > 0 else np.float64)
134 data: NDArray = np.frombuffer(
135 msg, offset=internal_metadata_size + external_metadata_size, dtype=dtype
136 )
137 payload = Payload(data=data, metadata=external_metadata)
139 return cls(
140 time_step,
141 simulation_id,
142 client_rank,
143 data_size,
144 external_metadata_size,
145 field_name,
146 payload
147 )
149 @property
150 def data(self) -> NDArray:
151 return self.payload.data
153 @property
154 def metadata(self) -> Dict[str, Any]:
155 return self.payload.metadata
157 def __repr__(self) -> str:
158 """Returns a string representation of the `PartialSimulationData` object."""
159 return (
160 f"<{self.__class__.__name__}: sim-id={self.simulation_id}, "
161 f'time-step={self.time_step}, field="{self.field}">'
162 )
165class SimulationDataStatus(Enum):
166 """Enum class representing the possible statuses of simulation data."""
168 PARTIAL = 0
169 COMPLETE = 1
170 ALREADY_RECEIVED = 2
171 EMPTY = 3
174class SimulationStatus(Enum):
175 """Enum class representing the possible statuses of a simulation's state."""
177 CONNECTED = 0
178 RUNNING = 1
179 FINISHED = 2
182class Simulation:
183 """Represents a single simulation with associated metadata. Each object corresponds
184 to a unique simulation or client and contains information necessary for tracking
185 the simulation's state, parameters, and fault tolerance.
187 ### Attributes
188 - **id_** (`int`): The unique identifier for the simulation.
189 - **script_path** (`str`): Path of the simulation script (to be generated).
190 - **connected** (`bool`): If the simulation has established a connection with the server.
191 - **nb_time_steps** (`int`): The total number of time steps for the simulation.
192 - **fields** (`List[str]`): A list of fields that will be sent.
193 - **parameters** (`List[Any]`): A list of parameters used in the simulation."""
195 def __init__(
196 self,
197 id_: int,
198 script_path: str,
199 nb_time_steps: int,
200 fields: List[str],
201 parameters: List[Any],
202 ) -> None:
204 self.id = id_
205 self.script_path: str = script_path
206 self.nb_time_steps = nb_time_steps
207 self.fields = fields
208 self.connected: bool = False
209 self.received_simulation_data: Dict[
210 int, Dict[int, Dict[str, Optional[Union[int, PartialSimulationData]]]]
211 ] = {} # client_rank, time_step, field, int (SA) or PartialSimulationData (DL)
212 self.received_time_steps: Dict[int, NDArray[np.bool_]] = {}
213 self.parameters: List[Any] = parameters
214 self.nb_received_time_steps: int = 0
215 self.nb_failures: int = 0
216 self.last_message: Optional[float] = None
217 self.finished: bool = False
218 self.duration: float = -1.0
219 self.t0: float = -1.0
221 def init_structures(self, client_rank: int, time_steps_known: bool = True) -> None:
222 """Initializes data structures to track received simulation data for a given client rank.
224 ### Parameters
225 - **client_rank** (`int`): The rank of the client whose data is being tracked.
226 - **time_steps_known** (`bool`, optional): if the total number of time steps
227 is known in advance. Default is `True`."""
229 # ensure client rank is initialized
230 if client_rank not in self.received_simulation_data:
231 self.received_simulation_data[client_rank] = {}
232 # Known sample count
233 shape = (len(self.fields), self.nb_time_steps if time_steps_known else 1)
234 self.received_time_steps[client_rank] = np.zeros(shape=shape, dtype=bool)
236 if self.t0 < 0:
237 self.t0 = time.time()
239 def init_data_storage(self, client_rank: int, time_step: int) -> None:
240 """Prepares storage for tracking field-level data
241 for a specific time step of a client rank.
243 ### Parameters
244 - **client_rank** (`int`): The rank of the client whose data is being tracked.
245 - **time_step** (`int`): The time step for which data storage is initialized."""
247 if time_step not in self.received_simulation_data[client_rank]:
248 self.received_simulation_data[client_rank][time_step] = {
249 f: None for f in self.fields
250 }
252 def time_step_expansion(self, client_rank: int, time_step: int) -> None:
253 """Dynamically expands the `received_time_steps` matrix for a client rank
254 if the time step exceeds current capacity.
256 ### Parameters
257 - **client_rank** (`int`): The rank of the client whose matrix needs expansion.
258 - **time_step** (`int`): The time step that triggered the need for expansion."""
260 if time_step > self.received_time_steps[client_rank].shape[1] - 1:
261 self.received_time_steps[client_rank] = np.concatenate(
262 [
263 self.received_time_steps[client_rank],
264 np.zeros((len(self.fields), 1), dtype=bool),
265 ],
266 axis=1,
267 )
269 self.nb_time_steps = max(self.nb_time_steps, time_step + 1)
271 def update(
272 self,
273 client_rank: int,
274 time_step: int,
275 field: str,
276 data: Optional[Union[int, PartialSimulationData]],
277 ) -> None:
278 """Updates the data associated with a specific field and time step for a client rank.
280 ### Parameters
281 - **client_rank** (`int`): The rank of the client whose data is being updated.
282 - **time_step** (`int`): The time step associated with the data.
283 - **field** (`str`): The specific field being updated.
284 - **data** (`Optional[Union[int, PartialSimulationData]]`):
285 The new data to store for the field."""
286 self.received_simulation_data[client_rank][time_step][field] = data
288 def get_data(
289 self, client_rank: int, time_step: int
290 ) -> Dict[str, Optional[Union[int, PartialSimulationData]]]:
291 """Retrieves the data for all fields associated with a specific time step of a client rank.
293 ### Parameters
294 - **client_rank** (`int`): The rank of the client whose data is being retrieved.
295 - **time_step** (`int`): The time step for which data is being fetched.
297 ### Returns
298 - `Dict[str, Optional[Union[int, PartialSimulationData]]]`:
299 A dictionary containing field names as keys and their corresponding data as values.
300 """
301 return self.received_simulation_data[client_rank][time_step]
303 def clear_data(self, client_rank: int, time_step: int) -> None:
304 """Clears the data for all fields associated with a specific time step of a client rank.
305 Useful when the complete data is given to the server for post-processing and no longer needs
306 to be in this object. This is to avoid duplications in checkpointing.
308 ### Parameters
309 - **client_rank** (`int`): The rank of the client whose data is being cleaned.
310 - **time_step** (`int`): The time step for which data is being cleaned."""
312 ref = self.received_simulation_data[client_rank][time_step]
313 for field in ref.keys():
314 ref[field] = None
316 def has_already_received(
317 self, client_rank: int, time_step: int, field: str
318 ) -> bool:
319 """Checks if the given time step has already been received
320 for the specified client and field, helping to avoid duplication of data.
322 ### Parameters
323 - **client_rank** (`int`): The rank of the client whose data is being retrieved.
324 - **time_step** (`int`): The time step for which the data is being checked.
325 - **field** (`str`): The field associated with the data being checked.
327 ### Returns
328 - `bool`: if the data has already been received."""
330 field_idx = self.fields.index(field)
331 return self.received_time_steps[client_rank][field_idx, time_step]
333 def is_complete(self, time_step: int) -> bool:
334 """Checks if the given time step has data for all the defined fields.
336 ### Parameters
337 - **time_step** (`int`): The time step to check for completeness.
339 ### Returns
340 - `bool`: if data has been received for all defined fields."""
342 for _, client_data in self.received_simulation_data.items():
343 if time_step not in client_data or len(
344 [
345 field
346 for field in client_data[time_step].values()
347 if field is not None
348 ]
349 ) != len(self.fields):
350 return False
351 return True
353 def mark_as_received(self, client_rank: int, time_step: int, field: str) -> None:
354 """Marks the given time step as received for the specified field.
356 ### Parameters
357 - **client_rank** (`int`): The rank of the client sending the data.
358 - **time_step** (`int`): The time step for which the data is received.
359 - **field** (`str`): The field associated with the data being marked as received.
360 """
362 field_idx = self.fields.index(field)
363 self.received_time_steps[client_rank][field_idx, time_step] = True
364 self.last_message = time.time()
366 def has_finished(self, force: bool = False) -> bool:
367 """Checks whether all time steps for the simulation have been received,
368 or forcefully marks as finished, if requested.
370 Returns:
371 bool: True if all time steps have been received or force is True.
372 """
374 def update_duration():
375 self.duration = time.time() - self.t0
377 if force:
378 self.finished = True
379 if self.t0 > 0:
380 update_duration()
381 return True
383 if (
384 not self.finished
385 and self.nb_time_steps > 0
386 and self.nb_received_time_steps == self.nb_time_steps
387 ):
388 self.finished = True
389 update_duration()
391 return self.finished
394class Group:
395 """Represents a Sobol group with all the relevant simulations.
397 This class maintains a caching mechanism for storing corresponding
398 time steps received for the current group.
400 ### Attributes
401 - **group_id** (`int`): The id of the group.
402 - **sobol_** (`bool`): A flag indicating whether the group uses Sobol cache.
403 (default is False)."""
405 def __init__(self, group_id: int, sobol_: bool = False) -> None:
407 self.group_id = group_id
408 self.simulations: Dict[int, Simulation] = {}
409 self.submitted: bool = False
410 self.nb_failures: int = 0
411 self.sobol_on: bool = sobol_
412 self.sobol_cache: Dict[Tuple[str, int, int], Dict[int, np.ndarray]] = {}
413 if self.sobol_on:
414 self.sobol_cache = {}
415 self.finished: bool = False
417 def __len__(self) -> int:
418 """Returns length of the current group."""
419 return len(self.simulations)
421 def cache(self, pdata: PartialSimulationData) -> None:
422 """Caches the received simulation data for the
423 specified field, simulation id, client rank, and time step.
425 ### Parameters
426 - **pdata** (`PartialSimulationData`): The simulation data to be cached."""
428 key = (pdata.field, pdata.client_rank, pdata.time_step)
429 if key not in self.sobol_cache:
430 self.sobol_cache[key] = {}
431 self.sobol_cache[key][pdata.simulation_id] = pdata.data
433 def get_cached(
434 self, field: str, client_rank: int, time_step: int
435 ) -> NDArray[np.float64]:
436 """Retrieves cached group data for the
437 specified field, client rank, and time step.
439 ### Parameters
440 - **field** (`str`): The field associated with the simulation data.
441 - **client_rank** (`int`): The rank of the client whose data is being retrieved.
442 - **time_step** (`int`): The time step of the simulation.
444 ### Returns
445 - `NDArray[np.float64]`: The cached data sorted
446 by simulation ids."""
448 key = (field, client_rank, time_step)
449 if key not in self.sobol_cache and len(self.sobol_cache[key]) != len(self):
450 return np.array([])
452 return np.array(
453 [
454 self.sobol_cache[key][sim_id]
455 for sim_id in sorted(self.sobol_cache[key].keys())
456 ]
457 )
459 def has_finished(self) -> bool:
460 """Marks the group as finished if all simulations have finished."""
461 self.finished = all(sim.has_finished() for sim in self.simulations.values())
462 return self.finished
464 def has_connected(self) -> bool:
465 """Returns if all simulations of the group have connected."""
466 return all(sim.connected for sim in self.simulations.values())