examples: add threading support and model parameter to llama-eval-new.py
- Add ThreadPoolExecutor for parallel request processing controlled by --threads - Add --model argument to specify model name in request data - Refactor process() to use thread-safe _process_single_case() method - Update progress tracking to work with concurrent execution
This commit is contained in:
parent
37b26cafee
commit
62b04cef54
|
|
@ -6,6 +6,7 @@ import os
|
|||
import re
|
||||
import subprocess
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, asdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
|
@ -140,12 +141,14 @@ class Processor:
|
|||
n_predict: int = 2048,
|
||||
threads: int = 32,
|
||||
verbose: bool = False,
|
||||
grader: Optional[Grader] = None
|
||||
grader: Optional[Grader] = None,
|
||||
model_name: Optional[str] = None
|
||||
):
|
||||
self.server_url = server_url
|
||||
self.n_predict = n_predict
|
||||
self.threads = threads
|
||||
self.verbose = verbose
|
||||
self.model_name = model_name
|
||||
self.dataset = AimeDataset()
|
||||
self.grader = grader or Grader()
|
||||
self.eval_state = EvalState(
|
||||
|
|
@ -160,7 +163,7 @@ class Processor:
|
|||
url = f"{self.server_url}/v1/chat/completions"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = {
|
||||
"model": "llama",
|
||||
"model": self.model_name if self.model_name else "llama",
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0,
|
||||
"max_tokens": self.n_predict
|
||||
|
|
@ -170,6 +173,30 @@ class Processor:
|
|||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _process_single_case(self, i: int) -> TaskState:
|
||||
"""Process a single case (thread-safe)"""
|
||||
question = self.dataset.get_question(i)
|
||||
case_id = f"aime_{self.dataset.split}_{question['id']}"
|
||||
prompt = question["problem"]
|
||||
gold = self.dataset.get_answer(question)
|
||||
|
||||
task_state = TaskState(
|
||||
case_id=case_id,
|
||||
prompt=prompt,
|
||||
gold=gold
|
||||
)
|
||||
|
||||
try:
|
||||
response = self._make_request(prompt)
|
||||
pred = response["choices"][0]["message"]["content"]
|
||||
task_state.pred = pred
|
||||
task_state.correct = self.grader.grade(gold, pred)
|
||||
task_state.status = "ok"
|
||||
except Exception as e:
|
||||
task_state.status = f"error: {str(e)}"
|
||||
|
||||
return task_state
|
||||
|
||||
def process(self, n_cases: int = None, seed: int = 42):
|
||||
"""Process cases and update eval state"""
|
||||
if n_cases is None:
|
||||
|
|
@ -185,39 +212,25 @@ class Processor:
|
|||
total = 0
|
||||
correct = 0
|
||||
|
||||
for i in tqdm(range(min(n_cases, len(self.dataset.questions))), desc="Processing"):
|
||||
question = self.dataset.get_question(i)
|
||||
case_id = f"aime_{self.dataset.split}_{question['id']}"
|
||||
prompt = question["problem"]
|
||||
gold = self.dataset.get_answer(question)
|
||||
indices = list(range(min(n_cases, len(self.dataset.questions))))
|
||||
|
||||
task_state = TaskState(
|
||||
case_id=case_id,
|
||||
prompt=prompt,
|
||||
gold=gold
|
||||
)
|
||||
with ThreadPoolExecutor(max_workers=self.threads) as executor:
|
||||
futures = {executor.submit(self._process_single_case, i): i for i in indices}
|
||||
|
||||
try:
|
||||
response = self._make_request(prompt)
|
||||
pred = response["choices"][0]["message"]["content"]
|
||||
task_state.pred = pred
|
||||
task_state.correct = self.grader.grade(gold, pred)
|
||||
task_state.status = "ok"
|
||||
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
|
||||
task_state = future.result()
|
||||
task_states["aime"].append(task_state)
|
||||
total += 1
|
||||
|
||||
if task_state.correct:
|
||||
correct += 1
|
||||
except Exception as e:
|
||||
task_state.status = f"error: {str(e)}"
|
||||
|
||||
task_states["aime"].append(task_state)
|
||||
total += 1
|
||||
|
||||
if self.verbose:
|
||||
print(f"\nCase {i+1}/{total}: {task_state.correct}")
|
||||
print(f" Gold: {gold}")
|
||||
if task_state.pred:
|
||||
print(f" Pred: {task_state.pred}")
|
||||
print(f" Status: {task_state.status}")
|
||||
if self.verbose:
|
||||
print(f"\nCase {total}: {task_state.correct}")
|
||||
print(f" Gold: {task_state.gold}")
|
||||
if task_state.pred:
|
||||
print(f" Pred: {task_state.pred}")
|
||||
print(f" Status: {task_state.status}")
|
||||
|
||||
self.eval_state.task_states["aime"] = {
|
||||
"total": total,
|
||||
|
|
@ -265,6 +278,12 @@ def main():
|
|||
default=32,
|
||||
help="Number of threads for parallel requests (default: 32)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model name to append as query parameter (e.g., gpt-oss-20b-hf)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
|
|
@ -310,7 +329,8 @@ def main():
|
|||
n_predict=args.n_predict,
|
||||
threads=args.threads,
|
||||
verbose=args.verbose,
|
||||
grader=grader
|
||||
grader=grader,
|
||||
model_name=args.model
|
||||
)
|
||||
|
||||
eval_state = processor.process(n_cases=args.n_cases)
|
||||
|
|
|
|||
Loading…
Reference in New Issue