diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py index 249b211f07..262c307988 100755 --- a/examples/llama-eval/llama-eval.py +++ b/examples/llama-eval/llama-eval.py @@ -59,7 +59,7 @@ Please reason step by step, and put your final answer within \\boxed{{}}. Please reason step by step, and put your final answer within \\boxed{{}}. """, "gsm8k": """{question} -Please reason step by step, and provide your final answer. +Please reason step by step, and put your final numeric answer within \\boxed{{}} without any extra characters. """, "gpqa": """{Question} @@ -97,6 +97,7 @@ class TaskState: gold: str pred: Optional[str] = None extracted: Optional[str] = None + grader_log: Dict[str, Any] = field(default_factory=dict) correct: bool = False status: str = "pending" @@ -169,20 +170,20 @@ class EvalState: gold: str, pred: Optional[str], extracted: Optional[str], + grader_log: Dict[str, Any], correct: bool, status: str ): - if self.dataset_type not in self.task_states: - self.task_states[self.dataset_type] = {} - if "cases" not in self.task_states[self.dataset_type]: - self.task_states[self.dataset_type]["cases"] = {} + if "cases" not in self.task_states: + self.task_states["cases"] = {} - self.task_states[self.dataset_type]["cases"][task_id] = { + self.task_states["cases"][task_id] = { "case_id": task_id, "prompt": prompt, "gold": gold, "pred": pred, "extracted": extracted, + "grader_log": grader_log, "correct": correct, "status": status } @@ -190,14 +191,7 @@ class EvalState: if correct: self.correct += 1 else: - self.correct = sum(1 for c in self.task_states.get(self.dataset_type, {}).get("cases", {}).values() if c.get("correct", False)) - - def add_grader_log(self, grader_log: Dict[str, Any]): - if self.dataset_type not in self.task_states: - self.task_states[self.dataset_type] = {} - if "grader_log" not in self.task_states[self.dataset_type]: - self.task_states[self.dataset_type]["grader_log"] = [] - self.task_states[self.dataset_type]["grader_log"].append(grader_log) + self.correct = sum(1 for c in self.task_states.get("cases", {}).values() if c.get("correct", False)) def print_progress(self, task_state: TaskState, total_tasks: int, correct_count: int = 0): extracted_display = task_state.extracted if task_state.extracted else "N/A" @@ -225,8 +219,8 @@ class EvalState: all_cases = {} for i, task_id in tasks_to_save: prompt, gold = self.get_case(i) - if task_id in self.task_states.get(self.dataset_type, {}).get("cases", {}): - all_cases[task_id] = self.task_states[self.dataset_type]["cases"][task_id] + if task_id in self.task_states.get("cases", {}): + all_cases[task_id] = self.task_states["cases"][task_id] else: all_cases[task_id] = { "case_id": task_id, @@ -234,6 +228,7 @@ class EvalState: "gold": gold, "pred": None, "extracted": None, + "grader_log": {}, "correct": False, "status": "pending" } @@ -242,12 +237,9 @@ class EvalState: "id": self.dataset_type, "tasks": [tid for _, tid in tasks_to_save], "task_states": { - self.dataset_type: { - "total": self.total, - "correct": self.correct, - "cases": all_cases, - "grader_log": self.task_states.get("grader_log", []) - } + "total": self.total, + "correct": self.correct, + "cases": all_cases, }, "sampling_config": self.sampling_config } @@ -279,9 +271,9 @@ class EvalState: eval_state.task_states = data.get("task_states", {}) - cases = eval_state.task_states.get(eval_state.dataset_type, {}).get("cases", {}) - eval_state.total = eval_state.task_states.get(eval_state.dataset_type, {}).get("total", 0) - eval_state.correct = eval_state.task_states.get(eval_state.dataset_type, {}).get("correct", 0) + cases = eval_state.task_states.get("cases", {}) + eval_state.total = eval_state.task_states.get("total", 0) + eval_state.correct = eval_state.task_states.get("correct", 0) if eval_state.total == 0: eval_state.total = len(cases) @@ -292,12 +284,12 @@ class EvalState: def is_complete(self) -> bool: if not self.all_tasks: return False - cases = self.task_states.get(self.dataset_type, {}).get("cases", {}) - completed = {tid for tid in self.task_states.get(self.dataset_type, {}).get("cases", {}).keys() if cases.get(tid, {}).get("status") == "ok"} + cases = self.task_states.get("cases", {}) + completed = {tid for tid in self.task_states.get("cases", {}).keys() if cases.get(tid, {}).get("status") == "ok"} return len(completed) == len(self.all_tasks) def get_pending_tasks(self) -> List[Tuple[int, str]]: - cases = self.task_states.get(self.dataset_type, {}).get("cases", {}) + cases = self.task_states.get("cases", {}) pending = [] for i, task_id in self.all_tasks: if cases.get(task_id, {}).get("status") != "ok": @@ -305,7 +297,7 @@ class EvalState: return pending def print_all_tasks(self): - cases = self.task_states.get(self.dataset_type, {}).get("cases", {}) + cases = self.task_states.get("cases", {}) tasks_to_show = self.all_tasks if self.all_tasks else self.tasks print() print("Tasks:") @@ -327,7 +319,7 @@ class EvalState: print() def print_existing_summary(self): - cases = self.task_states.get(self.dataset_type, {}).get("cases", {}) + cases = self.task_states.get("cases", {}) completed_cases = {tid: c for tid, c in cases.items() if c.get("status") == "ok"} correct = sum(1 for c in completed_cases.values() if c.get("correct", False)) total = len(completed_cases) @@ -450,7 +442,7 @@ class Aime2025Dataset(BaseDataset): ) class Gsm8kDataset(BaseDataset): - def __init__(self, split: str = "train"): + def __init__(self, split: str = "test"): self.split = split self.questions: List[Dict] = [] self._load_dataset() @@ -683,6 +675,7 @@ Please provide only the extracted answer, nothing else. If there is no clear ans ], "temperature": 0, } + #print(json.dumps(data, indent=2)) try: response = requests.post(url, headers=headers, json=data) @@ -759,23 +752,20 @@ class Processor: pred_truncated = self.grader._truncate_response(pred, max_lines=10) is_correct, extracted = self.grader.grade(gold, pred_truncated, prompt) - task_state.correct = is_correct - task_state.extracted = extracted - task_state.status = "ok" grader_log = { - "case_id": task_id, - "gold": gold, "pred": pred_truncated, - "extracted": extracted, - "correct": is_correct, "grader_type": self.grader.grader_type } if self.grader.grader_type == "regex" and self.grader.pattern: grader_log["pattern"] = self.grader.pattern - eval_state.add_grader_log(grader_log) - eval_state.add_result(task_id, prompt, gold, pred, extracted, is_correct, "ok") + task_state.correct = is_correct + task_state.extracted = extracted + task_state.grader_log = grader_log + task_state.status = "ok" + + eval_state.add_result(task_id, prompt, gold, pred, extracted, grader_log, is_correct, "ok") eval_state.dump() @@ -962,11 +952,10 @@ def main(): pending_tasks = eval_state.get_pending_tasks() print(f"Resuming from {len(pending_tasks)} pending tasks") - existing_cases = eval_state.task_states.get(eval_state.dataset_type, {}).get("cases", {}) + existing_cases = eval_state.task_states.get("cases", {}) eval_state.tasks = pending_tasks - eval_state.task_states.get(eval_state.dataset_type, {})["cases"] = existing_cases - eval_state.task_states.get(eval_state.dataset_type, {})["grader_log"] = [] + eval_state.task_states["cases"] = existing_cases judge_server_url = args.judge_server if args.judge_server else args.server judge_model_name = args.judge_model if args.judge_model else args.model