Coverage for melissa/scheduler/slurm_semiglobal.py: 24%

278 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-11-03 09:52 +0100

1#!/usr/bin/python3 

2 

3# Copyright (c) 2020-2022, Institut National de Recherche en Informatique et en Automatique (Inria) 

4# All rights reserved. 

5# 

6# Redistribution and use in source and binary forms, with or without 

7# modification, are permitted provided that the following conditions are met: 

8# 

9# * Redistributions of source code must retain the above copyright notice, this 

10# list of conditions and the following disclaimer. 

11# 

12# * Redistributions in binary form must reproduce the above copyright notice, 

13# this list of conditions and the following disclaimer in the documentation 

14# and/or other materials provided with the distribution. 

15# 

16# * Neither the name of the copyright holder nor the names of its 

17# contributors may be used to endorse or promote products derived from 

18# this software without specific prior written permission. 

19# 

20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 

21# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 

22# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 

23# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 

24# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 

25# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 

26# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 

27# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 

28# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 

29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 

30 

31import enum 

32import logging 

33import os 

34import re 

35import shutil 

36import subprocess 

37from typing import List, Tuple, Union, Dict, Optional, Any 

38from pathlib import Path 

39 

40from melissa.utility import time 

41from melissa.utility.process import ArgumentList, CompletedProcess, Environment, Process 

42 

43from .job import Id, Job, State 

44from .scheduler import HybridScheduler, Options 

45from .slurm_parser import break2str 

46 

47logger = logging.getLogger(__name__) 

48 

49 

50class JobType(enum.Enum): 

51 

52 Server = 0 

53 Client = 1 

54 

55 

56class HybridJob(Job): 

57 def __init__( 

58 self, 

59 uid: Id, 

60 process: "subprocess.Popen[str]", 

61 job_type: JobType, 

62 stdout_fname: str | None, 

63 stderr_fname: str | None, 

64 ) -> None: 

65 super().__init__() 

66 self._type = job_type 

67 self._state = State.WAITING 

68 self._uid = uid 

69 self._process = process 

70 self.stdout_fname = stdout_fname 

71 self.stderr_fname = stderr_fname 

72 

73 def id(self) -> Id: 

74 return self._process.pid 

75 

76 def unique_id(self) -> Union[str, int]: 

77 return self._uid 

78 

79 def state(self) -> State: 

80 return self._state 

81 

82 def __repr__(self) -> str: 

83 r = f"<{self.__class__.__name__} (id={self.id():d},state={self.state()})>" 

84 return r 

85 

86 

87class NodeWorkload: 

88 """Keep track of submitted jobs on different nodes.""" 

89 

90 def __init__(self): 

91 self._node_list: List[str] = [] 

92 self._allocation_size: int = 0 

93 self._ntasks_per_node: int = 0 

94 self._node_workload: Dict[str, List[HybridJob]] = {} 

95 

96 @property 

97 def node_list(self) -> List[str]: 

98 return self._node_list 

99 

100 @node_list.setter 

101 def node_list(self, nodes: List[str]): 

102 self._node_list = nodes 

103 self._node_workload = {node: [] for node in nodes} 

104 

105 @property 

106 def allocation_size(self) -> int: 

107 return self._allocation_size 

108 

109 @allocation_size.setter 

110 def allocation_size(self, size: int): 

111 self._allocation_size = size 

112 

113 @property 

114 def ntasks_per_node(self) -> int: 

115 return self._ntasks_per_node 

116 

117 @ntasks_per_node.setter 

118 def ntasks_per_node(self, ntasks: int): 

119 self._ntasks_per_node = ntasks 

120 

121 def append(self, node: str, job: HybridJob): 

122 self._node_workload[node].append(job) 

123 

124 def select_node(self, iteration) -> str: 

125 return min(self._node_workload, key=lambda node: len(self._node_workload[node])) 

126 

127 def __len__(self) -> int: 

128 n_submitted_jobs = 0 

129 for submitted_jobs in self._node_workload.values(): 

130 n_submitted_jobs += len(submitted_jobs) 

131 return n_submitted_jobs 

132 

133 def update_jobs(self): 

134 for node in self._node_workload: 

135 SlurmSemiGlobalScheduler._update_jobs_impl(self._node_workload[node]) 

