resume eval

This commit is contained in:
Georgi Gerganov 2026-02-16 16:21:36 +02:00
parent ad3a54eb68
commit e6e777cfb3
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 399 additions and 240 deletions

View File

@ -1,29 +0,0 @@
{
"id": "gpqa",
"tasks": [
"gpqa"
],
"task_states": {
"gpqa": {
"total": 1,
"correct": 0,
"cases": {
"gpqa": [
{
"case_id": "gpqa_000_184",
"prompt": "Consider a system with Hamiltonian operator $H = \\varepsilon \\vec{\\sigma}.\\vec{n}$. Here, $\\vec{n}$ is an arbitrary unit vector, $\\varepsilon $ is a constant of dimension energy, and components of $\\vec{\\sigma}$ are the Pauli spin matrices. What are the eigenvalues of the Hamiltonian operator?\n\n\n(A) +\\hbar/2, -\\hbar/2\n(B) +1, -1\n(C) +\\varepsilon \\hbar/2, - \\varepsilon \\hbar/2\n(D) + \\varepsilon, -\\varepsilon\n\n\nExpress your final answer as the corresponding option 'A', 'B', 'C', or 'D'.\n",
"gold": "+ \\varepsilon, -\\varepsilon\n",
"pred": null,
"extracted": null,
"correct": false,
"status": "error: HTTPConnectionPool(host='localhost', port=8034): Max retries exceeded with url: /v1/chat/completions (Caused by NewConnectionError(\"HTTPConnection(host='localhost', port=8034): Failed to establish a new connection: [Errno 61] Connection refused\"))"
}
]
}
}
},
"sampling_config": {
"temperature": 0,
"max_tokens": 2048
}
}

View File

