Coverage for melissa/server/message.py: 71%

80 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-09-22 10:36 +0000

1import ctypes 

2from enum import Enum 

3 

4import numpy as np 

5 

6# C prototypes 

7c_char_ptr = ctypes.POINTER(ctypes.c_char) 

8c_double_ptr = ctypes.POINTER(ctypes.c_double) 

9c_void_ptr_ptr = ctypes.POINTER(ctypes.c_void_p) 

10 

11 

12# This should maybe go into a constants file 

13MELISSA_MAX_NODE_NAME_LENGTH = 128 

14 

15 

16def get_str(buff): 

17 return buff.split(b'\x00')[0].decode('utf-8') 

18 

19 

20def encode(msg_parts): 

21 return map(bytes, msg_parts) 

22 

23 

24class MessageType(Enum): 

25 HELLO = 0 

26 JOB = 1 

27 DROP = 2 

28 STOP = 3 

29 TIMEOUT = 4 

30 SIMU_STATUS = 5 

31 SERVER = 6 

32 ALIVE = 7 

33 CONFIDENCE_INTERVAL = 8 

34 OPTIONS = 9 

35 

36 

37class ServerNodeName: 

38 

39 def __init__(self, rank, node_name): 

40 self.msg_prefix = ctypes.c_int32(MessageType.SERVER.value) 

41 self.rank = ctypes.c_int32(rank) 

42 self.node_name = node_name.encode('utf-8') 

43 

44 def encode(self): 

45 return bytes(self.msg_prefix) + bytes(self.rank) + self.node_name 

46 

47 def recv(self): 

48 pass 

49 

50 

51class ConnectionRequest: 

52 

53 def __init__(self, simulation_id, comm_size): 

54 self.simulation_id = simulation_id 

55 self.comm_size = comm_size 

56 

57 def encode(self, buff): 

58 parts = [ctypes.c_int32.from_buffer_copy(buff[:4]), 

59 ctypes.c_int32.from_buffer_copy(buff[4:])] 

60 

61 return b''.join(map(bytes, parts)) 

62 

63 @classmethod 

64 def recv(cls, buff): 

65 comm_size = ctypes.c_int32.from_buffer_copy(buff[:4]) 

66 simulation_id = ctypes.c_int32.from_buffer_copy(buff[4:]) 

67 return cls(simulation_id.value, comm_size.value) 

68 

69 

70class ConnectionResponse(ctypes.Structure): 

71 

72 def __init__(self, comm_size, sobol_op, learning, nb_parameters, 

73 verbose_lvl, port_names): 

74 self.comm_size = ctypes.c_int32(comm_size) 

75 self.sobol_op = ctypes.c_int32(sobol_op) 

76 self.learning = ctypes.c_int32(learning) 

77 self.nb_parameters = ctypes.c_int32(nb_parameters) 

78 self.verbose_lvl = ctypes.c_int32(verbose_lvl) 

79 print('port_names: ', port_names) 

80 self.port_names = [ 

81 port.ljust(MELISSA_MAX_NODE_NAME_LENGTH, '\x00').encode('utf-8') 

82 for port in port_names 

83 ] 

84 self.port_names = b''.join(self.port_names) 

85 

86 def encode(self): 

87 parts = [self.comm_size, self.sobol_op, self.learning, 

88 self.nb_parameters, self.verbose_lvl, self.port_names] 

89 msg = b''.join(map(bytes, parts)) 

90 return msg 

91 

92 

93class JobDetails: 

94 

95 def __init__(self, simulation_id, job_id, parameters): 

96 self.simulation_id = simulation_id 

97 self.job_id = job_id 

98 self.parameters = parameters 

99 

100 @classmethod 

101 def from_msg(cls, msg, nb_parameters): 

102 # TODO: validate message ? 

103 simulation_id = ctypes.c_int32.from_buffer_copy(msg[4:8]) 

104 job_id = msg[8: (-nb_parameters * 8)].decode('utf-8') 

105 # Get parameters 

106 params = msg[(-nb_parameters * 8):] 

107 parameters = np.frombuffer(params, dtype=np.double) 

108 return cls(simulation_id.value, job_id, parameters) 

109 

110 

111class Stop: 

112 

113 def encode(self): 

114 return bytes(ctypes.c_int32(MessageType.STOP.value)) 

115 

116 

117class SimulationStatusMessage: 

118 

119 def __init__(self, simulation_id, status): 

120 self.status = ctypes.c_int32(status.value) 

121 self.simulation_id = ctypes.c_int32(simulation_id) 

122 

123 def encode(self): 

124 parts = [ctypes.c_int32(MessageType.SIMU_STATUS.value), 

125 self.simulation_id, 

126 self.status] 

127 msg = b''.join(map(bytes, parts)) 

128 return msg