Coverage for melissa/utility/idr_torch.py: 0%
49 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
1#!/usr/bin/env python
2# coding: utf-8
5"""Fetching SLURM environment for torch distributed training."""
7import os
8import re
9from dataclasses import dataclass, field
10from typing import List
11from melissa.scheduler.slurm_parser import break2str
14@dataclass
15class SlurmEnvironment:
16 rank: int = field(init=False)
17 local_rank: int = field(init=False)
18 size: int = field(init=False)
19 cpus_per_task: int = field(init=False)
20 hostnames: List[str] = field(init=False)
21 nnodes: int = field(init=False)
22 ntasks_per_node: int = field(init=False)
23 gpu_world_size: int = field(init=False)
25 def __post_init__(self):
26 self.rank = int(os.environ.get('SLURM_PROCID', 0))
27 self.local_rank = int(os.environ.get('SLURM_LOCALID', 0))
28 self.node_rank = int(os.environ.get('SLURM_NODEID', 0))
29 self.size = int(os.environ.get('SLURM_NTASKS', 0))
30 self.cpus_per_task = int(os.environ.get('SLURM_CPUS_PER_TASK', 0))
31 self.hostnames = break2str(os.environ.get('SLURM_JOB_NODELIST', ","))
33 self.nnodes = int(os.environ.get("SLURM_NNODES", 0))
34 self.ntasks_per_node = int(os.environ.get("SLURM_NTASKS_PER_NODE", 0))
35 self.gpu_world_size = self._get_total_gpus()
37 def _get_total_gpus(self) -> int:
38 """Determine total GPUs allocated using SLURM environment variables in priority order."""
40 # Declare variables before checks
41 gpus = os.environ.get("SLURM_GPUS")
42 gpus_per_task = os.environ.get("SLURM_GPUS_PER_TASK")
43 ntasks = os.environ.get("SLURM_NTASKS")
44 gpus_on_node = os.environ.get("SLURM_GPUS_ON_NODE")
45 nnodes = os.environ.get("SLURM_NNODES")
46 gpus_per_node = os.environ.get("SLURM_GPUS_PER_NODE")
47 gres = os.environ.get("SLURM_GRES")
49 # Check each condition in priority order
50 if gpus:
51 return int(gpus)
53 if gpus_per_task and ntasks:
54 return int(gpus_per_task) * int(ntasks)
56 if gpus_on_node and nnodes:
57 return int(gpus_on_node) * int(nnodes)
59 if gpus_per_node and nnodes:
60 return int(gpus_per_node) * int(nnodes)
62 if gres and nnodes:
63 return self._extract_gpus_from_gres(gres) * int(nnodes)
65 return 0 # Default to 0 if no valid GPU allocation info is found.
67 def _extract_gpus_from_gres(self, gres: str) -> int:
68 """Extracts GPU count from SLURM_GRES (e.g., 'gpu:A100:4' -> 4)."""
69 match = re.search(r"gpu(:[a-zA-Z0-9_-]+)?:([0-9]+)", gres)
70 if match:
71 return int(match.group(2))
72 return 0 # Default to 0 if no GPUs are found in SLURM_GRES.