Coverage for melissa/server/deep_learning/base_dl_server.py: 44%

241 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-09-22 10:36 +0000

1import logging 

2import os 

3import threading 

4from abc import abstractmethod 

5from functools import wraps 

6from typing import Any, Callable, Dict, Optional, Tuple, Union 

7import time 

8import numpy as np 

9 

10from melissa.server.deep_learning.tensorboard_logger import TensorboardLogger 

11from melissa.server.deep_learning.dataset import MelissaIterableDataset 

12from melissa.server.deep_learning.reservoir import FIFO 

13from melissa.server.base_server import BaseServer 

14from melissa.server.simulation import (PartialSimulationData, Simulation, 

15 SimulationData, SimulationDataStatus) 

16from melissa.utility.networking import get_rank_and_num_server_proc 

17from pathlib import Path 

18import cloudpickle 

19 

20 

21logger = logging.getLogger(__name__) 

22 

23 

24class DeepMelissaServer(BaseServer): 

25 """ 

26 Director to be used for any DeepMelissa study. 

27 The MelissaServer is initialized with the proper options 

28 self.start() sets the order of operations including the 

29 user created training loop "train()" 

30 """ 

31 

32 def __init__(self, config: Dict[str, Any]): 

33 super().__init__(config) 

34 

35 self.learning: int = 2 

36 self.check_group_size() 

37 self.dl_config: Dict[str, Any] = config['dl_config'] 

38 self.study_options: Dict[str, Any] = config['study_options'] 

39 self.debug = True if self.study_options["verbosity"] >= 3 else False 

40 self.tb_logger = TensorboardLogger() 

41 # define temporary elementary buffer for typing purposes. 

42 self.buffer: FIFO = FIFO() 

43 self.dataset: MelissaIterableDataset = MelissaIterableDataset( 

44 buffer=self.buffer, 

45 tb_logger=self.tb_logger, 

46 config=self.config, 

47 transform=self.process_simulation_data, 

48 ) 

49 self.batch_size: int = self.dl_config['batch_size'] 

50 self.per_server_watermark: int = self.dl_config['per_server_watermark'] 

51 self.buffer_size: int = self.dl_config["buffer_size"] 

52 self.pseudo_epochs: int = self.dl_config.get("pseudo_epochs", 1) 

53 self.sample_number: int = 0 

54 self.n_expected_batches: int = 1 

55 self.idr_rank: int | None = None 

56 self.setup_slurm_ddp: bool = self.dl_config.get("setup_slurm_ddp", False) 

57 self.n_batches_update: int = self.dl_config["n_batches_update"] 

58 self.batch_offset = 0 

59 

60 def check_group_size(self): 

61 if self.group_size > 1 and self.num_clients % self.group_size != 0: 

62 logger.error("Incorrect group_size, please remove or adjust this option") 

63 self.catch_error = True 

64 

65 def start(self): 

66 """ 

67 The main execution method 

68 """ 

69 if not self.restart: 

70 self.launch_first_groups() 

71 if not self.setup_slurm_ddp: 

72 self.setup_environment() 

73 else: 

74 self.setup_environment_slurm() 

75 if not self.restart: 

76 self.set_model() 

77 else: 

78 # the reinitialization from checkpoint occurs here 

79 logger.info(f"Continuing from checkpoint {self.restart}") 

80 self.restart_from_checkpoint() 

81 if self.rank == 0: 

82 self.kill_and_restart_simulations() 

83 self.set_expected_batches_samples_watermark() 

84 self.configure_data_collection() 

85 

86 # put server receive on a separate thread. 

87 # should not be accesse by user 

88 rcv = threading.Thread(target=self.receive) 

89 rcv.start() 

90 

91 self.server_online() 

92 

93 self.tb_logger.close() 

94 if self.dl_config.get("convert_log_to_df", False): 

95 try: 

96 self.convert_log_to_df() 

97 except ImportError as e: 

