diff --git a/examples/llama-eval/README.md b/examples/llama-eval/README.md new file mode 100644 index 0000000000..46224be3ec --- /dev/null +++ b/examples/llama-eval/README.md @@ -0,0 +1,17 @@ +# llama.cpp/example/llama-eval + +`llama-eval.py` is a single-script evaluation runner that sends prompt/response pairs to any OpenAI-compatible HTTP server (the default `llama-server`). + +```bash +./llama-server -m model.gguf --port 8033 +python examples/llama-eval/llama-eval.py --path_server http://localhost:8033 --n_prompts 100 --prompt_source arc +``` + +The supported tasks are: + +- **GSM8K** — grade-school math +- **AIME** — competition math (integer answers) +- **MMLU** — multi-domain multiple choice +- **HellaSwag** — commonsense reasoning multiple choice +- **ARC** — grade-school science multiple choice +- **WinoGrande** — commonsense coreference multiple choice diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py new file mode 100644 index 0000000000..78bfc0c2e4 --- /dev/null +++ b/examples/llama-eval/llama-eval.py @@ -0,0 +1,703 @@ +#!/usr/bin/env python3 + +import re +import argparse +import os +from time import time +from typing import Union, Any, Mapping, cast + +import datasets +import logging +import requests +from tqdm.contrib.concurrent import thread_map +from typing import Iterator, Set +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +import json +import threading + +logging.basicConfig(level=logging.INFO, format='%(message)s') +logger = logging.getLogger("llama-eval") + +MATH_TEMPLATE = """ +{question} +Do not include any explanation. Put your final answer within \\boxed{{}}. +""" + + +def format_multiple_choice(prompt: str, choices: list[str]): + lines = [prompt] + + labels = [chr(ord("A") + i) for i in range(len(choices))] + for l, c in zip(labels, choices): + lines.append(f"({l}): {c.strip()}") + lines.append( + "Do not include any explanation. Answer with the corresponding option letter only" + ) + lines.append(", ".join(labels)) + lines.append("Put your final answer within \\boxed{{}}.") + + return "\n".join(lines), labels + + +def extract_boxed_text(text: str) -> str: + pattern = r"boxed{(.*?)}|framebox{(.*?)}" + matches = re.findall(pattern, text, re.DOTALL) + logger.debug(matches) + if matches: + for match in matches[::-1]: + for group in match: + if group != "": + return group.split(",")[-1].strip() + logger.debug("Could not extract boxed text. Maybe expand context window") + + return "" + + +@dataclass(frozen=True) +class Case: + task: str + kind: str + case_id: str + prompt: str + gold: str + meta_data: dict[str, Any] + + +class TaskSpec(ABC): + name: str + kind: str + + @abstractmethod + def load(self, limit, seed) -> datasets.Dataset: + pass + + @abstractmethod + def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: + pass + + @staticmethod + @abstractmethod + def grade(case: Case, response: dict) -> dict[str, Any]: + pass + + +class MCTaskSpec(TaskSpec): + @staticmethod + def grade(case: Case, response: dict) -> dict[str, Any]: + logger.debug(f"response {response}") + result = { + "task": case.task, + "case_id": case.case_id, + "correct": 0, + "pred": None, + "gold": case.gold, + "status": "ok", + } + + try: + extracted_answer = extract_boxed_text(response["choices"][0]["text"]) + except Exception as e: + result["status"] = "error" + logger.warning("ERROR: extract_boxed_text") + + return result + + if not extracted_answer: + result["status"] = "invalid" + logger.warning("INVALID: extract_boxed_text") + return result + + logger.debug(f"extracted_answer {extracted_answer}") + logger.debug(f"data['answer'] {case.gold}") + result["pred"] = extracted_answer + result["correct"] = 1 if extracted_answer == case.gold else 0 + + return result + + +class MathTaskSpec(TaskSpec): + + @staticmethod + def grade(case: Case, response: dict) -> dict[str, Any]: + logger.debug(f"response {response}") + result = { + "task": case.task, + "case_id": case.case_id, + "correct": 0, + "gold": case.gold, + "status": "ok", + "pred": None, + } + + try: + extracted_answer = extract_boxed_text(response["choices"][0]["text"]) + except: + result["status"] = "error" + logger.warning("ERROR: extract_boxed_text") + return result + + source_answer = case.gold + try: # All AIME answers are integers, so we convert the extracted answer to an integer + extracted_answer = int(extracted_answer) + source_answer = int(case.gold) + except (ValueError, TypeError): + result["status"] = "invalid" + return result + + logger.debug(f"extracted_answer {extracted_answer}") + logger.debug(f"data['answer'] {case.gold}") + result["pred"] = extracted_answer + result["correct"] = 1 if extracted_answer == source_answer else 0 + + return result + + +class ARC_Task(MCTaskSpec): + + def __init__(self): + self.name = "arc" + self.kind = "mc" + self.config = "ARC-Challenge" + self.split = "test" + + def load(self, limit, seed) -> datasets.Dataset: + ds = datasets.load_dataset("allenai/ai2_arc", self.config, split=self.split) + ds = ds.add_column("_row_id", list(range(len(ds)))) + if limit: + ds = ds.shuffle(seed=seed) + ds = ds.select(range(min(limit, len(ds)))) + return ds + + def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: + ds = self.load(limit, seed) + + for doc in ds: + doc = cast(Mapping[str, Any], doc) + + prompt, labels = format_multiple_choice( + doc["question"], doc["choices"]["text"] + ) + yield Case( + task=self.name, + kind=self.kind, + case_id=f"ARC-Challenge_{self.config}_{self.split}_{doc['_row_id']}", + prompt=prompt, + gold=doc["answerKey"], + meta_data={"labels": labels}, + ) + + +class WinoGrande_Task(MCTaskSpec): + + def __init__(self): + self.name = "winogrande" + self.kind = "mc" + self.config = "winogrande_debiased" + self.split = "validation" + + def load(self, limit, seed) -> datasets.Dataset: + ds = datasets.load_dataset("winogrande", self.config, split=self.split) + + ds = ds.add_column("_row_id", list(range(len(ds)))) + if limit: + ds = ds.shuffle(seed=seed) + ds = ds.select(range(min(limit, len(ds)))) + return ds + + def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: + ds = self.load(limit, seed) + + for doc in ds: + doc = cast(Mapping[str, Any], doc) + + prompt, labels = format_multiple_choice( + doc["sentence"], [doc["option1"], doc["option2"]] + ) + yield Case( + task=self.name, + kind=self.kind, + case_id=f"winogrande_{self.config}_{self.split}_{doc['_row_id']}", + prompt=prompt, + gold=labels[int(doc["answer"]) - 1], # winogrande answers are 1 based + meta_data={"labels": labels}, + ) + + +class MMLU_Task(MCTaskSpec): + + def __init__(self): + self.name = "mmlu" + self.kind = "mc" + self.config = "all" + self.split = "test" + + def load(self, limit, seed) -> datasets.Dataset: + ds = datasets.load_dataset("cais/mmlu", self.config, split=self.split) + ds = ds.add_column("_row_id", list(range(len(ds)))) + if limit: + ds = ds.shuffle(seed=seed) + ds = ds.select(range(min(limit, len(ds)))) + return ds + + def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: + ds = self.load(limit, seed) + + for doc in ds: + doc = cast(Mapping[str, Any], doc) + + prompt, labels = format_multiple_choice(doc["question"], doc["choices"]) + yield Case( + task=self.name, + kind=self.kind, + case_id=f"mmlu_{self.config}_{self.split}_{doc['subject']}_{doc['_row_id']}", + prompt=prompt, + gold=labels[int(doc["answer"])], + meta_data={"subject": doc["subject"], "labels": labels}, + ) + + +class Hellaswag_Task(MCTaskSpec): + + # Preprocess hellaswag + @staticmethod + def preprocess(text: str): + text = text.strip() + # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. + text = text.replace(" [title]", ". ") + text = re.sub("\\[.*?\\]", "", text) + text = text.replace(" ", " ") + return text + + @staticmethod + def hellaswag_process_doc(doc: dict[str, str]): + ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize() + question = Hellaswag_Task.preprocess(doc["activity_label"] + ": " + ctx) + proc_answers = [Hellaswag_Task.preprocess(answer) for answer in doc["endings"]] + prompt, labels = format_multiple_choice(question, proc_answers) + out_doc = { + "prompt": prompt, + "gold": labels[int(doc["label"])], + } + return out_doc + + def __init__(self): + self.name = "hellaswag" + self.kind = "mc" + + def load(self, limit, seed) -> datasets.Dataset: + ds = datasets.load_dataset("Rowan/hellaswag", split="validation") + if limit: + ds = ds.shuffle(seed=seed) + ds = ds.select(range(min(limit, len(ds)))) + ds = ds.map(Hellaswag_Task.hellaswag_process_doc) + + return ds + + def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: + ds = self.load(limit, seed) + for doc in ds: + doc = cast(Mapping[str, Any], doc) + yield Case( + task=self.name, + kind=self.kind, + case_id=f"hellaswag_{doc['split']}_{doc['ind']}", + prompt=doc["prompt"], + gold=doc["gold"], + meta_data={}, + ) + + +class Aime_Task(MathTaskSpec): + + def __init__(self): + self.name = "aime" + self.kind = "math" + self.split = "train" + + def load(self, limit, seed) -> datasets.Dataset: + ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split) + + if limit: + ds = ds.shuffle(seed=seed) + ds = ds.select(range(min(limit, len(ds)))) + + ds = ds.map( + lambda ex: { + "prompt": MATH_TEMPLATE.format( + question=ex["problem"], + ) + } + ) + return ds + + def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: + ds = self.load(limit, seed) + + for i, doc in enumerate(ds): + doc = cast(Mapping[str, Any], doc) + yield Case( + task=self.name, + kind=self.kind, + case_id=f"aime_{self.split}_{doc['id']}", + prompt=doc["prompt"], + gold=doc["answer"], + meta_data={"id": doc["id"]}, + ) + + +class Gsm8k_Task(MathTaskSpec): + + def __init__(self): + self.name = "gsm8k" + self.kind = "math" + self.config = "main" + self.split = "test" + + def load(self, limit, seed) -> datasets.Dataset: + ds = datasets.load_dataset("openai/gsm8k", self.config, split=self.split) + ds = ds.add_column("_row_id", list(range(len(ds)))) + if limit: + ds = ds.shuffle(seed=seed) + ds = ds.select(range(min(limit, len(ds)))) + + ds = ds.map( + lambda k: { + "prompt": MATH_TEMPLATE.format( + question=k["question"], + ), + "gold": k["answer"].split("### ")[-1].rstrip(), + } + ) + return ds + + def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: + ds = self.load(limit, seed) + + for doc in ds: + doc = cast(Mapping[str, Any], doc) + yield Case( + task=self.name, + kind=self.kind, + case_id=f"gsm8k_{self.config}_{self.split}:{doc['_row_id']}", + prompt=doc["prompt"], + gold=doc["gold"], + meta_data={}, + ) + + +TASK_DICT: dict[str, type[TaskSpec]] = { + "mmlu": MMLU_Task, + "aime": Aime_Task, + "gsm8k": Gsm8k_Task, + "hellaswag": Hellaswag_Task, + "arc": ARC_Task, + "winogrande": WinoGrande_Task, +} + + +def build_request(case: Case, n_predict: int) -> dict[str, Any]: + json_data = { + "n_predict": n_predict, + "max_tokens": n_predict, + "temperature": 0, + "prompt": case.prompt, + } + return json_data + + +def write_checkpoint_line( + checkpoint_file: Path, + row: dict[str, Any], + file_lock: threading.Lock, +): + with file_lock: + with checkpoint_file.open(mode="a", encoding="utf-8") as f: + f.write(json.dumps(row) + "\n") + + +def send_prompt( + case: Case, + data: dict, +) -> dict[str, Union[str, int]]: + result = { + "task": case.task, + "case_id": case.case_id, + "status": "error", + "correct": 0, + "gold": case.gold, + "pred": "", + "error": "", + } + session: requests.Session = data["session"] + server_address: str = data["server_address"] + task = TASK_DICT.get(case.task) + if task is None: + result["error"] = f"unknown_task: {case.task}" + return result + logger.debug(case.prompt) + + json_data = build_request(case, data["n_predict"]) + res_json = {} + try: + response = session.post(f"{server_address}/v1/completions", json=json_data) + res_json = response.json() + result["status"] = "ok" + except Exception as e: + result["error"] = f"http_exception: {e}" + logger.warning(result["error"]) + + if result["status"] == "ok": + result = TASK_DICT[case.task].grade(case, res_json) + + write_checkpoint_line( + data["checkpoint_file"], + result.copy(), + data["file_lock"], + ) + return result + +def aggregate_by_task(results: list[dict[str, Any]]) -> dict[str, dict[str, int]]: + tmp = { + "total": 0, + "error": 0, + "invalid": 0, + "correct": 0, + } + agg: dict[str, dict[str, int]] = {} + for row in results: + d = agg.get(row["task"], tmp.copy()) + d["total"] += 1 + status = row["status"] + if status == "ok": + d["correct"] += row["correct"] + elif status == "invalid": + d["invalid"] += 1 + elif status == "error": + d["error"] += 1 + + agg[row["task"]] = d + return agg + + +def print_summary(pertask_results: dict[str, dict[str, int]]): + print("\n=== llama-eval suite summary ===") + print( + f"{'Task':<15} {'Acc':>8} {'Correct':>8} {'Total':>8} {'Invalid':>8} {'Error':>8}" + ) + print("-" * 65) + + suite_total = 0 + suite_correct = 0 + + for task in sorted(pertask_results.keys()): + stats = pertask_results[task] + total = stats["total"] + correct = stats["correct"] + invalid = stats["invalid"] + error = stats["error"] + + acc = (correct / total) if total > 0 else 0.0 + + print( + f"{task:<15} " + f"{acc:8.3f} " + f"{correct:8d} " + f"{total:8d} " + f"{invalid:8d} " + f"{error:8d}" + ) + + suite_total += total + suite_correct += correct + + # Overall summary + print("-" * 65) + suite_acc = (suite_correct / suite_total) if suite_total > 0 else 0.0 + print( + f"{'ALL':<15} " f"{suite_acc:8.3f} " f"{suite_correct:8d} " f"{suite_total:8d}" + ) + + +def read_checkpoint( + checkpoint_file: Path, resume_flag: bool +) -> tuple[Set[str], Set[str], list[dict[str, Any]]]: + done = set() + errored = set() + results = [] + if not resume_flag or not checkpoint_file.is_file(): + return done, errored, results + + with checkpoint_file.open(mode="r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + row = json.loads(line) + except Exception as e: + logger.warning(f"WARNING: malformed checkpoint line {line}\n{e}") + continue + + case_id = row.get("case_id") + if not case_id: + continue + + if row["status"] == "error": + errored.add(case_id) + else: + done.add(case_id) + results.append(row) + errored -= done + return done, errored, results + + +def benchmark( + path_server: str, + prompt_source: str, + n_prompts: int, + n_predict: int, + rng_seed: int, + resume_flag: bool, + checkpoint_file: Path, + log_level: int, +): + logger.setLevel(log_level) + done, errored, checkpoint_results = read_checkpoint(checkpoint_file, resume_flag) + + if not path_server.startswith("http://") and not path_server.startswith("https://"): + logger.error("ERROR: malformed server path") + return + + if os.environ.get("LLAMA_ARG_N_PARALLEL") is None: + logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32") + os.environ["LLAMA_ARG_N_PARALLEL"] = "32" + + parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore + + task_queue: set[TaskSpec] = set() + for src in prompt_source.split(","): + if src == "all": + for v in TASK_DICT.values(): + task_queue.add(v()) + break + task_queue.add(TASK_DICT[src]()) + + session = None + try: + server_address: str = path_server + + adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore + session = requests.Session() + session.mount("http://", adapter) + session.mount("https://", adapter) + file_lock = threading.Lock() + cases: list[Case] = [] + data: list[dict] = [] + for task in task_queue: + for case in task.iter_cases(n_prompts, rng_seed): + if case.case_id in done or case.case_id in errored: + logger.debug(f"Skipping case_id {case.case_id} from checkpoint") + continue + + cases.append(case) + data.append( + { + "prompt_source": prompt_source, + "session": session, + "server_address": server_address, + "n_predict": n_predict, + "file_lock": file_lock, + "checkpoint_file": checkpoint_file, + } + ) + logger.info("Starting the benchmark...\n") + t0 = time() + results: list[dict[str, Union[str, int]]] = thread_map( + send_prompt, + cases, + data, + max_workers=parallel, + chunksize=1, + ) + finally: + if session is not None: + session.close() + + t1 = time() + logger.info(f"\nllama-eval duration: {t1-t0:.2f} s") + results.extend(checkpoint_results) + pertask_results = aggregate_by_task(results) + print_summary(pertask_results) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Tool for benchmarking the throughput of the llama.cpp HTTP server. " + "Results are printed to console and visualized as plots (saved to current working directory). " + "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help). " + "The reported numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, " + "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model)." + ) + parser.add_argument( + "--path_server", + type=str, + default="http://localhost:8033", + help="llama-server url", + ) + parser.add_argument( + "--prompt_source", + type=str, + default="mmlu", + help=f"Eval types supported: all,{list(TASK_DICT.keys())}", + ) + parser.add_argument( + "--n_prompts", type=int, default=None, help="Number of prompts to evaluate" + ) + parser.add_argument( + "--rng_seed", + type=int, + default=42, + help="Number to see rng (Used to select prompts from datasource)", + ) + parser.add_argument( + "--n_predict", + type=int, + default=2048, + help="Max. number of tokens to predict per prompt", + ) + parser.add_argument( + "--resume", + dest="resume_flag", + action="store_true", + default=True, + help="Enable resuming from last state stored in checkpoint file", + ) + parser.add_argument( + "--no-resume", + dest="resume_flag", + action="store_false", + help="Disble resuming from last state stored in checkpoint file", + ) + parser.add_argument( + "--checkpoint-file", + type=Path, + dest="checkpoint_file", + default="./llama-eval-checkpoint.jsonl", + help="Checkpoint file to read last state from", + ) + parser.set_defaults(log_level=logging.INFO) + parser.add_argument( + "--quiet", action="store_const", dest="log_level", const=logging.ERROR + ) + parser.add_argument( + "--debug", + action="store_const", + default=True, + dest="log_level", + const=logging.DEBUG, + ) + + args = parser.parse_args() + benchmark(**vars(args))