136 self._node_workload[node] = [ 

137 job 

138 for job in self._node_workload[node] 

139 if job._state in [State.WAITING, State.RUNNING] 

140 ] 

141 

142 

143class SlurmSemiGlobalScheduler(HybridScheduler[HybridJob]): 

144 # always compile regular expressions to enforce ASCII matching 

145 # allow matching things like version 1.2.3-rc4 

146 _srun_version_regexp = r"slurm (\d+)[.](\d+)[.](\d+\S*)" 

147 _srun_version_pattern = re.compile(_srun_version_regexp, re.ASCII) 

148 _sbatch_job_id_regexp = r"(\d+)" 

149 _sbatch_job_id_pattern = re.compile(_sbatch_job_id_regexp, re.ASCII) 

150 _sacct_line_regexp = r"(\d+)([+]0[.]batch)?[|](\w+)" 

151 _sacct_line_pattern = re.compile(_sacct_line_regexp, re.ASCII) 

152 _use_het_prefix: bool = False 

153 

154 @classmethod 

155 def is_direct(cls) -> bool: 

156 return True 

157 

158 @classmethod 

159 def _name_impl(cls) -> str: 

160 return "slurm" 

161 

162 @classmethod 

163 def _is_available_impl(cls) -> Tuple[bool, Union[str, Tuple[str, str]]]: 

164 srun_path = shutil.which("srun") 

165 if srun_path is None: 

166 return False, "srun executable not found" 

167 

168 srun = subprocess.run( 

169 [srun_path, "--version"], 

170 stdin=subprocess.DEVNULL, 

171 stdout=subprocess.PIPE, 

172 stderr=subprocess.PIPE, 

173 universal_newlines=True, 

174 ) 

175 if srun.returncode != 0: 

176 return False, "failed to execute %s: %s" % (srun_path, srun.stderr) 

177 

178 # do not use pattern.fullmatch() because of the newline at the end of 

179 # the output 

180 match = cls._srun_version_pattern.match(srun.stdout) 

181 if match is None: 

182 e = "srun output '{:s}' does not match expected format" 

183 raise ValueError(e.format(srun.stdout)) 

184 

185 version_major = int(match.group(1)) 

186 # the minor version might be of the form `05` but the Python int() 

187 # function handles this correctly 

188 version_minor = int(match.group(2)) 

189 version_patch = match.group(3) 

190 

191 if version_major < 19 or (version_major == 19 and version_minor < 5): 

192 logger.warn( 

193 "Melissa has not been tested with Slurm versions older than 19.05.5", 

194 RuntimeWarning, 

195 ) 

196 

197 if version_major < 17 or (version_major == 17 and version_minor < 11): 

198 return ( 

199 False, 

200 ( 

201 "expected at least Slurm 17.11," 

202 f"got {version_major}.{version_minor}.{version_patch}" 

203 "which does not support heterogeneous jobs" 

204 ), 

205 ) 

206 

207 cls._use_het_prefix = version_major >= 20 

208 

209 version_str = srun.stdout[match.span(1)[0] : match.span(3)[1]] 

210 return True, (srun_path, version_str) 

211 

212 def __init__(self) -> None: 

213 # Attributes for the srun allocation on multiple nodes 

214 self.srun_ctr: int = 0 

215 self._node_workload = NodeWorkload() 

216 # Standard initialization function 

217 is_available, info = self.is_available() 

218 if not is_available: 

219 raise RuntimeError("Slurm unavailable: %s" % (info,)) 

220 

221 assert self._use_het_prefix is not None 

222 

223 def _sanity_check_impl(self, options: Options) -> List[str]: 

224 args = options.raw_arguments 

225 errors = [] 

226 

227 for a in args: 

228 if a[0] != "-": 

229 errors.append("non-option argument '{:s}' detected".format(a)) 

230 elif a in ["-n", "--ntasks", "--test-only"]: 

231 errors.append("remove '{:s}' argument".format(a)) 

232 

233 command = ["srun", "--test-only", "--ntasks=1"] + args + ["--", "true"] 

234 srun = subprocess.run( 

235 command, 

236 stdin=subprocess.DEVNULL, 

237 stdout=subprocess.DEVNULL, 

238 stderr=subprocess.PIPE, 

239 universal_newlines=True, 

240 ) 

241 if srun.returncode != 0: 

242 e = "srun error on trial execution: {:s}".format(srun.stderr) 