98 logger.error(f"Unable to import dependencies for log conversion {e}. " 

99 "Please install pandas and tensorflow.") 

100 

101 self.server_finalize() 

102 

103 def setup_environment(self): 

104 """ 

105 Sets environment for distributed GPU training if desired. 

106 """ 

107 return 

108 

109 def setup_environment_slurm(self): 

110 """ 

111 Unique DDP env setup with slurm as recommended by 

112 http://www.idris.fr/eng/jean-zay/gpu/jean-zay-gpu-torch-multi-eng.html 

113 """ 

114 return 

115 

116 @abstractmethod 

117 def set_model(self): 

118 """ 

119 Configure the server self.model to prepare for initialization 

120 """ 

121 return 

122 

123 @abstractmethod 

124 def server_online(self): 

125 """ 

126 Initiating data collection and directing the custom 

127 methods for acting on collected data. 

128 """ 

129 return 

130 

131 def receive(self): 

132 """ 

133 Handle data from the server. 

134 """ 

135 try: 

136 self._is_receiving = True 

137 while not self.all_done(): 

138 start = time.time() 

139 data = self.run() 

140 if data is not None and isinstance(data, SimulationData): 

141 logger.debug( 

142 f"Receive Message {data.simulation_id} time-step {data.time_step}" 

143 ) 

144 self.buffer.put(data) 

145 self.sample_number += 1 

146 self.tb_logger.log_scalar( 

147 "put_time", time.time() - start, self.sample_number) 

148 

149 self._is_receiving = False 

150 self.dataset.signal_reception_over() 

151 logger.debug("Signal end of reception.") 

152 

153 except Exception as e: 

154 logger.exception(f"Exception was raised in the receiving thread: \n {e}") 

155 self.close_connection(1) 

156 os._exit(1) 

157 

158 def handle_simulation_data(self, msg): 

159 """ 

160 Parses and validates the incoming data messages from simulations 

161 """ 

162 # 1. Deserialize message 

163 msg_data: PartialSimulationData = PartialSimulationData.from_msg(msg, self.learning) 

164 group_id = msg_data.simulation_id // self.group_size 

165 logger.debug( 

166 f"Rank {self.rank} received {msg_data} from rank {msg_data.client_rank} " 

167 f"(vect_size: {msg_data.data_size})") 

168 

169 # 2. Apply filters 

170 if msg_data.field == "termination": 

171 logger.info(f"Rank {self.rank} received termination message " 

172 f"from simulation {msg_data.simulation_id} with " 

173 f"{msg_data.time_step} expected time-steps" 

174 ) 

175 # increment the total number of expected time-steps if not known 

176 if self.n_expected_batches == 0: 

177 self.num_samples += msg_data.time_step 

178 self.groups[group_id].simulations[msg_data.simulation_id].n_time_steps = ( 

179 msg_data.time_step 

180 ) 

181 logger.info(f"Rank {self.rank}: simulation {msg_data.simulation_id} finished " 

182 f"number of expected samples incremented to {self.num_samples}") 

183 self.n_finished_simulations += 1 

184 return None 

185 if ( 

186 msg_data.time_step < 0 

187 or (msg_data.time_step > self.num_samples and self.n_expected_batches > 0) 

188 ): 

189 logger.warning(f"Rank {self.rank}: bad time-step {msg_data.time_step}") 

190 return None 

191 if msg_data.field not in self.fields: 

192 if msg_data.field != "termination": 

193 logger.warning(f"Rank {self.rank}: bad field {msg_data.field}") 

194 return None 

195 

196 simulation = self.groups[group_id].simulations[msg_data.simulation_id] 

197 

198 simulation_status, simulation_data = self.check_simulation_data( 

199 simulation, msg_data) 

200 if simulation_status == SimulationDataStatus.COMPLETE and simulation_data: 

201 logger.debug( 

202 f"Rank {self.rank}: assembled time-step {simulation_data.time_step} " 

203 f"- simulationID {simulation_data.simulation_id}") 

