Coverage for melissa/server/simulation.py: 92%
83 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
1import logging
2import struct
3import time
4from dataclasses import dataclass
5from enum import Enum
6from typing import Any, Dict, List, Optional, Union
8import numpy as np
9import numpy.typing as npt
11logger = logging.getLogger(__name__)
14@dataclass
15class SimulationData:
16 simulation_id: int
17 time_step: int
18 data: Any
19 parameters: List
21 def __repr__(self) -> str:
22 s = (
23 f"<{self.__class__.__name__} "
24 f"simulation={self.simulation_id} "
25 f"time step={self.time_step}>"
26 )
27 return s
30class PartialSimulationData:
32 MAX_FIELD_NAME_SIZE = 128
34 def __init__(
35 self,
36 time_step: int,
37 simulation_id: int,
38 client_rank: int,
39 data_size: int,
40 field: str,
41 data: Any,
42 ):
43 self.time_step = time_step
44 self.simulation_id = simulation_id
45 self.client_rank = client_rank
46 self.data_size = data_size
47 self.field = field
48 self.data = data
50 @classmethod
51 def from_msg(cls, msg, learning):
52 time_step: int
53 simulation_id: int
54 client_rank: int
55 data_size: int
56 field: bytearray
58 # unpack metadata
59 size_metadata: int = 4 * 4 + cls.MAX_FIELD_NAME_SIZE
60 time_step, simulation_id, client_rank, data_size, field = struct.unpack(
61 f"4i{cls.MAX_FIELD_NAME_SIZE}s", msg[: size_metadata]
62 )
63 field_name: str = field.split(b"\x00")[0].decode("utf-8")
65 # unpack data array (float32 for DL and float64 for SA)
66 if learning > 0:
67 data: npt.ArrayLike = np.frombuffer(msg, offset=size_metadata, dtype="f")
68 else:
69 data: npt.ArrayLike = np.frombuffer(msg, offset=size_metadata, dtype="d")
71 return cls(time_step, simulation_id, client_rank, data_size, field_name, data)
73 def is_matching(self, simulation_data: "PartialSimulationData") -> bool:
74 return (
75 (self.field != simulation_data.field)
76 and (self.simulation_id == simulation_data.simulation_id)
77 and (self.time_step == simulation_data.time_step)
78 )
80 def __repr__(self) -> str:
81 return (
82 f"<{self.__class__.__name__}: simulation {self.simulation_id}, "
83 f"time step {self.time_step}, field {self.field}"
84 )
87class SimulationDataStatus(Enum):
88 PARTIAL = 0
89 COMPLETE = 1
90 ALREADY_RECEIVED = 2
91 EMPTY = 3
94class SimulationStatus(Enum):
95 CONNECTED = 0
96 RUNNING = 1
97 FINISHED = 2
100class Simulation:
101 def __init__(self, id: int, n_time_steps: int, fields: List[str], parameters: List[Any]):
102 self.id = id
103 self.n_time_steps = n_time_steps
104 self.fields = fields
105 self.received_simulation_data: Dict[
106 int, Dict[int, Dict[str, Optional[Union[int, PartialSimulationData]]]]
107 ] = {} # client_rank, time_step, field, int (SA) or PartialSimulationData (DL)
108 self.received_time_steps: Dict[int, npt.NDArray[np.bool_]] = {}
109 self.parameters: List[Any] = parameters
110 self.n_received_time_steps: int = 0
111 self.n_failures: int = 0
112 self.last_message: Optional[float] = None
114 def crashed(self) -> bool:
115 return False
117 def has_already_received(self, client_rank: int, time_step: int, field: str) -> bool:
118 field_idx = self.fields.index(field)
119 return self.received_time_steps[client_rank][field_idx, time_step]
121 def is_complete(self, time_step: int) -> bool:
122 for client_rank in self.received_simulation_data.keys():
123 if time_step not in self.received_simulation_data[client_rank]:
124 return False
125 elif len(
126 [
127 field
128 for field in self.received_simulation_data[client_rank][time_step].values()
129 if field is not None
130 ]
131 ) != len(self.fields):
132 return False
133 return True
135 def _mark_as_received(self, client_rank: int, time_step: int, field: str):
136 field_idx = self.fields.index(field)
137 self.received_time_steps[client_rank][field_idx, time_step] = True
138 self.last_message = time.time()
140 def finished(self) -> bool:
141 return self.n_received_time_steps == self.n_time_steps
144class Group:
145 def __init__(self, group_id: int):
146 self.group_id = group_id
147 self.simulations: Dict[int, Simulation] = {}
148 self.n_failures: int = 0