Coverage for melissa/utility/networking.py: 87%

125 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-09-22 10:36 +0000

1#!/usr/bin/python3 

2 

3# Copyright (c) 2021-2022, Institut National de Recherche en Informatique et en Automatique (Inria) 

4# All rights reserved. 

5# 

6# Redistribution and use in source and binary forms, with or without 

7# modification, are permitted provided that the following conditions are met: 

8# 

9# * Redistributions of source code must retain the above copyright notice, this 

10# list of conditions and the following disclaimer. 

11# 

12# * Redistributions in binary form must reproduce the above copyright notice, 

13# this list of conditions and the following disclaimer in the documentation 

14# and/or other materials provided with the distribution. 

15# 

16# * Neither the name of the copyright holder nor the names of its 

17# contributors may be used to endorse or promote products derived from 

18# this software without specific prior written permission. 

19# 

20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 

21# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 

22# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 

23# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 

24# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 

25# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 

26# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 

27# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 

28# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 

29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 

30 

31__all__ = [ 

32 "ConnectionId", "LengthPrefixFramingEncoder", "LengthPrefixFramingDecoder", 

33 "Socket", "connect_to_launcher", "connect_to_server", 

34 "get_available_protocols", "make_passive_socket", "pipe", "protocol2str", 

35 "socketpair", "str2protocol" 

36] 

37 

38import os 

39import errno 

40import logging 

41import socket 

42from socket import socket as Socket 

43from typing import List, Optional, Tuple 

44 

45logger = logging.getLogger(__name__) 

46 

47ConnectionId = int 

48 

49 

50class LengthPrefixFramingDecoder: 

51 def __init__(self, prefix_length: int) -> None: 

52 assert prefix_length <= 8 

53 if prefix_length <= 0: 

54 raise ValueError( 

55 "prefix length must be positive, got {:d}". 

56 format(prefix_length) 

57 ) 

58 self._prefix_length = prefix_length 

59 self._buffer = b'' 

60 

61 def execute(self, bs: bytes) -> List[bytes]: 

62 xs = self._buffer + bs 

63 

64 messages = [] # type: List[bytes] 

65 while len(xs) >= self._prefix_length: 

66 num_message_bytes = int.from_bytes( 

67 xs[:self._prefix_length], byteorder="little" 

68 ) 

69 

70 if self._prefix_length + num_message_bytes > len(xs): 

71 break 

72 

73 messages.append( 

74 xs[self._prefix_length:self._prefix_length + num_message_bytes] 

75 ) 

76 xs = xs[self._prefix_length + num_message_bytes:] 

77 

78 self._buffer = xs 

79 return messages 

80 

81 

82class LengthPrefixFramingEncoder: 

83 def __init__(self, prefix_length: int) -> None: 

84 assert prefix_length <= 8 

85 if prefix_length <= 0: 

86 raise ValueError( 

87 "prefix length must be positive, got {:d}". 

88 format(prefix_length) 

89 ) 

90 self._prefix_length = prefix_length 

91 

92 def execute(self, bs: bytes) -> bytes: 

93 n = len(bs) 

94 bits_per_byte = 8 

95 max_message_length = 1 << (self._prefix_length * bits_per_byte) 

96 if n >= max_message_length: 

97 raise RuntimeError( 

98 "message of length {:d} too large for length" 

99 "prefix framing with a prefix of length {:d}" 

100 .format(n, self._prefix_length) 

101 ) 

102 

103 prefix = n.to_bytes(self._prefix_length, byteorder="little") 

104 return prefix + bs 

105 

106 

107def _getenv(key: str) -> str: 

108 """Like os.getenv() but throws if the environment variable is not set.""" 

109 assert key 

110 

111 value = os.getenv(key) 

112 if value is None: 

113 raise RuntimeError("environment variable %s not set" % key) 

114 return value 

115 

116 

117def _fix_hostname(hostname: str) -> str: 

118 """ 

119 This function returns "localhost" if the hostname is the name of the 

120 machine running this code; otherwise it returns the hostname unchanged. 

121 

122 Since 2005, Debian-based distributions make the hostname an alias to 

123 127.0.1.1 (take a look a `/etc/hosts`) and in some cases, this leads to a 

124 hanging connect(2) when server and client are on the same host (is this a 

125 bug?). The problem may occur with SCTP and TCP. 

126 

127 Related: 

128 * https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=316099 

129 """ 

130 return "localhost" if hostname == socket.gethostname() else hostname 

