Coverage for melissa/utility/networking.py: 73%

148 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-11-03 09:52 +0100

1#!/usr/bin/python3 

2 

3# Copyright (c) 2021-2022, Institut National de Recherche en Informatique et en Automatique (Inria) 

4# All rights reserved. 

5# 

6# Redistribution and use in source and binary forms, with or without 

7# modification, are permitted provided that the following conditions are met: 

8# 

9# * Redistributions of source code must retain the above copyright notice, this 

10# list of conditions and the following disclaimer. 

11# 

12# * Redistributions in binary form must reproduce the above copyright notice, 

13# this list of conditions and the following disclaimer in the documentation 

14# and/or other materials provided with the distribution. 

15# 

16# * Neither the name of the copyright holder nor the names of its 

17# contributors may be used to endorse or promote products derived from 

18# this software without specific prior written permission. 

19# 

20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 

21# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 

22# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 

23# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 

24# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 

25# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 

26# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 

27# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 

28# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 

29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 

30 

31# pylint: skip-file 

32 

33 

34"""This script contains utility methods for networking with repect to connection 

35between the launcher-server. Some of the methods are not being used.""" 

36 

37__all__ = [ 

38 "ConnectionId", "LengthPrefixFramingEncoder", "LengthPrefixFramingDecoder", 

39 "Socket", "connect_to_launcher", "connect_to_server", 

40 "get_available_protocols", "make_passive_socket", "pipe", "protocol2str", 

41 "socketpair", "str2protocol", "select_protocol", "is_port_in_use", 

42] 

43 

44import os 

45import errno 

46import time 

47import logging 

48import socket 

49from typing import List, Optional, Tuple 

50from socket import socket as Socket 

51from melissa.launcher.config import TCP_MESSAGE_PREFIX_LENGTH 

52from melissa.server.exceptions import UnsupportedProtocol 

53 

54logger = logging.getLogger(__name__) 

55 

56ConnectionId = int 

57 

58PEER_SANITY_MSG = b"salut" 

59 

60 

61class LengthPrefixFramingDecoder: 

62 def __init__(self, prefix_length: int) -> None: 

63 assert prefix_length <= 8 

64 if prefix_length <= 0: 

65 raise ValueError( 

66 "prefix length must be positive, got {:d}". 

67 format(prefix_length) 

68 ) 

69 self._prefix_length = prefix_length 

70 self._buffer = b'' 

71 

72 def execute(self, bs: bytes) -> List[bytes]: 

73 xs = self._buffer + bs 

74 

75 messages = [] # type: List[bytes] 

76 while len(xs) >= self._prefix_length: 

77 num_message_bytes = int.from_bytes( 

78 xs[:self._prefix_length], byteorder="little" 

79 ) 

80 

81 if self._prefix_length + num_message_bytes > len(xs): 

82 break 

83 

84 messages.append( 

85 xs[self._prefix_length:self._prefix_length + num_message_bytes] 

86 ) 

87 xs = xs[self._prefix_length + num_message_bytes:] 

88 

89 self._buffer = xs 

90 return messages 

91 

92 

93class LengthPrefixFramingEncoder: 

94 def __init__(self, prefix_length: int) -> None: 

95 assert prefix_length <= 8 

96 if prefix_length <= 0: 

97 raise ValueError( 

98 "prefix length must be positive, got {:d}". 

99 format(prefix_length) 

100 ) 

101 self._prefix_length = prefix_length 

102 

103 def execute(self, bs: bytes) -> bytes: 

104 n = len(bs) 

105 bits_per_byte = 8 

106 max_message_length = 1 << (self._prefix_length * bits_per_byte) 

107 if n >= max_message_length: 

108 raise RuntimeError( 

109 "message of length {:d} too large for length" 

110 "prefix framing with a prefix of length {:d}" 

111 .format(n, self._prefix_length) 

112 ) 

113 

114 prefix = n.to_bytes(self._prefix_length, byteorder="little") 

115 return prefix + bs 

116 

117 

118def _getenv(key: str) -> str: 

119 """Like os.getenv() but throws if the environment variable is not set.""" 

120 assert key 

121 

122 value = os.getenv(key) 

123 if value is None: 

124 raise RuntimeError("environment variable %s not set" % key) 

125 return value 

126 

127 

128def _fix_hostname(hostname: str) -> str: 

129 """ 

130 This function returns "localhost" if the hostname is the name of the 

131 machine running this code; otherwise it returns the hostname unchanged. 

132 

133 Since 2005, Debian-based distributions make the hostname an alias to 

134 127.0.1.1 (take a look a `/etc/hosts`) and in some cases, this leads to a 

135 hanging connect(2) when server and client are on the same host (is this a 

136 bug?). The problem may occur with SCTP and TCP. 

137 

138 Related: 

139 * https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=316099 

140 """ 

141 return "localhost" if hostname == socket.gethostname() else hostname 

142 

143 

144def is_port_in_use(port: int, host: str = '127.0.0.1') -> bool: 

145 """Check if a port is in use.""" 

146 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: 

147 return sock.connect_ex((host, port)) == 0 

148 

149 

150def is_launcher_socket_alive_and_ready(launcherfd: socket.socket) -> bool: 

151 """Post handshake sanity check on the launcher fd.""" 

152 protocol_str = _getenv("MELISSA_LAUNCHER_PROTOCOL") 

153 protocol = str2protocol(protocol_str) 

154 

155 try: 

