Coverage for melissa/server/deep_learning/tf_server.py: 0%
82 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
1import os
2import logging
3from abc import abstractmethod
4from typing import Any, Dict
5from mpi4py import MPI
6import tensorflow as tf
8from melissa.server.deep_learning.tensorboard_logger import TfTensorboardLogger
9from melissa.server.deep_learning.base_dl_server import DeepMelissaServer
11logger = logging.getLogger(__name__)
14class TFServer(DeepMelissaServer):
15 """
16 Director to be used for any DeepMelissa study.
17 The MelissaServer is initialized with the proper options
18 self.start() sets the order of operations including the
19 user created training loop "train()"
20 """
22 def __init__(self, config: Dict[str, Any]):
23 super().__init__(config)
24 self.model: Any = None
25 self.optimizer: Any = None
27 def setup_environment(self):
28 """
29 Sets environment for data distributed training if desired.
30 """
31 try: # gpu execution
32 physical_devices = tf.config.list_physical_devices('GPU')
34 if not physical_devices:
35 raise Exception("No GPU found")
37 local_rank = self.rank % len(physical_devices)
38 tf.config.set_visible_devices(physical_devices[local_rank], 'GPU')
40 if 'SLURM_NODELIST' in os.environ:
41 logger.info("Slurm cluster initialization")
42 # build multi-worker environment from Slurm variables
43 cluster_resolver = tf.distribute.cluster_resolver.SlurmClusterResolver(
44 port_base=12345,
45 gpus_per_node=self.num_server_proc // int(os.environ['SLURM_NNODES']),
46 auto_set_gpu=False
47 )
49 elif 'OAR_NODEFILE' in os.environ:
50 logger.info("OAR cluster initialization")
51 # SLURM_PROCID is expected and used to get the worker rank
52 os.environ['SLURM_PROCID'] = str(self.rank)
53 # extract the hostlist
54 with open(os.environ['OAR_NODEFILE']) as my_file:
55 host_list = list(set(my_file.read().splitlines()))
56 cluster_resolver = tf.distribute.cluster_resolver.SlurmClusterResolver(
57 jobs={'worker': self.num_server_proc},
58 port_base=12345,
59 gpus_per_node=self.num_server_proc // len(host_list),
60 gpus_per_task=1, # this will enforce gpu/process correspondence
61 tasks_per_node={
62 host: self.num_server_proc // len(host_list) for host in host_list
63 },
64 auto_set_gpu=False
65 )
67 else:
68 logger.info("local cluster initialization")
69 # SLURM_PROCID is expected and used to get the worker rank
70 os.environ['SLURM_PROCID'] = str(self.rank)
71 cluster_resolver = tf.distribute.cluster_resolver.SlurmClusterResolver(
72 jobs={'worker': self.num_server_proc},
73 port_base=12345,
74 gpus_per_node=None, # this will be retrieved with nvidia-smi
75 gpus_per_task=1, # this will enforce gpu/process correspondence
76 tasks_per_node={os.uname()[1]: self.num_server_proc},
77 auto_set_gpu=False
78 )
80 # print list of physical and visible devices
81 logger.info(f"list of physical devices: {physical_devices}")
82 logger.info(f"list of visible devices: {tf.config.get_visible_devices('GPU')}")
84 # use NCCL communication protocol
85 implementation = tf.distribute.experimental.CommunicationImplementation.NCCL
86 communication_options = tf.distribute.experimental.CommunicationOptions(
87 implementation=implementation)
89 # declare distribution strategy
90 self.strategy = tf.distribute.MultiWorkerMirroredStrategy(
91 cluster_resolver=cluster_resolver, communication_options=communication_options)
93 except Exception as e: # cpu execution
94 logger.info(f"Slurm, oar and local cluster initialization failed with exception {e}")
96 if len(tf.config.list_physical_devices('CPU')) > 1:
97 raise Exception("tensorflow cannot be distributed on multiple non-gpu devices")
99 elif len(tf.config.list_physical_devices('CPU')) == 1:
100 logger.info("default MultiWorkerMirroredStrategy")
101 self.strategy = tf.distribute.MultiWorkerMirroredStrategy()
103 # initialize tensorboardLogger
104 self.tb_logger = TfTensorboardLogger(
105 self.rank,
106 disable=not self.dl_config["tensorboard"],
107 debug=self.debug
108 )
110 def server_online(self):
111 """
112 What the server should do while "online"
113 """
114 # main thread heads over to handle distributed training
115 self.train()
116 return
118 def server_finalize(self):
119 """
120 All finalization methods go here.
121 """
122 logger.info("Stop Server")
123 self.write_final_report()
124 self.close_connection()
126 return
128 def synchronize_data_availability(self) -> bool:
129 """
130 Coordinates dataset to be sure there are data available to be processed.
131 This is to avoid any deadlock in all_reduce of the gradients.
132 """
133 is_ready = self.dataset.has_data
134 reduced_receiving = 0
135 reduced_receiving += self.comm.allreduce(int(is_ready), op=MPI.SUM)
136 is_ready = True if reduced_receiving >= self.num_server_proc else False
137 return bool(is_ready)
139 @abstractmethod
140 def configure_data_collection(self):
141 """
142 Instantiates the data collector and buffer.
143 """
144 return
146 @abstractmethod
147 def train(self):
148 """
149 Use-case based training loop.
150 """
151 return
153 def test(self, model: Any):
154 """
155 User can setup a test function if desired.
156 Not required.
157 """
158 return model
160 def load_model_from_checkpoint(self):
161 """
162 Tensorflow specific load pattern
163 """
164 self.set_model()
165 step = tf.Variable(0, trainable=False)
166 with self.strategy.scope():
167 restore = tf.train.Checkpoint(
168 step=step, optimizer=self.optimizer, model=self.model
169 )
170 restore.read("checkpoints/model.pt")
172 self.batch_offset = step.numpy()
174 return
176 def checkpoint(
177 self,
178 batch: int = 0,
179 path: str = "checkpoints",
180 ):
182 with self.buffer.mutex:
183 self.checkpoint_state()
184 # tensorflow checkpoint
185 ckpt = tf.train.Checkpoint(
186 step=tf.Variable(batch, trainable=False), optimizer=self.optimizer, model=self.model
187 )
188 if self.rank == 0:
189 ckpt.write(f'{path}/model.pt')