diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py index 262c307988..726936ef40 100755 --- a/examples/llama-eval/llama-eval.py +++ b/examples/llama-eval/llama-eval.py @@ -52,23 +52,29 @@ SAMPLE_ANSWERS = { } TEMPLATE_REGISTRY = { - "aime": """{question} -Please reason step by step, and put your final answer within \\boxed{{}}. + "aime": """Solve the following math problem step by step. Put your answer inside \\boxed{{}}. + +{question} + +Remember to put your answer inside \\boxed{{}}. """, - "aime2025": """{question} -Please reason step by step, and put your final answer within \\boxed{{}}. + "aime2025": """Solve the following math problem step by step. Put your answer inside \\boxed{{}}. + +{question} + +Remember to put your answer inside \\boxed{{}}. """, "gsm8k": """{question} Please reason step by step, and put your final numeric answer within \\boxed{{}} without any extra characters. """, - "gpqa": """{Question} + "gpqa": """Answer the following multiple choice question. The last line of your response should be in the following format: 'Answer: A/B/C/D' (e.g. 'Answer: A'). -(A) {A} -(B) {B} -(C) {C} -(D) {D} +{Question} -Express your final answer as the corresponding option 'A', 'B', 'C', or 'D'. +A) {A} +B) {B} +C) {C} +D) {D} """, } @@ -78,6 +84,10 @@ class BaseDataset(ABC): def get_question(self, index: int) -> Dict: pass + @abstractmethod + def get_question_str(self, question: Dict) -> str: + pass + @abstractmethod def get_answer(self, question: Dict) -> str: pass @@ -155,13 +165,14 @@ class EvalState: self.all_tasks = list(self.tasks) - def get_case(self, index: int) -> Tuple[str, str]: + def get_case(self, index: int) -> Tuple[str, str, str]: if self.dataset is None: raise ValueError("Dataset not loaded.") question = self.dataset.get_question(index) + question_str = self.dataset.get_question_str(question) prompt = self.dataset.get_prompt(question) gold = self.dataset.get_answer(question) - return prompt, gold + return question_str, prompt, gold def add_result( self, @@ -218,7 +229,7 @@ class EvalState: tasks_to_save = self.all_tasks if self.all_tasks else self.tasks all_cases = {} for i, task_id in tasks_to_save: - prompt, gold = self.get_case(i) + question, prompt, gold = self.get_case(i) if task_id in self.task_states.get("cases", {}): all_cases[task_id] = self.task_states["cases"][task_id] else: @@ -303,19 +314,19 @@ class EvalState: print("Tasks:") print(" Task ID Dataset Prompt (first 40 chars) Expected Extracted Status") for i, task_id in tasks_to_show: - prompt, gold = self.get_case(i) + question, prompt, gold = self.get_case(i) case = cases.get(task_id, {}) status = case.get("status", "pending") extracted = case.get("extracted", "N/A") if status == "ok" else "N/A" is_correct = case.get("correct", False) if status == "ok" else False symbol = "✓ " if is_correct else ("✗ " if status == "ok" else "") - first_line = prompt.split('\n')[0] - truncated_prompt = first_line[:43] + first_line = question.split('\n')[0] + question_trunc = first_line[:43] if len(first_line) > 43: - truncated_prompt += "..." + question_trunc += "..." else: - truncated_prompt = truncated_prompt.ljust(43) + "..." - print(f" {task_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {gold:<10} {extracted:<10} {symbol}{status}") + question_trunc = question_trunc.ljust(43) + "..." + print(f" {task_id:<20} {self.dataset_type.upper()} {question_trunc:<40} {gold:<10} {extracted:<10} {symbol}{status}") print() def print_existing_summary(self): @@ -367,6 +378,10 @@ class AimeDataset(BaseDataset): """Get question by index""" return self.questions[index] + def get_question_str(self, question: Dict) -> str: + """Get question string""" + return question["problem"] if "problem" in question else question["question"] + def get_answer(self, question: Dict) -> str: answer = question["answer"] if isinstance(answer, str): @@ -376,12 +391,9 @@ class AimeDataset(BaseDataset): def get_prompt(self, question: Dict) -> str: """Get formatted prompt for the question""" - if question["dataset_type"] == "gpqa": - return TEMPLATE_REGISTRY["gpqa"].format(**question) - else: - return TEMPLATE_REGISTRY[question["dataset_type"]].format( - question=question["problem"] if "problem" in question else question["question"] - ) + return TEMPLATE_REGISTRY[question["dataset_type"]].format( + question=self.get_question_str(question), + ) class Aime2025Dataset(BaseDataset): def __init__(self): @@ -428,6 +440,10 @@ class Aime2025Dataset(BaseDataset): """Get question by index""" return self.questions[index] + def get_question_str(self, question: Dict) -> str: + """Get question string""" + return question["question"] + def get_answer(self, question: Dict) -> str: answer = question["answer"] if isinstance(answer, str): @@ -438,7 +454,7 @@ class Aime2025Dataset(BaseDataset): def get_prompt(self, question: Dict) -> str: """Get formatted prompt for the question""" return TEMPLATE_REGISTRY["aime2025"].format( - question=question["question"] + question=self.get_question_str(question), ) class Gsm8kDataset(BaseDataset): @@ -481,6 +497,10 @@ class Gsm8kDataset(BaseDataset): """Get question by index""" return self.questions[index] + def get_question_str(self, question: Dict) -> str: + """Get question string""" + return question["problem"] if "problem" in question else question["question"] + def get_answer(self, question: Dict) -> str: # GSM8K has pre-extracted gold field, AIME uses answer field if "gold" in question: @@ -494,7 +514,7 @@ class Gsm8kDataset(BaseDataset): def get_prompt(self, question: Dict) -> str: """Get formatted prompt for the question""" return TEMPLATE_REGISTRY[question["dataset_type"]].format( - question=question["problem"] if "problem" in question else question["question"] + question=self.get_question_str(question), ) class GpqaDataset(BaseDataset): @@ -549,6 +569,10 @@ class GpqaDataset(BaseDataset): """Get question by index""" return self.questions[index] + def get_question_str(self, question: Dict) -> str: + """Get question string""" + return question["Question"] + def get_answer(self, question: Dict) -> str: # GPQA returns the correct letter (A, B, C, or D) return question["correct_letter"] @@ -556,7 +580,7 @@ class GpqaDataset(BaseDataset): def get_prompt(self, question: Dict) -> str: """Get formatted prompt for the question""" return TEMPLATE_REGISTRY["gpqa"].format( - Question=question["Question"], + Question=self.get_question_str(question), A=question["shuffled_answers"][0], B=question["shuffled_answers"][1], C=question["shuffled_answers"][2], @@ -737,7 +761,7 @@ class Processor: return response.json() def _process_single_case(self, eval_state: EvalState, i: int, task_id: str) -> TaskState: - prompt, gold = eval_state.get_case(i) + question, prompt, gold = eval_state.get_case(i) task_state = TaskState( case_id=task_id,