Coverage for melissa/scheduler/slurm_global.py: 28%
138 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-19 09:33 +0100
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-19 09:33 +0100
1#!/usr/bin/python3
3# Copyright (c) 2020-2022, Institut National de Recherche en Informatique et en Automatique (Inria)
4# All rights reserved.
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions are met:
8#
9# * Redistributions of source code must retain the above copyright notice, this
10# list of conditions and the following disclaimer.
11#
12# * Redistributions in binary form must reproduce the above copyright notice,
13# this list of conditions and the following disclaimer in the documentation
14# and/or other materials provided with the distribution.
15#
16# * Neither the name of the copyright holder nor the names of its
17# contributors may be used to endorse or promote products derived from
18# this software without specific prior written permission.
19#
20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31import logging
32import os
33import re
34import shutil
35import subprocess
36from types import ModuleType
37from typing import cast, List, Tuple, Union, Any
39from melissa.utility import time
40from melissa.utility.process import ArgumentList, Environment, Process
42from .job import Id, Job, State
43from .scheduler import DirectScheduler, Options
45logging = cast(ModuleType, logging.getLogger(__name__))
48class SrunJob(Job):
49 def __init__(self, uid: Id, process: "subprocess.Popen[str]") -> None:
50 super().__init__()
51 self._uid = uid
52 self._process = process
53 self._state = State.RUNNING
55 def id(self) -> Id:
56 return self._process.pid
58 def unique_id(self) -> Union[str, int]:
59 return self._uid
61 def state(self) -> State:
62 return self._state
64 def __repr__(self) -> str:
65 r = "SrunJob(id={:d},state={:s})".format(self.id(), str(self._state))
66 return r
69class SlurmGlobalScheduler(DirectScheduler[SrunJob]):
70 # always compile regular expressions to enforce ASCII matching
71 # allow matching things like version 1.2.3-rc4
72 # _srun_version_regexp = r"slurm (\d+)[.](\d+)[.](\d+\S*)"
73 _srun_version_regexp = r"slurm (\d+)\.(\d+)\.(\d+)(-\S+)?"
75 _srun_version_pattern = re.compile(_srun_version_regexp, re.ASCII)
76 _sbatch_job_id_regexp = r"(\d+)"
77 _sbatch_job_id_pattern = re.compile(_sbatch_job_id_regexp, re.ASCII)
78 _sacct_line_regexp = r"(\d+)([+]0[.]batch)?[|](\w+)"
79 _sacct_line_pattern = re.compile(_sacct_line_regexp, re.ASCII)
80 _use_het_prefix: bool = False
82 @classmethod
83 def _name_impl(cls) -> str:
84 return "slurm"
86 @classmethod
87 def _is_available_impl(cls) -> Tuple[bool, Union[str, Tuple[str, str]]]:
88 srun_path = shutil.which("srun")
89 if srun_path is None:
90 return False, "srun executable not found"
92 srun = subprocess.run(
93 [srun_path, "--version"],
94 stdin=subprocess.DEVNULL,
95 stdout=subprocess.PIPE,
96 stderr=subprocess.PIPE,
97 universal_newlines=True,
98 )
99 if srun.returncode != 0:
100 return False, "failed to execute %s: %s" % (srun_path, srun.stderr)
102 # do not use pattern.fullmatch() because of the newline at the end of
103 # the output
104 match = cls._srun_version_pattern.match(srun.stdout)
105 if match is None:
106 e = "srun output '{:s}' does not match expected format"
107 raise ValueError(e.format(srun.stdout))
109 version_major = int(match.group(1))
110 # the minor version might be of the form `05` but the Python int()
111 # function handles this correctly
112 version_minor = int(match.group(2))
113 version_patch = int(match.group(3))
115 if version_major < 19 or (version_major == 19 and version_minor < 5):
116 logging.warn(
117 "Melissa has not been tested with Slurm versions older than 19.05.5", RuntimeWarning
118 )
120 if version_major < 17 or (version_major == 17 and version_minor < 11):
121 return (
122 False,
123 (
124 "Expected at least Slurm 17.11, got"
125 f"{version_major}.{version_minor}.{version_patch}"
126 "which does not support heterogeneous jobs"
127 ),
128 )
130 cls._use_het_prefix = version_major >= 20
132 version_str = srun.stdout[match.span(1)[0] : match.span(3)[1]]
133 return True, (srun_path, version_str)
135 def __init__(self) -> None:
136 is_available, info = self.is_available()
137 if not is_available:
138 raise RuntimeError("Slurm unavailable: %s" % (info,))
140 assert self._use_het_prefix is not None
142 def _sanity_check_impl(self, options: Options) -> List[str]:
143 args = options.raw_arguments
144 errors = []
146 for a in args:
147 if a[0] != "-":
148 errors.append("non-option argument '{:s}' detected".format(a))
149 elif a in ["-n", "--ntasks", "--test-only"]:
150 errors.append("remove '{:s}' argument".format(a))
152 command = ["srun", "--test-only", "--ntasks=1"] + args + ["--", "true"]
153 srun = subprocess.run(
154 command,
155 stdin=subprocess.DEVNULL,
156 stdout=subprocess.DEVNULL,
157 stderr=subprocess.PIPE,
158 universal_newlines=True,
159 )
160 if srun.returncode != 0:
161 e = "srun error on trial execution: {:s}".format(srun.stderr)
162 errors.append(e)
164 return errors
166 def _submit_job_impl(
167 self,
168 commands: List[ArgumentList],
169 env: Environment,
170 options: Options,
171 name: str,
172 unique_id: int,
173 ) -> Tuple[ArgumentList, Environment]:
174 # Approach to environment variables:
175 # By default all environment variables are propagated
176 srun_env = os.environ.copy()
177 srun_env.update(env)
179 output_filename = "./stdout/job.{:d}.{:s}.out".format(unique_id, name)
180 error_filename = "./stdout/job.{:d}.{:s}.err".format(unique_id, name)
182 # build propagated options of the srun command line
183 propagated_options = [
184 "--output={:s}".format(output_filename),
185 "--error={:s}".format(error_filename),
186 ] + options.raw_arguments
188 sched_cmd = options.sched_cmd
189 sched_cmd_opt = options.sched_cmd_opt
191 assert len(commands) == 1, "non-unit groups are not supported in this configuration"
192 srun_arguments: List[List[str]] = []
193 if not sched_cmd:
194 srun_arguments = [commands[0]]
195 else:
196 srun_arguments = [["--"], commands[0]]
198 def args2str(args: List[str]) -> str:
199 return "" + "".join(args)
200 logging.debug(f"srun_arguments: {srun_arguments}")
201 srun_arguments_str = [args2str(args) for args in srun_arguments]
202 logging.debug(f"srun_arguments_str: {srun_arguments_str}")
204 # write srun/job execution call
205 if not sched_cmd:
206 srun_call = srun_arguments_str
207 else:
208 srun_call = [sched_cmd] + sched_cmd_opt + propagated_options + srun_arguments_str
210 return srun_call, srun_env
212 def _make_job_impl(self, proc: "Process[str]", unique_id: int,
213 **kwargs: "dict[str, Any]") -> SrunJob:
214 return SrunJob(unique_id, proc)
216 @classmethod
217 def _update_jobs_impl(cls, jobs: List[SrunJob]) -> None:
218 for j in jobs:
219 returncode = j._process.poll()
220 if returncode is None:
221 state = State.RUNNING
222 elif returncode == 0:
223 state = State.TERMINATED
224 else:
225 state = State.FAILED
226 j._state = state
228 @classmethod
229 def _cancel_jobs_impl(cls, jobs: List[SrunJob]) -> None:
230 # when the user presses ctrl+c, the shell will send all processes in
231 # the same process group SIGINT. some programs respond intelligently to
232 # signals by freeing ressources and exiting. these programs may also
233 # exit _immediately_ if they receive a second signal within a short
234 # time frame (e.g., srun or mpirun which won't terminate its child
235 # processes in this case). for this reason, we wait before terminating
236 # jobs here.
237 max_wait_time = time.Time(seconds=5)
239 # wait at most max_wait_time overall
240 def compute_timeout(t_start: time.Time) -> time.Time:
241 t_waited = time.monotonic() - t_start
242 if t_waited < max_wait_time:
243 return max_wait_time - t_waited
244 return time.Time(seconds=0)
246 # terminate processes
247 t_0 = time.monotonic()
248 for j in jobs:
249 try:
250 timeout = compute_timeout(t_0)
251 j._process.wait(timeout.total_seconds())
252 except subprocess.TimeoutExpired:
253 logging.debug("Slurm srun scheduler terminating process %d", j.id())
254 j._process.terminate()
256 # kill processes if necessary
257 t_1 = time.monotonic()
258 for j in jobs:
259 try:
260 timeout = compute_timeout(t_1)
261 j._process.wait(timeout.total_seconds())
262 except subprocess.TimeoutExpired:
263 logging.debug("Slurm srun scheduler killing process %d", j.id())
264 j._process.kill()
266 j._state = State.FAILED