Coverage for melissa/scheduler/slurm_global.py: 28%

138 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-11-19 09:33 +0100

1#!/usr/bin/python3 

2 

3# Copyright (c) 2020-2022, Institut National de Recherche en Informatique et en Automatique (Inria) 

4# All rights reserved. 

5# 

6# Redistribution and use in source and binary forms, with or without 

7# modification, are permitted provided that the following conditions are met: 

8# 

9# * Redistributions of source code must retain the above copyright notice, this 

10# list of conditions and the following disclaimer. 

11# 

12# * Redistributions in binary form must reproduce the above copyright notice, 

13# this list of conditions and the following disclaimer in the documentation 

14# and/or other materials provided with the distribution. 

15# 

16# * Neither the name of the copyright holder nor the names of its 

17# contributors may be used to endorse or promote products derived from 

18# this software without specific prior written permission. 

19# 

20# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 

21# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 

22# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 

23# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 

24# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 

25# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 

26# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 

27# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 

28# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 

29# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 

30 

31import logging 

32import os 

33import re 

34import shutil 

35import subprocess 

36from types import ModuleType 

37from typing import cast, List, Tuple, Union, Any 

38 

39from melissa.utility import time 

40from melissa.utility.process import ArgumentList, Environment, Process 

41 

42from .job import Id, Job, State 

43from .scheduler import DirectScheduler, Options 

44 

45logging = cast(ModuleType, logging.getLogger(__name__)) 

46 

47 

48class SrunJob(Job): 

49 def __init__(self, uid: Id, process: "subprocess.Popen[str]") -> None: 

50 super().__init__() 

51 self._uid = uid 

52 self._process = process 

53 self._state = State.RUNNING 

54 

55 def id(self) -> Id: 

56 return self._process.pid 

57 

58 def unique_id(self) -> Union[str, int]: 

59 return self._uid 

60 

61 def state(self) -> State: 

62 return self._state 

63 

64 def __repr__(self) -> str: 

65 r = "SrunJob(id={:d},state={:s})".format(self.id(), str(self._state)) 

66 return r 

67 

68 

69class SlurmGlobalScheduler(DirectScheduler[SrunJob]): 

70 # always compile regular expressions to enforce ASCII matching 

71 # allow matching things like version 1.2.3-rc4 

72 # _srun_version_regexp = r"slurm (\d+)[.](\d+)[.](\d+\S*)" 

73 _srun_version_regexp = r"slurm (\d+)\.(\d+)\.(\d+)(-\S+)?" 

74 

75 _srun_version_pattern = re.compile(_srun_version_regexp, re.ASCII) 

76 _sbatch_job_id_regexp = r"(\d+)" 

77 _sbatch_job_id_pattern = re.compile(_sbatch_job_id_regexp, re.ASCII) 

78 _sacct_line_regexp = r"(\d+)([+]0[.]batch)?[|](\w+)" 

79 _sacct_line_pattern = re.compile(_sacct_line_regexp, re.ASCII) 

80 _use_het_prefix: bool = False 

81 

82 @classmethod 

83 def _name_impl(cls) -> str: 

84 return "slurm" 

85 

86 @classmethod 

87 def _is_available_impl(cls) -> Tuple[bool, Union[str, Tuple[str, str]]]: 

88 srun_path = shutil.which("srun") 

89 if srun_path is None: 

90 return False, "srun executable not found" 

91 

92 srun = subprocess.run( 

93 [srun_path, "--version"], 

94 stdin=subprocess.DEVNULL, 

95 stdout=subprocess.PIPE, 

96 stderr=subprocess.PIPE, 

97 universal_newlines=True, 

98 ) 

99 if srun.returncode != 0: 

100 return False, "failed to execute %s: %s" % (srun_path, srun.stderr) 

101 

102 # do not use pattern.fullmatch() because of the newline at the end of 

103 # the output 

104 match = cls._srun_version_pattern.match(srun.stdout) 

105 if match is None: 

106 e = "srun output '{:s}' does not match expected format" 

107 raise ValueError(e.format(srun.stdout)) 

108 

109 version_major = int(match.group(1)) 

110 # the minor version might be of the form `05` but the Python int() 

111 # function handles this correctly 

112 version_minor = int(match.group(2)) 

113 version_patch = int(match.group(3)) 

114 

115 if version_major < 19 or (version_major == 19 and version_minor < 5): 

116 logging.warn( 

117 "Melissa has not been tested with Slurm versions older than 19.05.5", RuntimeWarning 

118 ) 

119 

120 if version_major < 17 or (version_major == 17 and version_minor < 11): 

121 return ( 

122 False, 

123 ( 

124 "Expected at least Slurm 17.11, got" 

125 f"{version_major}.{version_minor}.{version_patch}" 

126 "which does not support heterogeneous jobs" 

127 ), 

128 ) 

129 

130 cls._use_het_prefix = version_major >= 20 

131 

132 version_str = srun.stdout[match.span(1)[0] : match.span(3)[1]] 

133 return True, (srun_path, version_str) 

134 

135 def __init__(self) -> None: 

136 is_available, info = self.is_available() 

137 if not is_available: 

