Coverage for melissa/server/deep_learning/tf_server.py: 0%

82 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-09-22 10:36 +0000

1import os 

2import logging 

3from abc import abstractmethod 

4from typing import Any, Dict 

5from mpi4py import MPI 

6import tensorflow as tf 

7 

8from melissa.server.deep_learning.tensorboard_logger import TfTensorboardLogger 

9from melissa.server.deep_learning.base_dl_server import DeepMelissaServer 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14class TFServer(DeepMelissaServer): 

15 """ 

16 Director to be used for any DeepMelissa study. 

17 The MelissaServer is initialized with the proper options 

18 self.start() sets the order of operations including the 

19 user created training loop "train()" 

20 """ 

21 

22 def __init__(self, config: Dict[str, Any]): 

23 super().__init__(config) 

24 self.model: Any = None 

25 self.optimizer: Any = None 

26 

27 def setup_environment(self): 

28 """ 

29 Sets environment for data distributed training if desired. 

30 """ 

31 try: # gpu execution 

32 physical_devices = tf.config.list_physical_devices('GPU') 

33 

34 if not physical_devices: 

35 raise Exception("No GPU found") 

36 

37 local_rank = self.rank % len(physical_devices) 

38 tf.config.set_visible_devices(physical_devices[local_rank], 'GPU') 

39 

40 if 'SLURM_NODELIST' in os.environ: 

41 logger.info("Slurm cluster initialization") 

42 # build multi-worker environment from Slurm variables 

43 cluster_resolver = tf.distribute.cluster_resolver.SlurmClusterResolver( 

44 port_base=12345, 

45 gpus_per_node=self.num_server_proc // int(os.environ['SLURM_NNODES']), 

46 auto_set_gpu=False 

47 ) 

48 

49 elif 'OAR_NODEFILE' in os.environ: 

50 logger.info("OAR cluster initialization") 

51 # SLURM_PROCID is expected and used to get the worker rank 

52 os.environ['SLURM_PROCID'] = str(self.rank) 

53 # extract the hostlist 

54 with open(os.environ['OAR_NODEFILE']) as my_file: 

55 host_list = list(set(my_file.read().splitlines())) 

56 cluster_resolver = tf.distribute.cluster_resolver.SlurmClusterResolver( 

57 jobs={'worker': self.num_server_proc}, 

58 port_base=12345, 

59 gpus_per_node=self.num_server_proc // len(host_list), 

60 gpus_per_task=1, # this will enforce gpu/process correspondence 

61 tasks_per_node={ 

62 host: self.num_server_proc // len(host_list) for host in host_list 

63 }, 

64 auto_set_gpu=False 

65 ) 

66 

67 else: 

68 logger.info("local cluster initialization") 

69 # SLURM_PROCID is expected and used to get the worker rank 

70 os.environ['SLURM_PROCID'] = str(self.rank) 

71 cluster_resolver = tf.distribute.cluster_resolver.SlurmClusterResolver( 

72 jobs={'worker': self.num_server_proc}, 

73 port_base=12345, 

74 gpus_per_node=None, # this will be retrieved with nvidia-smi 

75 gpus_per_task=1, # this will enforce gpu/process correspondence 

76 tasks_per_node={os.uname()[1]: self.num_server_proc}, 

77 auto_set_gpu=False 

78 ) 

79 

80 # print list of physical and visible devices 

81 logger.info(f"list of physical devices: {physical_devices}") 

82 logger.info(f"list of visible devices: {tf.config.get_visible_devices('GPU')}") 

83 

84 # use NCCL communication protocol 

85 implementation = tf.distribute.experimental.CommunicationImplementation.NCCL 

86 communication_options = tf.distribute.experimental.CommunicationOptions( 

87 implementation=implementation) 

88 

89 # declare distribution strategy 

90 self.strategy = tf.distribute.MultiWorkerMirroredStrategy( 

91 cluster_resolver=cluster_resolver, communication_options=communication_options) 

92 

93 except Exception as e: # cpu execution 

94 logger.info(f"Slurm, oar and local cluster initialization failed with exception {e}") 

95 

96 if len(tf.config.list_physical_devices('CPU')) > 1: 

97 raise Exception("tensorflow cannot be distributed on multiple non-gpu devices") 

98 

99 elif len(tf.config.list_physical_devices('CPU')) == 1: 

100 logger.info("default MultiWorkerMirroredStrategy") 

101 self.strategy = tf.distribute.MultiWorkerMirroredStrategy() 

102 

103 # initialize tensorboardLogger 

104 self.tb_logger = TfTensorboardLogger( 

105 self.rank, 

106 disable=not self.dl_config["tensorboard"], 

107 debug=self.debug 

108 ) 

109 

110 def server_online(self): 

111 """ 

112 What the server should do while "online" 

113 """ 

114 # main thread heads over to handle distributed training 

115 self.train() 

116 return 

117 

118 def server_finalize(self): 

119 """ 

120 All finalization methods go here. 

121 """ 

122 logger.info("Stop Server") 

123 self.write_final_report() 

124 self.close_connection() 

125 

126 return 

127 

128 def synchronize_data_availability(self) -> bool: 

129 """ 

130 Coordinates dataset to be sure there are data available to be processed. 

131 This is to avoid any deadlock in all_reduce of the gradients. 

132 """ 

133 is_ready = self.dataset.has_data 

134 reduced_receiving = 0 

135 reduced_receiving += self.comm.allreduce(int(is_ready), op=MPI.SUM) 

136 is_ready = True if reduced_receiving >= self.num_server_proc else False 

137 return bool(is_ready) 

138 

139 @abstractmethod 

140 def configure_data_collection(self): 

141 """ 

142 Instantiates the data collector and buffer. 

143 """ 

144 return 

145 

146 @abstractmethod 

147 def train(self): 

148 """ 

149 Use-case based training loop. 

150 """ 

151 return 

152 

153 def test(self, model: Any): 

154 """ 

155 User can setup a test function if desired. 

156 Not required. 

157 """ 

158 return model 

159 

160 def load_model_from_checkpoint(self): 

161 """ 

162 Tensorflow specific load pattern 

163 """ 

164 self.set_model() 

165 step = tf.Variable(0, trainable=False) 

166 with self.strategy.scope(): 

167 restore = tf.train.Checkpoint( 

168 step=step, optimizer=self.optimizer, model=self.model 

169 ) 

170 restore.read("checkpoints/model.pt") 

171 

172 self.batch_offset = step.numpy() 

173 

174 return 

175 

176 def checkpoint( 

177 self, 

178 batch: int = 0, 

179 path: str = "checkpoints", 

180 ): 

181 

182 with self.buffer.mutex: 

183 self.checkpoint_state() 

184 # tensorflow checkpoint 

185 ckpt = tf.train.Checkpoint( 

186 step=tf.Variable(batch, trainable=False), optimizer=self.optimizer, model=self.model 

187 ) 

188 if self.rank == 0: 

189 ckpt.write(f'{path}/model.pt')