examples: implement flexible grader system for answer validation
- Add Grader class supporting regex and CLI-based grading - Implement built-in regex patterns for AIME, GSM8K, MMLU, HellaSwag, ARC, WinoGrande - Add CLI grader interface: python script.py --answer <pred> --expected <gold> - Add HF telemetry disable to avoid warnings - Support exact match requirement for regex patterns - Add 30-second timeout for CLI grader - Handle both boxed and plain text formats for AIME answers
This commit is contained in:
parent
a80814e97b
commit
5a1be6ce37
|
|
@ -3,6 +3,8 @@
|
|||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
from dataclasses import dataclass, asdict
|
||||
from pathlib import Path
|
||||
|
|
@ -13,6 +15,16 @@ from tqdm import tqdm
|
|||
cache_dir = Path.home() / ".cache" / "huggingface" / "datasets"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
os.environ["HF_DATASETS_CACHE"] = str(cache_dir)
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
|
||||
GRADER_PATTERNS = {
|
||||
"aime": r'\boxed{(\d+)}|\b(\d+)\b',
|
||||
"gsm8k": r'\b(\d+)\b',
|
||||
"mmlu": r'[A-D]',
|
||||
"hellaswag": r'[A-D]',
|
||||
"arc": r'[A-D]',
|
||||
"winogrande": r'[A-D]',
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class EvalState:
|
||||
|
|
@ -50,19 +62,85 @@ class AimeDataset:
|
|||
def get_answer(self, question: Dict) -> str:
|
||||
return str(question["answer"])
|
||||
|
||||
class Grader:
|
||||
def __init__(
|
||||
self,
|
||||
grader_type: str = "regex",
|
||||
grader_regex_type: str = "aime",
|
||||
grader_script: Optional[str] = None
|
||||
):
|
||||
self.grader_type = grader_type
|
||||
self.grader_regex_type = grader_regex_type
|
||||
self.grader_script = grader_script
|
||||
self.pattern = self._get_pattern()
|
||||
|
||||
def _get_pattern(self) -> str:
|
||||
if self.grader_type == "regex":
|
||||
if self.grader_regex_type not in GRADER_PATTERNS:
|
||||
raise ValueError(f"Unknown grader regex type: {self.grader_regex_type}")
|
||||
return GRADER_PATTERNS[self.grader_regex_type]
|
||||
return None
|
||||
|
||||
def _grade_regex(self, gold: str, pred: str) -> bool:
|
||||
"""Grade using regex pattern matching"""
|
||||
matches = re.findall(self.pattern, pred, re.IGNORECASE)
|
||||
if not matches:
|
||||
return False
|
||||
|
||||
for match in matches:
|
||||
if isinstance(match, tuple):
|
||||
match = match[0] if match[0] else match[1]
|
||||
if match.strip() == gold.strip():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _grade_cli(self, gold: str, pred: str) -> bool:
|
||||
"""Grade using external CLI script"""
|
||||
if not self.grader_script:
|
||||
raise ValueError("CLI grader requires --grader-script")
|
||||
|
||||
script_path = Path(self.grader_script)
|
||||
if not script_path.exists():
|
||||
raise FileNotFoundError(f"Grader script not found: {self.grader_script}")
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[str(script_path), "--answer", pred, "--expected", gold],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30
|
||||
)
|
||||
return result.returncode == 0
|
||||
except subprocess.TimeoutExpired:
|
||||
return False
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
def grade(self, gold: str, pred: str) -> bool:
|
||||
"""Grade the response"""
|
||||
if self.grader_type == "regex":
|
||||
return self._grade_regex(gold, pred)
|
||||
elif self.grader_type == "cli":
|
||||
return self._grade_cli(gold, pred)
|
||||
else:
|
||||
raise ValueError(f"Unknown grader type: {self.grader_type}")
|
||||
|
||||
class Processor:
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
n_predict: int = 2048,
|
||||
threads: int = 32,
|
||||
verbose: bool = False
|
||||
verbose: bool = False,
|
||||
grader: Optional[Grader] = None
|
||||
):
|
||||
self.server_url = server_url
|
||||
self.n_predict = n_predict
|
||||
self.threads = threads
|
||||
self.verbose = verbose
|
||||
self.dataset = AimeDataset()
|
||||
self.grader = grader or Grader()
|
||||
self.eval_state = EvalState(
|
||||
id="aime-2025",
|
||||
tasks=["aime"],
|
||||
|
|
@ -85,15 +163,6 @@ class Processor:
|
|||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _grade_response(self, gold: str, pred: str) -> bool:
|
||||
"""Grade the response - abstracted for external grader support"""
|
||||
try:
|
||||
gold_int = int(gold)
|
||||
pred_int = int(pred)
|
||||
return gold_int == pred_int
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
def process(self, n_cases: int = None, seed: int = 42):
|
||||
"""Process cases and update eval state"""
|
||||
if n_cases is None:
|
||||
|
|
@ -125,7 +194,7 @@ class Processor:
|
|||
response = self._make_request(prompt)
|
||||
pred = response["choices"][0]["message"]["content"]
|
||||
task_state.pred = pred
|
||||
task_state.correct = self._grade_response(gold, pred)
|
||||
task_state.correct = self.grader.grade(gold, pred)
|
||||
task_state.status = "ok"
|
||||
|
||||
if task_state.correct:
|
||||
|
|
@ -200,14 +269,41 @@ def main():
|
|||
default=Path("llama-eval-state.json"),
|
||||
help="Output file for eval state (default: llama-eval-state.json)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grader-type",
|
||||
type=str,
|
||||
default="regex",
|
||||
choices=["regex", "cli"],
|
||||
help="Grader type: regex or cli (default: regex)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grader-regex-type",
|
||||
type=str,
|
||||
default="aime",
|
||||
choices=list(GRADER_PATTERNS.keys()),
|
||||
help="Regex grader type (default: aime)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grader-script",
|
||||
type=str,
|
||||
default=None,
|
||||
help="CLI grader script path (required for --grader-type cli)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
grader = Grader(
|
||||
grader_type=args.grader_type,
|
||||
grader_regex_type=args.grader_regex_type,
|
||||
grader_script=args.grader_script
|
||||
)
|
||||
|
||||
processor = Processor(
|
||||
server_url=args.server,
|
||||
n_predict=args.n_predict,
|
||||
threads=args.threads,
|
||||
verbose=args.verbose
|
||||
verbose=args.verbose,
|
||||
grader=grader
|
||||
)
|
||||
|
||||
eval_state = processor.process(n_cases=args.n_cases)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Test grader script")
|
||||
parser.add_argument("--answer", type=str, required=True, help="Predicted answer")
|
||||
parser.add_argument("--expected", type=str, required=True, help="Expected answer")
|
||||
args = parser.parse_args()
|
||||
|
||||
pred = args.answer.strip()
|
||||
gold = args.expected.strip()
|
||||
|
||||
print(f"Gold: {gold}")
|
||||
print(f"Pred: {pred}")
|
||||
|
||||
if pred == gold:
|
||||
print("Correct!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("Incorrect")
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue