Coverage for melissa/utility/rank_helper.py: 67%

89 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-11-03 09:52 +0100

1import os 

2from dataclasses import dataclass, field 

3from typing import Tuple, Callable, Any, Optional, List, Dict 

4from functools import wraps 

5 

6from numpy.typing import DTypeLike 

7from mpi4py.util.dtlib import to_numpy_dtype 

8from mpi4py import MPI 

9 

10SAMPLING_RANK: Optional[int] = None 

11MPI2NP_DT: Dict[str, DTypeLike] = { 

12 "int": to_numpy_dtype(MPI.INT), 

13 "float": to_numpy_dtype(MPI.FLOAT), 

14 "double": to_numpy_dtype(MPI.DOUBLE), 

15} 

16 

17 

18@dataclass 

19class ClusterEnvironment: 

20 comm_world: MPI.Comm = field(init=False) 

21 comm_world_size: int = field(init=False) 

22 comm_world_rank: int = field(init=False) 

23 comm_world_local_size: int = field(init=False) 

24 comm_world_local_rank: int = field(init=False) 

25 universe_size: int = field(init=False) 

26 comm_world_node_rank: int = field(init=False) 

27 

28 def __post_init__(self): 

29 self.comm_world = MPI.COMM_WORLD 

30 if 'SLURM_PROCID' in os.environ: 

31 self.comm_world_size = int(os.environ.get('SLURM_STEP_NUM_TASKS', 1)) 

32 self.comm_world_rank = int(os.environ.get('SLURM_PROCID', 0)) 

33 self.comm_world_local_rank = int(os.environ.get('SLURM_LOCALID', 0)) 

34 self.universe_size = int(os.environ.get('SLURM_STEP_NUM_NODES', 1)) 

35 if self.universe_size == 0: 

36 self.universe_size = int( 

37 os.environ.get('SLURM_NNODES', 1) 

38 ) 

39 self.comm_world_node_rank = int(os.environ.get('SLURM_NODEID', 0)) 

40 self.comm_world_local_size = self.__get_slurm_tasks_for_node(self.comm_world_node_rank) 

41 if self.comm_world_local_size == 0: 

42 self.comm_world_local_size = int( 

43 os.environ.get('SLURM_NTASKS_PER_NODE', '0') 

44 ) 

45 else: 

46 self.comm_world_size = int(os.environ.get('OMPI_COMM_WORLD_SIZE', 1)) 

47 self.comm_world_rank = int(os.environ.get('OMPI_COMM_WORLD_RANK', 0)) 

48 self.comm_world_local_size = int(os.environ.get('OMPI_COMM_WORLD_LOCAL_SIZE', 1)) 

49 self.comm_world_local_rank = int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK', 0)) 

50 self.universe_size = int(os.environ.get('OMPI_UNIVERSE_SIZE', 1)) 

51 self.comm_world_node_rank = int(os.environ.get('OMPI_COMM_WORLD_NODE_RANK', 0)) 

52 

53 def __get_slurm_tasks_for_node(self, nodeid: int): 

54 tasks_per_node_str = os.environ.get('SLURM_STEP_TASKS_PER_NODE', '1') 

55 if '(' in tasks_per_node_str: 

56 # handle the case like "28,40(x2),28(x5)" 

57 tasks_per_node = [] 

58 for part in tasks_per_node_str.split(','): 

59 if '(' in part: 

60 count, times = part.split('(') 

61 times = times.rstrip(')').split('x')[1] 

62 tasks_per_node.extend([int(count)] * int(times)) 

63 else: 

64 tasks_per_node.append(int(part)) 

65 else: 

66 # handle the case like 40, 44 

67 tasks_per_node = list(map(int, tasks_per_node_str.split(','))) 

68 

69 return tasks_per_node[nodeid] 

70 

71 

72def get_rank_and_num_server_proc() -> Tuple[int, int]: 

73 """Returns rank and world size of the current `MPI.COMM_WORLD`.""" 

74 cluster = ClusterEnvironment() 

75 return cluster.comm_world_rank, cluster.comm_world_size 

76 

77 

78def __set_sampling_rank() -> int: 

79 """Determines and sets the sampling rank based on the maximum local size of MPI processes.""" 

80 cluster = ClusterEnvironment() 

81 comm = cluster.comm_world 

82 local_size = cluster.comm_world_local_size 

83 node_rank = cluster.comm_world_node_rank 

84 sampling_rank = local_size - 1 

85 

86 # gather for finding the max local size 

87 gathered_data: Optional[List[Tuple[int, int]]] = comm.gather((node_rank, local_size), root=0) 

88 

89 if cluster.comm_world_rank == 0: 

90 if gathered_data: 

91 max_node_rank, max_local_size = max(gathered_data, key=lambda x: x[1]) 

92 sampling_rank = (max_node_rank + 1) * max_local_size - 1 

93 

94 # broadcast the sampling rank to all processes 

95 sampling_rank = comm.bcast(sampling_rank, root=0) 

96 

97 return sampling_rank 

98 

99 

100def initialize_sampling_rank(): 

101 global SAMPLING_RANK 

102 if SAMPLING_RANK is None: 

103 SAMPLING_RANK = __set_sampling_rank() 

104 

105 

106def get_sampling_rank() -> int: 

107 assert SAMPLING_RANK is not None 

108 return SAMPLING_RANK 

109 

110 

111def is_sampling_rank() -> bool: 

112 rank, _ = get_rank_and_num_server_proc() 

113 return rank == get_sampling_rank() 

114 

115 

116def sampling_rank_only(fn: Callable) -> Callable: 

117 """Decorator to run the decorated function only on the sampling MPI rank.""" 

118 @wraps(fn) 

119 def magic_fn(*args, **kwargs): 

120 if is_sampling_rank(): 

121 return fn(*args, **kwargs) 

122 return None 

123 

124 return magic_fn 

125 

126 

127def rank_zero_only(fn_: Callable) -> Callable: 

128 """Function that can be used as a decorator to enable a function/method 

129 being called only on rank 0. Inspired by pytorch_lightning""" 

130 

131 rank, _ = get_rank_and_num_server_proc() 

132 

133 @wraps(fn_) 

134 def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]: 

135 if rank == 0: 

136 return fn_(*args, **kwargs) 

137 return None 

138 

139 return wrapped_fn