Coverage for melissa/server/simulation.py: 88%

153 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-11-03 09:52 +0100

1"""This script defines classes associated with simulations, groups, and their metadata.""" 

2 

3import logging 

4import json 

5import struct 

6import time 

7from dataclasses import dataclass 

8from enum import Enum 

9from typing import Any, Dict, List, Optional, Union, Tuple 

10 

11import numpy as np 

12from numpy.typing import NDArray 

13 

14from melissa.utility.metadata import Payload 

15 

16 

17logger = logging.getLogger(__name__) 

18 

19 

20@dataclass 

21class SimulationData: 

22 """Stores data related to a specific simulation. 

23 

24 ### Attributes 

25 - **simulation_id** (`int`): The id of the simulation. 

26 - **time_step** (`int`): The time step for the simulation. 

27 - **payload** (`Dict[str, Payload]`): The dictionary of payload (data + metadata) associated 

28 with the simulation per field. 

29 - **parameters** (`list`): The list of parameters for the simulation.""" 

30 

31 simulation_id: int 

32 time_step: int 

33 payload: Dict[str, Payload] 

34 parameters: List 

35 

36 def __repr__(self) -> str: 

37 """Returns a string representation of the `SimulationData` object.""" 

38 s = ( 

39 f"<{self.__class__.__name__} " 

40 f"sim-id={self.simulation_id} " 

41 f"time-step={self.time_step} " 

42 f"fields={[f for f in self.payload.keys()]}>" 

43 ) 

44 return s 

45 

46 def __getitem__(self, field) -> Payload: 

47 return self.payload[field] 

48 

49 

50class PartialSimulationData: 

51 """Stores partial data for a specific simulation, 

52 including information on time step, client rank, and field. 

53 

54 ### Attributes 

55 - **time_step** (`int`): The time step for the simulation. 

56 - **simulation_id** (`int`): The id of the simulation. 

57 - **client_rank** (`int`): The rank of the client submitting the data. 

58 - **data_size** (`int`): The size of the data associated with the simulation. 

59 - **external_metadata_size** (`int`): The size of the user-given metadata. 

60 - **field** (`str`): The field to which the data belongs. 

61 - **payload** (`Payload`): The actual data + extra metadata associated with the simulation.""" 

62 

63 MAX_FIELD_NAME_SIZE = ( 

64 128 # comply with common/melissa_data.h::MELISSA_MAX_FIELD_NAME_LENGTH 

65 ) 

66 # 'I' = uint32_t 

67 # 'i' = int32 

68 # 'Q' = size_t (unsigned long long, platform dependent) 

69 metadata_struct_fmt = f"<2Ii2Q{MAX_FIELD_NAME_SIZE + 1}s" 

70 

71 def __init__( 

72 self, 

73 time_step: int, 

74 simulation_id: int, 

75 client_rank: int, 

76 data_size: int, 

77 external_metadata_size: int, 

78 field: str, 

79 payload: Payload 

80 ) -> None: 

81 

82 self.time_step = time_step 

83 self.simulation_id = simulation_id 

84 self.client_rank = client_rank 

85 self.data_size = data_size 

86 self.external_metadata_size = external_metadata_size 

87 self.field = field 

88 self.payload = payload 

89 

90 @classmethod 

91 def from_msg(cls, msg: bytes, learning: int) -> "PartialSimulationData": 

92 """Class method to deserialize a message and create an instance of the class. 

93 

94 ### Parameters 

95 - **msg** (`bytes`): The serialized message in bytes format to be deserialized. 

96 - **learning** (`int`): A flag or parameter used during the deserialization process. 

97 

98 ### Returns 

99 - `PartialSimulationData`: A new instance of the class created from the 

100 deserialized message.""" 

101 

102 time_step: int 

103 simulation_id: int 

104 client_rank: int 

105 data_size: int 

106 external_metadata_size: int 

107 external_metadata: Dict[str, Any] 

108 field: bytearray 

109 

110 # unpack data 

111 internal_metadata_size = struct.calcsize(cls.metadata_struct_fmt) 

112 ( 

113 time_step, 

114 simulation_id, 

115 client_rank, 

116 data_size, 

117 external_metadata_size, 

118 field, 

119 ) = struct.unpack(cls.metadata_struct_fmt, msg[:internal_metadata_size]) 

120 field_name: str = field.split(b"\x00")[0].decode("utf-8") 

121 

122 # unpack external metadata 

