simplify
This commit is contained in:
parent
7b84af8051
commit
6c41664b8b
|
|
@ -59,7 +59,7 @@ Please reason step by step, and put your final answer within \\boxed{{}}.
|
|||
Please reason step by step, and put your final answer within \\boxed{{}}.
|
||||
""",
|
||||
"gsm8k": """{question}
|
||||
Please reason step by step, and provide your final answer.
|
||||
Please reason step by step, and put your final numeric answer within \\boxed{{}} without any extra characters.
|
||||
""",
|
||||
"gpqa": """{Question}
|
||||
|
||||
|
|
@ -97,6 +97,7 @@ class TaskState:
|
|||
gold: str
|
||||
pred: Optional[str] = None
|
||||
extracted: Optional[str] = None
|
||||
grader_log: Dict[str, Any] = field(default_factory=dict)
|
||||
correct: bool = False
|
||||
status: str = "pending"
|
||||
|
||||
|
|
@ -169,20 +170,20 @@ class EvalState:
|
|||
gold: str,
|
||||
pred: Optional[str],
|
||||
extracted: Optional[str],
|
||||
grader_log: Dict[str, Any],
|
||||
correct: bool,
|
||||
status: str
|
||||
):
|
||||
if self.dataset_type not in self.task_states:
|
||||
self.task_states[self.dataset_type] = {}
|
||||
if "cases" not in self.task_states[self.dataset_type]:
|
||||
self.task_states[self.dataset_type]["cases"] = {}
|
||||
if "cases" not in self.task_states:
|
||||
self.task_states["cases"] = {}
|
||||
|
||||
self.task_states[self.dataset_type]["cases"][task_id] = {
|
||||
self.task_states["cases"][task_id] = {
|
||||
"case_id": task_id,
|
||||
"prompt": prompt,
|
||||
"gold": gold,
|
||||
"pred": pred,
|
||||
"extracted": extracted,
|
||||
"grader_log": grader_log,
|
||||
"correct": correct,
|
||||
"status": status
|
||||
}
|
||||
|
|
@ -190,14 +191,7 @@ class EvalState:
|
|||
if correct:
|
||||
self.correct += 1
|
||||
else:
|
||||
self.correct = sum(1 for c in self.task_states.get(self.dataset_type, {}).get("cases", {}).values() if c.get("correct", False))
|
||||
|
||||
def add_grader_log(self, grader_log: Dict[str, Any]):
|
||||
if self.dataset_type not in self.task_states:
|
||||
self.task_states[self.dataset_type] = {}
|
||||
if "grader_log" not in self.task_states[self.dataset_type]:
|
||||
self.task_states[self.dataset_type]["grader_log"] = []
|
||||
self.task_states[self.dataset_type]["grader_log"].append(grader_log)
|
||||
self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False))
|
||||
|
||||
def print_progress(self, task_state: TaskState, total_tasks: int, correct_count: int = 0):
|
||||
extracted_display = task_state.extracted if task_state.extracted else "N/A"
|
||||
|
|
@ -225,8 +219,8 @@ class EvalState:
|
|||
all_cases = {}
|
||||
for i, task_id in tasks_to_save:
|
||||
prompt, gold = self.get_case(i)
|
||||
if task_id in self.task_states.get(self.dataset_type, {}).get("cases", {}):
|
||||
all_cases[task_id] = self.task_states[self.dataset_type]["cases"][task_id]
|
||||
if task_id in self.task_states.get("cases", {}):
|
||||
all_cases[task_id] = self.task_states["cases"][task_id]
|
||||
else:
|
||||
all_cases[task_id] = {
|
||||
"case_id": task_id,
|
||||
|
|
@ -234,6 +228,7 @@ class EvalState:
|
|||
"gold": gold,
|
||||
"pred": None,
|
||||
"extracted": None,
|
||||
"grader_log": {},
|
||||
"correct": False,
|
||||
"status": "pending"
|
||||
}
|
||||
|
|
@ -242,12 +237,9 @@ class EvalState:
|
|||
"id": self.dataset_type,
|
||||
"tasks": [tid for _, tid in tasks_to_save],
|
||||
"task_states": {
|
||||
self.dataset_type: {
|
||||
"total": self.total,
|
||||
"correct": self.correct,
|
||||
"cases": all_cases,
|
||||
"grader_log": self.task_states.get("grader_log", [])
|
||||
}
|
||||
"total": self.total,
|
||||
"correct": self.correct,
|
||||
"cases": all_cases,
|
||||
},
|
||||
"sampling_config": self.sampling_config
|
||||
}
|
||||
|
|
@ -279,9 +271,9 @@ class EvalState:
|
|||
|
||||
eval_state.task_states = data.get("task_states", {})
|
||||
|
||||
cases = eval_state.task_states.get(eval_state.dataset_type, {}).get("cases", {})
|
||||
eval_state.total = eval_state.task_states.get(eval_state.dataset_type, {}).get("total", 0)
|
||||
eval_state.correct = eval_state.task_states.get(eval_state.dataset_type, {}).get("correct", 0)
|
||||
cases = eval_state.task_states.get("cases", {})
|
||||
eval_state.total = eval_state.task_states.get("total", 0)
|
||||
eval_state.correct = eval_state.task_states.get("correct", 0)
|
||||
|
||||
if eval_state.total == 0:
|
||||
eval_state.total = len(cases)
|
||||
|
|
@ -292,12 +284,12 @@ class EvalState:
|
|||
def is_complete(self) -> bool:
|
||||
if not self.all_tasks:
|
||||
return False
|
||||
cases = self.task_states.get(self.dataset_type, {}).get("cases", {})
|
||||
completed = {tid for tid in self.task_states.get(self.dataset_type, {}).get("cases", {}).keys() if cases.get(tid, {}).get("status") == "ok"}
|
||||
cases = self.task_states.get("cases", {})
|
||||
completed = {tid for tid in self.task_states.get("cases", {}).keys() if cases.get(tid, {}).get("status") == "ok"}
|
||||
return len(completed) == len(self.all_tasks)
|
||||
|
||||
def get_pending_tasks(self) -> List[Tuple[int, str]]:
|
||||
cases = self.task_states.get(self.dataset_type, {}).get("cases", {})
|
||||
cases = self.task_states.get("cases", {})
|
||||
pending = []
|
||||
for i, task_id in self.all_tasks:
|
||||
if cases.get(task_id, {}).get("status") != "ok":
|
||||
|
|
@ -305,7 +297,7 @@ class EvalState:
|
|||
return pending
|
||||
|
||||
def print_all_tasks(self):
|
||||
cases = self.task_states.get(self.dataset_type, {}).get("cases", {})
|
||||
cases = self.task_states.get("cases", {})
|
||||
tasks_to_show = self.all_tasks if self.all_tasks else self.tasks
|
||||
print()
|
||||
print("Tasks:")
|
||||
|
|
@ -327,7 +319,7 @@ class EvalState:
|
|||
print()
|
||||
|
||||
def print_existing_summary(self):
|
||||
cases = self.task_states.get(self.dataset_type, {}).get("cases", {})
|
||||
cases = self.task_states.get("cases", {})
|
||||
completed_cases = {tid: c for tid, c in cases.items() if c.get("status") == "ok"}
|
||||
correct = sum(1 for c in completed_cases.values() if c.get("correct", False))
|
||||
total = len(completed_cases)
|
||||
|
|
@ -450,7 +442,7 @@ class Aime2025Dataset(BaseDataset):
|
|||
)
|
||||
|
||||
class Gsm8kDataset(BaseDataset):
|
||||
def __init__(self, split: str = "train"):
|
||||
def __init__(self, split: str = "test"):
|
||||
self.split = split
|
||||
self.questions: List[Dict] = []
|
||||
self._load_dataset()
|
||||
|
|
@ -683,6 +675,7 @@ Please provide only the extracted answer, nothing else. If there is no clear ans
|
|||
],
|
||||
"temperature": 0,
|
||||
}
|
||||
#print(json.dumps(data, indent=2))
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
|
|
@ -759,23 +752,20 @@ class Processor:
|
|||
|
||||
pred_truncated = self.grader._truncate_response(pred, max_lines=10)
|
||||
is_correct, extracted = self.grader.grade(gold, pred_truncated, prompt)
|
||||
task_state.correct = is_correct
|
||||
task_state.extracted = extracted
|
||||
task_state.status = "ok"
|
||||
|
||||
grader_log = {
|
||||
"case_id": task_id,
|
||||
"gold": gold,
|
||||
"pred": pred_truncated,
|
||||
"extracted": extracted,
|
||||
"correct": is_correct,
|
||||
"grader_type": self.grader.grader_type
|
||||
}
|
||||
if self.grader.grader_type == "regex" and self.grader.pattern:
|
||||
grader_log["pattern"] = self.grader.pattern
|
||||
eval_state.add_grader_log(grader_log)
|
||||
|
||||
eval_state.add_result(task_id, prompt, gold, pred, extracted, is_correct, "ok")
|
||||
task_state.correct = is_correct
|
||||
task_state.extracted = extracted
|
||||
task_state.grader_log = grader_log
|
||||
task_state.status = "ok"
|
||||
|
||||
eval_state.add_result(task_id, prompt, gold, pred, extracted, grader_log, is_correct, "ok")
|
||||
|
||||
eval_state.dump()
|
||||
|
||||
|
|
@ -962,11 +952,10 @@ def main():
|
|||
pending_tasks = eval_state.get_pending_tasks()
|
||||
print(f"Resuming from {len(pending_tasks)} pending tasks")
|
||||
|
||||
existing_cases = eval_state.task_states.get(eval_state.dataset_type, {}).get("cases", {})
|
||||
existing_cases = eval_state.task_states.get("cases", {})
|
||||
|
||||
eval_state.tasks = pending_tasks
|
||||
eval_state.task_states.get(eval_state.dataset_type, {})["cases"] = existing_cases
|
||||
eval_state.task_states.get(eval_state.dataset_type, {})["grader_log"] = []
|
||||
eval_state.task_states["cases"] = existing_cases
|
||||
|
||||
judge_server_url = args.judge_server if args.judge_server else args.server
|
||||
judge_model_name = args.judge_model if args.judge_model else args.model
|
||||
|
|
|
|||
Loading…
Reference in New Issue