Coverage for melissa/server/deep_learning/tensorboard_logger.py: 39%
62 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-09-22 10:36 +0000
1from typing import Any, Optional
2try: # double import for CI
3 from torch.utils.tensorboard import SummaryWriter as TorchSummaryWriter
4 from tensorflow.summary import scalar, create_file_writer
5 from tensorflow.summary import SummaryWriter as TfSummaryWriter
6except ModuleNotFoundError:
7 try: # only one framework available
8 from torch.utils.tensorboard import SummaryWriter as TorchSummaryWriter
9 except ModuleNotFoundError:
10 from tensorflow.summary import scalar, create_file_writer
11 from tensorflow.summary import SummaryWriter as TfSummaryWriter
14class TensorboardLogger:
15 def __init__(self, disable: bool = False, debug: bool = False):
16 self.disable = disable
17 self.debug = debug
19 def log_scalar(self, tag: str, scalar_value: Any, step: int):
20 """
21 Logs scalar to tensorboard logger
22 """
23 raise NotImplementedError
25 def log_scalar_dbg(self, tag: str, scalar_value: Any, step: int):
26 """
27 Logs debugging related scalar to tensorboard logger
28 """
29 raise NotImplementedError
31 def log_histogram(self, tag: str, values: Any, step: Optional[int] = None):
32 """
33 Logs histograms to tensorboard logger
34 """
35 raise NotImplementedError
37 def close(self):
38 """
39 Flushes and closes tensorboard logger
40 """
41 raise NotImplementedError
44class TorchTensorboardLogger(TensorboardLogger):
45 def __init__(
46 self, rank: int, logdir: str = "tensorboard",
47 disable: bool = False, debug: bool = False
48 ):
49 super().__init__(disable, debug)
50 self.writer: Optional[TorchSummaryWriter] = None
51 if not self.disable:
52 self.writer = TorchSummaryWriter(f"{logdir}/gpu_{rank}", filename_suffix=f"rank_{rank}")
53 layout = {
54 "Server stats": {
55 "loss": ["Multiline", ["Loss/train", "Loss/valid"]],
56 "put_get_time": ["Multiline", ["put_time", "get_time"]],
57 },
58 }
59 self.writer.add_custom_scalars(layout)
61 def log_scalar(self, tag: str, scalar_value: Any, step: int):
62 if not self.disable and self.writer is not None:
63 self.writer.add_scalar(tag, scalar_value, step)
65 def log_histogram(self, tag: str, values: Any, step: Optional[int] = None):
66 if not self.disable and self.writer is not None:
67 self.writer.add_histogram(tag, values, step)
69 def log_scalar_dbg(self, tag: str, scalar_value: Any, step: int):
70 if not self.disable and self.writer is not None and self.debug:
71 self.writer.add_scalar(tag, scalar_value, step)
73 def close(self):
74 if not self.disable and self.writer is not None:
75 self.writer.flush()
76 self.writer.close()
79class TfTensorboardLogger(TensorboardLogger):
80 def __init__(
81 self, rank: int, logdir: str = "tensorboard",
82 disable: bool = False, debug: bool = False
83 ):
84 super().__init__(disable, debug)
85 self.writer: Optional[TfSummaryWriter] = None
86 if not self.disable:
87 self.writer = create_file_writer(
88 logdir + f"/gpu_{rank}", filename_suffix=f"rank_{rank}"
89 )
91 def log_scalar(self, tag: str, scalar_value: Any, step: int):
92 if not self.disable and self.writer is not None:
93 with self.writer.as_default():
94 scalar(name=tag, data=scalar_value, step=step)
96 def log_scalar_dbg(self, tag: str, scalar_value: Any, step: int):
97 if not self.disable and self.writer is not None and self.debug:
98 with self.writer.as_default():
99 scalar(name=tag, data=scalar_value, step=step)
101 def close(self):
102 if not self.disable and self.writer is not None:
103 self.writer.flush()
104 self.writer.close()