diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py index cb6c36148c..d44530e6ef 100755 --- a/examples/llama-eval/llama-eval.py +++ b/examples/llama-eval/llama-eval.py @@ -110,6 +110,7 @@ class TaskState: grader_log: Dict[str, Any] = field(default_factory=dict) correct: bool = False status: str = "pending" + tokens: Optional[int] = None class EvalState: @@ -183,7 +184,8 @@ class EvalState: extracted: Optional[str], grader_log: Dict[str, Any], correct: bool, - status: str + status: str, + tokens: Optional[int] = None ): if "cases" not in self.task_states: self.task_states["cases"] = {} @@ -196,7 +198,8 @@ class EvalState: "extracted": extracted, "grader_log": grader_log, "correct": correct, - "status": status + "status": status, + "tokens": tokens } if correct: @@ -206,6 +209,7 @@ class EvalState: def print_progress(self, task_state: TaskState, total_tasks: int, correct_count: int = 0): extracted_display = task_state.extracted if task_state.extracted else "N/A" + tokens_display = str(task_state.tokens) if task_state.tokens is not None else "N/A" success_ratio = correct_count / self.processed if self.processed > 0 else 0.0 first_line = task_state.prompt.split('\n')[0] truncated_prompt = first_line[:43] @@ -213,7 +217,7 @@ class EvalState: truncated_prompt += "..." else: truncated_prompt = truncated_prompt.ljust(43) + "..." - print(f"{self.processed:3}/{total_tasks:3} {task_state.case_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {task_state.gold:<10} {extracted_display:<10} {'✓' if task_state.correct else '✗'} [{correct_count:3}/{self.processed:3}, {success_ratio:.3f}]") + print(f"{self.processed:3}/{total_tasks:3} {task_state.case_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {task_state.gold:<10} {extracted_display:<10} {tokens_display:<6} {'✓' if task_state.correct else '✗'} [{correct_count:3}/{self.processed:3}, {success_ratio:.3f}]") def print_summary(self): if self.total == 0: @@ -241,7 +245,8 @@ class EvalState: "extracted": None, "grader_log": {}, "correct": False, - "status": "pending" + "status": "pending", + "tokens": None } data = { @@ -296,6 +301,9 @@ class EvalState: status_class = "error" status_text = f"Error: {status}" + tokens = case.get("tokens") + tokens_str = str(tokens) if tokens is not None else "" + result_escaped = self._escape_html(result) prompt_escaped = self._escape_html(prompt) grader_log_str = self._escape_html(json.dumps(grader_log, indent=2)) @@ -305,9 +313,10 @@ class EvalState: {status_text} {self._escape_html(gold)} {self._escape_html(extracted)} + {tokens_str} - +

Prompt

{prompt_escaped}
@@ -371,6 +380,7 @@ class EvalState: Status Gold Extracted + Tokens @@ -451,12 +461,14 @@ class EvalState: tasks_to_show = self.all_tasks if self.all_tasks else self.tasks print() print("Tasks:") - print(" Task ID Dataset Prompt (first 40 chars) Expected Extracted Status") + print(" Task ID Dataset Prompt (first 40 chars) Expected Extracted Tokens Status") for i, task_id in tasks_to_show: question, prompt, gold = self.get_case(i) case = cases.get(task_id, {}) status = case.get("status", "pending") extracted = case.get("extracted", "N/A") if status == "ok" else "N/A" + tokens = case.get("tokens") + tokens_str = str(tokens) if tokens is not None else "N/A" is_correct = case.get("correct", False) if status == "ok" else False symbol = "✓ " if is_correct else ("✗ " if status == "ok" else "") first_line = question.split('\n')[0] @@ -465,7 +477,7 @@ class EvalState: question_trunc += "..." else: question_trunc = question_trunc.ljust(43) + "..." - print(f" {task_id:<20} {self.dataset_type.upper()} {question_trunc:<40} {gold:<10} {extracted:<10} {symbol}{status}") + print(f" {task_id:<20} {self.dataset_type.upper()} {question_trunc:<40} {gold:<10} {extracted:<10} {tokens_str:<6} {symbol}{status}") print() def print_existing_summary(self): @@ -878,7 +890,7 @@ class Processor: self.model_name = model_name self.threads = threads - def _make_request(self, eval_state: EvalState, prompt: str) -> Dict[str, Any]: + def _make_request(self, eval_state: EvalState, prompt: str) -> Tuple[Dict[str, Any], int]: url = f"{self.server_url}/v1/chat/completions" headers = {"Content-Type": "application/json"} data = { @@ -897,7 +909,9 @@ class Processor: response = requests.post(url, headers=headers, json=data) response.raise_for_status() - return response.json() + result = response.json() + tokens = result.get("usage", {}).get("completion_tokens", 0) + return result, tokens def _process_single_case(self, eval_state: EvalState, i: int, task_id: str) -> TaskState: question, prompt, gold = eval_state.get_case(i) @@ -909,9 +923,10 @@ class Processor: ) try: - response = self._make_request(eval_state, prompt) + response, tokens = self._make_request(eval_state, prompt) result = response["choices"][0]["message"]["content"] task_state.result = result + task_state.tokens = tokens result_truncated = self.grader._truncate_response(result, max_lines=10) is_correct, extracted = self.grader.grade(gold, result_truncated, prompt) @@ -928,7 +943,7 @@ class Processor: task_state.grader_log = grader_log task_state.status = "ok" - eval_state.add_result(task_id, prompt, gold, result, extracted, grader_log, is_correct, "ok") + eval_state.add_result(task_id, prompt, gold, result, extracted, grader_log, is_correct, "ok", tokens) eval_state.dump()