Coverage for melissa/server/simulation.py: 92%

120 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-10 22:25 +0100

1"""This script defines classes associated with simulations, groups, and their metadata.""" 

2 

3import logging 

4import struct 

5import time 

6from dataclasses import dataclass 

7from enum import Enum 

8from typing import Any, Dict, List, Optional, Union, Tuple 

9 

10import numpy as np 

11from numpy.typing import NDArray, ArrayLike 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16@dataclass 

17class SimulationData: 

18 """Stores data related to a specific simulation. 

19 

20 ### Attributes 

21 - **simulation_id** (`int`): The id of the simulation. 

22 - **time_step** (`int`): The time step for the simulation. 

23 - **data** (`Dict[str, Any]`): The data associated with the simulation per field. 

24 - **parameters** (`list`): The list of parameters for the simulation.""" 

25 

26 simulation_id: int 

27 time_step: int 

28 data: Dict[str, Any] 

29 parameters: List 

30 

31 def __repr__(self) -> str: 

32 """Returns a string representation of the `SimulationData` object.""" 

33 s = ( 

34 f"<{self.__class__.__name__} " 

35 f"sim-id={self.simulation_id} " 

36 f"time-step={self.time_step}>" 

37 ) 

38 return s 

39 

40 

41class PartialSimulationData: 

42 """Stores partial data for a specific simulation, 

43 including information on time step, client rank, and field. 

44 

45 ### Attributes 

46 - **time_step** (`int`): The time step for the simulation. 

47 - **simulation_id** (`int`): The id of the simulation. 

48 - **client_rank** (`int`): The rank of the client submitting the data. 

49 - **data_size** (`int`): The size of the data associated with the simulation. 

50 - **field** (`str`): The field to which the data belongs. 

51 - **data** (`Any`): The actual data associated with the simulation.""" 

52 

53 MAX_FIELD_NAME_SIZE = 128 

54 

55 def __init__( 

56 self, 

57 time_step: int, 

58 simulation_id: int, 

59 client_rank: int, 

60 data_size: int, 

61 field: str, 

62 data: Any, 

63 ) -> None: 

64 

65 self.time_step = time_step 

66 self.simulation_id = simulation_id 

67 self.client_rank = client_rank 

68 self.data_size = data_size 

69 self.field = field 

70 self.data = data 

71 

72 @classmethod 

73 def from_msg(cls, 

74 msg: bytes, 

75 learning: int) -> "PartialSimulationData": 

76 

77 """Class method to deserialize a message and create an instance of the class. 

78 

79 ### Parameters 

80 - **msg** (`bytes`): The serialized message in bytes format to be deserialized. 

81 - **learning** (`int`): A flag or parameter used during the deserialization process. 

82 

83 ### Returns 

84 - `PartialSimulationData`: A new instance of the class created from the 

85 deserialized message.""" 

86 

87 time_step: int 

88 simulation_id: int 

89 client_rank: int 

90 data_size: int 

91 field: bytearray 

92 

93 # unpack metadata 

94 size_metadata: int = 4 * 4 + cls.MAX_FIELD_NAME_SIZE 

95 time_step, simulation_id, client_rank, data_size, field = struct.unpack( 

96 f"4i{cls.MAX_FIELD_NAME_SIZE}s", msg[: size_metadata] 

97 ) 

98 field_name: str = field.split(b"\x00")[0].decode("utf-8") 

99 

100 # unpack data array (float32 for DL and float64 for SA) 

101 dtype = np.dtype(np.float32 if learning > 0 else np.float64) 

102 data: ArrayLike = np.frombuffer( 

103 msg, 

104 offset=size_metadata, 

105 dtype=dtype 

106 ) 

107 

108 return cls(time_step, simulation_id, client_rank, data_size, field_name, data) 

109 

110 def __repr__(self) -> str: 

111 """Returns a string representation of the `PartialSimulationData` object.""" 

112 return ( 

113 f"<{self.__class__.__name__}: sim-id={self.simulation_id}, " 

114 f"time-step={self.time_step}, field=\"{self.field}\"" 

115 ) 

116 

117 

118class SimulationDataStatus(Enum): 

119 """Enum class representing the possible statuses of simulation data.""" 

120 PARTIAL = 0 

121 COMPLETE = 1 

122 ALREADY_RECEIVED = 2 

123 EMPTY = 3 

124 

125 

126class SimulationStatus(Enum): 

127 """Enum class representing the possible statuses of a simulation's state.""" 

128 CONNECTED = 0 

129 RUNNING = 1 

130 FINISHED = 2 

131 

132 

133class Simulation: 

134 """Represents a single simulation with associated metadata. Each object corresponds 

135 to a unique simulation or client and contains information necessary for tracking 

136 the simulation's state, parameters, and fault tolerance. 

137 

138 ### Attributes 

139 - **id_** (`int`): The unique identifier for the simulation. 

140 - **nb_time_steps** (`int`): The total number of time steps for the simulation. 

141 - **fields** (`List[str]`): A list of fields that will be sent. 

142 - **parameters** (`List[Any]`): A list of parameters used in the simulation.""" 

