diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py index 7d7348aa8e..f7c29832c6 100755 --- a/examples/llama-eval/llama-eval.py +++ b/examples/llama-eval/llama-eval.py @@ -325,18 +325,30 @@ class Grader: def _get_pattern(self) -> Optional[str]: if self.grader_type == "regex": - return GRADER_PATTERNS.get(self.grader_type) # Use grader_type as key + return GRADER_PATTERNS.get(self.dataset_type) # Use dataset_type as key return None def _extract_answer_regex(self, pred: str) -> Optional[str]: """Extract answer using regex pattern""" if not self.pattern: return None + + # For AIME datasets, prioritize boxed answers + if self.dataset_type in ["aime", "aime2025"]: + boxed_pattern = r'\\boxed{([^}]+)}' + boxed_matches = re.findall(boxed_pattern, pred, re.IGNORECASE) + if boxed_matches: + # Return the last boxed answer found (most likely the final answer) + return boxed_matches[-1].strip() + + # For other datasets, search for numbers from the end of the text + # This prioritizes numbers that appear later in the response matches = re.findall(self.pattern, pred, re.IGNORECASE) if not matches: return None - for match in matches: + # Process matches from end to start + for match in reversed(matches): if isinstance(match, tuple): match = match[0] if match[0] else match[1] extracted = match.strip() @@ -446,7 +458,8 @@ class Processor: judge_model_name: Optional[str] = None, dataset_type: str = "aime", seed: int = 1234, - sampling_config: Optional[Dict[str, Any]] = None + sampling_config: Optional[Dict[str, Any]] = None, + output_file: Optional[Path] = None ): self.server_url = server_url self.n_predict = n_predict @@ -459,10 +472,11 @@ class Processor: self.seed = seed self.grader = grader or Grader() self.sampling_config = sampling_config or {"n_predict": n_predict} + self.output_file = output_file or Path("llama-eval-state.json") self.eval_state = EvalState( id=dataset_type, tasks=[dataset_type], - task_states={}, + task_states={dataset_type: {}}, sampling_config=self.sampling_config ) @@ -533,8 +547,44 @@ class Processor: task_state.correct = is_correct task_state.extracted = extracted task_state.status = "ok" - except Exception as e: - task_state.status = f"error: {str(e)}" + + # Log grader request details for debugging + grader_log = { + "case_id": task_id, + "gold": gold, + "pred": pred_truncated, + "extracted": extracted, + "correct": is_correct, + "grader_type": self.grader.grader_type + } + if self.grader.grader_type == "regex" and self.grader.pattern: + grader_log["pattern"] = self.grader.pattern + if "grader_log" not in self.eval_state.task_states[self.dataset_type]: + self.eval_state.task_states[self.dataset_type]["grader_log"] = [] + self.eval_state.task_states[self.dataset_type]["grader_log"].append(grader_log) + + # Initialize cases dict if it doesn't exist + if "cases" not in self.eval_state.task_states[self.dataset_type]: + self.eval_state.task_states[self.dataset_type]["cases"] = {} + + # Update eval state with grading details + self.eval_state.task_states[self.dataset_type]["cases"][task_id] = { + "case_id": task_id, + "prompt": prompt, + "gold": gold, + "pred": pred, + "extracted": extracted, + "correct": is_correct, + "status": "ok" + } + + # Save eval state to disk after each task + try: + self.dump_state(self.output_file) + except Exception as dump_error: + task_state.status = f"error: {str(e)}; dump error: {str(dump_error)}" + except Exception as processing_error: + task_state.status = f"error: {str(processing_error)}" return task_state @@ -621,10 +671,13 @@ class Processor: print(f" Extracted: {task_state.extracted}") print(f" Status: {task_state.status}") + # Merge existing state with new state to preserve grader_log + existing_state = self.eval_state.task_states.get(self.dataset_type, {}) self.eval_state.task_states[self.dataset_type] = { "total": total, "correct": correct, - "cases": task_states + "cases": task_states, + **existing_state } print(f"\n{'='*60}") @@ -637,7 +690,6 @@ class Processor: """Dump eval state to JSON file""" with open(output_file, "w") as f: json.dump(asdict(self.eval_state), f, indent=2) - print(f"\nEval state dumped to {output_file}") def main(): parser = argparse.ArgumentParser( @@ -785,11 +837,13 @@ def main(): judge_server_url=args.judge_server, judge_model_name=args.judge_model, dataset_type=args.dataset, - sampling_config=sampling_config + sampling_config=sampling_config, + output_file=args.output ) eval_state = processor.process(n_cases=args.n_cases, seed=args.seed) processor.dump_state(args.output) + print(f"\nEval state dumped to {args.output}") if __name__ == "__main__": main()