Coverage for melissa/utility/external_validator.py: 0%
97 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
« prev ^ index » next coverage.py v7.10.1, created at 2025-11-03 09:52 +0100
1import os
2import time
3import hashlib
4import argparse
5import logging
6from abc import ABC, abstractmethod
7from typing import Any, Optional
8from dataclasses import dataclass
11logger = logging.getLogger("melissa")
14class BaseExternalValidator(ABC):
15 def __init__(
16 self,
17 checkpoint_dir: str,
18 poll_interval: float,
19 stop_file: str,
20 max_ckpt_age: float,
21 max_iterations: int,
22 ):
23 self.checkpoint_dir = checkpoint_dir
24 self.poll_interval = poll_interval
25 self.stop_file = stop_file
26 self.max_ckpt_age = max_ckpt_age
27 self.max_iterations = max_iterations
28 self.last_hash = None
29 self.iterations = 0
31 @abstractmethod
32 def filter_files(self, files: list[str]) -> list[str]:
33 pass
35 @abstractmethod
36 def load_checkpoint(self, checkpoint_path: str) -> Any:
37 pass
39 @abstractmethod
40 def validation_entrypoint(self, ckpt: Any) -> Any:
41 pass
43 def get_latest_checkpoint_path(self) -> Optional[str]:
45 if not os.path.exists(self.checkpoint_dir):
46 return None
48 files = [
49 os.path.join(self.checkpoint_dir, f)
50 for f in os.listdir(self.checkpoint_dir)
51 if os.path.isfile(os.path.join(self.checkpoint_dir, f))
52 ]
53 files = self.filter_files(files)
54 return max(files, key=os.path.getmtime) if files else None
56 def file_hash(self, path: str) -> str:
57 h = hashlib.md5()
58 with open(path, "rb") as f:
59 while chunk := f.read(8192):
60 h.update(chunk)
61 return h.hexdigest()
63 def should_stop(self, ckpt_path: Optional[str]) -> bool:
64 if self.stop_file and os.path.exists(self.stop_file):
65 logger.info("[Validator] Stop file found.")
66 return True
67 if ckpt_path:
68 age = time.time() - os.path.getmtime(ckpt_path)
69 if age > self.max_ckpt_age:
70 logger.info("[Validator] Checkpoint too old.")
71 return True
72 if self.iterations >= self.max_iterations:
73 logger.info("[Validator] Max iterations reached.")
74 return True
75 return False
77 def run_loop(self):
78 logger.info("[Validator] Loop started.")
79 while True:
80 ckpt_path = self.get_latest_checkpoint_path()
81 if self.should_stop(ckpt_path):
82 break
84 if ckpt_path:
85 ckpt_hash = self.file_hash(ckpt_path)
86 if ckpt_hash != self.last_hash:
87 self.last_hash = ckpt_hash
88 logger.info(f"[Validator] New checkpoint: {ckpt_path}")
89 ckpt = self.load_checkpoint(ckpt_path)
90 metrics = self.validation_entrypoint(ckpt)
91 if metrics:
92 logger.info(f"[Validator] Metrics: {metrics}")
93 self.iterations += 1
94 else:
95 time.sleep(self.poll_interval)
96 logger.info(f"[Validator] Loop ended with total iterations={self.iterations}.")
99@dataclass(frozen=True)
100class Defaults:
101 checkpoint_dir: str = "./checkpoints"
102 poll_interval: float = 10.0
103 stop_file: str = "./STOP_VALIDATION"
104 max_ckpt_age: float = 600.0
105 max_iterations: int = 50
108def get_default_parser():
109 parser = argparse.ArgumentParser(
110 description="Run validation loop on latest model checkpoint."
111 )
112 parser.add_argument("--checkpoint_dir", type=str, default=Defaults.checkpoint_dir)
113 parser.add_argument("--poll_interval", type=float, default=Defaults.poll_interval)
114 parser.add_argument("--stop_file", type=str, default=Defaults.stop_file)
115 parser.add_argument("--max_ckpt_age", type=float, default=Defaults.max_ckpt_age)
116 parser.add_argument("--max_iterations", type=int, default=Defaults.max_iterations)
118 return parser
121class ExternalValidator(BaseExternalValidator):
122 def load_checkpoint(self, checkpoint_path: str) -> Any:
123 print(f"Loaded checkpoint from {checkpoint_path}")
125 def validation_entrypoint(self, ckpt: Any) -> int:
126 print(f"Running {self.iterations} validation phase...")
127 return 1
129 def filter_files(self, files: list[str]) -> list[str]:
130 return files
133if __name__ == "__main__":
135 parser = get_default_parser()
136 args = parser.parse_args()
138 ExternalValidator(
139 checkpoint_dir=args.checkpoint_dir,
140 poll_interval=args.poll_interval,
141 stop_file=args.stop_file,
142 max_ckpt_age=args.max_ckpt_age,
143 max_iterations=args.max_iterations,
144 ).run_loop()