Coverage for melissa/server/message.py: 75%
64 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
1"""This script defines helper classes server-launcher connection."""
2import struct
3from typing import Any, List, Union
4from enum import Enum
6import numpy as np
7from numpy.typing import NDArray
10MELISSA_MAX_NODE_NAME_LENGTH = 128
13class MessageType(Enum):
14 """Message types for launcher-server communication."""
15 HELLO = 0
16 JOB = 1
17 DROP = 2
18 STOP = 3
19 TIMEOUT = 4
20 SIMU_STATUS = 5
21 SERVER = 6
22 ALIVE = 7
23 CONFIDENCE_INTERVAL = 8
24 OPTIONS = 9
27class ServerNodeName:
28 """Represents a server node with a rank and node name for encoding messages."""
30 def __init__(self, rank: int, node_name: str) -> None:
31 self.rank = rank
32 self.node_name = node_name.encode("utf-8")
34 def encode(self) -> bytes:
35 return struct.pack("<ii", MessageType.SERVER.value, self.rank) + self.node_name
38class ConnectionRequest:
39 """Represents a connection request with a simulation id and server communicator size."""
41 def __init__(self, simulation_id: int, comm_size: int) -> None:
42 self.simulation_id = simulation_id
43 self.comm_size = comm_size
45 def encode(self) -> bytes:
46 return struct.pack("<ii", self.comm_size, self.simulation_id)
48 @classmethod
49 def recv(cls, buff: bytes) -> "ConnectionRequest":
50 comm_size, simulation_id = struct.unpack("<ii", buff)
51 return cls(simulation_id, comm_size)
54class ConnectionResponse:
55 """Represents a response to a connection request, including communication details and config."""
57 def __init__(
58 self,
59 server_comm_size: int,
60 learning: int,
61 bind_simulation_to_server: int,
62 nb_parameters: int,
63 port_names: List[str],
64 ) -> None:
65 self.server_comm_size = server_comm_size
66 self.learning = learning
67 self.bind_to_server = bind_simulation_to_server
68 self.nb_parameters = nb_parameters
70 self.port_names = b"".join(
71 p.ljust(MELISSA_MAX_NODE_NAME_LENGTH, "\x00").encode("utf-8")
72 for p in port_names
73 )
75 def encode(self) -> bytes:
76 header = struct.pack(
77 "<4i", self.server_comm_size,
78 self.learning,
79 self.bind_to_server,
80 self.nb_parameters
81 )
82 return header + self.port_names
85class JobDetails:
86 """Represents details of a job, including its simulation id, job id, and parameters."""
88 def __init__(self, simulation_id: int, job_id: str, parameters: Union[List, NDArray]) -> None:
89 self.simulation_id = simulation_id
90 self.job_id = job_id
91 self.parameters = parameters
93 @classmethod
94 def from_msg(cls, msg: bytes, nb_parameters: int) -> "JobDetails":
95 # First 4 bytes are message type, next 4 are simulation ID
96 _, simulation_id = struct.unpack("<ii", msg[:8])
97 param_bytes_len = nb_parameters * 8
98 job_id = msg[8:-param_bytes_len].decode("utf-8")
99 parameters = np.frombuffer(msg[-param_bytes_len:], dtype=np.float64)
100 return cls(simulation_id, job_id, parameters)
103class Stop:
104 """Represents a STOP message to be sent in the communication protocol."""
106 def encode(self) -> bytes:
107 return struct.pack("<i", MessageType.STOP.value)
110class SimulationStatusMessage:
111 """Represents a message containing the status of a simulation."""
113 def __init__(self, simulation_id: int, status: Any) -> None:
114 self.simulation_id = simulation_id
115 self.status = status
117 def encode(self) -> bytes:
118 return struct.pack(
119 "<3i", MessageType.SIMU_STATUS.value, self.simulation_id, self.status.value
120 )