fix prompts

This commit is contained in:
Georgi Gerganov 2026-02-16 21:02:25 +02:00
parent 6c41664b8b
commit e2e998a2d6
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
1 changed files with 53 additions and 29 deletions

View File

@ -52,23 +52,29 @@ SAMPLE_ANSWERS = {
}
TEMPLATE_REGISTRY = {
"aime": """{question}
Please reason step by step, and put your final answer within \\boxed{{}}.
"aime": """Solve the following math problem step by step. Put your answer inside \\boxed{{}}.
{question}
Remember to put your answer inside \\boxed{{}}.
""",
"aime2025": """{question}
Please reason step by step, and put your final answer within \\boxed{{}}.
"aime2025": """Solve the following math problem step by step. Put your answer inside \\boxed{{}}.
{question}
Remember to put your answer inside \\boxed{{}}.
""",
"gsm8k": """{question}
Please reason step by step, and put your final numeric answer within \\boxed{{}} without any extra characters.
""",
"gpqa": """{Question}
"gpqa": """Answer the following multiple choice question. The last line of your response should be in the following format: 'Answer: A/B/C/D' (e.g. 'Answer: A').
(A) {A}
(B) {B}
(C) {C}
(D) {D}
{Question}
Express your final answer as the corresponding option 'A', 'B', 'C', or 'D'.
A) {A}
B) {B}
C) {C}
D) {D}
""",
}
@ -78,6 +84,10 @@ class BaseDataset(ABC):
def get_question(self, index: int) -> Dict:
pass
@abstractmethod
def get_question_str(self, question: Dict) -> str:
pass
@abstractmethod
def get_answer(self, question: Dict) -> str:
pass
@ -155,13 +165,14 @@ class EvalState:
self.all_tasks = list(self.tasks)
def get_case(self, index: int) -> Tuple[str, str]:
def get_case(self, index: int) -> Tuple[str, str, str]:
if self.dataset is None:
raise ValueError("Dataset not loaded.")
question = self.dataset.get_question(index)
question_str = self.dataset.get_question_str(question)
prompt = self.dataset.get_prompt(question)
gold = self.dataset.get_answer(question)
return prompt, gold
return question_str, prompt, gold
def add_result(
self,
@ -218,7 +229,7 @@ class EvalState:
tasks_to_save = self.all_tasks if self.all_tasks else self.tasks
all_cases = {}
for i, task_id in tasks_to_save:
prompt, gold = self.get_case(i)
question, prompt, gold = self.get_case(i)
if task_id in self.task_states.get("cases", {}):
all_cases[task_id] = self.task_states["cases"][task_id]
else:
@ -303,19 +314,19 @@ class EvalState:
print("Tasks:")
print(" Task ID Dataset Prompt (first 40 chars) Expected Extracted Status")
for i, task_id in tasks_to_show:
prompt, gold = self.get_case(i)
question, prompt, gold = self.get_case(i)
case = cases.get(task_id, {})
status = case.get("status", "pending")
extracted = case.get("extracted", "N/A") if status == "ok" else "N/A"
is_correct = case.get("correct", False) if status == "ok" else False
symbol = "" if is_correct else ("" if status == "ok" else "")
first_line = prompt.split('\n')[0]
truncated_prompt = first_line[:43]
first_line = question.split('\n')[0]
question_trunc = first_line[:43]
if len(first_line) > 43:
truncated_prompt += "..."
question_trunc += "..."
else:
truncated_prompt = truncated_prompt.ljust(43) + "..."
print(f" {task_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {gold:<10} {extracted:<10} {symbol}{status}")
question_trunc = question_trunc.ljust(43) + "..."
print(f" {task_id:<20} {self.dataset_type.upper()} {question_trunc:<40} {gold:<10} {extracted:<10} {symbol}{status}")
print()
def print_existing_summary(self):
@ -367,6 +378,10 @@ class AimeDataset(BaseDataset):
"""Get question by index"""
return self.questions[index]
def get_question_str(self, question: Dict) -> str:
"""Get question string"""
return question["problem"] if "problem" in question else question["question"]
def get_answer(self, question: Dict) -> str:
answer = question["answer"]
if isinstance(answer, str):
@ -376,12 +391,9 @@ class AimeDataset(BaseDataset):
def get_prompt(self, question: Dict) -> str:
"""Get formatted prompt for the question"""
if question["dataset_type"] == "gpqa":
return TEMPLATE_REGISTRY["gpqa"].format(**question)
else:
return TEMPLATE_REGISTRY[question["dataset_type"]].format(
question=question["problem"] if "problem" in question else question["question"]
)
return TEMPLATE_REGISTRY[question["dataset_type"]].format(
question=self.get_question_str(question),
)
class Aime2025Dataset(BaseDataset):
def __init__(self):
@ -428,6 +440,10 @@ class Aime2025Dataset(BaseDataset):
"""Get question by index"""
return self.questions[index]
def get_question_str(self, question: Dict) -> str:
"""Get question string"""
return question["question"]
def get_answer(self, question: Dict) -> str:
answer = question["answer"]
if isinstance(answer, str):
@ -438,7 +454,7 @@ class Aime2025Dataset(BaseDataset):
def get_prompt(self, question: Dict) -> str:
"""Get formatted prompt for the question"""
return TEMPLATE_REGISTRY["aime2025"].format(
question=question["question"]
question=self.get_question_str(question),
)
class Gsm8kDataset(BaseDataset):
@ -481,6 +497,10 @@ class Gsm8kDataset(BaseDataset):
"""Get question by index"""
return self.questions[index]
def get_question_str(self, question: Dict) -> str:
"""Get question string"""
return question["problem"] if "problem" in question else question["question"]
def get_answer(self, question: Dict) -> str:
# GSM8K has pre-extracted gold field, AIME uses answer field
if "gold" in question:
@ -494,7 +514,7 @@ class Gsm8kDataset(BaseDataset):
def get_prompt(self, question: Dict) -> str:
"""Get formatted prompt for the question"""
return TEMPLATE_REGISTRY[question["dataset_type"]].format(
question=question["problem"] if "problem" in question else question["question"]
question=self.get_question_str(question),
)
class GpqaDataset(BaseDataset):
@ -549,6 +569,10 @@ class GpqaDataset(BaseDataset):
"""Get question by index"""
return self.questions[index]
def get_question_str(self, question: Dict) -> str:
"""Get question string"""
return question["Question"]
def get_answer(self, question: Dict) -> str:
# GPQA returns the correct letter (A, B, C, or D)
return question["correct_letter"]
@ -556,7 +580,7 @@ class GpqaDataset(BaseDataset):
def get_prompt(self, question: Dict) -> str:
"""Get formatted prompt for the question"""
return TEMPLATE_REGISTRY["gpqa"].format(
Question=question["Question"],
Question=self.get_question_str(question),
A=question["shuffled_answers"][0],
B=question["shuffled_answers"][1],
C=question["shuffled_answers"][2],
@ -737,7 +761,7 @@ class Processor:
return response.json()
def _process_single_case(self, eval_state: EvalState, i: int, task_id: str) -> TaskState:
prompt, gold = eval_state.get_case(i)
question, prompt, gold = eval_state.get_case(i)
task_state = TaskState(
case_id=task_id,