Coverage for melissa/scheduler/slurm_semiglobal.py: 25%

268 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-09-22 10:36 +0000

1#!/usr/bin/python3 

2 

3# Copyright (c) 2020-2022, Institut National de Recherche en Informatique et en Automatique (Inria) 

4# All rights reserved. 

5# 

6# Redistribution and use in source and binary forms, with or without 

7# modification, are permitted provided that the following conditions are met: 

8# 

9# * Redistributions of source code must retain the above copyright notice, this 

10# list of conditions and the following disclaimer. 

11# 

12# * Redistributions in binary form must reproduce the above copyright notice, 

13# this list of conditions and the following disclaimer in the documentation 

14# and/or other materials provided with the distribution. 

15# 

16# * Neither the name of the copyright holder nor the names of its 

17# contributors may be used to endorse or promote products derived from 

18# this software without specific prior written permission. 

19# 

20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 

21# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 

22# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 

23# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 

24# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 

25# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 

26# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 

27# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 

28# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 

29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 

30 

31import enum 

32import logging 

33import os 

34import re 

35import shutil 

36import subprocess 

37from typing import List, Tuple, Union, Dict, Optional 

38from pathlib import Path 

39 

40from melissa.utility import time 

41from melissa.utility.process import ArgumentList, CompletedProcess, Environment, Process 

42 

43from .job import Id, Job, State 

44from .scheduler import HybridScheduler, Options 

45from .slurm_parser import break2str 

46 

47logger = logging.getLogger(__name__) 

48 

49 

50class JobType(enum.Enum): 

51 

52 Server = 0 

53 Client = 1 

54 

55 

56class HybridJob(Job): 

57 def __init__(self, uid: Id, process: "subprocess.Popen[str]", job_type: JobType) -> None: 

58 super().__init__() 

59 self._type = job_type 

60 self._state = State.WAITING 

61 self._uid = uid 

62 self._process = process 

63 

64 def id(self) -> Id: 

65 return self._process.pid 

66 

67 def unique_id(self) -> Union[str, int]: 

68 return self._uid 

69 

70 def state(self) -> State: 

71 return self._state 

72 

73 def __repr__(self) -> str: 

74 r = f"<{self.__class__.__name__} (id={self.id:d},state={self.state:s})>" 

75 return r 

76 

77 

78class NodeWorkload: 

79 """Keep track of submitted jobs on different nodes.""" 

80 

81 def __init__(self): 

82 self._node_list: List[str] = [] 

83 self._allocation_size: int = 0 

84 self._ntasks_per_node: int = 0 

85 self._node_workload: Dict[str, List[HybridJob]] = {} 

86 

87 @property 

88 def node_list(self) -> List[str]: 

89 return self._node_list 

90 

91 @node_list.setter 

92 def node_list(self, nodes: List[str]): 

93 self._node_list = nodes 

94 self._node_workload = {node: [] for node in nodes} 

95 

96 @property 

97 def allocation_size(self) -> int: 

98 return self._allocation_size 

99 

100 @allocation_size.setter 

101 def allocation_size(self, size: int): 

102 self._allocation_size = size 

103 

104 @property 

105 def ntasks_per_node(self) -> int: 

106 return self._ntasks_per_node 

107 

108 @ntasks_per_node.setter 

109 def ntasks_per_node(self, ntasks: int): 

110 self._ntasks_per_node = ntasks 

111 

112 def append(self, node: str, job: HybridJob): 

113 self._node_workload[node].append(job) 

114 

115 def select_node(self, iteration) -> str: 

116 return min( 

117 self._node_workload, key=lambda node: len(self._node_workload[node]) 

118 ) 

119 

120 def __len__(self) -> int: 

121 n_submitted_jobs = 0 

122 for submitted_jobs in self._node_workload.values(): 

123 n_submitted_jobs += len(submitted_jobs) 

124 return n_submitted_jobs 

125 

126 def update_jobs(self): 

127 for node in self._node_workload: 

128 SlurmSemiGlobalScheduler._update_jobs_impl(self._node_workload[node]) 

