Coverage for melissa/utility/rank_helper.py: 67%
89 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
1import os
2from dataclasses import dataclass, field
3from typing import Tuple, Callable, Any, Optional, List, Dict
4from functools import wraps
6from numpy.typing import DTypeLike
7from mpi4py.util.dtlib import to_numpy_dtype
8from mpi4py import MPI
10SAMPLING_RANK: Optional[int] = None
11MPI2NP_DT: Dict[str, DTypeLike] = {
12 "int": to_numpy_dtype(MPI.INT),
13 "float": to_numpy_dtype(MPI.FLOAT),
14 "double": to_numpy_dtype(MPI.DOUBLE),
15}
18@dataclass
19class ClusterEnvironment:
20 comm_world: MPI.Comm = field(init=False)
21 comm_world_size: int = field(init=False)
22 comm_world_rank: int = field(init=False)
23 comm_world_local_size: int = field(init=False)
24 comm_world_local_rank: int = field(init=False)
25 universe_size: int = field(init=False)
26 comm_world_node_rank: int = field(init=False)
28 def __post_init__(self):
29 self.comm_world = MPI.COMM_WORLD
30 if 'SLURM_PROCID' in os.environ:
31 self.comm_world_size = int(os.environ.get('SLURM_STEP_NUM_TASKS', 1))
32 self.comm_world_rank = int(os.environ.get('SLURM_PROCID', 0))
33 self.comm_world_local_rank = int(os.environ.get('SLURM_LOCALID', 0))
34 self.universe_size = int(os.environ.get('SLURM_STEP_NUM_NODES', 1))
35 if self.universe_size == 0:
36 self.universe_size = int(
37 os.environ.get('SLURM_NNODES', 1)
38 )
39 self.comm_world_node_rank = int(os.environ.get('SLURM_NODEID', 0))
40 self.comm_world_local_size = self.__get_slurm_tasks_for_node(self.comm_world_node_rank)
41 if self.comm_world_local_size == 0:
42 self.comm_world_local_size = int(
43 os.environ.get('SLURM_NTASKS_PER_NODE', '0')
44 )
45 else:
46 self.comm_world_size = int(os.environ.get('OMPI_COMM_WORLD_SIZE', 1))
47 self.comm_world_rank = int(os.environ.get('OMPI_COMM_WORLD_RANK', 0))
48 self.comm_world_local_size = int(os.environ.get('OMPI_COMM_WORLD_LOCAL_SIZE', 1))
49 self.comm_world_local_rank = int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK', 0))
50 self.universe_size = int(os.environ.get('OMPI_UNIVERSE_SIZE', 1))
51 self.comm_world_node_rank = int(os.environ.get('OMPI_COMM_WORLD_NODE_RANK', 0))
53 def __get_slurm_tasks_for_node(self, nodeid: int):
54 tasks_per_node_str = os.environ.get('SLURM_STEP_TASKS_PER_NODE', '1')
55 if '(' in tasks_per_node_str:
56 # handle the case like "28,40(x2),28(x5)"
57 tasks_per_node = []
58 for part in tasks_per_node_str.split(','):
59 if '(' in part:
60 count, times = part.split('(')
61 times = times.rstrip(')').split('x')[1]
62 tasks_per_node.extend([int(count)] * int(times))
63 else:
64 tasks_per_node.append(int(part))
65 else:
66 # handle the case like 40, 44
67 tasks_per_node = list(map(int, tasks_per_node_str.split(',')))
69 return tasks_per_node[nodeid]
72def get_rank_and_num_server_proc() -> Tuple[int, int]:
73 """Returns rank and world size of the current `MPI.COMM_WORLD`."""
74 cluster = ClusterEnvironment()
75 return cluster.comm_world_rank, cluster.comm_world_size
78def __set_sampling_rank() -> int:
79 """Determines and sets the sampling rank based on the maximum local size of MPI processes."""
80 cluster = ClusterEnvironment()
81 comm = cluster.comm_world
82 local_size = cluster.comm_world_local_size
83 node_rank = cluster.comm_world_node_rank
84 sampling_rank = local_size - 1
86 # gather for finding the max local size
87 gathered_data: Optional[List[Tuple[int, int]]] = comm.gather((node_rank, local_size), root=0)
89 if cluster.comm_world_rank == 0:
90 if gathered_data:
91 max_node_rank, max_local_size = max(gathered_data, key=lambda x: x[1])
92 sampling_rank = (max_node_rank + 1) * max_local_size - 1
94 # broadcast the sampling rank to all processes
95 sampling_rank = comm.bcast(sampling_rank, root=0)
97 return sampling_rank
100def initialize_sampling_rank():
101 global SAMPLING_RANK
102 if SAMPLING_RANK is None:
103 SAMPLING_RANK = __set_sampling_rank()
106def get_sampling_rank() -> int:
107 assert SAMPLING_RANK is not None
108 return SAMPLING_RANK
111def is_sampling_rank() -> bool:
112 rank, _ = get_rank_and_num_server_proc()
113 return rank == get_sampling_rank()
116def sampling_rank_only(fn: Callable) -> Callable:
117 """Decorator to run the decorated function only on the sampling MPI rank."""
118 @wraps(fn)
119 def magic_fn(*args, **kwargs):
120 if is_sampling_rank():
121 return fn(*args, **kwargs)
122 return None
124 return magic_fn
127def rank_zero_only(fn_: Callable) -> Callable:
128 """Function that can be used as a decorator to enable a function/method
129 being called only on rank 0. Inspired by pytorch_lightning"""
131 rank, _ = get_rank_and_num_server_proc()
133 @wraps(fn_)
134 def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]:
135 if rank == 0:
136 return fn_(*args, **kwargs)
137 return None
139 return wrapped_fn