243 errors.append(e) 

244 

245 return errors 

246 

247 def _submit_job_impl( 

248 self, 

249 commands: List[ArgumentList], 

250 env: Environment, 

251 options: Options, 

252 name: str, 

253 unique_id: int, 

254 ) -> Tuple[ArgumentList, Environment]: 

255 if name == "melissa-server": 

256 return self._submit_job_impl_server(commands, env, options, name, unique_id) 

257 else: 

258 return self._submit_job_impl_client(commands, env, options, name, unique_id) 

259 

260 def _submit_job_impl_server( 

261 self, 

262 commands: List[ArgumentList], 

263 env: Environment, 

264 options: Options, 

265 name: str, 

266 unique_id: int, 

267 ) -> Tuple[ArgumentList, Environment]: 

268 # sbatch submission 

269 sbatch_env = os.environ.copy() 

270 sbatch_env.update(env) 

271 

272 output_filename = "./stdout/job.{:d}.{:s}.out".format(unique_id, name) 

273 error_filename = "./stdout/job.{:d}.{:s}.err".format(unique_id, name) 

274 

275 propagated_options = [ 

276 "--output={:s}".format(output_filename), 

277 "--error={:s}".format(error_filename), 

278 ] 

279 

280 uid = unique_id 

281 self._server_uid = uid 

282 

283 sbatch_options = propagated_options + options.raw_arguments 

284 

285 # serialize sbatch options 

286 def options2str(options: str) -> str: 

287 return "#SBATCH " + options 

288 

289 sbatch_options_str = [options2str(o) for o in sbatch_options] 

290 

291 sched_cmd = options.sched_cmd 

292 sched_cmd_opt = options.sched_cmd_opt 

293 

294 # assemble srun arguments 

295 srun_arguments: List[List[str]] = [] 

296 if not sched_cmd: 

297 srun_arguments = [[""] + commands[0]] 

298 else: 

299 srun_arguments = [sched_cmd_opt + ["--"] + commands[0]] 

300 

301 # serialize srun arguments 

302 def srunargs2str(hetgroup: int, args: List[str]) -> str: 

303 assert hetgroup >= 0 

304 assert hetgroup < len(commands) 

305 

306 prefix = ": " if hetgroup > 0 else "" 

307 suffix = " \\" if hetgroup + 1 < len(commands) else "" 

308 return " " + prefix + " ".join(args) + suffix 

309 

310 srun_arguments_str = [ 

311 srunargs2str(i, args) for i, args in enumerate(srun_arguments) 

312 ] 

313 

314 # write srun calls to file 

315 Path("./sbatch").mkdir(parents=True, exist_ok=True) 

316 sbatch_script_filename = "./sbatch/sbatch.{:d}.sh".format(uid) 

317 sbatch_script = ( 

318 ["#!/bin/sh"] 

319 + ["# sbatch script for job {:s}".format(name)] 

320 + sbatch_options_str 

321 + [""] 

322 + (["exec \\"] if not sched_cmd else [f"exec {sched_cmd} \\"]) 

323 + srun_arguments_str 

324 ) 

325 

326 # POSIX requires files to end with a newline; missing newlines at the 

327 # end of a file may break scripts that append text. 

328 # this string won't contain a newline at the end; it must be added 

329 # manually or by using a function that adds it, e.g., `print` 

330 sbatch_script_str_noeol = "\n".join(sbatch_script) 

331 

332 with open(sbatch_script_filename, "w") as f: 

333 print(sbatch_script_str_noeol, file=f) 

334 

335 sbatch_call = ( 

336 ["sbatch"] 

337 + ["--parsable"] 

338 + ["--job-name={:s}".format(name)] 

339 + [sbatch_script_filename] 

340 ) 

341 

342 return sbatch_call, sbatch_env 

343 

344 def _submit_job_impl_client( 

345 self, 

346 commands: List[ArgumentList], 

347 env: Environment, 

348 options: Options, 

349 name: str, 

350 unique_id: int, 

351 ) -> Tuple[ArgumentList, Environment]: 

352 # srun submission for clients 

353 srun_env = os.environ.copy() 

354 srun_env.update(env) 

355 self.srun_ctr += 1 

356 

357 output_filename = "./stdout/job.{:d}.{:s}.out".format(unique_id, name) 

