From e6e777cfb32e8f71b45f1ff7995d9930d19e674c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 16 Feb 2026 16:21:36 +0200 Subject: [PATCH] resume eval --- examples/llama-eval/llama-eval-state.json | 29 - examples/llama-eval/llama-eval.py | 610 ++++++++++++++-------- 2 files changed, 399 insertions(+), 240 deletions(-) delete mode 100644 examples/llama-eval/llama-eval-state.json diff --git a/examples/llama-eval/llama-eval-state.json b/examples/llama-eval/llama-eval-state.json deleted file mode 100644 index add0f626a3..0000000000 --- a/examples/llama-eval/llama-eval-state.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "id": "gpqa", - "tasks": [ - "gpqa" - ], - "task_states": { - "gpqa": { - "total": 1, - "correct": 0, - "cases": { - "gpqa": [ - { - "case_id": "gpqa_000_184", - "prompt": "Consider a system with Hamiltonian operator $H = \\varepsilon \\vec{\\sigma}.\\vec{n}$. Here, $\\vec{n}$ is an arbitrary unit vector, $\\varepsilon $ is a constant of dimension energy, and components of $\\vec{\\sigma}$ are the Pauli spin matrices. What are the eigenvalues of the Hamiltonian operator?\n\n\n(A) +\\hbar/2, -\\hbar/2\n(B) +1, -1\n(C) +\\varepsilon \\hbar/2, - \\varepsilon \\hbar/2\n(D) + \\varepsilon, -\\varepsilon\n\n\nExpress your final answer as the corresponding option 'A', 'B', 'C', or 'D'.\n", - "gold": "+ \\varepsilon, -\\varepsilon\n", - "pred": null, - "extracted": null, - "correct": false, - "status": "error: HTTPConnectionPool(host='localhost', port=8034): Max retries exceeded with url: /v1/chat/completions (Caused by NewConnectionError(\"HTTPConnection(host='localhost', port=8034): Failed to establish a new connection: [Errno 61] Connection refused\"))" - } - ] - } - } - }, - "sampling_config": { - "temperature": 0, - "max_tokens": 2048 - } -} \ No newline at end of file diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py index 6959ff08d9..0cfa06ff43 100755 --- a/examples/llama-eval/llama-eval.py +++ b/examples/llama-eval/llama-eval.py @@ -8,8 +8,9 @@ import re import subprocess import sys import time +from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor, as_completed -from dataclasses import dataclass, asdict +from dataclasses import dataclass, asdict, field from pathlib import Path from typing import Dict, List, Optional, Any, Tuple import requests @@ -71,12 +72,23 @@ Express your final answer as the corresponding option 'A', 'B', 'C', or 'D'. """, } -@dataclass -class EvalState: - id: str - tasks: List[str] - task_states: Dict[str, Dict[str, Any]] - sampling_config: Dict[str, Any] + +class BaseDataset(ABC): + @abstractmethod + def get_question(self, index: int) -> Dict: + pass + + @abstractmethod + def get_answer(self, question: Dict) -> str: + pass + + @abstractmethod + def get_prompt(self, question: Dict) -> str: + pass + + def __len__(self) -> int: + return len(self.questions) + @dataclass class TaskState: @@ -88,13 +100,267 @@ class TaskState: correct: bool = False status: str = "pending" + +class EvalState: + def __init__( + self, + dataset_type: str, + sampling_config: Dict[str, Any], + output_file: Path = Path("llama-eval-state.json") + ): + self.dataset_type = dataset_type + self.sampling_config = sampling_config + self.output_file = output_file + self.dataset: Optional[BaseDataset] = None + self.tasks: List[Tuple[int, str]] = [] + self.all_tasks: List[Tuple[int, str]] = [] + self.task_states: Dict[str, Any] = {} + self.total = 0 + self.correct = 0 + self.processed = 0 + + def load_dataset(self, seed: int = 1234): + if self.dataset_type == "aime": + self.dataset = AimeDataset() + elif self.dataset_type == "aime2025": + self.dataset = Aime2025Dataset() + elif self.dataset_type == "gsm8k": + self.dataset = Gsm8kDataset() + elif self.dataset_type == "gpqa": + self.dataset = GpqaDataset(variant="diamond", seed=seed) + else: + raise ValueError(f"Unknown dataset type: {self.dataset_type}") + + def setup_tasks(self, n_cases: Optional[int] = None, seed: int = 1234): + if self.dataset is None: + raise ValueError("Dataset not loaded. Call load_dataset() first.") + + if n_cases is None: + n_cases = len(self.dataset) + + dataset_size = len(self.dataset) + rng = random.Random(seed) + + self.tasks = [] + for chunk_idx in range((n_cases + dataset_size - 1) // dataset_size): + chunk_size = min(dataset_size, n_cases - chunk_idx * dataset_size) + indices = list(range(dataset_size)) + rng.shuffle(indices) + chunk_indices = indices[:chunk_size] + + for i in chunk_indices: + task_id = f"{self.dataset_type}_{chunk_idx:03d}_{i:03d}" + self.tasks.append((i, task_id)) + + self.all_tasks = list(self.tasks) + + def get_case(self, index: int) -> Tuple[str, str]: + if self.dataset is None: + raise ValueError("Dataset not loaded.") + question = self.dataset.get_question(index) + prompt = self.dataset.get_prompt(question) + gold = self.dataset.get_answer(question) + return prompt, gold + + def add_result( + self, + task_id: str, + prompt: str, + gold: str, + pred: Optional[str], + extracted: Optional[str], + correct: bool, + status: str + ): + if self.dataset_type not in self.task_states: + self.task_states[self.dataset_type] = {} + if "cases" not in self.task_states[self.dataset_type]: + self.task_states[self.dataset_type]["cases"] = {} + + self.task_states[self.dataset_type]["cases"][task_id] = { + "case_id": task_id, + "prompt": prompt, + "gold": gold, + "pred": pred, + "extracted": extracted, + "correct": correct, + "status": status + } + + if correct: + self.correct += 1 + else: + self.correct = sum(1 for c in self.task_states.get(self.dataset_type, {}).get("cases", {}).values() if c.get("correct", False)) + + def add_grader_log(self, grader_log: Dict[str, Any]): + if self.dataset_type not in self.task_states: + self.task_states[self.dataset_type] = {} + if "grader_log" not in self.task_states[self.dataset_type]: + self.task_states[self.dataset_type]["grader_log"] = [] + self.task_states[self.dataset_type]["grader_log"].append(grader_log) + + def print_task_header(self): + tasks_to_show = self.all_tasks if self.all_tasks else self.tasks + cases = self.task_states.get(self.dataset_type, {}).get("cases", {}) + print("Tasks:") + print(" Task ID Dataset Prompt (first 40 chars) Expected Extracted Status") + for i, task_id in tasks_to_show: + prompt, gold = self.get_case(i) + case = cases.get(task_id, {}) + status = case.get("status", "pending") + extracted = case.get("extracted", "N/A") if status == "ok" else "N/A" + is_correct = case.get("correct", False) if status == "ok" else False + symbol = "✓ " if is_correct else ("✗ " if status == "ok" else "") + first_line = prompt.split('\n')[0] + truncated_prompt = first_line[:43] + if len(first_line) > 43: + truncated_prompt += "..." + else: + truncated_prompt = truncated_prompt.ljust(43) + "..." + print(f" {task_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {gold:<10} {extracted:<10} {symbol}{status}") + print() + + def print_progress(self, task_state: TaskState, total_tasks: int, correct_count: int = 0): + extracted_display = task_state.extracted if task_state.extracted else "N/A" + success_ratio = correct_count / self.processed if self.processed > 0 else 0.0 + first_line = task_state.prompt.split('\n')[0] + truncated_prompt = first_line[:43] + if len(first_line) > 43: + truncated_prompt += "..." + else: + truncated_prompt = truncated_prompt.ljust(43) + "..." + print(f"{self.processed:3}/{total_tasks:3} {task_state.case_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {task_state.gold:<10} {extracted_display:<10} {'✓' if task_state.correct else '✗'} [{correct_count:3}/{self.processed:3}, {success_ratio:.3f}]") + + def print_summary(self): + if self.total == 0: + print(f"\n{'='*60}") + print(f"Results: 0/0 correct (0.0%)") + print(f"{'='*60}") + else: + print(f"\n{'='*60}") + print(f"Results: {self.correct}/{self.total} correct ({self.correct/self.total*100:.1f}%)") + print(f"{'='*60}") + + def dump(self): + tasks_to_save = self.all_tasks if self.all_tasks else self.tasks + all_cases = {} + for i, task_id in tasks_to_save: + prompt, gold = self.get_case(i) + if task_id in self.task_states.get(self.dataset_type, {}).get("cases", {}): + all_cases[task_id] = self.task_states[self.dataset_type]["cases"][task_id] + else: + all_cases[task_id] = { + "case_id": task_id, + "prompt": prompt, + "gold": gold, + "pred": None, + "extracted": None, + "correct": False, + "status": "pending" + } + + data = { + "id": self.dataset_type, + "tasks": [tid for _, tid in tasks_to_save], + "task_states": { + self.dataset_type: { + "total": self.total, + "correct": self.correct, + "cases": all_cases, + "grader_log": self.task_states.get("grader_log", []) + } + }, + "sampling_config": self.sampling_config + } + with open(self.output_file, "w") as f: + json.dump(data, f, indent=2) + + @classmethod + def load(cls, path: Path) -> "EvalState": + with open(path, "r") as f: + data = json.load(f) + + eval_state = cls( + dataset_type=data["id"], + sampling_config=data["sampling_config"], + output_file=path + ) + eval_state.load_dataset() + + eval_state.tasks = [] + eval_state.all_tasks = [] + for task_id in data.get("tasks", []): + parts = task_id.rsplit("_", 2) + if len(parts) >= 3: + idx = int(parts[-1]) + else: + idx = 0 + eval_state.tasks.append((idx, task_id)) + eval_state.all_tasks.append((idx, task_id)) + + eval_state.task_states = data.get("task_states", {}) + + cases = eval_state.task_states.get(eval_state.dataset_type, {}).get("cases", {}) + eval_state.total = eval_state.task_states.get(eval_state.dataset_type, {}).get("total", 0) + eval_state.correct = eval_state.task_states.get(eval_state.dataset_type, {}).get("correct", 0) + + if eval_state.total == 0: + eval_state.total = len(cases) + eval_state.correct = sum(1 for c in cases.values() if c.get("correct", False)) + + return eval_state + + def is_complete(self) -> bool: + if not self.all_tasks: + return False + cases = self.task_states.get(self.dataset_type, {}).get("cases", {}) + completed = {tid for tid in self.task_states.get(self.dataset_type, {}).get("cases", {}).keys() if cases.get(tid, {}).get("status") == "ok"} + return len(completed) == len(self.all_tasks) + + def get_pending_tasks(self) -> List[Tuple[int, str]]: + cases = self.task_states.get(self.dataset_type, {}).get("cases", {}) + pending = [] + for i, task_id in self.all_tasks: + if cases.get(task_id, {}).get("status") != "ok": + pending.append((i, task_id)) + return pending + + def print_all_tasks(self): + cases = self.task_states.get(self.dataset_type, {}).get("cases", {}) + tasks_to_show = self.all_tasks if self.all_tasks else self.tasks + print("Tasks:") + print(" Task ID Dataset Prompt (first 40 chars) Expected Extracted Status") + for i, task_id in tasks_to_show: + prompt, gold = self.get_case(i) + case = cases.get(task_id, {}) + status = case.get("status", "pending") + extracted = case.get("extracted", "N/A") if status == "ok" else "N/A" + is_correct = case.get("correct", False) if status == "ok" else False + symbol = "✓ " if is_correct else ("✗ " if status == "ok" else "") + first_line = prompt.split('\n')[0] + truncated_prompt = first_line[:43] + if len(first_line) > 43: + truncated_prompt += "..." + else: + truncated_prompt = truncated_prompt.ljust(43) + "..." + print(f" {task_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {gold:<10} {extracted:<10} {symbol}{status}") + print() + + def print_existing_summary(self): + cases = self.task_states.get(self.dataset_type, {}).get("cases", {}) + correct = sum(1 for c in cases.values() if c.get("correct", False)) + total = len(cases) + print(f"\n{'='*60}") + print(f"Results: {correct}/{total} correct ({correct/total*100:.1f}%)") + print(f"{'='*60}") + def normalize_number(s: str) -> Optional[int]: match = re.match(r"\d+", s) # match digits from the start if not match: return None return int(match.group(0)) -class AimeDataset: +class AimeDataset(BaseDataset): def __init__(self, split: str = "train"): self.split = split self.questions: List[Dict] = [] @@ -139,7 +405,7 @@ class AimeDataset: question=question["problem"] if "problem" in question else question["question"] ) -class Aime2025Dataset: +class Aime2025Dataset(BaseDataset): def __init__(self): self.questions: List[Dict] = [] self._load_dataset() @@ -197,7 +463,7 @@ class Aime2025Dataset: question=question["question"] ) -class Gsm8kDataset: +class Gsm8kDataset(BaseDataset): def __init__(self, split: str = "train"): self.split = split self.questions: List[Dict] = [] @@ -253,7 +519,7 @@ class Gsm8kDataset: question=question["problem"] if "problem" in question else question["question"] ) -class GpqaDataset: +class GpqaDataset(BaseDataset): def __init__(self, variant: str = "diamond", seed: int = 1234): self.variant = variant self.seed = seed @@ -461,84 +727,38 @@ class Processor: def __init__( self, server_url: str, - n_predict: int = -1, - threads: int = 32, - verbose: bool = False, - grader: Optional[Grader] = None, + grader: Grader, model_name: Optional[str] = None, - judge_server_url: str = "", - judge_model_name: Optional[str] = None, - dataset_type: str = "aime", - seed: int = 1234, - sampling_config: Optional[Dict[str, Any]] = None, - output_file: Optional[Path] = None + threads: int = 32 ): self.server_url = server_url - self.n_predict = n_predict - self.threads = threads - self.verbose = verbose + self.grader = grader self.model_name = model_name - self.judge_server_url = judge_server_url if judge_server_url else server_url - self.judge_model_name = judge_model_name - self.dataset_type = dataset_type - self.seed = seed - self.grader = grader or Grader() - self.sampling_config = sampling_config or {"n_predict": n_predict} - self.output_file = output_file or Path("llama-eval-state.json") - self.eval_state = EvalState( - id=dataset_type, - tasks=[dataset_type], - task_states={dataset_type: {}}, - sampling_config=self.sampling_config - ) + self.threads = threads - # Pass judge configuration to grader if using LLM grader - if self.grader.grader_type == "llm": - if self.judge_model_name: - self.grader.judge_model_name = self.judge_model_name - if self.judge_server_url: - self.grader.judge_server_url = self.judge_server_url - - # Initialize appropriate dataset - if dataset_type == "aime": - self.dataset = AimeDataset() - elif dataset_type == "aime2025": - self.dataset = Aime2025Dataset() - elif dataset_type == "gsm8k": - self.dataset = Gsm8kDataset() - elif dataset_type == "gpqa": - self.dataset = GpqaDataset(variant="diamond", seed=self.seed) - else: - raise ValueError(f"Unknown dataset type: {dataset_type}") - - def _make_request(self, prompt: str) -> Dict[str, Any]: - """Make HTTP request to the server""" + def _make_request(self, eval_state: EvalState, prompt: str) -> Dict[str, Any]: url = f"{self.server_url}/v1/chat/completions" headers = {"Content-Type": "application/json"} data = { "model": self.model_name if self.model_name else "llama", "messages": [{"role": "user", "content": prompt}], - "n_predict": self.n_predict + "n_predict": eval_state.sampling_config.get("n_predict", -1) } - if self.sampling_config.get("temperature") is not None: - data["temperature"] = self.sampling_config["temperature"] - if self.sampling_config.get("top_k") is not None: - data["top_k"] = self.sampling_config["top_k"] - if self.sampling_config.get("top_p") is not None: - data["top_p"] = self.sampling_config["top_p"] - if self.sampling_config.get("min_p") is not None: - data["min_p"] = self.sampling_config["min_p"] + if eval_state.sampling_config.get("temperature") is not None: + data["temperature"] = eval_state.sampling_config["temperature"] + if eval_state.sampling_config.get("top_k") is not None: + data["top_k"] = eval_state.sampling_config["top_k"] + if eval_state.sampling_config.get("top_p") is not None: + data["top_p"] = eval_state.sampling_config["top_p"] + if eval_state.sampling_config.get("min_p") is not None: + data["min_p"] = eval_state.sampling_config["min_p"] response = requests.post(url, headers=headers, json=data) response.raise_for_status() return response.json() - def _process_single_case(self, i: int, task_id: str) -> TaskState: - """Process a single case (thread-safe)""" - question = self.dataset.get_question(i) - dataset_id = f"{self.dataset_type}_{i}" - gold = self.dataset.get_answer(question) - prompt = self.dataset.get_prompt(question) + def _process_single_case(self, eval_state: EvalState, i: int, task_id: str) -> TaskState: + prompt, gold = eval_state.get_case(i) task_state = TaskState( case_id=task_id, @@ -547,20 +767,16 @@ class Processor: ) try: - response = self._make_request(prompt) + response = self._make_request(eval_state, prompt) pred = response["choices"][0]["message"]["content"] task_state.pred = pred - # Truncate response to last 2-3 lines for grading pred_truncated = self.grader._truncate_response(pred, max_lines=10) - - # Grade the response is_correct, extracted = self.grader.grade(gold, pred_truncated, prompt) task_state.correct = is_correct task_state.extracted = extracted task_state.status = "ok" - # Log grader request details for debugging grader_log = { "case_id": task_id, "gold": gold, @@ -571,111 +787,49 @@ class Processor: } if self.grader.grader_type == "regex" and self.grader.pattern: grader_log["pattern"] = self.grader.pattern - if "grader_log" not in self.eval_state.task_states[self.dataset_type]: - self.eval_state.task_states[self.dataset_type]["grader_log"] = [] - self.eval_state.task_states[self.dataset_type]["grader_log"].append(grader_log) + eval_state.add_grader_log(grader_log) - # Initialize cases dict if it doesn't exist - if "cases" not in self.eval_state.task_states[self.dataset_type]: - self.eval_state.task_states[self.dataset_type]["cases"] = {} + eval_state.add_result(task_id, prompt, gold, pred, extracted, is_correct, "ok") - # Update eval state with grading details - self.eval_state.task_states[self.dataset_type]["cases"][task_id] = { - "case_id": task_id, - "prompt": prompt, - "gold": gold, - "pred": pred, - "extracted": extracted, - "correct": is_correct, - "status": "ok" - } + eval_state.dump() - # Save eval state to disk after each task - try: - self.dump_state(self.output_file) - except Exception as dump_error: - task_state.status = f"error: {str(e)}; dump error: {str(dump_error)}" - except Exception as processing_error: - task_state.status = f"error: {str(processing_error)}" + except Exception as e: + task_state.status = f"error: {str(e)}" return task_state - def process(self, n_cases: int = None, seed: int = 1234): - """Process cases and update eval state""" - if n_cases is None: - n_cases = len(self.dataset.questions) + def evaluate(self, eval_state: EvalState, verbose: bool = False, resume: bool = False): + total_tasks = len(eval_state.tasks) + eval_state.total = len(eval_state.all_tasks) if eval_state.all_tasks else total_tasks + eval_state.processed = 0 - print(f"\nProcessing {n_cases} {self.dataset_type.upper()} questions...") + print(f"\nProcessing {len(eval_state.tasks)} {eval_state.dataset_type.upper()} questions...") print(f"Server: {self.server_url} (model: {self.model_name})") - print(f"Grader: {self.grader.grader_type}", end="") - if self.grader.grader_type == "llm": - judge_model = self.judge_model_name if self.judge_model_name else self.model_name - print(f" (judge server: {self.judge_server_url}, model: {judge_model})", end="") - print() + print(f"Grader: {self.grader.grader_type}") print(f"Threads: {self.threads}") - print(f"Max tokens: {self.n_predict}") - print(f"Seed: {self.seed}") - print(f"Sampling: temp={self.sampling_config.get('temperature', 'skip')}, top-k={self.sampling_config.get('top_k', 'skip')}, top-p={self.sampling_config.get('top_p', 'skip')}, min-p={self.sampling_config.get('min_p', 'skip')}") + print(f"Sampling: temp={eval_state.sampling_config.get('temperature', 'skip')}, top-k={eval_state.sampling_config.get('top_k', 'skip')}, top-p={eval_state.sampling_config.get('top_p', 'skip')}, min-p={eval_state.sampling_config.get('min_p', 'skip')}") print() - dataset_size = len(self.dataset.questions) - random.seed(seed) + if not resume: + eval_state.print_task_header() - task_list = [] - for chunk_idx in range((n_cases + dataset_size - 1) // dataset_size): - chunk_size = min(dataset_size, n_cases - chunk_idx * dataset_size) - indices = list(range(dataset_size)) - random.shuffle(indices) - chunk_indices = indices[:chunk_size] - - for i in chunk_indices: - task_id = f"{self.dataset_type}_{chunk_idx:03d}_{i:03d}" - task_list.append((i, task_id)) - - # Print task summary table - print("Tasks:") - print(" Task ID Dataset Prompt (first 40 chars) Expected Status") - for i, task_id in task_list: - question = self.dataset.get_question(i) - prompt = self.dataset.get_prompt(question) - gold = self.dataset.get_answer(question) - first_line = prompt.split('\n')[0] - truncated_prompt = first_line[:43] - if len(first_line) > 43: - truncated_prompt += "..." - else: - truncated_prompt = truncated_prompt.ljust(43) + "..." - print(f" {task_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {gold:<10} pending") - print() - - task_states: Dict[str, List[TaskState]] = {task: [] for task in self.eval_state.tasks} - total = 0 - correct = 0 + correct_count = 0 with ThreadPoolExecutor(max_workers=self.threads) as executor: - futures = {executor.submit(self._process_single_case, i, task_id): (i, task_id) for i, task_id in task_list} + futures = { + executor.submit(self._process_single_case, eval_state, i, task_id): (i, task_id) + for i, task_id in eval_state.tasks + } for future in as_completed(futures): task_state = future.result() - task_states[self.dataset_type].append(task_state) - total += 1 - + eval_state.processed += 1 if task_state.correct: - correct += 1 + correct_count += 1 + eval_state.print_progress(task_state, total_tasks, correct_count) - # Print task completion status - extracted_display = task_state.extracted if task_state.extracted else "N/A" - success_ratio = correct / total if total > 0 else 0.0 - first_line = task_state.prompt.split('\n')[0] - truncated_prompt = first_line[:43] - if len(first_line) > 43: - truncated_prompt += "..." - else: - truncated_prompt = truncated_prompt.ljust(43) + "..." - print(f"{total:3}/{n_cases:3} {task_state.case_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {task_state.gold:<10} {extracted_display:<10} {'✓' if task_state.correct else '✗'} [{correct:3}/{total:3}, {success_ratio:.3f}]") - - if self.verbose: - print(f"\nCase {total}: {task_state.correct}") + if verbose: + print(f"\nCase {eval_state.processed}: {task_state.correct}") print(f" Gold: {task_state.gold}") if task_state.pred: print(f" Pred: {task_state.pred}") @@ -683,25 +837,9 @@ class Processor: print(f" Extracted: {task_state.extracted}") print(f" Status: {task_state.status}") - # Merge existing state with new state to preserve grader_log - existing_state = self.eval_state.task_states.get(self.dataset_type, {}) - self.eval_state.task_states[self.dataset_type] = { - "total": total, - "correct": correct, - "cases": task_states, - **existing_state - } - - print(f"\n{'='*60}") - print(f"Results: {correct}/{total} correct ({correct/total*100:.1f}%)") - print(f"{'='*60}") - - return self.eval_state - - def dump_state(self, output_file: Path): - """Dump eval state to JSON file""" - with open(output_file, "w") as f: - json.dump(asdict(self.eval_state), f, indent=2) + eval_state.correct = correct_count + eval_state.print_summary() + eval_state.dump() def main(): parser = argparse.ArgumentParser( @@ -810,51 +948,101 @@ def main(): default="", help="Model name for LLM judge (default: same as main model)" ) + parser.add_argument( + "--resume", + action="store_true", + help="Resume from existing eval state" + ) args = parser.parse_args() - # Validate grader type for GPQA if args.dataset == "gpqa" and args.grader_type != "llm": print("Error: GPQA dataset requires --grader-type llm") parser.print_help() sys.exit(1) - grader = Grader( - grader_type=args.grader_type, - grader_script=args.grader_script, - judge_model_name=args.judge_model if args.judge_model else args.model, - dataset_type=args.dataset - ) + if args.output.exists(): + print(f"Loading existing eval state from {args.output}") + eval_state = EvalState.load(args.output) - if args.grader_type == "llm" and not args.judge_server: - print("Warning: Using same server for LLM judge (no --judge-server specified)") + if eval_state.is_complete(): + eval_state.print_all_tasks() + eval_state.print_existing_summary() + return - sampling_config = {"n_predict": args.n_predict} - if args.temperature is not None: - sampling_config["temperature"] = args.temperature - if args.top_k is not None: - sampling_config["top_k"] = args.top_k - if args.top_p is not None: - sampling_config["top_p"] = args.top_p - if args.min_p is not None: - sampling_config["min_p"] = args.min_p + eval_state.print_all_tasks() + eval_state.print_existing_summary() + + if not args.resume: + print(f"Evaluation incomplete. Run with --resume to continue.") + return + + pending_tasks = eval_state.get_pending_tasks() + print(f"Resuming from {len(pending_tasks)} pending tasks") + + existing_cases = eval_state.task_states.get(eval_state.dataset_type, {}).get("cases", {}) + + eval_state.tasks = pending_tasks + eval_state.task_states.get(eval_state.dataset_type, {})["cases"] = existing_cases + eval_state.task_states.get(eval_state.dataset_type, {})["grader_log"] = [] + + judge_server_url = args.judge_server if args.judge_server else args.server + judge_model_name = args.judge_model if args.judge_model else args.model + grader = Grader( + grader_type=args.grader_type, + grader_script=args.grader_script, + judge_model_name=judge_model_name, + judge_server_url=judge_server_url, + dataset_type=eval_state.dataset_type + ) + resume = True + else: + if args.resume: + print("Error: No existing eval state found to resume") + sys.exit(1) + + judge_server_url = args.judge_server if args.judge_server else args.server + judge_model_name = args.judge_model if args.judge_model else args.model + + grader = Grader( + grader_type=args.grader_type, + grader_script=args.grader_script, + judge_model_name=judge_model_name, + judge_server_url=judge_server_url, + dataset_type=args.dataset + ) + + if args.grader_type == "llm" and not args.judge_server: + print("Warning: Using same server for LLM judge (no --judge-server specified)") + + sampling_config = {"n_predict": args.n_predict} + if args.temperature is not None: + sampling_config["temperature"] = args.temperature + if args.top_k is not None: + sampling_config["top_k"] = args.top_k + if args.top_p is not None: + sampling_config["top_p"] = args.top_p + if args.min_p is not None: + sampling_config["min_p"] = args.min_p + + eval_state = EvalState( + dataset_type=args.dataset, + sampling_config=sampling_config, + output_file=args.output + ) + eval_state.load_dataset(seed=args.seed) + eval_state.setup_tasks(n_cases=args.n_cases, seed=args.seed) + eval_state.dump() + resume = False processor = Processor( server_url=args.server, - n_predict=args.n_predict, - threads=args.threads, - verbose=args.verbose, grader=grader, model_name=args.model, - judge_server_url=args.judge_server, - judge_model_name=args.judge_model, - dataset_type=args.dataset, - sampling_config=sampling_config, - output_file=args.output + threads=args.threads ) - eval_state = processor.process(n_cases=args.n_cases, seed=args.seed) - processor.dump_state(args.output) + processor.evaluate(eval_state, verbose=args.verbose, resume=resume) print(f"\nEval state dumped to {args.output}") if __name__ == "__main__":