fix prompts
This commit is contained in:
parent
6c41664b8b
commit
e2e998a2d6
|
|
@ -52,23 +52,29 @@ SAMPLE_ANSWERS = {
|
|||
}
|
||||
|
||||
TEMPLATE_REGISTRY = {
|
||||
"aime": """{question}
|
||||
Please reason step by step, and put your final answer within \\boxed{{}}.
|
||||
"aime": """Solve the following math problem step by step. Put your answer inside \\boxed{{}}.
|
||||
|
||||
{question}
|
||||
|
||||
Remember to put your answer inside \\boxed{{}}.
|
||||
""",
|
||||
"aime2025": """{question}
|
||||
Please reason step by step, and put your final answer within \\boxed{{}}.
|
||||
"aime2025": """Solve the following math problem step by step. Put your answer inside \\boxed{{}}.
|
||||
|
||||
{question}
|
||||
|
||||
Remember to put your answer inside \\boxed{{}}.
|
||||
""",
|
||||
"gsm8k": """{question}
|
||||
Please reason step by step, and put your final numeric answer within \\boxed{{}} without any extra characters.
|
||||
""",
|
||||
"gpqa": """{Question}
|
||||
"gpqa": """Answer the following multiple choice question. The last line of your response should be in the following format: 'Answer: A/B/C/D' (e.g. 'Answer: A').
|
||||
|
||||
(A) {A}
|
||||
(B) {B}
|
||||
(C) {C}
|
||||
(D) {D}
|
||||
{Question}
|
||||
|
||||
Express your final answer as the corresponding option 'A', 'B', 'C', or 'D'.
|
||||
A) {A}
|
||||
B) {B}
|
||||
C) {C}
|
||||
D) {D}
|
||||
""",
|
||||
}
|
||||
|
||||
|
|
@ -78,6 +84,10 @@ class BaseDataset(ABC):
|
|||
def get_question(self, index: int) -> Dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_question_str(self, question: Dict) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_answer(self, question: Dict) -> str:
|
||||
pass
|
||||
|
|
@ -155,13 +165,14 @@ class EvalState:
|
|||
|
||||
self.all_tasks = list(self.tasks)
|
||||
|
||||
def get_case(self, index: int) -> Tuple[str, str]:
|
||||
def get_case(self, index: int) -> Tuple[str, str, str]:
|
||||
if self.dataset is None:
|
||||
raise ValueError("Dataset not loaded.")
|
||||
question = self.dataset.get_question(index)
|
||||
question_str = self.dataset.get_question_str(question)
|
||||
prompt = self.dataset.get_prompt(question)
|
||||
gold = self.dataset.get_answer(question)
|
||||
return prompt, gold
|
||||
return question_str, prompt, gold
|
||||
|
||||
def add_result(
|
||||
self,
|
||||
|
|
@ -218,7 +229,7 @@ class EvalState:
|
|||
tasks_to_save = self.all_tasks if self.all_tasks else self.tasks
|
||||
all_cases = {}
|
||||
for i, task_id in tasks_to_save:
|
||||
prompt, gold = self.get_case(i)
|
||||
question, prompt, gold = self.get_case(i)
|
||||
if task_id in self.task_states.get("cases", {}):
|
||||
all_cases[task_id] = self.task_states["cases"][task_id]
|
||||
else:
|
||||
|
|
@ -303,19 +314,19 @@ class EvalState:
|
|||
print("Tasks:")
|
||||
print(" Task ID Dataset Prompt (first 40 chars) Expected Extracted Status")
|
||||
for i, task_id in tasks_to_show:
|
||||
prompt, gold = self.get_case(i)
|
||||
question, prompt, gold = self.get_case(i)
|
||||
case = cases.get(task_id, {})
|
||||
status = case.get("status", "pending")
|
||||
extracted = case.get("extracted", "N/A") if status == "ok" else "N/A"
|
||||
is_correct = case.get("correct", False) if status == "ok" else False
|
||||
symbol = "✓ " if is_correct else ("✗ " if status == "ok" else "")
|
||||
first_line = prompt.split('\n')[0]
|
||||
truncated_prompt = first_line[:43]
|
||||
first_line = question.split('\n')[0]
|
||||
question_trunc = first_line[:43]
|
||||
if len(first_line) > 43:
|
||||
truncated_prompt += "..."
|
||||
question_trunc += "..."
|
||||
else:
|
||||
truncated_prompt = truncated_prompt.ljust(43) + "..."
|
||||
print(f" {task_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {gold:<10} {extracted:<10} {symbol}{status}")
|
||||
question_trunc = question_trunc.ljust(43) + "..."
|
||||
print(f" {task_id:<20} {self.dataset_type.upper()} {question_trunc:<40} {gold:<10} {extracted:<10} {symbol}{status}")
|
||||
print()
|
||||
|
||||
def print_existing_summary(self):
|
||||
|
|
@ -367,6 +378,10 @@ class AimeDataset(BaseDataset):
|
|||
"""Get question by index"""
|
||||
return self.questions[index]
|
||||
|
||||
def get_question_str(self, question: Dict) -> str:
|
||||
"""Get question string"""
|
||||
return question["problem"] if "problem" in question else question["question"]
|
||||
|
||||
def get_answer(self, question: Dict) -> str:
|
||||
answer = question["answer"]
|
||||
if isinstance(answer, str):
|
||||
|
|
@ -376,12 +391,9 @@ class AimeDataset(BaseDataset):
|
|||
|
||||
def get_prompt(self, question: Dict) -> str:
|
||||
"""Get formatted prompt for the question"""
|
||||
if question["dataset_type"] == "gpqa":
|
||||
return TEMPLATE_REGISTRY["gpqa"].format(**question)
|
||||
else:
|
||||
return TEMPLATE_REGISTRY[question["dataset_type"]].format(
|
||||
question=question["problem"] if "problem" in question else question["question"]
|
||||
)
|
||||
return TEMPLATE_REGISTRY[question["dataset_type"]].format(
|
||||
question=self.get_question_str(question),
|
||||
)
|
||||
|
||||
class Aime2025Dataset(BaseDataset):
|
||||
def __init__(self):
|
||||
|
|
@ -428,6 +440,10 @@ class Aime2025Dataset(BaseDataset):
|
|||
"""Get question by index"""
|
||||
return self.questions[index]
|
||||
|
||||
def get_question_str(self, question: Dict) -> str:
|
||||
"""Get question string"""
|
||||
return question["question"]
|
||||
|
||||
def get_answer(self, question: Dict) -> str:
|
||||
answer = question["answer"]
|
||||
if isinstance(answer, str):
|
||||
|
|
@ -438,7 +454,7 @@ class Aime2025Dataset(BaseDataset):
|
|||
def get_prompt(self, question: Dict) -> str:
|
||||
"""Get formatted prompt for the question"""
|
||||
return TEMPLATE_REGISTRY["aime2025"].format(
|
||||
question=question["question"]
|
||||
question=self.get_question_str(question),
|
||||
)
|
||||
|
||||
class Gsm8kDataset(BaseDataset):
|
||||
|
|
@ -481,6 +497,10 @@ class Gsm8kDataset(BaseDataset):
|
|||
"""Get question by index"""
|
||||
return self.questions[index]
|
||||
|
||||
def get_question_str(self, question: Dict) -> str:
|
||||
"""Get question string"""
|
||||
return question["problem"] if "problem" in question else question["question"]
|
||||
|
||||
def get_answer(self, question: Dict) -> str:
|
||||
# GSM8K has pre-extracted gold field, AIME uses answer field
|
||||
if "gold" in question:
|
||||
|
|
@ -494,7 +514,7 @@ class Gsm8kDataset(BaseDataset):
|
|||
def get_prompt(self, question: Dict) -> str:
|
||||
"""Get formatted prompt for the question"""
|
||||
return TEMPLATE_REGISTRY[question["dataset_type"]].format(
|
||||
question=question["problem"] if "problem" in question else question["question"]
|
||||
question=self.get_question_str(question),
|
||||
)
|
||||
|
||||
class GpqaDataset(BaseDataset):
|
||||
|
|
@ -549,6 +569,10 @@ class GpqaDataset(BaseDataset):
|
|||
"""Get question by index"""
|
||||
return self.questions[index]
|
||||
|
||||
def get_question_str(self, question: Dict) -> str:
|
||||
"""Get question string"""
|
||||
return question["Question"]
|
||||
|
||||
def get_answer(self, question: Dict) -> str:
|
||||
# GPQA returns the correct letter (A, B, C, or D)
|
||||
return question["correct_letter"]
|
||||
|
|
@ -556,7 +580,7 @@ class GpqaDataset(BaseDataset):
|
|||
def get_prompt(self, question: Dict) -> str:
|
||||
"""Get formatted prompt for the question"""
|
||||
return TEMPLATE_REGISTRY["gpqa"].format(
|
||||
Question=question["Question"],
|
||||
Question=self.get_question_str(question),
|
||||
A=question["shuffled_answers"][0],
|
||||
B=question["shuffled_answers"][1],
|
||||
C=question["shuffled_answers"][2],
|
||||
|
|
@ -737,7 +761,7 @@ class Processor:
|
|||
return response.json()
|
||||
|
||||
def _process_single_case(self, eval_state: EvalState, i: int, task_id: str) -> TaskState:
|
||||
prompt, gold = eval_state.get_case(i)
|
||||
question, prompt, gold = eval_state.get_case(i)
|
||||
|
||||
task_state = TaskState(
|
||||
case_id=task_id,
|
||||
|
|
|
|||
Loading…
Reference in New Issue