358 error_filename = "./stdout/job.{:d}.{:s}.err".format(unique_id, name) 

359 

360 # command line only scanned the first time 

361 if self.srun_ctr == 1: 

362 # extract all environment variables 

363 self._node_workload.allocation_size = int(srun_env["SLURM_NTASKS"]) 

364 self._node_workload.ntasks_per_node = int(srun_env["SLURM_NTASKS_PER_NODE"]) 

365 self._node_workload.node_list = break2str(srun_env["SLURM_NODELIST"]) 

366 assert ( 

367 len(self._node_workload.node_list) 

368 == self._node_workload.allocation_size 

369 // self._node_workload.ntasks_per_node 

370 ) 

371 logger.debug( 

372 f"allocation of {self._node_workload.allocation_size} cores " 

373 f"with identified nodelist {self._node_workload.node_list} " 

374 f"from environment variable {srun_env['SLURM_NODELIST']}" 

375 ) 

376 

377 # make sure resources are available before next series of srun 

378 logger.debug( 

379 f"Scheduler has currently {len(self._node_workload)} jobs running." 

380 ) 

381 self._node_workload.update_jobs() 

382 node = self._node_workload.select_node(self.srun_ctr - 1) 

383 

384 # build propagated options of the srun command line 

385 propagated_options = [ 

386 "--output={:s}".format(output_filename), 

387 "--error={:s}".format(error_filename), 

388 "--nodelist={:s}".format(node), 

389 ] + options.raw_arguments 

390 

391 # assemble srun arguments 

392 sched_cmd = options.sched_cmd 

393 sched_cmd_opt = options.sched_cmd_opt 

394 

395 # retrieve variable value amongst scheduler_arg_client 

396 def get_from_options(list_of_options: List[str], variable_name: str) -> str: 

397 for opt in list_of_options: 

398 if variable_name in opt: 

399 idx = opt.find("=") 

400 return opt[idx + 1 :] 

401 logger.error( 

402 f"The variable name {variable_name} is not amongst {list_of_options} " 

403 "Please make sure scheduler_arg_client is well set" 

404 ) 

405 return "" 

406 

407 # write srun/job execution call 

408 srun_call: List[str] = [] 

409 srun_arguments: List[str] = [] 

410 if not sched_cmd: 

411 assert ( 

412 len(commands) == 1 

413 ), "non-unit groups are not supported in this configuration" 

414 srun_call = commands[0] 

415 # node and resource specification with environment variables 

416 new_env = { 

417 "NODEID": node, 

418 "NTASKS": get_from_options(propagated_options, "ntasks"), 

419 "TIME": get_from_options(propagated_options, "time"), 

420 } 

421 srun_env.update(new_env) 

422 else: 

423 for i, cmd in enumerate(commands): 

424 args = ( 

425 sched_cmd_opt 

426 + propagated_options 

427 + ["--"] 

428 + cmd 

429 + ([":"] if i + 1 < len(commands) else []) 

430 ) 

431 srun_arguments.extend(args) 

432 srun_call = [sched_cmd] + srun_arguments 

433 

434 return srun_call, srun_env 

435 

436 def _make_job_impl( 

437 self, proc: "Process[str]", unique_id: int, **kwargs: "dict[str, Any]" 

438 ) -> HybridJob: 

439 if unique_id == self._server_uid: 

440 job_type = JobType.Server 

441 else: 

442 job_type = JobType.Client 

443 job = HybridJob( 

444 unique_id, 

445 proc, 

446 job_type, 

447 kwargs["stdout_fname"], # type: ignore 

448 kwargs["stderr_fname"], # type: ignore 

449 ) 

450 if job._type == JobType.Client: 

451 # Retrieve the node of the client 

452 node: str = "" 

453 # retrieved from process arguments (standard submission) 

454 for arg in proc.args: # type: ignore 

455 if "nodelist" in arg: # type: ignore 

456 node = arg.split("=")[-1] # type: ignore 

457 break 

458 # retrieved from the process environment (no sched_cmd submission) 

459 if not node: 

460 node = proc.env["NODEID"] # type: ignore 

461 assert node, f"Node list missing in {proc.args} and {proc.env}" # type: ignore 

462 self._node_workload.append(node, job) 

463 return job 

464 

465 @classmethod 

466 def _update_jobs_impl(cls, jobs: List[HybridJob]) -> None: 