156 time.sleep(1) 

157 data = launcherfd.recv(TCP_MESSAGE_PREFIX_LENGTH + len(PEER_SANITY_MSG)) 

158 if len(data) == 0: 

159 raise socket.error 

160 if protocol == socket.IPPROTO_TCP: 

161 data = LengthPrefixFramingDecoder( 

162 TCP_MESSAGE_PREFIX_LENGTH 

163 ).execute(data)[0] 

164 logger.debug( 

165 f"SANITY_MSG={PEER_SANITY_MSG.decode('utf-8')!r} " 

166 f"RECEIVED={data.decode('utf-8')!r}" 

167 ) 

168 return ( 

169 len(data) == len(PEER_SANITY_MSG) 

170 and data == PEER_SANITY_MSG 

171 ) 

172 except socket.error as e: 

173 logger.error(e) 

174 return False 

175 

176 

177def connect_to_launcher() -> socket.socket: 

178 """Performs primary tcp socket handshake with the launcher process.""" 

179 launcher_host_raw = _getenv("MELISSA_LAUNCHER_HOST") 

180 launcher_host = _fix_hostname(launcher_host_raw) 

181 launcher_port_str = _getenv("MELISSA_LAUNCHER_PORT") 

182 launcher_port = int(launcher_port_str) 

183 protocol_str = _getenv("MELISSA_LAUNCHER_PROTOCOL") 

184 protocol = str2protocol(protocol_str) 

185 

186 logger.debug( 

187 f"connecting to launcher host={launcher_host} port={launcher_port} protocol={protocol_str}" 

188 ) 

189 

190 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, protocol) 

191 try: 

192 s.connect((launcher_host, launcher_port)) 

193 except Exception as e: 

194 s.close() 

195 raise Exception(f'Exception caught during connection {e}') 

196 return s 

197 

198 

199def connect_to_server() -> Socket: 

200 server_host_raw = _getenv("MELISSA_SERVER_HOST") 

201 server_host = _fix_hostname(server_host_raw) 

202 server_port_str = _getenv("MELISSA_SERVER_PORT") 

203 server_port = int(server_port_str) 

204 

205 logger.debug( 

206 f"connecting to server host={server_host} port={server_port}" 

207 ) 

208 

209 return socket.create_connection((server_host, server_port)) 

210 

211 

212def get_available_protocols() -> List[int]: 

213 family = socket.AF_INET 

214 type = socket.SOCK_STREAM 

215 

216 def is_supported(proto: int) -> bool: 

217 try: 

218 socket.socket(family, type, proto).close() 

219 except OSError as e: 

220 if e.errno != errno.EPROTONOSUPPORT: 

221 raise 

222 return False 

223 return True 

224 

225 protocols = [socket.IPPROTO_SCTP, socket.IPPROTO_TCP] 

226 return [p for p in protocols if is_supported(p)] 

227 

228 

229def make_passive_socket( 

230 node: Optional[str] = None, 

231 *, 

232 protocol: int, 

233 backlog: Optional[int] = None 

234) -> Socket: 

235 # The IP compatibility layer on Infiniband is usually bound to an IPv4 

236 # address. 

237 family = socket.AF_INET 

238 type = socket.SOCK_STREAM 

239 sockfd = socket.socket(family, type, protocol) 

240 try: 

241 sockfd.bind(("" if node is None else node, 0)) 

242 sockfd.listen() if backlog is None else sockfd.listen(backlog) 

243 

244 listen_addr = sockfd.getsockname() 

245 logger.info(f"listening for connections on {listen_addr[0]}:{listen_addr[1]}") 

246 except Exception as e: 

247 sockfd.close() 

248 raise Exception(f'Exception caught listening for connection {e}') 

249 

250 return sockfd 

251 

252 

253# use Socket everywhere in order to avoid having to bother about the 

254# distinction between Socket objects and plain integer file descriptors 

255def pipe() -> Tuple[Socket, Socket]: 

256 fd_r, fd_w = socketpair() 

257 fd_r.shutdown(socket.SHUT_WR) 

258 fd_w.shutdown(socket.SHUT_RD) 

259 return fd_r, fd_w 

260 

261 

262def select_protocol() -> Tuple[int, str]: 

263 """Selecting Protocol for encoding-decoding messages over sockets.""" 

264 try: 

265 protocol = os.environ["MELISSA_LAUNCHER_PROTOCOL"] 

266 except KeyError as e: 

267 raise KeyError("Undefined protocol") from e 

268 

269 if protocol == "SCTP": 

270 return socket.IPPROTO_SCTP, "SCTP" 

271 if protocol == "TCP": 

272 return socket.IPPROTO_TCP, "TCP" 

273 raise UnsupportedProtocol(f"{protocol} not supported for server-launcher communication.") 

274 

275 

276def protocol2str(proto: int) -> str: 

277 if proto == socket.IPPROTO_SCTP: 

278 return "SCTP" 

279 if proto == socket.IPPROTO_TCP: 

280 return "TCP" 

281 raise ValueError("unknown protocol {:d}".format(proto)) 

282 

283 

284def socketpair() -> Tuple[Socket, Socket]: 

285 return socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) 

286 

287 

288def str2protocol(name: str) -> int: 

289 if name == "SCTP": 

290 return socket.IPPROTO_SCTP 

291 if name == "TCP": 

292 return socket.IPPROTO_TCP 

293 raise ValueError("unknown protocol {:s}".format(name))