From d2b10302ce4e515202f5635185681819dcbc77ba Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 15 Feb 2026 21:50:45 +0200 Subject: [PATCH] improve grader --- examples/llama-eval/llama-eval-new.py | 134 ++++++++++++++++++++++---- 1 file changed, 113 insertions(+), 21 deletions(-) diff --git a/examples/llama-eval/llama-eval-new.py b/examples/llama-eval/llama-eval-new.py index 4e104bcc0e..ff62777653 100755 --- a/examples/llama-eval/llama-eval-new.py +++ b/examples/llama-eval/llama-eval-new.py @@ -9,7 +9,7 @@ import time from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, asdict from pathlib import Path -from typing import Dict, List, Optional, Any +from typing import Dict, List, Optional, Any, Tuple import requests from tqdm import tqdm import random @@ -47,6 +47,7 @@ class TaskState: prompt: str gold: str pred: Optional[str] = None + extracted: Optional[str] = None correct: bool = False status: str = "pending" @@ -97,35 +98,49 @@ class Grader: self, grader_type: str = "regex", grader_regex_type: str = "aime", - grader_script: Optional[str] = None + grader_script: Optional[str] = None, + judge_model_name: Optional[str] = None, + judge_server_url: str = "" ): self.grader_type = grader_type self.grader_regex_type = grader_regex_type self.grader_script = grader_script + self.judge_model_name = judge_model_name + self.judge_server_url = judge_server_url self.pattern = self._get_pattern() - def _get_pattern(self) -> str: + def _get_pattern(self) -> Optional[str]: if self.grader_type == "regex": if self.grader_regex_type not in GRADER_PATTERNS: raise ValueError(f"Unknown grader regex type: {self.grader_regex_type}") return GRADER_PATTERNS[self.grader_regex_type] return None - def _grade_regex(self, gold: str, pred: str) -> bool: - """Grade using regex pattern matching""" + def _extract_answer_regex(self, pred: str) -> Optional[str]: + """Extract answer using regex pattern""" + if not self.pattern: + return None matches = re.findall(self.pattern, pred, re.IGNORECASE) if not matches: - return False + return None for match in matches: if isinstance(match, tuple): match = match[0] if match[0] else match[1] - if match.strip() == gold.strip(): - return True + extracted = match.strip() + if extracted: + return extracted + return None - return False + def _grade_regex(self, gold: str, pred: str) -> Tuple[bool, Optional[str]]: + """Grade using regex pattern matching""" + extracted = self._extract_answer_regex(pred) + if extracted is None: + return False, None + is_correct = extracted.strip() == gold.strip() + return is_correct, extracted - def _grade_cli(self, gold: str, pred: str) -> bool: + def _grade_cli(self, gold: str, pred: str) -> Tuple[bool, Optional[str]]: """Grade using external CLI script""" if not self.grader_script: raise ValueError("CLI grader requires --grader-script") @@ -141,18 +156,54 @@ class Grader: text=True, timeout=30 ) - return result.returncode == 0 + is_correct = result.returncode == 0 + extracted = pred if is_correct else None + return is_correct, extracted except subprocess.TimeoutExpired: - return False + return False, None except Exception as e: - return False + return False, None - def grade(self, gold: str, pred: str) -> bool: + def _grade_llm(self, gold: str, pred: str, problem: str) -> Tuple[bool, Optional[str]]: + """Grade using LLM-based extraction""" + prompt = f"""Extract the answer from this response: + +Response: {pred} + +Expected answer: {gold} + +Please provide only the extracted answer, nothing else.""" + url = f"{self.judge_server_url}/v1/chat/completions" + headers = {"Content-Type": "application/json"} + data = { + "model": self.judge_model_name, + "messages": [{"role": "user", "content": prompt}], + "temperature": 0, + "max_tokens": 256 + } + + try: + response = requests.post(url, headers=headers, json=data) + response.raise_for_status() + extracted = response.json()["choices"][0]["message"]["content"].strip() + is_correct = extracted.strip().lower() == gold.strip().lower() + return is_correct, extracted + except Exception as e: + return False, None + + def _truncate_response(self, response: str, max_lines: int = 3) -> str: + """Keep only last N lines of response""" + lines = response.split('\n') + return '\n'.join(lines[-max_lines:]) if len(lines) > max_lines else response + + def grade(self, gold: str, pred: str, problem: str = "") -> Tuple[bool, Optional[str]]: """Grade the response""" if self.grader_type == "regex": return self._grade_regex(gold, pred) elif self.grader_type == "cli": return self._grade_cli(gold, pred) + elif self.grader_type == "llm": + return self._grade_llm(gold, pred, problem) else: raise ValueError(f"Unknown grader type: {self.grader_type}") @@ -164,13 +215,17 @@ class Processor: threads: int = 32, verbose: bool = False, grader: Optional[Grader] = None, - model_name: Optional[str] = None + model_name: Optional[str] = None, + judge_server_url: str = "", + judge_model_name: Optional[str] = None ): self.server_url = server_url self.n_predict = n_predict self.threads = threads self.verbose = verbose self.model_name = model_name + self.judge_server_url = judge_server_url if judge_server_url else server_url + self.judge_model_name = judge_model_name self.dataset = AimeDataset() self.grader = grader or Grader() self.eval_state = EvalState( @@ -180,6 +235,13 @@ class Processor: sampling_config={"temperature": 0, "max_tokens": n_predict} ) + # Pass judge configuration to grader if using LLM grader + if self.grader.grader_type == "llm": + if self.judge_model_name: + self.grader.judge_model_name = self.judge_model_name + if self.judge_server_url: + self.grader.judge_server_url = self.judge_server_url + def _make_request(self, prompt: str) -> Dict[str, Any]: """Make HTTP request to the server""" url = f"{self.server_url}/v1/chat/completions" @@ -217,7 +279,14 @@ class Processor: response = self._make_request(prompt) pred = response["choices"][0]["message"]["content"] task_state.pred = pred - task_state.correct = self.grader.grade(gold, pred) + + # Truncate response to last 2-3 lines for grading + pred_truncated = self.grader._truncate_response(pred, max_lines=3) + + # Grade the response + is_correct, extracted = self.grader.grade(gold, pred_truncated, prompt) + task_state.correct = is_correct + task_state.extracted = extracted task_state.status = "ok" except Exception as e: task_state.status = f"error: {str(e)}" @@ -233,6 +302,10 @@ class Processor: print(f"Server: {self.server_url}") print(f"Threads: {self.threads}") print(f"Max tokens: {self.n_predict}") + print(f"Grader: {self.grader.grader_type}", end="") + if self.grader.grader_type == "llm": + print(f" (judge server: {self.judge_server_url}, model: {self.judge_model_name})", end="") + print() print() dataset_size = len(self.dataset.questions) @@ -276,15 +349,17 @@ class Processor: correct += 1 # Print task completion status - pred_display = task_state.pred if task_state.pred else "N/A" + extracted_display = task_state.extracted if task_state.extracted else "N/A" success_ratio = correct / total if total > 0 else 0.0 - print(f"{total:3}/{n_cases:3} {task_state.case_id:<15} AIME2025 {task_state.prompt[:40]:<40} {task_state.gold:<10} {pred_display:<10} {'✓' if task_state.correct else '✗'} [{correct:3}/{total:3}, {success_ratio:.3f}]") + print(f"{total:3}/{n_cases:3} {task_state.case_id:<15} AIME2025 {task_state.prompt[:40]:<40} {task_state.gold:<10} {extracted_display:<10} {'✓' if task_state.correct else '✗'} [{correct:3}/{total:3}, {success_ratio:.3f}]") if self.verbose: print(f"\nCase {total}: {task_state.correct}") print(f" Gold: {task_state.gold}") if task_state.pred: print(f" Pred: {task_state.pred}") + if task_state.extracted: + print(f" Extracted: {task_state.extracted}") print(f" Status: {task_state.status}") self.eval_state.task_states["aime"] = { @@ -360,8 +435,8 @@ def main(): "--grader-type", type=str, default="regex", - choices=["regex", "cli"], - help="Grader type: regex or cli (default: regex)" + choices=["regex", "cli", "llm"], + help="Grader type: regex, cli, or llm (default: regex)" ) parser.add_argument( "--grader-regex-type", @@ -376,6 +451,18 @@ def main(): default=None, help="CLI grader script path (required for --grader-type cli)" ) + parser.add_argument( + "--judge-server", + type=str, + default="", + help="Server URL for LLM judge (default: same as main server)" + ) + parser.add_argument( + "--judge-model", + type=str, + default=None, + help="Model name for LLM judge (default: same as main model)" + ) args = parser.parse_args() @@ -385,13 +472,18 @@ def main(): grader_script=args.grader_script ) + if args.grader_type == "llm" and not args.judge_server: + print("Warning: Using same server for LLM judge (no --judge-server specified)") + processor = Processor( server_url=args.server, n_predict=args.n_predict, threads=args.threads, verbose=args.verbose, grader=grader, - model_name=args.model + model_name=args.model, + judge_server_url=args.judge_server, + judge_model_name=args.judge_model ) eval_state = processor.process(n_cases=args.n_cases, seed=args.seed)