Coverage for melissa/server/deep_learning/torch_server.py: 31%
93 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
1import logging
2import os
3from abc import abstractmethod
4from typing import Any, Dict
5from torch.nn.parallel import DistributedDataParallel as DDP
6import torch.distributed as dist
7import torch
8import torch.utils.data
9import cloudpickle
10from melissa.server.deep_learning.tensorboard_logger import TorchTensorboardLogger
11from melissa.server.deep_learning.base_dl_server import DeepMelissaServer
14logger = logging.getLogger(__name__)
17class TorchServer(DeepMelissaServer):
18 """
19 Director to be used for any DeepMelissa study.
20 The MelissaServer is initialized with the proper options
21 self.start() sets the order of operations including the
22 user created training loop "train()"
23 """
25 def __init__(self, config: Dict[str, Any]):
26 super().__init__(config)
27 self.model: Any = None
29 def setup_environment(self):
30 """
31 Sets environment for distributed GPU training if desired.
32 """
33 os.environ["MASTER_ADDR"] = "127.0.0.1"
34 os.environ["MASTER_PORT"] = "29500"
35 if torch.cuda.is_available() and torch.cuda.device_count() >= self.num_server_proc:
36 world_size = torch.cuda.device_count()
37 self.device = f"cuda:{self.rank}"
38 backend = 'nccl'
39 logger.info(f"available with world_size {world_size}")
40 else:
41 world_size = self.num_server_proc
42 self.device = f"cpu:{self.rank}"
43 backend = 'gloo'
44 logger.info(f"Rank {self.rank} - self device: {self.device} - world_size {world_size}")
46 dist.init_process_group(
47 backend, rank=self.rank, world_size=world_size)
49 # initialize tensorboardLogger
50 self.tb_logger = TorchTensorboardLogger(
51 self.rank,
52 disable=not self.dl_config["tensorboard"],
53 debug=self.debug
54 )
55 return
57 def setup_environment_slurm(self):
58 """
59 Uses JZ recommendations for setting up multi-node DDP environment with slurm
60 """
61 from melissa.utility import idr_torch
63 if torch.cuda.is_available():
64 torch.cuda.set_device(idr_torch.local_rank)
65 world_size = idr_torch.size
66 self.device = f"cuda:{idr_torch.local_rank}"
67 self.idr_rank = idr_torch.local_rank
68 backend = 'nccl'
69 logger.info(f"available with world_size {world_size}")
70 else:
71 logger.error("Using setup_slurm_dpp requires GPU reservations. No GPU found.")
72 raise RuntimeError
74 dist.init_process_group(
75 backend, init_method="env://", rank=idr_torch.rank, world_size=idr_torch.size)
76 return
78 def server_online(self):
79 """
80 What the server should do while "online"
81 """
83 self.train()
85 return
87 def wrap_model_ddp(self):
89 if not self.setup_slurm_ddp and "cuda" in self.device:
90 model = DDP(self.model, device_ids=[self.device])
91 elif self.setup_slurm_ddp:
92 model = DDP(self.model, device_ids=[self.idr_rank])
93 else:
94 model = DDP(self.model)
96 return model
98 def server_finalize(self):
99 """
100 All finalization methods go here.
101 """
102 logger.info("Stop Server")
103 self.write_final_report()
104 self.close_connection()
105 dist.destroy_process_group()
107 return
109 def synchronize_data_availability(self) -> bool:
110 """
111 Coordinates dataset to be sure there are data available to be processed.
112 This is to avoid any deadlock in all_reduce of the gradients.
113 """
114 is_ready = self.dataset.has_data
115 if "cuda" in self.device:
116 is_ready = torch.tensor(
117 is_ready, dtype=bool, device=self.device) # type: ignore
118 else:
119 is_ready = torch.tensor(
120 int(is_ready), dtype=int, device=self.device) # type: ignore
122 dist.all_reduce(is_ready, op=dist.ReduceOp.PRODUCT)
124 return bool(is_ready)
126 @abstractmethod
127 def configure_data_collection(self):
128 """
129 Instantiates the data collector and buffer.
130 """
131 return
133 @abstractmethod
134 def train(self):
135 """
136 Use-case based training loop.
137 """
138 return
140 def test(self, model: Any):
141 """
142 User can setup a test function if desired.
143 Not required.
144 """
145 return model
147 def load_model_from_checkpoint(self):
148 """
149 Torch specific load pattern
150 """
151 with open('checkpoints/net_arch.pkl', 'rb') as f:
152 self.model = cloudpickle.load(f)
154 self.model.to(self.device)
155 self.model = self.wrap_model_ddp()
157 checkpoint = torch.load("checkpoints/model.pt")
158 self.optimizer = checkpoint["optimizer"]
159 self.model.load_state_dict(checkpoint["model_state_dict"])
160 self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
161 self.batch_offset = checkpoint["batch"]
163 return
165 def checkpoint(
166 self,
167 batch: int = 0,
168 path: str = "checkpoints",
169 ):
171 with self.buffer.mutex:
172 self.checkpoint_state()
174 if self.rank == 0:
175 with open(f'{path}/net_arch.pkl', 'wb') as f:
176 cloudpickle.dump(self.model.module, f)
178 torch.save(
179 {
180 "optimizer": self.optimizer,
181 "batch": batch,
182 "model_state_dict": self.model.state_dict(),
183 "optimizer_state_dict": self.optimizer.state_dict(),
184 },
185 f"{path}/model.pt",
186 )