Coverage for melissa/utility/networking.py: 80%

139 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-10 22:25 +0100

1#!/usr/bin/python3 

2 

3# Copyright (c) 2021-2022, Institut National de Recherche en Informatique et en Automatique (Inria) 

4# All rights reserved. 

5# 

6# Redistribution and use in source and binary forms, with or without 

7# modification, are permitted provided that the following conditions are met: 

8# 

9# * Redistributions of source code must retain the above copyright notice, this 

10# list of conditions and the following disclaimer. 

11# 

12# * Redistributions in binary form must reproduce the above copyright notice, 

13# this list of conditions and the following disclaimer in the documentation 

14# and/or other materials provided with the distribution. 

15# 

16# * Neither the name of the copyright holder nor the names of its 

17# contributors may be used to endorse or promote products derived from 

18# this software without specific prior written permission. 

19# 

20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 

21# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 

22# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 

23# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 

24# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 

25# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 

26# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 

27# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 

28# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 

29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 

30 

31# pylint: skip-file 

32 

33 

34"""This script contains utility methods for networking with repect to connection 

35between the launcher-server. Some of the methods are not being used.""" 

36 

37__all__ = [ 

38 "ConnectionId", "LengthPrefixFramingEncoder", "LengthPrefixFramingDecoder", 

39 "Socket", "connect_to_launcher", "connect_to_server", 

40 "get_available_protocols", "make_passive_socket", "pipe", "protocol2str", 

41 "socketpair", "str2protocol", "select_protocol", "is_port_in_use", 

42 "UnsupportedProtocol" 

43] 

44 

45import os 

46import errno 

47import logging 

48import socket 

49from typing import List, Optional, Tuple 

50from socket import socket as Socket 

51from mpi4py import MPI 

52 

53logger = logging.getLogger(__name__) 

54 

55ConnectionId = int 

56 

57 

58class UnsupportedProtocol(Exception): 

59 """Protocol Exception.""" 

60 def __init__(self, msg) -> None: 

61 self.msg = msg 

62 

63 def __str__(self) -> str: 

64 return f"Unsupported Protocol: {self.msg}" 

65 

66 

67class LengthPrefixFramingDecoder: 

68 def __init__(self, prefix_length: int) -> None: 

69 assert prefix_length <= 8 

70 if prefix_length <= 0: 

71 raise ValueError( 

72 "prefix length must be positive, got {:d}". 

73 format(prefix_length) 

74 ) 

75 self._prefix_length = prefix_length 

76 self._buffer = b'' 

77 

78 def execute(self, bs: bytes) -> List[bytes]: 

79 xs = self._buffer + bs 

80 

81 messages = [] # type: List[bytes] 

82 while len(xs) >= self._prefix_length: 

83 num_message_bytes = int.from_bytes( 

84 xs[:self._prefix_length], byteorder="little" 

85 ) 

86 

87 if self._prefix_length + num_message_bytes > len(xs): 

88 break 

89 

90 messages.append( 

91 xs[self._prefix_length:self._prefix_length + num_message_bytes] 

92 ) 

93 xs = xs[self._prefix_length + num_message_bytes:] 

94 

95 self._buffer = xs 

96 return messages 

97 

98 

99class LengthPrefixFramingEncoder: 

100 def __init__(self, prefix_length: int) -> None: 

101 assert prefix_length <= 8 

102 if prefix_length <= 0: 

103 raise ValueError( 

104 "prefix length must be positive, got {:d}". 

105 format(prefix_length) 

106 ) 

107 self._prefix_length = prefix_length 

108 

109 def execute(self, bs: bytes) -> bytes: 

110 n = len(bs) 

111 bits_per_byte = 8 

112 max_message_length = 1 << (self._prefix_length * bits_per_byte) 

113 if n >= max_message_length: 

114 raise RuntimeError( 

115 "message of length {:d} too large for length" 

116 "prefix framing with a prefix of length {:d}" 

117 .format(n, self._prefix_length) 

118 ) 

119 

120 prefix = n.to_bytes(self._prefix_length, byteorder="little") 

121 return prefix + bs 

122 

123 

124def _getenv(key: str) -> str: 

125 """Like os.getenv() but throws if the environment variable is not set.""" 

126 assert key 

127 

128 value = os.getenv(key) 

129 if value is None: 

130 raise RuntimeError("environment variable %s not set" % key) 

131 return value 

132 

133 

134def _fix_hostname(hostname: str) -> str: 

135 """ 

136 This function returns "localhost" if the hostname is the name of the 

137 machine running this code; otherwise it returns the hostname unchanged. 

138 

139 Since 2005, Debian-based distributions make the hostname an alias to 

140 127.0.1.1 (take a look a `/etc/hosts`) and in some cases, this leads to a 

141 hanging connect(2) when server and client are on the same host (is this a 

142 bug?). The problem may occur with SCTP and TCP. 

143 

144 Related: 

145 * https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=316099 

146 """ 

147 return "localhost" if hostname == socket.gethostname() else hostname 

148 

149 

