examples: implement flexible grader system for answer validation

- Add Grader class supporting regex and CLI-based grading
- Implement built-in regex patterns for AIME, GSM8K, MMLU, HellaSwag, ARC, WinoGrande
- Add CLI grader interface: python script.py --answer <pred> --expected <gold>
- Add HF telemetry disable to avoid warnings
- Support exact match requirement for regex patterns
- Add 30-second timeout for CLI grader
- Handle both boxed and plain text formats for AIME answers
This commit is contained in:
Georgi Gerganov 2026-01-31 16:31:46 +02:00
parent a80814e97b
commit 5a1be6ce37
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 134 additions and 12 deletions

View File

@ -3,6 +3,8 @@
import argparse
import json
import os
import re
import subprocess
import time
from dataclasses import dataclass, asdict
from pathlib import Path
@ -13,6 +15,16 @@ from tqdm import tqdm
cache_dir = Path.home() / ".cache" / "huggingface" / "datasets"
cache_dir.mkdir(parents=True, exist_ok=True)
os.environ["HF_DATASETS_CACHE"] = str(cache_dir)
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
GRADER_PATTERNS = {
"aime": r'\boxed{(\d+)}|\b(\d+)\b',
"gsm8k": r'\b(\d+)\b',
"mmlu": r'[A-D]',
"hellaswag": r'[A-D]',
"arc": r'[A-D]',
"winogrande": r'[A-D]',
}
@dataclass
class EvalState:
@ -50,19 +62,85 @@ class AimeDataset:
def get_answer(self, question: Dict) -> str:
return str(question["answer"])
class Grader:
def __init__(
self,
grader_type: str = "regex",
grader_regex_type: str = "aime",
grader_script: Optional[str] = None
):
self.grader_type = grader_type
self.grader_regex_type = grader_regex_type
self.grader_script = grader_script
self.pattern = self._get_pattern()
def _get_pattern(self) -> str:
if self.grader_type == "regex":
if self.grader_regex_type not in GRADER_PATTERNS:
raise ValueError(f"Unknown grader regex type: {self.grader_regex_type}")
return GRADER_PATTERNS[self.grader_regex_type]
return None
def _grade_regex(self, gold: str, pred: str) -> bool:
"""Grade using regex pattern matching"""
matches = re.findall(self.pattern, pred, re.IGNORECASE)
if not matches:
return False
for match in matches:
if isinstance(match, tuple):
match = match[0] if match[0] else match[1]
if match.strip() == gold.strip():
return True
return False
def _grade_cli(self, gold: str, pred: str) -> bool:
"""Grade using external CLI script"""
if not self.grader_script:
raise ValueError("CLI grader requires --grader-script")
script_path = Path(self.grader_script)
if not script_path.exists():
raise FileNotFoundError(f"Grader script not found: {self.grader_script}")
try:
result = subprocess.run(
[str(script_path), "--answer", pred, "--expected", gold],
capture_output=True,
text=True,
timeout=30
)
return result.returncode == 0
except subprocess.TimeoutExpired:
return False
except Exception as e:
return False
def grade(self, gold: str, pred: str) -> bool:
"""Grade the response"""
if self.grader_type == "regex":
return self._grade_regex(gold, pred)
elif self.grader_type == "cli":
return self._grade_cli(gold, pred)
else:
raise ValueError(f"Unknown grader type: {self.grader_type}")
class Processor:
def __init__(
self,
server_url: str,
n_predict: int = 2048,
threads: int = 32,
verbose: bool = False
verbose: bool = False,
grader: Optional[Grader] = None
):
self.server_url = server_url
self.n_predict = n_predict
self.threads = threads
self.verbose = verbose
self.dataset = AimeDataset()
self.grader = grader or Grader()
self.eval_state = EvalState(
id="aime-2025",
tasks=["aime"],
@ -85,15 +163,6 @@ class Processor:
response.raise_for_status()
return response.json()
def _grade_response(self, gold: str, pred: str) -> bool:
"""Grade the response - abstracted for external grader support"""
try:
gold_int = int(gold)
pred_int = int(pred)
return gold_int == pred_int
except (ValueError, TypeError):
return False
def process(self, n_cases: int = None, seed: int = 42):
"""Process cases and update eval state"""
if n_cases is None:
@ -125,7 +194,7 @@ class Processor:
response = self._make_request(prompt)
pred = response["choices"][0]["message"]["content"]
task_state.pred = pred
task_state.correct = self._grade_response(gold, pred)
task_state.correct = self.grader.grade(gold, pred)
task_state.status = "ok"
if task_state.correct:
@ -200,14 +269,41 @@ def main():
default=Path("llama-eval-state.json"),
help="Output file for eval state (default: llama-eval-state.json)"
)
parser.add_argument(
"--grader-type",
type=str,
default="regex",
choices=["regex", "cli"],
help="Grader type: regex or cli (default: regex)"
)
parser.add_argument(
"--grader-regex-type",
type=str,
default="aime",
choices=list(GRADER_PATTERNS.keys()),
help="Regex grader type (default: aime)"
)
parser.add_argument(
"--grader-script",
type=str,
default=None,
help="CLI grader script path (required for --grader-type cli)"
)
args = parser.parse_args()
grader = Grader(
grader_type=args.grader_type,
grader_regex_type=args.grader_regex_type,
grader_script=args.grader_script
)
processor = Processor(
server_url=args.server,
n_predict=args.n_predict,
threads=args.threads,
verbose=args.verbose
verbose=args.verbose,
grader=grader
)
eval_state = processor.process(n_cases=args.n_cases)

View File

@ -0,0 +1,26 @@
#!/usr/bin/env python3
import sys
import argparse
def main():
parser = argparse.ArgumentParser(description="Test grader script")
parser.add_argument("--answer", type=str, required=True, help="Predicted answer")
parser.add_argument("--expected", type=str, required=True, help="Expected answer")
args = parser.parse_args()
pred = args.answer.strip()
gold = args.expected.strip()
print(f"Gold: {gold}")
print(f"Pred: {pred}")
if pred == gold:
print("Correct!")
sys.exit(0)
else:
print("Incorrect")
sys.exit(1)
if __name__ == "__main__":
main()