@ -8,8 +8,9 @@ import re
import subprocess
import sys
import time
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, asdict
from dataclasses import dataclass, asdict, field
from pathlib import Path
from typing import Dict, List, Optional, Any, Tuple
import requests
@ -71,12 +72,23 @@ Express your final answer as the corresponding option 'A', 'B', 'C', or 'D'.
""",
}
@dataclass
class EvalState:
id: str
tasks: List[str]
task_states: Dict[str, Dict[str, Any]]
sampling_config: Dict[str, Any]
class BaseDataset(ABC):
@abstractmethod
def get_question(self, index: int) -> Dict:
pass
@abstractmethod
def get_answer(self, question: Dict) -> str:
pass
@abstractmethod
def get_prompt(self, question: Dict) -> str:
pass
def __len__(self) -> int:
return len(self.questions)
@dataclass
class TaskState:
@ -88,13 +100,267 @@ class TaskState:
correct: bool = False
status: str = "pending"
class EvalState:
def __init__(
self,
dataset_type: str,
sampling_config: Dict[str, Any],
output_file: Path = Path("llama-eval-state.json")
):
self.dataset_type = dataset_type
self.sampling_config = sampling_config
self.output_file = output_file
self.dataset: Optional[BaseDataset] = None
self.tasks: List[Tuple[int, str]] = []
self.all_tasks: List[Tuple[int, str]] = []
self.task_states: Dict[str, Any] = {}
self.total = 0
self.correct = 0
self.processed = 0
def load_dataset(self, seed: int = 1234):
if self.dataset_type == "aime":
self.dataset = AimeDataset()
elif self.dataset_type == "aime2025":
self.dataset = Aime2025Dataset()
elif self.dataset_type == "gsm8k":
self.dataset = Gsm8kDataset()
elif self.dataset_type == "gpqa":
self.dataset = GpqaDataset(variant="diamond", seed=seed)
else:
raise ValueError(f"Unknown dataset type: {self.dataset_type}")
def setup_tasks(self, n_cases: Optional[int] = None, seed: int = 1234):
if self.dataset is None:
raise ValueError("Dataset not loaded. Call load_dataset() first.")
if n_cases is None:
n_cases = len(self.dataset)
dataset_size = len(self.dataset)
rng = random.Random(seed)
self.tasks = []
for chunk_idx in range((n_cases + dataset_size - 1) // dataset_size):
chunk_size = min(dataset_size, n_cases - chunk_idx * dataset_size)
indices = list(range(dataset_size))
rng.shuffle(indices)
chunk_indices = indices[:chunk_size]
for i in chunk_indices:
task_id = f"{self.dataset_type}_{chunk_idx:03d}_{i:03d}"
self.tasks.append((i, task_id))
self.all_tasks = list(self.tasks)
def get_case(self, index: int) -> Tuple[str, str]:
if self.dataset is None:
raise ValueError("Dataset not loaded.")
question = self.dataset.get_question(index)
prompt = self.dataset.get_prompt(question)
gold = self.dataset.get_answer(question)
return prompt, gold
def add_result(
self,
task_id: str,
prompt: str,
gold: str,
pred: Optional[str],
extracted: Optional[str],
correct: bool,
status: str
):
if self.dataset_type not in self.task_states:
self.task_states[self.dataset_type] = {}
if "cases" not in self.task_states[self.dataset_type]:
self.task_states[self.dataset_type]["cases"] = {}
self.task_states[self.dataset_type]["cases"][task_id] = {
"case_id": task_id,
"prompt": prompt,
"gold": gold,
"pred": pred,
"extracted": extracted,
"correct": correct,
"status": status
}
if correct:
self.correct += 1
else:
self.correct = sum(1 for c in self.task_states.get(self.dataset_type, {}).get("cases", {}).values() if c.get("correct", False))
def add_grader_log(self, grader_log: Dict[str, Any]):
if self.dataset_type not in self.task_states:
self.task_states[self.dataset_type] = {}
if "grader_log" not in self.task_states[self.dataset_type]:
self.task_states[self.dataset_type]["grader_log"] = []
self.task_states[self.dataset_type]["grader_log"].append(grader_log)
def print_task_header(self):
tasks_to_show = self.all_tasks if self.all_tasks else self.tasks
cases = self.task_states.get(self.dataset_type, {}).get("cases", {})
print("Tasks:")
print(" Task ID Dataset Prompt (first 40 chars) Expected Extracted Status")
for i, task_id in tasks_to_show:
prompt, gold = self.get_case(i)
case = cases.get(task_id, {})
status = case.get("status", "pending")
extracted = case.get("extracted", "N/A") if status == "ok" else "N/A"
is_correct = case.get("correct", False) if status == "ok" else False
symbol = "" if is_correct else ("" if status == "ok" else "")
first_line = prompt.split('\n')[0]
truncated_prompt = first_line[:43]
if len(first_line) > 43:
truncated_prompt += "..."
else:
truncated_prompt = truncated_prompt.ljust(43) + "..."
print(f" {task_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {gold:<10} {extracted:<10} {symbol}{status}")
print()
def print_progress(self, task_state: TaskState, total_tasks: int, correct_count: int = 0):
extracted_display = task_state.extracted if task_state.extracted else "N/A"
success_ratio = correct_count / self.processed if self.processed > 0 else 0.0
first_line = task_state.prompt.split('\n')[0]
truncated_prompt = first_line[:43]
if len(first_line) > 43:
truncated_prompt += "..."
else:
truncated_prompt = truncated_prompt.ljust(43) + "..."
print(f"{self.processed:3}/{total_tasks:3} {task_state.case_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {task_state.gold:<10} {extracted_display:<10} {'' if task_state.correct else ''} [{correct_count:3}/{self.processed:3}, {success_ratio:.3f}]")
def print_summary(self):
if self.total == 0:
print(f"\n{'='*60}")
print(f"Results: 0/0 correct (0.0%)")
print(f"{'='*60}")
else:
print(f"\n{'='*60}")
print(f"Results: {self.correct}/{self.total} correct ({self.correct/self.total*100:.1f}%)")
print(f"{'='*60}")
def dump(self):
tasks_to_save = self.all_tasks if self.all_tasks else self.tasks
all_cases = {}
for i, task_id in tasks_to_save:
prompt, gold = self.get_case(i)
if task_id in self.task_states.get(self.dataset_type, {}).get("cases", {}):
all_cases[task_id] = self.task_states[self.dataset_type]["cases"][task_id]
else:
all_cases[task_id] = {
"case_id": task_id,
"prompt": prompt,
"gold": gold,
"pred": None,
"extracted": None,
"correct": False,
"status": "pending"
}
data = {
"id": self.dataset_type,
"tasks": [tid for _, tid in tasks_to_save],
"task_states": {
self.dataset_type: {
"total": self.total,
"correct": self.correct,
"cases": all_cases,
"grader_log": self.task_states.get("grader_log", [])
}
},
"sampling_config": self.sampling_config
}
with open(self.output_file, "w") as f:
json.dump(data, f, indent=2)
@classmethod
def load(cls, path: Path) -> "EvalState":
with open(path, "r") as f:
data = json.load(f)
eval_state = cls(
dataset_type=data["id"],
sampling_config=data["sampling_config"],
output_file=path
)
eval_state.load_dataset()
eval_state.tasks = []
eval_state.all_tasks = []
for task_id in data.get("tasks", []):
parts = task_id.rsplit("_", 2)
if len(parts) >= 3:
idx = int(parts[-1])
else:
idx = 0
eval_state.tasks.append((idx, task_id))
eval_state.all_tasks.append((idx, task_id))
eval_state.task_states = data.get("task_states", {})
cases = eval_state.task_states.get(eval_state.dataset_type, {}).get("cases", {})
eval_state.total = eval_state.task_states.get(eval_state.dataset_type, {}).get("total", 0)
eval_state.correct = eval_state.task_states.get(eval_state.dataset_type, {}).get("correct", 0)
if eval_state.total == 0:
eval_state.total = len(cases)
eval_state.correct = sum(1 for c in cases.values() if c.get("correct", False))
return eval_state
def is_complete(self) -> bool:
if not self.all_tasks:
return False
cases = self.task_states.get(self.dataset_type, {}).get("cases", {})
completed = {tid for tid in self.task_states.get(self.dataset_type, {}).get("cases", {}).keys() if cases.get(tid, {}).get("status") == "ok"}
return len(completed) == len(self.all_tasks)
def get_pending_tasks(self) -> List[Tuple[int, str]]:
cases = self.task_states.get(self.dataset_type, {}).get("cases", {})
pending = []
for i, task_id in self.all_tasks:
if cases.get(task_id, {}).get("status") != "ok":
pending.append((i, task_id))
return pending
def print_all_tasks(self):
cases = self.task_states.get(self.dataset_type, {}).get("cases", {})
tasks_to_show = self.all_tasks if self.all_tasks else self.tasks
print("Tasks:")
print(" Task ID Dataset Prompt (first 40 chars) Expected Extracted Status")
for i, task_id in tasks_to_show:
prompt, gold = self.get_case(i)
case = cases.get(task_id, {})
status = case.get("status", "pending")
extracted = case.get("extracted", "N/A") if status == "ok" else "N/A"
is_correct = case.get("correct", False) if status == "ok" else False
symbol = "" if is_correct else ("" if status == "ok" else "")
first_line = prompt.split('\n')[0]
truncated_prompt = first_line[:43]
if len(first_line) > 43:
truncated_prompt += "..."
else:
truncated_prompt = truncated_prompt.ljust(43) + "..."
print(f" {task_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {gold:<10} {extracted:<10} {symbol}{status}")
print()
def print_existing_summary(self):
cases = self.task_states.get(self.dataset_type, {}).get("cases", {})
correct = sum(1 for c in cases.values() if c.get("correct", False))
total = len(cases)
print(f"\n{'='*60}")
print(f"Results: {correct}/{total} correct ({correct/total*100:.1f}%)")
print(f"{'='*60}")
def normalize_number(s: str) -> Optional[int]:
match = re.match(r"\d+", s) # match digits from the start
if not match:
return None
return int(match.group(0))
class AimeDataset:
class AimeDataset(BaseDataset):
def __init__(self, split: str = "train"):
self.split = split
self.questions: List[Dict] = []
@ -139,7 +405,7 @@ class AimeDataset:
question=question["problem"] if "problem" in question else question["question"]
)
class Aime2025Dataset:
class Aime2025Dataset(BaseDataset):
def __init__(self):
self.questions: List[Dict] = []
self._load_dataset()
@ -197,7 +463,7 @@ class Aime2025Dataset:
question=question["question"]
)
class Gsm8kDataset:
class Gsm8kDataset(BaseDataset):
def __init__(self, split: str = "train"):
self.split = split
self.questions: List[Dict] = []
@ -253,7 +519,7 @@ class Gsm8kDataset:
question=question["problem"] if "problem" in question else question["question"]
)
class GpqaDataset:
class GpqaDataset(BaseDataset):
def __init__(self, variant: str = "diamond", seed: int = 1234):
self.variant = variant
self.seed = seed
@ -461,84 +727,38 @@ class Processor:
def __init__(
self,
server_url: str,
n_predict: int = -1,
threads: int = 32,
verbose: bool = False,
grader: Optional[Grader] = None,
grader: Grader,
model_name: Optional[str] = None,
judge_server_url: str = "",
judge_model_name: Optional[str] = None,
dataset_type: str = "aime",
seed: int = 1234,
sampling_config: Optional[Dict[str, Any]] = None,
output_file: Optional[Path] = None
threads: int = 32
):
self.server_url = server_url
self.n_predict = n_predict
self.threads = threads
self.verbose = verbose
self.grader = grader
self.model_name = model_name
self.judge_server_url = judge_server_url if judge_server_url else server_url
self.judge_model_name = judge_model_name
self.dataset_type = dataset_type
self.seed = seed
self.grader = grader or Grader()
self.sampling_config = sampling_config or {"n_predict": n_predict}
self.output_file = output_file or Path("llama-eval-state.json")
self.eval_state = EvalState(
id=dataset_type,
tasks=[dataset_type],
task_states={dataset_type: {}},
sampling_config=self.sampling_config
)
self.threads = threads
# Pass judge configuration to grader if using LLM grader
if self.grader.grader_type == "llm":
if self.judge_model_name:
self.grader.judge_model_name = self.judge_model_name
if self.judge_server_url:
self.grader.judge_server_url = self.judge_server_url
# Initialize appropriate dataset
if dataset_type == "aime":
self.dataset = AimeDataset()
elif dataset_type == "aime2025":
self.dataset = Aime2025Dataset()
elif dataset_type == "gsm8k":
self.dataset = Gsm8kDataset()
elif dataset_type == "gpqa":
self.dataset = GpqaDataset(variant="diamond", seed=self.seed)
else:
raise ValueError(f"Unknown dataset type: {dataset_type}")
def _make_request(self, prompt: str) -> Dict[str, Any]:
"""Make HTTP request to the server"""
def _make_request(self, eval_state: EvalState, prompt: str) -> Dict[str, Any]:
url = f"{self.server_url}/v1/chat/completions"
headers = {"Content-Type": "application/json"}
data = {
"model": self.model_name if self.model_name else "llama",
"messages": [{"role": "user", "content": prompt}],
"n_predict": self.n_predict
"n_predict": eval_state.sampling_config.get("n_predict", -1)
}
if self.sampling_config.get("temperature") is not None:
data["temperature"] = self.sampling_config["temperature"]
if self.sampling_config.get("top_k") is not None:
data["top_k"] = self.sampling_config["top_k"]
if self.sampling_config.get("top_p") is not None:
data["top_p"] = self.sampling_config["top_p"]
if self.sampling_config.get("min_p") is not None:
data["min_p"] = self.sampling_config["min_p"]
if eval_state.sampling_config.get("temperature") is not None:
data["temperature"] = eval_state.sampling_config["temperature"]
if eval_state.sampling_config.get("top_k") is not None:
data["top_k"] = eval_state.sampling_config["top_k"]
if eval_state.sampling_config.get("top_p") is not None:
data["top_p"] = eval_state.sampling_config["top_p"]
if eval_state.sampling_config.get("min_p") is not None:
data["min_p"] = eval_state.sampling_config["min_p"]
response = requests.post(url, headers=headers, json=data)
response.raise_for_status()
return response.json()
def _process_single_case(self, i: int, task_id: str) -> TaskState:
"""Process a single case (thread-safe)"""
question = self.dataset.get_question(i)
dataset_id = f"{self.dataset_type}_{i}"
gold = self.dataset.get_answer(question)
prompt = self.dataset.get_prompt(question)
def _process_single_case(self, eval_state: EvalState, i: int, task_id: str) -> TaskState:
prompt, gold = eval_state.get_case(i)
task_state = TaskState(
case_id=task_id,
@ -547,20 +767,16 @@ class Processor:
)
try:
response = self._make_request(prompt)
response = self._make_request(eval_state, prompt)
pred = response["choices"][0]["message"]["content"]
task_state.pred = pred
# Truncate response to last 2-3 lines for grading
pred_truncated = self.grader._truncate_response(pred, max_lines=10)
# Grade the response
is_correct, extracted = self.grader.grade(gold, pred_truncated, prompt)
task_state.correct = is_correct
task_state.extracted = extracted
task_state.status = "ok"
# Log grader request details for debugging
grader_log = {
"case_id": task_id,
"gold": gold,
@ -571,111 +787,49 @@ class Processor:
}
if self.grader.grader_type == "regex" and self.grader.pattern:
grader_log["pattern"] = self.grader.pattern
if "grader_log" not in self.eval_state.task_states[self.dataset_type]:
self.eval_state.task_states[self.dataset_type]["grader_log"] = []
self.eval_state.task_states[self.dataset_type]["grader_log"].append(grader_log)
eval_state.add_grader_log(grader_log)
# Initialize cases dict if it doesn't exist
if "cases" not in self.eval_state.task_states[self.dataset_type]:
self.eval_state.task_states[self.dataset_type]["cases"] = {}
eval_state.add_result(task_id, prompt, gold, pred, extracted, is_correct, "ok")
# Update eval state with grading details
self.eval_state.task_states[self.dataset_type]["cases"][task_id] = {
"case_id": task_id,
"prompt": prompt,
"gold": gold,
"pred": pred,
"extracted": extracted,
"correct": is_correct,
"status": "ok"
}
eval_state.dump()
# Save eval state to disk after each task
try:
self.dump_state(self.output_file)
except Exception as dump_error:
task_state.status = f"error: {str(e)}; dump error: {str(dump_error)}"
except Exception as processing_error:
task_state.status = f"error: {str(processing_error)}"
except Exception as e:
task_state.status = f"error: {str(e)}"
return task_state
def process(self, n_cases: int = None, seed: int = 1234):
"""Process cases and update eval state"""
if n_cases is None:
n_cases = len(self.dataset.questions)
def evaluate(self, eval_state: EvalState, verbose: bool = False, resume: bool = False):
total_tasks = len(eval_state.tasks)
eval_state.total = len(eval_state.all_tasks) if eval_state.all_tasks else total_tasks
eval_state.processed = 0
print(f"\nProcessing {n_cases} {self.dataset_type.upper()} questions...")
print(f"\nProcessing {len(eval_state.tasks)} {eval_state.dataset_type.upper()} questions...")
print(f"Server: {self.server_url} (model: {self.model_name})")
print(f"Grader: {self.grader.grader_type}", end="")
if self.grader.grader_type == "llm":
judge_model = self.judge_model_name if self.judge_model_name else self.model_name
print(f" (judge server: {self.judge_server_url}, model: {judge_model})", end="")
print()
print(f"Grader: {self.grader.grader_type}")
print(f"Threads: {self.threads}")
print(f"Max tokens: {self.n_predict}")
print(f"Seed: {self.seed}")
print(f"Sampling: temp={self.sampling_config.get('temperature', 'skip')}, top-k={self.sampling_config.get('top_k', 'skip')}, top-p={self.sampling_config.get('top_p', 'skip')}, min-p={self.sampling_config.get('min_p', 'skip')}")
print(f"Sampling: temp={eval_state.sampling_config.get('temperature', 'skip')}, top-k={eval_state.sampling_config.get('top_k', 'skip')}, top-p={eval_state.sampling_config.get('top_p', 'skip')}, min-p={eval_state.sampling_config.get('min_p', 'skip')}")
print()
dataset_size = len(self.dataset.questions)
random.seed(seed)
if not resume:
eval_state.print_task_header()
task_list = []
for chunk_idx in range((n_cases + dataset_size - 1) // dataset_size):
chunk_size = min(dataset_size, n_cases - chunk_idx * dataset_size)
indices = list(range(dataset_size))
random.shuffle(indices)
chunk_indices = indices[:chunk_size]
for i in chunk_indices:
task_id = f"{self.dataset_type}_{chunk_idx:03d}_{i:03d}"
task_list.append((i, task_id))
# Print task summary table
print("Tasks:")
print(" Task ID Dataset Prompt (first 40 chars) Expected Status")
for i, task_id in task_list:
question = self.dataset.get_question(i)
prompt = self.dataset.get_prompt(question)
gold = self.dataset.get_answer(question)
first_line = prompt.split('\n')[0]
truncated_prompt = first_line[:43]
if len(first_line) > 43:
truncated_prompt += "..."
else:
truncated_prompt = truncated_prompt.ljust(43) + "..."
print(f" {task_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {gold:<10} pending")
print()
task_states: Dict[str, List[TaskState]] = {task: [] for task in self.eval_state.tasks}
total = 0
correct = 0
correct_count = 0
with ThreadPoolExecutor(max_workers=self.threads) as executor:
futures = {executor.submit(self._process_single_case, i, task_id): (i, task_id) for i, task_id in task_list}
futures = {
executor.submit(self._process_single_case, eval_state, i, task_id): (i, task_id)
for i, task_id in eval_state.tasks
}
for future in as_completed(futures):
task_state = future.result()
task_states[self.dataset_type].append(task_state)
total += 1
eval_state.processed += 1
if task_state.correct:
correct += 1
correct_count += 1
eval_state.print_progress(task_state, total_tasks, correct_count)
# Print task completion status
extracted_display = task_state.extracted if task_state.extracted else "N/A"
success_ratio = correct / total if total > 0 else 0.0
first_line = task_state.prompt.split('\n')[0]
truncated_prompt = first_line[:43]
if len(first_line) > 43:
truncated_prompt += "..."
else:
truncated_prompt = truncated_prompt.ljust(43) + "..."
print(f"{total:3}/{n_cases:3} {task_state.case_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {task_state.gold:<10} {extracted_display:<10} {'' if task_state.correct else ''} [{correct:3}/{total:3}, {success_ratio:.3f}]")
if self.verbose:
print(f"\nCase {total}: {task_state.correct}")
if verbose:
print(f"\nCase {eval_state.processed}: {task_state.correct}")
print(f" Gold: {task_state.gold}")
if task_state.pred:
print(f" Pred: {task_state.pred}")
@ -683,25 +837,9 @@ class Processor:
print(f" Extracted: {task_state.extracted}")
print(f" Status: {task_state.status}")
# Merge existing state with new state to preserve grader_log
existing_state = self.eval_state.task_states.get(self.dataset_type, {})
self.eval_state.task_states[self.dataset_type] = {
"total": total,
"correct": correct,
"cases": task_states,
**existing_state
}
print(f"\n{'='*60}")
print(f"Results: {correct}/{total} correct ({correct/total*100:.1f}%)")
print(f"{'='*60}")
return self.eval_state
def dump_state(self, output_file: Path):
"""Dump eval state to JSON file"""
with open(output_file, "w") as f:
json.dump(asdict(self.eval_state), f, indent=2)
eval_state.correct = correct_count
eval_state.print_summary()
eval_state.dump()
def main():
parser = argparse.ArgumentParser(
@ -810,51 +948,101 @@ def main():
default="",
help="Model name for LLM judge (default: same as main model)"
)
parser.add_argument(
"--resume",
action="store_true",
help="Resume from existing eval state"
)
args = parser.parse_args()
# Validate grader type for GPQA
if args.dataset == "gpqa" and args.grader_type != "llm":
print("Error: GPQA dataset requires --grader-type llm")
parser.print_help()
sys.exit(1)
grader = Grader(
grader_type=args.grader_type,
grader_script=args.grader_script,
judge_model_name=args.judge_model if args.judge_model else args.model,
dataset_type=args.dataset
)
if args.output.exists():
print(f"Loading existing eval state from {args.output}")
eval_state = EvalState.load(args.output)
if args.grader_type == "llm" and not args.judge_server:
print("Warning: Using same server for LLM judge (no --judge-server specified)")
if eval_state.is_complete():
eval_state.print_all_tasks()
eval_state.print_existing_summary()
return
sampling_config = {"n_predict": args.n_predict}
if args.temperature is not None:
sampling_config["temperature"] = args.temperature
if args.top_k is not None:
sampling_config["top_k"] = args.top_k
if args.top_p is not None:
sampling_config["top_p"] = args.top_p
if args.min_p is not None:
sampling_config["min_p"] = args.min_p
eval_state.print_all_tasks()
eval_state.print_existing_summary()
if not args.resume:
print(f"Evaluation incomplete. Run with --resume to continue.")
return
pending_tasks = eval_state.get_pending_tasks()
print(f"Resuming from {len(pending_tasks)} pending tasks")
existing_cases = eval_state.task_states.get(eval_state.dataset_type, {}).get("cases", {})
eval_state.tasks = pending_tasks
eval_state.task_states.get(eval_state.dataset_type, {})["cases"] = existing_cases
eval_state.task_states.get(eval_state.dataset_type, {})["grader_log"] = []
judge_server_url = args.judge_server if args.judge_server else args.server
judge_model_name = args.judge_model if args.judge_model else args.model
grader = Grader(
grader_type=args.grader_type,
grader_script=args.grader_script,
judge_model_name=judge_model_name,
judge_server_url=judge_server_url,
dataset_type=eval_state.dataset_type
)
resume = True
else:
if args.resume:
print("Error: No existing eval state found to resume")
sys.exit(1)
judge_server_url = args.judge_server if args.judge_server else args.server
judge_model_name = args.judge_model if args.judge_model else args.model
grader = Grader(
grader_type=args.grader_type,
grader_script=args.grader_script,
judge_model_name=judge_model_name,
judge_server_url=judge_server_url,
dataset_type=args.dataset
)
if args.grader_type == "llm" and not args.judge_server:
print("Warning: Using same server for LLM judge (no --judge-server specified)")
sampling_config = {"n_predict": args.n_predict}
if args.temperature is not None:
sampling_config["temperature"] = args.temperature
if args.top_k is not None:
sampling_config["top_k"] = args.top_k
if args.top_p is not None:
sampling_config["top_p"] = args.top_p
if args.min_p is not None:
sampling_config["min_p"] = args.min_p
eval_state = EvalState(
dataset_type=args.dataset,
sampling_config=sampling_config,
output_file=args.output
)
eval_state.load_dataset(seed=args.seed)
eval_state.setup_tasks(n_cases=args.n_cases, seed=args.seed)
eval_state.dump()
resume = False
processor = Processor(
server_url=args.server,
n_predict=args.n_predict,
threads=args.threads,
verbose=args.verbose,
grader=grader,
model_name=args.model,
judge_server_url=args.judge_server,
judge_model_name=args.judge_model,
dataset_type=args.dataset,
sampling_config=sampling_config,
output_file=args.output
threads=args.threads
)
eval_state = processor.process(n_cases=args.n_cases, seed=args.seed)
processor.dump_state(args.output)
processor.evaluate(eval_state, verbose=args.verbose, resume=resume)
print(f"\nEval state dumped to {args.output}")
if __name__ == "__main__":