Coverage for melissa/scheduler/openmpi.py: 46%
108 statements
« prev ^ index » next coverage.py v7.10.1, created at 2026-01-23 15:52 +0100
« prev ^ index » next coverage.py v7.10.1, created at 2026-01-23 15:52 +0100
1#!/usr/bin/python3
3# Copyright (c) 2020-2022, Institut National de Recherche en Informatique et en Automatique (Inria)
4# All rights reserved.
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions are met:
8#
9# * Redistributions of source code must retain the above copyright notice, this
10# list of conditions and the following disclaimer.
11#
12# * Redistributions in binary form must reproduce the above copyright notice,
13# this list of conditions and the following disclaimer in the documentation
14# and/or other materials provided with the distribution.
15#
16# * Neither the name of the copyright holder nor the names of its
17# contributors may be used to endorse or promote products derived from
18# this software without specific prior written permission.
19#
20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31import logging
32import os
33import re
34import shutil
35import subprocess
36from typing import List, Tuple, Union, Any
38from melissa.utility import time
39from melissa.utility.process import ArgumentList, Environment, Process
41from .job import Id, Job, State
42from .scheduler import DirectScheduler, Options
44logger = logging.getLogger(__name__)
47class OpenMpiJob(Job):
48 def __init__(self, uid: Id, process: "subprocess.Popen[str]") -> None:
49 super().__init__()
50 self._uid = uid
51 self._process = process
52 self._state = State.RUNNING
54 def id(self) -> Id:
55 return self._process.pid
57 def unique_id(self) -> Union[str, int]:
58 return self._uid
60 def state(self) -> State:
61 return self._state
63 def __repr__(self) -> str:
64 r = "OpenMpiJob(id={:d},state={:s})".format(self.id(), str(self._state))
65 return r
68class OpenMpiScheduler(DirectScheduler[OpenMpiJob]):
69 @classmethod
70 def _name_impl(cls) -> str:
71 return "openmpi"
73 @classmethod
74 def _is_available_impl(cls) -> Tuple[bool, Union[str, Tuple[str, str]]]:
75 mpirun_path = shutil.which("mpirun")
76 if mpirun_path is None:
77 return False, "mpirun executable not found"
79 mpirun = subprocess.run(
80 [mpirun_path, "--version"],
81 stdin=subprocess.DEVNULL,
82 stdout=subprocess.PIPE,
83 stderr=subprocess.PIPE,
84 universal_newlines=True,
85 )
86 if mpirun.returncode != 0:
87 return False, "failed to execute %s: %s" % (mpirun_path, mpirun.stderr)
89 if mpirun.stderr != "":
90 logger.warning("%s: %s", mpirun_path, mpirun.stderr)
92 version_re = r"^mpirun \(Open MPI\) (\d+[.]\d+[.]\d+)"
93 match = re.match(version_re, mpirun.stdout)
94 if not match:
95 return False, "{:s} is not an OpenMPI mpirun".format(mpirun_path)
97 ompi_version = match.group(1)
98 return True, (mpirun_path, ompi_version)
100 def _sanity_check_impl(self, options: Options) -> List[str]:
101 args = options.raw_arguments
102 es = []
104 for a in args:
105 e = None
106 if "do-not-launch" in a:
107 e = "remove `{:s}` argument".format(a)
108 elif a in ["-N", "-c", "-n", "--n", "-np"]:
109 e = "remove `{:s}` argument".format(a)
111 if e is not None:
112 es.append(e)
114 return es
116 def _submit_job_impl(
117 self,
118 commands: List[ArgumentList],
119 env: Environment,
120 options: Options,
121 name: str,
122 unique_id: int,
123 ) -> Tuple[ArgumentList, Environment]:
124 # Approach to environment variables:
125 # Follow OpenMPI mpirun man page advice, that is,
126 # * set `VARIABLE=VALUE` in mpirun environment,
127 # * pass `-x VARIABLE` on the mpirun command line.
129 ompi_env = os.environ.copy()
130 env_args = [] # type: List[str]
131 for key in sorted(env.keys()):
132 ompi_env[key] = env[key]
133 env_args += ["-x", key]
135 ompi_options = options.raw_arguments + options.sched_cmd_opt
137 ompi_commands = [] # type: List[str]
138 for i, cmd in enumerate(commands):
139 ompi_cmd = (
140 ompi_options
141 + env_args
142 + ["--"]
143 + cmd
144 + ([":"] if i + 1 < len(commands) else [])
145 )
147 ompi_commands = ompi_commands + ompi_cmd
149 ompi_call = [options.sched_cmd] + ompi_commands
150 return ompi_call, ompi_env
152 def _make_job_impl(self, proc: "Process[str]", unique_id: int,
153 **kwargs: "dict[str, Any]") -> OpenMpiJob:
154 return OpenMpiJob(unique_id, proc)
156 @classmethod
157 def _update_jobs_impl(cls, jobs: List[OpenMpiJob]) -> None:
158 for j in jobs:
159 returncode = j._process.poll()
160 if returncode is None:
161 state = State.RUNNING
162 elif returncode == 0:
163 state = State.TERMINATED
164 else:
165 state = State.FAILED
166 j._state = state
168 @classmethod
169 def _cancel_jobs_impl(cls, jobs: List[OpenMpiJob]) -> None:
170 # when the user presses ctrl+c, the shell will send all processes in
171 # the same process group SIGINT. some programs respond intelligently to
172 # signals by freeing ressources and exiting. these programs may also
173 # exit _immediately_ if they receive a second signal within a short
174 # time frame (e.g., srun or mpirun which won't terminate its child
175 # processes in this case). for this reason, we wait before terminating
176 # jobs here.
177 max_wait_time = time.Time(seconds=5)
179 # wait at most max_wait_time overall
180 def compute_timeout(t_start: time.Time) -> time.Time:
181 t_waited = time.monotonic() - t_start
182 if t_waited < max_wait_time:
183 return max_wait_time - t_waited
184 return time.Time(seconds=0)
186 # terminate processes
187 t_0 = time.monotonic()
188 for j in jobs:
189 try:
190 timeout = compute_timeout(t_0)
191 j._process.wait(timeout.total_seconds())
192 except subprocess.TimeoutExpired:
193 logger.debug("OpenMPI scheduler terminating process %d", j.id())
194 j._process.terminate()
196 # kill processes if necessary
197 t_1 = time.monotonic()
198 for j in jobs:
199 try:
200 timeout = compute_timeout(t_1)
201 j._process.wait(timeout.total_seconds())
202 except subprocess.TimeoutExpired:
203 logger.debug("OpenMPI scheduler killing process %d", j.id())
204 j._process.kill()
206 j._state = State.FAILED