123 if external_metadata_size > 0: 

124 json_str = struct.unpack_from( 

125 f"{external_metadata_size}s", msg, offset=internal_metadata_size 

126 )[0] 

127 external_metadata = json.loads(json_str) 

128 else: 

129 external_metadata = {} 

130 

131 # unpack data array (FP32 for DL FP64 for SA) 

132 # though already converted on the client side before sending 

133 dtype = np.dtype(np.float32 if learning > 0 else np.float64) 

134 data: NDArray = np.frombuffer( 

135 msg, offset=internal_metadata_size + external_metadata_size, dtype=dtype 

136 ) 

137 payload = Payload(data=data, metadata=external_metadata) 

138 

139 return cls( 

140 time_step, 

141 simulation_id, 

142 client_rank, 

143 data_size, 

144 external_metadata_size, 

145 field_name, 

146 payload 

147 ) 

148 

149 @property 

150 def data(self) -> NDArray: 

151 return self.payload.data 

152 

153 @property 

154 def metadata(self) -> Dict[str, Any]: 

155 return self.payload.metadata 

156 

157 def __repr__(self) -> str: 

158 """Returns a string representation of the `PartialSimulationData` object.""" 

159 return ( 

160 f"<{self.__class__.__name__}: sim-id={self.simulation_id}, " 

161 f'time-step={self.time_step}, field="{self.field}">' 

162 ) 

163 

164 

165class SimulationDataStatus(Enum): 

166 """Enum class representing the possible statuses of simulation data.""" 

167 

168 PARTIAL = 0 

169 COMPLETE = 1 

170 ALREADY_RECEIVED = 2 

171 EMPTY = 3 

172 

173 

174class SimulationStatus(Enum): 

175 """Enum class representing the possible statuses of a simulation's state.""" 

176 

177 CONNECTED = 0 

178 RUNNING = 1 

179 FINISHED = 2 

180 

181 

182class Simulation: 

183 """Represents a single simulation with associated metadata. Each object corresponds 

184 to a unique simulation or client and contains information necessary for tracking 

185 the simulation's state, parameters, and fault tolerance. 

186 

187 ### Attributes 

188 - **id_** (`int`): The unique identifier for the simulation. 

189 - **script_path** (`str`): Path of the simulation script (to be generated). 

190 - **connected** (`bool`): If the simulation has established a connection with the server. 

191 - **nb_time_steps** (`int`): The total number of time steps for the simulation. 

192 - **fields** (`List[str]`): A list of fields that will be sent. 

193 - **parameters** (`List[Any]`): A list of parameters used in the simulation.""" 

194 

195 def __init__( 

196 self, 

197 id_: int, 

198 script_path: str, 

199 nb_time_steps: int, 

200 fields: List[str], 

201 parameters: List[Any], 

202 ) -> None: 

203 

204 self.id = id_ 

205 self.script_path: str = script_path 

206 self.nb_time_steps = nb_time_steps 

207 self.fields = fields 

208 self.connected: bool = False 

209 self.received_simulation_data: Dict[ 

210 int, Dict[int, Dict[str, Optional[Union[int, PartialSimulationData]]]] 

211 ] = {} # client_rank, time_step, field, int (SA) or PartialSimulationData (DL) 

212 self.received_time_steps: Dict[int, NDArray[np.bool_]] = {} 

213 self.parameters: List[Any] = parameters 

214 self.nb_received_time_steps: int = 0 

215 self.nb_failures: int = 0 

216 self.last_message: Optional[float] = None 

217 self.finished: bool = False 

218 self.duration: float = -1.0 

219 self.t0: float = -1.0 

220 

221 def init_structures(self, client_rank: int, time_steps_known: bool = True) -> None: 

222 """Initializes data structures to track received simulation data for a given client rank. 

223 

224 ### Parameters 

225 - **client_rank** (`int`): The rank of the client whose data is being tracked. 

226 - **time_steps_known** (`bool`, optional): if the total number of time steps 

227 is known in advance. Default is `True`.""" 

228 

229 # ensure client rank is initialized 

230 if client_rank not in self.received_simulation_data: 

231 self.received_simulation_data[client_rank] = {} 

232 # Known sample count 

233 shape = (len(self.fields), self.nb_time_steps if time_steps_known else 1) 

234 self.received_time_steps[client_rank] = np.zeros(shape=shape, dtype=bool) 

235 

236 if self.t0 < 0: 

