improve grader

This commit is contained in:
Georgi Gerganov 2026-02-15 21:50:45 +02:00
parent 68dde884d6
commit d2b10302ce
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
1 changed files with 113 additions and 21 deletions

View File

@ -9,7 +9,7 @@ import time
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Any from typing import Dict, List, Optional, Any, Tuple
import requests import requests
from tqdm import tqdm from tqdm import tqdm
import random import random
@ -47,6 +47,7 @@ class TaskState:
prompt: str prompt: str
gold: str gold: str
pred: Optional[str] = None pred: Optional[str] = None
extracted: Optional[str] = None
correct: bool = False correct: bool = False
status: str = "pending" status: str = "pending"
@ -97,35 +98,49 @@ class Grader:
self, self,
grader_type: str = "regex", grader_type: str = "regex",
grader_regex_type: str = "aime", grader_regex_type: str = "aime",
grader_script: Optional[str] = None grader_script: Optional[str] = None,
judge_model_name: Optional[str] = None,
judge_server_url: str = ""
): ):
self.grader_type = grader_type self.grader_type = grader_type
self.grader_regex_type = grader_regex_type self.grader_regex_type = grader_regex_type
self.grader_script = grader_script self.grader_script = grader_script
self.judge_model_name = judge_model_name
self.judge_server_url = judge_server_url
self.pattern = self._get_pattern() self.pattern = self._get_pattern()
def _get_pattern(self) -> str: def _get_pattern(self) -> Optional[str]:
if self.grader_type == "regex": if self.grader_type == "regex":
if self.grader_regex_type not in GRADER_PATTERNS: if self.grader_regex_type not in GRADER_PATTERNS:
raise ValueError(f"Unknown grader regex type: {self.grader_regex_type}") raise ValueError(f"Unknown grader regex type: {self.grader_regex_type}")
return GRADER_PATTERNS[self.grader_regex_type] return GRADER_PATTERNS[self.grader_regex_type]
return None return None
def _grade_regex(self, gold: str, pred: str) -> bool: def _extract_answer_regex(self, pred: str) -> Optional[str]:
"""Grade using regex pattern matching""" """Extract answer using regex pattern"""
if not self.pattern:
return None
matches = re.findall(self.pattern, pred, re.IGNORECASE) matches = re.findall(self.pattern, pred, re.IGNORECASE)
if not matches: if not matches:
return False return None
for match in matches: for match in matches:
if isinstance(match, tuple): if isinstance(match, tuple):
match = match[0] if match[0] else match[1] match = match[0] if match[0] else match[1]
if match.strip() == gold.strip(): extracted = match.strip()
return True if extracted:
return extracted
return None
return False def _grade_regex(self, gold: str, pred: str) -> Tuple[bool, Optional[str]]:
"""Grade using regex pattern matching"""
extracted = self._extract_answer_regex(pred)
if extracted is None:
return False, None
is_correct = extracted.strip() == gold.strip()
return is_correct, extracted
def _grade_cli(self, gold: str, pred: str) -> bool: def _grade_cli(self, gold: str, pred: str) -> Tuple[bool, Optional[str]]:
"""Grade using external CLI script""" """Grade using external CLI script"""
if not self.grader_script: if not self.grader_script:
raise ValueError("CLI grader requires --grader-script") raise ValueError("CLI grader requires --grader-script")
@ -141,18 +156,54 @@ class Grader:
text=True, text=True,
timeout=30 timeout=30
) )
return result.returncode == 0 is_correct = result.returncode == 0
extracted = pred if is_correct else None
return is_correct, extracted
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
return False return False, None
except Exception as e: except Exception as e:
return False return False, None
def grade(self, gold: str, pred: str) -> bool: def _grade_llm(self, gold: str, pred: str, problem: str) -> Tuple[bool, Optional[str]]:
"""Grade using LLM-based extraction"""
prompt = f"""Extract the answer from this response:
Response: {pred}
Expected answer: {gold}
Please provide only the extracted answer, nothing else."""
url = f"{self.judge_server_url}/v1/chat/completions"
headers = {"Content-Type": "application/json"}
data = {
"model": self.judge_model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0,
"max_tokens": 256
}
try:
response = requests.post(url, headers=headers, json=data)
response.raise_for_status()
extracted = response.json()["choices"][0]["message"]["content"].strip()
is_correct = extracted.strip().lower() == gold.strip().lower()
return is_correct, extracted
except Exception as e:
return False, None
def _truncate_response(self, response: str, max_lines: int = 3) -> str:
"""Keep only last N lines of response"""
lines = response.split('\n')
return '\n'.join(lines[-max_lines:]) if len(lines) > max_lines else response
def grade(self, gold: str, pred: str, problem: str = "") -> Tuple[bool, Optional[str]]:
"""Grade the response""" """Grade the response"""
if self.grader_type == "regex": if self.grader_type == "regex":
return self._grade_regex(gold, pred) return self._grade_regex(gold, pred)
elif self.grader_type == "cli": elif self.grader_type == "cli":
return self._grade_cli(gold, pred) return self._grade_cli(gold, pred)
elif self.grader_type == "llm":
return self._grade_llm(gold, pred, problem)
else: else:
raise ValueError(f"Unknown grader type: {self.grader_type}") raise ValueError(f"Unknown grader type: {self.grader_type}")
@ -164,13 +215,17 @@ class Processor:
threads: int = 32, threads: int = 32,
verbose: bool = False, verbose: bool = False,
grader: Optional[Grader] = None, grader: Optional[Grader] = None,
model_name: Optional[str] = None model_name: Optional[str] = None,
judge_server_url: str = "",
judge_model_name: Optional[str] = None
): ):
self.server_url = server_url self.server_url = server_url
self.n_predict = n_predict self.n_predict = n_predict
self.threads = threads self.threads = threads
self.verbose = verbose self.verbose = verbose
self.model_name = model_name self.model_name = model_name
self.judge_server_url = judge_server_url if judge_server_url else server_url
self.judge_model_name = judge_model_name
self.dataset = AimeDataset() self.dataset = AimeDataset()
self.grader = grader or Grader() self.grader = grader or Grader()
self.eval_state = EvalState( self.eval_state = EvalState(
@ -180,6 +235,13 @@ class Processor:
sampling_config={"temperature": 0, "max_tokens": n_predict} sampling_config={"temperature": 0, "max_tokens": n_predict}
) )
# Pass judge configuration to grader if using LLM grader
if self.grader.grader_type == "llm":
if self.judge_model_name:
self.grader.judge_model_name = self.judge_model_name
if self.judge_server_url:
self.grader.judge_server_url = self.judge_server_url
def _make_request(self, prompt: str) -> Dict[str, Any]: def _make_request(self, prompt: str) -> Dict[str, Any]:
"""Make HTTP request to the server""" """Make HTTP request to the server"""
url = f"{self.server_url}/v1/chat/completions" url = f"{self.server_url}/v1/chat/completions"
@ -217,7 +279,14 @@ class Processor:
response = self._make_request(prompt) response = self._make_request(prompt)
pred = response["choices"][0]["message"]["content"] pred = response["choices"][0]["message"]["content"]
task_state.pred = pred task_state.pred = pred
task_state.correct = self.grader.grade(gold, pred)
# Truncate response to last 2-3 lines for grading
pred_truncated = self.grader._truncate_response(pred, max_lines=3)
# Grade the response
is_correct, extracted = self.grader.grade(gold, pred_truncated, prompt)
task_state.correct = is_correct
task_state.extracted = extracted
task_state.status = "ok" task_state.status = "ok"
except Exception as e: except Exception as e:
task_state.status = f"error: {str(e)}" task_state.status = f"error: {str(e)}"
@ -233,6 +302,10 @@ class Processor:
print(f"Server: {self.server_url}") print(f"Server: {self.server_url}")
print(f"Threads: {self.threads}") print(f"Threads: {self.threads}")
print(f"Max tokens: {self.n_predict}") print(f"Max tokens: {self.n_predict}")
print(f"Grader: {self.grader.grader_type}", end="")
if self.grader.grader_type == "llm":
print(f" (judge server: {self.judge_server_url}, model: {self.judge_model_name})", end="")
print()
print() print()
dataset_size = len(self.dataset.questions) dataset_size = len(self.dataset.questions)
@ -276,15 +349,17 @@ class Processor:
correct += 1 correct += 1
# Print task completion status # Print task completion status
pred_display = task_state.pred if task_state.pred else "N/A" extracted_display = task_state.extracted if task_state.extracted else "N/A"
success_ratio = correct / total if total > 0 else 0.0 success_ratio = correct / total if total > 0 else 0.0
print(f"{total:3}/{n_cases:3} {task_state.case_id:<15} AIME2025 {task_state.prompt[:40]:<40} {task_state.gold:<10} {pred_display:<10} {'' if task_state.correct else ''} [{correct:3}/{total:3}, {success_ratio:.3f}]") print(f"{total:3}/{n_cases:3} {task_state.case_id:<15} AIME2025 {task_state.prompt[:40]:<40} {task_state.gold:<10} {extracted_display:<10} {'' if task_state.correct else ''} [{correct:3}/{total:3}, {success_ratio:.3f}]")
if self.verbose: if self.verbose:
print(f"\nCase {total}: {task_state.correct}") print(f"\nCase {total}: {task_state.correct}")
print(f" Gold: {task_state.gold}") print(f" Gold: {task_state.gold}")
if task_state.pred: if task_state.pred:
print(f" Pred: {task_state.pred}") print(f" Pred: {task_state.pred}")
if task_state.extracted:
print(f" Extracted: {task_state.extracted}")
print(f" Status: {task_state.status}") print(f" Status: {task_state.status}")
self.eval_state.task_states["aime"] = { self.eval_state.task_states["aime"] = {
@ -360,8 +435,8 @@ def main():
"--grader-type", "--grader-type",
type=str, type=str,
default="regex", default="regex",
choices=["regex", "cli"], choices=["regex", "cli", "llm"],
help="Grader type: regex or cli (default: regex)" help="Grader type: regex, cli, or llm (default: regex)"
) )
parser.add_argument( parser.add_argument(
"--grader-regex-type", "--grader-regex-type",
@ -376,6 +451,18 @@ def main():
default=None, default=None,
help="CLI grader script path (required for --grader-type cli)" help="CLI grader script path (required for --grader-type cli)"
) )
parser.add_argument(
"--judge-server",
type=str,
default="",
help="Server URL for LLM judge (default: same as main server)"
)
parser.add_argument(
"--judge-model",
type=str,
default=None,
help="Model name for LLM judge (default: same as main model)"
)
args = parser.parse_args() args = parser.parse_args()
@ -385,13 +472,18 @@ def main():
grader_script=args.grader_script grader_script=args.grader_script
) )
if args.grader_type == "llm" and not args.judge_server:
print("Warning: Using same server for LLM judge (no --judge-server specified)")
processor = Processor( processor = Processor(
server_url=args.server, server_url=args.server,
n_predict=args.n_predict, n_predict=args.n_predict,
threads=args.threads, threads=args.threads,
verbose=args.verbose, verbose=args.verbose,
grader=grader, grader=grader,
model_name=args.model model_name=args.model,
judge_server_url=args.judge_server,
judge_model_name=args.judge_model
) )
eval_state = processor.process(n_cases=args.n_cases, seed=args.seed) eval_state = processor.process(n_cases=args.n_cases, seed=args.seed)