Coverage for melissa/server/simulation.py: 92%

83 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-09-22 10:36 +0000

1import logging 

2import struct 

3import time 

4from dataclasses import dataclass 

5from enum import Enum 

6from typing import Any, Dict, List, Optional, Union 

7 

8import numpy as np 

9import numpy.typing as npt 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@dataclass 

15class SimulationData: 

16 simulation_id: int 

17 time_step: int 

18 data: Any 

19 parameters: List 

20 

21 def __repr__(self) -> str: 

22 s = ( 

23 f"<{self.__class__.__name__} " 

24 f"simulation={self.simulation_id} " 

25 f"time step={self.time_step}>" 

26 ) 

27 return s 

28 

29 

30class PartialSimulationData: 

31 

32 MAX_FIELD_NAME_SIZE = 128 

33 

34 def __init__( 

35 self, 

36 time_step: int, 

37 simulation_id: int, 

38 client_rank: int, 

39 data_size: int, 

40 field: str, 

41 data: Any, 

42 ): 

43 self.time_step = time_step 

44 self.simulation_id = simulation_id 

45 self.client_rank = client_rank 

46 self.data_size = data_size 

47 self.field = field 

48 self.data = data 

49 

50 @classmethod 

51 def from_msg(cls, msg, learning): 

52 time_step: int 

53 simulation_id: int 

54 client_rank: int 

55 data_size: int 

56 field: bytearray 

57 

58 # unpack metadata 

59 size_metadata: int = 4 * 4 + cls.MAX_FIELD_NAME_SIZE 

60 time_step, simulation_id, client_rank, data_size, field = struct.unpack( 

61 f"4i{cls.MAX_FIELD_NAME_SIZE}s", msg[: size_metadata] 

62 ) 

63 field_name: str = field.split(b"\x00")[0].decode("utf-8") 

64 

65 # unpack data array (float32 for DL and float64 for SA) 

66 if learning > 0: 

67 data: npt.ArrayLike = np.frombuffer(msg, offset=size_metadata, dtype="f") 

68 else: 

69 data: npt.ArrayLike = np.frombuffer(msg, offset=size_metadata, dtype="d") 

70 

71 return cls(time_step, simulation_id, client_rank, data_size, field_name, data) 

72 

73 def is_matching(self, simulation_data: "PartialSimulationData") -> bool: 

74 return ( 

75 (self.field != simulation_data.field) 

76 and (self.simulation_id == simulation_data.simulation_id) 

77 and (self.time_step == simulation_data.time_step) 

78 ) 

79 

80 def __repr__(self) -> str: 

81 return ( 

82 f"<{self.__class__.__name__}: simulation {self.simulation_id}, " 

83 f"time step {self.time_step}, field {self.field}" 

84 ) 

85 

86 

87class SimulationDataStatus(Enum): 

88 PARTIAL = 0 

89 COMPLETE = 1 

90 ALREADY_RECEIVED = 2 

91 EMPTY = 3 

92 

93 

94class SimulationStatus(Enum): 

95 CONNECTED = 0 

96 RUNNING = 1 

97 FINISHED = 2 

98 

99 

100class Simulation: 

101 def __init__(self, id: int, n_time_steps: int, fields: List[str], parameters: List[Any]): 

102 self.id = id 

103 self.n_time_steps = n_time_steps 

104 self.fields = fields 

105 self.received_simulation_data: Dict[ 

106 int, Dict[int, Dict[str, Optional[Union[int, PartialSimulationData]]]] 

107 ] = {} # client_rank, time_step, field, int (SA) or PartialSimulationData (DL) 

108 self.received_time_steps: Dict[int, npt.NDArray[np.bool_]] = {} 

109 self.parameters: List[Any] = parameters 

110 self.n_received_time_steps: int = 0 

111 self.n_failures: int = 0 

112 self.last_message: Optional[float] = None 

113 

114 def crashed(self) -> bool: 

115 return False 

116 

117 def has_already_received(self, client_rank: int, time_step: int, field: str) -> bool: 

118 field_idx = self.fields.index(field) 

119 return self.received_time_steps[client_rank][field_idx, time_step] 

120 

121 def is_complete(self, time_step: int) -> bool: 

122 for client_rank in self.received_simulation_data.keys(): 

123 if time_step not in self.received_simulation_data[client_rank]: 

124 return False 

125 elif len( 

126 [ 

127 field 

128 for field in self.received_simulation_data[client_rank][time_step].values() 

129 if field is not None 

130 ] 

131 ) != len(self.fields): 

132 return False 

133 return True 

134 

135 def _mark_as_received(self, client_rank: int, time_step: int, field: str): 

136 field_idx = self.fields.index(field) 

137 self.received_time_steps[client_rank][field_idx, time_step] = True 

138 self.last_message = time.time() 

139 

140 def finished(self) -> bool: 

141 return self.n_received_time_steps == self.n_time_steps 

142 

143 

144class Group: 

145 def __init__(self, group_id: int): 

146 self.group_id = group_id 

147 self.simulations: Dict[int, Simulation] = {} 

148 self.n_failures: int = 0