Coverage for melissa/utility/idr_torch.py: 84%

32 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-10 22:25 +0100

1#!/usr/bin/env python 

2# coding: utf-8 

3 

4 

5"""Fetching SLURM environment for torch distributed training.""" 

6 

7import os 

8from dataclasses import dataclass, field 

9from typing import List 

10from melissa.scheduler.slurm_parser import break2str 

11 

12 

13@dataclass 

14class SlurmEnvironment: 

15 rank: int = field(init=False) 

16 local_rank: int = field(init=False) 

17 size: int = field(init=False) 

18 cpus_per_task: int = field(init=False) 

19 hostnames: List[str] = field(init=False) 

20 gpu_ids: List[str] = field(init=False) 

21 nnodes: int = field(init=False) 

22 ntasks_per_node: int = field(init=False) 

23 master_addr: str = field(init=False) 

24 master_port: str = field(init=False) 

25 

26 def __post_init__(self): 

27 self.rank = int(os.environ.get('SLURM_PROCID', 0)) 

28 self.local_rank = int(os.environ.get('SLURM_LOCALID', 0)) 

29 self.size = int(os.environ.get('SLURM_NTASKS', 0)) 

30 self.cpus_per_task = int(os.environ.get('SLURM_CPUS_PER_TASK', 0)) 

31 self.hostnames = break2str(os.environ.get('SLURM_JOB_NODELIST', ",")) 

32 self.gpu_ids = os.environ.get('SLURM_STEP_GPUS', "") 

33 if self.gpu_ids != "": 

34 self.gpu_ids = self.gpu_ids.split(",") 

35 

36 self.nnodes = int(os.environ.get("SLURM_NNODES", 0)) 

37 self.ntasks_per_node = int(os.environ.get("SLURM_NTASKS_PER_NODE", -1)) 

38 

39 if self.gpu_ids != "" and len(self.hostnames) > 1: 

40 # Setting MASTER_ADDR and MASTER_PORT 

41 self.master_addr = self.hostnames[0] 

42 self.master_port = str(12345 + int(min(self.gpu_ids) if self.gpu_ids else 0)) 

43 

44 # Update environment variables for distributed training 

45 os.environ['MASTER_ADDR'] = self.master_addr 

46 os.environ['MASTER_PORT'] = self.master_port