Coverage for melissa/server/deep_learning/tensorboard/torch_logger.py: 0%
30 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 22:25 +0100
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-10 22:25 +0100
1from typing import Any, Optional, Union, List
2from torch.utils.tensorboard import SummaryWriter
3from melissa.server.deep_learning.tensorboard.base_logger import TensorboardLogger
4from matplotlib.figure import Figure
7class TorchTensorboardLogger(TensorboardLogger):
8 def __init__(
9 self,
10 rank: int = 0,
11 logdir: str = "tensorboard",
12 disable: bool = False,
13 debug: bool = False,
14 ) -> None:
16 super().__init__(disable, debug)
17 if not self.disable:
18 self._writer = SummaryWriter(
19 f"{logdir}/gpu_{rank}", filename_suffix=f"rank_{rank}"
20 )
21 layout = {
22 "Server stats": {
23 "loss": ["Multiline", ["Loss/train", "Loss/valid"]],
24 "put_get_time": ["Multiline", ["put_time", "get_time"]],
25 },
26 }
27 self.writer.add_custom_scalars(layout)
29 def log_scalar(self, tag: str, scalar_value: Any, step: int):
30 if not self.disable and self.writer is not None:
31 self.writer.add_scalar(tag, scalar_value, step)
33 def log_scalars(self, main_tag: str, tag_scalar_dict: dict, step: int):
34 if not self.disable and self.writer is not None:
35 self.writer.add_scalars(main_tag, tag_scalar_dict, step)
37 def log_histogram(self, tag: str, values: Any, step: Optional[int] = None):
38 if not self.disable and self.writer is not None:
39 self.writer.add_histogram(tag, values, step)
41 def log_scalar_dbg(self, tag: str, scalar_value: Any, step: int):
42 if not self.disable and self.writer is not None and self.debug:
43 self.writer.add_scalar(tag, scalar_value, step)
45 def log_figure(
46 self,
47 tag: str,
48 figure: Union[Figure, List[Figure]],
49 step: Optional[int] = None,
50 close: bool = True,
51 ):
52 if not self.disable and self.writer is not None:
53 self.writer.add_figure(tag, figure, step, close=close)
55 def close(self):
56 if not self.disable and self.writer is not None:
57 self.writer.flush()
58 self.writer.close()