467 logger.debug(f"Update - job list {[j._type for j in jobs]}") 

468 

469 for j in jobs: 

470 if j._type == JobType.Client: 

471 returncode = j._process.poll() 

472 if returncode is None: 

473 state = State.RUNNING 

474 elif returncode == 0: 

475 state = State.TERMINATED 

476 else: 

477 state = State.FAILED 

478 j._state = state 

479 # note that the server job status is not monitored 

480 # trough subprocess since it is an indirect submission. 

481 # instead, the server job keeps the initial WAITING status and 

482 # the server will be marked as running by the state machine 

483 # as soon as it connects to the launcher. 

484 # Then, his running state is only monitored through regular pinging. 

485 

486 @classmethod 

487 def _cancel_client_jobs_impl(cls, jobs: List[HybridJob]) -> None: 

488 logger.debug( 

489 f"Cancel - job list {[j._type for j in jobs if j._type == JobType.Client]}" 

490 ) 

491 

492 # when the user presses ctrl+c, the shell will send all processes in 

493 # the same process group SIGINT. some programs respond intelligently to 

494 # signals by freeing resources and exiting. these programs may also 

495 # exit _immediately_ if they receive a second signal within a short 

496 # time frame (e.g., srun or mpirun which won't terminate its child 

497 # processes in this case). for this reason, we wait before terminating 

498 # jobs here. 

499 max_wait_time = time.Time(seconds=5) 

500 

501 # wait at most max_wait_time overall 

502 def compute_timeout(t_start: time.Time) -> time.Time: 

503 t_waited = time.monotonic() - t_start 

504 if t_waited < max_wait_time: 

505 return max_wait_time - t_waited 

506 return time.Time(seconds=0) 

507 

508 # terminate processes 

509 t_0 = time.monotonic() 

510 for j in jobs: 

511 if j._type == JobType.Client: 

512 try: 

513 timeout = compute_timeout(t_0) 

514 j._process.wait(timeout.total_seconds()) 

515 except subprocess.TimeoutExpired: 

516 logger.debug(f"Slurm srun scheduler terminating process {j.id()}") 

517 j._process.terminate() 

518 

519 # kill processes if necessary 

520 t_1 = time.monotonic() 

521 for j in jobs: 

522 if j._type == JobType.Client: 

523 try: 

524 timeout = compute_timeout(t_1) 

525 j._process.wait(timeout.total_seconds()) 

526 except subprocess.TimeoutExpired: 

527 logger.debug(f"Slurm srun scheduler killing process {j.id()}") 

528 j._process.kill() 

529 

530 j._state = State.FAILED 

531 

532 @classmethod 

533 def _cancel_server_job_impl( 

534 cls, jobs: List[HybridJob] 

535 ) -> Optional[Tuple[ArgumentList, Environment]]: 

536 server_job: List[HybridJob] = [j for j in jobs if j._type == JobType.Server] 

537 logger.debug(f"Cancel - job list {[j._type for j in server_job]}") 

538 

539 # if file exists then read the sbatch job id 

540 if server_job[0].stdout_fname is not None: 

541 with open(server_job[0].stdout_fname) as f: 

542 match = cls._sbatch_job_id_pattern.fullmatch(f.read().strip()) 

543 # else read from the process' stored stdout 

544 else: 

545 assert ( 

546 server_job[0]._process.stdout is not None 

547 ), "Launcher was not able to find stdout containing sbatch id " 

548 "required for the server job cancellation." 

549 match = cls._sbatch_job_id_pattern.fullmatch( 

550 server_job[0]._process.stdout.read().strip() 

551 ) 

552 if match is None: 

553 e = "no job ID found in server sbatch output." 

554 raise ValueError(e) 

555 

556 scancel_command = [ 

557 "scancel", 

558 "--batch", 

559 "--quiet", 

560 match.group(1), # type: ignore 

561 ] 

562 

563 return (scancel_command, os.environ) if server_job else None 

564 

565 def _parse_cancel_jobs_impl( 

566 self, jobs: List[HybridJob], proc: CompletedProcess 

567 ) -> None: 

568 # scancel exits with status 1 if at least one job had already been 

569 # terminated 

570 if proc.exit_status not in [0, 1]: 

571 raise RuntimeError( 

572 "scancel error: exit status {:d}".format(proc.exit_status) 

573 )