131 

132 

133def connect_to_launcher() -> socket.socket: 

134 launcher_host_raw = _getenv("MELISSA_LAUNCHER_HOST") 

135 launcher_host = _fix_hostname(launcher_host_raw) 

136 launcher_port_str = _getenv("MELISSA_LAUNCHER_PORT") 

137 launcher_port = int(launcher_port_str) 

138 protocol_str = _getenv("MELISSA_LAUNCHER_PROTOCOL") 

139 protocol = str2protocol(protocol_str) 

140 

141 logger.debug( 

142 f"connecting to launcher host={launcher_host} port={launcher_port} protocol={protocol_str}" 

143 ) 

144 

145 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, protocol) 

146 try: 

147 s.connect((launcher_host, launcher_port)) 

148 except Exception as e: 

149 s.close() 

150 raise Exception(f'Exception caught during connection {e}') 

151 return s 

152 

153 

154def connect_to_server() -> Socket: 

155 server_host_raw = _getenv("MELISSA_SERVER_HOST") 

156 server_host = _fix_hostname(server_host_raw) 

157 server_port_str = _getenv("MELISSA_SERVER_PORT") 

158 server_port = int(server_port_str) 

159 

160 logger.debug( 

161 f"connecting to server host={server_host} port={server_port}" 

162 ) 

163 

164 return socket.create_connection((server_host, server_port)) 

165 

166 

167def get_available_protocols() -> List[int]: 

168 family = socket.AF_INET 

169 type = socket.SOCK_STREAM 

170 

171 def is_supported(proto: int) -> bool: 

172 try: 

173 socket.socket(family, type, proto).close() 

174 except OSError as e: 

175 if e.errno != errno.EPROTONOSUPPORT: 

176 raise 

177 return False 

178 return True 

179 

180 protocols = [socket.IPPROTO_SCTP, socket.IPPROTO_TCP] 

181 return [p for p in protocols if is_supported(p)] 

182 

183 

184def make_passive_socket( 

185 node: Optional[str] = None, 

186 *, 

187 protocol: int, 

188 backlog: Optional[int] = None 

189) -> Socket: 

190 # The IP compatibility layer on Infiniband is usually bound to an IPv4 

191 # address. 

192 family = socket.AF_INET 

193 type = socket.SOCK_STREAM 

194 sockfd = socket.socket(family, type, protocol) 

195 try: 

196 sockfd.bind(("" if node is None else node, 0)) 

197 sockfd.listen() if backlog is None else sockfd.listen(backlog) 

198 

199 listen_addr = sockfd.getsockname() 

200 logger.info(f"listening for connections on {listen_addr[0]}:{listen_addr[1]}") 

201 except Exception as e: 

202 sockfd.close() 

203 raise Exception(f'Exception caught listening for connection {e}') 

204 

205 return sockfd 

206 

207 

208# use Socket everywhere in order to avoid having to bother about the 

209# distinction between Socket objects and plain integer file descriptors 

210def pipe() -> Tuple[Socket, Socket]: 

211 fd_r, fd_w = socketpair() 

212 fd_r.shutdown(socket.SHUT_WR) 

213 fd_w.shutdown(socket.SHUT_RD) 

214 return fd_r, fd_w 

215 

216 

217def protocol2str(proto: int) -> str: 

218 if proto == socket.IPPROTO_SCTP: 

219 return "SCTP" 

220 if proto == socket.IPPROTO_TCP: 

221 return "TCP" 

222 raise ValueError("unknown protocol {:d}".format(proto)) 

223 

224 

225def socketpair() -> Tuple[Socket, Socket]: 

226 return socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM) 

227 

228 

229def str2protocol(name: str) -> int: 

230 if name == "SCTP": 

231 return socket.IPPROTO_SCTP 

232 if name == "TCP": 

233 return socket.IPPROTO_TCP 

234 raise ValueError("unknown protocol {:s}".format(name)) 

235 

236 

237def get_rank_and_num_server_proc() -> Tuple[int, int]: 

238 

239 try: 

240 rank = int(os.environ["SLURM_PROCID"]) 

241 nb_proc_server = int(os.environ["SLURM_NTASKS"]) 

242 except KeyError: 

243 from mpi4py import MPI 

244 rank = MPI.COMM_WORLD.Get_rank() 

245 nb_proc_server = MPI.COMM_WORLD.Get_size() 

246 

247 return rank, nb_proc_server