improve grader
This commit is contained in:
parent
68dde884d6
commit
d2b10302ce
|
|
@ -9,7 +9,7 @@ import time
|
|||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, asdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
import random
|
||||
|
|
@ -47,6 +47,7 @@ class TaskState:
|
|||
prompt: str
|
||||
gold: str
|
||||
pred: Optional[str] = None
|
||||
extracted: Optional[str] = None
|
||||
correct: bool = False
|
||||
status: str = "pending"
|
||||
|
||||
|
|
@ -97,35 +98,49 @@ class Grader:
|
|||
self,
|
||||
grader_type: str = "regex",
|
||||
grader_regex_type: str = "aime",
|
||||
grader_script: Optional[str] = None
|
||||
grader_script: Optional[str] = None,
|
||||
judge_model_name: Optional[str] = None,
|
||||
judge_server_url: str = ""
|
||||
):
|
||||
self.grader_type = grader_type
|
||||
self.grader_regex_type = grader_regex_type
|
||||
self.grader_script = grader_script
|
||||
self.judge_model_name = judge_model_name
|
||||
self.judge_server_url = judge_server_url
|
||||
self.pattern = self._get_pattern()
|
||||
|
||||
def _get_pattern(self) -> str:
|
||||
def _get_pattern(self) -> Optional[str]:
|
||||
if self.grader_type == "regex":
|
||||
if self.grader_regex_type not in GRADER_PATTERNS:
|
||||
raise ValueError(f"Unknown grader regex type: {self.grader_regex_type}")
|
||||
return GRADER_PATTERNS[self.grader_regex_type]
|
||||
return None
|
||||
|
||||
def _grade_regex(self, gold: str, pred: str) -> bool:
|
||||
"""Grade using regex pattern matching"""
|
||||
def _extract_answer_regex(self, pred: str) -> Optional[str]:
|
||||
"""Extract answer using regex pattern"""
|
||||
if not self.pattern:
|
||||
return None
|
||||
matches = re.findall(self.pattern, pred, re.IGNORECASE)
|
||||
if not matches:
|
||||
return False
|
||||
return None
|
||||
|
||||
for match in matches:
|
||||
if isinstance(match, tuple):
|
||||
match = match[0] if match[0] else match[1]
|
||||
if match.strip() == gold.strip():
|
||||
return True
|
||||
extracted = match.strip()
|
||||
if extracted:
|
||||
return extracted
|
||||
return None
|
||||
|
||||
return False
|
||||
def _grade_regex(self, gold: str, pred: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Grade using regex pattern matching"""
|
||||
extracted = self._extract_answer_regex(pred)
|
||||
if extracted is None:
|
||||
return False, None
|
||||
is_correct = extracted.strip() == gold.strip()
|
||||
return is_correct, extracted
|
||||
|
||||
def _grade_cli(self, gold: str, pred: str) -> bool:
|
||||
def _grade_cli(self, gold: str, pred: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Grade using external CLI script"""
|
||||
if not self.grader_script:
|
||||
raise ValueError("CLI grader requires --grader-script")
|
||||
|
|
@ -141,18 +156,54 @@ class Grader:
|
|||
text=True,
|
||||
timeout=30
|
||||
)
|
||||
return result.returncode == 0
|
||||
is_correct = result.returncode == 0
|
||||
extracted = pred if is_correct else None
|
||||
return is_correct, extracted
|
||||
except subprocess.TimeoutExpired:
|
||||
return False
|
||||
return False, None
|
||||
except Exception as e:
|
||||
return False
|
||||
return False, None
|
||||
|
||||
def grade(self, gold: str, pred: str) -> bool:
|
||||
def _grade_llm(self, gold: str, pred: str, problem: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Grade using LLM-based extraction"""
|
||||
prompt = f"""Extract the answer from this response:
|
||||
|
||||
Response: {pred}
|
||||
|
||||
Expected answer: {gold}
|
||||
|
||||
Please provide only the extracted answer, nothing else."""
|
||||
url = f"{self.judge_server_url}/v1/chat/completions"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = {
|
||||
"model": self.judge_model_name,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0,
|
||||
"max_tokens": 256
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
extracted = response.json()["choices"][0]["message"]["content"].strip()
|
||||
is_correct = extracted.strip().lower() == gold.strip().lower()
|
||||
return is_correct, extracted
|
||||
except Exception as e:
|
||||
return False, None
|
||||
|
||||
def _truncate_response(self, response: str, max_lines: int = 3) -> str:
|
||||
"""Keep only last N lines of response"""
|
||||
lines = response.split('\n')
|
||||
return '\n'.join(lines[-max_lines:]) if len(lines) > max_lines else response
|
||||
|
||||
def grade(self, gold: str, pred: str, problem: str = "") -> Tuple[bool, Optional[str]]:
|
||||
"""Grade the response"""
|
||||
if self.grader_type == "regex":
|
||||
return self._grade_regex(gold, pred)
|
||||
elif self.grader_type == "cli":
|
||||
return self._grade_cli(gold, pred)
|
||||
elif self.grader_type == "llm":
|
||||
return self._grade_llm(gold, pred, problem)
|
||||
else:
|
||||
raise ValueError(f"Unknown grader type: {self.grader_type}")
|
||||
|
||||
|
|
@ -164,13 +215,17 @@ class Processor:
|
|||
threads: int = 32,
|
||||
verbose: bool = False,
|
||||
grader: Optional[Grader] = None,
|
||||
model_name: Optional[str] = None
|
||||
model_name: Optional[str] = None,
|
||||
judge_server_url: str = "",
|
||||
judge_model_name: Optional[str] = None
|
||||
):
|
||||
self.server_url = server_url
|
||||
self.n_predict = n_predict
|
||||
self.threads = threads
|
||||
self.verbose = verbose
|
||||
self.model_name = model_name
|
||||
self.judge_server_url = judge_server_url if judge_server_url else server_url
|
||||
self.judge_model_name = judge_model_name
|
||||
self.dataset = AimeDataset()
|
||||
self.grader = grader or Grader()
|
||||
self.eval_state = EvalState(
|
||||
|
|
@ -180,6 +235,13 @@ class Processor:
|
|||
sampling_config={"temperature": 0, "max_tokens": n_predict}
|
||||
)
|
||||
|
||||
# Pass judge configuration to grader if using LLM grader
|
||||
if self.grader.grader_type == "llm":
|
||||
if self.judge_model_name:
|
||||
self.grader.judge_model_name = self.judge_model_name
|
||||
if self.judge_server_url:
|
||||
self.grader.judge_server_url = self.judge_server_url
|
||||
|
||||
def _make_request(self, prompt: str) -> Dict[str, Any]:
|
||||
"""Make HTTP request to the server"""
|
||||
url = f"{self.server_url}/v1/chat/completions"
|
||||
|
|
@ -217,7 +279,14 @@ class Processor:
|
|||
response = self._make_request(prompt)
|
||||
pred = response["choices"][0]["message"]["content"]
|
||||
task_state.pred = pred
|
||||
task_state.correct = self.grader.grade(gold, pred)
|
||||
|
||||
# Truncate response to last 2-3 lines for grading
|
||||
pred_truncated = self.grader._truncate_response(pred, max_lines=3)
|
||||
|
||||
# Grade the response
|
||||
is_correct, extracted = self.grader.grade(gold, pred_truncated, prompt)
|
||||
task_state.correct = is_correct
|
||||
task_state.extracted = extracted
|
||||
task_state.status = "ok"
|
||||
except Exception as e:
|
||||
task_state.status = f"error: {str(e)}"
|
||||
|
|
@ -233,6 +302,10 @@ class Processor:
|
|||
print(f"Server: {self.server_url}")
|
||||
print(f"Threads: {self.threads}")
|
||||
print(f"Max tokens: {self.n_predict}")
|
||||
print(f"Grader: {self.grader.grader_type}", end="")
|
||||
if self.grader.grader_type == "llm":
|
||||
print(f" (judge server: {self.judge_server_url}, model: {self.judge_model_name})", end="")
|
||||
print()
|
||||
print()
|
||||
|
||||
dataset_size = len(self.dataset.questions)
|
||||
|
|
@ -276,15 +349,17 @@ class Processor:
|
|||
correct += 1
|
||||
|
||||
# Print task completion status
|
||||
pred_display = task_state.pred if task_state.pred else "N/A"
|
||||
extracted_display = task_state.extracted if task_state.extracted else "N/A"
|
||||
success_ratio = correct / total if total > 0 else 0.0
|
||||
print(f"{total:3}/{n_cases:3} {task_state.case_id:<15} AIME2025 {task_state.prompt[:40]:<40} {task_state.gold:<10} {pred_display:<10} {'✓' if task_state.correct else '✗'} [{correct:3}/{total:3}, {success_ratio:.3f}]")
|
||||
print(f"{total:3}/{n_cases:3} {task_state.case_id:<15} AIME2025 {task_state.prompt[:40]:<40} {task_state.gold:<10} {extracted_display:<10} {'✓' if task_state.correct else '✗'} [{correct:3}/{total:3}, {success_ratio:.3f}]")
|
||||
|
||||
if self.verbose:
|
||||
print(f"\nCase {total}: {task_state.correct}")
|
||||
print(f" Gold: {task_state.gold}")
|
||||
if task_state.pred:
|
||||
print(f" Pred: {task_state.pred}")
|
||||
if task_state.extracted:
|
||||
print(f" Extracted: {task_state.extracted}")
|
||||
print(f" Status: {task_state.status}")
|
||||
|
||||
self.eval_state.task_states["aime"] = {
|
||||
|
|
@ -360,8 +435,8 @@ def main():
|
|||
"--grader-type",
|
||||
type=str,
|
||||
default="regex",
|
||||
choices=["regex", "cli"],
|
||||
help="Grader type: regex or cli (default: regex)"
|
||||
choices=["regex", "cli", "llm"],
|
||||
help="Grader type: regex, cli, or llm (default: regex)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grader-regex-type",
|
||||
|
|
@ -376,6 +451,18 @@ def main():
|
|||
default=None,
|
||||
help="CLI grader script path (required for --grader-type cli)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--judge-server",
|
||||
type=str,
|
||||
default="",
|
||||
help="Server URL for LLM judge (default: same as main server)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--judge-model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model name for LLM judge (default: same as main model)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -385,13 +472,18 @@ def main():
|
|||
grader_script=args.grader_script
|
||||
)
|
||||
|
||||
if args.grader_type == "llm" and not args.judge_server:
|
||||
print("Warning: Using same server for LLM judge (no --judge-server specified)")
|
||||
|
||||
processor = Processor(
|
||||
server_url=args.server,
|
||||
n_predict=args.n_predict,
|
||||
threads=args.threads,
|
||||
verbose=args.verbose,
|
||||
grader=grader,
|
||||
model_name=args.model
|
||||
model_name=args.model,
|
||||
judge_server_url=args.judge_server,
|
||||
judge_model_name=args.judge_model
|
||||
)
|
||||
|
||||
eval_state = processor.process(n_cases=args.n_cases, seed=args.seed)
|
||||
|
|
|
|||
Loading…
Reference in New Issue