Coverage for melissa/server/message.py: 71%
80 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
1import ctypes
2from enum import Enum
4import numpy as np
6# C prototypes
7c_char_ptr = ctypes.POINTER(ctypes.c_char)
8c_double_ptr = ctypes.POINTER(ctypes.c_double)
9c_void_ptr_ptr = ctypes.POINTER(ctypes.c_void_p)
12# This should maybe go into a constants file
13MELISSA_MAX_NODE_NAME_LENGTH = 128
16def get_str(buff):
17 return buff.split(b'\x00')[0].decode('utf-8')
20def encode(msg_parts):
21 return map(bytes, msg_parts)
24class MessageType(Enum):
25 HELLO = 0
26 JOB = 1
27 DROP = 2
28 STOP = 3
29 TIMEOUT = 4
30 SIMU_STATUS = 5
31 SERVER = 6
32 ALIVE = 7
33 CONFIDENCE_INTERVAL = 8
34 OPTIONS = 9
37class ServerNodeName:
39 def __init__(self, rank, node_name):
40 self.msg_prefix = ctypes.c_int32(MessageType.SERVER.value)
41 self.rank = ctypes.c_int32(rank)
42 self.node_name = node_name.encode('utf-8')
44 def encode(self):
45 return bytes(self.msg_prefix) + bytes(self.rank) + self.node_name
47 def recv(self):
48 pass
51class ConnectionRequest:
53 def __init__(self, simulation_id, comm_size):
54 self.simulation_id = simulation_id
55 self.comm_size = comm_size
57 def encode(self, buff):
58 parts = [ctypes.c_int32.from_buffer_copy(buff[:4]),
59 ctypes.c_int32.from_buffer_copy(buff[4:])]
61 return b''.join(map(bytes, parts))
63 @classmethod
64 def recv(cls, buff):
65 comm_size = ctypes.c_int32.from_buffer_copy(buff[:4])
66 simulation_id = ctypes.c_int32.from_buffer_copy(buff[4:])
67 return cls(simulation_id.value, comm_size.value)
70class ConnectionResponse(ctypes.Structure):
72 def __init__(self, comm_size, sobol_op, learning, nb_parameters,
73 verbose_lvl, port_names):
74 self.comm_size = ctypes.c_int32(comm_size)
75 self.sobol_op = ctypes.c_int32(sobol_op)
76 self.learning = ctypes.c_int32(learning)
77 self.nb_parameters = ctypes.c_int32(nb_parameters)
78 self.verbose_lvl = ctypes.c_int32(verbose_lvl)
79 print('port_names: ', port_names)
80 self.port_names = [
81 port.ljust(MELISSA_MAX_NODE_NAME_LENGTH, '\x00').encode('utf-8')
82 for port in port_names
83 ]
84 self.port_names = b''.join(self.port_names)
86 def encode(self):
87 parts = [self.comm_size, self.sobol_op, self.learning,
88 self.nb_parameters, self.verbose_lvl, self.port_names]
89 msg = b''.join(map(bytes, parts))
90 return msg
93class JobDetails:
95 def __init__(self, simulation_id, job_id, parameters):
96 self.simulation_id = simulation_id
97 self.job_id = job_id
98 self.parameters = parameters
100 @classmethod
101 def from_msg(cls, msg, nb_parameters):
102 # TODO: validate message ?
103 simulation_id = ctypes.c_int32.from_buffer_copy(msg[4:8])
104 job_id = msg[8: (-nb_parameters * 8)].decode('utf-8')
105 # Get parameters
106 params = msg[(-nb_parameters * 8):]
107 parameters = np.frombuffer(params, dtype=np.double)
108 return cls(simulation_id.value, job_id, parameters)
111class Stop:
113 def encode(self):
114 return bytes(ctypes.c_int32(MessageType.STOP.value))
117class SimulationStatusMessage:
119 def __init__(self, simulation_id, status):
120 self.status = ctypes.c_int32(status.value)
121 self.simulation_id = ctypes.c_int32(simulation_id)
123 def encode(self):
124 parts = [ctypes.c_int32(MessageType.SIMU_STATUS.value),
125 self.simulation_id,
126 self.status]
127 msg = b''.join(map(bytes, parts))
128 return msg