multi source llama-eval

This commit is contained in:
gatbontonpc 2026-01-12 13:47:43 -05:00
parent 2357f6f193
commit f3a5b4ea72
1 changed files with 472 additions and 233 deletions

View File

@ -2,91 +2,43 @@
import re import re
import argparse import argparse
import json
import os import os
import random from time import time
import subprocess from typing import Union, Any, Mapping, cast
from time import sleep, time
from typing import Optional, Union
import datasets import datasets
import logging import logging
import requests import requests
from tqdm.contrib.concurrent import thread_map from tqdm.contrib.concurrent import thread_map
from typing import Iterator from typing import Iterator
from abc import ABC from abc import ABC, abstractmethod
from dataclasses import dataclass
logging.basicConfig(level=logging.INFO, format='%(message)s') logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger("llama-eval") logger = logging.getLogger("llama-eval")
MATH_TEMPLATE = """ MATH_TEMPLATE = """
{question} {question}
Put your final answer within \\boxed{{}}. Do not include any explanation. Put your final answer within \\boxed{{}}.
""" """
MC_FROM_INT = {
0: "A",
1: "B",
2: "C",
3: "D",
}
def format_multiple_choice(prompt: str, choices: list[str]): def format_multiple_choice(prompt: str, choices: list[str]):
QUERY_TEMPLATE_MULTICHOICE = """ lines = [prompt]
{question}
(A) {A} labels = [chr(ord("A") + i) for i in range(len(choices))]
(B) {B} for l, c in zip(labels, choices):
(C) {C} lines.append(f"({l}): {c.strip()}")
(D) {D} lines.append(
"Do not include any explanation. Answer with the corresponding option letter only"
Express your final answer as the corresponding option 'A', 'B', 'C', or 'D'. Put your final answer within \\boxed{{}}.
""".strip()
A_str = choices[0]
B_str = choices[1]
C_str = choices[2]
D_str = choices[3]
query = QUERY_TEMPLATE_MULTICHOICE.format(
question=prompt, A=A_str, B=B_str, C=C_str, D=D_str
) )
return query lines.append(", ".join(labels))
lines.append("Put your final answer within \\boxed{{}}.")
return "\n".join(lines), labels
# Preprocess hellaswag def extract_boxed_text(text: str) -> str:
def preprocess(text):
text = text.strip()
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
text = text.replace(" [title]", ". ")
text = re.sub("\\[.*?\\]", "", text)
text = text.replace(" ", " ")
return text
def hellaswag_process_doc(doc):
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
question = preprocess(doc["activity_label"] + ": " + ctx)
proc_answers = [preprocess(answer) for answer in doc["endings"]]
prompt = format_multiple_choice(question, proc_answers)
out_doc = {
"prompt": prompt,
"gold": MC_FROM_INT[int(doc["label"])],
}
return out_doc
def mmlu_process_doc(doc):
prompt = format_multiple_choice(doc["question"], doc["choices"])
out_doc = {
"prompt": prompt,
"gold": MC_FROM_INT[int(doc["answer"])],
}
return out_doc
def extract_boxed_text(text):
pattern = r"boxed{(.*?)}|framebox{(.*?)}" pattern = r"boxed{(.*?)}|framebox{(.*?)}"
matches = re.findall(pattern, text, re.DOTALL) matches = re.findall(pattern, text, re.DOTALL)
logger.debug(matches) logger.debug(matches)
@ -95,222 +47,515 @@ def extract_boxed_text(text):
for group in match: for group in match:
if group != "": if group != "":
return group.split(",")[-1].strip() return group.split(",")[-1].strip()
logger.warning( logger.warning("Could not extract boxed text. Maybe expand context window")
"Could not extract boxed text. Using last integer. Maybe expand context window"
)
pattern = r"\d+" # get the last integer if no pattern found
matches = re.findall(pattern, text, re.DOTALL)
if matches:
return matches[-1]
return "" return ""
def get_prompts_text( @dataclass(frozen=True)
dataset_name: str, ds: datasets.Dataset class Case:
) -> Optional[tuple[list[str], list[str]]]: task: str
ret = [] kind: str
if dataset_name.lower() == "mmlu": case_id: str
ds = ds.map(mmlu_process_doc) prompt: str
ret = ds["prompt"], ds["gold"] gold: str
elif dataset_name.lower() == "hellaswag": meta_data: dict[str, Any]
ds = ds.map(hellaswag_process_doc)
ret = ds["prompt"], ds["gold"]
elif dataset_name.lower() == "aime": class TaskSpec(ABC):
name: str
kind: str
@abstractmethod
def load(self, limit, seed) -> datasets.Dataset:
pass
@abstractmethod
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
pass
@staticmethod
@abstractmethod
def grade(case: Case, response: dict) -> dict[str, Any]:
pass
class MCTaskSpec(TaskSpec):
@staticmethod
def grade(case: Case, response: dict) -> dict[str, Any]:
logger.debug(f"response {response}")
result = {
"task": case.task,
"case_id": case.case_id,
"correct": 0,
"pred": None,
"gold": case.gold,
"status": "ok",
}
try:
extracted_answer = extract_boxed_text(response["choices"][0]["text"])
except Exception as e:
result["status"] = "error"
logger.warning("ERROR: extract_boxed_text")
return result
if not extracted_answer:
result["status"] = "invalid"
logger.warning("INVALID: extract_boxed_text")
return result
logger.debug(f"extracted_answer {extracted_answer}")
logger.debug(f"data['answer'] {case.gold}")
result["pred"] = extracted_answer
result["correct"] = 1 if extracted_answer == case.gold else 0
return result
class MathTaskSpec(TaskSpec):
@staticmethod
def grade(case: Case, response: dict) -> dict[str, Any]:
logger.debug(f"response {response}")
result = {
"task": case.task,
"case_id": case.case_id,
"correct": 0,
"gold": case.gold,
"status": "ok",
"pred": None,
}
try:
extracted_answer = extract_boxed_text(response["choices"][0]["text"])
except Exception as e:
result["status"] = "error"
return result
source_answer = case.gold
try: # All AIME answers are integers, so we convert the extracted answer to an integer
extracted_answer = int(extracted_answer)
source_answer = int(case.gold)
except (ValueError, TypeError):
result["status"] = "invalid"
return result
logger.debug(f"extracted_answer {extracted_answer}")
logger.debug(f"data['answer'] {case.gold}")
result["pred"] = extracted_answer
result["correct"] = 1 if extracted_answer == source_answer else 0
return result
class ARC_Task(MCTaskSpec):
def __init__(self):
self.name = "arc"
self.kind = "mc"
def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset("allenai/ai2_arc", "ARC-Challenge", split="test")
if limit:
ds = ds.shuffle(seed=seed)
ds = ds.select(range(min(limit, len(ds))))
return ds
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = self.load(limit, seed)
for i, doc in enumerate(ds):
doc = cast(Mapping[str, Any], doc)
prompt, labels = format_multiple_choice(
doc["question"], doc["choices"]["text"]
)
yield Case(
task=self.name,
kind=self.kind,
case_id=f"ARC-Challenge:{i}",
prompt=prompt,
gold=doc["answerKey"],
meta_data={"labels": labels},
)
class WinoGrande_Task(MCTaskSpec):
def __init__(self):
self.name = "winogrande"
self.kind = "mc"
def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset(
"winogrande", "winogrande_debiased", split="validation"
)
if limit:
ds = ds.shuffle(seed=seed)
ds = ds.select(range(min(limit, len(ds))))
return ds
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = self.load(limit, seed)
for i, doc in enumerate(ds):
doc = cast(Mapping[str, Any], doc)
prompt, labels = format_multiple_choice(
doc["sentence"], [doc["option1"], doc["option2"]]
)
yield Case(
task=self.name,
kind=self.kind,
case_id=f"winogrande:{i}",
prompt=prompt,
gold=labels[int(doc["answer"]) - 1], # winogrande answers are 1 based
meta_data={"labels": labels},
)
class MMLU_Task(MCTaskSpec):
def __init__(self):
self.name = "mmlu"
self.kind = "mc"
def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset("cais/mmlu", "all", split="test")
if limit:
ds = ds.shuffle(seed=seed)
ds = ds.select(range(min(limit, len(ds))))
return ds
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = self.load(limit, seed)
for i, doc in enumerate(ds):
doc = cast(Mapping[str, Any], doc)
prompt, labels = format_multiple_choice(doc["question"], doc["choices"])
yield Case(
task=self.name,
kind=self.kind,
case_id=f"mmlu:{doc['subject']}:{i}",
prompt=prompt,
gold=labels[int(doc["answer"])],
meta_data={"subject": doc["subject"], "labels": labels},
)
class Hellaswag_Task(MCTaskSpec):
# Preprocess hellaswag
@staticmethod
def preprocess(text: str):
text = text.strip()
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
text = text.replace(" [title]", ". ")
text = re.sub("\\[.*?\\]", "", text)
text = text.replace(" ", " ")
return text
@staticmethod
def hellaswag_process_doc(doc: dict[str, str]):
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
question = Hellaswag_Task.preprocess(doc["activity_label"] + ": " + ctx)
proc_answers = [Hellaswag_Task.preprocess(answer) for answer in doc["endings"]]
prompt, labels = format_multiple_choice(question, proc_answers)
out_doc = {
"prompt": prompt,
"gold": labels[int(doc["label"])],
}
return out_doc
def __init__(self):
self.name = "hellaswag"
self.kind = "mc"
def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset("Rowan/hellaswag", split="validation")
if limit:
ds = ds.shuffle(seed=seed)
ds = ds.select(range(min(limit, len(ds))))
ds = ds.map(Hellaswag_Task.hellaswag_process_doc)
return ds
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = self.load(limit, seed)
for i, doc in enumerate(ds):
doc = cast(Mapping[str, Any], doc)
yield Case(
task=self.name,
kind=self.kind,
case_id=f"hellaswag:{i}",
prompt=doc["prompt"],
gold=doc["gold"],
meta_data={},
)
class Aime_Task(MathTaskSpec):
def __init__(self):
self.name = "aime"
self.kind = "math"
def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split="train")
if limit:
ds = ds.shuffle(seed=seed)
ds = ds.select(range(min(limit, len(ds))))
ds = ds.map( ds = ds.map(
lambda k: { lambda ex: {
"prompt": MATH_TEMPLATE.format( "prompt": MATH_TEMPLATE.format(
question=k["problem"], question=ex["problem"],
) )
} }
) )
ret = ds["prompt"], ds["answer"] return ds
elif dataset_name.lower() == "gsm8k":
ds = ds.map(lambda k: {"prompt": MATH_TEMPLATE.format(question=k["question"])})
la = []
for answer in ds["answer"]:
la.append(answer.split("### ")[-1].rstrip())
ret = ds["prompt"], la
else:
return None
return ret def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = self.load(limit, seed)
for i, doc in enumerate(ds):
doc = cast(Mapping[str, Any], doc)
yield Case(
task=self.name,
kind=self.kind,
case_id=f"aime:{i}",
prompt=doc["prompt"],
gold=doc["answer"],
meta_data={},
)
def get_dataset( class Gsm8k_Task(MathTaskSpec):
dataset_name: str, n_prompts: int, rng_seed: int
) -> Optional[datasets.Dataset]: def __init__(self):
ds = None self.name = "gsm8k"
cache_dir = "./build/bin/datasets" self.kind = "math"
logger.info(f"Loading {dataset_name.lower()} dataset...")
if dataset_name.lower() == "mmlu": def load(self, limit, seed) -> datasets.Dataset:
ds = datasets.load_dataset( ds = datasets.load_dataset("openai/gsm8k", "main", split="test")
"cais/mmlu", "all", split="test", cache_dir=cache_dir if limit:
ds = ds.shuffle(seed=seed)
ds = ds.select(range(min(limit, len(ds))))
ds = ds.map(
lambda k: {
"prompt": MATH_TEMPLATE.format(
question=k["question"],
),
"gold": k["answer"].split("### ")[-1].rstrip(),
}
) )
elif dataset_name.lower() == "hellaswag": return ds
ds = datasets.load_dataset(
"Rowan/hellaswag", split="validation", cache_dir=cache_dir
)
elif dataset_name.lower() == "aime":
ds = datasets.load_dataset(
"AI-MO/aimo-validation-aime", split="train", cache_dir=cache_dir
)
elif dataset_name.lower() == "gsm8k":
ds = datasets.load_dataset("openai/gsm8k", split="test")
else:
return None
if n_prompts >= 0: def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
ds = ds.shuffle(seed=rng_seed) ds = self.load(limit, seed)
ds = ds.select(range(min(n_prompts, len(ds))))
return ds for i, doc in enumerate(ds):
doc = cast(Mapping[str, Any], doc)
yield Case(
task=self.name,
kind=self.kind,
case_id=f"gsm8k:{i}",
prompt=doc["prompt"],
gold=doc["gold"],
meta_data={},
)
def send_prompt(data: dict) -> int: TASK_DICT: dict[str, type[TaskSpec]] = {
session = data["session"] "mmlu": MMLU_Task,
server_address: str = data["server_address"] "aime": Aime_Task,
prompt: str = data["prompt"] "gsm8k": Gsm8k_Task,
logger.info(f"data['external_server'] {data['external_server']}") "hellaswag": Hellaswag_Task,
logger.info(f"data['prompt'] {prompt}") "arc": ARC_Task,
logger.info(f"data['n_predict'] {data['n_predict']}") "winogrande": WinoGrande_Task,
}
json_data: dict = {
"prompt": prompt, def build_request(case: Case, n_predict: int) -> dict[str, Any]:
"max_tokens": data["n_predict"], json_data = {
"n_predict": n_predict,
"max_tokens": n_predict,
"temperature": 0, "temperature": 0,
"prompt": case.prompt,
} }
response = session.post(f"{server_address}/v1/completions", json=json_data) return json_data
res = json.loads(response.text)
logger.info(f"response {res}")
extracted_answer = extract_boxed_text(res["choices"][0]["text"])
source_answer = data["answer"]
if data["prompt_source"] == "aime" or data["prompt_source"] == "gsm8k":
try: # All AIME answers are integers, so we convert the extracted answer to an integer
extracted_answer = int(extracted_answer)
source_answer = int(source_answer)
except (ValueError, TypeError):
extracted_answer = None
logger.info(f"extracted_answer {extracted_answer}")
logger.info(f"data['answer'] {data['answer']}")
score = 1 if extracted_answer == source_answer else 0
return score
def get_server(path_server: str, path_log: Optional[str]) -> dict: def send_prompt(
if path_server.startswith("http://") or path_server.startswith("https://"): case: Case,
return {"process": None, "address": path_server, "fout": None} data: dict,
if os.environ.get("LLAMA_ARG_HOST") is None: ) -> dict[str, Union[str, int]]:
logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1") ret_err = {
os.environ["LLAMA_ARG_HOST"] = "127.0.0.1" "task": case.task,
if os.environ.get("LLAMA_ARG_PORT") is None: "case_id": case.case_id,
logger.info("LLAMA_ARG_PORT not explicitly set, using 8080") "status": "error",
os.environ["LLAMA_ARG_PORT"] = "8080" "correct": 0,
hostname: Optional[str] = os.environ.get("LLAMA_ARG_HOST") "gold": case.gold,
port: Optional[str] = os.environ.get("LLAMA_ARG_PORT") "pred": "",
assert hostname is not None "error": "",
assert port is not None }
address: str = f"http://{hostname}:{port}" session: requests.Session = data["session"]
logger.info(f"Starting the llama.cpp server under {address}...") server_address: str = data["server_address"]
task = TASK_DICT.get(case.task)
if task is None:
ret_err["error"] = f"unknown_task: {case.task}"
return ret_err
logger.debug(case.prompt)
fout = open(path_log.format(port=port), "w") if path_log is not None else subprocess.DEVNULL json_data = build_request(case, data["n_predict"])
process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT) try:
response = session.post(f"{server_address}/v1/completions", json=json_data)
if response.ok:
res_json = response.json()
else:
ret_err["error"] = f"http_response: {response.status_code}"
logger.warning(ret_err["error"])
return ret_err
except Exception as e:
ret_err["error"] = f"http_exception: {e}"
logger.warning(ret_err["error"])
return ret_err
logger.debug(response.text)
return TASK_DICT[case.task].grade(case, res_json)
n_failures: int = 0
while True:
try:
sleep(1.0)
exit_code = process.poll()
if exit_code is not None:
raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}{path_log and f', see {path_log.format(port=port)}' or ''}")
response = requests.get(f"{address}/health")
if response.status_code == 200:
break
except requests.ConnectionError:
n_failures += 1
if n_failures >= 10:
raise RuntimeError("llama.cpp server is not healthy after 10 seconds")
return {"process": process, "address": address, "fout": fout} def aggregate_by_task(results: list[dict[str, Any]]) -> dict[str, dict[str, int]]:
tmp = {
"total": 0,
"error": 0,
"invalid": 0,
"correct": 0,
}
agg: dict[str, dict[str, int]] = {}
for row in results:
d = agg.get(row["task"], tmp.copy())
d["total"] += 1
status = row["status"]
if status == "ok":
d["correct"] += row["correct"]
elif status == "invalid":
d["invalid"] += 1
elif status == "error":
d["error"] += 1
agg[row["task"]] = d
return agg
def print_summary(pertask_results: dict[str, dict[str, int]]):
print("\n=== llama-eval suite summary ===")
print(
f"{'Task':<15} {'Acc':>8} {'Correct':>8} {'Total':>8} {'Invalid':>8} {'Error':>8}"
)
print("-" * 65)
suite_total = 0
suite_correct = 0
for task in sorted(pertask_results.keys()):
stats = pertask_results[task]
total = stats["total"]
correct = stats["correct"]
invalid = stats["invalid"]
error = stats["error"]
acc = (correct / total) if total > 0 else 0.0
print(
f"{task:<15} "
f"{acc:8.3f} "
f"{correct:8d} "
f"{total:8d} "
f"{invalid:8d} "
f"{error:8d}"
)
suite_total += total
suite_correct += correct
# Overall summary
print("-" * 65)
suite_acc = (suite_correct / suite_total) if suite_total > 0 else 0.0
print(
f"{'ALL':<15} " f"{suite_acc:8.3f} " f"{suite_correct:8d} " f"{suite_total:8d}"
)
def benchmark( def benchmark(
path_server: str, path_server: str,
path_log: Optional[str],
prompt_source: str, prompt_source: str,
n_prompts: int, n_prompts: int,
n_predict: int, n_predict: int,
rng_seed: int, rng_seed: int,
): ):
external_server: bool = path_server.startswith("http://") or path_server.startswith("https://") if not path_server.startswith("http://") and not path_server.startswith("https://"):
logger.error("ERROR: malformed server path")
return
if os.environ.get("LLAMA_ARG_N_PARALLEL") is None: if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32") logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
os.environ["LLAMA_ARG_N_PARALLEL"] = "32" os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore
ds: Union[datasets.Dataset, None] = get_dataset(prompt_source, n_prompts, rng_seed)
if not ds:
logger.error("ERROR: get_dataset")
exit(0)
res: Union[tuple[list[str], list[str]], None] = get_prompts_text(prompt_source, ds) task_queue: set[TaskSpec] = set()
if not res: for src in prompt_source.split(","):
logger.error("ERROR: get_prompts_text") if src == "all":
exit(0) for v in TASK_DICT.values():
task_queue.add(v())
break
task_queue.add(TASK_DICT[src]())
prompts: Union[list[str], list[list[int]]] = res[0]
answer: Union[list[str], list[list[int]]] = res[1]
logger.info(prompts)
logger.info(f"external_server {external_server}")
server: Optional[dict] = None
session = None session = None
try: try:
server = get_server(path_server, path_log) server_address: str = path_server
server_address: str = server["address"]
assert external_server == (server["process"] is None)
adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore
session = requests.Session() session = requests.Session()
session.mount("http://", adapter) session.mount("http://", adapter)
session.mount("https://", adapter) session.mount("https://", adapter)
cases: list[Case] = []
data: list[dict] = [] data: list[dict] = []
for p, a in zip(prompts, answer): for task in task_queue:
data.append( for case in task.iter_cases(n_prompts, rng_seed):
{ cases.append(case)
"prompt_source": prompt_source, data.append(
"session": session, {
"server_address": server_address, "prompt_source": prompt_source,
"external_server": external_server, "session": session,
"prompt": p, "server_address": server_address,
"answer": a, "n_predict": n_predict,
"n_predict": n_predict, }
} )
)
logger.info("Starting the benchmark...\n") logger.info("Starting the benchmark...\n")
t0 = time() t0 = time()
results: list[int] = thread_map( results: list[dict[str, Union[str, int]]] = thread_map(
send_prompt, data, max_workers=parallel, chunksize=1 send_prompt,
cases,
data,
max_workers=parallel,
chunksize=1,
) )
finally: finally:
if server is not None and server["process"] is not None:
server["process"].terminate()
server["process"].wait()
if session is not None: if session is not None:
session.close() session.close()
t1 = time() t1 = time()
logger.info(f"\nllama-eval duration: {t1-t0:.2f} s")
correct: int = sum(results) pertask_results = aggregate_by_task(results)
total_questions: int = len(data) print_summary(pertask_results)
logger.info(f"llama-eval duration: {t1-t0:.2f} s")
logger.info(f"{prompt_source} correct: {correct}")
logger.info(f"{prompt_source} total_questions: {total_questions}")
logger.info(f"{prompt_source} accuracy: {correct / total_questions}")
if __name__ == "__main__": if __name__ == "__main__":
@ -324,23 +569,17 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--path_server", "--path_server",
type=str, type=str,
default="llama-server", default="http://localhost:8033",
help="Path to the llama.cpp server binary", help="llama-server url",
)
parser.add_argument(
"--path_log",
type=str,
default="server-bench-{port}.log",
help="Path to the model to use for the benchmark",
) )
parser.add_argument( parser.add_argument(
"--prompt_source", "--prompt_source",
type=str, type=str,
default="mmlu", default="mmlu",
help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions", help=f"Eval types supported: all,{TASK_DICT.keys()}",
) )
parser.add_argument( parser.add_argument(
"--n_prompts", type=int, default=100, help="Number of prompts to evaluate" "--n_prompts", type=int, default=None, help="Number of prompts to evaluate"
) )
parser.add_argument( parser.add_argument(
"--rng_seed", "--rng_seed",