Coverage for melissa/utility/external_validator.py: 0%

97 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-11-03 09:52 +0100

1import os 

2import time 

3import hashlib 

4import argparse 

5import logging 

6from abc import ABC, abstractmethod 

7from typing import Any, Optional 

8from dataclasses import dataclass 

9 

10 

11logger = logging.getLogger("melissa") 

12 

13 

14class BaseExternalValidator(ABC): 

15 def __init__( 

16 self, 

17 checkpoint_dir: str, 

18 poll_interval: float, 

19 stop_file: str, 

20 max_ckpt_age: float, 

21 max_iterations: int, 

22 ): 

23 self.checkpoint_dir = checkpoint_dir 

24 self.poll_interval = poll_interval 

25 self.stop_file = stop_file 

26 self.max_ckpt_age = max_ckpt_age 

27 self.max_iterations = max_iterations 

28 self.last_hash = None 

29 self.iterations = 0 

30 

31 @abstractmethod 

32 def filter_files(self, files: list[str]) -> list[str]: 

33 pass 

34 

35 @abstractmethod 

36 def load_checkpoint(self, checkpoint_path: str) -> Any: 

37 pass 

38 

39 @abstractmethod 

40 def validation_entrypoint(self, ckpt: Any) -> Any: 

41 pass 

42 

43 def get_latest_checkpoint_path(self) -> Optional[str]: 

44 

45 if not os.path.exists(self.checkpoint_dir): 

46 return None 

47 

48 files = [ 

49 os.path.join(self.checkpoint_dir, f) 

50 for f in os.listdir(self.checkpoint_dir) 

51 if os.path.isfile(os.path.join(self.checkpoint_dir, f)) 

52 ] 

53 files = self.filter_files(files) 

54 return max(files, key=os.path.getmtime) if files else None 

55 

56 def file_hash(self, path: str) -> str: 

57 h = hashlib.md5() 

58 with open(path, "rb") as f: 

59 while chunk := f.read(8192): 

60 h.update(chunk) 

61 return h.hexdigest() 

62 

63 def should_stop(self, ckpt_path: Optional[str]) -> bool: 

64 if self.stop_file and os.path.exists(self.stop_file): 

65 logger.info("[Validator] Stop file found.") 

66 return True 

67 if ckpt_path: 

68 age = time.time() - os.path.getmtime(ckpt_path) 

69 if age > self.max_ckpt_age: 

70 logger.info("[Validator] Checkpoint too old.") 

71 return True 

72 if self.iterations >= self.max_iterations: 

73 logger.info("[Validator] Max iterations reached.") 

74 return True 

75 return False 

76 

77 def run_loop(self): 

78 logger.info("[Validator] Loop started.") 

79 while True: 

80 ckpt_path = self.get_latest_checkpoint_path() 

81 if self.should_stop(ckpt_path): 

82 break 

83 

84 if ckpt_path: 

85 ckpt_hash = self.file_hash(ckpt_path) 

86 if ckpt_hash != self.last_hash: 

87 self.last_hash = ckpt_hash 

88 logger.info(f"[Validator] New checkpoint: {ckpt_path}") 

89 ckpt = self.load_checkpoint(ckpt_path) 

90 metrics = self.validation_entrypoint(ckpt) 

91 if metrics: 

92 logger.info(f"[Validator] Metrics: {metrics}") 

93 self.iterations += 1 

94 else: 

95 time.sleep(self.poll_interval) 

96 logger.info(f"[Validator] Loop ended with total iterations={self.iterations}.") 

97 

98 

99@dataclass(frozen=True) 

100class Defaults: 

101 checkpoint_dir: str = "./checkpoints" 

102 poll_interval: float = 10.0 

103 stop_file: str = "./STOP_VALIDATION" 

104 max_ckpt_age: float = 600.0 

105 max_iterations: int = 50 

106 

107 

108def get_default_parser(): 

109 parser = argparse.ArgumentParser( 

110 description="Run validation loop on latest model checkpoint." 

111 ) 

112 parser.add_argument("--checkpoint_dir", type=str, default=Defaults.checkpoint_dir) 

113 parser.add_argument("--poll_interval", type=float, default=Defaults.poll_interval) 

114 parser.add_argument("--stop_file", type=str, default=Defaults.stop_file) 

115 parser.add_argument("--max_ckpt_age", type=float, default=Defaults.max_ckpt_age) 

116 parser.add_argument("--max_iterations", type=int, default=Defaults.max_iterations) 

117 

118 return parser 

119 

120 

121class ExternalValidator(BaseExternalValidator): 

122 def load_checkpoint(self, checkpoint_path: str) -> Any: 

123 print(f"Loaded checkpoint from {checkpoint_path}") 

124 

125 def validation_entrypoint(self, ckpt: Any) -> int: 

126 print(f"Running {self.iterations} validation phase...") 

127 return 1 

128 

129 def filter_files(self, files: list[str]) -> list[str]: 

130 return files 

131 

132 

133if __name__ == "__main__": 

134 

135 parser = get_default_parser() 

136 args = parser.parse_args() 

137 

138 ExternalValidator( 

139 checkpoint_dir=args.checkpoint_dir, 

140 poll_interval=args.poll_interval, 

141 stop_file=args.stop_file, 

142 max_ckpt_age=args.max_ckpt_age, 

143 max_iterations=args.max_iterations, 

144 ).run_loop()