#!/usr/bin/env python3 import re import argparse import json import os import random import subprocess from time import sleep, time from typing import Optional, Union import datasets import logging import requests from tqdm.contrib.concurrent import thread_map from typing import Iterator from abc import ABC logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger("llama-eval") MATH_TEMPLATE = """ {question} Put your final answer within \\boxed{{}}. """ MC_FROM_INT = { 0: "A", 1: "B", 2: "C", 3: "D", } def format_multiple_choice(prompt: str, choices: list[str]): QUERY_TEMPLATE_MULTICHOICE = """ {question} (A) {A} (B) {B} (C) {C} (D) {D} Express your final answer as the corresponding option 'A', 'B', 'C', or 'D'. Put your final answer within \\boxed{{}}. """.strip() A_str = choices[0] B_str = choices[1] C_str = choices[2] D_str = choices[3] query = QUERY_TEMPLATE_MULTICHOICE.format( question=prompt, A=A_str, B=B_str, C=C_str, D=D_str ) return query # Preprocess hellaswag def preprocess(text): text = text.strip() # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. text = text.replace(" [title]", ". ") text = re.sub("\\[.*?\\]", "", text) text = text.replace(" ", " ") return text def hellaswag_process_doc(doc): ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize() question = preprocess(doc["activity_label"] + ": " + ctx) proc_answers = [preprocess(answer) for answer in doc["endings"]] prompt = format_multiple_choice(question, proc_answers) out_doc = { "prompt": prompt, "gold": MC_FROM_INT[int(doc["label"])], } return out_doc def mmlu_process_doc(doc): prompt = format_multiple_choice(doc["question"], doc["choices"]) out_doc = { "prompt": prompt, "gold": MC_FROM_INT[int(doc["answer"])], } return out_doc def extract_boxed_text(text): pattern = r"boxed{(.*?)}|framebox{(.*?)}" matches = re.findall(pattern, text, re.DOTALL) logger.debug(matches) if matches: for match in matches[::-1]: for group in match: if group != "": return group.split(",")[-1].strip() logger.warning( "Could not extract boxed text. Using last integer. Maybe expand context window" ) pattern = r"\d+" # get the last integer if no pattern found matches = re.findall(pattern, text, re.DOTALL) if matches: return matches[-1] return "" def get_prompts_text( dataset_name: str, ds: datasets.Dataset ) -> Optional[tuple[list[str], list[str]]]: ret = [] if dataset_name.lower() == "mmlu": ds = ds.map(mmlu_process_doc) ret = ds["prompt"], ds["gold"] elif dataset_name.lower() == "hellaswag": ds = ds.map(hellaswag_process_doc) ret = ds["prompt"], ds["gold"] elif dataset_name.lower() == "aime": ds = ds.map( lambda k: { "prompt": MATH_TEMPLATE.format( question=k["problem"], ) } ) ret = ds["prompt"], ds["answer"] elif dataset_name.lower() == "gsm8k": ds = ds.map(lambda k: {"prompt": MATH_TEMPLATE.format(question=k["question"])}) la = [] for answer in ds["answer"]: la.append(answer.split("### ")[-1].rstrip()) ret = ds["prompt"], la else: return None return ret def get_dataset( dataset_name: str, n_prompts: int, rng_seed: int ) -> Optional[datasets.Dataset]: ds = None cache_dir = "./build/bin/datasets" logger.info(f"Loading {dataset_name.lower()} dataset...") if dataset_name.lower() == "mmlu": ds = datasets.load_dataset( "cais/mmlu", "all", split="test", cache_dir=cache_dir ) elif dataset_name.lower() == "hellaswag": ds = datasets.load_dataset( "Rowan/hellaswag", split="validation", cache_dir=cache_dir ) elif dataset_name.lower() == "aime": ds = datasets.load_dataset( "AI-MO/aimo-validation-aime", split="train", cache_dir=cache_dir ) elif dataset_name.lower() == "gsm8k": ds = datasets.load_dataset("openai/gsm8k", split="test") else: return None if n_prompts >= 0: ds = ds.shuffle(seed=rng_seed) ds = ds.select(range(min(n_prompts, len(ds)))) return ds def send_prompt(data: dict) -> int: session = data["session"] server_address: str = data["server_address"] prompt: str = data["prompt"] logger.info(f"data['external_server'] {data['external_server']}") logger.info(f"data['prompt'] {prompt}") logger.info(f"data['n_predict'] {data['n_predict']}") json_data: dict = { "prompt": prompt, "max_tokens": data["n_predict"], "temperature": 0, } response = session.post(f"{server_address}/v1/completions", json=json_data) res = json.loads(response.text) logger.info(f"response {res}") extracted_answer = extract_boxed_text(res["choices"][0]["text"]) source_answer = data["answer"] if data["prompt_source"] == "aime" or data["prompt_source"] == "gsm8k": try: # All AIME answers are integers, so we convert the extracted answer to an integer extracted_answer = int(extracted_answer) source_answer = int(source_answer) except (ValueError, TypeError): extracted_answer = None logger.info(f"extracted_answer {extracted_answer}") logger.info(f"data['answer'] {data['answer']}") score = 1 if extracted_answer == source_answer else 0 return score def get_server(path_server: str, path_log: Optional[str]) -> dict: if path_server.startswith("http://") or path_server.startswith("https://"): return {"process": None, "address": path_server, "fout": None} if os.environ.get("LLAMA_ARG_HOST") is None: logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1") os.environ["LLAMA_ARG_HOST"] = "127.0.0.1" if os.environ.get("LLAMA_ARG_PORT") is None: logger.info("LLAMA_ARG_PORT not explicitly set, using 8080") os.environ["LLAMA_ARG_PORT"] = "8080" hostname: Optional[str] = os.environ.get("LLAMA_ARG_HOST") port: Optional[str] = os.environ.get("LLAMA_ARG_PORT") assert hostname is not None assert port is not None address: str = f"http://{hostname}:{port}" logger.info(f"Starting the llama.cpp server under {address}...") fout = open(path_log.format(port=port), "w") if path_log is not None else subprocess.DEVNULL process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT) n_failures: int = 0 while True: try: sleep(1.0) exit_code = process.poll() if exit_code is not None: raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}{path_log and f', see {path_log.format(port=port)}' or ''}") response = requests.get(f"{address}/health") if response.status_code == 200: break except requests.ConnectionError: n_failures += 1 if n_failures >= 10: raise RuntimeError("llama.cpp server is not healthy after 10 seconds") return {"process": process, "address": address, "fout": fout} def benchmark( path_server: str, path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, rng_seed: int, ): external_server: bool = path_server.startswith("http://") or path_server.startswith("https://") if os.environ.get("LLAMA_ARG_N_PARALLEL") is None: logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32") os.environ["LLAMA_ARG_N_PARALLEL"] = "32" parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore ds: Union[datasets.Dataset, None] = get_dataset(prompt_source, n_prompts, rng_seed) if not ds: logger.error("ERROR: get_dataset") exit(0) res: Union[tuple[list[str], list[str]], None] = get_prompts_text(prompt_source, ds) if not res: logger.error("ERROR: get_prompts_text") exit(0) prompts: Union[list[str], list[list[int]]] = res[0] answer: Union[list[str], list[list[int]]] = res[1] logger.info(prompts) logger.info(f"external_server {external_server}") server: Optional[dict] = None session = None try: server = get_server(path_server, path_log) server_address: str = server["address"] assert external_server == (server["process"] is None) adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore session = requests.Session() session.mount("http://", adapter) session.mount("https://", adapter) data: list[dict] = [] for p, a in zip(prompts, answer): data.append( { "prompt_source": prompt_source, "session": session, "server_address": server_address, "external_server": external_server, "prompt": p, "answer": a, "n_predict": n_predict, } ) logger.info("Starting the benchmark...\n") t0 = time() results: list[int] = thread_map( send_prompt, data, max_workers=parallel, chunksize=1 ) finally: if server is not None and server["process"] is not None: server["process"].terminate() server["process"].wait() if session is not None: session.close() t1 = time() correct: int = sum(results) total_questions: int = len(data) logger.info(f"llama-eval duration: {t1-t0:.2f} s") logger.info(f"{prompt_source} correct: {correct}") logger.info(f"{prompt_source} total_questions: {total_questions}") logger.info(f"{prompt_source} accuracy: {correct / total_questions}") if __name__ == "__main__": parser = argparse.ArgumentParser( description="Tool for benchmarking the throughput of the llama.cpp HTTP server. " "Results are printed to console and visualized as plots (saved to current working directory). " "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help). " "The reported numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, " "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model)." ) parser.add_argument( "--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary", ) parser.add_argument( "--path_log", type=str, default="server-bench-{port}.log", help="Path to the model to use for the benchmark", ) parser.add_argument( "--prompt_source", type=str, default="mmlu", help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions", ) parser.add_argument( "--n_prompts", type=int, default=100, help="Number of prompts to evaluate" ) parser.add_argument( "--rng_seed", type=int, default=42, help="Number to see rng (Used to select prompts from datasource)", ) parser.add_argument( "--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt", ) args = parser.parse_args() benchmark(**vars(args))