Coverage for melissa/server/simulation.py: 92%
120 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 22:25 +0100
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 22:25 +0100
1"""This script defines classes associated with simulations, groups, and their metadata."""
3import logging
4import struct
5import time
6from dataclasses import dataclass
7from enum import Enum
8from typing import Any, Dict, List, Optional, Union, Tuple
10import numpy as np
11from numpy.typing import NDArray, ArrayLike
13logger = logging.getLogger(__name__)
16@dataclass
17class SimulationData:
18 """Stores data related to a specific simulation.
20 ### Attributes
21 - **simulation_id** (`int`): The id of the simulation.
22 - **time_step** (`int`): The time step for the simulation.
23 - **data** (`Dict[str, Any]`): The data associated with the simulation per field.
24 - **parameters** (`list`): The list of parameters for the simulation."""
26 simulation_id: int
27 time_step: int
28 data: Dict[str, Any]
29 parameters: List
31 def __repr__(self) -> str:
32 """Returns a string representation of the `SimulationData` object."""
33 s = (
34 f"<{self.__class__.__name__} "
35 f"sim-id={self.simulation_id} "
36 f"time-step={self.time_step}>"
37 )
38 return s
41class PartialSimulationData:
42 """Stores partial data for a specific simulation,
43 including information on time step, client rank, and field.
45 ### Attributes
46 - **time_step** (`int`): The time step for the simulation.
47 - **simulation_id** (`int`): The id of the simulation.
48 - **client_rank** (`int`): The rank of the client submitting the data.
49 - **data_size** (`int`): The size of the data associated with the simulation.
50 - **field** (`str`): The field to which the data belongs.
51 - **data** (`Any`): The actual data associated with the simulation."""
53 MAX_FIELD_NAME_SIZE = 128
55 def __init__(
56 self,
57 time_step: int,
58 simulation_id: int,
59 client_rank: int,
60 data_size: int,
61 field: str,
62 data: Any,
63 ) -> None:
65 self.time_step = time_step
66 self.simulation_id = simulation_id
67 self.client_rank = client_rank
68 self.data_size = data_size
69 self.field = field
70 self.data = data
72 @classmethod
73 def from_msg(cls,
74 msg: bytes,
75 learning: int) -> "PartialSimulationData":
77 """Class method to deserialize a message and create an instance of the class.
79 ### Parameters
80 - **msg** (`bytes`): The serialized message in bytes format to be deserialized.
81 - **learning** (`int`): A flag or parameter used during the deserialization process.
83 ### Returns
84 - `PartialSimulationData`: A new instance of the class created from the
85 deserialized message."""
87 time_step: int
88 simulation_id: int
89 client_rank: int
90 data_size: int
91 field: bytearray
93 # unpack metadata
94 size_metadata: int = 4 * 4 + cls.MAX_FIELD_NAME_SIZE
95 time_step, simulation_id, client_rank, data_size, field = struct.unpack(
96 f"4i{cls.MAX_FIELD_NAME_SIZE}s", msg[: size_metadata]
97 )
98 field_name: str = field.split(b"\x00")[0].decode("utf-8")
100 # unpack data array (float32 for DL and float64 for SA)
101 dtype = np.dtype(np.float32 if learning > 0 else np.float64)
102 data: ArrayLike = np.frombuffer(
103 msg,
104 offset=size_metadata,
105 dtype=dtype
106 )
108 return cls(time_step, simulation_id, client_rank, data_size, field_name, data)
110 def __repr__(self) -> str:
111 """Returns a string representation of the `PartialSimulationData` object."""
112 return (
113 f"<{self.__class__.__name__}: sim-id={self.simulation_id}, "
114 f"time-step={self.time_step}, field=\"{self.field}\""
115 )
118class SimulationDataStatus(Enum):
119 """Enum class representing the possible statuses of simulation data."""
120 PARTIAL = 0
121 COMPLETE = 1
122 ALREADY_RECEIVED = 2
123 EMPTY = 3
126class SimulationStatus(Enum):
127 """Enum class representing the possible statuses of a simulation's state."""
128 CONNECTED = 0
129 RUNNING = 1
130 FINISHED = 2
133class Simulation:
134 """Represents a single simulation with associated metadata. Each object corresponds
135 to a unique simulation or client and contains information necessary for tracking
136 the simulation's state, parameters, and fault tolerance.
138 ### Attributes
139 - **id_** (`int`): The unique identifier for the simulation.
140 - **nb_time_steps** (`int`): The total number of time steps for the simulation.
141 - **fields** (`List[str]`): A list of fields that will be sent.
142 - **parameters** (`List[Any]`): A list of parameters used in the simulation."""
143 def __init__(self,
144 id_: int,
145 nb_time_steps: int,
146 fields: List[str],
147 parameters: List[Any]) -> None:
149 self.id = id_
150 self.nb_time_steps = nb_time_steps
151 self.fields = fields
152 self.received_simulation_data: Dict[
153 int, Dict[int, Dict[str, Optional[Union[int, PartialSimulationData]]]]
154 ] = {} # client_rank, time_step, field, int (SA) or PartialSimulationData (DL)
155 self.received_time_steps: Dict[int, NDArray[np.bool_]] = {}
156 self.parameters: List[Any] = parameters
157 self.nb_received_time_steps: int = 0
158 self.nb_failures: int = 0
159 self.last_message: Optional[float] = None
160 self.duration: float = -1.
161 self.t0: float = -1.
163 def init_structures(self, client_rank: int, time_steps_known: bool = True) -> None:
164 """Initializes data structures to track received simulation data for a given client rank.
166 ### Parameters
167 - **client_rank** (`int`): The rank of the client whose data is being tracked.
168 - **time_steps_known** (`bool`, optional): if the total number of time steps
169 is known in advance. Default is `True`."""
171 # ensure client rank is initialized
172 if client_rank not in self.received_simulation_data:
173 self.received_simulation_data[client_rank] = {}
174 # Known sample count
175 shape = (len(self.fields), self.nb_time_steps if time_steps_known else 1)
176 self.received_time_steps[client_rank] = np.zeros(
177 shape=shape,
178 dtype=bool
179 )
181 if self.t0 < 0:
182 self.t0 = time.time()
184 def init_data_storage(self, client_rank: int, time_step: int) -> None:
185 """Prepares storage for tracking field-level data
186 for a specific time step of a client rank.
188 ### Parameters
189 - **client_rank** (`int`): The rank of the client whose data is being tracked.
190 - **time_step** (`int`): The time step for which data storage is initialized."""
192 if time_step not in self.received_simulation_data[client_rank]:
193 self.received_simulation_data[client_rank][time_step] = {
194 f: None for f in self.fields
195 }
197 def time_step_expansion(self, client_rank: int, time_step: int) -> None:
198 """Dynamically expands the `received_time_steps` matrix for a client rank
199 if the time step exceeds current capacity.
201 ### Parameters
202 - **client_rank** (`int`): The rank of the client whose matrix needs expansion.
203 - **time_step** (`int`): The time step that triggered the need for expansion."""
205 if time_step > self.received_time_steps[client_rank].shape[1] - 1:
206 self.received_time_steps[client_rank] = np.concatenate(
207 [
208 self.received_time_steps[client_rank],
209 np.zeros((len(self.fields), 1), dtype=bool),
210 ],
211 axis=1,
212 )
214 def update(self,
215 client_rank: int,
216 time_step: int,
217 field: str,
218 data: Optional[Union[int, PartialSimulationData]]) -> None:
219 """Updates the data associated with a specific field and time step for a client rank.
221 ### Parameters
222 - **client_rank** (`int`): The rank of the client whose data is being updated.
223 - **time_step** (`int`): The time step associated with the data.
224 - **field** (`str`): The specific field being updated.
225 - **data** (`Optional[Union[int, PartialSimulationData]]`):
226 The new data to store for the field."""
227 self.received_simulation_data[client_rank][time_step][field] = data
229 def get_data(self,
230 client_rank: int,
231 time_step: int) -> Dict[str, Optional[Union[int, PartialSimulationData]]]:
232 """Retrieves the data for all fields associated with a specific time step of a client rank.
234 ### Parameters
235 - **client_rank** (`int`): The rank of the client whose data is being retrieved.
236 - **time_step** (`int`): The time step for which data is being fetched.
238 ### Returns
239 - `Dict[str, Optional[Union[int, PartialSimulationData]]]`:
240 A dictionary containing field names as keys and their corresponding data as values."""
241 return self.received_simulation_data[client_rank][time_step]
243 def clear_data(self,
244 client_rank: int,
245 time_step: int) -> None:
246 """Clears the data for all fields associated with a specific time step of a client rank.
247 Useful when the complete data is given to the server for post-processing and no longer needs
248 to be in this object. This is to avoid duplications in checkpointing.
250 ### Parameters
251 - **client_rank** (`int`): The rank of the client whose data is being cleaned.
252 - **time_step** (`int`): The time step for which data is being cleaned."""
254 ref = self.received_simulation_data[client_rank][time_step]
255 for field in ref.keys():
256 ref[field] = None
258 # TODO: remove
259 def crashed(self) -> bool:
260 """Unused."""
261 return False
263 def has_already_received(self,
264 client_rank: int,
265 time_step: int,
266 field: str) -> bool:
267 """Checks if the given time step has already been received
268 for the specified client and field, helping to avoid duplication of data.
270 ### Parameters
271 - **client_rank** (`int`): The rank of the client whose data is being retrieved.
272 - **time_step** (`int`): The time step for which the data is being checked.
273 - **field** (`str`): The field associated with the data being checked.
275 ### Returns
276 - `bool`: if the data has already been received."""
278 field_idx = self.fields.index(field)
279 return self.received_time_steps[client_rank][field_idx, time_step]
281 def is_complete(self, time_step: int) -> bool:
282 """Checks if the given time step has data for all the defined fields.
284 ### Parameters
285 - **time_step** (`int`): The time step to check for completeness.
287 ### Returns
288 - `bool`: if data has been received for all defined fields."""
290 for _, client_data in self.received_simulation_data.items():
291 if (
292 time_step not in client_data
293 or len(
294 [field for field in client_data[time_step].values() if field is not None]
295 ) != len(self.fields)
296 ):
297 return False
298 return True
300 def mark_as_received(self,
301 client_rank: int,
302 time_step: int,
303 field: str) -> None:
304 """Marks the given time step as received for the specified field.
306 ### Parameters
307 - **client_rank** (`int`): The rank of the client sending the data.
308 - **time_step** (`int`): The time step for which the data is received.
309 - **field** (`str`): The field associated with the data being marked as received."""
311 field_idx = self.fields.index(field)
312 self.received_time_steps[client_rank][field_idx, time_step] = True
313 self.last_message = time.time()
315 def finished(self) -> bool:
316 """Checks whether all time steps for the simulation have been received.
318 ### Returns
319 - `bool`: `True` if all time steps have been received."""
321 has_finished = self.nb_received_time_steps == self.nb_time_steps
323 if has_finished:
324 self.duration = time.time() - self.t0
326 return has_finished
329class Group:
330 """Represents a Sobol group with all the relevant simulations.
332 This class maintains a caching mechanism for storing corresponding
333 time steps received for the current group.
335 ### Attributes
336 - **group_id** (`int`): The id of the group.
337 - **sobol_** (`bool`): A flag indicating whether the group uses Sobol cache.
338 (default is False)."""
339 def __init__(self,
340 group_id: int,
341 sobol_: bool = False) -> None:
343 self.group_id = group_id
344 self.simulations: Dict[int, Simulation] = {}
345 self.nb_failures: int = 0
346 self.sobol_on: bool = sobol_
347 self.sobol_cache: Dict[Tuple[str, int, int], Dict[int, np.ndarray]] = {}
348 if self.sobol_on:
349 self.sobol_cache = {}
351 def __len__(self) -> int:
352 """Returns length of the current group."""
353 return len(self.simulations)
355 def cache(self, pdata: PartialSimulationData) -> None:
356 """Caches the received simulation data for the
357 specified field, simulation id, client rank, and time step.
359 ### Parameters
360 - **pdata** (`PartialSimulationData`): The simulation data to be cached."""
362 key = (pdata.field, pdata.client_rank, pdata.time_step)
363 if key not in self.sobol_cache:
364 self.sobol_cache[key] = {}
365 self.sobol_cache[key][pdata.simulation_id] = pdata.data
367 def get_cached(self,
368 field: str,
369 client_rank: int,
370 time_step: int) -> NDArray[np.float64]:
371 """Retrieves cached group data for the
372 specified field, client rank, and time step.
374 ### Parameters
375 - **field** (`str`): The field associated with the simulation data.
376 - **client_rank** (`int`): The rank of the client whose data is being retrieved.
377 - **time_step** (`int`): The time step of the simulation.
379 ### Returns
380 - `NDArray[np.float64]`: The cached data sorted
381 by simulation ids."""
383 key = (field, client_rank, time_step)
384 if (
385 key not in self.sobol_cache
386 and len(self.sobol_cache[key]) != len(self)
387 ):
388 return np.array([])
390 return np.array([
391 self.sobol_cache[key][sim_id]
392 for sim_id in sorted(self.sobol_cache[key].keys())
393 ])