Coverage for melissa/server/deep_learning/tensorboard/__init__.py: 17%

23 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-10 22:25 +0100

1from melissa.server.deep_learning.tensorboard.base_logger import ( 

2 TensorboardLogger, 

3 convert_tb_logs_to_df 

4) 

5from melissa.server.deep_learning import FrameworkType 

6 

7 

8__all__ = [ 

9 "TensorboardLogger", 

10 "convert_tb_logs_to_df", 

11 "make_tb_logger" 

12] 

13 

14 

15def make_tb_logger(framework_t: FrameworkType, 

16 rank: int = 0, 

17 logdir: str = "tensorboard", 

18 disable: bool = False, 

19 debug: bool = False) -> TensorboardLogger: 

20 

21 """Factory function to create a TensorBoard logger based on the specified deep learning 

22 framework. 

23 

24 ### Parameters 

25 - **framework_t** (`FrameworkType`): The type of framework (`TORCH`, `TENSORFLOW`, 

26 or `DEFAULT`). 

27 - **rank** (`int`, optional): Rank of the process (used for distributed training). 

28 Defaults to `0`. 

29 - **logdir** (`str`, optional): Directory where TensorBoard logs are stored. 

30 Defaults to `"tensorboard"`. 

31 - **disable** (`bool`, optional): If `True`, disables logging. Defaults to `False`. 

32 - **debug** (`bool`, optional): If `True`, enables debug mode for the logger. 

33 Defaults to `False`. 

34 

35 ### Returns 

36 - `TensorboardLogger`: An instance of the appropriate TensorBoard logger 

37 (`TorchTensorboardLogger` or `TfTensorboardLogger`). 

38 

39 ### Raises 

40 - `ModuleNotFoundError`: If `DEFAULT` is selected but neither PyTorch nor TensorFlow 

41 loggers are available. 

42 - `ValueError`: If an unsupported framework type is provided.""" 

43 

44 if framework_t is FrameworkType.TORCH: 

45 from melissa.server.deep_learning.tensorboard.torch_logger import TorchTensorboardLogger 

46 return TorchTensorboardLogger(rank, logdir, disable, debug) 

47 

48 elif framework_t is FrameworkType.TENSORFLOW: 

49 from melissa.server.deep_learning.tensorboard.tf_logger import TfTensorboardLogger 

50 return TfTensorboardLogger(rank, logdir, disable, debug) 

51 

52 elif framework_t is FrameworkType.DEFAULT: 

53 try: 

54 from melissa.server.deep_learning.tensorboard.torch_logger import TorchTensorboardLogger 

55 return TorchTensorboardLogger(rank, logdir, disable, debug) 

56 except ModuleNotFoundError: 

57 pass 

58 

59 try: 

60 from melissa.server.deep_learning.tensorboard.tf_logger import TfTensorboardLogger 

61 return TfTensorboardLogger(rank, logdir, disable, debug) 

62 except ModuleNotFoundError: 

63 pass 

64 

65 raise ModuleNotFoundError("Neither Torch nor TensorFlow TensorBoard loggers are available.") 

66 

67 else: 

68 raise ValueError(f"Unsupported framework type: {framework_t}")