Coverage for melissa/launcher/io.py: 54%
362 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-01-23 15:52 +0100
« prev ^ index » next coverage.py v7.10.1, created at 2026-01-23 15:52 +0100
1#!/usr/bin/python3
3# Copyright (c) 2021-2022, Institut National de Recherche en Informatique et en Automatique (Inria)
4# All rights reserved.
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions are met:
8#
9# * Redistributions of source code must retain the above copyright notice, this
10# list of conditions and the following disclaimer.
11#
12# * Redistributions in binary form must reproduce the above copyright notice,
13# this list of conditions and the following disclaimer in the documentation
14# and/or other materials provided with the distribution.
15#
16# * Neither the name of the copyright holder nor the names of its
17# contributors may be used to endorse or promote products derived from
18# this software without specific prior written permission.
19#
20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31import errno
32import logging
33import os
34import select
35import socket
36import sys
37import threading
38import subprocess
39from typing import cast, Dict, Generic, List, Optional, Tuple, TypeVar, Union
40from pathlib import Path
42from melissa.utility import networking, process
43from melissa.utility.networking import Socket, PEER_SANITY_MSG
44from melissa.scheduler.scheduler import Scheduler, IndirectScheduler, HybridScheduler
45from melissa.scheduler.scheduler import Options as SchedulerOptions
46from melissa.launcher.state_machine import Phase
48from . import action, event, message
49from .processor import Processor, DefaultProcessor
50from . import queue
51from melissa.launcher import config
53logger = logging.getLogger(__name__)
55JobT = TypeVar("JobT")
58class IoMaster(Generic[JobT]):
59 # the event file descriptors are reading and writing ends of a pipe or a
60 # UNIX domain socket; they are passed explicitly for testing purposes.
61 #
62 # the event FIFO should have unlimited size because the thread reading from
63 # the queue will also be putting elements into the queue.
64 def __init__(
65 self, listenfd: Socket, signalfd: Socket, timerfd: Socket,
66 webfd: Socket, eventfd_r: Socket, events: queue.Queue,
67 scheduler: Scheduler[JobT], client_options: SchedulerOptions,
68 server_options: SchedulerOptions, processor: Union[Processor, DefaultProcessor],
69 protocol: int, job_limit: int = 1000, std_output: bool = True
70 ) -> None:
71 self._listenfd = listenfd
72 self._protocol = protocol
73 self._signalfd = signalfd
74 self._timerfd = timerfd
75 self._webfd = webfd
76 self._eventfd_r = eventfd_r
77 self._events = events
78 self._scheduler = scheduler
79 self._client_options = client_options
80 self._server_options = server_options
81 self._uid = 0
82 self._peers = [] # type: List[Socket]
83 self._decoders: Dict[int, networking.LengthPrefixFramingDecoder] = {
84 }
85 self._processor = processor
86 self._poller = select.poll()
87 self._minions = [] # type: List[threading.Thread]
88 self._child_processes = [] # type: List[process.Process[str]]
89 self.job_limit = job_limit
90 self.std_output: bool = std_output
92 # this counter is used to keep track of the submitted jobs
93 # that have not been added to the jobs list yet
94 self.job_count: int = 0
96 # a poller index is introduced to make sure the server socket (peerfd)
97 # has a chance to be checked after each job postponement
98 self.poller_idx: int = 0
100 # a postponed_job list is added to avoid overloading the event queue
101 # it will contain all postponed JobSubmission messages
102 self.postponed_job_list: List[event.MessageReception] = []
104 # this value is used to open sockets for server ranks > 0
105 self.server_comm_size: int = 1
107 def _close_connection(self, fd: int) -> None:
108 if fd in self._decoders:
109 del self._decoders[fd]
110 sockfd = [p for p in self._peers if p.fileno() == fd][0]
111 self._poller.unregister(sockfd)
112 self._peers.remove(sockfd)
113 sockfd.close()
115 def _next_uid(self) -> int:
116 uid = self._uid
117 self._uid += 1
118 return uid
120 def _handle_file_descriptor_readable(self, fd: int) -> List[action.Action]:
121 if fd == self._listenfd.fileno():
122 peerfd, peer_addr = self._listenfd.accept()
123 logger.info(f"accepted connection peerfd {peerfd} peeraddr {peer_addr}")
125 self._peers.append(peerfd)
126 self._poller.register(peerfd, select.POLLIN)
128 if self._protocol == socket.IPPROTO_SCTP:
129 try:
130 # TODO query the system for the value
131 SCTP_NODELAY = 3
132 # send pings immediately
133 peerfd.setsockopt(socket.IPPROTO_SCTP, SCTP_NODELAY, 1)
134 peerfd.sendall(PEER_SANITY_MSG)
135 logger.debug(f"sent {PEER_SANITY_MSG.decode('utf-8')} "
136 f"to the peer={peerfd.fileno()}")
137 except Exception as e:
138 logger.warning(
139 f"failed to enable SCTP_NODELAY on socket {peerfd}, {e}"
140 )
141 peerfd.close()
142 raise
143 elif self._protocol == socket.IPPROTO_TCP:
144 self._decoders[peerfd.fileno()
145 ] = networking.LengthPrefixFramingDecoder(
146 config.TCP_MESSAGE_PREFIX_LENGTH
147 )
148 peerfd.sendall(
149 networking.LengthPrefixFramingEncoder(config.TCP_MESSAGE_PREFIX_LENGTH)
150 .execute(PEER_SANITY_MSG)
151 )
152 logger.debug(f"sent {PEER_SANITY_MSG.decode('utf-8')} "
153 f"to the peer={peerfd.fileno()}")
154 else:
155 raise NotImplementedError(
156 f"no support for protocol {self._protocol}")
158 cid = peerfd.fileno()
159 return self._processor.execute(event.NewConnection(cid))
161 if fd == self._signalfd.fileno():
162 b = self._signalfd.recv(1)
163 if len(b) == 0:
164 return []
165 signo = int.from_bytes(b, byteorder=sys.byteorder)
166 return self._processor.execute(event.Signal(signo))
168 if fd == self._timerfd.fileno():
169 b = self._timerfd.recv(1)
170 if len(b) == 0:
171 return []
172 return self._processor.execute(event.Timeout())
174 if fd == self._eventfd_r.fileno():
175 self._eventfd_r.recv(1)
176 num_items = self._events.qsize()
177 if num_items >= 16:
178 logger.warning(f"event FIFO contains {num_items} items")
179 ev = self._events.get()
181 if isinstance(ev, event.ProcessCompletion_):
182 return self._handle_process_completion(ev)
184 if isinstance(ev, event.JobPostponing):
185 # poller index alternates not to miss server pings (last socket)
186 # signals and timeout (first sockets)
187 self.poller_idx = -1 if self.poller_idx == 0 else 0
188 return self._handle_job_postponing(ev)
190 if isinstance(ev, event.JobSubmission):
191 # a JobSubmission event indicate that a MessageReception
192 # transition occurred and one job was added to the jobs list
193 # the counter must then be decremented
194 if self.job_count > 0:
195 self.job_count -= 1
197 return self._processor.execute(ev)
199 if fd == self._webfd.fileno():
200 return self._processor.execute(event.HttpRequest_())
202 raw_bytes = os.read(fd, 4096)
204 if len(raw_bytes) == 0:
205 self._close_connection(fd)
206 return self._processor.execute(event.ConnectionShutdown(fd))
208 if self._protocol == socket.IPPROTO_SCTP:
209 dsz_msg = message.deserialize(raw_bytes)
210 assert isinstance(self._processor, DefaultProcessor)
211 if (
212 isinstance(dsz_msg, message.JobSubmission)
213 and len(self._processor._state.jobs) + self.job_count >= self.job_limit
214 ):
215 # logger.debug(
216 # f"job-limit was reached {len(self._processor._state.jobs)} "
217 # f"+ {self.job_count}, "
218 # "job submission postponed"
219 # )
220 if not self.postponed_job_list:
221 self._events.put(event.JobPostponing())
222 self.postponed_job_list.append(event.MessageReception(fd, dsz_msg))
223 return [action.PostponeSubmission()]
224 else:
225 logger.debug("execute message reception")
226 if isinstance(dsz_msg, message.JobSubmission):
227 self.job_count += 1
228 elif isinstance(dsz_msg, message.CommSize):
229 logger.debug(f"received comm size {dsz_msg.comm_size}")
230 self.server_comm_size = dsz_msg.comm_size
231 logger.debug("execute message reception")
232 return self._processor.execute(
233 event.MessageReception(fd, dsz_msg)
234 )
236 if self._protocol == socket.IPPROTO_TCP:
237 byte_blocks = self._decoders[fd].execute(raw_bytes)
238 actions = [] # type: List[action.Action]
239 for bs in byte_blocks:
240 dsz_msg = message.deserialize(bs)
241 assert isinstance(self._processor, DefaultProcessor)
242 if (
243 isinstance(dsz_msg, message.JobSubmission)
244 and len(self._processor._state.jobs) + self.job_count >= self.job_limit
245 ):
246 # logger.debug(
247 # f"job-limit reached: {len(self._processor._state.jobs)} "
248 # f"+ {self.job_count},"
249 # " job submission postponed"
250 # )
251 if not self.postponed_job_list:
252 self._events.put(event.JobPostponing())
253 self.postponed_job_list.append(event.MessageReception(fd, dsz_msg))
254 actions.extend([action.PostponeSubmission()])
255 else:
256 if isinstance(dsz_msg, message.JobSubmission):
257 self.job_count += 1
258 elif isinstance(dsz_msg, message.CommSize):
259 logger.debug(f"received comm size {dsz_msg.comm_size}")
260 self.server_comm_size = dsz_msg.comm_size
261 logger.debug("execute message reception")
262 actions.extend(
263 self._processor.execute(event.MessageReception(fd, dsz_msg))
264 )
266 return actions
268 raise NotImplementedError("BUG protocol {:d}".format(self._protocol))
270 def _handle_process_completion(
271 self, pc: event.ProcessCompletion_
272 ) -> List[action.Action]:
273 assert (
274 isinstance(self._scheduler, IndirectScheduler)
275 or isinstance(self._scheduler, HybridScheduler)
276 )
278 stdout_filename, stderr_filename = self._make_io_redirect_filenames(
279 self._scheduler, pc.id
280 )
281 with open(stdout_filename, "r") as f:
282 stdout = f.read()
283 if stdout == "":
284 os.unlink(stdout_filename)
286 with open(stderr_filename, "r") as f:
287 stderr = f.read()
288 if stderr == "":
289 os.unlink(stderr_filename)
291 exit_status = pc.process.returncode
292 proc = process.CompletedProcess(exit_status, stdout, stderr)
294 try:
295 ev = event.Event() # mypy Python 3.5 work-around
296 if isinstance(pc.action, action.JobCancellation):
297 self._scheduler.parse_cancel_jobs(pc.action.jobs, proc)
298 ev = event.JobCancellation(pc.action.jobs)
299 logger.debug(f"job cancellation succeeded (UID {pc.id})")
301 elif isinstance(pc.action, action.JobSubmission):
302 job = self._scheduler.make_job(proc, pc.id) # type: ignore
303 ev = event.JobSubmission(pc.action, job)
304 logger.debug(
305 f"job submission succeeded (UID {pc.id} ID {job.id()})"
306 )
307 # a JobSubmission event indicate that a MessageReception
308 # transition occurred and one job was added to the jobs list
309 # the counter must then be decremented
310 if self.job_count > 0:
311 self.job_count -= 1
313 elif isinstance(pc.action, action.JobUpdate):
314 self._scheduler.parse_update_jobs(pc.action.jobs, proc)
315 ev = event.JobUpdate(pc.action.jobs)
316 logger.debug(f"job update succeeded (UID {pc.id})")
318 else:
319 raise NotImplementedError("BUG not implemented")
321 # remove parsed std out and error files
322 if not self.std_output:
323 args = ["rm", f"{stdout_filename}", f"{stderr_filename}"]
324 process.launch(args, {}, subprocess.DEVNULL, subprocess.DEVNULL) # type: ignore
326 except Exception as e:
327 ev = event.ActionFailure(pc.action, e)
328 logger.debug(f"scheduling action failed (UID {pc.id})")
330 return self._processor.execute(ev)
332 def _handle_job_cancellation(
333 self, jc: action.JobCancellation[JobT]
334 ) -> None:
335 assert jc.jobs
337 if self._scheduler.is_hybrid():
338 # when cancel_jobs is called inside Thread, the return values are ignored
339 # and only clients jobs are cancelled through their associated subprocess
340 t = threading.Thread(
341 target=lambda: self._scheduler.cancel_client_jobs(jc.jobs) # type: ignore
342 )
343 t.start()
345 hysched = cast(HybridScheduler[JobT], self._scheduler)
346 # this time the return values are kept and processed so that the server job
347 # only is killed through the scheduler
348 out = hysched.cancel_server_job(jc.jobs)
349 if out:
350 uid = self._next_uid()
351 args, env = out
352 proc, _, _ = self._launch_process(uid, args, env)
353 self._run_process_asynchronously(jc, uid, proc)
354 logger.debug(f"cancelling jobs uid={uid}")
355 else:
356 self._events.put(event.JobCancellation(jc.jobs))
357 return
359 if self._scheduler.is_direct():
360 t = threading.Thread(
361 target=lambda: self._scheduler.cancel_jobs(jc.jobs) # type: ignore
362 )
363 t.start()
364 self._events.put(event.JobCancellation(jc.jobs))
365 return
367 sched = cast(IndirectScheduler[JobT], self._scheduler)
368 uid = self._next_uid()
369 args, env = sched.cancel_jobs(jc.jobs)
370 proc, _, _ = self._launch_process(uid, args, env)
371 self._run_process_asynchronously(jc, uid, proc)
372 logger.debug(f"cancelling jobs uid={uid}")
374 def _handle_job_submission(self, sub: action.JobSubmission) -> None:
375 if isinstance(sub, action.ServerJobSubmission):
376 options = self._server_options
377 else:
378 options = self._client_options
379 uid = self._next_uid()
380 args, env = self._scheduler.submit_job(
381 sub.commands, sub.environment, options, sub.job_name, uid
382 )
383 proc, stdout_fname, stderr_fname = self._launch_process(uid, args, env)
385 if self._scheduler.is_direct():
386 job = self._scheduler.make_job( # type: ignore
387 proc,
388 uid,
389 stdout_fname=stdout_fname,
390 stderr_fname=stderr_fname
391 )
392 self._events.put(event.JobSubmission(sub, job))
393 logger.debug(f"job launched uid={uid} id={job.id()}")
394 return
396 self._run_process_asynchronously(sub, uid, proc)
397 logger.debug(f"submitting job uid {uid}")
399 def _handle_job_update(self, ju: action.JobUpdate[JobT]) -> None:
400 assert ju.jobs
401 jobs = ju.jobs
403 if self._scheduler.is_direct():
404 self._scheduler.update_jobs(jobs) # type: ignore
405 self._events.put(event.JobUpdate(jobs))
406 return
408 sched = cast(IndirectScheduler[JobT], self._scheduler)
409 uid = self._next_uid()
410 args, env = sched.update_jobs(jobs)
411 proc, _, _ = self._launch_process(uid, args, env)
412 self._run_process_asynchronously(ju, uid, proc)
413 logger.debug(f"updating jobs uid {uid}")
415 def _handle_message_sending(self, msg: action.MessageSending) -> None:
416 peers = [p for p in self._peers if p.fileno() in msg.cids]
417 assert len(peers) <= self.server_comm_size
418 logger.debug(f"message sending on peers {peers} with msg.cids {msg.cids}")
419 for peer in peers:
420 try:
421 if self._protocol == socket.IPPROTO_TCP:
422 bs = networking.LengthPrefixFramingEncoder(
423 config.TCP_MESSAGE_PREFIX_LENGTH
424 ).execute(msg.message.serialize())
425 ret = peer.send(bs, socket.MSG_DONTWAIT)
426 assert ret == len(bs)
427 else:
428 bs = msg.message.serialize()
429 ret = peer.send(bs, socket.MSG_DONTWAIT)
430 assert ret == len(bs)
431 except OSError as e:
432 assert e.errno != errno.EMSGSIZE
433 self._events.put(event.ActionFailure(msg, e))
435 def _handle_job_postponing(
436 self, ev: event.JobPostponing
437 ) -> List[Union[action.Action, action.PostponeSubmission]]:
438 assert isinstance(self._processor, DefaultProcessor)
439 # make sure no additional client will be submitted if the server is dead
440 if self._processor._state.phase == Phase.SERVER_DEAD:
441 logger.warning("Server is dead, postponed job won't be submitted")
442 self.postponed_job_list = []
443 return [action.PostponeSubmission()]
444 # postpone again if the job-limit is still reached
445 elif len(self._processor._state.jobs) + self.job_count >= self.job_limit:
446 # logger.debug(
447 # f"job-limit reached: {len(self._processor._state.jobs)} + {self.job_count}, "
448 # "job submission postponed again"
449 # )
450 self._events.put(ev)
451 return [action.PostponeSubmission()]
452 # submit job
453 else:
454 logger.debug(
455 f"job-limit not reached: {len(self._processor._state.jobs)} + {self.job_count}, "
456 "job will be submitted"
457 )
458 # as long as the JobSubmission event resulting from the transition below is not
459 # processed the jobs list won't be updated so the counter must be incremented
460 self.job_count += 1
461 job_submission_message = self.postponed_job_list.pop(0)
462 if len(self.postponed_job_list) > 0:
463 self._events.put(ev)
464 return self._processor.execute(
465 job_submission_message
466 )
468 def run(self) -> int:
469 for sock in [
470 self._listenfd, self._signalfd, self._timerfd, self._webfd,
471 self._eventfd_r
472 ]:
473 self._poller.register(sock, select.POLLIN)
475 exit_status: Optional[int] = None
476 while exit_status is None:
477 listfd = self._poller.poll()
479 assert listfd != []
481 def is_set(x: int, flag: int) -> bool:
482 assert x >= 0
483 assert flag >= 0
484 return x & flag == flag
486 # process only one event at a time. this is relevant, e.g., when
487 # the state machine wants to stop.
488 fd, revent = listfd[self.poller_idx]
489 if is_set(revent, select.POLLIN):
490 try:
491 actions = self._handle_file_descriptor_readable(fd)
492 except Exception as e:
493 logger.warning(f"server job already inactive {e}")
494 continue
495 elif is_set(revent, select.POLLOUT):
496 assert False
497 continue
498 elif is_set(revent, select.POLLERR):
499 logger.warning(f"file descriptor {fd} read end closed")
500 self._close_connection(fd)
501 # TODO: Is this okay ? Should we remove this ?
502 # There will be assertion errors when server ranks > 0 try to reach this point
503 if fd != self._processor._state.server_cid:
504 continue
505 actions = self._processor.execute(event.ConnectionShutdown(fd))
506 elif is_set(revent, select.POLLHUP):
507 logger.warning(f"file descriptor {fd} write end closed")
508 continue
509 elif is_set(revent, select.POLLNVAL):
510 logger.error(f"file descriptor is closed {fd}")
511 assert False
512 continue
514 for ac in actions:
515 assert exit_status is None
516 if type(ac) is not action.PostponeSubmission:
517 logger.debug(f"executing {type(ac)}")
519 if isinstance(ac, action.ConnectionClosure):
520 self._close_connection(ac.cid)
521 elif isinstance(ac, action.Exit):
522 exit_status = ac.status
523 elif isinstance(ac, action.JobCancellation):
524 self._handle_job_cancellation(ac)
525 elif isinstance(ac, action.JobSubmission):
526 self._handle_job_submission(ac)
527 elif isinstance(ac, action.JobUpdate):
528 self._handle_job_update(ac)
529 elif isinstance(ac, action.MessageSending):
530 self._handle_message_sending(ac)
531 elif isinstance(ac, action.PostponeSubmission):
532 pass
533 elif isinstance(ac, action.ConnectionServer):
534 pass
535 else:
536 logger.error(f"unhandled action {ac}")
538 self._minions = [t for t in self._minions if t.is_alive()]
539 self._child_processes = [
540 p for p in self._child_processes if p.poll() is None
541 ]
543 for sock in self._peers:
544 sock.close()
546 if self._child_processes:
547 logger.warning(
548 f"{len(self._child_processes)} child processes still running"
549 )
551 if self._minions:
552 logger.warning(
553 f"list of worker threads not empty (length {len(self._minions)})"
554 )
556 # indirect scheduler: the object in the queue is assumed to be a
557 # ProcessCompletion_ instance from the job cancellation
558 if (self._scheduler.is_direct() and not self._events.empty()) \
559 or self._events.qsize() > 1:
560 logger.warning(
561 f"expected empty event queue, have {self._events.qsize()} queued items"
562 )
564 return exit_status
566 @classmethod
567 def _make_io_redirect_filenames(cls, scheduler: Scheduler[JobT],
568 uid: int) -> Tuple[str, str]:
569 """
570 Generate names for the files storing standard output and standard error
571 of a process.
572 """
573 Path("./stdout").mkdir(parents=True, exist_ok=True)
575 def f(suffix: str) -> str:
576 return f"./stdout/{scheduler.name()}.{uid}.{suffix}"
578 return f("out"), f("err")
580 def _launch_process(
581 self,
582 uid: int,
583 args: process.ArgumentList,
584 env: process.Environment,
585 ) -> "Tuple[process.Process[str], str | None, str | None]":
586 logger.info(f"submission command: {' '.join(args)}")
587 if not self.std_output and self._scheduler.is_direct():
588 return (
589 process.launch(args, env, subprocess.PIPE, subprocess.PIPE), # type:ignore
590 None,
591 None
592 )
594 stdout_filename, stderr_filename = self._make_io_redirect_filenames(
595 self._scheduler, uid
596 )
597 with open(stdout_filename, "w") as stdout:
598 with open(stderr_filename, "w") as stderr:
599 return (
600 process.launch(args, env, stdout, stderr),
601 stdout_filename,
602 stderr_filename
603 )
605 def _run_process_asynchronously(
606 self, ac: action.Action, uid: int, proc: "process.Process[str]"
607 ) -> None:
608 assert not self._scheduler.is_direct() or self._scheduler.is_hybrid()
610 minion = threading.Thread(
611 target=self._wait_for_process, args=(self._events, ac, uid, proc)
612 )
613 self._child_processes.append(proc)
614 self._minions.append(minion)
615 minion.start()
617 @classmethod
618 def _wait_for_process(
619 cls,
620 events: queue.Queue,
621 ac: action.Action,
622 uid: int,
623 proc: "process.Process[str]",
624 ) -> None:
625 try:
626 proc.wait()
627 events.put(event.ProcessCompletion_(ac, uid, proc))
628 except Exception as e:
629 logger.error(f"process command: {str(proc)}")
630 logger.debug(f"wait for process exception: {e}")
631 # this situation would arise when the job submission number is restricted
632 # e.g. if only n oarsub are allowed the additional ones will fail ultimately
633 # causing the launcher to fail
634 logger.error("wait_for_process thread crashed the launcher will stop")