204 elif simulation_status == SimulationDataStatus.ALREADY_RECEIVED: 

205 logger.warning( 

206 f"Rank {self.rank}: duplicate simulation data {msg_data}") 

207 

208 # Check if simulation has finished 

209 if ( 

210 simulation_status == SimulationDataStatus.COMPLETE 

211 or simulation_status == SimulationDataStatus.EMPTY 

212 ) and simulation.finished(): 

213 logger.info(f"Rank {self.rank}: simulation {simulation.id} finished") 

214 self.n_finished_simulations += 1 

215 

216 return simulation_data 

217 

218 def check_simulation_data( 

219 self, simulation: Simulation, simulation_data: PartialSimulationData 

220 ) -> Tuple[ 

221 SimulationDataStatus, 

222 Union[Optional[SimulationData], Optional[PartialSimulationData]], 

223 ]: 

224 """ 

225 Look for duplicated messages, 

226 update received_simulation_data and the simulation_data status. 

227 """ 

228 if simulation_data.client_rank not in simulation.received_simulation_data: 

229 simulation.received_simulation_data[simulation_data.client_rank] = {} 

230 # the 2D-array is allocated at once if the number of expected samples is known 

231 if self.n_expected_batches != 0: 

232 simulation.received_time_steps[simulation_data.client_rank] = ( 

233 np.zeros((len(self.fields), self.num_samples), dtype=bool) 

234 ) 

235 # if not it is initialized with a single column 

236 else: 

237 simulation.received_time_steps[simulation_data.client_rank] = ( 

238 np.zeros((len(self.fields), 1), dtype=bool) 

239 ) 

240 # if num_samples is unknown the received_time_step matrix is built on the fly 

241 if ( 

242 simulation_data.time_step 

243 > simulation.received_time_steps[simulation_data.client_rank].shape[1] - 1 

244 ): 

245 simulation.received_time_steps[simulation_data.client_rank] = np.concatenate( 

246 [simulation.received_time_steps[simulation_data.client_rank], 

247 np.zeros((len(self.fields), 1), dtype=bool)], axis=1 

248 ) 

249 

250 # Data have already been received 

251 if simulation.has_already_received( 

252 simulation_data.client_rank, simulation_data.time_step, simulation_data.field 

253 ): 

254 logger.debug(f"simulation {simulation.id} " 

255 f"field {simulation_data.field} " 

256 f"timestep {simulation_data.time_step} discarded") 

257 return SimulationDataStatus.ALREADY_RECEIVED, None 

258 # Time step has never been seen 

259 if simulation_data.time_step not in simulation.received_simulation_data[ 

260 simulation_data.client_rank 

261 ]: 

262 simulation.received_simulation_data[simulation_data.client_rank][ 

263 simulation_data.time_step 

264 ] = {field: None for field in simulation.fields} 

265 # Update the entry 

266 simulation.received_simulation_data[simulation_data.client_rank][ 

267 simulation_data.time_step 

268 ][simulation_data.field] = simulation_data 

269 simulation._mark_as_received( 

270 simulation_data.client_rank, simulation_data.time_step, simulation_data.field 

271 ) 

272 if simulation.is_complete(simulation_data.time_step): 

273 # All fields have been received for the time step 

274 simulation.n_received_time_steps += 1 

275 # Check there is actual data 

276 is_empty = simulation_data.data_size == 0 

277 if is_empty: 

278 # Data have been set to another device, fields are empty 

279 del simulation.received_simulation_data[simulation_data.client_rank][ 

280 simulation_data.time_step 

281 ] 

282 return SimulationDataStatus.EMPTY, None 

283 # Concatenate data in the same order as fields. 

284 data = [] 

285 for sd in simulation.received_simulation_data[simulation_data.client_rank][ 

286 simulation_data.time_step].values(): 

287 if not sd: 

288 logger.warning('No data dictionary found') 

289 else: 

290 assert isinstance(sd, PartialSimulationData) 

291 data.append(sd.data) 

292 

