Coverage for melissa/server/deep_learning/torch_server.py: 31%

93 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-09-22 10:36 +0000

1import logging 

2import os 

3from abc import abstractmethod 

4from typing import Any, Dict 

5from torch.nn.parallel import DistributedDataParallel as DDP 

6import torch.distributed as dist 

7import torch 

8import torch.utils.data 

9import cloudpickle 

10from melissa.server.deep_learning.tensorboard_logger import TorchTensorboardLogger 

11from melissa.server.deep_learning.base_dl_server import DeepMelissaServer 

12 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17class TorchServer(DeepMelissaServer): 

18 """ 

19 Director to be used for any DeepMelissa study. 

20 The MelissaServer is initialized with the proper options 

21 self.start() sets the order of operations including the 

22 user created training loop "train()" 

23 """ 

24 

25 def __init__(self, config: Dict[str, Any]): 

26 super().__init__(config) 

27 self.model: Any = None 

28 

29 def setup_environment(self): 

30 """ 

31 Sets environment for distributed GPU training if desired. 

32 """ 

33 os.environ["MASTER_ADDR"] = "127.0.0.1" 

34 os.environ["MASTER_PORT"] = "29500" 

35 if torch.cuda.is_available() and torch.cuda.device_count() >= self.num_server_proc: 

36 world_size = torch.cuda.device_count() 

37 self.device = f"cuda:{self.rank}" 

38 backend = 'nccl' 

39 logger.info(f"available with world_size {world_size}") 

40 else: 

41 world_size = self.num_server_proc 

42 self.device = f"cpu:{self.rank}" 

43 backend = 'gloo' 

44 logger.info(f"Rank {self.rank} - self device: {self.device} - world_size {world_size}") 

45 

46 dist.init_process_group( 

47 backend, rank=self.rank, world_size=world_size) 

48 

49 # initialize tensorboardLogger 

50 self.tb_logger = TorchTensorboardLogger( 

51 self.rank, 

52 disable=not self.dl_config["tensorboard"], 

53 debug=self.debug 

54 ) 

55 return 

56 

57 def setup_environment_slurm(self): 

58 """ 

59 Uses JZ recommendations for setting up multi-node DDP environment with slurm 

60 """ 

61 from melissa.utility import idr_torch 

62 

63 if torch.cuda.is_available(): 

64 torch.cuda.set_device(idr_torch.local_rank) 

65 world_size = idr_torch.size 

66 self.device = f"cuda:{idr_torch.local_rank}" 

67 self.idr_rank = idr_torch.local_rank 

68 backend = 'nccl' 

69 logger.info(f"available with world_size {world_size}") 

70 else: 

71 logger.error("Using setup_slurm_dpp requires GPU reservations. No GPU found.") 

72 raise RuntimeError 

73 

74 dist.init_process_group( 

75 backend, init_method="env://", rank=idr_torch.rank, world_size=idr_torch.size) 

76 return 

77 

78 def server_online(self): 

79 """ 

80 What the server should do while "online" 

81 """ 

82 

83 self.train() 

84 

85 return 

86 

87 def wrap_model_ddp(self): 

88 

89 if not self.setup_slurm_ddp and "cuda" in self.device: 

90 model = DDP(self.model, device_ids=[self.device]) 

91 elif self.setup_slurm_ddp: 

92 model = DDP(self.model, device_ids=[self.idr_rank]) 

93 else: 

94 model = DDP(self.model) 

95 

96 return model 

97 

98 def server_finalize(self): 

99 """ 

100 All finalization methods go here. 

101 """ 

102 logger.info("Stop Server") 

103 self.write_final_report() 

104 self.close_connection() 

105 dist.destroy_process_group() 

106 

107 return 

108 

109 def synchronize_data_availability(self) -> bool: 

110 """ 

111 Coordinates dataset to be sure there are data available to be processed. 

112 This is to avoid any deadlock in all_reduce of the gradients. 

113 """ 

114 is_ready = self.dataset.has_data 

115 if "cuda" in self.device: 

116 is_ready = torch.tensor( 

117 is_ready, dtype=bool, device=self.device) # type: ignore 

118 else: 

119 is_ready = torch.tensor( 

120 int(is_ready), dtype=int, device=self.device) # type: ignore 

121 

122 dist.all_reduce(is_ready, op=dist.ReduceOp.PRODUCT) 

123 

124 return bool(is_ready) 

125 

126 @abstractmethod 

127 def configure_data_collection(self): 

128 """ 

129 Instantiates the data collector and buffer. 

130 """ 

131 return 

132 

133 @abstractmethod 

134 def train(self): 

135 """ 

136 Use-case based training loop. 

137 """ 

138 return 

139 

140 def test(self, model: Any): 

141 """ 

142 User can setup a test function if desired. 

143 Not required. 

144 """ 

145 return model 

146 

147 def load_model_from_checkpoint(self): 

148 """ 

149 Torch specific load pattern 

150 """ 

151 with open('checkpoints/net_arch.pkl', 'rb') as f: 

152 self.model = cloudpickle.load(f) 

153 

154 self.model.to(self.device) 

155 self.model = self.wrap_model_ddp() 

156 

157 checkpoint = torch.load("checkpoints/model.pt") 

158 self.optimizer = checkpoint["optimizer"] 

159 self.model.load_state_dict(checkpoint["model_state_dict"]) 

160 self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 

161 self.batch_offset = checkpoint["batch"] 

162 

163 return 

164 

165 def checkpoint( 

166 self, 

167 batch: int = 0, 

168 path: str = "checkpoints", 

169 ): 

170 

171 with self.buffer.mutex: 

172 self.checkpoint_state() 

173 

174 if self.rank == 0: 

175 with open(f'{path}/net_arch.pkl', 'wb') as f: 

176 cloudpickle.dump(self.model.module, f) 

177 

178 torch.save( 

179 { 

180 "optimizer": self.optimizer, 

181 "batch": batch, 

182 "model_state_dict": self.model.state_dict(), 

183 "optimizer_state_dict": self.optimizer.state_dict(), 

184 }, 

185 f"{path}/model.pt", 

186 )