129 self._node_workload[node] = [ 

130 job for job in self._node_workload[node] 

131 if job._state in [State.WAITING, State.RUNNING] 

132 ] 

133 

134 

135class SlurmSemiGlobalScheduler(HybridScheduler[HybridJob]): 

136 # always compile regular expressions to enforce ASCII matching 

137 # allow matching things like version 1.2.3-rc4 

138 _srun_version_regexp = r"slurm (\d+)[.](\d+)[.](\d+\S*)" 

139 _srun_version_pattern = re.compile(_srun_version_regexp, re.ASCII) 

140 _sbatch_job_id_regexp = r"(\d+)" 

141 _sbatch_job_id_pattern = re.compile(_sbatch_job_id_regexp, re.ASCII) 

142 _sacct_line_regexp = r"(\d+)([+]0[.]batch)?[|](\w+)" 

143 _sacct_line_pattern = re.compile(_sacct_line_regexp, re.ASCII) 

144 _use_het_prefix: bool = False 

145 

146 @classmethod 

147 def is_direct(cls) -> bool: 

148 return True 

149 

150 @classmethod 

151 def _name_impl(cls) -> str: 

152 return "slurm" 

153 

154 @classmethod 

155 def _is_available_impl(cls) -> Tuple[bool, Union[str, Tuple[str, str]]]: 

156 srun_path = shutil.which("srun") 

157 if srun_path is None: 

158 return False, "srun executable not found" 

159 

160 srun = subprocess.run( 

161 [srun_path, "--version"], 

162 stdin=subprocess.DEVNULL, 

163 stdout=subprocess.PIPE, 

164 stderr=subprocess.PIPE, 

165 universal_newlines=True, 

166 ) 

167 if srun.returncode != 0: 

168 return False, "failed to execute %s: %s" % (srun_path, srun.stderr) 

169 

170 # do not use pattern.fullmatch() because of the newline at the end of 

171 # the output 

172 match = cls._srun_version_pattern.match(srun.stdout) 

173 if match is None: 

174 e = "srun output '{:s}' does not match expected format" 

175 raise ValueError(e.format(srun.stdout)) 

176 

177 version_major = int(match.group(1)) 

178 # the minor version might be of the form `05` but the Python int() 

179 # function handles this correctly 

180 version_minor = int(match.group(2)) 

181 version_patch = match.group(3) 

182 

183 if version_major < 19 or (version_major == 19 and version_minor < 5): 

184 logger.warn( 

185 "Melissa has not been tested with Slurm versions older than 19.05.5", RuntimeWarning 

186 ) 

187 

188 if version_major < 17 or (version_major == 17 and version_minor < 11): 

189 return ( 

190 False, 

191 ( 

192 "expected at least Slurm 17.11," 

193 f"got {version_major}.{version_minor}.{version_patch}" 

194 "which does not support heterogeneous jobs" 

195 ), 

196 ) 

197 

198 cls._use_het_prefix = version_major >= 20 

199 

200 version_str = srun.stdout[match.span(1)[0] : match.span(3)[1]] 

201 return True, (srun_path, version_str) 

202 

203 def __init__(self) -> None: 

204 # Attributes for the srun allocation on multiple nodes 

205 self.srun_ctr: int = 0 

206 self._node_workload = NodeWorkload() 

207 # Standard initialization function 

208 is_available, info = self.is_available() 

209 if not is_available: 

210 raise RuntimeError("Slurm unavailable: %s" % (info,)) 

211 

212 assert self._use_het_prefix is not None 

213 

214 def _sanity_check_impl(self, options: Options) -> List[str]: 

215 args = options.raw_arguments 

216 errors = [] 

217 

218 for a in args: 

219 if a[0] != "-": 

220 errors.append("non-option argument '{:s}' detected".format(a)) 

221 elif a in ["-n", "--ntasks", "--test-only"]: 

222 errors.append("remove '{:s}' argument".format(a)) 

223 

224 command = ["srun", "--test-only", "--ntasks=1"] + args + ["--", "true"] 

