diff --git a/examples/llama-eval/llama-eval-new.py b/examples/llama-eval/llama-eval-new.py index a27ed4a37c..1026ecee44 100755 --- a/examples/llama-eval/llama-eval-new.py +++ b/examples/llama-eval/llama-eval-new.py @@ -3,6 +3,8 @@ import argparse import json import os +import re +import subprocess import time from dataclasses import dataclass, asdict from pathlib import Path @@ -13,6 +15,16 @@ from tqdm import tqdm cache_dir = Path.home() / ".cache" / "huggingface" / "datasets" cache_dir.mkdir(parents=True, exist_ok=True) os.environ["HF_DATASETS_CACHE"] = str(cache_dir) +os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" + +GRADER_PATTERNS = { + "aime": r'\boxed{(\d+)}|\b(\d+)\b', + "gsm8k": r'\b(\d+)\b', + "mmlu": r'[A-D]', + "hellaswag": r'[A-D]', + "arc": r'[A-D]', + "winogrande": r'[A-D]', +} @dataclass class EvalState: @@ -50,19 +62,85 @@ class AimeDataset: def get_answer(self, question: Dict) -> str: return str(question["answer"]) +class Grader: + def __init__( + self, + grader_type: str = "regex", + grader_regex_type: str = "aime", + grader_script: Optional[str] = None + ): + self.grader_type = grader_type + self.grader_regex_type = grader_regex_type + self.grader_script = grader_script + self.pattern = self._get_pattern() + + def _get_pattern(self) -> str: + if self.grader_type == "regex": + if self.grader_regex_type not in GRADER_PATTERNS: + raise ValueError(f"Unknown grader regex type: {self.grader_regex_type}") + return GRADER_PATTERNS[self.grader_regex_type] + return None + + def _grade_regex(self, gold: str, pred: str) -> bool: + """Grade using regex pattern matching""" + matches = re.findall(self.pattern, pred, re.IGNORECASE) + if not matches: + return False + + for match in matches: + if isinstance(match, tuple): + match = match[0] if match[0] else match[1] + if match.strip() == gold.strip(): + return True + + return False + + def _grade_cli(self, gold: str, pred: str) -> bool: + """Grade using external CLI script""" + if not self.grader_script: + raise ValueError("CLI grader requires --grader-script") + + script_path = Path(self.grader_script) + if not script_path.exists(): + raise FileNotFoundError(f"Grader script not found: {self.grader_script}") + + try: + result = subprocess.run( + [str(script_path), "--answer", pred, "--expected", gold], + capture_output=True, + text=True, + timeout=30 + ) + return result.returncode == 0 + except subprocess.TimeoutExpired: + return False + except Exception as e: + return False + + def grade(self, gold: str, pred: str) -> bool: + """Grade the response""" + if self.grader_type == "regex": + return self._grade_regex(gold, pred) + elif self.grader_type == "cli": + return self._grade_cli(gold, pred) + else: + raise ValueError(f"Unknown grader type: {self.grader_type}") + class Processor: def __init__( self, server_url: str, n_predict: int = 2048, threads: int = 32, - verbose: bool = False + verbose: bool = False, + grader: Optional[Grader] = None ): self.server_url = server_url self.n_predict = n_predict self.threads = threads self.verbose = verbose self.dataset = AimeDataset() + self.grader = grader or Grader() self.eval_state = EvalState( id="aime-2025", tasks=["aime"], @@ -85,15 +163,6 @@ class Processor: response.raise_for_status() return response.json() - def _grade_response(self, gold: str, pred: str) -> bool: - """Grade the response - abstracted for external grader support""" - try: - gold_int = int(gold) - pred_int = int(pred) - return gold_int == pred_int - except (ValueError, TypeError): - return False - def process(self, n_cases: int = None, seed: int = 42): """Process cases and update eval state""" if n_cases is None: @@ -125,7 +194,7 @@ class Processor: response = self._make_request(prompt) pred = response["choices"][0]["message"]["content"] task_state.pred = pred - task_state.correct = self._grade_response(gold, pred) + task_state.correct = self.grader.grade(gold, pred) task_state.status = "ok" if task_state.correct: @@ -200,14 +269,41 @@ def main(): default=Path("llama-eval-state.json"), help="Output file for eval state (default: llama-eval-state.json)" ) + parser.add_argument( + "--grader-type", + type=str, + default="regex", + choices=["regex", "cli"], + help="Grader type: regex or cli (default: regex)" + ) + parser.add_argument( + "--grader-regex-type", + type=str, + default="aime", + choices=list(GRADER_PATTERNS.keys()), + help="Regex grader type (default: aime)" + ) + parser.add_argument( + "--grader-script", + type=str, + default=None, + help="CLI grader script path (required for --grader-type cli)" + ) args = parser.parse_args() + grader = Grader( + grader_type=args.grader_type, + grader_regex_type=args.grader_regex_type, + grader_script=args.grader_script + ) + processor = Processor( server_url=args.server, n_predict=args.n_predict, threads=args.threads, - verbose=args.verbose + verbose=args.verbose, + grader=grader ) eval_state = processor.process(n_cases=args.n_cases) diff --git a/examples/llama-eval/test-grader.py b/examples/llama-eval/test-grader.py new file mode 100755 index 0000000000..c32901cf70 --- /dev/null +++ b/examples/llama-eval/test-grader.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 + +import sys +import argparse + +def main(): + parser = argparse.ArgumentParser(description="Test grader script") + parser.add_argument("--answer", type=str, required=True, help="Predicted answer") + parser.add_argument("--expected", type=str, required=True, help="Expected answer") + args = parser.parse_args() + + pred = args.answer.strip() + gold = args.expected.strip() + + print(f"Gold: {gold}") + print(f"Pred: {pred}") + + if pred == gold: + print("Correct!") + sys.exit(0) + else: + print("Incorrect") + sys.exit(1) + +if __name__ == "__main__": + main()