#!/usr/bin/env python3 import re import argparse import os from time import time from typing import Union, Any, Mapping, cast import datasets import logging import requests from tqdm.contrib.concurrent import thread_map from typing import Iterator, Set from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path import json import threading logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger("llama-eval") MATH_TEMPLATE = """ {question} Do not include any explanation. Put your final answer within \\boxed{{}}. """ def format_multiple_choice(prompt: str, choices: list[str]): lines = [prompt] labels = [chr(ord("A") + i) for i in range(len(choices))] for l, c in zip(labels, choices): lines.append(f"({l}): {c.strip()}") lines.append( "Do not include any explanation. Answer with the corresponding option letter only" ) lines.append(", ".join(labels)) lines.append("Put your final answer within \\boxed{{}}.") return "\n".join(lines), labels def extract_boxed_text(text: str) -> str: pattern = r"boxed{(.*?)}|framebox{(.*?)}" matches = re.findall(pattern, text, re.DOTALL) logger.debug(matches) if matches: for match in matches[::-1]: for group in match: if group != "": return group.split(",")[-1].strip() logger.debug("Could not extract boxed text. Maybe expand context window") return "" @dataclass(frozen=True) class Case: task: str kind: str case_id: str prompt: str gold: str meta_data: dict[str, Any] class TaskSpec(ABC): name: str kind: str @abstractmethod def load(self, limit, seed) -> datasets.Dataset: pass @abstractmethod def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: pass @staticmethod @abstractmethod def grade(case: Case, response: dict) -> dict[str, Any]: pass class MCTaskSpec(TaskSpec): @staticmethod def grade(case: Case, response: dict) -> dict[str, Any]: logger.debug(f"response {response}") result = { "task": case.task, "case_id": case.case_id, "correct": 0, "pred": None, "gold": case.gold, "status": "ok", } try: extracted_answer = extract_boxed_text(response["choices"][0]["text"]) except Exception as e: result["status"] = "error" logger.warning("ERROR: extract_boxed_text") return result if not extracted_answer: result["status"] = "invalid" logger.warning("INVALID: extract_boxed_text") return result logger.debug(f"extracted_answer {extracted_answer}") logger.debug(f"data['answer'] {case.gold}") result["pred"] = extracted_answer result["correct"] = 1 if extracted_answer == case.gold else 0 return result class MathTaskSpec(TaskSpec): @staticmethod def grade(case: Case, response: dict) -> dict[str, Any]: logger.debug(f"response {response}") result = { "task": case.task, "case_id": case.case_id, "correct": 0, "gold": case.gold, "status": "ok", "pred": None, } try: extracted_answer = extract_boxed_text(response["choices"][0]["text"]) except: result["status"] = "error" logger.warning("ERROR: extract_boxed_text") return result source_answer = case.gold try: # All AIME answers are integers, so we convert the extracted answer to an integer extracted_answer = int(extracted_answer) source_answer = int(case.gold) except (ValueError, TypeError): result["status"] = "invalid" return result logger.debug(f"extracted_answer {extracted_answer}") logger.debug(f"data['answer'] {case.gold}") result["pred"] = extracted_answer result["correct"] = 1 if extracted_answer == source_answer else 0 return result class ARC_Task(MCTaskSpec): def __init__(self): self.name = "arc" self.kind = "mc" self.config = "ARC-Challenge" self.split = "test" def load(self, limit, seed) -> datasets.Dataset: ds = datasets.load_dataset("allenai/ai2_arc", self.config, split=self.split) ds = ds.add_column("_row_id", list(range(len(ds)))) if limit: ds = ds.shuffle(seed=seed) ds = ds.select(range(min(limit, len(ds)))) return ds def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: ds = self.load(limit, seed) for doc in ds: doc = cast(Mapping[str, Any], doc) prompt, labels = format_multiple_choice( doc["question"], doc["choices"]["text"] ) yield Case( task=self.name, kind=self.kind, case_id=f"ARC-Challenge_{self.config}_{self.split}_{doc['_row_id']}", prompt=prompt, gold=doc["answerKey"], meta_data={"labels": labels}, ) class WinoGrande_Task(MCTaskSpec): def __init__(self): self.name = "winogrande" self.kind = "mc" self.config = "winogrande_debiased" self.split = "validation" def load(self, limit, seed) -> datasets.Dataset: ds = datasets.load_dataset("winogrande", self.config, split=self.split) ds = ds.add_column("_row_id", list(range(len(ds)))) if limit: ds = ds.shuffle(seed=seed) ds = ds.select(range(min(limit, len(ds)))) return ds def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: ds = self.load(limit, seed) for doc in ds: doc = cast(Mapping[str, Any], doc) prompt, labels = format_multiple_choice( doc["sentence"], [doc["option1"], doc["option2"]] ) yield Case( task=self.name, kind=self.kind, case_id=f"winogrande_{self.config}_{self.split}_{doc['_row_id']}", prompt=prompt, gold=labels[int(doc["answer"]) - 1], # winogrande answers are 1 based meta_data={"labels": labels}, ) class MMLU_Task(MCTaskSpec): def __init__(self): self.name = "mmlu" self.kind = "mc" self.config = "all" self.split = "test" def load(self, limit, seed) -> datasets.Dataset: ds = datasets.load_dataset("cais/mmlu", self.config, split=self.split) ds = ds.add_column("_row_id", list(range(len(ds)))) if limit: ds = ds.shuffle(seed=seed) ds = ds.select(range(min(limit, len(ds)))) return ds def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: ds = self.load(limit, seed) for doc in ds: doc = cast(Mapping[str, Any], doc) prompt, labels = format_multiple_choice(doc["question"], doc["choices"]) yield Case( task=self.name, kind=self.kind, case_id=f"mmlu_{self.config}_{self.split}_{doc['subject']}_{doc['_row_id']}", prompt=prompt, gold=labels[int(doc["answer"])], meta_data={"subject": doc["subject"], "labels": labels}, ) class Hellaswag_Task(MCTaskSpec): # Preprocess hellaswag @staticmethod def preprocess(text: str): text = text.strip() # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. text = text.replace(" [title]", ". ") text = re.sub("\\[.*?\\]", "", text) text = text.replace(" ", " ") return text @staticmethod def hellaswag_process_doc(doc: dict[str, str]): ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize() question = Hellaswag_Task.preprocess(doc["activity_label"] + ": " + ctx) proc_answers = [Hellaswag_Task.preprocess(answer) for answer in doc["endings"]] prompt, labels = format_multiple_choice(question, proc_answers) out_doc = { "prompt": prompt, "gold": labels[int(doc["label"])], } return out_doc def __init__(self): self.name = "hellaswag" self.kind = "mc" def load(self, limit, seed) -> datasets.Dataset: ds = datasets.load_dataset("Rowan/hellaswag", split="validation") if limit: ds = ds.shuffle(seed=seed) ds = ds.select(range(min(limit, len(ds)))) ds = ds.map(Hellaswag_Task.hellaswag_process_doc) return ds def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: ds = self.load(limit, seed) for doc in ds: doc = cast(Mapping[str, Any], doc) yield Case( task=self.name, kind=self.kind, case_id=f"hellaswag_{doc['split']}_{doc['ind']}", prompt=doc["prompt"], gold=doc["gold"], meta_data={}, ) class Aime_Task(MathTaskSpec): def __init__(self): self.name = "aime" self.kind = "math" self.split = "train" def load(self, limit, seed) -> datasets.Dataset: ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split) if limit: ds = ds.shuffle(seed=seed) ds = ds.select(range(min(limit, len(ds)))) ds = ds.map( lambda ex: { "prompt": MATH_TEMPLATE.format( question=ex["problem"], ) } ) return ds def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: ds = self.load(limit, seed) for i, doc in enumerate(ds): doc = cast(Mapping[str, Any], doc) yield Case( task=self.name, kind=self.kind, case_id=f"aime_{self.split}_{doc['id']}", prompt=doc["prompt"], gold=doc["answer"], meta_data={"id": doc["id"]}, ) class Gsm8k_Task(MathTaskSpec): def __init__(self): self.name = "gsm8k" self.kind = "math" self.config = "main" self.split = "test" def load(self, limit, seed) -> datasets.Dataset: ds = datasets.load_dataset("openai/gsm8k", self.config, split=self.split) ds = ds.add_column("_row_id", list(range(len(ds)))) if limit: ds = ds.shuffle(seed=seed) ds = ds.select(range(min(limit, len(ds)))) ds = ds.map( lambda k: { "prompt": MATH_TEMPLATE.format( question=k["question"], ), "gold": k["answer"].split("### ")[-1].rstrip(), } ) return ds def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: ds = self.load(limit, seed) for doc in ds: doc = cast(Mapping[str, Any], doc) yield Case( task=self.name, kind=self.kind, case_id=f"gsm8k_{self.config}_{self.split}:{doc['_row_id']}", prompt=doc["prompt"], gold=doc["gold"], meta_data={}, ) TASK_DICT: dict[str, type[TaskSpec]] = { "mmlu": MMLU_Task, "aime": Aime_Task, "gsm8k": Gsm8k_Task, "hellaswag": Hellaswag_Task, "arc": ARC_Task, "winogrande": WinoGrande_Task, } def build_request(case: Case, n_predict: int) -> dict[str, Any]: json_data = { "n_predict": n_predict, "max_tokens": n_predict, "temperature": 0, "prompt": case.prompt, } return json_data def write_checkpoint_line( checkpoint_file: Path, row: dict[str, Any], file_lock: threading.Lock, ): with file_lock: with checkpoint_file.open(mode="a", encoding="utf-8") as f: f.write(json.dumps(row) + "\n") def send_prompt( case: Case, data: dict, ) -> dict[str, Union[str, int]]: result = { "task": case.task, "case_id": case.case_id, "status": "error", "correct": 0, "gold": case.gold, "pred": "", "error": "", } session: requests.Session = data["session"] server_address: str = data["server_address"] task = TASK_DICT.get(case.task) if task is None: result["error"] = f"unknown_task: {case.task}" return result logger.debug(case.prompt) json_data = build_request(case, data["n_predict"]) res_json = {} try: response = session.post(f"{server_address}/v1/completions", json=json_data) res_json = response.json() result["status"] = "ok" except Exception as e: result["error"] = f"http_exception: {e}" logger.warning(result["error"]) if result["status"] == "ok": result = TASK_DICT[case.task].grade(case, res_json) write_checkpoint_line( data["checkpoint_file"], result.copy(), data["file_lock"], ) return result def aggregate_by_task(results: list[dict[str, Any]]) -> dict[str, dict[str, int]]: tmp = { "total": 0, "error": 0, "invalid": 0, "correct": 0, } agg: dict[str, dict[str, int]] = {} for row in results: d = agg.get(row["task"], tmp.copy()) d["total"] += 1 status = row["status"] if status == "ok": d["correct"] += row["correct"] elif status == "invalid": d["invalid"] += 1 elif status == "error": d["error"] += 1 agg[row["task"]] = d return agg def print_summary(pertask_results: dict[str, dict[str, int]]): print("\n=== llama-eval suite summary ===") print( f"{'Task':<15} {'Acc':>8} {'Correct':>8} {'Total':>8} {'Invalid':>8} {'Error':>8}" ) print("-" * 65) suite_total = 0 suite_correct = 0 for task in sorted(pertask_results.keys()): stats = pertask_results[task] total = stats["total"] correct = stats["correct"] invalid = stats["invalid"] error = stats["error"] acc = (correct / total) if total > 0 else 0.0 print( f"{task:<15} " f"{acc:8.3f} " f"{correct:8d} " f"{total:8d} " f"{invalid:8d} " f"{error:8d}" ) suite_total += total suite_correct += correct # Overall summary print("-" * 65) suite_acc = (suite_correct / suite_total) if suite_total > 0 else 0.0 print( f"{'ALL':<15} " f"{suite_acc:8.3f} " f"{suite_correct:8d} " f"{suite_total:8d}" ) def read_checkpoint( checkpoint_file: Path, resume_flag: bool ) -> tuple[Set[str], Set[str], list[dict[str, Any]]]: done = set() errored = set() results = [] if not resume_flag or not checkpoint_file.is_file(): return done, errored, results with checkpoint_file.open(mode="r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: row = json.loads(line) except Exception as e: logger.warning(f"WARNING: malformed checkpoint line {line}\n{e}") continue case_id = row.get("case_id") if not case_id: continue if row["status"] == "error": errored.add(case_id) else: done.add(case_id) results.append(row) errored -= done return done, errored, results def benchmark( path_server: str, prompt_source: str, n_prompts: int, n_predict: int, rng_seed: int, resume_flag: bool, checkpoint_file: Path, log_level: int, ): logger.setLevel(log_level) done, errored, checkpoint_results = read_checkpoint(checkpoint_file, resume_flag) if not path_server.startswith("http://") and not path_server.startswith("https://"): logger.error("ERROR: malformed server path") return if os.environ.get("LLAMA_ARG_N_PARALLEL") is None: logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32") os.environ["LLAMA_ARG_N_PARALLEL"] = "32" parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore task_queue: set[TaskSpec] = set() for src in prompt_source.split(","): if src == "all": for v in TASK_DICT.values(): task_queue.add(v()) break task_queue.add(TASK_DICT[src]()) session = None try: server_address: str = path_server adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore session = requests.Session() session.mount("http://", adapter) session.mount("https://", adapter) file_lock = threading.Lock() cases: list[Case] = [] data: list[dict] = [] for task in task_queue: for case in task.iter_cases(n_prompts, rng_seed): if case.case_id in done or case.case_id in errored: logger.debug(f"Skipping case_id {case.case_id} from checkpoint") continue cases.append(case) data.append( { "prompt_source": prompt_source, "session": session, "server_address": server_address, "n_predict": n_predict, "file_lock": file_lock, "checkpoint_file": checkpoint_file, } ) logger.info("Starting the benchmark...\n") t0 = time() results: list[dict[str, Union[str, int]]] = thread_map( send_prompt, cases, data, max_workers=parallel, chunksize=1, ) finally: if session is not None: session.close() t1 = time() logger.info(f"\nllama-eval duration: {t1-t0:.2f} s") results.extend(checkpoint_results) pertask_results = aggregate_by_task(results) print_summary(pertask_results) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Tool for benchmarking the throughput of the llama.cpp HTTP server. " "Results are printed to console and visualized as plots (saved to current working directory). " "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help). " "The reported numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, " "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model)." ) parser.add_argument( "--path_server", type=str, default="http://localhost:8033", help="llama-server url", ) parser.add_argument( "--prompt_source", type=str, default="mmlu", help=f"Eval types supported: all,{list(TASK_DICT.keys())}", ) parser.add_argument( "--n_prompts", type=int, default=None, help="Number of prompts to evaluate" ) parser.add_argument( "--rng_seed", type=int, default=42, help="Number to see rng (Used to select prompts from datasource)", ) parser.add_argument( "--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt", ) parser.add_argument( "--resume", dest="resume_flag", action="store_true", default=True, help="Enable resuming from last state stored in checkpoint file", ) parser.add_argument( "--no-resume", dest="resume_flag", action="store_false", help="Disble resuming from last state stored in checkpoint file", ) parser.add_argument( "--checkpoint-file", type=Path, dest="checkpoint_file", default="./llama-eval-checkpoint.jsonl", help="Checkpoint file to read last state from", ) parser.set_defaults(log_level=logging.INFO) parser.add_argument( "--quiet", action="store_const", dest="log_level", const=logging.ERROR ) parser.add_argument( "--debug", action="store_const", default=True, dest="log_level", const=logging.DEBUG, ) args = parser.parse_args() benchmark(**vars(args))