225 srun = subprocess.run( 

226 command, 

227 stdin=subprocess.DEVNULL, 

228 stdout=subprocess.DEVNULL, 

229 stderr=subprocess.PIPE, 

230 universal_newlines=True, 

231 ) 

232 if srun.returncode != 0: 

233 e = "srun error on trial execution: {:s}".format(srun.stderr) 

234 errors.append(e) 

235 

236 return errors 

237 

238 def _submit_job_impl( 

239 self, 

240 commands: List[ArgumentList], 

241 env: Environment, 

242 options: Options, 

243 name: str, 

244 unique_id: int, 

245 ) -> Tuple[ArgumentList, Environment]: 

246 if name == "melissa-server": 

247 return self._submit_job_impl_server(commands, env, options, name, unique_id) 

248 else: 

249 return self._submit_job_impl_client(commands, env, options, name, unique_id) 

250 

251 def _submit_job_impl_server( 

252 self, 

253 commands: List[ArgumentList], 

254 env: Environment, 

255 options: Options, 

256 name: str, 

257 unique_id: int, 

258 ) -> Tuple[ArgumentList, Environment]: 

259 # sbatch submission 

260 sbatch_env = os.environ.copy() 

261 sbatch_env.update(env) 

262 

263 output_filename = "./stdout/job.{:d}.{:s}.out".format(unique_id, name) 

264 error_filename = "./stdout/job.{:d}.{:s}.err".format(unique_id, name) 

265 

266 propagated_options = [ 

267 "--output={:s}".format(output_filename), 

268 "--error={:s}".format(error_filename), 

269 ] 

270 

271 uid = unique_id 

272 self._server_uid = uid 

273 

274 sbatch_options = propagated_options + options.raw_arguments 

275 

276 # serialize sbatch options 

277 def options2str(options: str) -> str: 

278 return "#SBATCH " + options 

279 

280 sbatch_options_str = [options2str(o) for o in sbatch_options] 

281 

282 sched_cmd = options.sched_cmd 

283 sched_cmd_opt = options.sched_cmd_opt 

284 

285 # assemble srun arguments 

286 srun_arguments: List[List[str]] = [] 

287 if not sched_cmd: 

288 srun_arguments = [[""] + commands[0]] 

289 else: 

290 srun_arguments = [sched_cmd_opt + ["--"] + commands[0]] 

291 

292 # serialize srun arguments 

293 def srunargs2str(hetgroup: int, args: List[str]) -> str: 

294 assert hetgroup >= 0 

295 assert hetgroup < len(commands) 

296 

297 prefix = ": " if hetgroup > 0 else "" 

298 suffix = " \\" if hetgroup + 1 < len(commands) else "" 

299 return " " + prefix + " ".join(args) + suffix 

300 

301 srun_arguments_str = [srunargs2str(i, args) for i, args in enumerate(srun_arguments)] 

302 

303 # write srun calls to file 

304 Path('./sbatch').mkdir(parents=True, exist_ok=True) 

305 sbatch_script_filename = "./sbatch/sbatch.{:d}.sh".format(uid) 

306 sbatch_script = ( 

307 ["#!/bin/sh"] 

308 + ["# sbatch script for job {:s}".format(name)] 

309 + sbatch_options_str 

310 + [""] 

311 + (["exec \\"] if not sched_cmd else [f"exec {sched_cmd} \\"]) 

312 + srun_arguments_str 

313 ) 

314 

315 # POSIX requires files to end with a newline; missing newlines at the 

316 # end of a file may break scripts that append text. 

317 # this string won't contain a newline at the end; it must be added 

318 # manually or by using a function that adds it, e.g., `print` 

319 sbatch_script_str_noeol = "\n".join(sbatch_script) 

320 

321 with open(sbatch_script_filename, "w") as f: 

322 print(sbatch_script_str_noeol, file=f) 

323 

324 sbatch_call = ( 

325 ["sbatch"] 

326 + ["--parsable"] 

327 + ["--job-name={:s}".format(name)] 

328 + [sbatch_script_filename] 

329 ) 

330 

331 return sbatch_call, sbatch_env 

332 

333 def _submit_job_impl_client( 

334 self, 

335 commands: List[ArgumentList], 

336 env: Environment, 

337 options: Options, 

338 name: str, 

339 unique_id: int, 

340 ) -> Tuple[ArgumentList, Environment]: 

341 # srun submission for clients 

342 srun_env = os.environ.copy() 

343 srun_env.update(env) 

344 self.srun_ctr += 1 

345 

346 output_filename = "./stdout/job.{:d}.{:s}.out".format(unique_id, name) 

347 error_filename = "./stdout/job.{:d}.{:s}.err".format(unique_id, name) 

348 

349 # command line only scanned the first time 

350 if self.srun_ctr == 1: 

351 # extract all environment variables 

352 self._node_workload.allocation_size = int(srun_env["SLURM_NTASKS"]) 

353 self._node_workload.ntasks_per_node = int(srun_env["SLURM_NTASKS_PER_NODE"]) 

354 self._node_workload.node_list = break2str(srun_env["SLURM_NODELIST"]) 

355 assert ( 

356 len(self._node_workload.node_list) 

357 == self._node_workload.allocation_size // self._node_workload.ntasks_per_node 

358 ) 

359 logger.debug( 

360 f"allocation of {self._node_workload.allocation_size} cores " 

361 f"with identified nodelist {self._node_workload.node_list} " 

362 f"from environment variable {srun_env['SLURM_NODELIST']}" 

363 ) 

364 

365 # make sure resources are available before next series of srun 

366 logger.debug(f"Scheduler has currently {len(self._node_workload)} jobs running.") 

367 self._node_workload.update_jobs() 

368 node = self._node_workload.select_node(self.srun_ctr - 1) 

369 

370 # build propagated options of the srun command line 

371 propagated_options = [ 

372 "--output={:s}".format(output_filename), 

373 "--error={:s}".format(error_filename), 

374 "--nodelist={:s}".format(node), 

375 ] + options.raw_arguments 

376 

377 # assemble srun arguments 

378 sched_cmd = options.sched_cmd 

379 sched_cmd_opt = options.sched_cmd_opt 

380 

381 # retrieve variable value amongst scheduler_arg_client 

382 def get_from_options(list_of_options: List[str], variable_name: str) -> str: 

383 for opt in list_of_options: 

384 if variable_name in opt: 

385 idx = opt.find("=") 

386 return opt[idx + 1:] 

387 logger.error( 

388 f"The variable name {variable_name} is not amongst {list_of_options} " 

389 "Please make sure scheduler_arg_client is well set" 

390 ) 

391 return "" 

392 

393 # write srun/job execution call 

394 srun_call: List[str] = [] 

395 srun_arguments: List[str] = [] 

396 if not sched_cmd: 

397 assert len(commands) == 1, "non-unit groups are not supported in this configuration" 

398 srun_call = commands[0] 

399 # node and resource specification with environment variables 

400 new_env = { 

401 "NODEID": node, 

402 "NTASKS": get_from_options(propagated_options, "ntasks"), 

403 "TIME": get_from_options(propagated_options, "time") 

404 } 

405 srun_env.update(new_env) 

406 else: 

407 for i, cmd in enumerate(commands): 

408 args = ( 

409 sched_cmd_opt + propagated_options + ["--"] + cmd 

410 + ([":"] if i + 1 < len(commands) else []) 

411 ) 

412 srun_arguments.extend(args) 

413 srun_call = [sched_cmd] + srun_arguments 

414 

415 return srun_call, srun_env 

416 

417 def _make_job_impl(self, proc: "Process[str]", unique_id: int) -> HybridJob: 

418 if unique_id == self._server_uid: 

419 job_type = JobType.Server 

420 else: 

421 job_type = JobType.Client 

422 job = HybridJob(unique_id, proc, job_type) 

423 if job._type == JobType.Client: 

424 # Retrieve the node of the client 

425 node: str = "" 

426 # retrieved from process arguments (standard submission) 

427 for arg in proc.args: # type: ignore 

