add checkpointing
This commit is contained in:
parent
b0d50a5681
commit
979299a32f
|
|
@ -1,20 +1,17 @@
|
|||
# llama.cpp/example/llama-eval
|
||||
|
||||
The purpose of this example is to to run evaluations metrics against a an openapi api compatible LLM via http (llama-server).
|
||||
`llama-eval.py` is a single-script evaluation runner that sends prompt/response pairs to any OpenAI-compatible HTTP server (the default `llama-server`).
|
||||
|
||||
```bash
|
||||
./llama-server -m model.gguf --port 8033
|
||||
python examples/llama-eval/llama-eval.py --path_server http://localhost:8033 --n_prompts 100 --prompt_source arc
|
||||
```
|
||||
|
||||
```bash
|
||||
python examples/llama-eval/llama-eval.py --path_server http://localhost:8033 --n_prompt 100 --prompt_source arc
|
||||
```
|
||||
The supported tasks are:
|
||||
|
||||
## Supported tasks (MVP)
|
||||
|
||||
- **GSM8K** — grade-school math (final-answer only)
|
||||
- **AIME** — competition math (final-answer only)
|
||||
- **MMLU** — multi-domain knowledge (multiple choice)
|
||||
- **HellaSwag** — commonsense reasoning (multiple choice)
|
||||
- **ARC** — grade-school science reasoning (multiple choice)
|
||||
- **WinoGrande** — commonsense coreference resolution (multiple choice)
|
||||
- **GSM8K** — grade-school math
|
||||
- **AIME** — competition math (integer answers)
|
||||
- **MMLU** — multi-domain multiple choice
|
||||
- **HellaSwag** — commonsense reasoning multiple choice
|
||||
- **ARC** — grade-school science multiple choice
|
||||
- **WinoGrande** — commonsense coreference multiple choice
|
||||
|
|
|
|||
|
|
@ -10,9 +10,12 @@ import datasets
|
|||
import logging
|
||||
import requests
|
||||
from tqdm.contrib.concurrent import thread_map
|
||||
from typing import Iterator
|
||||
from typing import Iterator, Set
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import json
|
||||
import threading
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
logger = logging.getLogger("llama-eval")
|
||||
|
|
@ -47,7 +50,7 @@ def extract_boxed_text(text: str) -> str:
|
|||
for group in match:
|
||||
if group != "":
|
||||
return group.split(",")[-1].strip()
|
||||
logger.warning("Could not extract boxed text. Maybe expand context window")
|
||||
logger.debug("Could not extract boxed text. Maybe expand context window")
|
||||
|
||||
return ""
|
||||
|
||||
|
|
@ -130,8 +133,9 @@ class MathTaskSpec(TaskSpec):
|
|||
|
||||
try:
|
||||
extracted_answer = extract_boxed_text(response["choices"][0]["text"])
|
||||
except Exception as e:
|
||||
except:
|
||||
result["status"] = "error"
|
||||
logger.warning("ERROR: extract_boxed_text")
|
||||
return result
|
||||
|
||||
source_answer = case.gold
|
||||
|
|
@ -155,9 +159,12 @@ class ARC_Task(MCTaskSpec):
|
|||
def __init__(self):
|
||||
self.name = "arc"
|
||||
self.kind = "mc"
|
||||
self.config = "ARC-Challenge"
|
||||
self.split = "test"
|
||||
|
||||
def load(self, limit, seed) -> datasets.Dataset:
|
||||
ds = datasets.load_dataset("allenai/ai2_arc", "ARC-Challenge", split="test")
|
||||
ds = datasets.load_dataset("allenai/ai2_arc", self.config, split=self.split)
|
||||
ds = ds.add_column("_row_id", list(range(len(ds))))
|
||||
if limit:
|
||||
ds = ds.shuffle(seed=seed)
|
||||
ds = ds.select(range(min(limit, len(ds))))
|
||||
|
|
@ -166,7 +173,7 @@ class ARC_Task(MCTaskSpec):
|
|||
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
|
||||
ds = self.load(limit, seed)
|
||||
|
||||
for i, doc in enumerate(ds):
|
||||
for doc in ds:
|
||||
doc = cast(Mapping[str, Any], doc)
|
||||
|
||||
prompt, labels = format_multiple_choice(
|
||||
|
|
@ -175,7 +182,7 @@ class ARC_Task(MCTaskSpec):
|
|||
yield Case(
|
||||
task=self.name,
|
||||
kind=self.kind,
|
||||
case_id=f"ARC-Challenge:{i}",
|
||||
case_id=f"ARC-Challenge_{self.config}_{self.split}_{doc['_row_id']}",
|
||||
prompt=prompt,
|
||||
gold=doc["answerKey"],
|
||||
meta_data={"labels": labels},
|
||||
|
|
@ -187,11 +194,13 @@ class WinoGrande_Task(MCTaskSpec):
|
|||
def __init__(self):
|
||||
self.name = "winogrande"
|
||||
self.kind = "mc"
|
||||
self.config = "winogrande_debiased"
|
||||
self.split = "validation"
|
||||
|
||||
def load(self, limit, seed) -> datasets.Dataset:
|
||||
ds = datasets.load_dataset(
|
||||
"winogrande", "winogrande_debiased", split="validation"
|
||||
)
|
||||
ds = datasets.load_dataset("winogrande", self.config, split=self.split)
|
||||
|
||||
ds = ds.add_column("_row_id", list(range(len(ds))))
|
||||
if limit:
|
||||
ds = ds.shuffle(seed=seed)
|
||||
ds = ds.select(range(min(limit, len(ds))))
|
||||
|
|
@ -200,7 +209,7 @@ class WinoGrande_Task(MCTaskSpec):
|
|||
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
|
||||
ds = self.load(limit, seed)
|
||||
|
||||
for i, doc in enumerate(ds):
|
||||
for doc in ds:
|
||||
doc = cast(Mapping[str, Any], doc)
|
||||
|
||||
prompt, labels = format_multiple_choice(
|
||||
|
|
@ -209,7 +218,7 @@ class WinoGrande_Task(MCTaskSpec):
|
|||
yield Case(
|
||||
task=self.name,
|
||||
kind=self.kind,
|
||||
case_id=f"winogrande:{i}",
|
||||
case_id=f"winogrande_{self.config}_{self.split}_{doc['_row_id']}",
|
||||
prompt=prompt,
|
||||
gold=labels[int(doc["answer"]) - 1], # winogrande answers are 1 based
|
||||
meta_data={"labels": labels},
|
||||
|
|
@ -221,9 +230,12 @@ class MMLU_Task(MCTaskSpec):
|
|||
def __init__(self):
|
||||
self.name = "mmlu"
|
||||
self.kind = "mc"
|
||||
self.config = "all"
|
||||
self.split = "test"
|
||||
|
||||
def load(self, limit, seed) -> datasets.Dataset:
|
||||
ds = datasets.load_dataset("cais/mmlu", "all", split="test")
|
||||
ds = datasets.load_dataset("cais/mmlu", self.config, split=self.split)
|
||||
ds = ds.add_column("_row_id", list(range(len(ds))))
|
||||
if limit:
|
||||
ds = ds.shuffle(seed=seed)
|
||||
ds = ds.select(range(min(limit, len(ds))))
|
||||
|
|
@ -232,14 +244,14 @@ class MMLU_Task(MCTaskSpec):
|
|||
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
|
||||
ds = self.load(limit, seed)
|
||||
|
||||
for i, doc in enumerate(ds):
|
||||
for doc in ds:
|
||||
doc = cast(Mapping[str, Any], doc)
|
||||
|
||||
prompt, labels = format_multiple_choice(doc["question"], doc["choices"])
|
||||
yield Case(
|
||||
task=self.name,
|
||||
kind=self.kind,
|
||||
case_id=f"mmlu:{doc['subject']}:{i}",
|
||||
case_id=f"mmlu_{self.config}_{self.split}_{doc['subject']}_{doc['_row_id']}",
|
||||
prompt=prompt,
|
||||
gold=labels[int(doc["answer"])],
|
||||
meta_data={"subject": doc["subject"], "labels": labels},
|
||||
|
|
@ -285,12 +297,12 @@ class Hellaswag_Task(MCTaskSpec):
|
|||
|
||||
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
|
||||
ds = self.load(limit, seed)
|
||||
for i, doc in enumerate(ds):
|
||||
for doc in ds:
|
||||
doc = cast(Mapping[str, Any], doc)
|
||||
yield Case(
|
||||
task=self.name,
|
||||
kind=self.kind,
|
||||
case_id=f"hellaswag:{i}",
|
||||
case_id=f"hellaswag_{doc['split']}_{doc['ind']}",
|
||||
prompt=doc["prompt"],
|
||||
gold=doc["gold"],
|
||||
meta_data={},
|
||||
|
|
@ -302,9 +314,10 @@ class Aime_Task(MathTaskSpec):
|
|||
def __init__(self):
|
||||
self.name = "aime"
|
||||
self.kind = "math"
|
||||
self.split = "train"
|
||||
|
||||
def load(self, limit, seed) -> datasets.Dataset:
|
||||
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split="train")
|
||||
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split)
|
||||
|
||||
if limit:
|
||||
ds = ds.shuffle(seed=seed)
|
||||
|
|
@ -327,10 +340,10 @@ class Aime_Task(MathTaskSpec):
|
|||
yield Case(
|
||||
task=self.name,
|
||||
kind=self.kind,
|
||||
case_id=f"aime:{i}",
|
||||
case_id=f"aime_{self.split}_{doc['id']}",
|
||||
prompt=doc["prompt"],
|
||||
gold=doc["answer"],
|
||||
meta_data={},
|
||||
meta_data={"id": doc["id"]},
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -339,9 +352,12 @@ class Gsm8k_Task(MathTaskSpec):
|
|||
def __init__(self):
|
||||
self.name = "gsm8k"
|
||||
self.kind = "math"
|
||||
self.config = "main"
|
||||
self.split = "test"
|
||||
|
||||
def load(self, limit, seed) -> datasets.Dataset:
|
||||
ds = datasets.load_dataset("openai/gsm8k", "main", split="test")
|
||||
ds = datasets.load_dataset("openai/gsm8k", self.config, split=self.split)
|
||||
ds = ds.add_column("_row_id", list(range(len(ds))))
|
||||
if limit:
|
||||
ds = ds.shuffle(seed=seed)
|
||||
ds = ds.select(range(min(limit, len(ds))))
|
||||
|
|
@ -359,12 +375,12 @@ class Gsm8k_Task(MathTaskSpec):
|
|||
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
|
||||
ds = self.load(limit, seed)
|
||||
|
||||
for i, doc in enumerate(ds):
|
||||
for doc in ds:
|
||||
doc = cast(Mapping[str, Any], doc)
|
||||
yield Case(
|
||||
task=self.name,
|
||||
kind=self.kind,
|
||||
case_id=f"gsm8k:{i}",
|
||||
case_id=f"gsm8k_{self.config}_{self.split}:{doc['_row_id']}",
|
||||
prompt=doc["prompt"],
|
||||
gold=doc["gold"],
|
||||
meta_data={},
|
||||
|
|
@ -391,11 +407,21 @@ def build_request(case: Case, n_predict: int) -> dict[str, Any]:
|
|||
return json_data
|
||||
|
||||
|
||||
def write_checkpoint_line(
|
||||
checkpoint_file: Path,
|
||||
row: dict[str, Any],
|
||||
file_lock: threading.Lock,
|
||||
):
|
||||
with file_lock:
|
||||
with checkpoint_file.open(mode="a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(row) + "\n")
|
||||
|
||||
|
||||
def send_prompt(
|
||||
case: Case,
|
||||
data: dict,
|
||||
) -> dict[str, Union[str, int]]:
|
||||
ret_err = {
|
||||
result = {
|
||||
"task": case.task,
|
||||
"case_id": case.case_id,
|
||||
"status": "error",
|
||||
|
|
@ -408,26 +434,29 @@ def send_prompt(
|
|||
server_address: str = data["server_address"]
|
||||
task = TASK_DICT.get(case.task)
|
||||
if task is None:
|
||||
ret_err["error"] = f"unknown_task: {case.task}"
|
||||
return ret_err
|
||||
result["error"] = f"unknown_task: {case.task}"
|
||||
return result
|
||||
logger.debug(case.prompt)
|
||||
|
||||
json_data = build_request(case, data["n_predict"])
|
||||
res_json = {}
|
||||
try:
|
||||
response = session.post(f"{server_address}/v1/completions", json=json_data)
|
||||
if response.ok:
|
||||
res_json = response.json()
|
||||
else:
|
||||
ret_err["error"] = f"http_response: {response.status_code}"
|
||||
logger.warning(ret_err["error"])
|
||||
return ret_err
|
||||
res_json = response.json()
|
||||
result["status"] = "ok"
|
||||
except Exception as e:
|
||||
ret_err["error"] = f"http_exception: {e}"
|
||||
logger.warning(ret_err["error"])
|
||||
return ret_err
|
||||
logger.debug(response.text)
|
||||
return TASK_DICT[case.task].grade(case, res_json)
|
||||
result["error"] = f"http_exception: {e}"
|
||||
logger.warning(result["error"])
|
||||
|
||||
if result["status"] == "ok":
|
||||
result = TASK_DICT[case.task].grade(case, res_json)
|
||||
|
||||
write_checkpoint_line(
|
||||
data["checkpoint_file"],
|
||||
result.copy(),
|
||||
data["file_lock"],
|
||||
)
|
||||
return result
|
||||
|
||||
def aggregate_by_task(results: list[dict[str, Any]]) -> dict[str, dict[str, int]]:
|
||||
tmp = {
|
||||
|
|
@ -491,13 +520,52 @@ def print_summary(pertask_results: dict[str, dict[str, int]]):
|
|||
)
|
||||
|
||||
|
||||
def read_checkpoint(
|
||||
checkpoint_file: Path, resume_flag: bool
|
||||
) -> tuple[Set[str], Set[str], list[dict[str, Any]]]:
|
||||
done = set()
|
||||
errored = set()
|
||||
results = []
|
||||
if not resume_flag or not checkpoint_file.is_file():
|
||||
return done, errored, results
|
||||
|
||||
with checkpoint_file.open(mode="r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
row = json.loads(line)
|
||||
except Exception as e:
|
||||
logger.warning(f"WARNING: malformed checkpoint line {line}\n{e}")
|
||||
continue
|
||||
|
||||
case_id = row.get("case_id")
|
||||
if not case_id:
|
||||
continue
|
||||
|
||||
if row["status"] == "error":
|
||||
errored.add(case_id)
|
||||
else:
|
||||
done.add(case_id)
|
||||
results.append(row)
|
||||
errored -= done
|
||||
return done, errored, results
|
||||
|
||||
|
||||
def benchmark(
|
||||
path_server: str,
|
||||
prompt_source: str,
|
||||
n_prompts: int,
|
||||
n_predict: int,
|
||||
rng_seed: int,
|
||||
resume_flag: bool,
|
||||
checkpoint_file: Path,
|
||||
log_level: int,
|
||||
):
|
||||
logger.setLevel(log_level)
|
||||
done, errored, checkpoint_results = read_checkpoint(checkpoint_file, resume_flag)
|
||||
|
||||
if not path_server.startswith("http://") and not path_server.startswith("https://"):
|
||||
logger.error("ERROR: malformed server path")
|
||||
return
|
||||
|
|
@ -524,11 +592,15 @@ def benchmark(
|
|||
session = requests.Session()
|
||||
session.mount("http://", adapter)
|
||||
session.mount("https://", adapter)
|
||||
|
||||
file_lock = threading.Lock()
|
||||
cases: list[Case] = []
|
||||
data: list[dict] = []
|
||||
for task in task_queue:
|
||||
for case in task.iter_cases(n_prompts, rng_seed):
|
||||
if case.case_id in done or case.case_id in errored:
|
||||
logger.debug(f"Skipping case_id {case.case_id} from checkpoint")
|
||||
continue
|
||||
|
||||
cases.append(case)
|
||||
data.append(
|
||||
{
|
||||
|
|
@ -536,6 +608,8 @@ def benchmark(
|
|||
"session": session,
|
||||
"server_address": server_address,
|
||||
"n_predict": n_predict,
|
||||
"file_lock": file_lock,
|
||||
"checkpoint_file": checkpoint_file,
|
||||
}
|
||||
)
|
||||
logger.info("Starting the benchmark...\n")
|
||||
|
|
@ -553,7 +627,7 @@ def benchmark(
|
|||
|
||||
t1 = time()
|
||||
logger.info(f"\nllama-eval duration: {t1-t0:.2f} s")
|
||||
|
||||
results.extend(checkpoint_results)
|
||||
pertask_results = aggregate_by_task(results)
|
||||
print_summary(pertask_results)
|
||||
|
||||
|
|
@ -593,5 +667,37 @@ if __name__ == "__main__":
|
|||
default=2048,
|
||||
help="Max. number of tokens to predict per prompt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
dest="resume_flag",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Enable resuming from last state stored in checkpoint file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-resume",
|
||||
dest="resume_flag",
|
||||
action="store_false",
|
||||
help="Disble resuming from last state stored in checkpoint file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint-file",
|
||||
type=Path,
|
||||
dest="checkpoint_file",
|
||||
default="./llama-eval-checkpoint.jsonl",
|
||||
help="Checkpoint file to read last state from",
|
||||
)
|
||||
parser.set_defaults(log_level=logging.INFO)
|
||||
parser.add_argument(
|
||||
"--quiet", action="store_const", dest="log_level", const=logging.ERROR
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_const",
|
||||
default=True,
|
||||
dest="log_level",
|
||||
const=logging.DEBUG,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
benchmark(**vars(args))
|
||||
|
|
|
|||
Loading…
Reference in New Issue