293 del simulation.received_simulation_data[simulation_data.client_rank][ 

294 simulation_data.time_step 

295 ] 

296 return SimulationDataStatus.COMPLETE, SimulationData( 

297 simulation_data.simulation_id, 

298 simulation_data.time_step, 

299 data, 

300 simulation.parameters, 

301 ) 

302 else: 

303 # Not all fields have been received yet 

304 return SimulationDataStatus.PARTIAL, None 

305 

306 def set_expected_batches_samples_watermark(self): 

307 """ 

308 Takes user config and computes the expected samples per server proc 

309 and expected batches per server proc 

310 """ 

311 # standard case where num_samples is given in the config file 

312 if self.num_samples > 0: 

313 # ensure watermark is sufficient 

314 self.check_water_mark() 

315 

316 # Account for possible accumulated shift 

317 self.n_expected_samples = (self.num_clients // self.num_server_proc) * self.num_samples 

318 self.n_expected_batches = ( 

319 self.n_expected_samples // self.batch_size * self.pseudo_epochs 

320 ) 

321 

322 if self.pseudo_epochs > 1 and self.buffer_size != self.n_expected_samples: 

323 logger.warning( 

324 "User tried using pseudo_epochs with buffer size smaller than expected " 

325 "samples. Setting buffer size to number of expected samples " 

326 f"({self.n_expected_samples})." 

327 ) 

328 self.buffer_size = self.n_expected_samples 

329 

330 logger.info( 

331 f"Expecting {self.n_expected_samples} " 

332 f"samples across {self.n_expected_batches} batches.") 

333 # when num_samples is not known a priori 

334 else: 

335 logger.info("Number of expected samples a priori unknown") 

336 self.n_expected_batches = 0 

337 

338 def check_water_mark(self): 

339 """ 

340 Ensures there are sufficient samples to reach the per_server_watermark 

341 """ 

342 total_samples = (self.num_samples * self.num_clients) 

343 samples_per_server = total_samples // self.num_server_proc 

344 if not self.dl_config["per_server_watermark"] <= samples_per_server: 

345 raise Exception('Insufficient samples to reach per_server_watermark. ' 

346 'please increase num_samples, or decrease per_server_watermark.') 

347 

348 def other_processes_finished(self, batch_number: int) -> bool: 

349 """ 

350 Ensure the server processes are emptying their buffers together 

351 after data reception is finished. 

352 """ 

353 

354 logger.debug(f"{self.rank} is on batch {batch_number + 1}/{self.n_expected_batches}") 

355 

356 # ensure self.dataset._is_receiving is in sync across all server 

357 # processes. 

358 data_available = self.synchronize_data_availability() 

359 

360 # in case of pseudo_offline training, we want to avoid a 

361 # server timeout so we ping the launcher with time_monitor 

362 if not data_available: 

363 # at this point the total number of expected samples should be known 

364 # and used to update the value of self.n_expected_batches 

365 if self.n_expected_batches == 0: 

366 # per client number of expected time-steps 

367 self.num_samples //= self.num_clients 

368 self.set_expected_batches_samples_watermark() 

369 self.time_monitor.check_clock(time.monotonic(), self) 

370 logger.debug("One of the server processes finished receiving. " 

371 f"{self.rank} is on batch {batch_number + 1}/{self.n_expected_batches}") 

372 

373 return not data_available 

374 

375 def convert_log_to_df(self): 

376 """ 

377 Convert local TensorBoard data into Pandas DataFrame. 

378 Saves the pandas dataframe as a pickle file inside 

379 out_dir/tensorboard. 

380 """ 

381 from tensorflow.python.summary.summary_iterator import summary_iterator 

382 import pandas as pd 

383 

384 def convert_tfevent(filepath): 

385 return pd.DataFrame([ 

386 parse_tfevent(e) for e in summary_iterator(filepath) if len(e.summary.value) 

387 ]) 

388 

389 def parse_tfevent(tfevent): 

390 return dict( 

391 wall_time=tfevent.wall_time, 

392 name=tfevent.summary.value[0].tag, 

393 step=tfevent.step, 

394 value=float(tfevent.summary.value[0].simple_value), 

395 ) 

396 

397 columns_order = ['wall_time', 'name', 'step', 'value'] 

398 

399 out = [] 

400 for folder in Path("tensorboard").iterdir(): 

401 if f"gpu_{self.rank}" in str(folder): 

402 for file in folder.iterdir(): 

403 if "events.out.tfevents" not in str(file): 

404 continue 

405 if f"rank_{self.rank}" not in str(file): 

406 continue 

407 logger.info(f"Parsing {str(file)}") 

408 out.append(convert_tfevent(str(file))) 

409 

410 all_df = pd.concat(out)[columns_order] 

411 all_df.reset_index(drop=True) 

412 all_df.to_pickle(f"./tensorboard/data_rank_{self.rank}.pkl") 

413 

414 def server_finalize(self): 

415 """ 

416 All finalization methods go here. 

417 """ 

418 return 

419 

420 @abstractmethod 

421 def configure_data_collection(self): 

422 """ 

423 Instantiates the data collector and buffer. 

424 """ 

425 return 

426 

427 @abstractmethod 

428 def train(self): 

429 """ 

430 Use-case based training loop. 

431 """ 

432 return 

433 

434 def test(self, model: Any): 

435 """ 

436 User can setup a test function if desired. 

437 Not required. 

438 """ 

439 return model 

440 

441 def synchronize_data_availability(self) -> bool: 

442 """ 

443 Coordinates the dataset _is_receiving across all 

444 server processes. This usually requires a library 

445 specific all_reduce function (e.g. dist.all_reduce 

446 in pytorch) 

447 """ 

448 return True 

449 

450 def checkpoint_state(self): 

451 """ 

452 Checkpoint the current state of the server 

453 """ 

454 self.save_base_state() 

455 # serialize the self.buffer.queue and then pickle it 

456 with open(f'checkpoints/buffer_state_{self.rank}.pkl', 'wb') as f: 

457 cloudpickle.dump(self.buffer.save_state(), f) 

458 

459 return 

460 

461 @abstractmethod 

462 def checkpoint(self, batch: int, path: str): 

463 """ 

464 The method called to initiate full tree checkpointing. This is 

465 specific to torch or tf. 

466 """ 

467 return 

468 

469 def restart_from_checkpoint(self): 

470 """ 

471 Restart the server from a checkpoint 

472 """ 

473 self.load_base_state() 

474 

475 if ( 

476 not any("model.pt" in filename for filename in os.listdir("checkpoints")) 

477 or not os.path.exists(f"checkpoints/buffer_state_{self.rank}.pkl") 

478 ): 

479 raise Exception(f"No checkpoint and/or queue found on rank {self.rank}. Exiting.") 

480 

481 logger.info(f"Restarting from checkpoint {self.rank}") 

482 with open(f'checkpoints/buffer_state_{self.rank}.pkl', 'rb') as f: 

483 state = cloudpickle.load(f) 

484 self.buffer.load_from_state(state) 

485 

486 # lib specific loading method (torch vs tf) 

487 self.load_model_from_checkpoint() 

488 

489 return 

490 

491 @abstractmethod 

492 def load_model_from_checkpoint(self): 

493 """ 

494 Library specific model loading function 

495 """ 

496 return 

497 

498 

499def rank_zero_only(fn: Callable) -> Callable: 

500 """Function that can be used as a decorator to enable a function/method 

501 being called only on rank 0. 

502 Inspired by pytorch_lightning 

503 https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/utilities/rank_zero.html#rank_zero_info 

504 """ 

505 

506 rank, _ = get_rank_and_num_server_proc() 

507 

508 @wraps(fn) 

509 def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]: 

510 if rank == 0: 

511 return fn(*args, **kwargs) 

512 return None 

513 

514 return wrapped_fn