eval : support multiple dataset runs

This commit is contained in:
Georgi Gerganov 2026-02-02 22:34:25 +02:00
parent 8156d549f6
commit fd90796da2
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
1 changed files with 29 additions and 11 deletions

View File

@ -12,6 +12,7 @@ from pathlib import Path
from typing import Dict, List, Optional, Any from typing import Dict, List, Optional, Any
import requests import requests
from tqdm import tqdm from tqdm import tqdm
import random
cache_dir = Path.home() / ".cache" / "huggingface" / "datasets" cache_dir = Path.home() / ".cache" / "huggingface" / "datasets"
cache_dir.mkdir(parents=True, exist_ok=True) cache_dir.mkdir(parents=True, exist_ok=True)
@ -194,10 +195,10 @@ class Processor:
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
def _process_single_case(self, i: int) -> TaskState: def _process_single_case(self, i: int, task_id: str) -> TaskState:
"""Process a single case (thread-safe)""" """Process a single case (thread-safe)"""
question = self.dataset.get_question(i) question = self.dataset.get_question(i)
case_id = f"aime_{self.dataset.split}_{question['id']}" dataset_id = f"aime_{self.dataset.split}_{question['id']}"
gold = self.dataset.get_answer(question) gold = self.dataset.get_answer(question)
# Apply template if available # Apply template if available
@ -207,7 +208,7 @@ class Processor:
prompt = question["problem"] prompt = question["problem"]
task_state = TaskState( task_state = TaskState(
case_id=case_id, case_id=task_id,
prompt=prompt, prompt=prompt,
gold=gold gold=gold
) )
@ -223,7 +224,7 @@ class Processor:
return task_state return task_state
def process(self, n_cases: int = None, seed: int = 42): def process(self, n_cases: int = None, seed: int = 1234):
"""Process cases and update eval state""" """Process cases and update eval state"""
if n_cases is None: if n_cases is None:
n_cases = len(self.dataset.questions) n_cases = len(self.dataset.questions)
@ -234,26 +235,37 @@ class Processor:
print(f"Max tokens: {self.n_predict}") print(f"Max tokens: {self.n_predict}")
print() print()
dataset_size = len(self.dataset.questions)
random.seed(seed)
task_list = []
for chunk_idx in range((n_cases + dataset_size - 1) // dataset_size):
chunk_size = min(dataset_size, n_cases - chunk_idx * dataset_size)
indices = list(range(dataset_size))
random.shuffle(indices)
chunk_indices = indices[:chunk_size]
for i in chunk_indices:
task_id = f"aime_{self.eval_state.id}_{chunk_idx:03d}_{i:03d}"
task_list.append((i, task_id))
# Print task summary table # Print task summary table
print("Tasks:") print("Tasks:")
print(" Task ID Dataset Prompt (first 40 chars) Expected Status") print(" Task ID Dataset Prompt (first 40 chars) Expected Status")
for i in range(min(n_cases, len(self.dataset.questions))): for i, task_id in task_list:
question = self.dataset.get_question(i) question = self.dataset.get_question(i)
case_id = f"aime_{self.dataset.split}_{question['id']}"
prompt = question["problem"] prompt = question["problem"]
gold = self.dataset.get_answer(question) gold = self.dataset.get_answer(question)
truncated_prompt = prompt[:40] + "..." if len(prompt) > 40 else prompt truncated_prompt = prompt[:40] + "..." if len(prompt) > 40 else prompt
print(f" {case_id:<15} AIME2025 {truncated_prompt:<40} {gold:<10} pending") print(f" {task_id:<15} AIME2025 {truncated_prompt:<40} {gold:<10} pending")
print() print()
task_states: Dict[str, List[TaskState]] = {task: [] for task in self.eval_state.tasks} task_states: Dict[str, List[TaskState]] = {task: [] for task in self.eval_state.tasks}
total = 0 total = 0
correct = 0 correct = 0
indices = list(range(min(n_cases, len(self.dataset.questions))))
with ThreadPoolExecutor(max_workers=self.threads) as executor: with ThreadPoolExecutor(max_workers=self.threads) as executor:
futures = {executor.submit(self._process_single_case, i): i for i in indices} futures = {executor.submit(self._process_single_case, i, task_id): (i, task_id) for i, task_id in task_list}
for future in as_completed(futures): for future in as_completed(futures):
task_state = future.result() task_state = future.result()
@ -309,6 +321,12 @@ def main():
default=None, default=None,
help="Number of cases to evaluate (default: all)" help="Number of cases to evaluate (default: all)"
) )
parser.add_argument(
"--seed",
type=int,
default=1234,
help="Random seed for shuffling (default: 1234)"
)
parser.add_argument( parser.add_argument(
"--n_predict", "--n_predict",
type=int, type=int,
@ -376,7 +394,7 @@ def main():
model_name=args.model model_name=args.model
) )
eval_state = processor.process(n_cases=args.n_cases) eval_state = processor.process(n_cases=args.n_cases, seed=args.seed)
processor.dump_state(args.output) processor.dump_state(args.output)
if __name__ == "__main__": if __name__ == "__main__":