Coverage for melissa/scheduler/slurm_semiglobal.py: 24%
278 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
1#!/usr/bin/python3
3# Copyright (c) 2020-2022, Institut National de Recherche en Informatique et en Automatique (Inria)
4# All rights reserved.
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions are met:
8#
9# * Redistributions of source code must retain the above copyright notice, this
10# list of conditions and the following disclaimer.
11#
12# * Redistributions in binary form must reproduce the above copyright notice,
13# this list of conditions and the following disclaimer in the documentation
14# and/or other materials provided with the distribution.
15#
16# * Neither the name of the copyright holder nor the names of its
17# contributors may be used to endorse or promote products derived from
18# this software without specific prior written permission.
19#
20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31import enum
32import logging
33import os
34import re
35import shutil
36import subprocess
37from typing import List, Tuple, Union, Dict, Optional, Any
38from pathlib import Path
40from melissa.utility import time
41from melissa.utility.process import ArgumentList, CompletedProcess, Environment, Process
43from .job import Id, Job, State
44from .scheduler import HybridScheduler, Options
45from .slurm_parser import break2str
47logger = logging.getLogger(__name__)
50class JobType(enum.Enum):
52 Server = 0
53 Client = 1
56class HybridJob(Job):
57 def __init__(
58 self,
59 uid: Id,
60 process: "subprocess.Popen[str]",
61 job_type: JobType,
62 stdout_fname: str | None,
63 stderr_fname: str | None,
64 ) -> None:
65 super().__init__()
66 self._type = job_type
67 self._state = State.WAITING
68 self._uid = uid
69 self._process = process
70 self.stdout_fname = stdout_fname
71 self.stderr_fname = stderr_fname
73 def id(self) -> Id:
74 return self._process.pid
76 def unique_id(self) -> Union[str, int]:
77 return self._uid
79 def state(self) -> State:
80 return self._state
82 def __repr__(self) -> str:
83 r = f"<{self.__class__.__name__} (id={self.id():d},state={self.state()})>"
84 return r
87class NodeWorkload:
88 """Keep track of submitted jobs on different nodes."""
90 def __init__(self):
91 self._node_list: List[str] = []
92 self._allocation_size: int = 0
93 self._ntasks_per_node: int = 0
94 self._node_workload: Dict[str, List[HybridJob]] = {}
96 @property
97 def node_list(self) -> List[str]:
98 return self._node_list
100 @node_list.setter
101 def node_list(self, nodes: List[str]):
102 self._node_list = nodes
103 self._node_workload = {node: [] for node in nodes}
105 @property
106 def allocation_size(self) -> int:
107 return self._allocation_size
109 @allocation_size.setter
110 def allocation_size(self, size: int):
111 self._allocation_size = size
113 @property
114 def ntasks_per_node(self) -> int:
115 return self._ntasks_per_node
117 @ntasks_per_node.setter
118 def ntasks_per_node(self, ntasks: int):
119 self._ntasks_per_node = ntasks
121 def append(self, node: str, job: HybridJob):
122 self._node_workload[node].append(job)
124 def select_node(self, iteration) -> str:
125 return min(self._node_workload, key=lambda node: len(self._node_workload[node]))
127 def __len__(self) -> int:
128 n_submitted_jobs = 0
129 for submitted_jobs in self._node_workload.values():
130 n_submitted_jobs += len(submitted_jobs)
131 return n_submitted_jobs
133 def update_jobs(self):
134 for node in self._node_workload:
135 SlurmSemiGlobalScheduler._update_jobs_impl(self._node_workload[node])
136 self._node_workload[node] = [
137 job
138 for job in self._node_workload[node]
139 if job._state in [State.WAITING, State.RUNNING]
140 ]
143class SlurmSemiGlobalScheduler(HybridScheduler[HybridJob]):
144 # always compile regular expressions to enforce ASCII matching
145 # allow matching things like version 1.2.3-rc4
146 _srun_version_regexp = r"slurm (\d+)[.](\d+)[.](\d+\S*)"
147 _srun_version_pattern = re.compile(_srun_version_regexp, re.ASCII)
148 _sbatch_job_id_regexp = r"(\d+)"
149 _sbatch_job_id_pattern = re.compile(_sbatch_job_id_regexp, re.ASCII)
150 _sacct_line_regexp = r"(\d+)([+]0[.]batch)?[|](\w+)"
151 _sacct_line_pattern = re.compile(_sacct_line_regexp, re.ASCII)
152 _use_het_prefix: bool = False
154 @classmethod
155 def is_direct(cls) -> bool:
156 return True
158 @classmethod
159 def _name_impl(cls) -> str:
160 return "slurm"
162 @classmethod
163 def _is_available_impl(cls) -> Tuple[bool, Union[str, Tuple[str, str]]]:
164 srun_path = shutil.which("srun")
165 if srun_path is None:
166 return False, "srun executable not found"
168 srun = subprocess.run(
169 [srun_path, "--version"],
170 stdin=subprocess.DEVNULL,
171 stdout=subprocess.PIPE,
172 stderr=subprocess.PIPE,
173 universal_newlines=True,
174 )
175 if srun.returncode != 0:
176 return False, "failed to execute %s: %s" % (srun_path, srun.stderr)
178 # do not use pattern.fullmatch() because of the newline at the end of
179 # the output
180 match = cls._srun_version_pattern.match(srun.stdout)
181 if match is None:
182 e = "srun output '{:s}' does not match expected format"
183 raise ValueError(e.format(srun.stdout))
185 version_major = int(match.group(1))
186 # the minor version might be of the form `05` but the Python int()
187 # function handles this correctly
188 version_minor = int(match.group(2))
189 version_patch = match.group(3)
191 if version_major < 19 or (version_major == 19 and version_minor < 5):
192 logger.warn(
193 "Melissa has not been tested with Slurm versions older than 19.05.5",
194 RuntimeWarning,
195 )
197 if version_major < 17 or (version_major == 17 and version_minor < 11):
198 return (
199 False,
200 (
201 "expected at least Slurm 17.11,"
202 f"got {version_major}.{version_minor}.{version_patch}"
203 "which does not support heterogeneous jobs"
204 ),
205 )
207 cls._use_het_prefix = version_major >= 20
209 version_str = srun.stdout[match.span(1)[0] : match.span(3)[1]]
210 return True, (srun_path, version_str)
212 def __init__(self) -> None:
213 # Attributes for the srun allocation on multiple nodes
214 self.srun_ctr: int = 0
215 self._node_workload = NodeWorkload()
216 # Standard initialization function
217 is_available, info = self.is_available()
218 if not is_available:
219 raise RuntimeError("Slurm unavailable: %s" % (info,))
221 assert self._use_het_prefix is not None
223 def _sanity_check_impl(self, options: Options) -> List[str]:
224 args = options.raw_arguments
225 errors = []
227 for a in args:
228 if a[0] != "-":
229 errors.append("non-option argument '{:s}' detected".format(a))
230 elif a in ["-n", "--ntasks", "--test-only"]:
231 errors.append("remove '{:s}' argument".format(a))
233 command = ["srun", "--test-only", "--ntasks=1"] + args + ["--", "true"]
234 srun = subprocess.run(
235 command,
236 stdin=subprocess.DEVNULL,
237 stdout=subprocess.DEVNULL,
238 stderr=subprocess.PIPE,
239 universal_newlines=True,
240 )
241 if srun.returncode != 0:
242 e = "srun error on trial execution: {:s}".format(srun.stderr)
243 errors.append(e)
245 return errors
247 def _submit_job_impl(
248 self,
249 commands: List[ArgumentList],
250 env: Environment,
251 options: Options,
252 name: str,
253 unique_id: int,
254 ) -> Tuple[ArgumentList, Environment]:
255 if name == "melissa-server":
256 return self._submit_job_impl_server(commands, env, options, name, unique_id)
257 else:
258 return self._submit_job_impl_client(commands, env, options, name, unique_id)
260 def _submit_job_impl_server(
261 self,
262 commands: List[ArgumentList],
263 env: Environment,
264 options: Options,
265 name: str,
266 unique_id: int,
267 ) -> Tuple[ArgumentList, Environment]:
268 # sbatch submission
269 sbatch_env = os.environ.copy()
270 sbatch_env.update(env)
272 output_filename = "./stdout/job.{:d}.{:s}.out".format(unique_id, name)
273 error_filename = "./stdout/job.{:d}.{:s}.err".format(unique_id, name)
275 propagated_options = [
276 "--output={:s}".format(output_filename),
277 "--error={:s}".format(error_filename),
278 ]
280 uid = unique_id
281 self._server_uid = uid
283 sbatch_options = propagated_options + options.raw_arguments
285 # serialize sbatch options
286 def options2str(options: str) -> str:
287 return "#SBATCH " + options
289 sbatch_options_str = [options2str(o) for o in sbatch_options]
291 sched_cmd = options.sched_cmd
292 sched_cmd_opt = options.sched_cmd_opt
294 # assemble srun arguments
295 srun_arguments: List[List[str]] = []
296 if not sched_cmd:
297 srun_arguments = [[""] + commands[0]]
298 else:
299 srun_arguments = [sched_cmd_opt + ["--"] + commands[0]]
301 # serialize srun arguments
302 def srunargs2str(hetgroup: int, args: List[str]) -> str:
303 assert hetgroup >= 0
304 assert hetgroup < len(commands)
306 prefix = ": " if hetgroup > 0 else ""
307 suffix = " \\" if hetgroup + 1 < len(commands) else ""
308 return " " + prefix + " ".join(args) + suffix
310 srun_arguments_str = [
311 srunargs2str(i, args) for i, args in enumerate(srun_arguments)
312 ]
314 # write srun calls to file
315 Path("./sbatch").mkdir(parents=True, exist_ok=True)
316 sbatch_script_filename = "./sbatch/sbatch.{:d}.sh".format(uid)
317 sbatch_script = (
318 ["#!/bin/sh"]
319 + ["# sbatch script for job {:s}".format(name)]
320 + sbatch_options_str
321 + [""]
322 + (["exec \\"] if not sched_cmd else [f"exec {sched_cmd} \\"])
323 + srun_arguments_str
324 )
326 # POSIX requires files to end with a newline; missing newlines at the
327 # end of a file may break scripts that append text.
328 # this string won't contain a newline at the end; it must be added
329 # manually or by using a function that adds it, e.g., `print`
330 sbatch_script_str_noeol = "\n".join(sbatch_script)
332 with open(sbatch_script_filename, "w") as f:
333 print(sbatch_script_str_noeol, file=f)
335 sbatch_call = (
336 ["sbatch"]
337 + ["--parsable"]
338 + ["--job-name={:s}".format(name)]
339 + [sbatch_script_filename]
340 )
342 return sbatch_call, sbatch_env
344 def _submit_job_impl_client(
345 self,
346 commands: List[ArgumentList],
347 env: Environment,
348 options: Options,
349 name: str,
350 unique_id: int,
351 ) -> Tuple[ArgumentList, Environment]:
352 # srun submission for clients
353 srun_env = os.environ.copy()
354 srun_env.update(env)
355 self.srun_ctr += 1
357 output_filename = "./stdout/job.{:d}.{:s}.out".format(unique_id, name)
358 error_filename = "./stdout/job.{:d}.{:s}.err".format(unique_id, name)
360 # command line only scanned the first time
361 if self.srun_ctr == 1:
362 # extract all environment variables
363 self._node_workload.allocation_size = int(srun_env["SLURM_NTASKS"])
364 self._node_workload.ntasks_per_node = int(srun_env["SLURM_NTASKS_PER_NODE"])
365 self._node_workload.node_list = break2str(srun_env["SLURM_NODELIST"])
366 assert (
367 len(self._node_workload.node_list)
368 == self._node_workload.allocation_size
369 // self._node_workload.ntasks_per_node
370 )
371 logger.debug(
372 f"allocation of {self._node_workload.allocation_size} cores "
373 f"with identified nodelist {self._node_workload.node_list} "
374 f"from environment variable {srun_env['SLURM_NODELIST']}"
375 )
377 # make sure resources are available before next series of srun
378 logger.debug(
379 f"Scheduler has currently {len(self._node_workload)} jobs running."
380 )
381 self._node_workload.update_jobs()
382 node = self._node_workload.select_node(self.srun_ctr - 1)
384 # build propagated options of the srun command line
385 propagated_options = [
386 "--output={:s}".format(output_filename),
387 "--error={:s}".format(error_filename),
388 "--nodelist={:s}".format(node),
389 ] + options.raw_arguments
391 # assemble srun arguments
392 sched_cmd = options.sched_cmd
393 sched_cmd_opt = options.sched_cmd_opt
395 # retrieve variable value amongst scheduler_arg_client
396 def get_from_options(list_of_options: List[str], variable_name: str) -> str:
397 for opt in list_of_options:
398 if variable_name in opt:
399 idx = opt.find("=")
400 return opt[idx + 1 :]
401 logger.error(
402 f"The variable name {variable_name} is not amongst {list_of_options} "
403 "Please make sure scheduler_arg_client is well set"
404 )
405 return ""
407 # write srun/job execution call
408 srun_call: List[str] = []
409 srun_arguments: List[str] = []
410 if not sched_cmd:
411 assert (
412 len(commands) == 1
413 ), "non-unit groups are not supported in this configuration"
414 srun_call = commands[0]
415 # node and resource specification with environment variables
416 new_env = {
417 "NODEID": node,
418 "NTASKS": get_from_options(propagated_options, "ntasks"),
419 "TIME": get_from_options(propagated_options, "time"),
420 }
421 srun_env.update(new_env)
422 else:
423 for i, cmd in enumerate(commands):
424 args = (
425 sched_cmd_opt
426 + propagated_options
427 + ["--"]
428 + cmd
429 + ([":"] if i + 1 < len(commands) else [])
430 )
431 srun_arguments.extend(args)
432 srun_call = [sched_cmd] + srun_arguments
434 return srun_call, srun_env
436 def _make_job_impl(
437 self, proc: "Process[str]", unique_id: int, **kwargs: "dict[str, Any]"
438 ) -> HybridJob:
439 if unique_id == self._server_uid:
440 job_type = JobType.Server
441 else:
442 job_type = JobType.Client
443 job = HybridJob(
444 unique_id,
445 proc,
446 job_type,
447 kwargs["stdout_fname"], # type: ignore
448 kwargs["stderr_fname"], # type: ignore
449 )
450 if job._type == JobType.Client:
451 # Retrieve the node of the client
452 node: str = ""
453 # retrieved from process arguments (standard submission)
454 for arg in proc.args: # type: ignore
455 if "nodelist" in arg: # type: ignore
456 node = arg.split("=")[-1] # type: ignore
457 break
458 # retrieved from the process environment (no sched_cmd submission)
459 if not node:
460 node = proc.env["NODEID"] # type: ignore
461 assert node, f"Node list missing in {proc.args} and {proc.env}" # type: ignore
462 self._node_workload.append(node, job)
463 return job
465 @classmethod
466 def _update_jobs_impl(cls, jobs: List[HybridJob]) -> None:
467 logger.debug(f"Update - job list {[j._type for j in jobs]}")
469 for j in jobs:
470 if j._type == JobType.Client:
471 returncode = j._process.poll()
472 if returncode is None:
473 state = State.RUNNING
474 elif returncode == 0:
475 state = State.TERMINATED
476 else:
477 state = State.FAILED
478 j._state = state
479 # note that the server job status is not monitored
480 # trough subprocess since it is an indirect submission.
481 # instead, the server job keeps the initial WAITING status and
482 # the server will be marked as running by the state machine
483 # as soon as it connects to the launcher.
484 # Then, his running state is only monitored through regular pinging.
486 @classmethod
487 def _cancel_client_jobs_impl(cls, jobs: List[HybridJob]) -> None:
488 logger.debug(
489 f"Cancel - job list {[j._type for j in jobs if j._type == JobType.Client]}"
490 )
492 # when the user presses ctrl+c, the shell will send all processes in
493 # the same process group SIGINT. some programs respond intelligently to
494 # signals by freeing resources and exiting. these programs may also
495 # exit _immediately_ if they receive a second signal within a short
496 # time frame (e.g., srun or mpirun which won't terminate its child
497 # processes in this case). for this reason, we wait before terminating
498 # jobs here.
499 max_wait_time = time.Time(seconds=5)
501 # wait at most max_wait_time overall
502 def compute_timeout(t_start: time.Time) -> time.Time:
503 t_waited = time.monotonic() - t_start
504 if t_waited < max_wait_time:
505 return max_wait_time - t_waited
506 return time.Time(seconds=0)
508 # terminate processes
509 t_0 = time.monotonic()
510 for j in jobs:
511 if j._type == JobType.Client:
512 try:
513 timeout = compute_timeout(t_0)
514 j._process.wait(timeout.total_seconds())
515 except subprocess.TimeoutExpired:
516 logger.debug(f"Slurm srun scheduler terminating process {j.id()}")
517 j._process.terminate()
519 # kill processes if necessary
520 t_1 = time.monotonic()
521 for j in jobs:
522 if j._type == JobType.Client:
523 try:
524 timeout = compute_timeout(t_1)
525 j._process.wait(timeout.total_seconds())
526 except subprocess.TimeoutExpired:
527 logger.debug(f"Slurm srun scheduler killing process {j.id()}")
528 j._process.kill()
530 j._state = State.FAILED
532 @classmethod
533 def _cancel_server_job_impl(
534 cls, jobs: List[HybridJob]
535 ) -> Optional[Tuple[ArgumentList, Environment]]:
536 server_job: List[HybridJob] = [j for j in jobs if j._type == JobType.Server]
537 logger.debug(f"Cancel - job list {[j._type for j in server_job]}")
539 # if file exists then read the sbatch job id
540 if server_job[0].stdout_fname is not None:
541 with open(server_job[0].stdout_fname) as f:
542 match = cls._sbatch_job_id_pattern.fullmatch(f.read().strip())
543 # else read from the process' stored stdout
544 else:
545 assert (
546 server_job[0]._process.stdout is not None
547 ), "Launcher was not able to find stdout containing sbatch id "
548 "required for the server job cancellation."
549 match = cls._sbatch_job_id_pattern.fullmatch(
550 server_job[0]._process.stdout.read().strip()
551 )
552 if match is None:
553 e = "no job ID found in server sbatch output."
554 raise ValueError(e)
556 scancel_command = [
557 "scancel",
558 "--batch",
559 "--quiet",
560 match.group(1), # type: ignore
561 ]
563 return (scancel_command, os.environ) if server_job else None
565 def _parse_cancel_jobs_impl(
566 self, jobs: List[HybridJob], proc: CompletedProcess
567 ) -> None:
568 # scancel exits with status 1 if at least one job had already been
569 # terminated
570 if proc.exit_status not in [0, 1]:
571 raise RuntimeError(
572 "scancel error: exit status {:d}".format(proc.exit_status)
573 )