Coverage for melissa/server/deep_learning/tensorboard/tf_logger.py: 0%

59 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-07-09 14:19 +0200

1import io 

2 

3import numpy as np 

4from typing import Any, Union, List, Optional 

5import tensorflow as tf 

6import matplotlib.pyplot as plt 

7from matplotlib.figure import Figure 

8from PIL import Image 

9 

10from melissa.server.deep_learning.tensorboard import TensorboardLogger 

11 

12 

13class TfTensorboardLogger(TensorboardLogger): 

14 def __init__( 

15 self, 

16 rank: int = 0, 

17 logdir: str = "tensorboard", 

18 disable: bool = False, 

19 debug: bool = False 

20 ) -> None: 

21 super().__init__(disable, debug) 

22 if not self.disable: 

23 self._writer = tf.summary.create_file_writer( 

24 f"{logdir}/gpu_{rank}", filename_suffix=f"rank_{rank}" 

25 ) 

26 # sets the default step, if step=None 

27 tf.summary.experimental.set_step(-1) 

28 

29 def log_scalar(self, tag: str, scalar_value: Any, step: int): 

30 if not self.disable and self._writer is not None: 

31 with self._writer.as_default(): 

32 tf.summary.scalar(name=tag, data=scalar_value, step=step) 

33 

34 def log_scalars(self, main_tag: str, tag_scalar_dict: dict, step: int): 

35 if not self.disable and self._writer is not None: 

36 with self._writer.as_default(): 

37 for key, value in tag_scalar_dict.items(): 

38 tf.summary.scalar(name=f"{main_tag}/{key}", data=value, step=step) 

39 

40 def log_histogram(self, tag: str, values: Any, step: Optional[int] = None): 

41 if not self.disable and self._writer is not None: 

42 with self._writer.as_default(): 

43 tf.summary.histogram(name=tag, data=values, step=step) 

44 

45 def log_scalar_dbg(self, tag: str, scalar_value: Any, step: int): 

46 if not self.disable and self._writer is not None and self.debug: 

47 with self._writer.as_default(): 

48 tf.summary.scalar(name=tag, data=scalar_value, step=step) 

49 

50 def __figure_to_image(self, figures: Union[Figure, List[Figure]], close: bool = True): 

51 """Converts matplotlib figure(s) to a TensorFlow image tensor. 

52 Inspired from `torch.utils.tensorboard._utils.figure_to_image`.""" 

53 

54 def render_to_rgb(figure): 

55 buf = io.BytesIO() 

56 figure.savefig(buf, format='png') 

57 buf.seek(0) 

58 image = Image.open(buf) 

59 image = np.array(image.convert("RGB")) 

60 buf.close() 

61 if close: 

62 plt.close(figure) 

63 return image 

64 

65 if isinstance(figures, list): 

66 images = np.stack([render_to_rgb(fig) for fig in figures]) 

67 else: 

68 images = render_to_rgb(figures) 

69 

70 if len(images.shape) == 3: # (HWC) 

71 images = np.expand_dims(images, axis=0) 

72 

73 images = tf.convert_to_tensor(images, dtype=tf.float32) / 255.0 

74 

75 return images 

76 

77 def log_figure( 

78 self, 

79 tag: str, 

80 figure: Union[Figure, List[Figure]], 

81 step: Optional[int] = None, 

82 close: bool = True, 

83 ): 

84 if not self.disable and self._writer is not None: 

85 with self._writer.as_default(): 

86 image = self.__figure_to_image(figure, close) 

87 tf.summary.image(tag, image, step=step) 

88 

89 def close(self): 

90 if not self.disable and self._writer is not None: 

91 self._writer.flush() 

92 self._writer.close() 

93 self._writer = None