Coverage for melissa/server/deep_learning/base_dl_server.py: 44%
241 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
1import logging
2import os
3import threading
4from abc import abstractmethod
5from functools import wraps
6from typing import Any, Callable, Dict, Optional, Tuple, Union
7import time
8import numpy as np
10from melissa.server.deep_learning.tensorboard_logger import TensorboardLogger
11from melissa.server.deep_learning.dataset import MelissaIterableDataset
12from melissa.server.deep_learning.reservoir import FIFO
13from melissa.server.base_server import BaseServer
14from melissa.server.simulation import (PartialSimulationData, Simulation,
15 SimulationData, SimulationDataStatus)
16from melissa.utility.networking import get_rank_and_num_server_proc
17from pathlib import Path
18import cloudpickle
21logger = logging.getLogger(__name__)
24class DeepMelissaServer(BaseServer):
25 """
26 Director to be used for any DeepMelissa study.
27 The MelissaServer is initialized with the proper options
28 self.start() sets the order of operations including the
29 user created training loop "train()"
30 """
32 def __init__(self, config: Dict[str, Any]):
33 super().__init__(config)
35 self.learning: int = 2
36 self.check_group_size()
37 self.dl_config: Dict[str, Any] = config['dl_config']
38 self.study_options: Dict[str, Any] = config['study_options']
39 self.debug = True if self.study_options["verbosity"] >= 3 else False
40 self.tb_logger = TensorboardLogger()
41 # define temporary elementary buffer for typing purposes.
42 self.buffer: FIFO = FIFO()
43 self.dataset: MelissaIterableDataset = MelissaIterableDataset(
44 buffer=self.buffer,
45 tb_logger=self.tb_logger,
46 config=self.config,
47 transform=self.process_simulation_data,
48 )
49 self.batch_size: int = self.dl_config['batch_size']
50 self.per_server_watermark: int = self.dl_config['per_server_watermark']
51 self.buffer_size: int = self.dl_config["buffer_size"]
52 self.pseudo_epochs: int = self.dl_config.get("pseudo_epochs", 1)
53 self.sample_number: int = 0
54 self.n_expected_batches: int = 1
55 self.idr_rank: int | None = None
56 self.setup_slurm_ddp: bool = self.dl_config.get("setup_slurm_ddp", False)
57 self.n_batches_update: int = self.dl_config["n_batches_update"]
58 self.batch_offset = 0
60 def check_group_size(self):
61 if self.group_size > 1 and self.num_clients % self.group_size != 0:
62 logger.error("Incorrect group_size, please remove or adjust this option")
63 self.catch_error = True
65 def start(self):
66 """
67 The main execution method
68 """
69 if not self.restart:
70 self.launch_first_groups()
71 if not self.setup_slurm_ddp:
72 self.setup_environment()
73 else:
74 self.setup_environment_slurm()
75 if not self.restart:
76 self.set_model()
77 else:
78 # the reinitialization from checkpoint occurs here
79 logger.info(f"Continuing from checkpoint {self.restart}")
80 self.restart_from_checkpoint()
81 if self.rank == 0:
82 self.kill_and_restart_simulations()
83 self.set_expected_batches_samples_watermark()
84 self.configure_data_collection()
86 # put server receive on a separate thread.
87 # should not be accesse by user
88 rcv = threading.Thread(target=self.receive)
89 rcv.start()
91 self.server_online()
93 self.tb_logger.close()
94 if self.dl_config.get("convert_log_to_df", False):
95 try:
96 self.convert_log_to_df()
97 except ImportError as e:
98 logger.error(f"Unable to import dependencies for log conversion {e}. "
99 "Please install pandas and tensorflow.")
101 self.server_finalize()
103 def setup_environment(self):
104 """
105 Sets environment for distributed GPU training if desired.
106 """
107 return
109 def setup_environment_slurm(self):
110 """
111 Unique DDP env setup with slurm as recommended by
112 http://www.idris.fr/eng/jean-zay/gpu/jean-zay-gpu-torch-multi-eng.html
113 """
114 return
116 @abstractmethod
117 def set_model(self):
118 """
119 Configure the server self.model to prepare for initialization
120 """
121 return
123 @abstractmethod
124 def server_online(self):
125 """
126 Initiating data collection and directing the custom
127 methods for acting on collected data.
128 """
129 return
131 def receive(self):
132 """
133 Handle data from the server.
134 """
135 try:
136 self._is_receiving = True
137 while not self.all_done():
138 start = time.time()
139 data = self.run()
140 if data is not None and isinstance(data, SimulationData):
141 logger.debug(
142 f"Receive Message {data.simulation_id} time-step {data.time_step}"
143 )
144 self.buffer.put(data)
145 self.sample_number += 1
146 self.tb_logger.log_scalar(
147 "put_time", time.time() - start, self.sample_number)
149 self._is_receiving = False
150 self.dataset.signal_reception_over()
151 logger.debug("Signal end of reception.")
153 except Exception as e:
154 logger.exception(f"Exception was raised in the receiving thread: \n {e}")
155 self.close_connection(1)
156 os._exit(1)
158 def handle_simulation_data(self, msg):
159 """
160 Parses and validates the incoming data messages from simulations
161 """
162 # 1. Deserialize message
163 msg_data: PartialSimulationData = PartialSimulationData.from_msg(msg, self.learning)
164 group_id = msg_data.simulation_id // self.group_size
165 logger.debug(
166 f"Rank {self.rank} received {msg_data} from rank {msg_data.client_rank} "
167 f"(vect_size: {msg_data.data_size})")
169 # 2. Apply filters
170 if msg_data.field == "termination":
171 logger.info(f"Rank {self.rank} received termination message "
172 f"from simulation {msg_data.simulation_id} with "
173 f"{msg_data.time_step} expected time-steps"
174 )
175 # increment the total number of expected time-steps if not known
176 if self.n_expected_batches == 0:
177 self.num_samples += msg_data.time_step
178 self.groups[group_id].simulations[msg_data.simulation_id].n_time_steps = (
179 msg_data.time_step
180 )
181 logger.info(f"Rank {self.rank}: simulation {msg_data.simulation_id} finished "
182 f"number of expected samples incremented to {self.num_samples}")
183 self.n_finished_simulations += 1
184 return None
185 if (
186 msg_data.time_step < 0
187 or (msg_data.time_step > self.num_samples and self.n_expected_batches > 0)
188 ):
189 logger.warning(f"Rank {self.rank}: bad time-step {msg_data.time_step}")
190 return None
191 if msg_data.field not in self.fields:
192 if msg_data.field != "termination":
193 logger.warning(f"Rank {self.rank}: bad field {msg_data.field}")
194 return None
196 simulation = self.groups[group_id].simulations[msg_data.simulation_id]
198 simulation_status, simulation_data = self.check_simulation_data(
199 simulation, msg_data)
200 if simulation_status == SimulationDataStatus.COMPLETE and simulation_data:
201 logger.debug(
202 f"Rank {self.rank}: assembled time-step {simulation_data.time_step} "
203 f"- simulationID {simulation_data.simulation_id}")
204 elif simulation_status == SimulationDataStatus.ALREADY_RECEIVED:
205 logger.warning(
206 f"Rank {self.rank}: duplicate simulation data {msg_data}")
208 # Check if simulation has finished
209 if (
210 simulation_status == SimulationDataStatus.COMPLETE
211 or simulation_status == SimulationDataStatus.EMPTY
212 ) and simulation.finished():
213 logger.info(f"Rank {self.rank}: simulation {simulation.id} finished")
214 self.n_finished_simulations += 1
216 return simulation_data
218 def check_simulation_data(
219 self, simulation: Simulation, simulation_data: PartialSimulationData
220 ) -> Tuple[
221 SimulationDataStatus,
222 Union[Optional[SimulationData], Optional[PartialSimulationData]],
223 ]:
224 """
225 Look for duplicated messages,
226 update received_simulation_data and the simulation_data status.
227 """
228 if simulation_data.client_rank not in simulation.received_simulation_data:
229 simulation.received_simulation_data[simulation_data.client_rank] = {}
230 # the 2D-array is allocated at once if the number of expected samples is known
231 if self.n_expected_batches != 0:
232 simulation.received_time_steps[simulation_data.client_rank] = (
233 np.zeros((len(self.fields), self.num_samples), dtype=bool)
234 )
235 # if not it is initialized with a single column
236 else:
237 simulation.received_time_steps[simulation_data.client_rank] = (
238 np.zeros((len(self.fields), 1), dtype=bool)
239 )
240 # if num_samples is unknown the received_time_step matrix is built on the fly
241 if (
242 simulation_data.time_step
243 > simulation.received_time_steps[simulation_data.client_rank].shape[1] - 1
244 ):
245 simulation.received_time_steps[simulation_data.client_rank] = np.concatenate(
246 [simulation.received_time_steps[simulation_data.client_rank],
247 np.zeros((len(self.fields), 1), dtype=bool)], axis=1
248 )
250 # Data have already been received
251 if simulation.has_already_received(
252 simulation_data.client_rank, simulation_data.time_step, simulation_data.field
253 ):
254 logger.debug(f"simulation {simulation.id} "
255 f"field {simulation_data.field} "
256 f"timestep {simulation_data.time_step} discarded")
257 return SimulationDataStatus.ALREADY_RECEIVED, None
258 # Time step has never been seen
259 if simulation_data.time_step not in simulation.received_simulation_data[
260 simulation_data.client_rank
261 ]:
262 simulation.received_simulation_data[simulation_data.client_rank][
263 simulation_data.time_step
264 ] = {field: None for field in simulation.fields}
265 # Update the entry
266 simulation.received_simulation_data[simulation_data.client_rank][
267 simulation_data.time_step
268 ][simulation_data.field] = simulation_data
269 simulation._mark_as_received(
270 simulation_data.client_rank, simulation_data.time_step, simulation_data.field
271 )
272 if simulation.is_complete(simulation_data.time_step):
273 # All fields have been received for the time step
274 simulation.n_received_time_steps += 1
275 # Check there is actual data
276 is_empty = simulation_data.data_size == 0
277 if is_empty:
278 # Data have been set to another device, fields are empty
279 del simulation.received_simulation_data[simulation_data.client_rank][
280 simulation_data.time_step
281 ]
282 return SimulationDataStatus.EMPTY, None
283 # Concatenate data in the same order as fields.
284 data = []
285 for sd in simulation.received_simulation_data[simulation_data.client_rank][
286 simulation_data.time_step].values():
287 if not sd:
288 logger.warning('No data dictionary found')
289 else:
290 assert isinstance(sd, PartialSimulationData)
291 data.append(sd.data)
293 del simulation.received_simulation_data[simulation_data.client_rank][
294 simulation_data.time_step
295 ]
296 return SimulationDataStatus.COMPLETE, SimulationData(
297 simulation_data.simulation_id,
298 simulation_data.time_step,
299 data,
300 simulation.parameters,
301 )
302 else:
303 # Not all fields have been received yet
304 return SimulationDataStatus.PARTIAL, None
306 def set_expected_batches_samples_watermark(self):
307 """
308 Takes user config and computes the expected samples per server proc
309 and expected batches per server proc
310 """
311 # standard case where num_samples is given in the config file
312 if self.num_samples > 0:
313 # ensure watermark is sufficient
314 self.check_water_mark()
316 # Account for possible accumulated shift
317 self.n_expected_samples = (self.num_clients // self.num_server_proc) * self.num_samples
318 self.n_expected_batches = (
319 self.n_expected_samples // self.batch_size * self.pseudo_epochs
320 )
322 if self.pseudo_epochs > 1 and self.buffer_size != self.n_expected_samples:
323 logger.warning(
324 "User tried using pseudo_epochs with buffer size smaller than expected "
325 "samples. Setting buffer size to number of expected samples "
326 f"({self.n_expected_samples})."
327 )
328 self.buffer_size = self.n_expected_samples
330 logger.info(
331 f"Expecting {self.n_expected_samples} "
332 f"samples across {self.n_expected_batches} batches.")
333 # when num_samples is not known a priori
334 else:
335 logger.info("Number of expected samples a priori unknown")
336 self.n_expected_batches = 0
338 def check_water_mark(self):
339 """
340 Ensures there are sufficient samples to reach the per_server_watermark
341 """
342 total_samples = (self.num_samples * self.num_clients)
343 samples_per_server = total_samples // self.num_server_proc
344 if not self.dl_config["per_server_watermark"] <= samples_per_server:
345 raise Exception('Insufficient samples to reach per_server_watermark. '
346 'please increase num_samples, or decrease per_server_watermark.')
348 def other_processes_finished(self, batch_number: int) -> bool:
349 """
350 Ensure the server processes are emptying their buffers together
351 after data reception is finished.
352 """
354 logger.debug(f"{self.rank} is on batch {batch_number + 1}/{self.n_expected_batches}")
356 # ensure self.dataset._is_receiving is in sync across all server
357 # processes.
358 data_available = self.synchronize_data_availability()
360 # in case of pseudo_offline training, we want to avoid a
361 # server timeout so we ping the launcher with time_monitor
362 if not data_available:
363 # at this point the total number of expected samples should be known
364 # and used to update the value of self.n_expected_batches
365 if self.n_expected_batches == 0:
366 # per client number of expected time-steps
367 self.num_samples //= self.num_clients
368 self.set_expected_batches_samples_watermark()
369 self.time_monitor.check_clock(time.monotonic(), self)
370 logger.debug("One of the server processes finished receiving. "
371 f"{self.rank} is on batch {batch_number + 1}/{self.n_expected_batches}")
373 return not data_available
375 def convert_log_to_df(self):
376 """
377 Convert local TensorBoard data into Pandas DataFrame.
378 Saves the pandas dataframe as a pickle file inside
379 out_dir/tensorboard.
380 """
381 from tensorflow.python.summary.summary_iterator import summary_iterator
382 import pandas as pd
384 def convert_tfevent(filepath):
385 return pd.DataFrame([
386 parse_tfevent(e) for e in summary_iterator(filepath) if len(e.summary.value)
387 ])
389 def parse_tfevent(tfevent):
390 return dict(
391 wall_time=tfevent.wall_time,
392 name=tfevent.summary.value[0].tag,
393 step=tfevent.step,
394 value=float(tfevent.summary.value[0].simple_value),
395 )
397 columns_order = ['wall_time', 'name', 'step', 'value']
399 out = []
400 for folder in Path("tensorboard").iterdir():
401 if f"gpu_{self.rank}" in str(folder):
402 for file in folder.iterdir():
403 if "events.out.tfevents" not in str(file):
404 continue
405 if f"rank_{self.rank}" not in str(file):
406 continue
407 logger.info(f"Parsing {str(file)}")
408 out.append(convert_tfevent(str(file)))
410 all_df = pd.concat(out)[columns_order]
411 all_df.reset_index(drop=True)
412 all_df.to_pickle(f"./tensorboard/data_rank_{self.rank}.pkl")
414 def server_finalize(self):
415 """
416 All finalization methods go here.
417 """
418 return
420 @abstractmethod
421 def configure_data_collection(self):
422 """
423 Instantiates the data collector and buffer.
424 """
425 return
427 @abstractmethod
428 def train(self):
429 """
430 Use-case based training loop.
431 """
432 return
434 def test(self, model: Any):
435 """
436 User can setup a test function if desired.
437 Not required.
438 """
439 return model
441 def synchronize_data_availability(self) -> bool:
442 """
443 Coordinates the dataset _is_receiving across all
444 server processes. This usually requires a library
445 specific all_reduce function (e.g. dist.all_reduce
446 in pytorch)
447 """
448 return True
450 def checkpoint_state(self):
451 """
452 Checkpoint the current state of the server
453 """
454 self.save_base_state()
455 # serialize the self.buffer.queue and then pickle it
456 with open(f'checkpoints/buffer_state_{self.rank}.pkl', 'wb') as f:
457 cloudpickle.dump(self.buffer.save_state(), f)
459 return
461 @abstractmethod
462 def checkpoint(self, batch: int, path: str):
463 """
464 The method called to initiate full tree checkpointing. This is
465 specific to torch or tf.
466 """
467 return
469 def restart_from_checkpoint(self):
470 """
471 Restart the server from a checkpoint
472 """
473 self.load_base_state()
475 if (
476 not any("model.pt" in filename for filename in os.listdir("checkpoints"))
477 or not os.path.exists(f"checkpoints/buffer_state_{self.rank}.pkl")
478 ):
479 raise Exception(f"No checkpoint and/or queue found on rank {self.rank}. Exiting.")
481 logger.info(f"Restarting from checkpoint {self.rank}")
482 with open(f'checkpoints/buffer_state_{self.rank}.pkl', 'rb') as f:
483 state = cloudpickle.load(f)
484 self.buffer.load_from_state(state)
486 # lib specific loading method (torch vs tf)
487 self.load_model_from_checkpoint()
489 return
491 @abstractmethod
492 def load_model_from_checkpoint(self):
493 """
494 Library specific model loading function
495 """
496 return
499def rank_zero_only(fn: Callable) -> Callable:
500 """Function that can be used as a decorator to enable a function/method
501 being called only on rank 0.
502 Inspired by pytorch_lightning
503 https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/utilities/rank_zero.html#rank_zero_info
504 """
506 rank, _ = get_rank_and_num_server_proc()
508 @wraps(fn)
509 def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]:
510 if rank == 0:
511 return fn(*args, **kwargs)
512 return None
514 return wrapped_fn