resume eval
This commit is contained in:
parent
ad3a54eb68
commit
e6e777cfb3
|
|
@ -1,29 +0,0 @@
|
|||
{
|
||||
"id": "gpqa",
|
||||
"tasks": [
|
||||
"gpqa"
|
||||
],
|
||||
"task_states": {
|
||||
"gpqa": {
|
||||
"total": 1,
|
||||
"correct": 0,
|
||||
"cases": {
|
||||
"gpqa": [
|
||||
{
|
||||
"case_id": "gpqa_000_184",
|
||||
"prompt": "Consider a system with Hamiltonian operator $H = \\varepsilon \\vec{\\sigma}.\\vec{n}$. Here, $\\vec{n}$ is an arbitrary unit vector, $\\varepsilon $ is a constant of dimension energy, and components of $\\vec{\\sigma}$ are the Pauli spin matrices. What are the eigenvalues of the Hamiltonian operator?\n\n\n(A) +\\hbar/2, -\\hbar/2\n(B) +1, -1\n(C) +\\varepsilon \\hbar/2, - \\varepsilon \\hbar/2\n(D) + \\varepsilon, -\\varepsilon\n\n\nExpress your final answer as the corresponding option 'A', 'B', 'C', or 'D'.\n",
|
||||
"gold": "+ \\varepsilon, -\\varepsilon\n",
|
||||
"pred": null,
|
||||
"extracted": null,
|
||||
"correct": false,
|
||||
"status": "error: HTTPConnectionPool(host='localhost', port=8034): Max retries exceeded with url: /v1/chat/completions (Caused by NewConnectionError(\"HTTPConnection(host='localhost', port=8034): Failed to establish a new connection: [Errno 61] Connection refused\"))"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"sampling_config": {
|
||||
"temperature": 0,
|
||||
"max_tokens": 2048
|
||||
}
|
||||
}
|
||||
|
|
@ -8,8 +8,9 @@ import re
|
|||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, asdict
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
import requests
|
||||
|
|
@ -71,12 +72,23 @@ Express your final answer as the corresponding option 'A', 'B', 'C', or 'D'.
|
|||
""",
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class EvalState:
|
||||
id: str
|
||||
tasks: List[str]
|
||||
task_states: Dict[str, Dict[str, Any]]
|
||||
sampling_config: Dict[str, Any]
|
||||
|
||||
class BaseDataset(ABC):
|
||||
@abstractmethod
|
||||
def get_question(self, index: int) -> Dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_answer(self, question: Dict) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_prompt(self, question: Dict) -> str:
|
||||
pass
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.questions)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskState:
|
||||
|
|
@ -88,13 +100,267 @@ class TaskState:
|
|||
correct: bool = False
|
||||
status: str = "pending"
|
||||
|
||||
|
||||
class EvalState:
|
||||
def __init__(
|
||||
self,
|
||||
dataset_type: str,
|
||||
sampling_config: Dict[str, Any],
|
||||
output_file: Path = Path("llama-eval-state.json")
|
||||
):
|
||||
self.dataset_type = dataset_type
|
||||
self.sampling_config = sampling_config
|
||||
self.output_file = output_file
|
||||
self.dataset: Optional[BaseDataset] = None
|
||||
self.tasks: List[Tuple[int, str]] = []
|
||||
self.all_tasks: List[Tuple[int, str]] = []
|
||||
self.task_states: Dict[str, Any] = {}
|
||||
self.total = 0
|
||||
self.correct = 0
|
||||
self.processed = 0
|
||||
|
||||
def load_dataset(self, seed: int = 1234):
|
||||
if self.dataset_type == "aime":
|
||||
self.dataset = AimeDataset()
|
||||
elif self.dataset_type == "aime2025":
|
||||
self.dataset = Aime2025Dataset()
|
||||
elif self.dataset_type == "gsm8k":
|
||||
self.dataset = Gsm8kDataset()
|
||||
elif self.dataset_type == "gpqa":
|
||||
self.dataset = GpqaDataset(variant="diamond", seed=seed)
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset type: {self.dataset_type}")
|
||||
|
||||
def setup_tasks(self, n_cases: Optional[int] = None, seed: int = 1234):
|
||||
if self.dataset is None:
|
||||
raise ValueError("Dataset not loaded. Call load_dataset() first.")
|
||||
|
||||
if n_cases is None:
|
||||
n_cases = len(self.dataset)
|
||||
|
||||
dataset_size = len(self.dataset)
|
||||
rng = random.Random(seed)
|
||||
|
||||
self.tasks = []
|
||||
for chunk_idx in range((n_cases + dataset_size - 1) // dataset_size):
|
||||
chunk_size = min(dataset_size, n_cases - chunk_idx * dataset_size)
|
||||
indices = list(range(dataset_size))
|
||||
rng.shuffle(indices)
|
||||
chunk_indices = indices[:chunk_size]
|
||||
|
||||
for i in chunk_indices:
|
||||
task_id = f"{self.dataset_type}_{chunk_idx:03d}_{i:03d}"
|
||||
self.tasks.append((i, task_id))
|
||||
|
||||
self.all_tasks = list(self.tasks)
|
||||
|
||||
def get_case(self, index: int) -> Tuple[str, str]:
|
||||
if self.dataset is None:
|
||||
raise ValueError("Dataset not loaded.")
|
||||
question = self.dataset.get_question(index)
|
||||
prompt = self.dataset.get_prompt(question)
|
||||
gold = self.dataset.get_answer(question)
|
||||
return prompt, gold
|
||||
|
||||
def add_result(
|
||||
self,
|
||||
task_id: str,
|
||||
prompt: str,
|
||||
gold: str,
|
||||
pred: Optional[str],
|
||||
extracted: Optional[str],
|
||||
correct: bool,
|
||||
status: str
|
||||
):
|
||||
if self.dataset_type not in self.task_states:
|
||||
self.task_states[self.dataset_type] = {}
|
||||
if "cases" not in self.task_states[self.dataset_type]:
|
||||
self.task_states[self.dataset_type]["cases"] = {}
|
||||
|
||||
self.task_states[self.dataset_type]["cases"][task_id] = {
|
||||
"case_id": task_id,
|
||||
"prompt": prompt,
|
||||
"gold": gold,
|
||||
"pred": pred,
|
||||
"extracted": extracted,
|
||||
"correct": correct,
|
||||
"status": status
|
||||
}
|
||||
|
||||
if correct:
|
||||
self.correct += 1
|
||||
else:
|
||||
self.correct = sum(1 for c in self.task_states.get(self.dataset_type, {}).get("cases", {}).values() if c.get("correct", False))
|
||||
|
||||
def add_grader_log(self, grader_log: Dict[str, Any]):
|
||||
if self.dataset_type not in self.task_states:
|
||||
self.task_states[self.dataset_type] = {}
|
||||
if "grader_log" not in self.task_states[self.dataset_type]:
|
||||
self.task_states[self.dataset_type]["grader_log"] = []
|
||||
self.task_states[self.dataset_type]["grader_log"].append(grader_log)
|
||||
|
||||
def print_task_header(self):
|
||||
tasks_to_show = self.all_tasks if self.all_tasks else self.tasks
|
||||
cases = self.task_states.get(self.dataset_type, {}).get("cases", {})
|
||||
print("Tasks:")
|
||||
print(" Task ID Dataset Prompt (first 40 chars) Expected Extracted Status")
|
||||
for i, task_id in tasks_to_show:
|
||||
prompt, gold = self.get_case(i)
|
||||
case = cases.get(task_id, {})
|
||||
status = case.get("status", "pending")
|
||||
extracted = case.get("extracted", "N/A") if status == "ok" else "N/A"
|
||||
is_correct = case.get("correct", False) if status == "ok" else False
|
||||
symbol = "✓ " if is_correct else ("✗ " if status == "ok" else "")
|
||||
first_line = prompt.split('\n')[0]
|
||||
truncated_prompt = first_line[:43]
|
||||
if len(first_line) > 43:
|
||||
truncated_prompt += "..."
|
||||
else:
|
||||
truncated_prompt = truncated_prompt.ljust(43) + "..."
|
||||
print(f" {task_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {gold:<10} {extracted:<10} {symbol}{status}")
|
||||
print()
|
||||
|
||||
def print_progress(self, task_state: TaskState, total_tasks: int, correct_count: int = 0):
|
||||
extracted_display = task_state.extracted if task_state.extracted else "N/A"
|
||||
success_ratio = correct_count / self.processed if self.processed > 0 else 0.0
|
||||
first_line = task_state.prompt.split('\n')[0]
|
||||
truncated_prompt = first_line[:43]
|
||||
if len(first_line) > 43:
|
||||
truncated_prompt += "..."
|
||||
else:
|
||||
truncated_prompt = truncated_prompt.ljust(43) + "..."
|
||||
print(f"{self.processed:3}/{total_tasks:3} {task_state.case_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {task_state.gold:<10} {extracted_display:<10} {'✓' if task_state.correct else '✗'} [{correct_count:3}/{self.processed:3}, {success_ratio:.3f}]")
|
||||
|
||||
def print_summary(self):
|
||||
if self.total == 0:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Results: 0/0 correct (0.0%)")
|
||||
print(f"{'='*60}")
|
||||
else:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Results: {self.correct}/{self.total} correct ({self.correct/self.total*100:.1f}%)")
|
||||
print(f"{'='*60}")
|
||||
|
||||
def dump(self):
|
||||
tasks_to_save = self.all_tasks if self.all_tasks else self.tasks
|
||||
all_cases = {}
|
||||
for i, task_id in tasks_to_save:
|
||||
prompt, gold = self.get_case(i)
|
||||
if task_id in self.task_states.get(self.dataset_type, {}).get("cases", {}):
|
||||
all_cases[task_id] = self.task_states[self.dataset_type]["cases"][task_id]
|
||||
else:
|
||||
all_cases[task_id] = {
|
||||
"case_id": task_id,
|
||||
"prompt": prompt,
|
||||
"gold": gold,
|
||||
"pred": None,
|
||||
"extracted": None,
|
||||
"correct": False,
|
||||
"status": "pending"
|
||||
}
|
||||
|
||||
data = {
|
||||
"id": self.dataset_type,
|
||||
"tasks": [tid for _, tid in tasks_to_save],
|
||||
"task_states": {
|
||||
self.dataset_type: {
|
||||
"total": self.total,
|
||||
"correct": self.correct,
|
||||
"cases": all_cases,
|
||||
"grader_log": self.task_states.get("grader_log", [])
|
||||
}
|
||||
},
|
||||
"sampling_config": self.sampling_config
|
||||
}
|
||||
with open(self.output_file, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Path) -> "EvalState":
|
||||
with open(path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
eval_state = cls(
|
||||
dataset_type=data["id"],
|
||||
sampling_config=data["sampling_config"],
|
||||
output_file=path
|
||||
)
|
||||
eval_state.load_dataset()
|
||||
|
||||
eval_state.tasks = []
|
||||
eval_state.all_tasks = []
|
||||
for task_id in data.get("tasks", []):
|
||||
parts = task_id.rsplit("_", 2)
|
||||
if len(parts) >= 3:
|
||||
idx = int(parts[-1])
|
||||
else:
|
||||
idx = 0
|
||||
eval_state.tasks.append((idx, task_id))
|
||||
eval_state.all_tasks.append((idx, task_id))
|
||||
|
||||
eval_state.task_states = data.get("task_states", {})
|
||||
|
||||
cases = eval_state.task_states.get(eval_state.dataset_type, {}).get("cases", {})
|
||||
eval_state.total = eval_state.task_states.get(eval_state.dataset_type, {}).get("total", 0)
|
||||
eval_state.correct = eval_state.task_states.get(eval_state.dataset_type, {}).get("correct", 0)
|
||||
|
||||
if eval_state.total == 0:
|
||||
eval_state.total = len(cases)
|
||||
eval_state.correct = sum(1 for c in cases.values() if c.get("correct", False))
|
||||
|
||||
return eval_state
|
||||
|
||||
def is_complete(self) -> bool:
|
||||
if not self.all_tasks:
|
||||
return False
|
||||
cases = self.task_states.get(self.dataset_type, {}).get("cases", {})
|
||||
completed = {tid for tid in self.task_states.get(self.dataset_type, {}).get("cases", {}).keys() if cases.get(tid, {}).get("status") == "ok"}
|
||||
return len(completed) == len(self.all_tasks)
|
||||
|
||||
def get_pending_tasks(self) -> List[Tuple[int, str]]:
|
||||
cases = self.task_states.get(self.dataset_type, {}).get("cases", {})
|
||||
pending = []
|
||||
for i, task_id in self.all_tasks:
|
||||
if cases.get(task_id, {}).get("status") != "ok":
|
||||
pending.append((i, task_id))
|
||||
return pending
|
||||
|
||||
def print_all_tasks(self):
|
||||
cases = self.task_states.get(self.dataset_type, {}).get("cases", {})
|
||||
tasks_to_show = self.all_tasks if self.all_tasks else self.tasks
|
||||
print("Tasks:")
|
||||
print(" Task ID Dataset Prompt (first 40 chars) Expected Extracted Status")
|
||||
for i, task_id in tasks_to_show:
|
||||
prompt, gold = self.get_case(i)
|
||||
case = cases.get(task_id, {})
|
||||
status = case.get("status", "pending")
|
||||
extracted = case.get("extracted", "N/A") if status == "ok" else "N/A"
|
||||
is_correct = case.get("correct", False) if status == "ok" else False
|
||||
symbol = "✓ " if is_correct else ("✗ " if status == "ok" else "")
|
||||
first_line = prompt.split('\n')[0]
|
||||
truncated_prompt = first_line[:43]
|
||||
if len(first_line) > 43:
|
||||
truncated_prompt += "..."
|
||||
else:
|
||||
truncated_prompt = truncated_prompt.ljust(43) + "..."
|
||||
print(f" {task_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {gold:<10} {extracted:<10} {symbol}{status}")
|
||||
print()
|
||||
|
||||
def print_existing_summary(self):
|
||||
cases = self.task_states.get(self.dataset_type, {}).get("cases", {})
|
||||
correct = sum(1 for c in cases.values() if c.get("correct", False))
|
||||
total = len(cases)
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Results: {correct}/{total} correct ({correct/total*100:.1f}%)")
|
||||
print(f"{'='*60}")
|
||||
|
||||
def normalize_number(s: str) -> Optional[int]:
|
||||
match = re.match(r"\d+", s) # match digits from the start
|
||||
if not match:
|
||||
return None
|
||||
return int(match.group(0))
|
||||
|
||||
class AimeDataset:
|
||||
class AimeDataset(BaseDataset):
|
||||
def __init__(self, split: str = "train"):
|
||||
self.split = split
|
||||
self.questions: List[Dict] = []
|
||||
|
|
@ -139,7 +405,7 @@ class AimeDataset:
|
|||
question=question["problem"] if "problem" in question else question["question"]
|
||||
)
|
||||
|
||||
class Aime2025Dataset:
|
||||
class Aime2025Dataset(BaseDataset):
|
||||
def __init__(self):
|
||||
self.questions: List[Dict] = []
|
||||
self._load_dataset()
|
||||
|
|
@ -197,7 +463,7 @@ class Aime2025Dataset:
|
|||
question=question["question"]
|
||||
)
|
||||
|
||||
class Gsm8kDataset:
|
||||
class Gsm8kDataset(BaseDataset):
|
||||
def __init__(self, split: str = "train"):
|
||||
self.split = split
|
||||
self.questions: List[Dict] = []
|
||||
|
|
@ -253,7 +519,7 @@ class Gsm8kDataset:
|
|||
question=question["problem"] if "problem" in question else question["question"]
|
||||
)
|
||||
|
||||
class GpqaDataset:
|
||||
class GpqaDataset(BaseDataset):
|
||||
def __init__(self, variant: str = "diamond", seed: int = 1234):
|
||||
self.variant = variant
|
||||
self.seed = seed
|
||||
|
|
@ -461,84 +727,38 @@ class Processor:
|
|||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
n_predict: int = -1,
|
||||
threads: int = 32,
|
||||
verbose: bool = False,
|
||||
grader: Optional[Grader] = None,
|
||||
grader: Grader,
|
||||
model_name: Optional[str] = None,
|
||||
judge_server_url: str = "",
|
||||
judge_model_name: Optional[str] = None,
|
||||
dataset_type: str = "aime",
|
||||
seed: int = 1234,
|
||||
sampling_config: Optional[Dict[str, Any]] = None,
|
||||
output_file: Optional[Path] = None
|
||||
threads: int = 32
|
||||
):
|
||||
self.server_url = server_url
|
||||
self.n_predict = n_predict
|
||||
self.threads = threads
|
||||
self.verbose = verbose
|
||||
self.grader = grader
|
||||
self.model_name = model_name
|
||||
self.judge_server_url = judge_server_url if judge_server_url else server_url
|
||||
self.judge_model_name = judge_model_name
|
||||
self.dataset_type = dataset_type
|
||||
self.seed = seed
|
||||
self.grader = grader or Grader()
|
||||
self.sampling_config = sampling_config or {"n_predict": n_predict}
|
||||
self.output_file = output_file or Path("llama-eval-state.json")
|
||||
self.eval_state = EvalState(
|
||||
id=dataset_type,
|
||||
tasks=[dataset_type],
|
||||
task_states={dataset_type: {}},
|
||||
sampling_config=self.sampling_config
|
||||
)
|
||||
self.threads = threads
|
||||
|
||||
# Pass judge configuration to grader if using LLM grader
|
||||
if self.grader.grader_type == "llm":
|
||||
if self.judge_model_name:
|
||||
self.grader.judge_model_name = self.judge_model_name
|
||||
if self.judge_server_url:
|
||||
self.grader.judge_server_url = self.judge_server_url
|
||||
|
||||
# Initialize appropriate dataset
|
||||
if dataset_type == "aime":
|
||||
self.dataset = AimeDataset()
|
||||
elif dataset_type == "aime2025":
|
||||
self.dataset = Aime2025Dataset()
|
||||
elif dataset_type == "gsm8k":
|
||||
self.dataset = Gsm8kDataset()
|
||||
elif dataset_type == "gpqa":
|
||||
self.dataset = GpqaDataset(variant="diamond", seed=self.seed)
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset type: {dataset_type}")
|
||||
|
||||
def _make_request(self, prompt: str) -> Dict[str, Any]:
|
||||
"""Make HTTP request to the server"""
|
||||
def _make_request(self, eval_state: EvalState, prompt: str) -> Dict[str, Any]:
|
||||
url = f"{self.server_url}/v1/chat/completions"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = {
|
||||
"model": self.model_name if self.model_name else "llama",
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"n_predict": self.n_predict
|
||||
"n_predict": eval_state.sampling_config.get("n_predict", -1)
|
||||
}
|
||||
if self.sampling_config.get("temperature") is not None:
|
||||
data["temperature"] = self.sampling_config["temperature"]
|
||||
if self.sampling_config.get("top_k") is not None:
|
||||
data["top_k"] = self.sampling_config["top_k"]
|
||||
if self.sampling_config.get("top_p") is not None:
|
||||
data["top_p"] = self.sampling_config["top_p"]
|
||||
if self.sampling_config.get("min_p") is not None:
|
||||
data["min_p"] = self.sampling_config["min_p"]
|
||||
if eval_state.sampling_config.get("temperature") is not None:
|
||||
data["temperature"] = eval_state.sampling_config["temperature"]
|
||||
if eval_state.sampling_config.get("top_k") is not None:
|
||||
data["top_k"] = eval_state.sampling_config["top_k"]
|
||||
if eval_state.sampling_config.get("top_p") is not None:
|
||||
data["top_p"] = eval_state.sampling_config["top_p"]
|
||||
if eval_state.sampling_config.get("min_p") is not None:
|
||||
data["min_p"] = eval_state.sampling_config["min_p"]
|
||||
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _process_single_case(self, i: int, task_id: str) -> TaskState:
|
||||
"""Process a single case (thread-safe)"""
|
||||
question = self.dataset.get_question(i)
|
||||
dataset_id = f"{self.dataset_type}_{i}"
|
||||
gold = self.dataset.get_answer(question)
|
||||
prompt = self.dataset.get_prompt(question)
|
||||
def _process_single_case(self, eval_state: EvalState, i: int, task_id: str) -> TaskState:
|
||||
prompt, gold = eval_state.get_case(i)
|
||||
|
||||
task_state = TaskState(
|
||||
case_id=task_id,
|
||||
|
|
@ -547,20 +767,16 @@ class Processor:
|
|||
)
|
||||
|
||||
try:
|
||||
response = self._make_request(prompt)
|
||||
response = self._make_request(eval_state, prompt)
|
||||
pred = response["choices"][0]["message"]["content"]
|
||||
task_state.pred = pred
|
||||
|
||||
# Truncate response to last 2-3 lines for grading
|
||||
pred_truncated = self.grader._truncate_response(pred, max_lines=10)
|
||||
|
||||
# Grade the response
|
||||
is_correct, extracted = self.grader.grade(gold, pred_truncated, prompt)
|
||||
task_state.correct = is_correct
|
||||
task_state.extracted = extracted
|
||||
task_state.status = "ok"
|
||||
|
||||
# Log grader request details for debugging
|
||||
grader_log = {
|
||||
"case_id": task_id,
|
||||
"gold": gold,
|
||||
|
|
@ -571,111 +787,49 @@ class Processor:
|
|||
}
|
||||
if self.grader.grader_type == "regex" and self.grader.pattern:
|
||||
grader_log["pattern"] = self.grader.pattern
|
||||
if "grader_log" not in self.eval_state.task_states[self.dataset_type]:
|
||||
self.eval_state.task_states[self.dataset_type]["grader_log"] = []
|
||||
self.eval_state.task_states[self.dataset_type]["grader_log"].append(grader_log)
|
||||
eval_state.add_grader_log(grader_log)
|
||||
|
||||
# Initialize cases dict if it doesn't exist
|
||||
if "cases" not in self.eval_state.task_states[self.dataset_type]:
|
||||
self.eval_state.task_states[self.dataset_type]["cases"] = {}
|
||||
eval_state.add_result(task_id, prompt, gold, pred, extracted, is_correct, "ok")
|
||||
|
||||
# Update eval state with grading details
|
||||
self.eval_state.task_states[self.dataset_type]["cases"][task_id] = {
|
||||
"case_id": task_id,
|
||||
"prompt": prompt,
|
||||
"gold": gold,
|
||||
"pred": pred,
|
||||
"extracted": extracted,
|
||||
"correct": is_correct,
|
||||
"status": "ok"
|
||||
}
|
||||
eval_state.dump()
|
||||
|
||||
# Save eval state to disk after each task
|
||||
try:
|
||||
self.dump_state(self.output_file)
|
||||
except Exception as dump_error:
|
||||
task_state.status = f"error: {str(e)}; dump error: {str(dump_error)}"
|
||||
except Exception as processing_error:
|
||||
task_state.status = f"error: {str(processing_error)}"
|
||||
except Exception as e:
|
||||
task_state.status = f"error: {str(e)}"
|
||||
|
||||
return task_state
|
||||
|
||||
def process(self, n_cases: int = None, seed: int = 1234):
|
||||
"""Process cases and update eval state"""
|
||||
if n_cases is None:
|
||||
n_cases = len(self.dataset.questions)
|
||||
def evaluate(self, eval_state: EvalState, verbose: bool = False, resume: bool = False):
|
||||
total_tasks = len(eval_state.tasks)
|
||||
eval_state.total = len(eval_state.all_tasks) if eval_state.all_tasks else total_tasks
|
||||
eval_state.processed = 0
|
||||
|
||||
print(f"\nProcessing {n_cases} {self.dataset_type.upper()} questions...")
|
||||
print(f"\nProcessing {len(eval_state.tasks)} {eval_state.dataset_type.upper()} questions...")
|
||||
print(f"Server: {self.server_url} (model: {self.model_name})")
|
||||
print(f"Grader: {self.grader.grader_type}", end="")
|
||||
if self.grader.grader_type == "llm":
|
||||
judge_model = self.judge_model_name if self.judge_model_name else self.model_name
|
||||
print(f" (judge server: {self.judge_server_url}, model: {judge_model})", end="")
|
||||
print()
|
||||
print(f"Grader: {self.grader.grader_type}")
|
||||
print(f"Threads: {self.threads}")
|
||||
print(f"Max tokens: {self.n_predict}")
|
||||
print(f"Seed: {self.seed}")
|
||||
print(f"Sampling: temp={self.sampling_config.get('temperature', 'skip')}, top-k={self.sampling_config.get('top_k', 'skip')}, top-p={self.sampling_config.get('top_p', 'skip')}, min-p={self.sampling_config.get('min_p', 'skip')}")
|
||||
print(f"Sampling: temp={eval_state.sampling_config.get('temperature', 'skip')}, top-k={eval_state.sampling_config.get('top_k', 'skip')}, top-p={eval_state.sampling_config.get('top_p', 'skip')}, min-p={eval_state.sampling_config.get('min_p', 'skip')}")
|
||||
print()
|
||||
|
||||
dataset_size = len(self.dataset.questions)
|
||||
random.seed(seed)
|
||||
if not resume:
|
||||
eval_state.print_task_header()
|
||||
|
||||
task_list = []
|
||||
for chunk_idx in range((n_cases + dataset_size - 1) // dataset_size):
|
||||
chunk_size = min(dataset_size, n_cases - chunk_idx * dataset_size)
|
||||
indices = list(range(dataset_size))
|
||||
random.shuffle(indices)
|
||||
chunk_indices = indices[:chunk_size]
|
||||
|
||||
for i in chunk_indices:
|
||||
task_id = f"{self.dataset_type}_{chunk_idx:03d}_{i:03d}"
|
||||
task_list.append((i, task_id))
|
||||
|
||||
# Print task summary table
|
||||
print("Tasks:")
|
||||
print(" Task ID Dataset Prompt (first 40 chars) Expected Status")
|
||||
for i, task_id in task_list:
|
||||
question = self.dataset.get_question(i)
|
||||
prompt = self.dataset.get_prompt(question)
|
||||
gold = self.dataset.get_answer(question)
|
||||
first_line = prompt.split('\n')[0]
|
||||
truncated_prompt = first_line[:43]
|
||||
if len(first_line) > 43:
|
||||
truncated_prompt += "..."
|
||||
else:
|
||||
truncated_prompt = truncated_prompt.ljust(43) + "..."
|
||||
print(f" {task_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {gold:<10} pending")
|
||||
print()
|
||||
|
||||
task_states: Dict[str, List[TaskState]] = {task: [] for task in self.eval_state.tasks}
|
||||
total = 0
|
||||
correct = 0
|
||||
correct_count = 0
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.threads) as executor:
|
||||
futures = {executor.submit(self._process_single_case, i, task_id): (i, task_id) for i, task_id in task_list}
|
||||
futures = {
|
||||
executor.submit(self._process_single_case, eval_state, i, task_id): (i, task_id)
|
||||
for i, task_id in eval_state.tasks
|
||||
}
|
||||
|
||||
for future in as_completed(futures):
|
||||
task_state = future.result()
|
||||
task_states[self.dataset_type].append(task_state)
|
||||
total += 1
|
||||
|
||||
eval_state.processed += 1
|
||||
if task_state.correct:
|
||||
correct += 1
|
||||
correct_count += 1
|
||||
eval_state.print_progress(task_state, total_tasks, correct_count)
|
||||
|
||||
# Print task completion status
|
||||
extracted_display = task_state.extracted if task_state.extracted else "N/A"
|
||||
success_ratio = correct / total if total > 0 else 0.0
|
||||
first_line = task_state.prompt.split('\n')[0]
|
||||
truncated_prompt = first_line[:43]
|
||||
if len(first_line) > 43:
|
||||
truncated_prompt += "..."
|
||||
else:
|
||||
truncated_prompt = truncated_prompt.ljust(43) + "..."
|
||||
print(f"{total:3}/{n_cases:3} {task_state.case_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {task_state.gold:<10} {extracted_display:<10} {'✓' if task_state.correct else '✗'} [{correct:3}/{total:3}, {success_ratio:.3f}]")
|
||||
|
||||
if self.verbose:
|
||||
print(f"\nCase {total}: {task_state.correct}")
|
||||
if verbose:
|
||||
print(f"\nCase {eval_state.processed}: {task_state.correct}")
|
||||
print(f" Gold: {task_state.gold}")
|
||||
if task_state.pred:
|
||||
print(f" Pred: {task_state.pred}")
|
||||
|
|
@ -683,25 +837,9 @@ class Processor:
|
|||
print(f" Extracted: {task_state.extracted}")
|
||||
print(f" Status: {task_state.status}")
|
||||
|
||||
# Merge existing state with new state to preserve grader_log
|
||||
existing_state = self.eval_state.task_states.get(self.dataset_type, {})
|
||||
self.eval_state.task_states[self.dataset_type] = {
|
||||
"total": total,
|
||||
"correct": correct,
|
||||
"cases": task_states,
|
||||
**existing_state
|
||||
}
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Results: {correct}/{total} correct ({correct/total*100:.1f}%)")
|
||||
print(f"{'='*60}")
|
||||
|
||||
return self.eval_state
|
||||
|
||||
def dump_state(self, output_file: Path):
|
||||
"""Dump eval state to JSON file"""
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(asdict(self.eval_state), f, indent=2)
|
||||
eval_state.correct = correct_count
|
||||
eval_state.print_summary()
|
||||
eval_state.dump()
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
|
@ -810,51 +948,101 @@ def main():
|
|||
default="",
|
||||
help="Model name for LLM judge (default: same as main model)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
action="store_true",
|
||||
help="Resume from existing eval state"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate grader type for GPQA
|
||||
if args.dataset == "gpqa" and args.grader_type != "llm":
|
||||
print("Error: GPQA dataset requires --grader-type llm")
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
|
||||
grader = Grader(
|
||||
grader_type=args.grader_type,
|
||||
grader_script=args.grader_script,
|
||||
judge_model_name=args.judge_model if args.judge_model else args.model,
|
||||
dataset_type=args.dataset
|
||||
)
|
||||
if args.output.exists():
|
||||
print(f"Loading existing eval state from {args.output}")
|
||||
eval_state = EvalState.load(args.output)
|
||||
|
||||
if args.grader_type == "llm" and not args.judge_server:
|
||||
print("Warning: Using same server for LLM judge (no --judge-server specified)")
|
||||
if eval_state.is_complete():
|
||||
eval_state.print_all_tasks()
|
||||
eval_state.print_existing_summary()
|
||||
return
|
||||
|
||||
sampling_config = {"n_predict": args.n_predict}
|
||||
if args.temperature is not None:
|
||||
sampling_config["temperature"] = args.temperature
|
||||
if args.top_k is not None:
|
||||
sampling_config["top_k"] = args.top_k
|
||||
if args.top_p is not None:
|
||||
sampling_config["top_p"] = args.top_p
|
||||
if args.min_p is not None:
|
||||
sampling_config["min_p"] = args.min_p
|
||||
eval_state.print_all_tasks()
|
||||
eval_state.print_existing_summary()
|
||||
|
||||
if not args.resume:
|
||||
print(f"Evaluation incomplete. Run with --resume to continue.")
|
||||
return
|
||||
|
||||
pending_tasks = eval_state.get_pending_tasks()
|
||||
print(f"Resuming from {len(pending_tasks)} pending tasks")
|
||||
|
||||
existing_cases = eval_state.task_states.get(eval_state.dataset_type, {}).get("cases", {})
|
||||
|
||||
eval_state.tasks = pending_tasks
|
||||
eval_state.task_states.get(eval_state.dataset_type, {})["cases"] = existing_cases
|
||||
eval_state.task_states.get(eval_state.dataset_type, {})["grader_log"] = []
|
||||
|
||||
judge_server_url = args.judge_server if args.judge_server else args.server
|
||||
judge_model_name = args.judge_model if args.judge_model else args.model
|
||||
grader = Grader(
|
||||
grader_type=args.grader_type,
|
||||
grader_script=args.grader_script,
|
||||
judge_model_name=judge_model_name,
|
||||
judge_server_url=judge_server_url,
|
||||
dataset_type=eval_state.dataset_type
|
||||
)
|
||||
resume = True
|
||||
else:
|
||||
if args.resume:
|
||||
print("Error: No existing eval state found to resume")
|
||||
sys.exit(1)
|
||||
|
||||
judge_server_url = args.judge_server if args.judge_server else args.server
|
||||
judge_model_name = args.judge_model if args.judge_model else args.model
|
||||
|
||||
grader = Grader(
|
||||
grader_type=args.grader_type,
|
||||
grader_script=args.grader_script,
|
||||
judge_model_name=judge_model_name,
|
||||
judge_server_url=judge_server_url,
|
||||
dataset_type=args.dataset
|
||||
)
|
||||
|
||||
if args.grader_type == "llm" and not args.judge_server:
|
||||
print("Warning: Using same server for LLM judge (no --judge-server specified)")
|
||||
|
||||
sampling_config = {"n_predict": args.n_predict}
|
||||
if args.temperature is not None:
|
||||
sampling_config["temperature"] = args.temperature
|
||||
if args.top_k is not None:
|
||||
sampling_config["top_k"] = args.top_k
|
||||
if args.top_p is not None:
|
||||
sampling_config["top_p"] = args.top_p
|
||||
if args.min_p is not None:
|
||||
sampling_config["min_p"] = args.min_p
|
||||
|
||||
eval_state = EvalState(
|
||||
dataset_type=args.dataset,
|
||||
sampling_config=sampling_config,
|
||||
output_file=args.output
|
||||
)
|
||||
eval_state.load_dataset(seed=args.seed)
|
||||
eval_state.setup_tasks(n_cases=args.n_cases, seed=args.seed)
|
||||
eval_state.dump()
|
||||
resume = False
|
||||
|
||||
processor = Processor(
|
||||
server_url=args.server,
|
||||
n_predict=args.n_predict,
|
||||
threads=args.threads,
|
||||
verbose=args.verbose,
|
||||
grader=grader,
|
||||
model_name=args.model,
|
||||
judge_server_url=args.judge_server,
|
||||
judge_model_name=args.judge_model,
|
||||
dataset_type=args.dataset,
|
||||
sampling_config=sampling_config,
|
||||
output_file=args.output
|
||||
threads=args.threads
|
||||
)
|
||||
|
||||
eval_state = processor.process(n_cases=args.n_cases, seed=args.seed)
|
||||
processor.dump_state(args.output)
|
||||
processor.evaluate(eval_state, verbose=args.verbose, resume=resume)
|
||||
print(f"\nEval state dumped to {args.output}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue