diff --git a/examples/llama-eval/llama-eval-new.py b/examples/llama-eval/llama-eval-new.py index 7c4a7582b2..d3c318e151 100755 --- a/examples/llama-eval/llama-eval-new.py +++ b/examples/llama-eval/llama-eval-new.py @@ -27,6 +27,13 @@ GRADER_PATTERNS = { "winogrande": r'[A-D]', } +TEMPLATE_REGISTRY = { + "aime": """ +{question} +Please reason step by step, and put your final answer within \\boxed{{}}. +""", +} + @dataclass class EvalState: id: str @@ -43,6 +50,12 @@ class TaskState: correct: bool = False status: str = "pending" +def normalize_number(s: str) -> Optional[int]: + match = re.match(r"\d+", s) # match digits from the start + if not match: + return None + return int(match.group(0)) + class AimeDataset: def __init__(self, split: str = "train"): self.split = split @@ -60,7 +73,12 @@ class AimeDataset: else: ds = load_dataset("AI-MO/aimo-validation-aime", split=self.split) - self.questions = list(ds) + self.questions = [] + for row in ds: + question = dict(row) + question["dataset_type"] = "aime" + self.questions.append(question) + print(f"AIME dataset loaded: {len(self.questions)} questions") def get_question(self, index: int) -> Dict: @@ -68,7 +86,11 @@ class AimeDataset: return self.questions[index] def get_answer(self, question: Dict) -> str: - return str(question["answer"]) + answer = question["answer"] + if isinstance(answer, str): + normalized = normalize_number(answer) + return str(normalized) if normalized is not None else answer + return str(answer) class Grader: def __init__( @@ -177,9 +199,14 @@ class Processor: """Process a single case (thread-safe)""" question = self.dataset.get_question(i) case_id = f"aime_{self.dataset.split}_{question['id']}" - prompt = question["problem"] gold = self.dataset.get_answer(question) + # Apply template if available + if question["dataset_type"] in TEMPLATE_REGISTRY: + prompt = TEMPLATE_REGISTRY[question["dataset_type"]].format(question=question["problem"]) + else: + prompt = question["problem"] + task_state = TaskState( case_id=case_id, prompt=prompt,