Coverage for melissa/scheduler/slurm_semiglobal.py: 25%
268 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
1#!/usr/bin/python3
3# Copyright (c) 2020-2022, Institut National de Recherche en Informatique et en Automatique (Inria)
4# All rights reserved.
5#
6# Redistribution and use in source and binary forms, with or without
7# modification, are permitted provided that the following conditions are met:
8#
9# * Redistributions of source code must retain the above copyright notice, this
10# list of conditions and the following disclaimer.
11#
12# * Redistributions in binary form must reproduce the above copyright notice,
13# this list of conditions and the following disclaimer in the documentation
14# and/or other materials provided with the distribution.
15#
16# * Neither the name of the copyright holder nor the names of its
17# contributors may be used to endorse or promote products derived from
18# this software without specific prior written permission.
19#
20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31import enum
32import logging
33import os
34import re
35import shutil
36import subprocess
37from typing import List, Tuple, Union, Dict, Optional
38from pathlib import Path
40from melissa.utility import time
41from melissa.utility.process import ArgumentList, CompletedProcess, Environment, Process
43from .job import Id, Job, State
44from .scheduler import HybridScheduler, Options
45from .slurm_parser import break2str
47logger = logging.getLogger(__name__)
50class JobType(enum.Enum):
52 Server = 0
53 Client = 1
56class HybridJob(Job):
57 def __init__(self, uid: Id, process: "subprocess.Popen[str]", job_type: JobType) -> None:
58 super().__init__()
59 self._type = job_type
60 self._state = State.WAITING
61 self._uid = uid
62 self._process = process
64 def id(self) -> Id:
65 return self._process.pid
67 def unique_id(self) -> Union[str, int]:
68 return self._uid
70 def state(self) -> State:
71 return self._state
73 def __repr__(self) -> str:
74 r = f"<{self.__class__.__name__} (id={self.id:d},state={self.state:s})>"
75 return r
78class NodeWorkload:
79 """Keep track of submitted jobs on different nodes."""
81 def __init__(self):
82 self._node_list: List[str] = []
83 self._allocation_size: int = 0
84 self._ntasks_per_node: int = 0
85 self._node_workload: Dict[str, List[HybridJob]] = {}
87 @property
88 def node_list(self) -> List[str]:
89 return self._node_list
91 @node_list.setter
92 def node_list(self, nodes: List[str]):
93 self._node_list = nodes
94 self._node_workload = {node: [] for node in nodes}
96 @property
97 def allocation_size(self) -> int:
98 return self._allocation_size
100 @allocation_size.setter
101 def allocation_size(self, size: int):
102 self._allocation_size = size
104 @property
105 def ntasks_per_node(self) -> int:
106 return self._ntasks_per_node
108 @ntasks_per_node.setter
109 def ntasks_per_node(self, ntasks: int):
110 self._ntasks_per_node = ntasks
112 def append(self, node: str, job: HybridJob):
113 self._node_workload[node].append(job)
115 def select_node(self, iteration) -> str:
116 return min(
117 self._node_workload, key=lambda node: len(self._node_workload[node])
118 )
120 def __len__(self) -> int:
121 n_submitted_jobs = 0
122 for submitted_jobs in self._node_workload.values():
123 n_submitted_jobs += len(submitted_jobs)
124 return n_submitted_jobs
126 def update_jobs(self):
127 for node in self._node_workload:
128 SlurmSemiGlobalScheduler._update_jobs_impl(self._node_workload[node])
129 self._node_workload[node] = [
130 job for job in self._node_workload[node]
131 if job._state in [State.WAITING, State.RUNNING]
132 ]
135class SlurmSemiGlobalScheduler(HybridScheduler[HybridJob]):
136 # always compile regular expressions to enforce ASCII matching
137 # allow matching things like version 1.2.3-rc4
138 _srun_version_regexp = r"slurm (\d+)[.](\d+)[.](\d+\S*)"
139 _srun_version_pattern = re.compile(_srun_version_regexp, re.ASCII)
140 _sbatch_job_id_regexp = r"(\d+)"
141 _sbatch_job_id_pattern = re.compile(_sbatch_job_id_regexp, re.ASCII)
142 _sacct_line_regexp = r"(\d+)([+]0[.]batch)?[|](\w+)"
143 _sacct_line_pattern = re.compile(_sacct_line_regexp, re.ASCII)
144 _use_het_prefix: bool = False
146 @classmethod
147 def is_direct(cls) -> bool:
148 return True
150 @classmethod
151 def _name_impl(cls) -> str:
152 return "slurm"
154 @classmethod
155 def _is_available_impl(cls) -> Tuple[bool, Union[str, Tuple[str, str]]]:
156 srun_path = shutil.which("srun")
157 if srun_path is None:
158 return False, "srun executable not found"
160 srun = subprocess.run(
161 [srun_path, "--version"],
162 stdin=subprocess.DEVNULL,
163 stdout=subprocess.PIPE,
164 stderr=subprocess.PIPE,
165 universal_newlines=True,
166 )
167 if srun.returncode != 0:
168 return False, "failed to execute %s: %s" % (srun_path, srun.stderr)
170 # do not use pattern.fullmatch() because of the newline at the end of
171 # the output
172 match = cls._srun_version_pattern.match(srun.stdout)
173 if match is None:
174 e = "srun output '{:s}' does not match expected format"
175 raise ValueError(e.format(srun.stdout))
177 version_major = int(match.group(1))
178 # the minor version might be of the form `05` but the Python int()
179 # function handles this correctly
180 version_minor = int(match.group(2))
181 version_patch = match.group(3)
183 if version_major < 19 or (version_major == 19 and version_minor < 5):
184 logger.warn(
185 "Melissa has not been tested with Slurm versions older than 19.05.5", RuntimeWarning
186 )
188 if version_major < 17 or (version_major == 17 and version_minor < 11):
189 return (
190 False,
191 (
192 "expected at least Slurm 17.11,"
193 f"got {version_major}.{version_minor}.{version_patch}"
194 "which does not support heterogeneous jobs"
195 ),
196 )
198 cls._use_het_prefix = version_major >= 20
200 version_str = srun.stdout[match.span(1)[0] : match.span(3)[1]]
201 return True, (srun_path, version_str)
203 def __init__(self) -> None:
204 # Attributes for the srun allocation on multiple nodes
205 self.srun_ctr: int = 0
206 self._node_workload = NodeWorkload()
207 # Standard initialization function
208 is_available, info = self.is_available()
209 if not is_available:
210 raise RuntimeError("Slurm unavailable: %s" % (info,))
212 assert self._use_het_prefix is not None
214 def _sanity_check_impl(self, options: Options) -> List[str]:
215 args = options.raw_arguments
216 errors = []
218 for a in args:
219 if a[0] != "-":
220 errors.append("non-option argument '{:s}' detected".format(a))
221 elif a in ["-n", "--ntasks", "--test-only"]:
222 errors.append("remove '{:s}' argument".format(a))
224 command = ["srun", "--test-only", "--ntasks=1"] + args + ["--", "true"]
225 srun = subprocess.run(
226 command,
227 stdin=subprocess.DEVNULL,
228 stdout=subprocess.DEVNULL,
229 stderr=subprocess.PIPE,
230 universal_newlines=True,
231 )
232 if srun.returncode != 0:
233 e = "srun error on trial execution: {:s}".format(srun.stderr)
234 errors.append(e)
236 return errors
238 def _submit_job_impl(
239 self,
240 commands: List[ArgumentList],
241 env: Environment,
242 options: Options,
243 name: str,
244 unique_id: int,
245 ) -> Tuple[ArgumentList, Environment]:
246 if name == "melissa-server":
247 return self._submit_job_impl_server(commands, env, options, name, unique_id)
248 else:
249 return self._submit_job_impl_client(commands, env, options, name, unique_id)
251 def _submit_job_impl_server(
252 self,
253 commands: List[ArgumentList],
254 env: Environment,
255 options: Options,
256 name: str,
257 unique_id: int,
258 ) -> Tuple[ArgumentList, Environment]:
259 # sbatch submission
260 sbatch_env = os.environ.copy()
261 sbatch_env.update(env)
263 output_filename = "./stdout/job.{:d}.{:s}.out".format(unique_id, name)
264 error_filename = "./stdout/job.{:d}.{:s}.err".format(unique_id, name)
266 propagated_options = [
267 "--output={:s}".format(output_filename),
268 "--error={:s}".format(error_filename),
269 ]
271 uid = unique_id
272 self._server_uid = uid
274 sbatch_options = propagated_options + options.raw_arguments
276 # serialize sbatch options
277 def options2str(options: str) -> str:
278 return "#SBATCH " + options
280 sbatch_options_str = [options2str(o) for o in sbatch_options]
282 sched_cmd = options.sched_cmd
283 sched_cmd_opt = options.sched_cmd_opt
285 # assemble srun arguments
286 srun_arguments: List[List[str]] = []
287 if not sched_cmd:
288 srun_arguments = [[""] + commands[0]]
289 else:
290 srun_arguments = [sched_cmd_opt + ["--"] + commands[0]]
292 # serialize srun arguments
293 def srunargs2str(hetgroup: int, args: List[str]) -> str:
294 assert hetgroup >= 0
295 assert hetgroup < len(commands)
297 prefix = ": " if hetgroup > 0 else ""
298 suffix = " \\" if hetgroup + 1 < len(commands) else ""
299 return " " + prefix + " ".join(args) + suffix
301 srun_arguments_str = [srunargs2str(i, args) for i, args in enumerate(srun_arguments)]
303 # write srun calls to file
304 Path('./sbatch').mkdir(parents=True, exist_ok=True)
305 sbatch_script_filename = "./sbatch/sbatch.{:d}.sh".format(uid)
306 sbatch_script = (
307 ["#!/bin/sh"]
308 + ["# sbatch script for job {:s}".format(name)]
309 + sbatch_options_str
310 + [""]
311 + (["exec \\"] if not sched_cmd else [f"exec {sched_cmd} \\"])
312 + srun_arguments_str
313 )
315 # POSIX requires files to end with a newline; missing newlines at the
316 # end of a file may break scripts that append text.
317 # this string won't contain a newline at the end; it must be added
318 # manually or by using a function that adds it, e.g., `print`
319 sbatch_script_str_noeol = "\n".join(sbatch_script)
321 with open(sbatch_script_filename, "w") as f:
322 print(sbatch_script_str_noeol, file=f)
324 sbatch_call = (
325 ["sbatch"]
326 + ["--parsable"]
327 + ["--job-name={:s}".format(name)]
328 + [sbatch_script_filename]
329 )
331 return sbatch_call, sbatch_env
333 def _submit_job_impl_client(
334 self,
335 commands: List[ArgumentList],
336 env: Environment,
337 options: Options,
338 name: str,
339 unique_id: int,
340 ) -> Tuple[ArgumentList, Environment]:
341 # srun submission for clients
342 srun_env = os.environ.copy()
343 srun_env.update(env)
344 self.srun_ctr += 1
346 output_filename = "./stdout/job.{:d}.{:s}.out".format(unique_id, name)
347 error_filename = "./stdout/job.{:d}.{:s}.err".format(unique_id, name)
349 # command line only scanned the first time
350 if self.srun_ctr == 1:
351 # extract all environment variables
352 self._node_workload.allocation_size = int(srun_env["SLURM_NTASKS"])
353 self._node_workload.ntasks_per_node = int(srun_env["SLURM_NTASKS_PER_NODE"])
354 self._node_workload.node_list = break2str(srun_env["SLURM_NODELIST"])
355 assert (
356 len(self._node_workload.node_list)
357 == self._node_workload.allocation_size // self._node_workload.ntasks_per_node
358 )
359 logger.debug(
360 f"allocation of {self._node_workload.allocation_size} cores "
361 f"with identified nodelist {self._node_workload.node_list} "
362 f"from environment variable {srun_env['SLURM_NODELIST']}"
363 )
365 # make sure resources are available before next series of srun
366 logger.debug(f"Scheduler has currently {len(self._node_workload)} jobs running.")
367 self._node_workload.update_jobs()
368 node = self._node_workload.select_node(self.srun_ctr - 1)
370 # build propagated options of the srun command line
371 propagated_options = [
372 "--output={:s}".format(output_filename),
373 "--error={:s}".format(error_filename),
374 "--nodelist={:s}".format(node),
375 ] + options.raw_arguments
377 # assemble srun arguments
378 sched_cmd = options.sched_cmd
379 sched_cmd_opt = options.sched_cmd_opt
381 # retrieve variable value amongst scheduler_arg_client
382 def get_from_options(list_of_options: List[str], variable_name: str) -> str:
383 for opt in list_of_options:
384 if variable_name in opt:
385 idx = opt.find("=")
386 return opt[idx + 1:]
387 logger.error(
388 f"The variable name {variable_name} is not amongst {list_of_options} "
389 "Please make sure scheduler_arg_client is well set"
390 )
391 return ""
393 # write srun/job execution call
394 srun_call: List[str] = []
395 srun_arguments: List[str] = []
396 if not sched_cmd:
397 assert len(commands) == 1, "non-unit groups are not supported in this configuration"
398 srun_call = commands[0]
399 # node and resource specification with environment variables
400 new_env = {
401 "NODEID": node,
402 "NTASKS": get_from_options(propagated_options, "ntasks"),
403 "TIME": get_from_options(propagated_options, "time")
404 }
405 srun_env.update(new_env)
406 else:
407 for i, cmd in enumerate(commands):
408 args = (
409 sched_cmd_opt + propagated_options + ["--"] + cmd
410 + ([":"] if i + 1 < len(commands) else [])
411 )
412 srun_arguments.extend(args)
413 srun_call = [sched_cmd] + srun_arguments
415 return srun_call, srun_env
417 def _make_job_impl(self, proc: "Process[str]", unique_id: int) -> HybridJob:
418 if unique_id == self._server_uid:
419 job_type = JobType.Server
420 else:
421 job_type = JobType.Client
422 job = HybridJob(unique_id, proc, job_type)
423 if job._type == JobType.Client:
424 # Retrieve the node of the client
425 node: str = ""
426 # retrieved from process arguments (standard submission)
427 for arg in proc.args: # type: ignore
428 if "nodelist" in arg: # type: ignore
429 node = arg.split("=")[-1] # type: ignore
430 break
431 # retrieved from the process environment (no sched_cmd submission)
432 if not node:
433 node = proc.env["NODEID"] # type: ignore
434 assert node, f"Node list missing in {proc.args} and {proc.env}" # type: ignore
435 self._node_workload.append(node, job)
436 return job
438 @classmethod
439 def _update_jobs_impl(cls, jobs: List[HybridJob]) -> None:
440 logger.debug(f"Update - job list {[j._type for j in jobs]}")
442 for j in jobs:
443 if j._type == JobType.Client:
444 returncode = j._process.poll()
445 if returncode is None:
446 state = State.RUNNING
447 elif returncode == 0:
448 state = State.TERMINATED
449 else:
450 state = State.FAILED
451 j._state = state
452 # note that the server job status is not monitored
453 # trough subprocess since it is an indirect submission.
454 # instead, the server job keeps the initial WAITING status and
455 # the server will be marked as running by the state machine
456 # as soon as it connects to the launcher.
457 # Then, his running state is only monitored through regular pinging.
459 @classmethod
460 def _cancel_client_jobs_impl(cls, jobs: List[HybridJob]) -> None:
461 logger.debug(f"Cancel - job list {[j._type for j in jobs if j._type == JobType.Client]}")
463 # when the user presses ctrl+c, the shell will send all processes in
464 # the same process group SIGINT. some programs respond intelligently to
465 # signals by freeing resources and exiting. these programs may also
466 # exit _immediately_ if they receive a second signal within a short
467 # time frame (e.g., srun or mpirun which won't terminate its child
468 # processes in this case). for this reason, we wait before terminating
469 # jobs here.
470 max_wait_time = time.Time(seconds=5)
472 # wait at most max_wait_time overall
473 def compute_timeout(t_start: time.Time) -> time.Time:
474 t_waited = time.monotonic() - t_start
475 if t_waited < max_wait_time:
476 return max_wait_time - t_waited
477 return time.Time(seconds=0)
479 # terminate processes
480 t_0 = time.monotonic()
481 for j in jobs:
482 if j._type == JobType.Client:
483 try:
484 timeout = compute_timeout(t_0)
485 j._process.wait(timeout.total_seconds())
486 except subprocess.TimeoutExpired:
487 logger.debug(f"Slurm srun scheduler terminating process {j.id()}")
488 j._process.terminate()
490 # kill processes if necessary
491 t_1 = time.monotonic()
492 for j in jobs:
493 if j._type == JobType.Client:
494 try:
495 timeout = compute_timeout(t_1)
496 j._process.wait(timeout.total_seconds())
497 except subprocess.TimeoutExpired:
498 logger.debug(f"Slurm srun scheduler killing process {j.id()}")
499 j._process.kill()
501 j._state = State.FAILED
503 @classmethod
504 def _cancel_server_job_impl(
505 cls, jobs: List[HybridJob]
506 ) -> Optional[Tuple[ArgumentList, Environment]]:
507 server_job: List[HybridJob] = [j for j in jobs if j._type == JobType.Server]
508 logger.debug(f"Cancel - job list {[j._type for j in server_job]}")
510 job_list = [str(server_job[0].id())]
511 scancel_command = ["scancel", "--batch", "--quiet"] + job_list
513 return (scancel_command, os.environ) if server_job else None
515 def _parse_cancel_jobs_impl(self, jobs: List[HybridJob], proc: CompletedProcess) -> None:
516 # scancel exits with status 1 if at least one job had already been
517 # terminated
518 if proc.exit_status not in [0, 1]:
519 raise RuntimeError("scancel error: exit status {:d}".format(proc.exit_status))