Coverage for melissa/server/deep_learning/dataset.py: 69%

64 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-09-22 10:36 +0000

1from typing import Union, List, Dict, Any, Callable, Optional 

2import logging 

3import time 

4 

5try: # double import for unit tests and CI 

6 from torch.utils.data.dataset import IterableDataset 

7 import tensorflow as tf 

8except ModuleNotFoundError: # only one framework available 

9 try: 

10 from torch.utils.data.dataset import IterableDataset 

11 except ModuleNotFoundError: 

12 import tensorflow as tf 

13 

14 class IterableDataset: # type: ignore 

15 pass 

16 

17from melissa.server.simulation import SimulationData 

18from melissa.server.deep_learning.reservoir import FIFO, Empty 

19from melissa.server.deep_learning.tensorboard_logger import TensorboardLogger 

20 

21 

22logger = logging.getLogger(__name__) 

23 

24 

25class MelissaIterableDataset: 

26 def __init__( 

27 self, 

28 buffer: FIFO, 

29 config: dict = {}, 

30 transform: Optional[Callable] = None, 

31 tb_logger: Optional[TensorboardLogger] = None, 

32 ) -> None: 

33 self.buffer: FIFO = buffer 

34 self.tb_logger = tb_logger 

35 self._is_receiving: bool = True 

36 self.batch = 0 

37 self.sample_number = 0 

38 self.config: Dict[str, Any] = config 

39 self.transform = transform 

40 

41 @property 

42 def has_data(self) -> bool: 

43 return self._is_receiving or not self.buffer.empty() 

44 

45 def signal_reception_over(self): 

46 self._is_receiving = False 

47 self.buffer.signal_reception_over() 

48 

49 def __iter__(self): 

50 # Infinite iterator which will always try to pull from the 

51 # buffer as long as the buffer is not empty or the server 

52 # is still receiving data 

53 while self.has_data: 

54 try: 

55 start_get = time.time() 

56 items: Union[SimulationData, List[SimulationData]] = self.buffer.get(timeout=1) 

57 end_get = time.time() 

58 start_processing = time.time() 

59 if self.transform: 

60 data = self.transform(items, self.config) 

61 elif isinstance(items, list): 

62 data = [item.data for item in items] 

63 else: 

64 data = items.data 

65 end_processing = time.time() 

66 if self.tb_logger: 

67 self.sample_number += 1 

68 self.tb_logger.log_scalar("get_time", end_get - start_get, self.sample_number) 

69 self.tb_logger.log_scalar( 

70 "processing_time", end_processing - start_processing, self.sample_number 

71 ) 

72 self.tb_logger.log_scalar("buffer_size", len(self.buffer), self.sample_number) 

73 yield data 

74 except Empty: 

75 logger.warning("Buffer empty but still receiving.") 

76 continue 

77 

78 

79class TorchMelissaIterableDataset(MelissaIterableDataset, IterableDataset): 

80 """ 

81 Object whose only job is to make an iterable dataset available to torch. 

82 Only used by Melissa-DL children. 

83 """ 

84 def __init__( 

85 self, 

86 buffer: FIFO, 

87 config: dict = {}, 

88 transform: Optional[Callable] = None, 

89 tb_logger: Optional[TensorboardLogger] = None, 

90 ) -> None: 

91 MelissaIterableDataset.__init__(self, buffer, config, transform, tb_logger) 

92 IterableDataset.__init__(self) 

93 

94 

95class TfMelissaIterableDataset(MelissaIterableDataset): 

96 """ 

97 Object whose only job is to make an iterable dataset available to tensorflow. 

98 Only used by Melissa-DL children. 

99 """ 

100 def __init__( 

101 self, 

102 buffer: FIFO, 

103 config: dict = {}, 

104 transform: Optional[Callable] = None, 

105 tb_logger: Optional[TensorboardLogger] = None, 

106 ) -> None: 

107 MelissaIterableDataset.__init__(self, buffer, config, transform, tb_logger) 

108 

109 def as_tensorflow_dataset(self): 

110 dataset = tf.data.Dataset.from_generator( 

111 self.__iter__, 

112 output_types=(tf.float32, tf.float32), 

113 output_shapes=((None,), (None,)), 

114 ) 

115 return dataset