improve grader

This commit is contained in:
Georgi Gerganov 2026-02-15 21:50:45 +02:00
parent 68dde884d6
commit d2b10302ce
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
1 changed files with 113 additions and 21 deletions

View File

@ -9,7 +9,7 @@ import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Dict, List, Optional, Any
from typing import Dict, List, Optional, Any, Tuple
import requests
from tqdm import tqdm
import random
@ -47,6 +47,7 @@ class TaskState:
prompt: str
gold: str
pred: Optional[str] = None
extracted: Optional[str] = None
correct: bool = False
status: str = "pending"
@ -97,35 +98,49 @@ class Grader:
self,
grader_type: str = "regex",
grader_regex_type: str = "aime",
grader_script: Optional[str] = None
grader_script: Optional[str] = None,
judge_model_name: Optional[str] = None,
judge_server_url: str = ""
):
self.grader_type = grader_type
self.grader_regex_type = grader_regex_type
self.grader_script = grader_script
self.judge_model_name = judge_model_name
self.judge_server_url = judge_server_url
self.pattern = self._get_pattern()
def _get_pattern(self) -> str:
def _get_pattern(self) -> Optional[str]:
if self.grader_type == "regex":
if self.grader_regex_type not in GRADER_PATTERNS:
raise ValueError(f"Unknown grader regex type: {self.grader_regex_type}")
return GRADER_PATTERNS[self.grader_regex_type]
return None
def _grade_regex(self, gold: str, pred: str) -> bool:
"""Grade using regex pattern matching"""
def _extract_answer_regex(self, pred: str) -> Optional[str]:
"""Extract answer using regex pattern"""
if not self.pattern:
return None
matches = re.findall(self.pattern, pred, re.IGNORECASE)
if not matches:
return False
return None
for match in matches:
if isinstance(match, tuple):
match = match[0] if match[0] else match[1]
if match.strip() == gold.strip():
return True
extracted = match.strip()
if extracted:
return extracted
return None
return False
def _grade_regex(self, gold: str, pred: str) -> Tuple[bool, Optional[str]]:
"""Grade using regex pattern matching"""
extracted = self._extract_answer_regex(pred)
if extracted is None:
return False, None
is_correct = extracted.strip() == gold.strip()
return is_correct, extracted
def _grade_cli(self, gold: str, pred: str) -> bool:
def _grade_cli(self, gold: str, pred: str) -> Tuple[bool, Optional[str]]:
"""Grade using external CLI script"""
if not self.grader_script:
raise ValueError("CLI grader requires --grader-script")
@ -141,18 +156,54 @@ class Grader:
text=True,
timeout=30
)
return result.returncode == 0
is_correct = result.returncode == 0
extracted = pred if is_correct else None
return is_correct, extracted
except subprocess.TimeoutExpired:
return False
return False, None
except Exception as e:
return False
return False, None
def grade(self, gold: str, pred: str) -> bool:
def _grade_llm(self, gold: str, pred: str, problem: str) -> Tuple[bool, Optional[str]]:
"""Grade using LLM-based extraction"""
prompt = f"""Extract the answer from this response:
Response: {pred}
Expected answer: {gold}
Please provide only the extracted answer, nothing else."""
url = f"{self.judge_server_url}/v1/chat/completions"
headers = {"Content-Type": "application/json"}
data = {
"model": self.judge_model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0,
"max_tokens": 256
}
try:
response = requests.post(url, headers=headers, json=data)
response.raise_for_status()
extracted = response.json()["choices"][0]["message"]["content"].strip()
is_correct = extracted.strip().lower() == gold.strip().lower()
return is_correct, extracted
except Exception as e:
return False, None
def _truncate_response(self, response: str, max_lines: int = 3) -> str:
"""Keep only last N lines of response"""
lines = response.split('\n')
return '\n'.join(lines[-max_lines:]) if len(lines) > max_lines else response
def grade(self, gold: str, pred: str, problem: str = "") -> Tuple[bool, Optional[str]]:
"""Grade the response"""
if self.grader_type == "regex":
return self._grade_regex(gold, pred)
elif self.grader_type == "cli":
return self._grade_cli(gold, pred)
elif self.grader_type == "llm":
return self._grade_llm(gold, pred, problem)
else:
raise ValueError(f"Unknown grader type: {self.grader_type}")
@ -164,13 +215,17 @@ class Processor:
threads: int = 32,
verbose: bool = False,
grader: Optional[Grader] = None,
model_name: Optional[str] = None
model_name: Optional[str] = None,
judge_server_url: str = "",
judge_model_name: Optional[str] = None
):
self.server_url = server_url
self.n_predict = n_predict
self.threads = threads
self.verbose = verbose
self.model_name = model_name
self.judge_server_url = judge_server_url if judge_server_url else server_url
self.judge_model_name = judge_model_name
self.dataset = AimeDataset()
self.grader = grader or Grader()
self.eval_state = EvalState(
@ -180,6 +235,13 @@ class Processor:
sampling_config={"temperature": 0, "max_tokens": n_predict}
)
# Pass judge configuration to grader if using LLM grader
if self.grader.grader_type == "llm":
if self.judge_model_name:
self.grader.judge_model_name = self.judge_model_name
if self.judge_server_url:
self.grader.judge_server_url = self.judge_server_url
def _make_request(self, prompt: str) -> Dict[str, Any]:
"""Make HTTP request to the server"""
url = f"{self.server_url}/v1/chat/completions"
@ -217,7 +279,14 @@ class Processor:
response = self._make_request(prompt)
pred = response["choices"][0]["message"]["content"]
task_state.pred = pred
task_state.correct = self.grader.grade(gold, pred)
# Truncate response to last 2-3 lines for grading
pred_truncated = self.grader._truncate_response(pred, max_lines=3)
# Grade the response
is_correct, extracted = self.grader.grade(gold, pred_truncated, prompt)
task_state.correct = is_correct
task_state.extracted = extracted
task_state.status = "ok"
except Exception as e:
task_state.status = f"error: {str(e)}"
@ -233,6 +302,10 @@ class Processor:
print(f"Server: {self.server_url}")
print(f"Threads: {self.threads}")
print(f"Max tokens: {self.n_predict}")
print(f"Grader: {self.grader.grader_type}", end="")
if self.grader.grader_type == "llm":
print(f" (judge server: {self.judge_server_url}, model: {self.judge_model_name})", end="")
print()
print()
dataset_size = len(self.dataset.questions)
@ -276,15 +349,17 @@ class Processor:
correct += 1
# Print task completion status
pred_display = task_state.pred if task_state.pred else "N/A"
extracted_display = task_state.extracted if task_state.extracted else "N/A"
success_ratio = correct / total if total > 0 else 0.0
print(f"{total:3}/{n_cases:3} {task_state.case_id:<15} AIME2025 {task_state.prompt[:40]:<40} {task_state.gold:<10} {pred_display:<10} {'' if task_state.correct else ''} [{correct:3}/{total:3}, {success_ratio:.3f}]")
print(f"{total:3}/{n_cases:3} {task_state.case_id:<15} AIME2025 {task_state.prompt[:40]:<40} {task_state.gold:<10} {extracted_display:<10} {'' if task_state.correct else ''} [{correct:3}/{total:3}, {success_ratio:.3f}]")
if self.verbose:
print(f"\nCase {total}: {task_state.correct}")
print(f" Gold: {task_state.gold}")
if task_state.pred:
print(f" Pred: {task_state.pred}")
if task_state.extracted:
print(f" Extracted: {task_state.extracted}")
print(f" Status: {task_state.status}")
self.eval_state.task_states["aime"] = {
@ -360,8 +435,8 @@ def main():
"--grader-type",
type=str,
default="regex",
choices=["regex", "cli"],
help="Grader type: regex or cli (default: regex)"
choices=["regex", "cli", "llm"],
help="Grader type: regex, cli, or llm (default: regex)"
)
parser.add_argument(
"--grader-regex-type",
@ -376,6 +451,18 @@ def main():
default=None,
help="CLI grader script path (required for --grader-type cli)"
)
parser.add_argument(
"--judge-server",
type=str,
default="",
help="Server URL for LLM judge (default: same as main server)"
)
parser.add_argument(
"--judge-model",
type=str,
default=None,
help="Model name for LLM judge (default: same as main model)"
)
args = parser.parse_args()
@ -385,13 +472,18 @@ def main():
grader_script=args.grader_script
)
if args.grader_type == "llm" and not args.judge_server:
print("Warning: Using same server for LLM judge (no --judge-server specified)")
processor = Processor(
server_url=args.server,
n_predict=args.n_predict,
threads=args.threads,
verbose=args.verbose,
grader=grader,
model_name=args.model
model_name=args.model,
judge_server_url=args.judge_server,
judge_model_name=args.judge_model
)
eval_state = processor.process(n_cases=args.n_cases, seed=args.seed)