Coverage for melissa/utility/networking.py: 80%
139 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 22:25 +0100
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 22:25 +0100
1#!/usr/bin/python3
3# Copyright (c) 2021-2022, Institut National de Recherche en Informatique et en Automatique (Inria)
4# All rights reserved.
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions are met:
8#
9# * Redistributions of source code must retain the above copyright notice, this
10# list of conditions and the following disclaimer.
11#
12# * Redistributions in binary form must reproduce the above copyright notice,
13# this list of conditions and the following disclaimer in the documentation
14# and/or other materials provided with the distribution.
15#
16# * Neither the name of the copyright holder nor the names of its
17# contributors may be used to endorse or promote products derived from
18# this software without specific prior written permission.
19#
20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31# pylint: skip-file
34"""This script contains utility methods for networking with repect to connection
35between the launcher-server. Some of the methods are not being used."""
37__all__ = [
38 "ConnectionId", "LengthPrefixFramingEncoder", "LengthPrefixFramingDecoder",
39 "Socket", "connect_to_launcher", "connect_to_server",
40 "get_available_protocols", "make_passive_socket", "pipe", "protocol2str",
41 "socketpair", "str2protocol", "select_protocol", "is_port_in_use",
42 "UnsupportedProtocol"
43]
45import os
46import errno
47import logging
48import socket
49from typing import List, Optional, Tuple
50from socket import socket as Socket
51from mpi4py import MPI
53logger = logging.getLogger(__name__)
55ConnectionId = int
58class UnsupportedProtocol(Exception):
59 """Protocol Exception."""
60 def __init__(self, msg) -> None:
61 self.msg = msg
63 def __str__(self) -> str:
64 return f"Unsupported Protocol: {self.msg}"
67class LengthPrefixFramingDecoder:
68 def __init__(self, prefix_length: int) -> None:
69 assert prefix_length <= 8
70 if prefix_length <= 0:
71 raise ValueError(
72 "prefix length must be positive, got {:d}".
73 format(prefix_length)
74 )
75 self._prefix_length = prefix_length
76 self._buffer = b''
78 def execute(self, bs: bytes) -> List[bytes]:
79 xs = self._buffer + bs
81 messages = [] # type: List[bytes]
82 while len(xs) >= self._prefix_length:
83 num_message_bytes = int.from_bytes(
84 xs[:self._prefix_length], byteorder="little"
85 )
87 if self._prefix_length + num_message_bytes > len(xs):
88 break
90 messages.append(
91 xs[self._prefix_length:self._prefix_length + num_message_bytes]
92 )
93 xs = xs[self._prefix_length + num_message_bytes:]
95 self._buffer = xs
96 return messages
99class LengthPrefixFramingEncoder:
100 def __init__(self, prefix_length: int) -> None:
101 assert prefix_length <= 8
102 if prefix_length <= 0:
103 raise ValueError(
104 "prefix length must be positive, got {:d}".
105 format(prefix_length)
106 )
107 self._prefix_length = prefix_length
109 def execute(self, bs: bytes) -> bytes:
110 n = len(bs)
111 bits_per_byte = 8
112 max_message_length = 1 << (self._prefix_length * bits_per_byte)
113 if n >= max_message_length:
114 raise RuntimeError(
115 "message of length {:d} too large for length"
116 "prefix framing with a prefix of length {:d}"
117 .format(n, self._prefix_length)
118 )
120 prefix = n.to_bytes(self._prefix_length, byteorder="little")
121 return prefix + bs
124def _getenv(key: str) -> str:
125 """Like os.getenv() but throws if the environment variable is not set."""
126 assert key
128 value = os.getenv(key)
129 if value is None:
130 raise RuntimeError("environment variable %s not set" % key)
131 return value
134def _fix_hostname(hostname: str) -> str:
135 """
136 This function returns "localhost" if the hostname is the name of the
137 machine running this code; otherwise it returns the hostname unchanged.
139 Since 2005, Debian-based distributions make the hostname an alias to
140 127.0.1.1 (take a look a `/etc/hosts`) and in some cases, this leads to a
141 hanging connect(2) when server and client are on the same host (is this a
142 bug?). The problem may occur with SCTP and TCP.
144 Related:
145 * https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=316099
146 """
147 return "localhost" if hostname == socket.gethostname() else hostname
150def is_port_in_use(port: int, host: str = '127.0.0.1') -> bool:
151 """Check if a port is in use."""
152 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
153 return sock.connect_ex((host, port)) == 0
156def connect_to_launcher() -> socket.socket:
157 """Performs primary tcp socket handshake with the launcher process."""
158 launcher_host_raw = _getenv("MELISSA_LAUNCHER_HOST")
159 launcher_host = _fix_hostname(launcher_host_raw)
160 launcher_port_str = _getenv("MELISSA_LAUNCHER_PORT")
161 launcher_port = int(launcher_port_str)
162 protocol_str = _getenv("MELISSA_LAUNCHER_PROTOCOL")
163 protocol = str2protocol(protocol_str)
165 logger.debug(
166 f"connecting to launcher host={launcher_host} port={launcher_port} protocol={protocol_str}"
167 )
169 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, protocol)
170 try:
171 s.connect((launcher_host, launcher_port))
172 except Exception as e:
173 s.close()
174 raise Exception(f'Exception caught during connection {e}')
175 return s
178def connect_to_server() -> Socket:
179 server_host_raw = _getenv("MELISSA_SERVER_HOST")
180 server_host = _fix_hostname(server_host_raw)
181 server_port_str = _getenv("MELISSA_SERVER_PORT")
182 server_port = int(server_port_str)
184 logger.debug(
185 f"connecting to server host={server_host} port={server_port}"
186 )
188 return socket.create_connection((server_host, server_port))
191def get_available_protocols() -> List[int]:
192 family = socket.AF_INET
193 type = socket.SOCK_STREAM
195 def is_supported(proto: int) -> bool:
196 try:
197 socket.socket(family, type, proto).close()
198 except OSError as e:
199 if e.errno != errno.EPROTONOSUPPORT:
200 raise
201 return False
202 return True
204 protocols = [socket.IPPROTO_SCTP, socket.IPPROTO_TCP]
205 return [p for p in protocols if is_supported(p)]
208def make_passive_socket(
209 node: Optional[str] = None,
210 *,
211 protocol: int,
212 backlog: Optional[int] = None
213) -> Socket:
214 # The IP compatibility layer on Infiniband is usually bound to an IPv4
215 # address.
216 family = socket.AF_INET
217 type = socket.SOCK_STREAM
218 sockfd = socket.socket(family, type, protocol)
219 try:
220 sockfd.bind(("" if node is None else node, 0))
221 sockfd.listen() if backlog is None else sockfd.listen(backlog)
223 listen_addr = sockfd.getsockname()
224 logger.info(f"listening for connections on {listen_addr[0]}:{listen_addr[1]}")
225 except Exception as e:
226 sockfd.close()
227 raise Exception(f'Exception caught listening for connection {e}')
229 return sockfd
232# use Socket everywhere in order to avoid having to bother about the
233# distinction between Socket objects and plain integer file descriptors
234def pipe() -> Tuple[Socket, Socket]:
235 fd_r, fd_w = socketpair()
236 fd_r.shutdown(socket.SHUT_WR)
237 fd_w.shutdown(socket.SHUT_RD)
238 return fd_r, fd_w
241def select_protocol() -> Tuple[int, str]:
242 """Selecting Protocol for encoding-decoding messages over sockets."""
243 try:
244 protocol = os.environ["MELISSA_LAUNCHER_PROTOCOL"]
245 except KeyError as e:
246 raise KeyError("Undefined protocol") from e
248 if protocol == "SCTP":
249 return socket.IPPROTO_SCTP, "SCTP"
250 if protocol == "TCP":
251 return socket.IPPROTO_TCP, "TCP"
252 raise UnsupportedProtocol(f"{protocol} not supported for server-launcher communication.")
255def protocol2str(proto: int) -> str:
256 if proto == socket.IPPROTO_SCTP:
257 return "SCTP"
258 if proto == socket.IPPROTO_TCP:
259 return "TCP"
260 raise ValueError("unknown protocol {:d}".format(proto))
263def socketpair() -> Tuple[Socket, Socket]:
264 return socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
267def str2protocol(name: str) -> int:
268 if name == "SCTP":
269 return socket.IPPROTO_SCTP
270 if name == "TCP":
271 return socket.IPPROTO_TCP
272 raise ValueError("unknown protocol {:s}".format(name))
275def get_rank_and_num_server_proc() -> Tuple[int, int]:
277 rank = MPI.COMM_WORLD.Get_rank()
278 nb_proc_server = MPI.COMM_WORLD.Get_size()
280 return rank, nb_proc_server