diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py
index 57cced2dac..6af1459e25 100755
--- a/examples/llama-eval/llama-eval.py
+++ b/examples/llama-eval/llama-eval.py
@@ -132,6 +132,7 @@ class EvalState:
self.total = 0
self.correct = 0
self.processed = 0
+ self.total_time: float = 0.0
def load_dataset(self, seed: int = 1234):
if self.dataset_type == "aime":
@@ -258,6 +259,7 @@ class EvalState:
"task_states": {
"total": self.total,
"correct": self.correct,
+ "total_time": self.total_time,
"cases": all_cases,
},
"sampling_config": self.sampling_config
@@ -377,6 +379,7 @@ class EvalState:
| Incorrect | {incorrect_count} |
| Pending | {pending_count} |
| Accuracy | {accuracy:.1f}% |
+ | Total Time | {self.total_time:.1f}s |
| Sampling | {sampling_str} |
@@ -449,6 +452,7 @@ class EvalState:
cases = eval_state.task_states.get("cases", {})
eval_state.total = eval_state.task_states.get("total", 0)
eval_state.correct = eval_state.task_states.get("correct", 0)
+ eval_state.total_time = eval_state.task_states.get("total_time", 0.0)
if eval_state.total == 0:
eval_state.total = len(cases)
@@ -984,6 +988,7 @@ class Processor:
total_tasks = len(eval_state.tasks)
eval_state.total = len(eval_state.all_tasks) if eval_state.all_tasks else total_tasks
eval_state.processed = 0
+ start_time = time.time()
print(f"\nProcessing {len(eval_state.tasks)} {eval_state.dataset_type.upper()} tasks ...")
print(f"Server: {self.server_url} (model: {self.model_name})")
@@ -1000,11 +1005,16 @@ class Processor:
for i, task_id in eval_state.tasks
}
+ session_time = 0.0
for future in as_completed(futures):
task_state = future.result()
eval_state.processed += 1
if task_state.correct:
correct_count += 1
+ elapsed = time.time() - start_time
+ eval_state.total_time += elapsed
+ session_time += elapsed
+ start_time = time.time()
eval_state.print_progress(task_state, total_tasks, correct_count)
if verbose:
@@ -1016,6 +1026,7 @@ class Processor:
print(f" Answer: {task_state.answer}")
print(f" Status: {task_state.status}")
+ print(f"\nSession time: {session_time:.1f}s | Total accumulated time: {eval_state.total_time:.1f}s")
eval_state.print_summary()
eval_state.dump()