llama.cpp/examples/llama-eval/llama-eval.py

704 lines
20 KiB
Python

#!/usr/bin/env python3
import re
import argparse
import os
from time import time
from typing import Union, Any, Mapping, cast
import datasets
import logging
import requests
from tqdm.contrib.concurrent import thread_map
from typing import Iterator, Set
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
import json
import threading
logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger("llama-eval")
MATH_TEMPLATE = """
{question}
Do not include any explanation. Put your final answer within \\boxed{{}}.
"""
def format_multiple_choice(prompt: str, choices: list[str]):
lines = [prompt]
labels = [chr(ord("A") + i) for i in range(len(choices))]
for l, c in zip(labels, choices):
lines.append(f"({l}): {c.strip()}")
lines.append(
"Do not include any explanation. Answer with the corresponding option letter only"
)
lines.append(", ".join(labels))
lines.append("Put your final answer within \\boxed{{}}.")
return "\n".join(lines), labels
def extract_boxed_text(text: str) -> str:
pattern = r"boxed{(.*?)}|framebox{(.*?)}"
matches = re.findall(pattern, text, re.DOTALL)
logger.debug(matches)
if matches:
for match in matches[::-1]:
for group in match:
if group != "":
return group.split(",")[-1].strip()
logger.debug("Could not extract boxed text. Maybe expand context window")
return ""
@dataclass(frozen=True)
class Case:
task: str
kind: str
case_id: str
prompt: str
gold: str
meta_data: dict[str, Any]
class TaskSpec(ABC):
name: str
kind: str
@abstractmethod
def load(self, limit, seed) -> datasets.Dataset:
pass
@abstractmethod
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
pass
@staticmethod
@abstractmethod
def grade(case: Case, response: dict) -> dict[str, Any]:
pass
class MCTaskSpec(TaskSpec):
@staticmethod
def grade(case: Case, response: dict) -> dict[str, Any]:
logger.debug(f"response {response}")
result = {
"task": case.task,
"case_id": case.case_id,
"correct": 0,
"pred": None,
"gold": case.gold,
"status": "ok",
}
try:
extracted_answer = extract_boxed_text(response["choices"][0]["text"])
except Exception as e:
result["status"] = "error"
logger.warning("ERROR: extract_boxed_text")
return result
if not extracted_answer:
result["status"] = "invalid"
logger.warning("INVALID: extract_boxed_text")
return result
logger.debug(f"extracted_answer {extracted_answer}")
logger.debug(f"data['answer'] {case.gold}")
result["pred"] = extracted_answer
result["correct"] = 1 if extracted_answer == case.gold else 0
return result
class MathTaskSpec(TaskSpec):
@staticmethod
def grade(case: Case, response: dict) -> dict[str, Any]:
logger.debug(f"response {response}")
result = {
"task": case.task,
"case_id": case.case_id,
"correct": 0,
"gold": case.gold,
"status": "ok",
"pred": None,
}
try:
extracted_answer = extract_boxed_text(response["choices"][0]["text"])
except:
result["status"] = "error"
logger.warning("ERROR: extract_boxed_text")
return result
source_answer = case.gold
try: # All AIME answers are integers, so we convert the extracted answer to an integer
extracted_answer = int(extracted_answer)
source_answer = int(case.gold)
except (ValueError, TypeError):
result["status"] = "invalid"
return result
logger.debug(f"extracted_answer {extracted_answer}")
logger.debug(f"data['answer'] {case.gold}")
result["pred"] = extracted_answer
result["correct"] = 1 if extracted_answer == source_answer else 0
return result
class ARC_Task(MCTaskSpec):
def __init__(self):
self.name = "arc"
self.kind = "mc"
self.config = "ARC-Challenge"
self.split = "test"
def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset("allenai/ai2_arc", self.config, split=self.split)
ds = ds.add_column("_row_id", list(range(len(ds))))
if limit:
ds = ds.shuffle(seed=seed)
ds = ds.select(range(min(limit, len(ds))))
return ds
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = self.load(limit, seed)
for doc in ds:
doc = cast(Mapping[str, Any], doc)
prompt, labels = format_multiple_choice(
doc["question"], doc["choices"]["text"]
)
yield Case(
task=self.name,
kind=self.kind,
case_id=f"ARC-Challenge_{self.config}_{self.split}_{doc['_row_id']}",
prompt=prompt,
gold=doc["answerKey"],
meta_data={"labels": labels},
)
class WinoGrande_Task(MCTaskSpec):
def __init__(self):
self.name = "winogrande"
self.kind = "mc"
self.config = "winogrande_debiased"
self.split = "validation"
def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset("winogrande", self.config, split=self.split)
ds = ds.add_column("_row_id", list(range(len(ds))))
if limit:
ds = ds.shuffle(seed=seed)
ds = ds.select(range(min(limit, len(ds))))
return ds
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = self.load(limit, seed)
for doc in ds:
doc = cast(Mapping[str, Any], doc)
prompt, labels = format_multiple_choice(
doc["sentence"], [doc["option1"], doc["option2"]]
)
yield Case(
task=self.name,
kind=self.kind,
case_id=f"winogrande_{self.config}_{self.split}_{doc['_row_id']}",
prompt=prompt,
gold=labels[int(doc["answer"]) - 1], # winogrande answers are 1 based
meta_data={"labels": labels},
)
class MMLU_Task(MCTaskSpec):
def __init__(self):
self.name = "mmlu"
self.kind = "mc"
self.config = "all"
self.split = "test"
def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset("cais/mmlu", self.config, split=self.split)
ds = ds.add_column("_row_id", list(range(len(ds))))
if limit:
ds = ds.shuffle(seed=seed)
ds = ds.select(range(min(limit, len(ds))))
return ds
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = self.load(limit, seed)
for doc in ds:
doc = cast(Mapping[str, Any], doc)
prompt, labels = format_multiple_choice(doc["question"], doc["choices"])
yield Case(
task=self.name,
kind=self.kind,
case_id=f"mmlu_{self.config}_{self.split}_{doc['subject']}_{doc['_row_id']}",
prompt=prompt,
gold=labels[int(doc["answer"])],
meta_data={"subject": doc["subject"], "labels": labels},
)
class Hellaswag_Task(MCTaskSpec):
# Preprocess hellaswag
@staticmethod
def preprocess(text: str):
text = text.strip()
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
text = text.replace(" [title]", ". ")
text = re.sub("\\[.*?\\]", "", text)
text = text.replace(" ", " ")
return text
@staticmethod
def hellaswag_process_doc(doc: dict[str, str]):
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
question = Hellaswag_Task.preprocess(doc["activity_label"] + ": " + ctx)
proc_answers = [Hellaswag_Task.preprocess(answer) for answer in doc["endings"]]
prompt, labels = format_multiple_choice(question, proc_answers)
out_doc = {
"prompt": prompt,
"gold": labels[int(doc["label"])],
}
return out_doc
def __init__(self):
self.name = "hellaswag"
self.kind = "mc"
def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset("Rowan/hellaswag", split="validation")
if limit:
ds = ds.shuffle(seed=seed)
ds = ds.select(range(min(limit, len(ds))))
ds = ds.map(Hellaswag_Task.hellaswag_process_doc)
return ds
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = self.load(limit, seed)
for doc in ds:
doc = cast(Mapping[str, Any], doc)
yield Case(
task=self.name,
kind=self.kind,
case_id=f"hellaswag_{doc['split']}_{doc['ind']}",
prompt=doc["prompt"],
gold=doc["gold"],
meta_data={},
)
class Aime_Task(MathTaskSpec):
def __init__(self):
self.name = "aime"
self.kind = "math"
self.split = "train"
def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split)
if limit:
ds = ds.shuffle(seed=seed)
ds = ds.select(range(min(limit, len(ds))))
ds = ds.map(
lambda ex: {
"prompt": MATH_TEMPLATE.format(
question=ex["problem"],
)
}
)
return ds
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = self.load(limit, seed)
for i, doc in enumerate(ds):
doc = cast(Mapping[str, Any], doc)
yield Case(
task=self.name,
kind=self.kind,
case_id=f"aime_{self.split}_{doc['id']}",
prompt=doc["prompt"],
gold=doc["answer"],
meta_data={"id": doc["id"]},
)
class Gsm8k_Task(MathTaskSpec):
def __init__(self):
self.name = "gsm8k"
self.kind = "math"
self.config = "main"
self.split = "test"
def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset("openai/gsm8k", self.config, split=self.split)
ds = ds.add_column("_row_id", list(range(len(ds))))
if limit:
ds = ds.shuffle(seed=seed)
ds = ds.select(range(min(limit, len(ds))))
ds = ds.map(
lambda k: {
"prompt": MATH_TEMPLATE.format(
question=k["question"],
),
"gold": k["answer"].split("### ")[-1].rstrip(),
}
)
return ds
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = self.load(limit, seed)
for doc in ds:
doc = cast(Mapping[str, Any], doc)
yield Case(
task=self.name,
kind=self.kind,
case_id=f"gsm8k_{self.config}_{self.split}:{doc['_row_id']}",
prompt=doc["prompt"],
gold=doc["gold"],
meta_data={},
)
TASK_DICT: dict[str, type[TaskSpec]] = {
"mmlu": MMLU_Task,
"aime": Aime_Task,
"gsm8k": Gsm8k_Task,
"hellaswag": Hellaswag_Task,
"arc": ARC_Task,
"winogrande": WinoGrande_Task,
}
def build_request(case: Case, n_predict: int) -> dict[str, Any]:
json_data = {
"n_predict": n_predict,
"max_tokens": n_predict,
"temperature": 0,
"prompt": case.prompt,
}
return json_data
def write_checkpoint_line(
checkpoint_file: Path,
row: dict[str, Any],
file_lock: threading.Lock,
):
with file_lock:
with checkpoint_file.open(mode="a", encoding="utf-8") as f:
f.write(json.dumps(row) + "\n")
def send_prompt(
case: Case,
data: dict,
) -> dict[str, Union[str, int]]:
result = {
"task": case.task,
"case_id": case.case_id,
"status": "error",
"correct": 0,
"gold": case.gold,
"pred": "",
"error": "",
}
session: requests.Session = data["session"]
server_address: str = data["server_address"]
task = TASK_DICT.get(case.task)
if task is None:
result["error"] = f"unknown_task: {case.task}"
return result
logger.debug(case.prompt)
json_data = build_request(case, data["n_predict"])
res_json = {}
try:
response = session.post(f"{server_address}/v1/completions", json=json_data)
res_json = response.json()
result["status"] = "ok"
except Exception as e:
result["error"] = f"http_exception: {e}"
logger.warning(result["error"])
if result["status"] == "ok":
result = TASK_DICT[case.task].grade(case, res_json)
write_checkpoint_line(
data["checkpoint_file"],
result.copy(),
data["file_lock"],
)
return result
def aggregate_by_task(results: list[dict[str, Any]]) -> dict[str, dict[str, int]]:
tmp = {
"total": 0,
"error": 0,
"invalid": 0,
"correct": 0,
}
agg: dict[str, dict[str, int]] = {}
for row in results:
d = agg.get(row["task"], tmp.copy())
d["total"] += 1
status = row["status"]
if status == "ok":
d["correct"] += row["correct"]
elif status == "invalid":
d["invalid"] += 1
elif status == "error":
d["error"] += 1
agg[row["task"]] = d
return agg
def print_summary(pertask_results: dict[str, dict[str, int]]):
print("\n=== llama-eval suite summary ===")
print(
f"{'Task':<15} {'Acc':>8} {'Correct':>8} {'Total':>8} {'Invalid':>8} {'Error':>8}"
)
print("-" * 65)
suite_total = 0
suite_correct = 0
for task in sorted(pertask_results.keys()):
stats = pertask_results[task]
total = stats["total"]
correct = stats["correct"]
invalid = stats["invalid"]
error = stats["error"]
acc = (correct / total) if total > 0 else 0.0
print(
f"{task:<15} "
f"{acc:8.3f} "
f"{correct:8d} "
f"{total:8d} "
f"{invalid:8d} "
f"{error:8d}"
)
suite_total += total
suite_correct += correct
# Overall summary
print("-" * 65)
suite_acc = (suite_correct / suite_total) if suite_total > 0 else 0.0
print(
f"{'ALL':<15} " f"{suite_acc:8.3f} " f"{suite_correct:8d} " f"{suite_total:8d}"
)
def read_checkpoint(
checkpoint_file: Path, resume_flag: bool
) -> tuple[Set[str], Set[str], list[dict[str, Any]]]:
done = set()
errored = set()
results = []
if not resume_flag or not checkpoint_file.is_file():
return done, errored, results
with checkpoint_file.open(mode="r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
row = json.loads(line)
except Exception as e:
logger.warning(f"WARNING: malformed checkpoint line {line}\n{e}")
continue
case_id = row.get("case_id")
if not case_id:
continue
if row["status"] == "error":
errored.add(case_id)
else:
done.add(case_id)
results.append(row)
errored -= done
return done, errored, results
def benchmark(
path_server: str,
prompt_source: str,
n_prompts: int,
n_predict: int,
rng_seed: int,
resume_flag: bool,
checkpoint_file: Path,
log_level: int,
):
logger.setLevel(log_level)
done, errored, checkpoint_results = read_checkpoint(checkpoint_file, resume_flag)
if not path_server.startswith("http://") and not path_server.startswith("https://"):
logger.error("ERROR: malformed server path")
return
if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore
task_queue: set[TaskSpec] = set()
for src in prompt_source.split(","):
if src == "all":
for v in TASK_DICT.values():
task_queue.add(v())
break
task_queue.add(TASK_DICT[src]())
session = None
try:
server_address: str = path_server
adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore
session = requests.Session()
session.mount("http://", adapter)
session.mount("https://", adapter)
file_lock = threading.Lock()
cases: list[Case] = []
data: list[dict] = []
for task in task_queue:
for case in task.iter_cases(n_prompts, rng_seed):
if case.case_id in done or case.case_id in errored:
logger.debug(f"Skipping case_id {case.case_id} from checkpoint")
continue
cases.append(case)
data.append(
{
"prompt_source": prompt_source,
"session": session,
"server_address": server_address,
"n_predict": n_predict,
"file_lock": file_lock,
"checkpoint_file": checkpoint_file,
}
)
logger.info("Starting the benchmark...\n")
t0 = time()
results: list[dict[str, Union[str, int]]] = thread_map(
send_prompt,
cases,
data,
max_workers=parallel,
chunksize=1,
)
finally:
if session is not None:
session.close()
t1 = time()
logger.info(f"\nllama-eval duration: {t1-t0:.2f} s")
results.extend(checkpoint_results)
pertask_results = aggregate_by_task(results)
print_summary(pertask_results)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
"Results are printed to console and visualized as plots (saved to current working directory). "
"To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help). "
"The reported numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
"particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model)."
)
parser.add_argument(
"--path_server",
type=str,
default="http://localhost:8033",
help="llama-server url",
)
parser.add_argument(
"--prompt_source",
type=str,
default="mmlu",
help=f"Eval types supported: all,{list(TASK_DICT.keys())}",
)
parser.add_argument(
"--n_prompts", type=int, default=None, help="Number of prompts to evaluate"
)
parser.add_argument(
"--rng_seed",
type=int,
default=42,
help="Number to see rng (Used to select prompts from datasource)",
)
parser.add_argument(
"--n_predict",
type=int,
default=2048,
help="Max. number of tokens to predict per prompt",
)
parser.add_argument(
"--resume",
dest="resume_flag",
action="store_true",
default=True,
help="Enable resuming from last state stored in checkpoint file",
)
parser.add_argument(
"--no-resume",
dest="resume_flag",
action="store_false",
help="Disble resuming from last state stored in checkpoint file",
)
parser.add_argument(
"--checkpoint-file",
type=Path,
dest="checkpoint_file",
default="./llama-eval-checkpoint.jsonl",
help="Checkpoint file to read last state from",
)
parser.set_defaults(log_level=logging.INFO)
parser.add_argument(
"--quiet", action="store_const", dest="log_level", const=logging.ERROR
)
parser.add_argument(
"--debug",
action="store_const",
default=True,
dest="log_level",
const=logging.DEBUG,
)
args = parser.parse_args()
benchmark(**vars(args))