Coverage for melissa/utility/idr_torch.py: 0%

49 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-11-03 09:52 +0100

1#!/usr/bin/env python 

2# coding: utf-8 

3 

4 

5"""Fetching SLURM environment for torch distributed training.""" 

6 

7import os 

8import re 

9from dataclasses import dataclass, field 

10from typing import List 

11from melissa.scheduler.slurm_parser import break2str 

12 

13 

14@dataclass 

15class SlurmEnvironment: 

16 rank: int = field(init=False) 

17 local_rank: int = field(init=False) 

18 size: int = field(init=False) 

19 cpus_per_task: int = field(init=False) 

20 hostnames: List[str] = field(init=False) 

21 nnodes: int = field(init=False) 

22 ntasks_per_node: int = field(init=False) 

23 gpu_world_size: int = field(init=False) 

24 

25 def __post_init__(self): 

26 self.rank = int(os.environ.get('SLURM_PROCID', 0)) 

27 self.local_rank = int(os.environ.get('SLURM_LOCALID', 0)) 

28 self.node_rank = int(os.environ.get('SLURM_NODEID', 0)) 

29 self.size = int(os.environ.get('SLURM_NTASKS', 0)) 

30 self.cpus_per_task = int(os.environ.get('SLURM_CPUS_PER_TASK', 0)) 

31 self.hostnames = break2str(os.environ.get('SLURM_JOB_NODELIST', ",")) 

32 

33 self.nnodes = int(os.environ.get("SLURM_NNODES", 0)) 

34 self.ntasks_per_node = int(os.environ.get("SLURM_NTASKS_PER_NODE", 0)) 

35 self.gpu_world_size = self._get_total_gpus() 

36 

37 def _get_total_gpus(self) -> int: 

38 """Determine total GPUs allocated using SLURM environment variables in priority order.""" 

39 

40 # Declare variables before checks 

41 gpus = os.environ.get("SLURM_GPUS") 

42 gpus_per_task = os.environ.get("SLURM_GPUS_PER_TASK") 

43 ntasks = os.environ.get("SLURM_NTASKS") 

44 gpus_on_node = os.environ.get("SLURM_GPUS_ON_NODE") 

45 nnodes = os.environ.get("SLURM_NNODES") 

46 gpus_per_node = os.environ.get("SLURM_GPUS_PER_NODE") 

47 gres = os.environ.get("SLURM_GRES") 

48 

49 # Check each condition in priority order 

50 if gpus: 

51 return int(gpus) 

52 

53 if gpus_per_task and ntasks: 

54 return int(gpus_per_task) * int(ntasks) 

55 

56 if gpus_on_node and nnodes: 

57 return int(gpus_on_node) * int(nnodes) 

58 

59 if gpus_per_node and nnodes: 

60 return int(gpus_per_node) * int(nnodes) 

61 

62 if gres and nnodes: 

63 return self._extract_gpus_from_gres(gres) * int(nnodes) 

64 

65 return 0 # Default to 0 if no valid GPU allocation info is found. 

66 

67 def _extract_gpus_from_gres(self, gres: str) -> int: 

68 """Extracts GPU count from SLURM_GRES (e.g., 'gpu:A100:4' -> 4).""" 

69 match = re.search(r"gpu(:[a-zA-Z0-9_-]+)?:([0-9]+)", gres) 

70 if match: 

71 return int(match.group(2)) 

72 return 0 # Default to 0 if no GPUs are found in SLURM_GRES.