Coverage for melissa/server/message.py: 75%

64 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-11-03 09:52 +0100

1"""This script defines helper classes server-launcher connection.""" 

2import struct 

3from typing import Any, List, Union 

4from enum import Enum 

5 

6import numpy as np 

7from numpy.typing import NDArray 

8 

9 

10MELISSA_MAX_NODE_NAME_LENGTH = 128 

11 

12 

13class MessageType(Enum): 

14 """Message types for launcher-server communication.""" 

15 HELLO = 0 

16 JOB = 1 

17 DROP = 2 

18 STOP = 3 

19 TIMEOUT = 4 

20 SIMU_STATUS = 5 

21 SERVER = 6 

22 ALIVE = 7 

23 CONFIDENCE_INTERVAL = 8 

24 OPTIONS = 9 

25 

26 

27class ServerNodeName: 

28 """Represents a server node with a rank and node name for encoding messages.""" 

29 

30 def __init__(self, rank: int, node_name: str) -> None: 

31 self.rank = rank 

32 self.node_name = node_name.encode("utf-8") 

33 

34 def encode(self) -> bytes: 

35 return struct.pack("<ii", MessageType.SERVER.value, self.rank) + self.node_name 

36 

37 

38class ConnectionRequest: 

39 """Represents a connection request with a simulation id and server communicator size.""" 

40 

41 def __init__(self, simulation_id: int, comm_size: int) -> None: 

42 self.simulation_id = simulation_id 

43 self.comm_size = comm_size 

44 

45 def encode(self) -> bytes: 

46 return struct.pack("<ii", self.comm_size, self.simulation_id) 

47 

48 @classmethod 

49 def recv(cls, buff: bytes) -> "ConnectionRequest": 

50 comm_size, simulation_id = struct.unpack("<ii", buff) 

51 return cls(simulation_id, comm_size) 

52 

53 

54class ConnectionResponse: 

55 """Represents a response to a connection request, including communication details and config.""" 

56 

57 def __init__( 

58 self, 

59 server_comm_size: int, 

60 learning: int, 

61 bind_simulation_to_server: int, 

62 nb_parameters: int, 

63 port_names: List[str], 

64 ) -> None: 

65 self.server_comm_size = server_comm_size 

66 self.learning = learning 

67 self.bind_to_server = bind_simulation_to_server 

68 self.nb_parameters = nb_parameters 

69 

70 self.port_names = b"".join( 

71 p.ljust(MELISSA_MAX_NODE_NAME_LENGTH, "\x00").encode("utf-8") 

72 for p in port_names 

73 ) 

74 

75 def encode(self) -> bytes: 

76 header = struct.pack( 

77 "<4i", self.server_comm_size, 

78 self.learning, 

79 self.bind_to_server, 

80 self.nb_parameters 

81 ) 

82 return header + self.port_names 

83 

84 

85class JobDetails: 

86 """Represents details of a job, including its simulation id, job id, and parameters.""" 

87 

88 def __init__(self, simulation_id: int, job_id: str, parameters: Union[List, NDArray]) -> None: 

89 self.simulation_id = simulation_id 

90 self.job_id = job_id 

91 self.parameters = parameters 

92 

93 @classmethod 

94 def from_msg(cls, msg: bytes, nb_parameters: int) -> "JobDetails": 

95 # First 4 bytes are message type, next 4 are simulation ID 

96 _, simulation_id = struct.unpack("<ii", msg[:8]) 

97 param_bytes_len = nb_parameters * 8 

98 job_id = msg[8:-param_bytes_len].decode("utf-8") 

99 parameters = np.frombuffer(msg[-param_bytes_len:], dtype=np.float64) 

100 return cls(simulation_id, job_id, parameters) 

101 

102 

103class Stop: 

104 """Represents a STOP message to be sent in the communication protocol.""" 

105 

106 def encode(self) -> bytes: 

107 return struct.pack("<i", MessageType.STOP.value) 

108 

109 

110class SimulationStatusMessage: 

111 """Represents a message containing the status of a simulation.""" 

112 

113 def __init__(self, simulation_id: int, status: Any) -> None: 

114 self.simulation_id = simulation_id 

115 self.status = status 

116 

117 def encode(self) -> bytes: 

118 return struct.pack( 

119 "<3i", MessageType.SIMU_STATUS.value, self.simulation_id, self.status.value 

120 )