Coverage for melissa/launcher/io.py: 11%
353 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 22:25 +0100
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 22:25 +0100
1#!/usr/bin/python3
3# Copyright (c) 2021-2022, Institut National de Recherche en Informatique et en Automatique (Inria)
4# All rights reserved.
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions are met:
8#
9# * Redistributions of source code must retain the above copyright notice, this
10# list of conditions and the following disclaimer.
11#
12# * Redistributions in binary form must reproduce the above copyright notice,
13# this list of conditions and the following disclaimer in the documentation
14# and/or other materials provided with the distribution.
15#
16# * Neither the name of the copyright holder nor the names of its
17# contributors may be used to endorse or promote products derived from
18# this software without specific prior written permission.
19#
20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31import errno
32import logging
33import os
34import select
35import socket
36import sys
37import threading
38import subprocess
39from typing import cast, Dict, Generic, List, Optional, Tuple, TypeVar, Union
40from pathlib import Path
42from melissa.utility import networking, process
43from melissa.utility.networking import Socket
44from melissa.scheduler.scheduler import Scheduler, IndirectScheduler, HybridScheduler
45from melissa.scheduler.scheduler import Options as SchedulerOptions
46from melissa.launcher.state_machine import Phase
48from . import action, event, message
49from .processor import Processor, DefaultProcessor
50from . import queue
51from melissa.launcher import config
53logger = logging.getLogger(__name__)
55JobT = TypeVar("JobT")
58class IoMaster(Generic[JobT]):
59 # the event file descriptors are reading and writing ends of a pipe or a
60 # UNIX domain socket; they are passed explicitly for testing purposes.
61 #
62 # the event FIFO should have unlimited size because the thread reading from
63 # the queue will also be putting elements into the queue.
64 def __init__(
65 self, listenfd: Socket, signalfd: Socket, timerfd: Socket,
66 webfd: Socket, eventfd_r: Socket, events: queue.Queue,
67 scheduler: Scheduler[JobT], client_options: SchedulerOptions,
68 server_options: SchedulerOptions, processor: Union[Processor, DefaultProcessor],
69 protocol: int, job_limit: int = 1000, std_output: bool = True
70 ) -> None:
71 self._listenfd = listenfd
72 self._protocol = protocol
73 self._signalfd = signalfd
74 self._timerfd = timerfd
75 self._webfd = webfd
76 self._eventfd_r = eventfd_r
77 self._events = events
78 self._scheduler = scheduler
79 self._client_options = client_options
80 self._server_options = server_options
81 self._uid = 0
82 self._peers = [] # type: List[Socket]
83 self._decoders: Dict[int, networking.LengthPrefixFramingDecoder] = {
84 }
85 self._processor = processor
86 self._poller = select.poll()
87 self._minions = [] # type: List[threading.Thread]
88 self._child_processes = [] # type: List[process.Process[str]]
89 self.job_limit = job_limit
90 self.std_output: bool = std_output
92 # this counter is used to keep track of the submitted jobs
93 # that have not been added to the jobs list yet
94 self.job_count: int = 0
96 # a poller index is introduced to make sure the server socket (peerfd)
97 # has a chance to be checked after each job postponement
98 self.poller_idx: int = 0
100 # a postponed_job list is added to avoid overloading the event queue
101 # it will contain all postponed JobSubmission messages
102 self.postponed_job_list: List[event.MessageReception] = []
104 # this value is used to open sockets for server ranks > 0
105 self.server_comm_size: int = 1
107 def _close_connection(self, fd: int) -> None:
108 if fd in self._decoders:
109 del self._decoders[fd]
110 sockfd = [p for p in self._peers if p.fileno() == fd][0]
111 self._poller.unregister(sockfd)
112 self._peers.remove(sockfd)
113 sockfd.close()
115 def _next_uid(self) -> int:
116 uid = self._uid
117 self._uid += 1
118 return uid
120 def _handle_file_descriptor_readable(self, fd: int) -> List[action.Action]:
121 if fd == self._listenfd.fileno():
122 peerfd, peer_addr = self._listenfd.accept()
123 logger.info(f"accepted connection peerfd {peerfd} peeraddr {peer_addr}")
125 self._peers.append(peerfd)
126 self._poller.register(peerfd, select.POLLIN)
128 if self._protocol == socket.IPPROTO_SCTP:
129 try:
130 # TODO query the system for the value
131 SCTP_NODELAY = 3
132 # send pings immediately
133 peerfd.setsockopt(socket.IPPROTO_SCTP, SCTP_NODELAY, 1)
134 except Exception as e:
135 logger.warning(
136 f"failed to enable SCTP_NODELAY on socket {peerfd}, {e}"
137 )
138 peerfd.close()
139 raise
140 elif self._protocol == socket.IPPROTO_TCP:
141 self._decoders[peerfd.fileno()
142 ] = networking.LengthPrefixFramingDecoder(
143 config.TCP_MESSAGE_PREFIX_LENGTH
144 )
145 else:
146 raise NotImplementedError(
147 f"no support for protocol {self._protocol}")
149 cid = peerfd.fileno()
150 return self._processor.execute(event.NewConnection(cid))
152 if fd == self._signalfd.fileno():
153 b = self._signalfd.recv(1)
154 if len(b) == 0:
155 return []
156 signo = int.from_bytes(b, byteorder=sys.byteorder)
157 return self._processor.execute(event.Signal(signo))
159 if fd == self._timerfd.fileno():
160 b = self._timerfd.recv(1)
161 if len(b) == 0:
162 return []
163 return self._processor.execute(event.Timeout())
165 if fd == self._eventfd_r.fileno():
166 self._eventfd_r.recv(1)
167 num_items = self._events.qsize()
168 if num_items >= 16:
169 logger.warning(f"event FIFO contains {num_items} items")
170 ev = self._events.get()
172 if isinstance(ev, event.ProcessCompletion_):
173 return self._handle_process_completion(ev)
175 if isinstance(ev, event.JobPostponing):
176 # poller index alternates not to miss server pings (last socket)
177 # signals and timeout (first sockets)
178 self.poller_idx = -1 if self.poller_idx == 0 else 0
179 return self._handle_job_postponing(ev)
181 if isinstance(ev, event.JobSubmission):
182 # a JobSubmission event indicate that a MessageReception
183 # transition occurred and one job was added to the jobs list
184 # the counter must then be decremented
185 if self.job_count > 0:
186 self.job_count -= 1
188 return self._processor.execute(ev)
190 if fd == self._webfd.fileno():
191 return self._processor.execute(event.HttpRequest_())
193 raw_bytes = os.read(fd, 4096)
195 if len(raw_bytes) == 0:
196 self._close_connection(fd)
197 return self._processor.execute(event.ConnectionShutdown(fd))
199 if self._protocol == socket.IPPROTO_SCTP:
200 dsz_msg = message.deserialize(raw_bytes)
201 assert isinstance(self._processor, DefaultProcessor)
202 if (
203 isinstance(dsz_msg, message.JobSubmission)
204 and len(self._processor._state.jobs) + self.job_count >= self.job_limit
205 ):
206 # logger.debug(
207 # f"job-limit was reached {len(self._processor._state.jobs)} "
208 # f"+ {self.job_count}, "
209 # "job submission postponed"
210 # )
211 if not self.postponed_job_list:
212 self._events.put(event.JobPostponing())
213 self.postponed_job_list.append(event.MessageReception(fd, dsz_msg))
214 return [action.PostponeSubmission()]
215 else:
216 logger.debug("execute message reception")
217 if isinstance(dsz_msg, message.JobSubmission):
218 self.job_count += 1
219 elif isinstance(dsz_msg, message.CommSize):
220 logger.debug(f"received comm size {dsz_msg.comm_size}")
221 self.server_comm_size = dsz_msg.comm_size
222 return self._processor.execute(
223 event.MessageReception(fd, dsz_msg)
224 )
226 if self._protocol == socket.IPPROTO_TCP:
227 byte_blocks = self._decoders[fd].execute(raw_bytes)
228 actions = [] # type: List[action.Action]
229 for bs in byte_blocks:
230 dsz_msg = message.deserialize(bs)
231 assert isinstance(self._processor, DefaultProcessor)
232 if (
233 isinstance(dsz_msg, message.JobSubmission)
234 and len(self._processor._state.jobs) + self.job_count >= self.job_limit
235 ):
236 # logger.debug(
237 # f"job-limit reached: {len(self._processor._state.jobs)} "
238 # f"+ {self.job_count},"
239 # " job submission postponed"
240 # )
241 if not self.postponed_job_list:
242 self._events.put(event.JobPostponing())
243 self.postponed_job_list.append(event.MessageReception(fd, dsz_msg))
244 actions.extend([action.PostponeSubmission()])
245 else:
246 if isinstance(dsz_msg, message.JobSubmission):
247 self.job_count += 1
248 logger.debug("execute message reception")
249 actions.extend(
250 self._processor.execute(event.MessageReception(fd, dsz_msg))
251 )
253 return actions
255 raise NotImplementedError("BUG protocol {:d}".format(self._protocol))
257 def _handle_process_completion(
258 self, pc: event.ProcessCompletion_
259 ) -> List[action.Action]:
260 assert (
261 isinstance(self._scheduler, IndirectScheduler)
262 or isinstance(self._scheduler, HybridScheduler)
263 )
265 stdout_filename, stderr_filename = self._make_io_redirect_filenames(
266 self._scheduler, pc.id
267 )
268 with open(stdout_filename, "r") as f:
269 stdout = f.read()
270 if stdout == "":
271 os.unlink(stdout_filename)
273 with open(stderr_filename, "r") as f:
274 stderr = f.read()
275 if stderr == "":
276 os.unlink(stderr_filename)
278 exit_status = pc.process.returncode
279 proc = process.CompletedProcess(exit_status, stdout, stderr)
281 try:
282 ev = event.Event() # mypy Python 3.5 work-around
283 if isinstance(pc.action, action.JobCancellation):
284 self._scheduler.parse_cancel_jobs(pc.action.jobs, proc)
285 ev = event.JobCancellation(pc.action.jobs)
286 logger.debug(f"job cancellation succeeded (UID {pc.id})")
288 elif isinstance(pc.action, action.JobSubmission):
289 job = self._scheduler.make_job(proc, pc.id) # type: ignore
290 ev = event.JobSubmission(pc.action, job)
291 logger.debug(
292 f"job submission succeeded (UID {pc.id} ID {job.id()})"
293 )
294 # a JobSubmission event indicate that a MessageReception
295 # transition occurred and one job was added to the jobs list
296 # the counter must then be decremented
297 if self.job_count > 0:
298 self.job_count -= 1
300 elif isinstance(pc.action, action.JobUpdate):
301 self._scheduler.parse_update_jobs(pc.action.jobs, proc)
302 ev = event.JobUpdate(pc.action.jobs)
303 logger.debug(f"job update succeeded (UID {pc.id})")
305 else:
306 raise NotImplementedError("BUG not implemented")
308 # remove parsed std out and error files
309 if not self.std_output:
310 args = ["rm", f"{stdout_filename}", f"{stderr_filename}"]
311 process.launch(args, {}, subprocess.DEVNULL, subprocess.DEVNULL) # type: ignore
313 except Exception as e:
314 ev = event.ActionFailure(pc.action, e)
315 logger.debug(f"scheduling action failed (UID {pc.id})")
317 return self._processor.execute(ev)
319 def _handle_job_cancellation(
320 self, jc: action.JobCancellation[JobT]
321 ) -> None:
322 assert jc.jobs
324 if self._scheduler.is_hybrid():
325 # when cancel_jobs is called inside Thread, the return values are ignored
326 # and only clients jobs are cancelled through their associated subprocess
327 t = threading.Thread(
328 target=lambda: self._scheduler.cancel_client_jobs(jc.jobs) # type: ignore
329 )
330 t.start()
332 hysched = cast(HybridScheduler[JobT], self._scheduler)
333 # this time the return values are kept and processed so that the server job
334 # only is killed through the scheduler
335 out = hysched.cancel_server_job(jc.jobs)
336 if out:
337 uid = self._next_uid()
338 args, env = out
339 proc = self._launch_process(uid, args, env)
340 self._run_process_asynchronously(jc, uid, proc)
341 logger.debug(f"cancelling jobs uid={uid}")
342 else:
343 self._events.put(event.JobCancellation(jc.jobs))
344 return
346 if self._scheduler.is_direct():
347 t = threading.Thread(
348 target=lambda: self._scheduler.cancel_jobs(jc.jobs) # type: ignore
349 )
350 t.start()
351 self._events.put(event.JobCancellation(jc.jobs))
352 return
354 sched = cast(IndirectScheduler[JobT], self._scheduler)
355 uid = self._next_uid()
356 args, env = sched.cancel_jobs(jc.jobs)
357 proc = self._launch_process(uid, args, env)
358 self._run_process_asynchronously(jc, uid, proc)
359 logger.debug(f"cancelling jobs uid={uid}")
361 def _handle_job_submission(self, sub: action.JobSubmission) -> None:
362 if isinstance(sub, action.ServerJobSubmission):
363 options = self._server_options
364 else:
365 options = self._client_options
366 uid = self._next_uid()
367 args, env = self._scheduler.submit_job(
368 sub.commands, sub.environment, options, sub.job_name, uid
369 )
370 proc = self._launch_process(uid, args, env)
372 if self._scheduler.is_direct():
373 job = self._scheduler.make_job(proc, uid) # type: ignore
374 self._events.put(event.JobSubmission(sub, job))
375 logger.debug(f"job launched uid={uid} id={job.id()}")
376 return
378 self._run_process_asynchronously(sub, uid, proc)
379 logger.debug(f"submitting job uid {uid}")
381 def _handle_job_update(self, ju: action.JobUpdate[JobT]) -> None:
382 assert ju.jobs
383 jobs = ju.jobs
385 if self._scheduler.is_direct():
386 self._scheduler.update_jobs(jobs) # type: ignore
387 self._events.put(event.JobUpdate(jobs))
388 return
390 sched = cast(IndirectScheduler[JobT], self._scheduler)
391 uid = self._next_uid()
392 args, env = sched.update_jobs(jobs)
393 proc = self._launch_process(uid, args, env)
394 self._run_process_asynchronously(ju, uid, proc)
395 logger.debug(f"updating jobs uid {uid}")
397 def _handle_message_sending(self, msg: action.MessageSending) -> None:
398 peers = [p for p in self._peers if p.fileno() in msg.cid]
399 assert len(peers) <= self.server_comm_size
400 logger.debug(f"message sending on peers {peers} with msg.cid {msg.cid}")
401 for peer in peers:
402 try:
403 if self._protocol == socket.IPPROTO_TCP:
404 bs = networking.LengthPrefixFramingEncoder(
405 config.TCP_MESSAGE_PREFIX_LENGTH
406 ).execute(msg.message.serialize())
407 ret = peer.send(bs, socket.MSG_DONTWAIT)
408 assert ret == len(bs)
409 else:
410 bs = msg.message.serialize()
411 ret = peer.send(bs, socket.MSG_DONTWAIT)
412 assert ret == len(bs)
413 except OSError as e:
414 assert e.errno != errno.EMSGSIZE
415 self._events.put(event.ActionFailure(msg, e))
417 def _handle_job_postponing(
418 self, ev: event.JobPostponing
419 ) -> List[Union[action.Action, action.PostponeSubmission]]:
420 assert isinstance(self._processor, DefaultProcessor)
421 # make sure no additional client will be submitted if the server is dead
422 if self._processor._state.phase == Phase.SERVER_DEAD:
423 logger.warning("Server is dead, postponed job won't be submitted")
424 self.postponed_job_list = []
425 return [action.PostponeSubmission()]
426 # postpone again if the job-limit is still reached
427 elif len(self._processor._state.jobs) + self.job_count >= self.job_limit:
428 # logger.debug(
429 # f"job-limit reached: {len(self._processor._state.jobs)} + {self.job_count}, "
430 # "job submission postponed again"
431 # )
432 self._events.put(ev)
433 return [action.PostponeSubmission()]
434 # submit job
435 else:
436 logger.debug(
437 f"job-limit not reached: {len(self._processor._state.jobs)} + {self.job_count}, "
438 "job will be submitted"
439 )
440 # as long as the JobSubmission event resulting from the transition below is not
441 # processed the jobs list won't be updated so the counter must be incremented
442 self.job_count += 1
443 job_submission_message = self.postponed_job_list.pop(0)
444 if len(self.postponed_job_list) > 0:
445 self._events.put(ev)
446 return self._processor.execute(
447 job_submission_message
448 )
450 def run(self) -> int:
451 for sock in [
452 self._listenfd, self._signalfd, self._timerfd, self._webfd,
453 self._eventfd_r
454 ]:
455 self._poller.register(sock, select.POLLIN)
457 exit_status: Optional[int] = None
458 while exit_status is None:
459 listfd = self._poller.poll()
461 assert listfd != []
463 def is_set(x: int, flag: int) -> bool:
464 assert x >= 0
465 assert flag >= 0
466 return x & flag == flag
468 # process only one event at a time. this is relevant, e.g., when
469 # the state machine wants to stop.
470 fd, revent = listfd[self.poller_idx]
471 if is_set(revent, select.POLLIN):
472 try:
473 actions = self._handle_file_descriptor_readable(fd)
474 except Exception as e:
475 logger.warning(f"server job already inactive {e}")
476 continue
477 elif is_set(revent, select.POLLOUT):
478 assert False
479 continue
480 elif is_set(revent, select.POLLERR):
481 logger.warning(f"file descriptor {fd} read end closed")
482 self._close_connection(fd)
483 # TODO: Is this okay ? Should we remove this ?
484 # There will be assertion errors when server ranks > 0 try to reach this point
485 if fd != self._processor._state.server_cid:
486 continue
487 actions = self._processor.execute(event.ConnectionShutdown(fd))
488 elif is_set(revent, select.POLLHUP):
489 logger.warning(f"file descriptor {fd} write end closed")
490 continue
491 elif is_set(revent, select.POLLNVAL):
492 logger.error(f"file descriptor is closed {fd}")
493 assert False
494 continue
496 for ac in actions:
497 assert exit_status is None
498 if type(ac) is not action.PostponeSubmission:
499 logger.debug(f"executing {type(ac)}")
501 if isinstance(ac, action.ConnectionClosure):
502 self._close_connection(ac.cid)
503 elif isinstance(ac, action.Exit):
504 exit_status = ac.status
505 elif isinstance(ac, action.JobCancellation):
506 self._handle_job_cancellation(ac)
507 elif isinstance(ac, action.JobSubmission):
508 self._handle_job_submission(ac)
509 elif isinstance(ac, action.JobUpdate):
510 self._handle_job_update(ac)
511 elif isinstance(ac, action.MessageSending):
512 self._handle_message_sending(ac)
513 elif isinstance(ac, action.PostponeSubmission):
514 pass
515 elif isinstance(ac, action.ConnectionServer):
516 pass
517 else:
518 logger.error(f"unhandled action {ac}")
520 self._minions = [t for t in self._minions if t.is_alive()]
521 self._child_processes = [
522 p for p in self._child_processes if p.poll() is None
523 ]
525 for sock in self._peers:
526 sock.close()
528 if self._child_processes:
529 logger.warning(
530 f"{len(self._child_processes)} child processes still running"
531 )
533 if self._minions:
534 logger.warning(
535 f"list of worker threads not empty (length {len(self._minions)})"
536 )
538 # indirect scheduler: the object in the queue is assumed to be a
539 # ProcessCompletion_ instance from the job cancellation
540 if (self._scheduler.is_direct() and not self._events.empty()) \
541 or self._events.qsize() > 1:
542 logger.warning(
543 f"expected empty event queue, have {self._events.qsize()} queued items"
544 )
546 return exit_status
548 @classmethod
549 def _make_io_redirect_filenames(cls, scheduler: Scheduler[JobT],
550 uid: int) -> Tuple[str, str]:
551 """
552 Generate names for the files storing standard output and standard error
553 of a process.
554 """
555 Path("./stdout").mkdir(parents=True, exist_ok=True)
557 def f(suffix: str) -> str:
558 return f"./stdout/{scheduler.name()}.{uid}.{suffix}"
560 return f("out"), f("err")
562 def _launch_process(
563 self,
564 uid: int,
565 args: process.ArgumentList,
566 env: process.Environment,
567 ) -> "process.Process[str]":
568 logger.info(f"submission command: {' '.join(args)}")
569 if not self.std_output and self._scheduler.is_direct():
570 return process.launch(args, env, subprocess.DEVNULL, subprocess.DEVNULL) # type: ignore
571 stdout_filename, stderr_filename = self._make_io_redirect_filenames(
572 self._scheduler, uid
573 )
574 with open(stdout_filename, "w") as stdout:
575 with open(stderr_filename, "w") as stderr:
576 return process.launch(args, env, stdout, stderr)
578 def _run_process_asynchronously(
579 self, ac: action.Action, uid: int, proc: "process.Process[str]"
580 ) -> None:
581 assert not self._scheduler.is_direct() or self._scheduler.is_hybrid()
583 minion = threading.Thread(
584 target=self._wait_for_process, args=(self._events, ac, uid, proc)
585 )
586 self._child_processes.append(proc)
587 self._minions.append(minion)
588 minion.start()
590 @classmethod
591 def _wait_for_process(
592 cls,
593 events: queue.Queue,
594 ac: action.Action,
595 uid: int,
596 proc: "process.Process[str]",
597 ) -> None:
598 try:
599 proc.wait()
600 events.put(event.ProcessCompletion_(ac, uid, proc))
601 except Exception as e:
602 logger.debug(f"wait for process exception: {e}")
603 # this situation would arise when the job submission number is restricted
604 # e.g. if only n oarsub are allowed the additional ones will fail ultimately
605 # causing the launcher to fail
606 logger.error("wait_for_process thread crashed the launcher will stop")