Coverage for melissa/utility/idr_torch.py: 84%
32 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 22:25 +0100
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 22:25 +0100
1#!/usr/bin/env python
2# coding: utf-8
5"""Fetching SLURM environment for torch distributed training."""
7import os
8from dataclasses import dataclass, field
9from typing import List
10from melissa.scheduler.slurm_parser import break2str
13@dataclass
14class SlurmEnvironment:
15 rank: int = field(init=False)
16 local_rank: int = field(init=False)
17 size: int = field(init=False)
18 cpus_per_task: int = field(init=False)
19 hostnames: List[str] = field(init=False)
20 gpu_ids: List[str] = field(init=False)
21 nnodes: int = field(init=False)
22 ntasks_per_node: int = field(init=False)
23 master_addr: str = field(init=False)
24 master_port: str = field(init=False)
26 def __post_init__(self):
27 self.rank = int(os.environ.get('SLURM_PROCID', 0))
28 self.local_rank = int(os.environ.get('SLURM_LOCALID', 0))
29 self.size = int(os.environ.get('SLURM_NTASKS', 0))
30 self.cpus_per_task = int(os.environ.get('SLURM_CPUS_PER_TASK', 0))
31 self.hostnames = break2str(os.environ.get('SLURM_JOB_NODELIST', ","))
32 self.gpu_ids = os.environ.get('SLURM_STEP_GPUS', "")
33 if self.gpu_ids != "":
34 self.gpu_ids = self.gpu_ids.split(",")
36 self.nnodes = int(os.environ.get("SLURM_NNODES", 0))
37 self.ntasks_per_node = int(os.environ.get("SLURM_NTASKS_PER_NODE", -1))
39 if self.gpu_ids != "" and len(self.hostnames) > 1:
40 # Setting MASTER_ADDR and MASTER_PORT
41 self.master_addr = self.hostnames[0]
42 self.master_port = str(12345 + int(min(self.gpu_ids) if self.gpu_ids else 0))
44 # Update environment variables for distributed training
45 os.environ['MASTER_ADDR'] = self.master_addr
46 os.environ['MASTER_PORT'] = self.master_port