237 self.t0 = time.time() 

238 

239 def init_data_storage(self, client_rank: int, time_step: int) -> None: 

240 """Prepares storage for tracking field-level data 

241 for a specific time step of a client rank. 

242 

243 ### Parameters 

244 - **client_rank** (`int`): The rank of the client whose data is being tracked. 

245 - **time_step** (`int`): The time step for which data storage is initialized.""" 

246 

247 if time_step not in self.received_simulation_data[client_rank]: 

248 self.received_simulation_data[client_rank][time_step] = { 

249 f: None for f in self.fields 

250 } 

251 

252 def time_step_expansion(self, client_rank: int, time_step: int) -> None: 

253 """Dynamically expands the `received_time_steps` matrix for a client rank 

254 if the time step exceeds current capacity. 

255 

256 ### Parameters 

257 - **client_rank** (`int`): The rank of the client whose matrix needs expansion. 

258 - **time_step** (`int`): The time step that triggered the need for expansion.""" 

259 

260 if time_step > self.received_time_steps[client_rank].shape[1] - 1: 

261 self.received_time_steps[client_rank] = np.concatenate( 

262 [ 

263 self.received_time_steps[client_rank], 

264 np.zeros((len(self.fields), 1), dtype=bool), 

265 ], 

266 axis=1, 

267 ) 

268 

269 self.nb_time_steps = max(self.nb_time_steps, time_step + 1) 

270 

271 def update( 

272 self, 

273 client_rank: int, 

274 time_step: int, 

275 field: str, 

276 data: Optional[Union[int, PartialSimulationData]], 

277 ) -> None: 

278 """Updates the data associated with a specific field and time step for a client rank. 

279 

280 ### Parameters 

281 - **client_rank** (`int`): The rank of the client whose data is being updated. 

282 - **time_step** (`int`): The time step associated with the data. 

283 - **field** (`str`): The specific field being updated. 

284 - **data** (`Optional[Union[int, PartialSimulationData]]`): 

285 The new data to store for the field.""" 

286 self.received_simulation_data[client_rank][time_step][field] = data 

287 

288 def get_data( 

289 self, client_rank: int, time_step: int 

290 ) -> Dict[str, Optional[Union[int, PartialSimulationData]]]: 

291 """Retrieves the data for all fields associated with a specific time step of a client rank. 

292 

293 ### Parameters 

294 - **client_rank** (`int`): The rank of the client whose data is being retrieved. 

295 - **time_step** (`int`): The time step for which data is being fetched. 

296 

297 ### Returns 

298 - `Dict[str, Optional[Union[int, PartialSimulationData]]]`: 

299 A dictionary containing field names as keys and their corresponding data as values. 

300 """ 

301 return self.received_simulation_data[client_rank][time_step] 

302 

303 def clear_data(self, client_rank: int, time_step: int) -> None: 

304 """Clears the data for all fields associated with a specific time step of a client rank. 

305 Useful when the complete data is given to the server for post-processing and no longer needs 

306 to be in this object. This is to avoid duplications in checkpointing. 

307 

308 ### Parameters 

309 - **client_rank** (`int`): The rank of the client whose data is being cleaned. 

310 - **time_step** (`int`): The time step for which data is being cleaned.""" 

311 

312 ref = self.received_simulation_data[client_rank][time_step] 

313 for field in ref.keys(): 

314 ref[field] = None 

315 

316 def has_already_received( 

317 self, client_rank: int, time_step: int, field: str 

318 ) -> bool: 

319 """Checks if the given time step has already been received 

320 for the specified client and field, helping to avoid duplication of data. 

321 

322 ### Parameters 

323 - **client_rank** (`int`): The rank of the client whose data is being retrieved. 

324 - **time_step** (`int`): The time step for which the data is being checked. 

325 - **field** (`str`): The field associated with the data being checked. 

326 

327 ### Returns 

328 - `bool`: if the data has already been received.""" 

329 

330 field_idx = self.fields.index(field) 

331 return self.received_time_steps[client_rank][field_idx, time_step] 

332 

333 def is_complete(self, time_step: int) -> bool: 

334 """Checks if the given time step has data for all the defined fields. 

335 

336 ### Parameters 

337 - **time_step** (`int`): The time step to check for completeness. 

338 

339 ### Returns 

340 - `bool`: if data has been received for all defined fields.""" 

341 

342 for _, client_data in self.received_simulation_data.items(): 

