datasets : add gsm8k
This commit is contained in:
parent
1db8428f00
commit
e8a807519a
|
|
@ -328,3 +328,68 @@ Questions:
|
|||
- Updated `_grade_llm()` to use instance variables instead of parameters
|
||||
- Simplified Processor initialization to pass judge config to grader
|
||||
- Updated startup info to show judge server and model
|
||||
|
||||
### llama-eval-new.py GSM8K Dataset Support
|
||||
|
||||
**Changes Made:**
|
||||
1. **GSM8K Dataset Integration** - Added support for GSM8K dataset alongside AIME
|
||||
- Created `Gsm8kDataset` class with proper answer extraction logic
|
||||
- GSM8K uses `"question"` field instead of `"problem"` field
|
||||
- GSM8K answer field contains full reasoning with `####` prefix
|
||||
- Extracts numeric answer from answer field during initialization
|
||||
- Uses same regex grader pattern as AIME (`\b(\d+)\b`)
|
||||
|
||||
2. **Dataset Type Configuration** - Added dataset selection support
|
||||
- Added `--dataset` CLI argument with choices `aime` and `gsm8k`
|
||||
- Updated `Processor` class to accept `dataset_type` parameter
|
||||
- Dataset-specific initialization in `Processor.__init__()`
|
||||
- Dataset name displayed in task summary table
|
||||
|
||||
3. **Template Registry** - Added dataset-specific prompt templates
|
||||
- AIME template: includes `\boxed{}` wrapper for final answer
|
||||
- GSM8K template: plain text answer without wrapper
|
||||
- Templates applied based on `question["dataset_type"]` field
|
||||
|
||||
4. **Answer Extraction Logic** - Fixed GSM8K answer extraction
|
||||
- GSM8K has pre-extracted `"gold"` field with numeric answer
|
||||
- `Gsm8kDataset.get_answer()` checks for `"gold"` field first
|
||||
- Falls back to answer field if gold field not present
|
||||
- `AimeDataset.get_answer()` simplified to remove duplicate method
|
||||
|
||||
5. **Task ID Format** - Fixed duplicate prefix in task IDs
|
||||
- Changed from `f"{dataset_type}_{eval_state.id}_{chunk_idx:03d}_{i:03d}"`
|
||||
- To `f"{dataset_type}_{chunk_idx:03d}_{i:03d}"`
|
||||
- Removed redundant `eval_state.id` (was "gsm8k" for GSM8K)
|
||||
|
||||
6. **Column Width Adjustments** - Improved table formatting
|
||||
- Task ID column: 25 characters
|
||||
- Dataset column: 5 characters
|
||||
- Prompt column: 40 characters
|
||||
- Expected column: 10 characters
|
||||
|
||||
**Testing Results:**
|
||||
- ✅ GSM8K dataset loads correctly with 7473 questions
|
||||
- ✅ Numeric answers extracted from full reasoning text
|
||||
- ✅ Task summary table displays correctly with adjusted column widths
|
||||
- ✅ Task IDs show correct format (e.g., `gsm8k_000_3169`)
|
||||
- ✅ Both AIME and GSM8K datasets work with same script
|
||||
- ✅ Answer extraction works for both boxed and plain text formats
|
||||
- ✅ Progress tracking shows extracted answers for both datasets
|
||||
|
||||
**Key Technical Decisions:**
|
||||
- GSM8K uses `"question"` field instead of `"problem"` field
|
||||
- GSM8K answer field contains full reasoning with `####` prefix
|
||||
- Numeric answer extracted during dataset initialization
|
||||
- Same regex grader pattern works for both datasets
|
||||
- Dataset selection via CLI argument for separate runs
|
||||
- Template registry supports different prompt formats per dataset
|
||||
- Task ID format simplified to avoid duplication
|
||||
|
||||
**Refactoring:**
|
||||
- Removed duplicate `get_question()` method from `AimeDataset`
|
||||
- Removed "2025" suffix from eval state ID (was remnant from old version)
|
||||
- Removed "2025" suffix from task summary table output
|
||||
- Removed "2025" suffix from progress tracking output
|
||||
- Updated `Processor.__init__()` to initialize appropriate dataset based on type
|
||||
- Updated `_process_single_case()` to handle both `"problem"` and `"question"` fields
|
||||
- Updated `process()` method to display dataset name and use `dataset_type` for task states
|
||||
|
|
|
|||
|
|
@ -31,6 +31,9 @@ GRADER_PATTERNS = {
|
|||
TEMPLATE_REGISTRY = {
|
||||
"aime": """{question}
|
||||
Please reason step by step, and put your final answer within \\boxed{{}}.
|
||||
""",
|
||||
"gsm8k": """{question}
|
||||
Please reason step by step, and provide your final answer.
|
||||
""",
|
||||
}
|
||||
|
||||
|
|
@ -93,6 +96,56 @@ class AimeDataset:
|
|||
return str(normalized) if normalized is not None else answer
|
||||
return str(answer)
|
||||
|
||||
class Gsm8kDataset:
|
||||
def __init__(self, split: str = "train"):
|
||||
self.split = split
|
||||
self.questions: List[Dict] = []
|
||||
self._load_dataset()
|
||||
|
||||
def _load_dataset(self):
|
||||
print(f"Loading GSM8K dataset (split: {self.split})...")
|
||||
from datasets import load_dataset
|
||||
|
||||
cache_path = cache_dir / "openai___gsm8k" / "default" / "0.0.0"
|
||||
if cache_path.exists():
|
||||
print(f"Using cached dataset from {cache_path}")
|
||||
ds = load_dataset("openai/gsm8k", "main", split=self.split, cache_dir=str(cache_path))
|
||||
else:
|
||||
ds = load_dataset("openai/gsm8k", "main", split=self.split)
|
||||
|
||||
self.questions = []
|
||||
for row in ds:
|
||||
question = dict(row)
|
||||
question["dataset_type"] = "gsm8k"
|
||||
|
||||
# Extract numeric answer from the answer field (already has #### prefix)
|
||||
gold = question["answer"]
|
||||
# Split by #### and take the last part
|
||||
parts = gold.split("####")
|
||||
if len(parts) > 1:
|
||||
gold = parts[-1].strip()
|
||||
# Extract the first number from the remaining text
|
||||
normalized = normalize_number(gold)
|
||||
question["gold"] = str(normalized) if normalized is not None else gold
|
||||
|
||||
self.questions.append(question)
|
||||
|
||||
print(f"GSM8K dataset loaded: {len(self.questions)} questions")
|
||||
|
||||
def get_question(self, index: int) -> Dict:
|
||||
"""Get question by index"""
|
||||
return self.questions[index]
|
||||
|
||||
def get_answer(self, question: Dict) -> str:
|
||||
# GSM8K has pre-extracted gold field, AIME uses answer field
|
||||
if "gold" in question:
|
||||
return question["gold"]
|
||||
answer = question["answer"]
|
||||
if isinstance(answer, str):
|
||||
normalized = normalize_number(answer)
|
||||
return str(normalized) if normalized is not None else answer
|
||||
return str(answer)
|
||||
|
||||
class Grader:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -217,7 +270,8 @@ class Processor:
|
|||
grader: Optional[Grader] = None,
|
||||
model_name: Optional[str] = None,
|
||||
judge_server_url: str = "",
|
||||
judge_model_name: Optional[str] = None
|
||||
judge_model_name: Optional[str] = None,
|
||||
dataset_type: str = "aime"
|
||||
):
|
||||
self.server_url = server_url
|
||||
self.n_predict = n_predict
|
||||
|
|
@ -226,11 +280,11 @@ class Processor:
|
|||
self.model_name = model_name
|
||||
self.judge_server_url = judge_server_url if judge_server_url else server_url
|
||||
self.judge_model_name = judge_model_name
|
||||
self.dataset = AimeDataset()
|
||||
self.dataset_type = dataset_type
|
||||
self.grader = grader or Grader()
|
||||
self.eval_state = EvalState(
|
||||
id="aime-2025",
|
||||
tasks=["aime"],
|
||||
id=dataset_type,
|
||||
tasks=[dataset_type],
|
||||
task_states={},
|
||||
sampling_config={"temperature": 0, "max_tokens": n_predict}
|
||||
)
|
||||
|
|
@ -242,6 +296,14 @@ class Processor:
|
|||
if self.judge_server_url:
|
||||
self.grader.judge_server_url = self.judge_server_url
|
||||
|
||||
# Initialize appropriate dataset
|
||||
if dataset_type == "aime":
|
||||
self.dataset = AimeDataset()
|
||||
elif dataset_type == "gsm8k":
|
||||
self.dataset = Gsm8kDataset()
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset type: {dataset_type}")
|
||||
|
||||
def _make_request(self, prompt: str) -> Dict[str, Any]:
|
||||
"""Make HTTP request to the server"""
|
||||
url = f"{self.server_url}/v1/chat/completions"
|
||||
|
|
@ -260,14 +322,14 @@ class Processor:
|
|||
def _process_single_case(self, i: int, task_id: str) -> TaskState:
|
||||
"""Process a single case (thread-safe)"""
|
||||
question = self.dataset.get_question(i)
|
||||
dataset_id = f"aime_{self.dataset.split}_{question['id']}"
|
||||
dataset_id = f"{self.dataset_type}_{self.dataset.split}_{i}"
|
||||
gold = self.dataset.get_answer(question)
|
||||
|
||||
# Apply template if available
|
||||
if question["dataset_type"] in TEMPLATE_REGISTRY:
|
||||
prompt = TEMPLATE_REGISTRY[question["dataset_type"]].format(question=question["problem"])
|
||||
prompt = TEMPLATE_REGISTRY[question["dataset_type"]].format(question=question["problem"] if "problem" in question else question["question"])
|
||||
else:
|
||||
prompt = question["problem"]
|
||||
prompt = question["problem"] if "problem" in question else question["question"]
|
||||
|
||||
task_state = TaskState(
|
||||
case_id=task_id,
|
||||
|
|
@ -298,7 +360,7 @@ class Processor:
|
|||
if n_cases is None:
|
||||
n_cases = len(self.dataset.questions)
|
||||
|
||||
print(f"\nProcessing {n_cases} AIME questions...")
|
||||
print(f"\nProcessing {n_cases} {self.dataset_type.upper()} questions...")
|
||||
print(f"Server: {self.server_url}")
|
||||
print(f"Threads: {self.threads}")
|
||||
print(f"Max tokens: {self.n_predict}")
|
||||
|
|
@ -319,18 +381,18 @@ class Processor:
|
|||
chunk_indices = indices[:chunk_size]
|
||||
|
||||
for i in chunk_indices:
|
||||
task_id = f"aime_{self.eval_state.id}_{chunk_idx:03d}_{i:03d}"
|
||||
task_id = f"{self.dataset_type}_{chunk_idx:03d}_{i:03d}"
|
||||
task_list.append((i, task_id))
|
||||
|
||||
# Print task summary table
|
||||
print("Tasks:")
|
||||
print(" Task ID Dataset Prompt (first 40 chars) Expected Status")
|
||||
print(" Task ID Dataset Prompt (first 40 chars) Expected Status")
|
||||
for i, task_id in task_list:
|
||||
question = self.dataset.get_question(i)
|
||||
prompt = question["problem"]
|
||||
prompt = question["problem"] if "problem" in question else question["question"]
|
||||
gold = self.dataset.get_answer(question)
|
||||
truncated_prompt = prompt[:40] + "..." if len(prompt) > 40 else prompt
|
||||
print(f" {task_id:<15} AIME2025 {truncated_prompt:<40} {gold:<10} pending")
|
||||
print(f" {task_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {gold:<10} pending")
|
||||
print()
|
||||
|
||||
task_states: Dict[str, List[TaskState]] = {task: [] for task in self.eval_state.tasks}
|
||||
|
|
@ -342,7 +404,7 @@ class Processor:
|
|||
|
||||
for future in as_completed(futures):
|
||||
task_state = future.result()
|
||||
task_states["aime"].append(task_state)
|
||||
task_states[self.dataset_type].append(task_state)
|
||||
total += 1
|
||||
|
||||
if task_state.correct:
|
||||
|
|
@ -351,7 +413,7 @@ class Processor:
|
|||
# Print task completion status
|
||||
extracted_display = task_state.extracted if task_state.extracted else "N/A"
|
||||
success_ratio = correct / total if total > 0 else 0.0
|
||||
print(f"{total:3}/{n_cases:3} {task_state.case_id:<15} AIME2025 {task_state.prompt[:40]:<40} {task_state.gold:<10} {extracted_display:<10} {'✓' if task_state.correct else '✗'} [{correct:3}/{total:3}, {success_ratio:.3f}]")
|
||||
print(f"{total:3}/{n_cases:3} {task_state.case_id:<20} {self.dataset_type.upper()} {task_state.prompt[:40]:<40} {task_state.gold:<10} {extracted_display:<10} {'✓' if task_state.correct else '✗'} [{correct:3}/{total:3}, {success_ratio:.3f}]")
|
||||
|
||||
if self.verbose:
|
||||
print(f"\nCase {total}: {task_state.correct}")
|
||||
|
|
@ -362,7 +424,7 @@ class Processor:
|
|||
print(f" Extracted: {task_state.extracted}")
|
||||
print(f" Status: {task_state.status}")
|
||||
|
||||
self.eval_state.task_states["aime"] = {
|
||||
self.eval_state.task_states[self.dataset_type] = {
|
||||
"total": total,
|
||||
"correct": correct,
|
||||
"cases": task_states
|
||||
|
|
@ -382,7 +444,7 @@ class Processor:
|
|||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Simplified AIME evaluation tool for llama.cpp"
|
||||
description="Simplified evaluation tool for llama.cpp"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--server",
|
||||
|
|
@ -390,6 +452,13 @@ def main():
|
|||
default="http://localhost:8033",
|
||||
help="llama-server URL (default: http://localhost:8033)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="aime",
|
||||
choices=["aime", "gsm8k"],
|
||||
help="Dataset type (default: aime)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_cases",
|
||||
type=int,
|
||||
|
|
@ -483,7 +552,8 @@ def main():
|
|||
grader=grader,
|
||||
model_name=args.model,
|
||||
judge_server_url=args.judge_server,
|
||||
judge_model_name=args.judge_model
|
||||
judge_model_name=args.judge_model,
|
||||
dataset_type=args.dataset
|
||||
)
|
||||
|
||||
eval_state = processor.process(n_cases=args.n_cases, seed=args.seed)
|
||||
|
|
|
|||
Loading…
Reference in New Issue