Coverage for melissa/utility/networking.py: 87%
125 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
1#!/usr/bin/python3
3# Copyright (c) 2021-2022, Institut National de Recherche en Informatique et en Automatique (Inria)
4# All rights reserved.
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions are met:
8#
9# * Redistributions of source code must retain the above copyright notice, this
10# list of conditions and the following disclaimer.
11#
12# * Redistributions in binary form must reproduce the above copyright notice,
13# this list of conditions and the following disclaimer in the documentation
14# and/or other materials provided with the distribution.
15#
16# * Neither the name of the copyright holder nor the names of its
17# contributors may be used to endorse or promote products derived from
18# this software without specific prior written permission.
19#
20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31__all__ = [
32 "ConnectionId", "LengthPrefixFramingEncoder", "LengthPrefixFramingDecoder",
33 "Socket", "connect_to_launcher", "connect_to_server",
34 "get_available_protocols", "make_passive_socket", "pipe", "protocol2str",
35 "socketpair", "str2protocol"
36]
38import os
39import errno
40import logging
41import socket
42from socket import socket as Socket
43from typing import List, Optional, Tuple
45logger = logging.getLogger(__name__)
47ConnectionId = int
50class LengthPrefixFramingDecoder:
51 def __init__(self, prefix_length: int) -> None:
52 assert prefix_length <= 8
53 if prefix_length <= 0:
54 raise ValueError(
55 "prefix length must be positive, got {:d}".
56 format(prefix_length)
57 )
58 self._prefix_length = prefix_length
59 self._buffer = b''
61 def execute(self, bs: bytes) -> List[bytes]:
62 xs = self._buffer + bs
64 messages = [] # type: List[bytes]
65 while len(xs) >= self._prefix_length:
66 num_message_bytes = int.from_bytes(
67 xs[:self._prefix_length], byteorder="little"
68 )
70 if self._prefix_length + num_message_bytes > len(xs):
71 break
73 messages.append(
74 xs[self._prefix_length:self._prefix_length + num_message_bytes]
75 )
76 xs = xs[self._prefix_length + num_message_bytes:]
78 self._buffer = xs
79 return messages
82class LengthPrefixFramingEncoder:
83 def __init__(self, prefix_length: int) -> None:
84 assert prefix_length <= 8
85 if prefix_length <= 0:
86 raise ValueError(
87 "prefix length must be positive, got {:d}".
88 format(prefix_length)
89 )
90 self._prefix_length = prefix_length
92 def execute(self, bs: bytes) -> bytes:
93 n = len(bs)
94 bits_per_byte = 8
95 max_message_length = 1 << (self._prefix_length * bits_per_byte)
96 if n >= max_message_length:
97 raise RuntimeError(
98 "message of length {:d} too large for length"
99 "prefix framing with a prefix of length {:d}"
100 .format(n, self._prefix_length)
101 )
103 prefix = n.to_bytes(self._prefix_length, byteorder="little")
104 return prefix + bs
107def _getenv(key: str) -> str:
108 """Like os.getenv() but throws if the environment variable is not set."""
109 assert key
111 value = os.getenv(key)
112 if value is None:
113 raise RuntimeError("environment variable %s not set" % key)
114 return value
117def _fix_hostname(hostname: str) -> str:
118 """
119 This function returns "localhost" if the hostname is the name of the
120 machine running this code; otherwise it returns the hostname unchanged.
122 Since 2005, Debian-based distributions make the hostname an alias to
123 127.0.1.1 (take a look a `/etc/hosts`) and in some cases, this leads to a
124 hanging connect(2) when server and client are on the same host (is this a
125 bug?). The problem may occur with SCTP and TCP.
127 Related:
128 * https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=316099
129 """
130 return "localhost" if hostname == socket.gethostname() else hostname
133def connect_to_launcher() -> socket.socket:
134 launcher_host_raw = _getenv("MELISSA_LAUNCHER_HOST")
135 launcher_host = _fix_hostname(launcher_host_raw)
136 launcher_port_str = _getenv("MELISSA_LAUNCHER_PORT")
137 launcher_port = int(launcher_port_str)
138 protocol_str = _getenv("MELISSA_LAUNCHER_PROTOCOL")
139 protocol = str2protocol(protocol_str)
141 logger.debug(
142 f"connecting to launcher host={launcher_host} port={launcher_port} protocol={protocol_str}"
143 )
145 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, protocol)
146 try:
147 s.connect((launcher_host, launcher_port))
148 except Exception as e:
149 s.close()
150 raise Exception(f'Exception caught during connection {e}')
151 return s
154def connect_to_server() -> Socket:
155 server_host_raw = _getenv("MELISSA_SERVER_HOST")
156 server_host = _fix_hostname(server_host_raw)
157 server_port_str = _getenv("MELISSA_SERVER_PORT")
158 server_port = int(server_port_str)
160 logger.debug(
161 f"connecting to server host={server_host} port={server_port}"
162 )
164 return socket.create_connection((server_host, server_port))
167def get_available_protocols() -> List[int]:
168 family = socket.AF_INET
169 type = socket.SOCK_STREAM
171 def is_supported(proto: int) -> bool:
172 try:
173 socket.socket(family, type, proto).close()
174 except OSError as e:
175 if e.errno != errno.EPROTONOSUPPORT:
176 raise
177 return False
178 return True
180 protocols = [socket.IPPROTO_SCTP, socket.IPPROTO_TCP]
181 return [p for p in protocols if is_supported(p)]
184def make_passive_socket(
185 node: Optional[str] = None,
186 *,
187 protocol: int,
188 backlog: Optional[int] = None
189) -> Socket:
190 # The IP compatibility layer on Infiniband is usually bound to an IPv4
191 # address.
192 family = socket.AF_INET
193 type = socket.SOCK_STREAM
194 sockfd = socket.socket(family, type, protocol)
195 try:
196 sockfd.bind(("" if node is None else node, 0))
197 sockfd.listen() if backlog is None else sockfd.listen(backlog)
199 listen_addr = sockfd.getsockname()
200 logger.info(f"listening for connections on {listen_addr[0]}:{listen_addr[1]}")
201 except Exception as e:
202 sockfd.close()
203 raise Exception(f'Exception caught listening for connection {e}')
205 return sockfd
208# use Socket everywhere in order to avoid having to bother about the
209# distinction between Socket objects and plain integer file descriptors
210def pipe() -> Tuple[Socket, Socket]:
211 fd_r, fd_w = socketpair()
212 fd_r.shutdown(socket.SHUT_WR)
213 fd_w.shutdown(socket.SHUT_RD)
214 return fd_r, fd_w
217def protocol2str(proto: int) -> str:
218 if proto == socket.IPPROTO_SCTP:
219 return "SCTP"
220 if proto == socket.IPPROTO_TCP:
221 return "TCP"
222 raise ValueError("unknown protocol {:d}".format(proto))
225def socketpair() -> Tuple[Socket, Socket]:
226 return socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
229def str2protocol(name: str) -> int:
230 if name == "SCTP":
231 return socket.IPPROTO_SCTP
232 if name == "TCP":
233 return socket.IPPROTO_TCP
234 raise ValueError("unknown protocol {:s}".format(name))
237def get_rank_and_num_server_proc() -> Tuple[int, int]:
239 try:
240 rank = int(os.environ["SLURM_PROCID"])
241 nb_proc_server = int(os.environ["SLURM_NTASKS"])
242 except KeyError:
243 from mpi4py import MPI
244 rank = MPI.COMM_WORLD.Get_rank()
245 nb_proc_server = MPI.COMM_WORLD.Get_size()
247 return rank, nb_proc_server