Coverage for melissa/server/deep_learning/tensorboard/torch_logger.py: 0%

30 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-10 22:25 +0100

1from typing import Any, Optional, Union, List 

2from torch.utils.tensorboard import SummaryWriter 

3from melissa.server.deep_learning.tensorboard.base_logger import TensorboardLogger 

4from matplotlib.figure import Figure 

5 

6 

7class TorchTensorboardLogger(TensorboardLogger): 

8 def __init__( 

9 self, 

10 rank: int = 0, 

11 logdir: str = "tensorboard", 

12 disable: bool = False, 

13 debug: bool = False, 

14 ) -> None: 

15 

16 super().__init__(disable, debug) 

17 if not self.disable: 

18 self._writer = SummaryWriter( 

19 f"{logdir}/gpu_{rank}", filename_suffix=f"rank_{rank}" 

20 ) 

21 layout = { 

22 "Server stats": { 

23 "loss": ["Multiline", ["Loss/train", "Loss/valid"]], 

24 "put_get_time": ["Multiline", ["put_time", "get_time"]], 

25 }, 

26 } 

27 self.writer.add_custom_scalars(layout) 

28 

29 def log_scalar(self, tag: str, scalar_value: Any, step: int): 

30 if not self.disable and self.writer is not None: 

31 self.writer.add_scalar(tag, scalar_value, step) 

32 

33 def log_scalars(self, main_tag: str, tag_scalar_dict: dict, step: int): 

34 if not self.disable and self.writer is not None: 

35 self.writer.add_scalars(main_tag, tag_scalar_dict, step) 

36 

37 def log_histogram(self, tag: str, values: Any, step: Optional[int] = None): 

38 if not self.disable and self.writer is not None: 

39 self.writer.add_histogram(tag, values, step) 

40 

41 def log_scalar_dbg(self, tag: str, scalar_value: Any, step: int): 

42 if not self.disable and self.writer is not None and self.debug: 

43 self.writer.add_scalar(tag, scalar_value, step) 

44 

45 def log_figure( 

46 self, 

47 tag: str, 

48 figure: Union[Figure, List[Figure]], 

49 step: Optional[int] = None, 

50 close: bool = True, 

51 ): 

52 if not self.disable and self.writer is not None: 

53 self.writer.add_figure(tag, figure, step, close=close) 

54 

55 def close(self): 

56 if not self.disable and self.writer is not None: 

57 self.writer.flush() 

58 self.writer.close()