diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py index 35850c2a25..249b211f07 100755 --- a/examples/llama-eval/llama-eval.py +++ b/examples/llama-eval/llama-eval.py @@ -328,11 +328,17 @@ class EvalState: def print_existing_summary(self): cases = self.task_states.get(self.dataset_type, {}).get("cases", {}) - correct = sum(1 for c in cases.values() if c.get("correct", False)) - total = len(cases) - print(f"{'='*60}") - print(f"Results: {correct}/{total} correct ({correct/total*100:.1f}%)") - print(f"{'='*60}") + completed_cases = {tid: c for tid, c in cases.items() if c.get("status") == "ok"} + correct = sum(1 for c in completed_cases.values() if c.get("correct", False)) + total = len(completed_cases) + if total == 0: + print(f"{'='*60}") + print(f"Results: 0/0 correct (0.0%)") + print(f"{'='*60}") + else: + print(f"{'='*60}") + print(f"Results: {correct}/{total} correct ({correct/total*100:.1f}%)") + print(f"{'='*60}") def normalize_number(s: str) -> Optional[int]: match = re.match(r"\d+", s) # match digits from the start @@ -814,7 +820,6 @@ class Processor: print(f" Extracted: {task_state.extracted}") print(f" Status: {task_state.status}") - eval_state.correct = correct_count eval_state.print_summary() eval_state.dump()