143 def __init__(self, 

144 id_: int, 

145 nb_time_steps: int, 

146 fields: List[str], 

147 parameters: List[Any]) -> None: 

148 

149 self.id = id_ 

150 self.nb_time_steps = nb_time_steps 

151 self.fields = fields 

152 self.received_simulation_data: Dict[ 

153 int, Dict[int, Dict[str, Optional[Union[int, PartialSimulationData]]]] 

154 ] = {} # client_rank, time_step, field, int (SA) or PartialSimulationData (DL) 

155 self.received_time_steps: Dict[int, NDArray[np.bool_]] = {} 

156 self.parameters: List[Any] = parameters 

157 self.nb_received_time_steps: int = 0 

158 self.nb_failures: int = 0 

159 self.last_message: Optional[float] = None 

160 self.duration: float = -1. 

161 self.t0: float = -1. 

162 

163 def init_structures(self, client_rank: int, time_steps_known: bool = True) -> None: 

164 """Initializes data structures to track received simulation data for a given client rank. 

165 

166 ### Parameters 

167 - **client_rank** (`int`): The rank of the client whose data is being tracked. 

168 - **time_steps_known** (`bool`, optional): if the total number of time steps 

169 is known in advance. Default is `True`.""" 

170 

171 # ensure client rank is initialized 

172 if client_rank not in self.received_simulation_data: 

173 self.received_simulation_data[client_rank] = {} 

174 # Known sample count 

175 shape = (len(self.fields), self.nb_time_steps if time_steps_known else 1) 

176 self.received_time_steps[client_rank] = np.zeros( 

177 shape=shape, 

178 dtype=bool 

179 ) 

180 

181 if self.t0 < 0: 

182 self.t0 = time.time() 

183 

184 def init_data_storage(self, client_rank: int, time_step: int) -> None: 

185 """Prepares storage for tracking field-level data 

186 for a specific time step of a client rank. 

187 

188 ### Parameters 

189 - **client_rank** (`int`): The rank of the client whose data is being tracked. 

190 - **time_step** (`int`): The time step for which data storage is initialized.""" 

191 

192 if time_step not in self.received_simulation_data[client_rank]: 

193 self.received_simulation_data[client_rank][time_step] = { 

194 f: None for f in self.fields 

195 } 

196 

197 def time_step_expansion(self, client_rank: int, time_step: int) -> None: 

198 """Dynamically expands the `received_time_steps` matrix for a client rank 

199 if the time step exceeds current capacity. 

200 

201 ### Parameters 

202 - **client_rank** (`int`): The rank of the client whose matrix needs expansion. 

203 - **time_step** (`int`): The time step that triggered the need for expansion.""" 

204 

205 if time_step > self.received_time_steps[client_rank].shape[1] - 1: 

206 self.received_time_steps[client_rank] = np.concatenate( 

207 [ 

208 self.received_time_steps[client_rank], 

209 np.zeros((len(self.fields), 1), dtype=bool), 

210 ], 

211 axis=1, 

212 ) 

213 

214 def update(self, 

215 client_rank: int, 

216 time_step: int, 

217 field: str, 

218 data: Optional[Union[int, PartialSimulationData]]) -> None: 

219 """Updates the data associated with a specific field and time step for a client rank. 

220 

221 ### Parameters 

222 - **client_rank** (`int`): The rank of the client whose data is being updated. 

223 - **time_step** (`int`): The time step associated with the data. 

224 - **field** (`str`): The specific field being updated. 

225 - **data** (`Optional[Union[int, PartialSimulationData]]`): 

226 The new data to store for the field.""" 

227 self.received_simulation_data[client_rank][time_step][field] = data 

228 

229 def get_data(self, 

230 client_rank: int, 

231 time_step: int) -> Dict[str, Optional[Union[int, PartialSimulationData]]]: 

232 """Retrieves the data for all fields associated with a specific time step of a client rank. 

233 

234 ### Parameters 

235 - **client_rank** (`int`): The rank of the client whose data is being retrieved. 

236 - **time_step** (`int`): The time step for which data is being fetched. 

237 

238 ### Returns 

239 - `Dict[str, Optional[Union[int, PartialSimulationData]]]`: 

240 A dictionary containing field names as keys and their corresponding data as values.""" 

241 return self.received_simulation_data[client_rank][time_step] 

242 

243 def clear_data(self, 

244 client_rank: int, 

245 time_step: int) -> None: 

246 """Clears the data for all fields associated with a specific time step of a client rank. 

247 Useful when the complete data is given to the server for post-processing and no longer needs 

248 to be in this object. This is to avoid duplications in checkpointing. 

249 

250 ### Parameters 

251 - **client_rank** (`int`): The rank of the client whose data is being cleaned. 

252 - **time_step** (`int`): The time step for which data is being cleaned.""" 

