From 979299a32f63e8804760ce50c47639305a567117 Mon Sep 17 00:00:00 2001 From: gatbontonpc Date: Fri, 16 Jan 2026 17:58:31 -0500 Subject: [PATCH] add checkpointing --- examples/llama-eval/README.md | 21 ++-- examples/llama-eval/llama-eval.py | 182 +++++++++++++++++++++++------- 2 files changed, 153 insertions(+), 50 deletions(-) diff --git a/examples/llama-eval/README.md b/examples/llama-eval/README.md index 4dfaf09a22..46224be3ec 100644 --- a/examples/llama-eval/README.md +++ b/examples/llama-eval/README.md @@ -1,20 +1,17 @@ # llama.cpp/example/llama-eval -The purpose of this example is to to run evaluations metrics against a an openapi api compatible LLM via http (llama-server). +`llama-eval.py` is a single-script evaluation runner that sends prompt/response pairs to any OpenAI-compatible HTTP server (the default `llama-server`). ```bash ./llama-server -m model.gguf --port 8033 +python examples/llama-eval/llama-eval.py --path_server http://localhost:8033 --n_prompts 100 --prompt_source arc ``` -```bash -python examples/llama-eval/llama-eval.py --path_server http://localhost:8033 --n_prompt 100 --prompt_source arc -``` +The supported tasks are: -## Supported tasks (MVP) - -- **GSM8K** — grade-school math (final-answer only) -- **AIME** — competition math (final-answer only) -- **MMLU** — multi-domain knowledge (multiple choice) -- **HellaSwag** — commonsense reasoning (multiple choice) -- **ARC** — grade-school science reasoning (multiple choice) -- **WinoGrande** — commonsense coreference resolution (multiple choice) \ No newline at end of file +- **GSM8K** — grade-school math +- **AIME** — competition math (integer answers) +- **MMLU** — multi-domain multiple choice +- **HellaSwag** — commonsense reasoning multiple choice +- **ARC** — grade-school science multiple choice +- **WinoGrande** — commonsense coreference multiple choice diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py index 0ded50545c..78bfc0c2e4 100644 --- a/examples/llama-eval/llama-eval.py +++ b/examples/llama-eval/llama-eval.py @@ -10,9 +10,12 @@ import datasets import logging import requests from tqdm.contrib.concurrent import thread_map -from typing import Iterator +from typing import Iterator, Set from abc import ABC, abstractmethod from dataclasses import dataclass +from pathlib import Path +import json +import threading logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger("llama-eval") @@ -47,7 +50,7 @@ def extract_boxed_text(text: str) -> str: for group in match: if group != "": return group.split(",")[-1].strip() - logger.warning("Could not extract boxed text. Maybe expand context window") + logger.debug("Could not extract boxed text. Maybe expand context window") return "" @@ -130,8 +133,9 @@ class MathTaskSpec(TaskSpec): try: extracted_answer = extract_boxed_text(response["choices"][0]["text"]) - except Exception as e: + except: result["status"] = "error" + logger.warning("ERROR: extract_boxed_text") return result source_answer = case.gold @@ -155,9 +159,12 @@ class ARC_Task(MCTaskSpec): def __init__(self): self.name = "arc" self.kind = "mc" + self.config = "ARC-Challenge" + self.split = "test" def load(self, limit, seed) -> datasets.Dataset: - ds = datasets.load_dataset("allenai/ai2_arc", "ARC-Challenge", split="test") + ds = datasets.load_dataset("allenai/ai2_arc", self.config, split=self.split) + ds = ds.add_column("_row_id", list(range(len(ds)))) if limit: ds = ds.shuffle(seed=seed) ds = ds.select(range(min(limit, len(ds)))) @@ -166,7 +173,7 @@ class ARC_Task(MCTaskSpec): def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: ds = self.load(limit, seed) - for i, doc in enumerate(ds): + for doc in ds: doc = cast(Mapping[str, Any], doc) prompt, labels = format_multiple_choice( @@ -175,7 +182,7 @@ class ARC_Task(MCTaskSpec): yield Case( task=self.name, kind=self.kind, - case_id=f"ARC-Challenge:{i}", + case_id=f"ARC-Challenge_{self.config}_{self.split}_{doc['_row_id']}", prompt=prompt, gold=doc["answerKey"], meta_data={"labels": labels}, @@ -187,11 +194,13 @@ class WinoGrande_Task(MCTaskSpec): def __init__(self): self.name = "winogrande" self.kind = "mc" + self.config = "winogrande_debiased" + self.split = "validation" def load(self, limit, seed) -> datasets.Dataset: - ds = datasets.load_dataset( - "winogrande", "winogrande_debiased", split="validation" - ) + ds = datasets.load_dataset("winogrande", self.config, split=self.split) + + ds = ds.add_column("_row_id", list(range(len(ds)))) if limit: ds = ds.shuffle(seed=seed) ds = ds.select(range(min(limit, len(ds)))) @@ -200,7 +209,7 @@ class WinoGrande_Task(MCTaskSpec): def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: ds = self.load(limit, seed) - for i, doc in enumerate(ds): + for doc in ds: doc = cast(Mapping[str, Any], doc) prompt, labels = format_multiple_choice( @@ -209,7 +218,7 @@ class WinoGrande_Task(MCTaskSpec): yield Case( task=self.name, kind=self.kind, - case_id=f"winogrande:{i}", + case_id=f"winogrande_{self.config}_{self.split}_{doc['_row_id']}", prompt=prompt, gold=labels[int(doc["answer"]) - 1], # winogrande answers are 1 based meta_data={"labels": labels}, @@ -221,9 +230,12 @@ class MMLU_Task(MCTaskSpec): def __init__(self): self.name = "mmlu" self.kind = "mc" + self.config = "all" + self.split = "test" def load(self, limit, seed) -> datasets.Dataset: - ds = datasets.load_dataset("cais/mmlu", "all", split="test") + ds = datasets.load_dataset("cais/mmlu", self.config, split=self.split) + ds = ds.add_column("_row_id", list(range(len(ds)))) if limit: ds = ds.shuffle(seed=seed) ds = ds.select(range(min(limit, len(ds)))) @@ -232,14 +244,14 @@ class MMLU_Task(MCTaskSpec): def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: ds = self.load(limit, seed) - for i, doc in enumerate(ds): + for doc in ds: doc = cast(Mapping[str, Any], doc) prompt, labels = format_multiple_choice(doc["question"], doc["choices"]) yield Case( task=self.name, kind=self.kind, - case_id=f"mmlu:{doc['subject']}:{i}", + case_id=f"mmlu_{self.config}_{self.split}_{doc['subject']}_{doc['_row_id']}", prompt=prompt, gold=labels[int(doc["answer"])], meta_data={"subject": doc["subject"], "labels": labels}, @@ -285,12 +297,12 @@ class Hellaswag_Task(MCTaskSpec): def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: ds = self.load(limit, seed) - for i, doc in enumerate(ds): + for doc in ds: doc = cast(Mapping[str, Any], doc) yield Case( task=self.name, kind=self.kind, - case_id=f"hellaswag:{i}", + case_id=f"hellaswag_{doc['split']}_{doc['ind']}", prompt=doc["prompt"], gold=doc["gold"], meta_data={}, @@ -302,9 +314,10 @@ class Aime_Task(MathTaskSpec): def __init__(self): self.name = "aime" self.kind = "math" + self.split = "train" def load(self, limit, seed) -> datasets.Dataset: - ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split="train") + ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split) if limit: ds = ds.shuffle(seed=seed) @@ -327,10 +340,10 @@ class Aime_Task(MathTaskSpec): yield Case( task=self.name, kind=self.kind, - case_id=f"aime:{i}", + case_id=f"aime_{self.split}_{doc['id']}", prompt=doc["prompt"], gold=doc["answer"], - meta_data={}, + meta_data={"id": doc["id"]}, ) @@ -339,9 +352,12 @@ class Gsm8k_Task(MathTaskSpec): def __init__(self): self.name = "gsm8k" self.kind = "math" + self.config = "main" + self.split = "test" def load(self, limit, seed) -> datasets.Dataset: - ds = datasets.load_dataset("openai/gsm8k", "main", split="test") + ds = datasets.load_dataset("openai/gsm8k", self.config, split=self.split) + ds = ds.add_column("_row_id", list(range(len(ds)))) if limit: ds = ds.shuffle(seed=seed) ds = ds.select(range(min(limit, len(ds)))) @@ -359,12 +375,12 @@ class Gsm8k_Task(MathTaskSpec): def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: ds = self.load(limit, seed) - for i, doc in enumerate(ds): + for doc in ds: doc = cast(Mapping[str, Any], doc) yield Case( task=self.name, kind=self.kind, - case_id=f"gsm8k:{i}", + case_id=f"gsm8k_{self.config}_{self.split}:{doc['_row_id']}", prompt=doc["prompt"], gold=doc["gold"], meta_data={}, @@ -391,11 +407,21 @@ def build_request(case: Case, n_predict: int) -> dict[str, Any]: return json_data +def write_checkpoint_line( + checkpoint_file: Path, + row: dict[str, Any], + file_lock: threading.Lock, +): + with file_lock: + with checkpoint_file.open(mode="a", encoding="utf-8") as f: + f.write(json.dumps(row) + "\n") + + def send_prompt( case: Case, data: dict, ) -> dict[str, Union[str, int]]: - ret_err = { + result = { "task": case.task, "case_id": case.case_id, "status": "error", @@ -408,26 +434,29 @@ def send_prompt( server_address: str = data["server_address"] task = TASK_DICT.get(case.task) if task is None: - ret_err["error"] = f"unknown_task: {case.task}" - return ret_err + result["error"] = f"unknown_task: {case.task}" + return result logger.debug(case.prompt) json_data = build_request(case, data["n_predict"]) + res_json = {} try: response = session.post(f"{server_address}/v1/completions", json=json_data) - if response.ok: - res_json = response.json() - else: - ret_err["error"] = f"http_response: {response.status_code}" - logger.warning(ret_err["error"]) - return ret_err + res_json = response.json() + result["status"] = "ok" except Exception as e: - ret_err["error"] = f"http_exception: {e}" - logger.warning(ret_err["error"]) - return ret_err - logger.debug(response.text) - return TASK_DICT[case.task].grade(case, res_json) + result["error"] = f"http_exception: {e}" + logger.warning(result["error"]) + if result["status"] == "ok": + result = TASK_DICT[case.task].grade(case, res_json) + + write_checkpoint_line( + data["checkpoint_file"], + result.copy(), + data["file_lock"], + ) + return result def aggregate_by_task(results: list[dict[str, Any]]) -> dict[str, dict[str, int]]: tmp = { @@ -491,13 +520,52 @@ def print_summary(pertask_results: dict[str, dict[str, int]]): ) +def read_checkpoint( + checkpoint_file: Path, resume_flag: bool +) -> tuple[Set[str], Set[str], list[dict[str, Any]]]: + done = set() + errored = set() + results = [] + if not resume_flag or not checkpoint_file.is_file(): + return done, errored, results + + with checkpoint_file.open(mode="r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + row = json.loads(line) + except Exception as e: + logger.warning(f"WARNING: malformed checkpoint line {line}\n{e}") + continue + + case_id = row.get("case_id") + if not case_id: + continue + + if row["status"] == "error": + errored.add(case_id) + else: + done.add(case_id) + results.append(row) + errored -= done + return done, errored, results + + def benchmark( path_server: str, prompt_source: str, n_prompts: int, n_predict: int, rng_seed: int, + resume_flag: bool, + checkpoint_file: Path, + log_level: int, ): + logger.setLevel(log_level) + done, errored, checkpoint_results = read_checkpoint(checkpoint_file, resume_flag) + if not path_server.startswith("http://") and not path_server.startswith("https://"): logger.error("ERROR: malformed server path") return @@ -524,11 +592,15 @@ def benchmark( session = requests.Session() session.mount("http://", adapter) session.mount("https://", adapter) - + file_lock = threading.Lock() cases: list[Case] = [] data: list[dict] = [] for task in task_queue: for case in task.iter_cases(n_prompts, rng_seed): + if case.case_id in done or case.case_id in errored: + logger.debug(f"Skipping case_id {case.case_id} from checkpoint") + continue + cases.append(case) data.append( { @@ -536,6 +608,8 @@ def benchmark( "session": session, "server_address": server_address, "n_predict": n_predict, + "file_lock": file_lock, + "checkpoint_file": checkpoint_file, } ) logger.info("Starting the benchmark...\n") @@ -553,7 +627,7 @@ def benchmark( t1 = time() logger.info(f"\nllama-eval duration: {t1-t0:.2f} s") - + results.extend(checkpoint_results) pertask_results = aggregate_by_task(results) print_summary(pertask_results) @@ -593,5 +667,37 @@ if __name__ == "__main__": default=2048, help="Max. number of tokens to predict per prompt", ) + parser.add_argument( + "--resume", + dest="resume_flag", + action="store_true", + default=True, + help="Enable resuming from last state stored in checkpoint file", + ) + parser.add_argument( + "--no-resume", + dest="resume_flag", + action="store_false", + help="Disble resuming from last state stored in checkpoint file", + ) + parser.add_argument( + "--checkpoint-file", + type=Path, + dest="checkpoint_file", + default="./llama-eval-checkpoint.jsonl", + help="Checkpoint file to read last state from", + ) + parser.set_defaults(log_level=logging.INFO) + parser.add_argument( + "--quiet", action="store_const", dest="log_level", const=logging.ERROR + ) + parser.add_argument( + "--debug", + action="store_const", + default=True, + dest="log_level", + const=logging.DEBUG, + ) + args = parser.parse_args() benchmark(**vars(args))