Coverage for melissa/utility/networking.py: 73%
148 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
1#!/usr/bin/python3
3# Copyright (c) 2021-2022, Institut National de Recherche en Informatique et en Automatique (Inria)
4# All rights reserved.
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions are met:
8#
9# * Redistributions of source code must retain the above copyright notice, this
10# list of conditions and the following disclaimer.
11#
12# * Redistributions in binary form must reproduce the above copyright notice,
13# this list of conditions and the following disclaimer in the documentation
14# and/or other materials provided with the distribution.
15#
16# * Neither the name of the copyright holder nor the names of its
17# contributors may be used to endorse or promote products derived from
18# this software without specific prior written permission.
19#
20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31# pylint: skip-file
34"""This script contains utility methods for networking with repect to connection
35between the launcher-server. Some of the methods are not being used."""
37__all__ = [
38 "ConnectionId", "LengthPrefixFramingEncoder", "LengthPrefixFramingDecoder",
39 "Socket", "connect_to_launcher", "connect_to_server",
40 "get_available_protocols", "make_passive_socket", "pipe", "protocol2str",
41 "socketpair", "str2protocol", "select_protocol", "is_port_in_use",
42]
44import os
45import errno
46import time
47import logging
48import socket
49from typing import List, Optional, Tuple
50from socket import socket as Socket
51from melissa.launcher.config import TCP_MESSAGE_PREFIX_LENGTH
52from melissa.server.exceptions import UnsupportedProtocol
54logger = logging.getLogger(__name__)
56ConnectionId = int
58PEER_SANITY_MSG = b"salut"
61class LengthPrefixFramingDecoder:
62 def __init__(self, prefix_length: int) -> None:
63 assert prefix_length <= 8
64 if prefix_length <= 0:
65 raise ValueError(
66 "prefix length must be positive, got {:d}".
67 format(prefix_length)
68 )
69 self._prefix_length = prefix_length
70 self._buffer = b''
72 def execute(self, bs: bytes) -> List[bytes]:
73 xs = self._buffer + bs
75 messages = [] # type: List[bytes]
76 while len(xs) >= self._prefix_length:
77 num_message_bytes = int.from_bytes(
78 xs[:self._prefix_length], byteorder="little"
79 )
81 if self._prefix_length + num_message_bytes > len(xs):
82 break
84 messages.append(
85 xs[self._prefix_length:self._prefix_length + num_message_bytes]
86 )
87 xs = xs[self._prefix_length + num_message_bytes:]
89 self._buffer = xs
90 return messages
93class LengthPrefixFramingEncoder:
94 def __init__(self, prefix_length: int) -> None:
95 assert prefix_length <= 8
96 if prefix_length <= 0:
97 raise ValueError(
98 "prefix length must be positive, got {:d}".
99 format(prefix_length)
100 )
101 self._prefix_length = prefix_length
103 def execute(self, bs: bytes) -> bytes:
104 n = len(bs)
105 bits_per_byte = 8
106 max_message_length = 1 << (self._prefix_length * bits_per_byte)
107 if n >= max_message_length:
108 raise RuntimeError(
109 "message of length {:d} too large for length"
110 "prefix framing with a prefix of length {:d}"
111 .format(n, self._prefix_length)
112 )
114 prefix = n.to_bytes(self._prefix_length, byteorder="little")
115 return prefix + bs
118def _getenv(key: str) -> str:
119 """Like os.getenv() but throws if the environment variable is not set."""
120 assert key
122 value = os.getenv(key)
123 if value is None:
124 raise RuntimeError("environment variable %s not set" % key)
125 return value
128def _fix_hostname(hostname: str) -> str:
129 """
130 This function returns "localhost" if the hostname is the name of the
131 machine running this code; otherwise it returns the hostname unchanged.
133 Since 2005, Debian-based distributions make the hostname an alias to
134 127.0.1.1 (take a look a `/etc/hosts`) and in some cases, this leads to a
135 hanging connect(2) when server and client are on the same host (is this a
136 bug?). The problem may occur with SCTP and TCP.
138 Related:
139 * https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=316099
140 """
141 return "localhost" if hostname == socket.gethostname() else hostname
144def is_port_in_use(port: int, host: str = '127.0.0.1') -> bool:
145 """Check if a port is in use."""
146 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
147 return sock.connect_ex((host, port)) == 0
150def is_launcher_socket_alive_and_ready(launcherfd: socket.socket) -> bool:
151 """Post handshake sanity check on the launcher fd."""
152 protocol_str = _getenv("MELISSA_LAUNCHER_PROTOCOL")
153 protocol = str2protocol(protocol_str)
155 try:
156 time.sleep(1)
157 data = launcherfd.recv(TCP_MESSAGE_PREFIX_LENGTH + len(PEER_SANITY_MSG))
158 if len(data) == 0:
159 raise socket.error
160 if protocol == socket.IPPROTO_TCP:
161 data = LengthPrefixFramingDecoder(
162 TCP_MESSAGE_PREFIX_LENGTH
163 ).execute(data)[0]
164 logger.debug(
165 f"SANITY_MSG={PEER_SANITY_MSG.decode('utf-8')!r} "
166 f"RECEIVED={data.decode('utf-8')!r}"
167 )
168 return (
169 len(data) == len(PEER_SANITY_MSG)
170 and data == PEER_SANITY_MSG
171 )
172 except socket.error as e:
173 logger.error(e)
174 return False
177def connect_to_launcher() -> socket.socket:
178 """Performs primary tcp socket handshake with the launcher process."""
179 launcher_host_raw = _getenv("MELISSA_LAUNCHER_HOST")
180 launcher_host = _fix_hostname(launcher_host_raw)
181 launcher_port_str = _getenv("MELISSA_LAUNCHER_PORT")
182 launcher_port = int(launcher_port_str)
183 protocol_str = _getenv("MELISSA_LAUNCHER_PROTOCOL")
184 protocol = str2protocol(protocol_str)
186 logger.debug(
187 f"connecting to launcher host={launcher_host} port={launcher_port} protocol={protocol_str}"
188 )
190 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, protocol)
191 try:
192 s.connect((launcher_host, launcher_port))
193 except Exception as e:
194 s.close()
195 raise Exception(f'Exception caught during connection {e}')
196 return s
199def connect_to_server() -> Socket:
200 server_host_raw = _getenv("MELISSA_SERVER_HOST")
201 server_host = _fix_hostname(server_host_raw)
202 server_port_str = _getenv("MELISSA_SERVER_PORT")
203 server_port = int(server_port_str)
205 logger.debug(
206 f"connecting to server host={server_host} port={server_port}"
207 )
209 return socket.create_connection((server_host, server_port))
212def get_available_protocols() -> List[int]:
213 family = socket.AF_INET
214 type = socket.SOCK_STREAM
216 def is_supported(proto: int) -> bool:
217 try:
218 socket.socket(family, type, proto).close()
219 except OSError as e:
220 if e.errno != errno.EPROTONOSUPPORT:
221 raise
222 return False
223 return True
225 protocols = [socket.IPPROTO_SCTP, socket.IPPROTO_TCP]
226 return [p for p in protocols if is_supported(p)]
229def make_passive_socket(
230 node: Optional[str] = None,
231 *,
232 protocol: int,
233 backlog: Optional[int] = None
234) -> Socket:
235 # The IP compatibility layer on Infiniband is usually bound to an IPv4
236 # address.
237 family = socket.AF_INET
238 type = socket.SOCK_STREAM
239 sockfd = socket.socket(family, type, protocol)
240 try:
241 sockfd.bind(("" if node is None else node, 0))
242 sockfd.listen() if backlog is None else sockfd.listen(backlog)
244 listen_addr = sockfd.getsockname()
245 logger.info(f"listening for connections on {listen_addr[0]}:{listen_addr[1]}")
246 except Exception as e:
247 sockfd.close()
248 raise Exception(f'Exception caught listening for connection {e}')
250 return sockfd
253# use Socket everywhere in order to avoid having to bother about the
254# distinction between Socket objects and plain integer file descriptors
255def pipe() -> Tuple[Socket, Socket]:
256 fd_r, fd_w = socketpair()
257 fd_r.shutdown(socket.SHUT_WR)
258 fd_w.shutdown(socket.SHUT_RD)
259 return fd_r, fd_w
262def select_protocol() -> Tuple[int, str]:
263 """Selecting Protocol for encoding-decoding messages over sockets."""
264 try:
265 protocol = os.environ["MELISSA_LAUNCHER_PROTOCOL"]
266 except KeyError as e:
267 raise KeyError("Undefined protocol") from e
269 if protocol == "SCTP":
270 return socket.IPPROTO_SCTP, "SCTP"
271 if protocol == "TCP":
272 return socket.IPPROTO_TCP, "TCP"
273 raise UnsupportedProtocol(f"{protocol} not supported for server-launcher communication.")
276def protocol2str(proto: int) -> str:
277 if proto == socket.IPPROTO_SCTP:
278 return "SCTP"
279 if proto == socket.IPPROTO_TCP:
280 return "TCP"
281 raise ValueError("unknown protocol {:d}".format(proto))
284def socketpair() -> Tuple[Socket, Socket]:
285 return socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
288def str2protocol(name: str) -> int:
289 if name == "SCTP":
290 return socket.IPPROTO_SCTP
291 if name == "TCP":
292 return socket.IPPROTO_TCP
293 raise ValueError("unknown protocol {:s}".format(name))