Coverage for melissa/server/deep_learning/tensorboard/tf_logger.py: 0%
59 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-07-09 14:19 +0200
« prev ^ index » next coverage.py v7.6.12, created at 2025-07-09 14:19 +0200
1import io
3import numpy as np
4from typing import Any, Union, List, Optional
5import tensorflow as tf
6import matplotlib.pyplot as plt
7from matplotlib.figure import Figure
8from PIL import Image
10from melissa.server.deep_learning.tensorboard import TensorboardLogger
13class TfTensorboardLogger(TensorboardLogger):
14 def __init__(
15 self,
16 rank: int = 0,
17 logdir: str = "tensorboard",
18 disable: bool = False,
19 debug: bool = False
20 ) -> None:
21 super().__init__(disable, debug)
22 if not self.disable:
23 self._writer = tf.summary.create_file_writer(
24 f"{logdir}/gpu_{rank}", filename_suffix=f"rank_{rank}"
25 )
26 # sets the default step, if step=None
27 tf.summary.experimental.set_step(-1)
29 def log_scalar(self, tag: str, scalar_value: Any, step: int):
30 if not self.disable and self._writer is not None:
31 with self._writer.as_default():
32 tf.summary.scalar(name=tag, data=scalar_value, step=step)
34 def log_scalars(self, main_tag: str, tag_scalar_dict: dict, step: int):
35 if not self.disable and self._writer is not None:
36 with self._writer.as_default():
37 for key, value in tag_scalar_dict.items():
38 tf.summary.scalar(name=f"{main_tag}/{key}", data=value, step=step)
40 def log_histogram(self, tag: str, values: Any, step: Optional[int] = None):
41 if not self.disable and self._writer is not None:
42 with self._writer.as_default():
43 tf.summary.histogram(name=tag, data=values, step=step)
45 def log_scalar_dbg(self, tag: str, scalar_value: Any, step: int):
46 if not self.disable and self._writer is not None and self.debug:
47 with self._writer.as_default():
48 tf.summary.scalar(name=tag, data=scalar_value, step=step)
50 def __figure_to_image(self, figures: Union[Figure, List[Figure]], close: bool = True):
51 """Converts matplotlib figure(s) to a TensorFlow image tensor.
52 Inspired from `torch.utils.tensorboard._utils.figure_to_image`."""
54 def render_to_rgb(figure):
55 buf = io.BytesIO()
56 figure.savefig(buf, format='png')
57 buf.seek(0)
58 image = Image.open(buf)
59 image = np.array(image.convert("RGB"))
60 buf.close()
61 if close:
62 plt.close(figure)
63 return image
65 if isinstance(figures, list):
66 images = np.stack([render_to_rgb(fig) for fig in figures])
67 else:
68 images = render_to_rgb(figures)
70 if len(images.shape) == 3: # (HWC)
71 images = np.expand_dims(images, axis=0)
73 images = tf.convert_to_tensor(images, dtype=tf.float32) / 255.0
75 return images
77 def log_figure(
78 self,
79 tag: str,
80 figure: Union[Figure, List[Figure]],
81 step: Optional[int] = None,
82 close: bool = True,
83 ):
84 if not self.disable and self._writer is not None:
85 with self._writer.as_default():
86 image = self.__figure_to_image(figure, close)
87 tf.summary.image(tag, image, step=step)
89 def close(self):
90 if not self.disable and self._writer is not None:
91 self._writer.flush()
92 self._writer.close()
93 self._writer = None