This commit is contained in:
Georgi Gerganov 2026-02-16 19:47:06 +02:00
parent 7b84af8051
commit 6c41664b8b
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
1 changed files with 32 additions and 43 deletions

View File

@ -59,7 +59,7 @@ Please reason step by step, and put your final answer within \\boxed{{}}.
Please reason step by step, and put your final answer within \\boxed{{}}.
""",
"gsm8k": """{question}
Please reason step by step, and provide your final answer.
Please reason step by step, and put your final numeric answer within \\boxed{{}} without any extra characters.
""",
"gpqa": """{Question}
@ -97,6 +97,7 @@ class TaskState:
gold: str
pred: Optional[str] = None
extracted: Optional[str] = None
grader_log: Dict[str, Any] = field(default_factory=dict)
correct: bool = False
status: str = "pending"
@ -169,20 +170,20 @@ class EvalState:
gold: str,
pred: Optional[str],
extracted: Optional[str],
grader_log: Dict[str, Any],
correct: bool,
status: str
):
if self.dataset_type not in self.task_states:
self.task_states[self.dataset_type] = {}
if "cases" not in self.task_states[self.dataset_type]:
self.task_states[self.dataset_type]["cases"] = {}
if "cases" not in self.task_states:
self.task_states["cases"] = {}
self.task_states[self.dataset_type]["cases"][task_id] = {
self.task_states["cases"][task_id] = {
"case_id": task_id,
"prompt": prompt,
"gold": gold,
"pred": pred,
"extracted": extracted,
"grader_log": grader_log,
"correct": correct,
"status": status
}
@ -190,14 +191,7 @@ class EvalState:
if correct:
self.correct += 1
else:
self.correct = sum(1 for c in self.task_states.get(self.dataset_type, {}).get("cases", {}).values() if c.get("correct", False))
def add_grader_log(self, grader_log: Dict[str, Any]):
if self.dataset_type not in self.task_states:
self.task_states[self.dataset_type] = {}
if "grader_log" not in self.task_states[self.dataset_type]:
self.task_states[self.dataset_type]["grader_log"] = []
self.task_states[self.dataset_type]["grader_log"].append(grader_log)
self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False))
def print_progress(self, task_state: TaskState, total_tasks: int, correct_count: int = 0):
extracted_display = task_state.extracted if task_state.extracted else "N/A"
@ -225,8 +219,8 @@ class EvalState:
all_cases = {}
for i, task_id in tasks_to_save:
prompt, gold = self.get_case(i)
if task_id in self.task_states.get(self.dataset_type, {}).get("cases", {}):
all_cases[task_id] = self.task_states[self.dataset_type]["cases"][task_id]
if task_id in self.task_states.get("cases", {}):
all_cases[task_id] = self.task_states["cases"][task_id]
else:
all_cases[task_id] = {
"case_id": task_id,
@ -234,6 +228,7 @@ class EvalState:
"gold": gold,
"pred": None,
"extracted": None,
"grader_log": {},
"correct": False,
"status": "pending"
}
@ -242,12 +237,9 @@ class EvalState:
"id": self.dataset_type,
"tasks": [tid for _, tid in tasks_to_save],
"task_states": {
self.dataset_type: {
"total": self.total,
"correct": self.correct,
"cases": all_cases,
"grader_log": self.task_states.get("grader_log", [])
}
"total": self.total,
"correct": self.correct,
"cases": all_cases,
},
"sampling_config": self.sampling_config
}
@ -279,9 +271,9 @@ class EvalState:
eval_state.task_states = data.get("task_states", {})
cases = eval_state.task_states.get(eval_state.dataset_type, {}).get("cases", {})
eval_state.total = eval_state.task_states.get(eval_state.dataset_type, {}).get("total", 0)
eval_state.correct = eval_state.task_states.get(eval_state.dataset_type, {}).get("correct", 0)
cases = eval_state.task_states.get("cases", {})
eval_state.total = eval_state.task_states.get("total", 0)
eval_state.correct = eval_state.task_states.get("correct", 0)
if eval_state.total == 0:
eval_state.total = len(cases)
@ -292,12 +284,12 @@ class EvalState:
def is_complete(self) -> bool:
if not self.all_tasks:
return False
cases = self.task_states.get(self.dataset_type, {}).get("cases", {})
completed = {tid for tid in self.task_states.get(self.dataset_type, {}).get("cases", {}).keys() if cases.get(tid, {}).get("status") == "ok"}
cases = self.task_states.get("cases", {})
completed = {tid for tid in self.task_states.get("cases", {}).keys() if cases.get(tid, {}).get("status") == "ok"}
return len(completed) == len(self.all_tasks)
def get_pending_tasks(self) -> List[Tuple[int, str]]:
cases = self.task_states.get(self.dataset_type, {}).get("cases", {})
cases = self.task_states.get("cases", {})
pending = []
for i, task_id in self.all_tasks:
if cases.get(task_id, {}).get("status") != "ok":
@ -305,7 +297,7 @@ class EvalState:
return pending
def print_all_tasks(self):
cases = self.task_states.get(self.dataset_type, {}).get("cases", {})
cases = self.task_states.get("cases", {})
tasks_to_show = self.all_tasks if self.all_tasks else self.tasks
print()
print("Tasks:")
@ -327,7 +319,7 @@ class EvalState:
print()
def print_existing_summary(self):
cases = self.task_states.get(self.dataset_type, {}).get("cases", {})
cases = self.task_states.get("cases", {})
completed_cases = {tid: c for tid, c in cases.items() if c.get("status") == "ok"}
correct = sum(1 for c in completed_cases.values() if c.get("correct", False))
total = len(completed_cases)
@ -450,7 +442,7 @@ class Aime2025Dataset(BaseDataset):
)
class Gsm8kDataset(BaseDataset):
def __init__(self, split: str = "train"):
def __init__(self, split: str = "test"):
self.split = split
self.questions: List[Dict] = []
self._load_dataset()
@ -683,6 +675,7 @@ Please provide only the extracted answer, nothing else. If there is no clear ans
],
"temperature": 0,
}
#print(json.dumps(data, indent=2))
try:
response = requests.post(url, headers=headers, json=data)
@ -759,23 +752,20 @@ class Processor:
pred_truncated = self.grader._truncate_response(pred, max_lines=10)
is_correct, extracted = self.grader.grade(gold, pred_truncated, prompt)
task_state.correct = is_correct
task_state.extracted = extracted
task_state.status = "ok"
grader_log = {
"case_id": task_id,
"gold": gold,
"pred": pred_truncated,
"extracted": extracted,
"correct": is_correct,
"grader_type": self.grader.grader_type
}
if self.grader.grader_type == "regex" and self.grader.pattern:
grader_log["pattern"] = self.grader.pattern
eval_state.add_grader_log(grader_log)
eval_state.add_result(task_id, prompt, gold, pred, extracted, is_correct, "ok")
task_state.correct = is_correct
task_state.extracted = extracted
task_state.grader_log = grader_log
task_state.status = "ok"
eval_state.add_result(task_id, prompt, gold, pred, extracted, grader_log, is_correct, "ok")
eval_state.dump()
@ -962,11 +952,10 @@ def main():
pending_tasks = eval_state.get_pending_tasks()
print(f"Resuming from {len(pending_tasks)} pending tasks")
existing_cases = eval_state.task_states.get(eval_state.dataset_type, {}).get("cases", {})
existing_cases = eval_state.task_states.get("cases", {})
eval_state.tasks = pending_tasks
eval_state.task_states.get(eval_state.dataset_type, {})["cases"] = existing_cases
eval_state.task_states.get(eval_state.dataset_type, {})["grader_log"] = []
eval_state.task_states["cases"] = existing_cases
judge_server_url = args.judge_server if args.judge_server else args.server
judge_model_name = args.judge_model if args.judge_model else args.model