428 if "nodelist" in arg: # type: ignore 

429 node = arg.split("=")[-1] # type: ignore 

430 break 

431 # retrieved from the process environment (no sched_cmd submission) 

432 if not node: 

433 node = proc.env["NODEID"] # type: ignore 

434 assert node, f"Node list missing in {proc.args} and {proc.env}" # type: ignore 

435 self._node_workload.append(node, job) 

436 return job 

437 

438 @classmethod 

439 def _update_jobs_impl(cls, jobs: List[HybridJob]) -> None: 

440 logger.debug(f"Update - job list {[j._type for j in jobs]}") 

441 

442 for j in jobs: 

443 if j._type == JobType.Client: 

444 returncode = j._process.poll() 

445 if returncode is None: 

446 state = State.RUNNING 

447 elif returncode == 0: 

448 state = State.TERMINATED 

449 else: 

450 state = State.FAILED 

451 j._state = state 

452 # note that the server job status is not monitored 

453 # trough subprocess since it is an indirect submission. 

454 # instead, the server job keeps the initial WAITING status and 

455 # the server will be marked as running by the state machine 

456 # as soon as it connects to the launcher. 

457 # Then, his running state is only monitored through regular pinging. 

458 

459 @classmethod 

460 def _cancel_client_jobs_impl(cls, jobs: List[HybridJob]) -> None: 

461 logger.debug(f"Cancel - job list {[j._type for j in jobs if j._type == JobType.Client]}") 

462 

463 # when the user presses ctrl+c, the shell will send all processes in 

464 # the same process group SIGINT. some programs respond intelligently to 

465 # signals by freeing resources and exiting. these programs may also 

466 # exit _immediately_ if they receive a second signal within a short 

467 # time frame (e.g., srun or mpirun which won't terminate its child 

468 # processes in this case). for this reason, we wait before terminating 

469 # jobs here. 

470 max_wait_time = time.Time(seconds=5) 

471 

472 # wait at most max_wait_time overall 

473 def compute_timeout(t_start: time.Time) -> time.Time: 

474 t_waited = time.monotonic() - t_start 

475 if t_waited < max_wait_time: 

476 return max_wait_time - t_waited 

477 return time.Time(seconds=0) 

478 

479 # terminate processes 

480 t_0 = time.monotonic() 

481 for j in jobs: 

482 if j._type == JobType.Client: 

483 try: 

484 timeout = compute_timeout(t_0) 

485 j._process.wait(timeout.total_seconds()) 

486 except subprocess.TimeoutExpired: 

487 logger.debug(f"Slurm srun scheduler terminating process {j.id()}") 

488 j._process.terminate() 

489 

490 # kill processes if necessary 

491 t_1 = time.monotonic() 

492 for j in jobs: 

493 if j._type == JobType.Client: 

494 try: 

495 timeout = compute_timeout(t_1) 

496 j._process.wait(timeout.total_seconds()) 

497 except subprocess.TimeoutExpired: 

498 logger.debug(f"Slurm srun scheduler killing process {j.id()}") 

499 j._process.kill() 

500 

501 j._state = State.FAILED 

502 

503 @classmethod 

504 def _cancel_server_job_impl( 

505 cls, jobs: List[HybridJob] 

506 ) -> Optional[Tuple[ArgumentList, Environment]]: 

507 server_job: List[HybridJob] = [j for j in jobs if j._type == JobType.Server] 

508 logger.debug(f"Cancel - job list {[j._type for j in server_job]}") 

509 

510 job_list = [str(server_job[0].id())] 

511 scancel_command = ["scancel", "--batch", "--quiet"] + job_list 

512 

513 return (scancel_command, os.environ) if server_job else None 

514 

515 def _parse_cancel_jobs_impl(self, jobs: List[HybridJob], proc: CompletedProcess) -> None: 

516 # scancel exits with status 1 if at least one job had already been 

517 # terminated 

518 if proc.exit_status not in [0, 1]: 

519 raise RuntimeError("scancel error: exit status {:d}".format(proc.exit_status))