138 raise RuntimeError("Slurm unavailable: %s" % (info,)) 

139 

140 assert self._use_het_prefix is not None 

141 

142 def _sanity_check_impl(self, options: Options) -> List[str]: 

143 args = options.raw_arguments 

144 errors = [] 

145 

146 for a in args: 

147 if a[0] != "-": 

148 errors.append("non-option argument '{:s}' detected".format(a)) 

149 elif a in ["-n", "--ntasks", "--test-only"]: 

150 errors.append("remove '{:s}' argument".format(a)) 

151 

152 command = ["srun", "--test-only", "--ntasks=1"] + args + ["--", "true"] 

153 srun = subprocess.run( 

154 command, 

155 stdin=subprocess.DEVNULL, 

156 stdout=subprocess.DEVNULL, 

157 stderr=subprocess.PIPE, 

158 universal_newlines=True, 

159 ) 

160 if srun.returncode != 0: 

161 e = "srun error on trial execution: {:s}".format(srun.stderr) 

162 errors.append(e) 

163 

164 return errors 

165 

166 def _submit_job_impl( 

167 self, 

168 commands: List[ArgumentList], 

169 env: Environment, 

170 options: Options, 

171 name: str, 

172 unique_id: int, 

173 ) -> Tuple[ArgumentList, Environment]: 

174 # Approach to environment variables: 

175 # By default all environment variables are propagated 

176 srun_env = os.environ.copy() 

177 srun_env.update(env) 

178 

179 output_filename = "./stdout/job.{:d}.{:s}.out".format(unique_id, name) 

180 error_filename = "./stdout/job.{:d}.{:s}.err".format(unique_id, name) 

181 

182 # build propagated options of the srun command line 

183 propagated_options = [ 

184 "--output={:s}".format(output_filename), 

185 "--error={:s}".format(error_filename), 

186 ] + options.raw_arguments 

187 

188 sched_cmd = options.sched_cmd 

189 sched_cmd_opt = options.sched_cmd_opt 

190 

191 assert len(commands) == 1, "non-unit groups are not supported in this configuration" 

192 srun_arguments: List[List[str]] = [] 

193 if not sched_cmd: 

194 srun_arguments = [commands[0]] 

195 else: 

196 srun_arguments = [["--"], commands[0]] 

197 

198 def args2str(args: List[str]) -> str: 

199 return "" + "".join(args) 

200 logging.debug(f"srun_arguments: {srun_arguments}") 

201 srun_arguments_str = [args2str(args) for args in srun_arguments] 

202 logging.debug(f"srun_arguments_str: {srun_arguments_str}") 

203 

204 # write srun/job execution call 

205 if not sched_cmd: 

206 srun_call = srun_arguments_str 

207 else: 

208 srun_call = [sched_cmd] + sched_cmd_opt + propagated_options + srun_arguments_str 

209 

210 return srun_call, srun_env 

211 

212 def _make_job_impl(self, proc: "Process[str]", unique_id: int, 

213 **kwargs: "dict[str, Any]") -> SrunJob: 

214 return SrunJob(unique_id, proc) 

215 

216 @classmethod 

217 def _update_jobs_impl(cls, jobs: List[SrunJob]) -> None: 

218 for j in jobs: 

219 returncode = j._process.poll() 

220 if returncode is None: 

221 state = State.RUNNING 

222 elif returncode == 0: 

223 state = State.TERMINATED 

224 else: 

225 state = State.FAILED 

226 j._state = state 

227 

228 @classmethod 

229 def _cancel_jobs_impl(cls, jobs: List[SrunJob]) -> None: 

230 # when the user presses ctrl+c, the shell will send all processes in 

231 # the same process group SIGINT. some programs respond intelligently to 

232 # signals by freeing ressources and exiting. these programs may also 

233 # exit _immediately_ if they receive a second signal within a short 

234 # time frame (e.g., srun or mpirun which won't terminate its child 

235 # processes in this case). for this reason, we wait before terminating 

236 # jobs here. 

237 max_wait_time = time.Time(seconds=5) 

238 

239 # wait at most max_wait_time overall 

240 def compute_timeout(t_start: time.Time) -> time.Time: 

241 t_waited = time.monotonic() - t_start 

242 if t_waited < max_wait_time: 

243 return max_wait_time - t_waited 

244 return time.Time(seconds=0) 

245 

246 # terminate processes 

247 t_0 = time.monotonic() 

248 for j in jobs: 

249 try: 

250 timeout = compute_timeout(t_0) 

251 j._process.wait(timeout.total_seconds()) 

252 except subprocess.TimeoutExpired: 

253 logging.debug("Slurm srun scheduler terminating process %d", j.id()) 

254 j._process.terminate() 

255 

256 # kill processes if necessary 

257 t_1 = time.monotonic() 

258 for j in jobs: 

259 try: 

260 timeout = compute_timeout(t_1) 

261 j._process.wait(timeout.total_seconds()) 

262 except subprocess.TimeoutExpired: 

263 logging.debug("Slurm srun scheduler killing process %d", j.id()) 

264 j._process.kill() 

265 

266 j._state = State.FAILED