working llama-eval mc and math suite
This commit is contained in:
parent
707cbafcaa
commit
2357f6f193
|
|
@ -0,0 +1,358 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import re
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
from time import sleep, time
|
||||
from typing import Optional, Union
|
||||
|
||||
import datasets
|
||||
import logging
|
||||
import requests
|
||||
from tqdm.contrib.concurrent import thread_map
|
||||
from typing import Iterator
|
||||
from abc import ABC
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
logger = logging.getLogger("llama-eval")
|
||||
|
||||
|
||||
MATH_TEMPLATE = """
|
||||
{question}
|
||||
Put your final answer within \\boxed{{}}.
|
||||
"""
|
||||
|
||||
MC_FROM_INT = {
|
||||
0: "A",
|
||||
1: "B",
|
||||
2: "C",
|
||||
3: "D",
|
||||
}
|
||||
|
||||
|
||||
def format_multiple_choice(prompt: str, choices: list[str]):
|
||||
QUERY_TEMPLATE_MULTICHOICE = """
|
||||
{question}
|
||||
|
||||
(A) {A}
|
||||
(B) {B}
|
||||
(C) {C}
|
||||
(D) {D}
|
||||
|
||||
Express your final answer as the corresponding option 'A', 'B', 'C', or 'D'. Put your final answer within \\boxed{{}}.
|
||||
|
||||
""".strip()
|
||||
A_str = choices[0]
|
||||
B_str = choices[1]
|
||||
C_str = choices[2]
|
||||
D_str = choices[3]
|
||||
query = QUERY_TEMPLATE_MULTICHOICE.format(
|
||||
question=prompt, A=A_str, B=B_str, C=C_str, D=D_str
|
||||
)
|
||||
return query
|
||||
|
||||
|
||||
# Preprocess hellaswag
|
||||
def preprocess(text):
|
||||
text = text.strip()
|
||||
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
|
||||
text = text.replace(" [title]", ". ")
|
||||
text = re.sub("\\[.*?\\]", "", text)
|
||||
text = text.replace(" ", " ")
|
||||
return text
|
||||
|
||||
|
||||
def hellaswag_process_doc(doc):
|
||||
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
|
||||
question = preprocess(doc["activity_label"] + ": " + ctx)
|
||||
proc_answers = [preprocess(answer) for answer in doc["endings"]]
|
||||
prompt = format_multiple_choice(question, proc_answers)
|
||||
out_doc = {
|
||||
"prompt": prompt,
|
||||
"gold": MC_FROM_INT[int(doc["label"])],
|
||||
}
|
||||
return out_doc
|
||||
|
||||
|
||||
def mmlu_process_doc(doc):
|
||||
prompt = format_multiple_choice(doc["question"], doc["choices"])
|
||||
out_doc = {
|
||||
"prompt": prompt,
|
||||
"gold": MC_FROM_INT[int(doc["answer"])],
|
||||
}
|
||||
return out_doc
|
||||
|
||||
|
||||
def extract_boxed_text(text):
|
||||
pattern = r"boxed{(.*?)}|framebox{(.*?)}"
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
logger.debug(matches)
|
||||
if matches:
|
||||
for match in matches[::-1]:
|
||||
for group in match:
|
||||
if group != "":
|
||||
return group.split(",")[-1].strip()
|
||||
logger.warning(
|
||||
"Could not extract boxed text. Using last integer. Maybe expand context window"
|
||||
)
|
||||
pattern = r"\d+" # get the last integer if no pattern found
|
||||
matches = re.findall(pattern, text, re.DOTALL)
|
||||
if matches:
|
||||
return matches[-1]
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def get_prompts_text(
|
||||
dataset_name: str, ds: datasets.Dataset
|
||||
) -> Optional[tuple[list[str], list[str]]]:
|
||||
ret = []
|
||||
if dataset_name.lower() == "mmlu":
|
||||
ds = ds.map(mmlu_process_doc)
|
||||
ret = ds["prompt"], ds["gold"]
|
||||
elif dataset_name.lower() == "hellaswag":
|
||||
ds = ds.map(hellaswag_process_doc)
|
||||
ret = ds["prompt"], ds["gold"]
|
||||
elif dataset_name.lower() == "aime":
|
||||
ds = ds.map(
|
||||
lambda k: {
|
||||
"prompt": MATH_TEMPLATE.format(
|
||||
question=k["problem"],
|
||||
)
|
||||
}
|
||||
)
|
||||
ret = ds["prompt"], ds["answer"]
|
||||
elif dataset_name.lower() == "gsm8k":
|
||||
ds = ds.map(lambda k: {"prompt": MATH_TEMPLATE.format(question=k["question"])})
|
||||
la = []
|
||||
for answer in ds["answer"]:
|
||||
la.append(answer.split("### ")[-1].rstrip())
|
||||
ret = ds["prompt"], la
|
||||
else:
|
||||
return None
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def get_dataset(
|
||||
dataset_name: str, n_prompts: int, rng_seed: int
|
||||
) -> Optional[datasets.Dataset]:
|
||||
ds = None
|
||||
cache_dir = "./build/bin/datasets"
|
||||
logger.info(f"Loading {dataset_name.lower()} dataset...")
|
||||
if dataset_name.lower() == "mmlu":
|
||||
ds = datasets.load_dataset(
|
||||
"cais/mmlu", "all", split="test", cache_dir=cache_dir
|
||||
)
|
||||
elif dataset_name.lower() == "hellaswag":
|
||||
ds = datasets.load_dataset(
|
||||
"Rowan/hellaswag", split="validation", cache_dir=cache_dir
|
||||
)
|
||||
elif dataset_name.lower() == "aime":
|
||||
ds = datasets.load_dataset(
|
||||
"AI-MO/aimo-validation-aime", split="train", cache_dir=cache_dir
|
||||
)
|
||||
elif dataset_name.lower() == "gsm8k":
|
||||
ds = datasets.load_dataset("openai/gsm8k", split="test")
|
||||
else:
|
||||
return None
|
||||
|
||||
if n_prompts >= 0:
|
||||
ds = ds.shuffle(seed=rng_seed)
|
||||
ds = ds.select(range(min(n_prompts, len(ds))))
|
||||
return ds
|
||||
|
||||
|
||||
def send_prompt(data: dict) -> int:
|
||||
session = data["session"]
|
||||
server_address: str = data["server_address"]
|
||||
prompt: str = data["prompt"]
|
||||
logger.info(f"data['external_server'] {data['external_server']}")
|
||||
logger.info(f"data['prompt'] {prompt}")
|
||||
logger.info(f"data['n_predict'] {data['n_predict']}")
|
||||
|
||||
json_data: dict = {
|
||||
"prompt": prompt,
|
||||
"max_tokens": data["n_predict"],
|
||||
"temperature": 0,
|
||||
}
|
||||
response = session.post(f"{server_address}/v1/completions", json=json_data)
|
||||
res = json.loads(response.text)
|
||||
logger.info(f"response {res}")
|
||||
extracted_answer = extract_boxed_text(res["choices"][0]["text"])
|
||||
source_answer = data["answer"]
|
||||
if data["prompt_source"] == "aime" or data["prompt_source"] == "gsm8k":
|
||||
try: # All AIME answers are integers, so we convert the extracted answer to an integer
|
||||
extracted_answer = int(extracted_answer)
|
||||
source_answer = int(source_answer)
|
||||
except (ValueError, TypeError):
|
||||
extracted_answer = None
|
||||
logger.info(f"extracted_answer {extracted_answer}")
|
||||
logger.info(f"data['answer'] {data['answer']}")
|
||||
|
||||
score = 1 if extracted_answer == source_answer else 0
|
||||
|
||||
return score
|
||||
|
||||
|
||||
def get_server(path_server: str, path_log: Optional[str]) -> dict:
|
||||
if path_server.startswith("http://") or path_server.startswith("https://"):
|
||||
return {"process": None, "address": path_server, "fout": None}
|
||||
if os.environ.get("LLAMA_ARG_HOST") is None:
|
||||
logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1")
|
||||
os.environ["LLAMA_ARG_HOST"] = "127.0.0.1"
|
||||
if os.environ.get("LLAMA_ARG_PORT") is None:
|
||||
logger.info("LLAMA_ARG_PORT not explicitly set, using 8080")
|
||||
os.environ["LLAMA_ARG_PORT"] = "8080"
|
||||
hostname: Optional[str] = os.environ.get("LLAMA_ARG_HOST")
|
||||
port: Optional[str] = os.environ.get("LLAMA_ARG_PORT")
|
||||
assert hostname is not None
|
||||
assert port is not None
|
||||
address: str = f"http://{hostname}:{port}"
|
||||
logger.info(f"Starting the llama.cpp server under {address}...")
|
||||
|
||||
fout = open(path_log.format(port=port), "w") if path_log is not None else subprocess.DEVNULL
|
||||
process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT)
|
||||
|
||||
n_failures: int = 0
|
||||
while True:
|
||||
try:
|
||||
sleep(1.0)
|
||||
exit_code = process.poll()
|
||||
if exit_code is not None:
|
||||
raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}{path_log and f', see {path_log.format(port=port)}' or ''}")
|
||||
response = requests.get(f"{address}/health")
|
||||
if response.status_code == 200:
|
||||
break
|
||||
except requests.ConnectionError:
|
||||
n_failures += 1
|
||||
if n_failures >= 10:
|
||||
raise RuntimeError("llama.cpp server is not healthy after 10 seconds")
|
||||
|
||||
return {"process": process, "address": address, "fout": fout}
|
||||
|
||||
|
||||
def benchmark(
|
||||
path_server: str,
|
||||
path_log: Optional[str],
|
||||
prompt_source: str,
|
||||
n_prompts: int,
|
||||
n_predict: int,
|
||||
rng_seed: int,
|
||||
):
|
||||
external_server: bool = path_server.startswith("http://") or path_server.startswith("https://")
|
||||
if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
|
||||
logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
|
||||
os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
|
||||
|
||||
parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore
|
||||
ds: Union[datasets.Dataset, None] = get_dataset(prompt_source, n_prompts, rng_seed)
|
||||
if not ds:
|
||||
logger.error("ERROR: get_dataset")
|
||||
exit(0)
|
||||
|
||||
res: Union[tuple[list[str], list[str]], None] = get_prompts_text(prompt_source, ds)
|
||||
if not res:
|
||||
logger.error("ERROR: get_prompts_text")
|
||||
exit(0)
|
||||
|
||||
prompts: Union[list[str], list[list[int]]] = res[0]
|
||||
answer: Union[list[str], list[list[int]]] = res[1]
|
||||
|
||||
logger.info(prompts)
|
||||
logger.info(f"external_server {external_server}")
|
||||
|
||||
server: Optional[dict] = None
|
||||
session = None
|
||||
try:
|
||||
server = get_server(path_server, path_log)
|
||||
server_address: str = server["address"]
|
||||
assert external_server == (server["process"] is None)
|
||||
|
||||
adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore
|
||||
session = requests.Session()
|
||||
session.mount("http://", adapter)
|
||||
session.mount("https://", adapter)
|
||||
|
||||
data: list[dict] = []
|
||||
for p, a in zip(prompts, answer):
|
||||
data.append(
|
||||
{
|
||||
"prompt_source": prompt_source,
|
||||
"session": session,
|
||||
"server_address": server_address,
|
||||
"external_server": external_server,
|
||||
"prompt": p,
|
||||
"answer": a,
|
||||
"n_predict": n_predict,
|
||||
}
|
||||
)
|
||||
|
||||
logger.info("Starting the benchmark...\n")
|
||||
t0 = time()
|
||||
results: list[int] = thread_map(
|
||||
send_prompt, data, max_workers=parallel, chunksize=1
|
||||
)
|
||||
finally:
|
||||
if server is not None and server["process"] is not None:
|
||||
server["process"].terminate()
|
||||
server["process"].wait()
|
||||
if session is not None:
|
||||
session.close()
|
||||
|
||||
t1 = time()
|
||||
|
||||
correct: int = sum(results)
|
||||
total_questions: int = len(data)
|
||||
logger.info(f"llama-eval duration: {t1-t0:.2f} s")
|
||||
logger.info(f"{prompt_source} correct: {correct}")
|
||||
logger.info(f"{prompt_source} total_questions: {total_questions}")
|
||||
logger.info(f"{prompt_source} accuracy: {correct / total_questions}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
|
||||
"Results are printed to console and visualized as plots (saved to current working directory). "
|
||||
"To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help). "
|
||||
"The reported numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
|
||||
"particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--path_server",
|
||||
type=str,
|
||||
default="llama-server",
|
||||
help="Path to the llama.cpp server binary",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--path_log",
|
||||
type=str,
|
||||
default="server-bench-{port}.log",
|
||||
help="Path to the model to use for the benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt_source",
|
||||
type=str,
|
||||
default="mmlu",
|
||||
help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_prompts", type=int, default=100, help="Number of prompts to evaluate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rng_seed",
|
||||
type=int,
|
||||
default=42,
|
||||
help="Number to see rng (Used to select prompts from datasource)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_predict",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Max. number of tokens to predict per prompt",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
benchmark(**vars(args))
|
||||
Loading…
Reference in New Issue