From e8a807519a8b57368f04ac542596cfd6c52520b6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 15 Feb 2026 23:19:46 +0200 Subject: [PATCH] datasets : add gsm8k --- examples/llama-eval/llama-eval-discussion.md | 65 ++++++++++++ examples/llama-eval/llama-eval-new.py | 104 ++++++++++++++++--- 2 files changed, 152 insertions(+), 17 deletions(-) diff --git a/examples/llama-eval/llama-eval-discussion.md b/examples/llama-eval/llama-eval-discussion.md index 57bcda138f..1747aa0655 100644 --- a/examples/llama-eval/llama-eval-discussion.md +++ b/examples/llama-eval/llama-eval-discussion.md @@ -328,3 +328,68 @@ Questions: - Updated `_grade_llm()` to use instance variables instead of parameters - Simplified Processor initialization to pass judge config to grader - Updated startup info to show judge server and model + +### llama-eval-new.py GSM8K Dataset Support + +**Changes Made:** +1. **GSM8K Dataset Integration** - Added support for GSM8K dataset alongside AIME + - Created `Gsm8kDataset` class with proper answer extraction logic + - GSM8K uses `"question"` field instead of `"problem"` field + - GSM8K answer field contains full reasoning with `####` prefix + - Extracts numeric answer from answer field during initialization + - Uses same regex grader pattern as AIME (`\b(\d+)\b`) + +2. **Dataset Type Configuration** - Added dataset selection support + - Added `--dataset` CLI argument with choices `aime` and `gsm8k` + - Updated `Processor` class to accept `dataset_type` parameter + - Dataset-specific initialization in `Processor.__init__()` + - Dataset name displayed in task summary table + +3. **Template Registry** - Added dataset-specific prompt templates + - AIME template: includes `\boxed{}` wrapper for final answer + - GSM8K template: plain text answer without wrapper + - Templates applied based on `question["dataset_type"]` field + +4. **Answer Extraction Logic** - Fixed GSM8K answer extraction + - GSM8K has pre-extracted `"gold"` field with numeric answer + - `Gsm8kDataset.get_answer()` checks for `"gold"` field first + - Falls back to answer field if gold field not present + - `AimeDataset.get_answer()` simplified to remove duplicate method + +5. **Task ID Format** - Fixed duplicate prefix in task IDs + - Changed from `f"{dataset_type}_{eval_state.id}_{chunk_idx:03d}_{i:03d}"` + - To `f"{dataset_type}_{chunk_idx:03d}_{i:03d}"` + - Removed redundant `eval_state.id` (was "gsm8k" for GSM8K) + +6. **Column Width Adjustments** - Improved table formatting + - Task ID column: 25 characters + - Dataset column: 5 characters + - Prompt column: 40 characters + - Expected column: 10 characters + +**Testing Results:** +- ✅ GSM8K dataset loads correctly with 7473 questions +- ✅ Numeric answers extracted from full reasoning text +- ✅ Task summary table displays correctly with adjusted column widths +- ✅ Task IDs show correct format (e.g., `gsm8k_000_3169`) +- ✅ Both AIME and GSM8K datasets work with same script +- ✅ Answer extraction works for both boxed and plain text formats +- ✅ Progress tracking shows extracted answers for both datasets + +**Key Technical Decisions:** +- GSM8K uses `"question"` field instead of `"problem"` field +- GSM8K answer field contains full reasoning with `####` prefix +- Numeric answer extracted during dataset initialization +- Same regex grader pattern works for both datasets +- Dataset selection via CLI argument for separate runs +- Template registry supports different prompt formats per dataset +- Task ID format simplified to avoid duplication + +**Refactoring:** +- Removed duplicate `get_question()` method from `AimeDataset` +- Removed "2025" suffix from eval state ID (was remnant from old version) +- Removed "2025" suffix from task summary table output +- Removed "2025" suffix from progress tracking output +- Updated `Processor.__init__()` to initialize appropriate dataset based on type +- Updated `_process_single_case()` to handle both `"problem"` and `"question"` fields +- Updated `process()` method to display dataset name and use `dataset_type` for task states diff --git a/examples/llama-eval/llama-eval-new.py b/examples/llama-eval/llama-eval-new.py index ff62777653..8426dae724 100755 --- a/examples/llama-eval/llama-eval-new.py +++ b/examples/llama-eval/llama-eval-new.py @@ -31,6 +31,9 @@ GRADER_PATTERNS = { TEMPLATE_REGISTRY = { "aime": """{question} Please reason step by step, and put your final answer within \\boxed{{}}. +""", + "gsm8k": """{question} +Please reason step by step, and provide your final answer. """, } @@ -93,6 +96,56 @@ class AimeDataset: return str(normalized) if normalized is not None else answer return str(answer) +class Gsm8kDataset: + def __init__(self, split: str = "train"): + self.split = split + self.questions: List[Dict] = [] + self._load_dataset() + + def _load_dataset(self): + print(f"Loading GSM8K dataset (split: {self.split})...") + from datasets import load_dataset + + cache_path = cache_dir / "openai___gsm8k" / "default" / "0.0.0" + if cache_path.exists(): + print(f"Using cached dataset from {cache_path}") + ds = load_dataset("openai/gsm8k", "main", split=self.split, cache_dir=str(cache_path)) + else: + ds = load_dataset("openai/gsm8k", "main", split=self.split) + + self.questions = [] + for row in ds: + question = dict(row) + question["dataset_type"] = "gsm8k" + + # Extract numeric answer from the answer field (already has #### prefix) + gold = question["answer"] + # Split by #### and take the last part + parts = gold.split("####") + if len(parts) > 1: + gold = parts[-1].strip() + # Extract the first number from the remaining text + normalized = normalize_number(gold) + question["gold"] = str(normalized) if normalized is not None else gold + + self.questions.append(question) + + print(f"GSM8K dataset loaded: {len(self.questions)} questions") + + def get_question(self, index: int) -> Dict: + """Get question by index""" + return self.questions[index] + + def get_answer(self, question: Dict) -> str: + # GSM8K has pre-extracted gold field, AIME uses answer field + if "gold" in question: + return question["gold"] + answer = question["answer"] + if isinstance(answer, str): + normalized = normalize_number(answer) + return str(normalized) if normalized is not None else answer + return str(answer) + class Grader: def __init__( self, @@ -217,7 +270,8 @@ class Processor: grader: Optional[Grader] = None, model_name: Optional[str] = None, judge_server_url: str = "", - judge_model_name: Optional[str] = None + judge_model_name: Optional[str] = None, + dataset_type: str = "aime" ): self.server_url = server_url self.n_predict = n_predict @@ -226,11 +280,11 @@ class Processor: self.model_name = model_name self.judge_server_url = judge_server_url if judge_server_url else server_url self.judge_model_name = judge_model_name - self.dataset = AimeDataset() + self.dataset_type = dataset_type self.grader = grader or Grader() self.eval_state = EvalState( - id="aime-2025", - tasks=["aime"], + id=dataset_type, + tasks=[dataset_type], task_states={}, sampling_config={"temperature": 0, "max_tokens": n_predict} ) @@ -242,6 +296,14 @@ class Processor: if self.judge_server_url: self.grader.judge_server_url = self.judge_server_url + # Initialize appropriate dataset + if dataset_type == "aime": + self.dataset = AimeDataset() + elif dataset_type == "gsm8k": + self.dataset = Gsm8kDataset() + else: + raise ValueError(f"Unknown dataset type: {dataset_type}") + def _make_request(self, prompt: str) -> Dict[str, Any]: """Make HTTP request to the server""" url = f"{self.server_url}/v1/chat/completions" @@ -260,14 +322,14 @@ class Processor: def _process_single_case(self, i: int, task_id: str) -> TaskState: """Process a single case (thread-safe)""" question = self.dataset.get_question(i) - dataset_id = f"aime_{self.dataset.split}_{question['id']}" + dataset_id = f"{self.dataset_type}_{self.dataset.split}_{i}" gold = self.dataset.get_answer(question) # Apply template if available if question["dataset_type"] in TEMPLATE_REGISTRY: - prompt = TEMPLATE_REGISTRY[question["dataset_type"]].format(question=question["problem"]) + prompt = TEMPLATE_REGISTRY[question["dataset_type"]].format(question=question["problem"] if "problem" in question else question["question"]) else: - prompt = question["problem"] + prompt = question["problem"] if "problem" in question else question["question"] task_state = TaskState( case_id=task_id, @@ -298,7 +360,7 @@ class Processor: if n_cases is None: n_cases = len(self.dataset.questions) - print(f"\nProcessing {n_cases} AIME questions...") + print(f"\nProcessing {n_cases} {self.dataset_type.upper()} questions...") print(f"Server: {self.server_url}") print(f"Threads: {self.threads}") print(f"Max tokens: {self.n_predict}") @@ -319,18 +381,18 @@ class Processor: chunk_indices = indices[:chunk_size] for i in chunk_indices: - task_id = f"aime_{self.eval_state.id}_{chunk_idx:03d}_{i:03d}" + task_id = f"{self.dataset_type}_{chunk_idx:03d}_{i:03d}" task_list.append((i, task_id)) # Print task summary table print("Tasks:") - print(" Task ID Dataset Prompt (first 40 chars) Expected Status") + print(" Task ID Dataset Prompt (first 40 chars) Expected Status") for i, task_id in task_list: question = self.dataset.get_question(i) - prompt = question["problem"] + prompt = question["problem"] if "problem" in question else question["question"] gold = self.dataset.get_answer(question) truncated_prompt = prompt[:40] + "..." if len(prompt) > 40 else prompt - print(f" {task_id:<15} AIME2025 {truncated_prompt:<40} {gold:<10} pending") + print(f" {task_id:<20} {self.dataset_type.upper()} {truncated_prompt:<40} {gold:<10} pending") print() task_states: Dict[str, List[TaskState]] = {task: [] for task in self.eval_state.tasks} @@ -342,7 +404,7 @@ class Processor: for future in as_completed(futures): task_state = future.result() - task_states["aime"].append(task_state) + task_states[self.dataset_type].append(task_state) total += 1 if task_state.correct: @@ -351,7 +413,7 @@ class Processor: # Print task completion status extracted_display = task_state.extracted if task_state.extracted else "N/A" success_ratio = correct / total if total > 0 else 0.0 - print(f"{total:3}/{n_cases:3} {task_state.case_id:<15} AIME2025 {task_state.prompt[:40]:<40} {task_state.gold:<10} {extracted_display:<10} {'✓' if task_state.correct else '✗'} [{correct:3}/{total:3}, {success_ratio:.3f}]") + print(f"{total:3}/{n_cases:3} {task_state.case_id:<20} {self.dataset_type.upper()} {task_state.prompt[:40]:<40} {task_state.gold:<10} {extracted_display:<10} {'✓' if task_state.correct else '✗'} [{correct:3}/{total:3}, {success_ratio:.3f}]") if self.verbose: print(f"\nCase {total}: {task_state.correct}") @@ -362,7 +424,7 @@ class Processor: print(f" Extracted: {task_state.extracted}") print(f" Status: {task_state.status}") - self.eval_state.task_states["aime"] = { + self.eval_state.task_states[self.dataset_type] = { "total": total, "correct": correct, "cases": task_states @@ -382,7 +444,7 @@ class Processor: def main(): parser = argparse.ArgumentParser( - description="Simplified AIME evaluation tool for llama.cpp" + description="Simplified evaluation tool for llama.cpp" ) parser.add_argument( "--server", @@ -390,6 +452,13 @@ def main(): default="http://localhost:8033", help="llama-server URL (default: http://localhost:8033)" ) + parser.add_argument( + "--dataset", + type=str, + default="aime", + choices=["aime", "gsm8k"], + help="Dataset type (default: aime)" + ) parser.add_argument( "--n_cases", type=int, @@ -483,7 +552,8 @@ def main(): grader=grader, model_name=args.model, judge_server_url=args.judge_server, - judge_model_name=args.judge_model + judge_model_name=args.judge_model, + dataset_type=args.dataset ) eval_state = processor.process(n_cases=args.n_cases, seed=args.seed)