343 if time_step not in client_data or len( 

344 [ 

345 field 

346 for field in client_data[time_step].values() 

347 if field is not None 

348 ] 

349 ) != len(self.fields): 

350 return False 

351 return True 

352 

353 def mark_as_received(self, client_rank: int, time_step: int, field: str) -> None: 

354 """Marks the given time step as received for the specified field. 

355 

356 ### Parameters 

357 - **client_rank** (`int`): The rank of the client sending the data. 

358 - **time_step** (`int`): The time step for which the data is received. 

359 - **field** (`str`): The field associated with the data being marked as received. 

360 """ 

361 

362 field_idx = self.fields.index(field) 

363 self.received_time_steps[client_rank][field_idx, time_step] = True 

364 self.last_message = time.time() 

365 

366 def has_finished(self, force: bool = False) -> bool: 

367 """Checks whether all time steps for the simulation have been received, 

368 or forcefully marks as finished, if requested. 

369 

370 Returns: 

371 bool: True if all time steps have been received or force is True. 

372 """ 

373 

374 def update_duration(): 

375 self.duration = time.time() - self.t0 

376 

377 if force: 

378 self.finished = True 

379 if self.t0 > 0: 

380 update_duration() 

381 return True 

382 

383 if ( 

384 not self.finished 

385 and self.nb_time_steps > 0 

386 and self.nb_received_time_steps == self.nb_time_steps 

387 ): 

388 self.finished = True 

389 update_duration() 

390 

391 return self.finished 

392 

393 

394class Group: 

395 """Represents a Sobol group with all the relevant simulations. 

396 

397 This class maintains a caching mechanism for storing corresponding 

398 time steps received for the current group. 

399 

400 ### Attributes 

401 - **group_id** (`int`): The id of the group. 

402 - **sobol_** (`bool`): A flag indicating whether the group uses Sobol cache. 

403 (default is False).""" 

404 

405 def __init__(self, group_id: int, sobol_: bool = False) -> None: 

406 

407 self.group_id = group_id 

408 self.simulations: Dict[int, Simulation] = {} 

409 self.submitted: bool = False 

410 self.nb_failures: int = 0 

411 self.sobol_on: bool = sobol_ 

412 self.sobol_cache: Dict[Tuple[str, int, int], Dict[int, np.ndarray]] = {} 

413 if self.sobol_on: 

414 self.sobol_cache = {} 

415 self.finished: bool = False 

416 

417 def __len__(self) -> int: 

418 """Returns length of the current group.""" 

419 return len(self.simulations) 

420 

421 def cache(self, pdata: PartialSimulationData) -> None: 

422 """Caches the received simulation data for the 

423 specified field, simulation id, client rank, and time step. 

424 

425 ### Parameters 

426 - **pdata** (`PartialSimulationData`): The simulation data to be cached.""" 

427 

428 key = (pdata.field, pdata.client_rank, pdata.time_step) 

429 if key not in self.sobol_cache: 

430 self.sobol_cache[key] = {} 

431 self.sobol_cache[key][pdata.simulation_id] = pdata.data 

432 

433 def get_cached( 

434 self, field: str, client_rank: int, time_step: int 

435 ) -> NDArray[np.float64]: 

436 """Retrieves cached group data for the 

437 specified field, client rank, and time step. 

438 

439 ### Parameters 

440 - **field** (`str`): The field associated with the simulation data. 

441 - **client_rank** (`int`): The rank of the client whose data is being retrieved. 

442 - **time_step** (`int`): The time step of the simulation. 

443 

444 ### Returns 

445 - `NDArray[np.float64]`: The cached data sorted 

446 by simulation ids.""" 

447 

448 key = (field, client_rank, time_step) 

449 if key not in self.sobol_cache and len(self.sobol_cache[key]) != len(self): 

450 return np.array([]) 

451 

452 return np.array( 

453 [ 

454 self.sobol_cache[key][sim_id] 

455 for sim_id in sorted(self.sobol_cache[key].keys()) 

456 ] 

457 ) 

458 

459 def has_finished(self) -> bool: 

460 """Marks the group as finished if all simulations have finished.""" 

461 self.finished = all(sim.has_finished() for sim in self.simulations.values()) 

462 return self.finished 

463 

464 def has_connected(self) -> bool: 

465 """Returns if all simulations of the group have connected.""" 

466 return all(sim.connected for sim in self.simulations.values())