improve grader
This commit is contained in:
parent
68dde884d6
commit
d2b10302ce
|
|
@ -9,7 +9,7 @@ import time
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, asdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Any
|
from typing import Dict, List, Optional, Any, Tuple
|
||||||
import requests
|
import requests
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import random
|
import random
|
||||||
|
|
@ -47,6 +47,7 @@ class TaskState:
|
||||||
prompt: str
|
prompt: str
|
||||||
gold: str
|
gold: str
|
||||||
pred: Optional[str] = None
|
pred: Optional[str] = None
|
||||||
|
extracted: Optional[str] = None
|
||||||
correct: bool = False
|
correct: bool = False
|
||||||
status: str = "pending"
|
status: str = "pending"
|
||||||
|
|
||||||
|
|
@ -97,35 +98,49 @@ class Grader:
|
||||||
self,
|
self,
|
||||||
grader_type: str = "regex",
|
grader_type: str = "regex",
|
||||||
grader_regex_type: str = "aime",
|
grader_regex_type: str = "aime",
|
||||||
grader_script: Optional[str] = None
|
grader_script: Optional[str] = None,
|
||||||
|
judge_model_name: Optional[str] = None,
|
||||||
|
judge_server_url: str = ""
|
||||||
):
|
):
|
||||||
self.grader_type = grader_type
|
self.grader_type = grader_type
|
||||||
self.grader_regex_type = grader_regex_type
|
self.grader_regex_type = grader_regex_type
|
||||||
self.grader_script = grader_script
|
self.grader_script = grader_script
|
||||||
|
self.judge_model_name = judge_model_name
|
||||||
|
self.judge_server_url = judge_server_url
|
||||||
self.pattern = self._get_pattern()
|
self.pattern = self._get_pattern()
|
||||||
|
|
||||||
def _get_pattern(self) -> str:
|
def _get_pattern(self) -> Optional[str]:
|
||||||
if self.grader_type == "regex":
|
if self.grader_type == "regex":
|
||||||
if self.grader_regex_type not in GRADER_PATTERNS:
|
if self.grader_regex_type not in GRADER_PATTERNS:
|
||||||
raise ValueError(f"Unknown grader regex type: {self.grader_regex_type}")
|
raise ValueError(f"Unknown grader regex type: {self.grader_regex_type}")
|
||||||
return GRADER_PATTERNS[self.grader_regex_type]
|
return GRADER_PATTERNS[self.grader_regex_type]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _grade_regex(self, gold: str, pred: str) -> bool:
|
def _extract_answer_regex(self, pred: str) -> Optional[str]:
|
||||||
"""Grade using regex pattern matching"""
|
"""Extract answer using regex pattern"""
|
||||||
|
if not self.pattern:
|
||||||
|
return None
|
||||||
matches = re.findall(self.pattern, pred, re.IGNORECASE)
|
matches = re.findall(self.pattern, pred, re.IGNORECASE)
|
||||||
if not matches:
|
if not matches:
|
||||||
return False
|
return None
|
||||||
|
|
||||||
for match in matches:
|
for match in matches:
|
||||||
if isinstance(match, tuple):
|
if isinstance(match, tuple):
|
||||||
match = match[0] if match[0] else match[1]
|
match = match[0] if match[0] else match[1]
|
||||||
if match.strip() == gold.strip():
|
extracted = match.strip()
|
||||||
return True
|
if extracted:
|
||||||
|
return extracted
|
||||||
|
return None
|
||||||
|
|
||||||
return False
|
def _grade_regex(self, gold: str, pred: str) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""Grade using regex pattern matching"""
|
||||||
|
extracted = self._extract_answer_regex(pred)
|
||||||
|
if extracted is None:
|
||||||
|
return False, None
|
||||||
|
is_correct = extracted.strip() == gold.strip()
|
||||||
|
return is_correct, extracted
|
||||||
|
|
||||||
def _grade_cli(self, gold: str, pred: str) -> bool:
|
def _grade_cli(self, gold: str, pred: str) -> Tuple[bool, Optional[str]]:
|
||||||
"""Grade using external CLI script"""
|
"""Grade using external CLI script"""
|
||||||
if not self.grader_script:
|
if not self.grader_script:
|
||||||
raise ValueError("CLI grader requires --grader-script")
|
raise ValueError("CLI grader requires --grader-script")
|
||||||
|
|
@ -141,18 +156,54 @@ class Grader:
|
||||||
text=True,
|
text=True,
|
||||||
timeout=30
|
timeout=30
|
||||||
)
|
)
|
||||||
return result.returncode == 0
|
is_correct = result.returncode == 0
|
||||||
|
extracted = pred if is_correct else None
|
||||||
|
return is_correct, extracted
|
||||||
except subprocess.TimeoutExpired:
|
except subprocess.TimeoutExpired:
|
||||||
return False
|
return False, None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return False
|
return False, None
|
||||||
|
|
||||||
def grade(self, gold: str, pred: str) -> bool:
|
def _grade_llm(self, gold: str, pred: str, problem: str) -> Tuple[bool, Optional[str]]:
|
||||||
|
"""Grade using LLM-based extraction"""
|
||||||
|
prompt = f"""Extract the answer from this response:
|
||||||
|
|
||||||
|
Response: {pred}
|
||||||
|
|
||||||
|
Expected answer: {gold}
|
||||||
|
|
||||||
|
Please provide only the extracted answer, nothing else."""
|
||||||
|
url = f"{self.judge_server_url}/v1/chat/completions"
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
data = {
|
||||||
|
"model": self.judge_model_name,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"temperature": 0,
|
||||||
|
"max_tokens": 256
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(url, headers=headers, json=data)
|
||||||
|
response.raise_for_status()
|
||||||
|
extracted = response.json()["choices"][0]["message"]["content"].strip()
|
||||||
|
is_correct = extracted.strip().lower() == gold.strip().lower()
|
||||||
|
return is_correct, extracted
|
||||||
|
except Exception as e:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
def _truncate_response(self, response: str, max_lines: int = 3) -> str:
|
||||||
|
"""Keep only last N lines of response"""
|
||||||
|
lines = response.split('\n')
|
||||||
|
return '\n'.join(lines[-max_lines:]) if len(lines) > max_lines else response
|
||||||
|
|
||||||
|
def grade(self, gold: str, pred: str, problem: str = "") -> Tuple[bool, Optional[str]]:
|
||||||
"""Grade the response"""
|
"""Grade the response"""
|
||||||
if self.grader_type == "regex":
|
if self.grader_type == "regex":
|
||||||
return self._grade_regex(gold, pred)
|
return self._grade_regex(gold, pred)
|
||||||
elif self.grader_type == "cli":
|
elif self.grader_type == "cli":
|
||||||
return self._grade_cli(gold, pred)
|
return self._grade_cli(gold, pred)
|
||||||
|
elif self.grader_type == "llm":
|
||||||
|
return self._grade_llm(gold, pred, problem)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown grader type: {self.grader_type}")
|
raise ValueError(f"Unknown grader type: {self.grader_type}")
|
||||||
|
|
||||||
|
|
@ -164,13 +215,17 @@ class Processor:
|
||||||
threads: int = 32,
|
threads: int = 32,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
grader: Optional[Grader] = None,
|
grader: Optional[Grader] = None,
|
||||||
model_name: Optional[str] = None
|
model_name: Optional[str] = None,
|
||||||
|
judge_server_url: str = "",
|
||||||
|
judge_model_name: Optional[str] = None
|
||||||
):
|
):
|
||||||
self.server_url = server_url
|
self.server_url = server_url
|
||||||
self.n_predict = n_predict
|
self.n_predict = n_predict
|
||||||
self.threads = threads
|
self.threads = threads
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
self.judge_server_url = judge_server_url if judge_server_url else server_url
|
||||||
|
self.judge_model_name = judge_model_name
|
||||||
self.dataset = AimeDataset()
|
self.dataset = AimeDataset()
|
||||||
self.grader = grader or Grader()
|
self.grader = grader or Grader()
|
||||||
self.eval_state = EvalState(
|
self.eval_state = EvalState(
|
||||||
|
|
@ -180,6 +235,13 @@ class Processor:
|
||||||
sampling_config={"temperature": 0, "max_tokens": n_predict}
|
sampling_config={"temperature": 0, "max_tokens": n_predict}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Pass judge configuration to grader if using LLM grader
|
||||||
|
if self.grader.grader_type == "llm":
|
||||||
|
if self.judge_model_name:
|
||||||
|
self.grader.judge_model_name = self.judge_model_name
|
||||||
|
if self.judge_server_url:
|
||||||
|
self.grader.judge_server_url = self.judge_server_url
|
||||||
|
|
||||||
def _make_request(self, prompt: str) -> Dict[str, Any]:
|
def _make_request(self, prompt: str) -> Dict[str, Any]:
|
||||||
"""Make HTTP request to the server"""
|
"""Make HTTP request to the server"""
|
||||||
url = f"{self.server_url}/v1/chat/completions"
|
url = f"{self.server_url}/v1/chat/completions"
|
||||||
|
|
@ -217,7 +279,14 @@ class Processor:
|
||||||
response = self._make_request(prompt)
|
response = self._make_request(prompt)
|
||||||
pred = response["choices"][0]["message"]["content"]
|
pred = response["choices"][0]["message"]["content"]
|
||||||
task_state.pred = pred
|
task_state.pred = pred
|
||||||
task_state.correct = self.grader.grade(gold, pred)
|
|
||||||
|
# Truncate response to last 2-3 lines for grading
|
||||||
|
pred_truncated = self.grader._truncate_response(pred, max_lines=3)
|
||||||
|
|
||||||
|
# Grade the response
|
||||||
|
is_correct, extracted = self.grader.grade(gold, pred_truncated, prompt)
|
||||||
|
task_state.correct = is_correct
|
||||||
|
task_state.extracted = extracted
|
||||||
task_state.status = "ok"
|
task_state.status = "ok"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
task_state.status = f"error: {str(e)}"
|
task_state.status = f"error: {str(e)}"
|
||||||
|
|
@ -233,6 +302,10 @@ class Processor:
|
||||||
print(f"Server: {self.server_url}")
|
print(f"Server: {self.server_url}")
|
||||||
print(f"Threads: {self.threads}")
|
print(f"Threads: {self.threads}")
|
||||||
print(f"Max tokens: {self.n_predict}")
|
print(f"Max tokens: {self.n_predict}")
|
||||||
|
print(f"Grader: {self.grader.grader_type}", end="")
|
||||||
|
if self.grader.grader_type == "llm":
|
||||||
|
print(f" (judge server: {self.judge_server_url}, model: {self.judge_model_name})", end="")
|
||||||
|
print()
|
||||||
print()
|
print()
|
||||||
|
|
||||||
dataset_size = len(self.dataset.questions)
|
dataset_size = len(self.dataset.questions)
|
||||||
|
|
@ -276,15 +349,17 @@ class Processor:
|
||||||
correct += 1
|
correct += 1
|
||||||
|
|
||||||
# Print task completion status
|
# Print task completion status
|
||||||
pred_display = task_state.pred if task_state.pred else "N/A"
|
extracted_display = task_state.extracted if task_state.extracted else "N/A"
|
||||||
success_ratio = correct / total if total > 0 else 0.0
|
success_ratio = correct / total if total > 0 else 0.0
|
||||||
print(f"{total:3}/{n_cases:3} {task_state.case_id:<15} AIME2025 {task_state.prompt[:40]:<40} {task_state.gold:<10} {pred_display:<10} {'✓' if task_state.correct else '✗'} [{correct:3}/{total:3}, {success_ratio:.3f}]")
|
print(f"{total:3}/{n_cases:3} {task_state.case_id:<15} AIME2025 {task_state.prompt[:40]:<40} {task_state.gold:<10} {extracted_display:<10} {'✓' if task_state.correct else '✗'} [{correct:3}/{total:3}, {success_ratio:.3f}]")
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(f"\nCase {total}: {task_state.correct}")
|
print(f"\nCase {total}: {task_state.correct}")
|
||||||
print(f" Gold: {task_state.gold}")
|
print(f" Gold: {task_state.gold}")
|
||||||
if task_state.pred:
|
if task_state.pred:
|
||||||
print(f" Pred: {task_state.pred}")
|
print(f" Pred: {task_state.pred}")
|
||||||
|
if task_state.extracted:
|
||||||
|
print(f" Extracted: {task_state.extracted}")
|
||||||
print(f" Status: {task_state.status}")
|
print(f" Status: {task_state.status}")
|
||||||
|
|
||||||
self.eval_state.task_states["aime"] = {
|
self.eval_state.task_states["aime"] = {
|
||||||
|
|
@ -360,8 +435,8 @@ def main():
|
||||||
"--grader-type",
|
"--grader-type",
|
||||||
type=str,
|
type=str,
|
||||||
default="regex",
|
default="regex",
|
||||||
choices=["regex", "cli"],
|
choices=["regex", "cli", "llm"],
|
||||||
help="Grader type: regex or cli (default: regex)"
|
help="Grader type: regex, cli, or llm (default: regex)"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--grader-regex-type",
|
"--grader-regex-type",
|
||||||
|
|
@ -376,6 +451,18 @@ def main():
|
||||||
default=None,
|
default=None,
|
||||||
help="CLI grader script path (required for --grader-type cli)"
|
help="CLI grader script path (required for --grader-type cli)"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--judge-server",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Server URL for LLM judge (default: same as main server)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--judge-model",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Model name for LLM judge (default: same as main model)"
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
@ -385,13 +472,18 @@ def main():
|
||||||
grader_script=args.grader_script
|
grader_script=args.grader_script
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.grader_type == "llm" and not args.judge_server:
|
||||||
|
print("Warning: Using same server for LLM judge (no --judge-server specified)")
|
||||||
|
|
||||||
processor = Processor(
|
processor = Processor(
|
||||||
server_url=args.server,
|
server_url=args.server,
|
||||||
n_predict=args.n_predict,
|
n_predict=args.n_predict,
|
||||||
threads=args.threads,
|
threads=args.threads,
|
||||||
verbose=args.verbose,
|
verbose=args.verbose,
|
||||||
grader=grader,
|
grader=grader,
|
||||||
model_name=args.model
|
model_name=args.model,
|
||||||
|
judge_server_url=args.judge_server,
|
||||||
|
judge_model_name=args.judge_model
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_state = processor.process(n_cases=args.n_cases, seed=args.seed)
|
eval_state = processor.process(n_cases=args.n_cases, seed=args.seed)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue