add checkpointing

This commit is contained in:
gatbontonpc 2026-01-16 17:58:31 -05:00
parent b0d50a5681
commit 979299a32f
2 changed files with 153 additions and 50 deletions

View File

@ -1,20 +1,17 @@
# llama.cpp/example/llama-eval
The purpose of this example is to to run evaluations metrics against a an openapi api compatible LLM via http (llama-server).
`llama-eval.py` is a single-script evaluation runner that sends prompt/response pairs to any OpenAI-compatible HTTP server (the default `llama-server`).
```bash
./llama-server -m model.gguf --port 8033
python examples/llama-eval/llama-eval.py --path_server http://localhost:8033 --n_prompts 100 --prompt_source arc
```
```bash
python examples/llama-eval/llama-eval.py --path_server http://localhost:8033 --n_prompt 100 --prompt_source arc
```
The supported tasks are:
## Supported tasks (MVP)
- **GSM8K** — grade-school math (final-answer only)
- **AIME** — competition math (final-answer only)
- **MMLU** — multi-domain knowledge (multiple choice)
- **HellaSwag** — commonsense reasoning (multiple choice)
- **ARC** — grade-school science reasoning (multiple choice)
- **WinoGrande** — commonsense coreference resolution (multiple choice)
- **GSM8K** — grade-school math
- **AIME** — competition math (integer answers)
- **MMLU** — multi-domain multiple choice
- **HellaSwag** — commonsense reasoning multiple choice
- **ARC** — grade-school science multiple choice
- **WinoGrande** — commonsense coreference multiple choice

View File

@ -10,9 +10,12 @@ import datasets
import logging
import requests
from tqdm.contrib.concurrent import thread_map
from typing import Iterator
from typing import Iterator, Set
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
import json
import threading
logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger("llama-eval")
@ -47,7 +50,7 @@ def extract_boxed_text(text: str) -> str:
for group in match:
if group != "":
return group.split(",")[-1].strip()
logger.warning("Could not extract boxed text. Maybe expand context window")
logger.debug("Could not extract boxed text. Maybe expand context window")
return ""
@ -130,8 +133,9 @@ class MathTaskSpec(TaskSpec):
try:
extracted_answer = extract_boxed_text(response["choices"][0]["text"])
except Exception as e:
except:
result["status"] = "error"
logger.warning("ERROR: extract_boxed_text")
return result
source_answer = case.gold
@ -155,9 +159,12 @@ class ARC_Task(MCTaskSpec):
def __init__(self):
self.name = "arc"
self.kind = "mc"
self.config = "ARC-Challenge"
self.split = "test"
def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset("allenai/ai2_arc", "ARC-Challenge", split="test")
ds = datasets.load_dataset("allenai/ai2_arc", self.config, split=self.split)
ds = ds.add_column("_row_id", list(range(len(ds))))
if limit:
ds = ds.shuffle(seed=seed)
ds = ds.select(range(min(limit, len(ds))))
@ -166,7 +173,7 @@ class ARC_Task(MCTaskSpec):
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = self.load(limit, seed)
for i, doc in enumerate(ds):
for doc in ds:
doc = cast(Mapping[str, Any], doc)
prompt, labels = format_multiple_choice(
@ -175,7 +182,7 @@ class ARC_Task(MCTaskSpec):
yield Case(
task=self.name,
kind=self.kind,
case_id=f"ARC-Challenge:{i}",
case_id=f"ARC-Challenge_{self.config}_{self.split}_{doc['_row_id']}",
prompt=prompt,
gold=doc["answerKey"],
meta_data={"labels": labels},
@ -187,11 +194,13 @@ class WinoGrande_Task(MCTaskSpec):
def __init__(self):
self.name = "winogrande"
self.kind = "mc"
self.config = "winogrande_debiased"
self.split = "validation"
def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset(
"winogrande", "winogrande_debiased", split="validation"
)
ds = datasets.load_dataset("winogrande", self.config, split=self.split)
ds = ds.add_column("_row_id", list(range(len(ds))))
if limit:
ds = ds.shuffle(seed=seed)
ds = ds.select(range(min(limit, len(ds))))
@ -200,7 +209,7 @@ class WinoGrande_Task(MCTaskSpec):
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = self.load(limit, seed)
for i, doc in enumerate(ds):
for doc in ds:
doc = cast(Mapping[str, Any], doc)
prompt, labels = format_multiple_choice(
@ -209,7 +218,7 @@ class WinoGrande_Task(MCTaskSpec):
yield Case(
task=self.name,
kind=self.kind,
case_id=f"winogrande:{i}",
case_id=f"winogrande_{self.config}_{self.split}_{doc['_row_id']}",
prompt=prompt,
gold=labels[int(doc["answer"]) - 1], # winogrande answers are 1 based
meta_data={"labels": labels},
@ -221,9 +230,12 @@ class MMLU_Task(MCTaskSpec):
def __init__(self):
self.name = "mmlu"
self.kind = "mc"
self.config = "all"
self.split = "test"
def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset("cais/mmlu", "all", split="test")
ds = datasets.load_dataset("cais/mmlu", self.config, split=self.split)
ds = ds.add_column("_row_id", list(range(len(ds))))
if limit:
ds = ds.shuffle(seed=seed)
ds = ds.select(range(min(limit, len(ds))))
@ -232,14 +244,14 @@ class MMLU_Task(MCTaskSpec):
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = self.load(limit, seed)
for i, doc in enumerate(ds):
for doc in ds:
doc = cast(Mapping[str, Any], doc)
prompt, labels = format_multiple_choice(doc["question"], doc["choices"])
yield Case(
task=self.name,
kind=self.kind,
case_id=f"mmlu:{doc['subject']}:{i}",
case_id=f"mmlu_{self.config}_{self.split}_{doc['subject']}_{doc['_row_id']}",
prompt=prompt,
gold=labels[int(doc["answer"])],
meta_data={"subject": doc["subject"], "labels": labels},
@ -285,12 +297,12 @@ class Hellaswag_Task(MCTaskSpec):
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = self.load(limit, seed)
for i, doc in enumerate(ds):
for doc in ds:
doc = cast(Mapping[str, Any], doc)
yield Case(
task=self.name,
kind=self.kind,
case_id=f"hellaswag:{i}",
case_id=f"hellaswag_{doc['split']}_{doc['ind']}",
prompt=doc["prompt"],
gold=doc["gold"],
meta_data={},
@ -302,9 +314,10 @@ class Aime_Task(MathTaskSpec):
def __init__(self):
self.name = "aime"
self.kind = "math"
self.split = "train"
def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split="train")
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split)
if limit:
ds = ds.shuffle(seed=seed)
@ -327,10 +340,10 @@ class Aime_Task(MathTaskSpec):
yield Case(
task=self.name,
kind=self.kind,
case_id=f"aime:{i}",
case_id=f"aime_{self.split}_{doc['id']}",
prompt=doc["prompt"],
gold=doc["answer"],
meta_data={},
meta_data={"id": doc["id"]},
)
@ -339,9 +352,12 @@ class Gsm8k_Task(MathTaskSpec):
def __init__(self):
self.name = "gsm8k"
self.kind = "math"
self.config = "main"
self.split = "test"
def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset("openai/gsm8k", "main", split="test")
ds = datasets.load_dataset("openai/gsm8k", self.config, split=self.split)
ds = ds.add_column("_row_id", list(range(len(ds))))
if limit:
ds = ds.shuffle(seed=seed)
ds = ds.select(range(min(limit, len(ds))))
@ -359,12 +375,12 @@ class Gsm8k_Task(MathTaskSpec):
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = self.load(limit, seed)
for i, doc in enumerate(ds):
for doc in ds:
doc = cast(Mapping[str, Any], doc)
yield Case(
task=self.name,
kind=self.kind,
case_id=f"gsm8k:{i}",
case_id=f"gsm8k_{self.config}_{self.split}:{doc['_row_id']}",
prompt=doc["prompt"],
gold=doc["gold"],
meta_data={},
@ -391,11 +407,21 @@ def build_request(case: Case, n_predict: int) -> dict[str, Any]:
return json_data
def write_checkpoint_line(
checkpoint_file: Path,
row: dict[str, Any],
file_lock: threading.Lock,
):
with file_lock:
with checkpoint_file.open(mode="a", encoding="utf-8") as f:
f.write(json.dumps(row) + "\n")
def send_prompt(
case: Case,
data: dict,
) -> dict[str, Union[str, int]]:
ret_err = {
result = {
"task": case.task,
"case_id": case.case_id,
"status": "error",
@ -408,26 +434,29 @@ def send_prompt(
server_address: str = data["server_address"]
task = TASK_DICT.get(case.task)
if task is None:
ret_err["error"] = f"unknown_task: {case.task}"
return ret_err
result["error"] = f"unknown_task: {case.task}"
return result
logger.debug(case.prompt)
json_data = build_request(case, data["n_predict"])
res_json = {}
try:
response = session.post(f"{server_address}/v1/completions", json=json_data)
if response.ok:
res_json = response.json()
else:
ret_err["error"] = f"http_response: {response.status_code}"
logger.warning(ret_err["error"])
return ret_err
res_json = response.json()
result["status"] = "ok"
except Exception as e:
ret_err["error"] = f"http_exception: {e}"
logger.warning(ret_err["error"])
return ret_err
logger.debug(response.text)
return TASK_DICT[case.task].grade(case, res_json)
result["error"] = f"http_exception: {e}"
logger.warning(result["error"])
if result["status"] == "ok":
result = TASK_DICT[case.task].grade(case, res_json)
write_checkpoint_line(
data["checkpoint_file"],
result.copy(),
data["file_lock"],
)
return result
def aggregate_by_task(results: list[dict[str, Any]]) -> dict[str, dict[str, int]]:
tmp = {
@ -491,13 +520,52 @@ def print_summary(pertask_results: dict[str, dict[str, int]]):
)
def read_checkpoint(
checkpoint_file: Path, resume_flag: bool
) -> tuple[Set[str], Set[str], list[dict[str, Any]]]:
done = set()
errored = set()
results = []
if not resume_flag or not checkpoint_file.is_file():
return done, errored, results
with checkpoint_file.open(mode="r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
row = json.loads(line)
except Exception as e:
logger.warning(f"WARNING: malformed checkpoint line {line}\n{e}")
continue
case_id = row.get("case_id")
if not case_id:
continue
if row["status"] == "error":
errored.add(case_id)
else:
done.add(case_id)
results.append(row)
errored -= done
return done, errored, results
def benchmark(
path_server: str,
prompt_source: str,
n_prompts: int,
n_predict: int,
rng_seed: int,
resume_flag: bool,
checkpoint_file: Path,
log_level: int,
):
logger.setLevel(log_level)
done, errored, checkpoint_results = read_checkpoint(checkpoint_file, resume_flag)
if not path_server.startswith("http://") and not path_server.startswith("https://"):
logger.error("ERROR: malformed server path")
return
@ -524,11 +592,15 @@ def benchmark(
session = requests.Session()
session.mount("http://", adapter)
session.mount("https://", adapter)
file_lock = threading.Lock()
cases: list[Case] = []
data: list[dict] = []
for task in task_queue:
for case in task.iter_cases(n_prompts, rng_seed):
if case.case_id in done or case.case_id in errored:
logger.debug(f"Skipping case_id {case.case_id} from checkpoint")
continue
cases.append(case)
data.append(
{
@ -536,6 +608,8 @@ def benchmark(
"session": session,
"server_address": server_address,
"n_predict": n_predict,
"file_lock": file_lock,
"checkpoint_file": checkpoint_file,
}
)
logger.info("Starting the benchmark...\n")
@ -553,7 +627,7 @@ def benchmark(
t1 = time()
logger.info(f"\nllama-eval duration: {t1-t0:.2f} s")
results.extend(checkpoint_results)
pertask_results = aggregate_by_task(results)
print_summary(pertask_results)
@ -593,5 +667,37 @@ if __name__ == "__main__":
default=2048,
help="Max. number of tokens to predict per prompt",
)
parser.add_argument(
"--resume",
dest="resume_flag",
action="store_true",
default=True,
help="Enable resuming from last state stored in checkpoint file",
)
parser.add_argument(
"--no-resume",
dest="resume_flag",
action="store_false",
help="Disble resuming from last state stored in checkpoint file",
)
parser.add_argument(
"--checkpoint-file",
type=Path,
dest="checkpoint_file",
default="./llama-eval-checkpoint.jsonl",
help="Checkpoint file to read last state from",
)
parser.set_defaults(log_level=logging.INFO)
parser.add_argument(
"--quiet", action="store_const", dest="log_level", const=logging.ERROR
)
parser.add_argument(
"--debug",
action="store_const",
default=True,
dest="log_level",
const=logging.DEBUG,
)
args = parser.parse_args()
benchmark(**vars(args))