150def is_port_in_use(port: int, host: str = '127.0.0.1') -> bool: 

151 """Check if a port is in use.""" 

152 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: 

153 return sock.connect_ex((host, port)) == 0 

154 

155 

156def connect_to_launcher() -> socket.socket: 

157 """Performs primary tcp socket handshake with the launcher process.""" 

158 launcher_host_raw = _getenv("MELISSA_LAUNCHER_HOST") 

159 launcher_host = _fix_hostname(launcher_host_raw) 

160 launcher_port_str = _getenv("MELISSA_LAUNCHER_PORT") 

161 launcher_port = int(launcher_port_str) 

162 protocol_str = _getenv("MELISSA_LAUNCHER_PROTOCOL") 

163 protocol = str2protocol(protocol_str) 

164 

165 logger.debug( 

166 f"connecting to launcher host={launcher_host} port={launcher_port} protocol={protocol_str}" 

167 ) 

168 

169 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, protocol) 

170 try: 

171 s.connect((launcher_host, launcher_port)) 

172 except Exception as e: 

173 s.close() 

174 raise Exception(f'Exception caught during connection {e}') 

175 return s 

176 

177 

178def connect_to_server() -> Socket: 

179 server_host_raw = _getenv("MELISSA_SERVER_HOST") 

180 server_host = _fix_hostname(server_host_raw) 

181 server_port_str = _getenv("MELISSA_SERVER_PORT") 

182 server_port = int(server_port_str) 

183 

184 logger.debug( 

185 f"connecting to server host={server_host} port={server_port}" 

186 ) 

187 

188 return socket.create_connection((server_host, server_port)) 

189 

190 

191def get_available_protocols() -> List[int]: 

192 family = socket.AF_INET 

193 type = socket.SOCK_STREAM 

194 

195 def is_supported(proto: int) -> bool: 

196 try: 

197 socket.socket(family, type, proto).close() 

198 except OSError as e: 

199 if e.errno != errno.EPROTONOSUPPORT: 

200 raise 

201 return False 

202 return True 

203 

204 protocols = [socket.IPPROTO_SCTP, socket.IPPROTO_TCP] 

205 return [p for p in protocols if is_supported(p)] 

206 

207 

208def make_passive_socket( 

209 node: Optional[str] = None, 

210 *, 

211 protocol: int, 

212 backlog: Optional[int] = None 

213) -> Socket: 

214 # The IP compatibility layer on Infiniband is usually bound to an IPv4 

215 # address. 

216 family = socket.AF_INET 

217 type = socket.SOCK_STREAM 

218 sockfd = socket.socket(family, type, protocol) 

219 try: 

220 sockfd.bind(("" if node is None else node, 0)) 

221 sockfd.listen() if backlog is None else sockfd.listen(backlog) 

222 

223 listen_addr = sockfd.getsockname() 

224 logger.info(f"listening for connections on {listen_addr[0]}:{listen_addr[1]}") 

225 except Exception as e: 

226 sockfd.close() 

227 raise Exception(f'Exception caught listening for connection {e}') 

228 

229 return sockfd 

230 

231 

232# use Socket everywhere in order to avoid having to bother about the 

233# distinction between Socket objects and plain integer file descriptors 

234def pipe() -> Tuple[Socket, Socket]: 

235 fd_r, fd_w = socketpair() 

236 fd_r.shutdown(socket.SHUT_WR) 

237 fd_w.shutdown(socket.SHUT_RD) 

238 return fd_r, fd_w 

239 

240 

241def select_protocol() -> Tuple[int, str]: 

242 """Selecting Protocol for encoding-decoding messages over sockets.""" 

243 try: 

244 protocol = os.environ["MELISSA_LAUNCHER_PROTOCOL"] 

245 except KeyError as e: 

246 raise KeyError("Undefined protocol") from e 

247 

248 if protocol == "SCTP": 

249 return socket.IPPROTO_SCTP, "SCTP" 

250 if protocol == "TCP": 

251 return socket.IPPROTO_TCP, "TCP" 

252 raise UnsupportedProtocol(f"{protocol} not supported for server-launcher communication.") 

253 

254 

255def protocol2str(proto: int) -> str: 

256 if proto == socket.IPPROTO_SCTP: 

257 return "SCTP" 

258 if proto == socket.IPPROTO_TCP: 

259 return "TCP" 

260 raise ValueError("unknown protocol {:d}".format(proto)) 

261 

262 

263def socketpair() -> Tuple[Socket, Socket]: 

264 return socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) 

265 

266 

267def str2protocol(name: str) -> int: 

268 if name == "SCTP": 

269 return socket.IPPROTO_SCTP 

270 if name == "TCP": 

271 return socket.IPPROTO_TCP 

272 raise ValueError("unknown protocol {:s}".format(name)) 

273 

274 

275def get_rank_and_num_server_proc() -> Tuple[int, int]: 

276 

277 rank = MPI.COMM_WORLD.Get_rank() 

278 nb_proc_server = MPI.COMM_WORLD.Get_size() 

279 

280 return rank, nb_proc_server