multi source llama-eval

This commit is contained in:
gatbontonpc 2026-01-12 13:47:43 -05:00
parent 2357f6f193
commit f3a5b4ea72
1 changed files with 472 additions and 233 deletions

View File

@ -2,91 +2,43 @@
import re
import argparse
import json
import os
import random
import subprocess
from time import sleep, time
from typing import Optional, Union
from time import time
from typing import Union, Any, Mapping, cast
import datasets
import logging
import requests
from tqdm.contrib.concurrent import thread_map
from typing import Iterator
from abc import ABC
from abc import ABC, abstractmethod
from dataclasses import dataclass
logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger("llama-eval")
MATH_TEMPLATE = """
{question}
Put your final answer within \\boxed{{}}.
Do not include any explanation. Put your final answer within \\boxed{{}}.
"""
MC_FROM_INT = {
0: "A",
1: "B",
2: "C",
3: "D",
}
def format_multiple_choice(prompt: str, choices: list[str]):
QUERY_TEMPLATE_MULTICHOICE = """
{question}
lines = [prompt]
(A) {A}
(B) {B}
(C) {C}
(D) {D}
Express your final answer as the corresponding option 'A', 'B', 'C', or 'D'. Put your final answer within \\boxed{{}}.
""".strip()
A_str = choices[0]
B_str = choices[1]
C_str = choices[2]
D_str = choices[3]
query = QUERY_TEMPLATE_MULTICHOICE.format(
question=prompt, A=A_str, B=B_str, C=C_str, D=D_str
labels = [chr(ord("A") + i) for i in range(len(choices))]
for l, c in zip(labels, choices):
lines.append(f"({l}): {c.strip()}")
lines.append(
"Do not include any explanation. Answer with the corresponding option letter only"
)
return query
lines.append(", ".join(labels))
lines.append("Put your final answer within \\boxed{{}}.")
return "\n".join(lines), labels
# Preprocess hellaswag
def preprocess(text):
text = text.strip()
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
text = text.replace(" [title]", ". ")
text = re.sub("\\[.*?\\]", "", text)
text = text.replace(" ", " ")
return text
def hellaswag_process_doc(doc):
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
question = preprocess(doc["activity_label"] + ": " + ctx)
proc_answers = [preprocess(answer) for answer in doc["endings"]]
prompt = format_multiple_choice(question, proc_answers)
out_doc = {
"prompt": prompt,
"gold": MC_FROM_INT[int(doc["label"])],
}
return out_doc
def mmlu_process_doc(doc):
prompt = format_multiple_choice(doc["question"], doc["choices"])
out_doc = {
"prompt": prompt,
"gold": MC_FROM_INT[int(doc["answer"])],
}
return out_doc
def extract_boxed_text(text):
def extract_boxed_text(text: str) -> str:
pattern = r"boxed{(.*?)}|framebox{(.*?)}"
matches = re.findall(pattern, text, re.DOTALL)
logger.debug(matches)
@ -95,222 +47,515 @@ def extract_boxed_text(text):
for group in match:
if group != "":
return group.split(",")[-1].strip()
logger.warning(
"Could not extract boxed text. Using last integer. Maybe expand context window"
)
pattern = r"\d+" # get the last integer if no pattern found
matches = re.findall(pattern, text, re.DOTALL)
if matches:
return matches[-1]
logger.warning("Could not extract boxed text. Maybe expand context window")
return ""
def get_prompts_text(
dataset_name: str, ds: datasets.Dataset
) -> Optional[tuple[list[str], list[str]]]:
ret = []
if dataset_name.lower() == "mmlu":
ds = ds.map(mmlu_process_doc)
ret = ds["prompt"], ds["gold"]
elif dataset_name.lower() == "hellaswag":
ds = ds.map(hellaswag_process_doc)
ret = ds["prompt"], ds["gold"]
elif dataset_name.lower() == "aime":
@dataclass(frozen=True)
class Case:
task: str
kind: str
case_id: str
prompt: str
gold: str
meta_data: dict[str, Any]
class TaskSpec(ABC):
name: str
kind: str
@abstractmethod
def load(self, limit, seed) -> datasets.Dataset:
pass
@abstractmethod
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
pass
@staticmethod
@abstractmethod
def grade(case: Case, response: dict) -> dict[str, Any]:
pass
class MCTaskSpec(TaskSpec):
@staticmethod
def grade(case: Case, response: dict) -> dict[str, Any]:
logger.debug(f"response {response}")
result = {
"task": case.task,
"case_id": case.case_id,
"correct": 0,
"pred": None,
"gold": case.gold,
"status": "ok",
}
try:
extracted_answer = extract_boxed_text(response["choices"][0]["text"])
except Exception as e:
result["status"] = "error"
logger.warning("ERROR: extract_boxed_text")
return result
if not extracted_answer:
result["status"] = "invalid"
logger.warning("INVALID: extract_boxed_text")
return result
logger.debug(f"extracted_answer {extracted_answer}")
logger.debug(f"data['answer'] {case.gold}")
result["pred"] = extracted_answer
result["correct"] = 1 if extracted_answer == case.gold else 0
return result
class MathTaskSpec(TaskSpec):
@staticmethod
def grade(case: Case, response: dict) -> dict[str, Any]:
logger.debug(f"response {response}")
result = {
"task": case.task,
"case_id": case.case_id,
"correct": 0,
"gold": case.gold,
"status": "ok",
"pred": None,
}
try:
extracted_answer = extract_boxed_text(response["choices"][0]["text"])
except Exception as e:
result["status"] = "error"
return result
source_answer = case.gold
try: # All AIME answers are integers, so we convert the extracted answer to an integer
extracted_answer = int(extracted_answer)
source_answer = int(case.gold)
except (ValueError, TypeError):
result["status"] = "invalid"
return result
logger.debug(f"extracted_answer {extracted_answer}")
logger.debug(f"data['answer'] {case.gold}")
result["pred"] = extracted_answer
result["correct"] = 1 if extracted_answer == source_answer else 0
return result
class ARC_Task(MCTaskSpec):
def __init__(self):
self.name = "arc"
self.kind = "mc"
def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset("allenai/ai2_arc", "ARC-Challenge", split="test")
if limit:
ds = ds.shuffle(seed=seed)
ds = ds.select(range(min(limit, len(ds))))
return ds
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = self.load(limit, seed)
for i, doc in enumerate(ds):
doc = cast(Mapping[str, Any], doc)
prompt, labels = format_multiple_choice(
doc["question"], doc["choices"]["text"]
)
yield Case(
task=self.name,
kind=self.kind,
case_id=f"ARC-Challenge:{i}",
prompt=prompt,
gold=doc["answerKey"],
meta_data={"labels": labels},
)
class WinoGrande_Task(MCTaskSpec):
def __init__(self):
self.name = "winogrande"
self.kind = "mc"
def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset(
"winogrande", "winogrande_debiased", split="validation"
)
if limit:
ds = ds.shuffle(seed=seed)
ds = ds.select(range(min(limit, len(ds))))
return ds
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = self.load(limit, seed)
for i, doc in enumerate(ds):
doc = cast(Mapping[str, Any], doc)
prompt, labels = format_multiple_choice(
doc["sentence"], [doc["option1"], doc["option2"]]
)
yield Case(
task=self.name,
kind=self.kind,
case_id=f"winogrande:{i}",
prompt=prompt,
gold=labels[int(doc["answer"]) - 1], # winogrande answers are 1 based
meta_data={"labels": labels},
)
class MMLU_Task(MCTaskSpec):
def __init__(self):
self.name = "mmlu"
self.kind = "mc"
def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset("cais/mmlu", "all", split="test")
if limit:
ds = ds.shuffle(seed=seed)
ds = ds.select(range(min(limit, len(ds))))
return ds
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = self.load(limit, seed)
for i, doc in enumerate(ds):
doc = cast(Mapping[str, Any], doc)
prompt, labels = format_multiple_choice(doc["question"], doc["choices"])
yield Case(
task=self.name,
kind=self.kind,
case_id=f"mmlu:{doc['subject']}:{i}",
prompt=prompt,
gold=labels[int(doc["answer"])],
meta_data={"subject": doc["subject"], "labels": labels},
)
class Hellaswag_Task(MCTaskSpec):
# Preprocess hellaswag
@staticmethod
def preprocess(text: str):
text = text.strip()
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
text = text.replace(" [title]", ". ")
text = re.sub("\\[.*?\\]", "", text)
text = text.replace(" ", " ")
return text
@staticmethod
def hellaswag_process_doc(doc: dict[str, str]):
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
question = Hellaswag_Task.preprocess(doc["activity_label"] + ": " + ctx)
proc_answers = [Hellaswag_Task.preprocess(answer) for answer in doc["endings"]]
prompt, labels = format_multiple_choice(question, proc_answers)
out_doc = {
"prompt": prompt,
"gold": labels[int(doc["label"])],
}
return out_doc
def __init__(self):
self.name = "hellaswag"
self.kind = "mc"
def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset("Rowan/hellaswag", split="validation")
if limit:
ds = ds.shuffle(seed=seed)
ds = ds.select(range(min(limit, len(ds))))
ds = ds.map(Hellaswag_Task.hellaswag_process_doc)
return ds
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = self.load(limit, seed)
for i, doc in enumerate(ds):
doc = cast(Mapping[str, Any], doc)
yield Case(
task=self.name,
kind=self.kind,
case_id=f"hellaswag:{i}",
prompt=doc["prompt"],
gold=doc["gold"],
meta_data={},
)
class Aime_Task(MathTaskSpec):
def __init__(self):
self.name = "aime"
self.kind = "math"
def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split="train")
if limit:
ds = ds.shuffle(seed=seed)
ds = ds.select(range(min(limit, len(ds))))
ds = ds.map(
lambda k: {
lambda ex: {
"prompt": MATH_TEMPLATE.format(
question=k["problem"],
question=ex["problem"],
)
}
)
ret = ds["prompt"], ds["answer"]
elif dataset_name.lower() == "gsm8k":
ds = ds.map(lambda k: {"prompt": MATH_TEMPLATE.format(question=k["question"])})
la = []
for answer in ds["answer"]:
la.append(answer.split("### ")[-1].rstrip())
ret = ds["prompt"], la
else:
return None
return ds
return ret
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = self.load(limit, seed)
for i, doc in enumerate(ds):
doc = cast(Mapping[str, Any], doc)
yield Case(
task=self.name,
kind=self.kind,
case_id=f"aime:{i}",
prompt=doc["prompt"],
gold=doc["answer"],
meta_data={},
)
def get_dataset(
dataset_name: str, n_prompts: int, rng_seed: int
) -> Optional[datasets.Dataset]:
ds = None
cache_dir = "./build/bin/datasets"
logger.info(f"Loading {dataset_name.lower()} dataset...")
if dataset_name.lower() == "mmlu":
ds = datasets.load_dataset(
"cais/mmlu", "all", split="test", cache_dir=cache_dir
class Gsm8k_Task(MathTaskSpec):
def __init__(self):
self.name = "gsm8k"
self.kind = "math"
def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset("openai/gsm8k", "main", split="test")
if limit:
ds = ds.shuffle(seed=seed)
ds = ds.select(range(min(limit, len(ds))))
ds = ds.map(
lambda k: {
"prompt": MATH_TEMPLATE.format(
question=k["question"],
),
"gold": k["answer"].split("### ")[-1].rstrip(),
}
)
elif dataset_name.lower() == "hellaswag":
ds = datasets.load_dataset(
"Rowan/hellaswag", split="validation", cache_dir=cache_dir
)
elif dataset_name.lower() == "aime":
ds = datasets.load_dataset(
"AI-MO/aimo-validation-aime", split="train", cache_dir=cache_dir
)
elif dataset_name.lower() == "gsm8k":
ds = datasets.load_dataset("openai/gsm8k", split="test")
else:
return None
return ds
if n_prompts >= 0:
ds = ds.shuffle(seed=rng_seed)
ds = ds.select(range(min(n_prompts, len(ds))))
return ds
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = self.load(limit, seed)
for i, doc in enumerate(ds):
doc = cast(Mapping[str, Any], doc)
yield Case(
task=self.name,
kind=self.kind,
case_id=f"gsm8k:{i}",
prompt=doc["prompt"],
gold=doc["gold"],
meta_data={},
)
def send_prompt(data: dict) -> int:
session = data["session"]
server_address: str = data["server_address"]
prompt: str = data["prompt"]
logger.info(f"data['external_server'] {data['external_server']}")
logger.info(f"data['prompt'] {prompt}")
logger.info(f"data['n_predict'] {data['n_predict']}")
TASK_DICT: dict[str, type[TaskSpec]] = {
"mmlu": MMLU_Task,
"aime": Aime_Task,
"gsm8k": Gsm8k_Task,
"hellaswag": Hellaswag_Task,
"arc": ARC_Task,
"winogrande": WinoGrande_Task,
}
json_data: dict = {
"prompt": prompt,
"max_tokens": data["n_predict"],
def build_request(case: Case, n_predict: int) -> dict[str, Any]:
json_data = {
"n_predict": n_predict,
"max_tokens": n_predict,
"temperature": 0,
"prompt": case.prompt,
}
response = session.post(f"{server_address}/v1/completions", json=json_data)
res = json.loads(response.text)
logger.info(f"response {res}")
extracted_answer = extract_boxed_text(res["choices"][0]["text"])
source_answer = data["answer"]
if data["prompt_source"] == "aime" or data["prompt_source"] == "gsm8k":
try: # All AIME answers are integers, so we convert the extracted answer to an integer
extracted_answer = int(extracted_answer)
source_answer = int(source_answer)
except (ValueError, TypeError):
extracted_answer = None
logger.info(f"extracted_answer {extracted_answer}")
logger.info(f"data['answer'] {data['answer']}")
score = 1 if extracted_answer == source_answer else 0
return score
return json_data
def get_server(path_server: str, path_log: Optional[str]) -> dict:
if path_server.startswith("http://") or path_server.startswith("https://"):
return {"process": None, "address": path_server, "fout": None}
if os.environ.get("LLAMA_ARG_HOST") is None:
logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1")
os.environ["LLAMA_ARG_HOST"] = "127.0.0.1"
if os.environ.get("LLAMA_ARG_PORT") is None:
logger.info("LLAMA_ARG_PORT not explicitly set, using 8080")
os.environ["LLAMA_ARG_PORT"] = "8080"
hostname: Optional[str] = os.environ.get("LLAMA_ARG_HOST")
port: Optional[str] = os.environ.get("LLAMA_ARG_PORT")
assert hostname is not None
assert port is not None
address: str = f"http://{hostname}:{port}"
logger.info(f"Starting the llama.cpp server under {address}...")
def send_prompt(
case: Case,
data: dict,
) -> dict[str, Union[str, int]]:
ret_err = {
"task": case.task,
"case_id": case.case_id,
"status": "error",
"correct": 0,
"gold": case.gold,
"pred": "",
"error": "",
}
session: requests.Session = data["session"]
server_address: str = data["server_address"]
task = TASK_DICT.get(case.task)
if task is None:
ret_err["error"] = f"unknown_task: {case.task}"
return ret_err
logger.debug(case.prompt)
fout = open(path_log.format(port=port), "w") if path_log is not None else subprocess.DEVNULL
process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT)
json_data = build_request(case, data["n_predict"])
try:
response = session.post(f"{server_address}/v1/completions", json=json_data)
if response.ok:
res_json = response.json()
else:
ret_err["error"] = f"http_response: {response.status_code}"
logger.warning(ret_err["error"])
return ret_err
except Exception as e:
ret_err["error"] = f"http_exception: {e}"
logger.warning(ret_err["error"])
return ret_err
logger.debug(response.text)
return TASK_DICT[case.task].grade(case, res_json)
n_failures: int = 0
while True:
try:
sleep(1.0)
exit_code = process.poll()
if exit_code is not None:
raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}{path_log and f', see {path_log.format(port=port)}' or ''}")
response = requests.get(f"{address}/health")
if response.status_code == 200:
break
except requests.ConnectionError:
n_failures += 1
if n_failures >= 10:
raise RuntimeError("llama.cpp server is not healthy after 10 seconds")
return {"process": process, "address": address, "fout": fout}
def aggregate_by_task(results: list[dict[str, Any]]) -> dict[str, dict[str, int]]:
tmp = {
"total": 0,
"error": 0,
"invalid": 0,
"correct": 0,
}
agg: dict[str, dict[str, int]] = {}
for row in results:
d = agg.get(row["task"], tmp.copy())
d["total"] += 1
status = row["status"]
if status == "ok":
d["correct"] += row["correct"]
elif status == "invalid":
d["invalid"] += 1
elif status == "error":
d["error"] += 1
agg[row["task"]] = d
return agg
def print_summary(pertask_results: dict[str, dict[str, int]]):
print("\n=== llama-eval suite summary ===")
print(
f"{'Task':<15} {'Acc':>8} {'Correct':>8} {'Total':>8} {'Invalid':>8} {'Error':>8}"
)
print("-" * 65)
suite_total = 0
suite_correct = 0
for task in sorted(pertask_results.keys()):
stats = pertask_results[task]
total = stats["total"]
correct = stats["correct"]
invalid = stats["invalid"]
error = stats["error"]
acc = (correct / total) if total > 0 else 0.0
print(
f"{task:<15} "
f"{acc:8.3f} "
f"{correct:8d} "
f"{total:8d} "
f"{invalid:8d} "
f"{error:8d}"
)
suite_total += total
suite_correct += correct
# Overall summary
print("-" * 65)
suite_acc = (suite_correct / suite_total) if suite_total > 0 else 0.0
print(
f"{'ALL':<15} " f"{suite_acc:8.3f} " f"{suite_correct:8d} " f"{suite_total:8d}"
)
def benchmark(
path_server: str,
path_log: Optional[str],
prompt_source: str,
n_prompts: int,
n_predict: int,
rng_seed: int,
):
external_server: bool = path_server.startswith("http://") or path_server.startswith("https://")
if not path_server.startswith("http://") and not path_server.startswith("https://"):
logger.error("ERROR: malformed server path")
return
if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore
ds: Union[datasets.Dataset, None] = get_dataset(prompt_source, n_prompts, rng_seed)
if not ds:
logger.error("ERROR: get_dataset")
exit(0)
parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore
res: Union[tuple[list[str], list[str]], None] = get_prompts_text(prompt_source, ds)
if not res:
logger.error("ERROR: get_prompts_text")
exit(0)
task_queue: set[TaskSpec] = set()
for src in prompt_source.split(","):
if src == "all":
for v in TASK_DICT.values():
task_queue.add(v())
break
task_queue.add(TASK_DICT[src]())
prompts: Union[list[str], list[list[int]]] = res[0]
answer: Union[list[str], list[list[int]]] = res[1]
logger.info(prompts)
logger.info(f"external_server {external_server}")
server: Optional[dict] = None
session = None
try:
server = get_server(path_server, path_log)
server_address: str = server["address"]
assert external_server == (server["process"] is None)
server_address: str = path_server
adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore
session = requests.Session()
session.mount("http://", adapter)
session.mount("https://", adapter)
cases: list[Case] = []
data: list[dict] = []
for p, a in zip(prompts, answer):
data.append(
{
"prompt_source": prompt_source,
"session": session,
"server_address": server_address,
"external_server": external_server,
"prompt": p,
"answer": a,
"n_predict": n_predict,
}
)
for task in task_queue:
for case in task.iter_cases(n_prompts, rng_seed):
cases.append(case)
data.append(
{
"prompt_source": prompt_source,
"session": session,
"server_address": server_address,
"n_predict": n_predict,
}
)
logger.info("Starting the benchmark...\n")
t0 = time()
results: list[int] = thread_map(
send_prompt, data, max_workers=parallel, chunksize=1
results: list[dict[str, Union[str, int]]] = thread_map(
send_prompt,
cases,
data,
max_workers=parallel,
chunksize=1,
)
finally:
if server is not None and server["process"] is not None:
server["process"].terminate()
server["process"].wait()
if session is not None:
session.close()
t1 = time()
logger.info(f"\nllama-eval duration: {t1-t0:.2f} s")
correct: int = sum(results)
total_questions: int = len(data)
logger.info(f"llama-eval duration: {t1-t0:.2f} s")
logger.info(f"{prompt_source} correct: {correct}")
logger.info(f"{prompt_source} total_questions: {total_questions}")
logger.info(f"{prompt_source} accuracy: {correct / total_questions}")
pertask_results = aggregate_by_task(results)
print_summary(pertask_results)
if __name__ == "__main__":
@ -324,23 +569,17 @@ if __name__ == "__main__":
parser.add_argument(
"--path_server",
type=str,
default="llama-server",
help="Path to the llama.cpp server binary",
)
parser.add_argument(
"--path_log",
type=str,
default="server-bench-{port}.log",
help="Path to the model to use for the benchmark",
default="http://localhost:8033",
help="llama-server url",
)
parser.add_argument(
"--prompt_source",
type=str,
default="mmlu",
help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions",
help=f"Eval types supported: all,{TASK_DICT.keys()}",
)
parser.add_argument(
"--n_prompts", type=int, default=100, help="Number of prompts to evaluate"
"--n_prompts", type=int, default=None, help="Number of prompts to evaluate"
)
parser.add_argument(
"--rng_seed",