253 

254 ref = self.received_simulation_data[client_rank][time_step] 

255 for field in ref.keys(): 

256 ref[field] = None 

257 

258 # TODO: remove 

259 def crashed(self) -> bool: 

260 """Unused.""" 

261 return False 

262 

263 def has_already_received(self, 

264 client_rank: int, 

265 time_step: int, 

266 field: str) -> bool: 

267 """Checks if the given time step has already been received 

268 for the specified client and field, helping to avoid duplication of data. 

269 

270 ### Parameters 

271 - **client_rank** (`int`): The rank of the client whose data is being retrieved. 

272 - **time_step** (`int`): The time step for which the data is being checked. 

273 - **field** (`str`): The field associated with the data being checked. 

274 

275 ### Returns 

276 - `bool`: if the data has already been received.""" 

277 

278 field_idx = self.fields.index(field) 

279 return self.received_time_steps[client_rank][field_idx, time_step] 

280 

281 def is_complete(self, time_step: int) -> bool: 

282 """Checks if the given time step has data for all the defined fields. 

283 

284 ### Parameters 

285 - **time_step** (`int`): The time step to check for completeness. 

286 

287 ### Returns 

288 - `bool`: if data has been received for all defined fields.""" 

289 

290 for _, client_data in self.received_simulation_data.items(): 

291 if ( 

292 time_step not in client_data 

293 or len( 

294 [field for field in client_data[time_step].values() if field is not None] 

295 ) != len(self.fields) 

296 ): 

297 return False 

298 return True 

299 

300 def mark_as_received(self, 

301 client_rank: int, 

302 time_step: int, 

303 field: str) -> None: 

304 """Marks the given time step as received for the specified field. 

305 

306 ### Parameters 

307 - **client_rank** (`int`): The rank of the client sending the data. 

308 - **time_step** (`int`): The time step for which the data is received. 

309 - **field** (`str`): The field associated with the data being marked as received.""" 

310 

311 field_idx = self.fields.index(field) 

312 self.received_time_steps[client_rank][field_idx, time_step] = True 

313 self.last_message = time.time() 

314 

315 def finished(self) -> bool: 

316 """Checks whether all time steps for the simulation have been received. 

317 

318 ### Returns 

319 - `bool`: `True` if all time steps have been received.""" 

320 

321 has_finished = self.nb_received_time_steps == self.nb_time_steps 

322 

323 if has_finished: 

324 self.duration = time.time() - self.t0 

325 

326 return has_finished 

327 

328 

329class Group: 

330 """Represents a Sobol group with all the relevant simulations. 

331 

332 This class maintains a caching mechanism for storing corresponding 

333 time steps received for the current group. 

334 

335 ### Attributes 

336 - **group_id** (`int`): The id of the group. 

337 - **sobol_** (`bool`): A flag indicating whether the group uses Sobol cache. 

338 (default is False).""" 

339 def __init__(self, 

340 group_id: int, 

341 sobol_: bool = False) -> None: 

342 

343 self.group_id = group_id 

344 self.simulations: Dict[int, Simulation] = {} 

345 self.nb_failures: int = 0 

346 self.sobol_on: bool = sobol_ 

347 self.sobol_cache: Dict[Tuple[str, int, int], Dict[int, np.ndarray]] = {} 

348 if self.sobol_on: 

349 self.sobol_cache = {} 

350 

351 def __len__(self) -> int: 

352 """Returns length of the current group.""" 

353 return len(self.simulations) 

354 

355 def cache(self, pdata: PartialSimulationData) -> None: 

356 """Caches the received simulation data for the 

357 specified field, simulation id, client rank, and time step. 

358 

359 ### Parameters 

360 - **pdata** (`PartialSimulationData`): The simulation data to be cached.""" 

361 

362 key = (pdata.field, pdata.client_rank, pdata.time_step) 

363 if key not in self.sobol_cache: 

364 self.sobol_cache[key] = {} 

365 self.sobol_cache[key][pdata.simulation_id] = pdata.data 

366 

367 def get_cached(self, 

368 field: str, 

369 client_rank: int, 

370 time_step: int) -> NDArray[np.float64]: 

371 """Retrieves cached group data for the 

372 specified field, client rank, and time step. 

373 

374 ### Parameters 

375 - **field** (`str`): The field associated with the simulation data. 

376 - **client_rank** (`int`): The rank of the client whose data is being retrieved. 

377 - **time_step** (`int`): The time step of the simulation. 

378 

379 ### Returns 

380 - `NDArray[np.float64]`: The cached data sorted 

381 by simulation ids.""" 

382 

383 key = (field, client_rank, time_step) 

384 if ( 

385 key not in self.sobol_cache 

386 and len(self.sobol_cache[key]) != len(self) 

387 ): 

388 return np.array([]) 

389 

390 return np.array([ 

391 self.sobol_cache[key][sim_id] 

392 for sim_id in sorted(self.sobol_cache[key].keys()) 

393 ])