Coverage for melissa/server/deep_learning/tensorboard_logger.py: 39%

62 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-09-22 10:36 +0000

1from typing import Any, Optional 

2try: # double import for CI 

3 from torch.utils.tensorboard import SummaryWriter as TorchSummaryWriter 

4 from tensorflow.summary import scalar, create_file_writer 

5 from tensorflow.summary import SummaryWriter as TfSummaryWriter 

6except ModuleNotFoundError: 

7 try: # only one framework available 

8 from torch.utils.tensorboard import SummaryWriter as TorchSummaryWriter 

9 except ModuleNotFoundError: 

10 from tensorflow.summary import scalar, create_file_writer 

11 from tensorflow.summary import SummaryWriter as TfSummaryWriter 

12 

13 

14class TensorboardLogger: 

15 def __init__(self, disable: bool = False, debug: bool = False): 

16 self.disable = disable 

17 self.debug = debug 

18 

19 def log_scalar(self, tag: str, scalar_value: Any, step: int): 

20 """ 

21 Logs scalar to tensorboard logger 

22 """ 

23 raise NotImplementedError 

24 

25 def log_scalar_dbg(self, tag: str, scalar_value: Any, step: int): 

26 """ 

27 Logs debugging related scalar to tensorboard logger 

28 """ 

29 raise NotImplementedError 

30 

31 def log_histogram(self, tag: str, values: Any, step: Optional[int] = None): 

32 """ 

33 Logs histograms to tensorboard logger 

34 """ 

35 raise NotImplementedError 

36 

37 def close(self): 

38 """ 

39 Flushes and closes tensorboard logger 

40 """ 

41 raise NotImplementedError 

42 

43 

44class TorchTensorboardLogger(TensorboardLogger): 

45 def __init__( 

46 self, rank: int, logdir: str = "tensorboard", 

47 disable: bool = False, debug: bool = False 

48 ): 

49 super().__init__(disable, debug) 

50 self.writer: Optional[TorchSummaryWriter] = None 

51 if not self.disable: 

52 self.writer = TorchSummaryWriter(f"{logdir}/gpu_{rank}", filename_suffix=f"rank_{rank}") 

53 layout = { 

54 "Server stats": { 

55 "loss": ["Multiline", ["Loss/train", "Loss/valid"]], 

56 "put_get_time": ["Multiline", ["put_time", "get_time"]], 

57 }, 

58 } 

59 self.writer.add_custom_scalars(layout) 

60 

61 def log_scalar(self, tag: str, scalar_value: Any, step: int): 

62 if not self.disable and self.writer is not None: 

63 self.writer.add_scalar(tag, scalar_value, step) 

64 

65 def log_histogram(self, tag: str, values: Any, step: Optional[int] = None): 

66 if not self.disable and self.writer is not None: 

67 self.writer.add_histogram(tag, values, step) 

68 

69 def log_scalar_dbg(self, tag: str, scalar_value: Any, step: int): 

70 if not self.disable and self.writer is not None and self.debug: 

71 self.writer.add_scalar(tag, scalar_value, step) 

72 

73 def close(self): 

74 if not self.disable and self.writer is not None: 

75 self.writer.flush() 

76 self.writer.close() 

77 

78 

79class TfTensorboardLogger(TensorboardLogger): 

80 def __init__( 

81 self, rank: int, logdir: str = "tensorboard", 

82 disable: bool = False, debug: bool = False 

83 ): 

84 super().__init__(disable, debug) 

85 self.writer: Optional[TfSummaryWriter] = None 

86 if not self.disable: 

87 self.writer = create_file_writer( 

88 logdir + f"/gpu_{rank}", filename_suffix=f"rank_{rank}" 

89 ) 

90 

91 def log_scalar(self, tag: str, scalar_value: Any, step: int): 

92 if not self.disable and self.writer is not None: 

93 with self.writer.as_default(): 

94 scalar(name=tag, data=scalar_value, step=step) 

95 

96 def log_scalar_dbg(self, tag: str, scalar_value: Any, step: int): 

97 if not self.disable and self.writer is not None and self.debug: 

98 with self.writer.as_default(): 

99 scalar(name=tag, data=scalar_value, step=step) 

100 

101 def close(self): 

102 if not self.disable and self.writer is not None: 

103 self.writer.flush() 

104 self.writer.close()