From f3a5b4ea728fb14bd0b2e5394f82c90d315e65b5 Mon Sep 17 00:00:00 2001 From: gatbontonpc Date: Mon, 12 Jan 2026 13:47:43 -0500 Subject: [PATCH] multi source llama-eval --- examples/llama-eval/llama-eval.py | 705 ++++++++++++++++++++---------- 1 file changed, 472 insertions(+), 233 deletions(-) diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py index 10ec766fe6..411d0adbab 100644 --- a/examples/llama-eval/llama-eval.py +++ b/examples/llama-eval/llama-eval.py @@ -2,91 +2,43 @@ import re import argparse -import json import os -import random -import subprocess -from time import sleep, time -from typing import Optional, Union +from time import time +from typing import Union, Any, Mapping, cast import datasets import logging import requests from tqdm.contrib.concurrent import thread_map from typing import Iterator -from abc import ABC +from abc import ABC, abstractmethod +from dataclasses import dataclass logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger("llama-eval") - MATH_TEMPLATE = """ {question} -Put your final answer within \\boxed{{}}. +Do not include any explanation. Put your final answer within \\boxed{{}}. """ -MC_FROM_INT = { - 0: "A", - 1: "B", - 2: "C", - 3: "D", -} - def format_multiple_choice(prompt: str, choices: list[str]): - QUERY_TEMPLATE_MULTICHOICE = """ - {question} + lines = [prompt] - (A) {A} - (B) {B} - (C) {C} - (D) {D} - - Express your final answer as the corresponding option 'A', 'B', 'C', or 'D'. Put your final answer within \\boxed{{}}. - - """.strip() - A_str = choices[0] - B_str = choices[1] - C_str = choices[2] - D_str = choices[3] - query = QUERY_TEMPLATE_MULTICHOICE.format( - question=prompt, A=A_str, B=B_str, C=C_str, D=D_str + labels = [chr(ord("A") + i) for i in range(len(choices))] + for l, c in zip(labels, choices): + lines.append(f"({l}): {c.strip()}") + lines.append( + "Do not include any explanation. Answer with the corresponding option letter only" ) - return query + lines.append(", ".join(labels)) + lines.append("Put your final answer within \\boxed{{}}.") + + return "\n".join(lines), labels -# Preprocess hellaswag -def preprocess(text): - text = text.strip() - # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. - text = text.replace(" [title]", ". ") - text = re.sub("\\[.*?\\]", "", text) - text = text.replace(" ", " ") - return text - - -def hellaswag_process_doc(doc): - ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize() - question = preprocess(doc["activity_label"] + ": " + ctx) - proc_answers = [preprocess(answer) for answer in doc["endings"]] - prompt = format_multiple_choice(question, proc_answers) - out_doc = { - "prompt": prompt, - "gold": MC_FROM_INT[int(doc["label"])], - } - return out_doc - - -def mmlu_process_doc(doc): - prompt = format_multiple_choice(doc["question"], doc["choices"]) - out_doc = { - "prompt": prompt, - "gold": MC_FROM_INT[int(doc["answer"])], - } - return out_doc - - -def extract_boxed_text(text): +def extract_boxed_text(text: str) -> str: pattern = r"boxed{(.*?)}|framebox{(.*?)}" matches = re.findall(pattern, text, re.DOTALL) logger.debug(matches) @@ -95,222 +47,515 @@ def extract_boxed_text(text): for group in match: if group != "": return group.split(",")[-1].strip() - logger.warning( - "Could not extract boxed text. Using last integer. Maybe expand context window" - ) - pattern = r"\d+" # get the last integer if no pattern found - matches = re.findall(pattern, text, re.DOTALL) - if matches: - return matches[-1] + logger.warning("Could not extract boxed text. Maybe expand context window") return "" -def get_prompts_text( - dataset_name: str, ds: datasets.Dataset -) -> Optional[tuple[list[str], list[str]]]: - ret = [] - if dataset_name.lower() == "mmlu": - ds = ds.map(mmlu_process_doc) - ret = ds["prompt"], ds["gold"] - elif dataset_name.lower() == "hellaswag": - ds = ds.map(hellaswag_process_doc) - ret = ds["prompt"], ds["gold"] - elif dataset_name.lower() == "aime": +@dataclass(frozen=True) +class Case: + task: str + kind: str + case_id: str + prompt: str + gold: str + meta_data: dict[str, Any] + + +class TaskSpec(ABC): + name: str + kind: str + + @abstractmethod + def load(self, limit, seed) -> datasets.Dataset: + pass + + @abstractmethod + def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: + pass + + @staticmethod + @abstractmethod + def grade(case: Case, response: dict) -> dict[str, Any]: + pass + + +class MCTaskSpec(TaskSpec): + @staticmethod + def grade(case: Case, response: dict) -> dict[str, Any]: + logger.debug(f"response {response}") + result = { + "task": case.task, + "case_id": case.case_id, + "correct": 0, + "pred": None, + "gold": case.gold, + "status": "ok", + } + + try: + extracted_answer = extract_boxed_text(response["choices"][0]["text"]) + except Exception as e: + result["status"] = "error" + logger.warning("ERROR: extract_boxed_text") + + return result + + if not extracted_answer: + result["status"] = "invalid" + logger.warning("INVALID: extract_boxed_text") + return result + + logger.debug(f"extracted_answer {extracted_answer}") + logger.debug(f"data['answer'] {case.gold}") + result["pred"] = extracted_answer + result["correct"] = 1 if extracted_answer == case.gold else 0 + + return result + + +class MathTaskSpec(TaskSpec): + + @staticmethod + def grade(case: Case, response: dict) -> dict[str, Any]: + logger.debug(f"response {response}") + result = { + "task": case.task, + "case_id": case.case_id, + "correct": 0, + "gold": case.gold, + "status": "ok", + "pred": None, + } + + try: + extracted_answer = extract_boxed_text(response["choices"][0]["text"]) + except Exception as e: + result["status"] = "error" + return result + + source_answer = case.gold + try: # All AIME answers are integers, so we convert the extracted answer to an integer + extracted_answer = int(extracted_answer) + source_answer = int(case.gold) + except (ValueError, TypeError): + result["status"] = "invalid" + return result + + logger.debug(f"extracted_answer {extracted_answer}") + logger.debug(f"data['answer'] {case.gold}") + result["pred"] = extracted_answer + result["correct"] = 1 if extracted_answer == source_answer else 0 + + return result + + +class ARC_Task(MCTaskSpec): + + def __init__(self): + self.name = "arc" + self.kind = "mc" + + def load(self, limit, seed) -> datasets.Dataset: + ds = datasets.load_dataset("allenai/ai2_arc", "ARC-Challenge", split="test") + if limit: + ds = ds.shuffle(seed=seed) + ds = ds.select(range(min(limit, len(ds)))) + return ds + + def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: + ds = self.load(limit, seed) + + for i, doc in enumerate(ds): + doc = cast(Mapping[str, Any], doc) + + prompt, labels = format_multiple_choice( + doc["question"], doc["choices"]["text"] + ) + yield Case( + task=self.name, + kind=self.kind, + case_id=f"ARC-Challenge:{i}", + prompt=prompt, + gold=doc["answerKey"], + meta_data={"labels": labels}, + ) + + +class WinoGrande_Task(MCTaskSpec): + + def __init__(self): + self.name = "winogrande" + self.kind = "mc" + + def load(self, limit, seed) -> datasets.Dataset: + ds = datasets.load_dataset( + "winogrande", "winogrande_debiased", split="validation" + ) + if limit: + ds = ds.shuffle(seed=seed) + ds = ds.select(range(min(limit, len(ds)))) + return ds + + def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: + ds = self.load(limit, seed) + + for i, doc in enumerate(ds): + doc = cast(Mapping[str, Any], doc) + + prompt, labels = format_multiple_choice( + doc["sentence"], [doc["option1"], doc["option2"]] + ) + yield Case( + task=self.name, + kind=self.kind, + case_id=f"winogrande:{i}", + prompt=prompt, + gold=labels[int(doc["answer"]) - 1], # winogrande answers are 1 based + meta_data={"labels": labels}, + ) + + +class MMLU_Task(MCTaskSpec): + + def __init__(self): + self.name = "mmlu" + self.kind = "mc" + + def load(self, limit, seed) -> datasets.Dataset: + ds = datasets.load_dataset("cais/mmlu", "all", split="test") + if limit: + ds = ds.shuffle(seed=seed) + ds = ds.select(range(min(limit, len(ds)))) + return ds + + def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: + ds = self.load(limit, seed) + + for i, doc in enumerate(ds): + doc = cast(Mapping[str, Any], doc) + + prompt, labels = format_multiple_choice(doc["question"], doc["choices"]) + yield Case( + task=self.name, + kind=self.kind, + case_id=f"mmlu:{doc['subject']}:{i}", + prompt=prompt, + gold=labels[int(doc["answer"])], + meta_data={"subject": doc["subject"], "labels": labels}, + ) + + +class Hellaswag_Task(MCTaskSpec): + + # Preprocess hellaswag + @staticmethod + def preprocess(text: str): + text = text.strip() + # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. + text = text.replace(" [title]", ". ") + text = re.sub("\\[.*?\\]", "", text) + text = text.replace(" ", " ") + return text + + @staticmethod + def hellaswag_process_doc(doc: dict[str, str]): + ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize() + question = Hellaswag_Task.preprocess(doc["activity_label"] + ": " + ctx) + proc_answers = [Hellaswag_Task.preprocess(answer) for answer in doc["endings"]] + prompt, labels = format_multiple_choice(question, proc_answers) + out_doc = { + "prompt": prompt, + "gold": labels[int(doc["label"])], + } + return out_doc + + def __init__(self): + self.name = "hellaswag" + self.kind = "mc" + + def load(self, limit, seed) -> datasets.Dataset: + ds = datasets.load_dataset("Rowan/hellaswag", split="validation") + if limit: + ds = ds.shuffle(seed=seed) + ds = ds.select(range(min(limit, len(ds)))) + ds = ds.map(Hellaswag_Task.hellaswag_process_doc) + + return ds + + def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: + ds = self.load(limit, seed) + for i, doc in enumerate(ds): + doc = cast(Mapping[str, Any], doc) + yield Case( + task=self.name, + kind=self.kind, + case_id=f"hellaswag:{i}", + prompt=doc["prompt"], + gold=doc["gold"], + meta_data={}, + ) + + +class Aime_Task(MathTaskSpec): + + def __init__(self): + self.name = "aime" + self.kind = "math" + + def load(self, limit, seed) -> datasets.Dataset: + ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split="train") + + if limit: + ds = ds.shuffle(seed=seed) + ds = ds.select(range(min(limit, len(ds)))) + ds = ds.map( - lambda k: { + lambda ex: { "prompt": MATH_TEMPLATE.format( - question=k["problem"], + question=ex["problem"], ) } ) - ret = ds["prompt"], ds["answer"] - elif dataset_name.lower() == "gsm8k": - ds = ds.map(lambda k: {"prompt": MATH_TEMPLATE.format(question=k["question"])}) - la = [] - for answer in ds["answer"]: - la.append(answer.split("### ")[-1].rstrip()) - ret = ds["prompt"], la - else: - return None + return ds - return ret + def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: + ds = self.load(limit, seed) + + for i, doc in enumerate(ds): + doc = cast(Mapping[str, Any], doc) + yield Case( + task=self.name, + kind=self.kind, + case_id=f"aime:{i}", + prompt=doc["prompt"], + gold=doc["answer"], + meta_data={}, + ) -def get_dataset( - dataset_name: str, n_prompts: int, rng_seed: int -) -> Optional[datasets.Dataset]: - ds = None - cache_dir = "./build/bin/datasets" - logger.info(f"Loading {dataset_name.lower()} dataset...") - if dataset_name.lower() == "mmlu": - ds = datasets.load_dataset( - "cais/mmlu", "all", split="test", cache_dir=cache_dir +class Gsm8k_Task(MathTaskSpec): + + def __init__(self): + self.name = "gsm8k" + self.kind = "math" + + def load(self, limit, seed) -> datasets.Dataset: + ds = datasets.load_dataset("openai/gsm8k", "main", split="test") + if limit: + ds = ds.shuffle(seed=seed) + ds = ds.select(range(min(limit, len(ds)))) + + ds = ds.map( + lambda k: { + "prompt": MATH_TEMPLATE.format( + question=k["question"], + ), + "gold": k["answer"].split("### ")[-1].rstrip(), + } ) - elif dataset_name.lower() == "hellaswag": - ds = datasets.load_dataset( - "Rowan/hellaswag", split="validation", cache_dir=cache_dir - ) - elif dataset_name.lower() == "aime": - ds = datasets.load_dataset( - "AI-MO/aimo-validation-aime", split="train", cache_dir=cache_dir - ) - elif dataset_name.lower() == "gsm8k": - ds = datasets.load_dataset("openai/gsm8k", split="test") - else: - return None + return ds - if n_prompts >= 0: - ds = ds.shuffle(seed=rng_seed) - ds = ds.select(range(min(n_prompts, len(ds)))) - return ds + def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: + ds = self.load(limit, seed) + + for i, doc in enumerate(ds): + doc = cast(Mapping[str, Any], doc) + yield Case( + task=self.name, + kind=self.kind, + case_id=f"gsm8k:{i}", + prompt=doc["prompt"], + gold=doc["gold"], + meta_data={}, + ) -def send_prompt(data: dict) -> int: - session = data["session"] - server_address: str = data["server_address"] - prompt: str = data["prompt"] - logger.info(f"data['external_server'] {data['external_server']}") - logger.info(f"data['prompt'] {prompt}") - logger.info(f"data['n_predict'] {data['n_predict']}") +TASK_DICT: dict[str, type[TaskSpec]] = { + "mmlu": MMLU_Task, + "aime": Aime_Task, + "gsm8k": Gsm8k_Task, + "hellaswag": Hellaswag_Task, + "arc": ARC_Task, + "winogrande": WinoGrande_Task, +} - json_data: dict = { - "prompt": prompt, - "max_tokens": data["n_predict"], + +def build_request(case: Case, n_predict: int) -> dict[str, Any]: + json_data = { + "n_predict": n_predict, + "max_tokens": n_predict, "temperature": 0, + "prompt": case.prompt, } - response = session.post(f"{server_address}/v1/completions", json=json_data) - res = json.loads(response.text) - logger.info(f"response {res}") - extracted_answer = extract_boxed_text(res["choices"][0]["text"]) - source_answer = data["answer"] - if data["prompt_source"] == "aime" or data["prompt_source"] == "gsm8k": - try: # All AIME answers are integers, so we convert the extracted answer to an integer - extracted_answer = int(extracted_answer) - source_answer = int(source_answer) - except (ValueError, TypeError): - extracted_answer = None - logger.info(f"extracted_answer {extracted_answer}") - logger.info(f"data['answer'] {data['answer']}") - - score = 1 if extracted_answer == source_answer else 0 - - return score + return json_data -def get_server(path_server: str, path_log: Optional[str]) -> dict: - if path_server.startswith("http://") or path_server.startswith("https://"): - return {"process": None, "address": path_server, "fout": None} - if os.environ.get("LLAMA_ARG_HOST") is None: - logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1") - os.environ["LLAMA_ARG_HOST"] = "127.0.0.1" - if os.environ.get("LLAMA_ARG_PORT") is None: - logger.info("LLAMA_ARG_PORT not explicitly set, using 8080") - os.environ["LLAMA_ARG_PORT"] = "8080" - hostname: Optional[str] = os.environ.get("LLAMA_ARG_HOST") - port: Optional[str] = os.environ.get("LLAMA_ARG_PORT") - assert hostname is not None - assert port is not None - address: str = f"http://{hostname}:{port}" - logger.info(f"Starting the llama.cpp server under {address}...") +def send_prompt( + case: Case, + data: dict, +) -> dict[str, Union[str, int]]: + ret_err = { + "task": case.task, + "case_id": case.case_id, + "status": "error", + "correct": 0, + "gold": case.gold, + "pred": "", + "error": "", + } + session: requests.Session = data["session"] + server_address: str = data["server_address"] + task = TASK_DICT.get(case.task) + if task is None: + ret_err["error"] = f"unknown_task: {case.task}" + return ret_err + logger.debug(case.prompt) - fout = open(path_log.format(port=port), "w") if path_log is not None else subprocess.DEVNULL - process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT) + json_data = build_request(case, data["n_predict"]) + try: + response = session.post(f"{server_address}/v1/completions", json=json_data) + if response.ok: + res_json = response.json() + else: + ret_err["error"] = f"http_response: {response.status_code}" + logger.warning(ret_err["error"]) + return ret_err + except Exception as e: + ret_err["error"] = f"http_exception: {e}" + logger.warning(ret_err["error"]) + return ret_err + logger.debug(response.text) + return TASK_DICT[case.task].grade(case, res_json) - n_failures: int = 0 - while True: - try: - sleep(1.0) - exit_code = process.poll() - if exit_code is not None: - raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}{path_log and f', see {path_log.format(port=port)}' or ''}") - response = requests.get(f"{address}/health") - if response.status_code == 200: - break - except requests.ConnectionError: - n_failures += 1 - if n_failures >= 10: - raise RuntimeError("llama.cpp server is not healthy after 10 seconds") - return {"process": process, "address": address, "fout": fout} +def aggregate_by_task(results: list[dict[str, Any]]) -> dict[str, dict[str, int]]: + tmp = { + "total": 0, + "error": 0, + "invalid": 0, + "correct": 0, + } + agg: dict[str, dict[str, int]] = {} + for row in results: + d = agg.get(row["task"], tmp.copy()) + d["total"] += 1 + status = row["status"] + if status == "ok": + d["correct"] += row["correct"] + elif status == "invalid": + d["invalid"] += 1 + elif status == "error": + d["error"] += 1 + + agg[row["task"]] = d + return agg + + +def print_summary(pertask_results: dict[str, dict[str, int]]): + print("\n=== llama-eval suite summary ===") + print( + f"{'Task':<15} {'Acc':>8} {'Correct':>8} {'Total':>8} {'Invalid':>8} {'Error':>8}" + ) + print("-" * 65) + + suite_total = 0 + suite_correct = 0 + + for task in sorted(pertask_results.keys()): + stats = pertask_results[task] + total = stats["total"] + correct = stats["correct"] + invalid = stats["invalid"] + error = stats["error"] + + acc = (correct / total) if total > 0 else 0.0 + + print( + f"{task:<15} " + f"{acc:8.3f} " + f"{correct:8d} " + f"{total:8d} " + f"{invalid:8d} " + f"{error:8d}" + ) + + suite_total += total + suite_correct += correct + + # Overall summary + print("-" * 65) + suite_acc = (suite_correct / suite_total) if suite_total > 0 else 0.0 + print( + f"{'ALL':<15} " f"{suite_acc:8.3f} " f"{suite_correct:8d} " f"{suite_total:8d}" + ) def benchmark( path_server: str, - path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, rng_seed: int, ): - external_server: bool = path_server.startswith("http://") or path_server.startswith("https://") + if not path_server.startswith("http://") and not path_server.startswith("https://"): + logger.error("ERROR: malformed server path") + return + if os.environ.get("LLAMA_ARG_N_PARALLEL") is None: logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32") os.environ["LLAMA_ARG_N_PARALLEL"] = "32" - parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore - ds: Union[datasets.Dataset, None] = get_dataset(prompt_source, n_prompts, rng_seed) - if not ds: - logger.error("ERROR: get_dataset") - exit(0) + parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore - res: Union[tuple[list[str], list[str]], None] = get_prompts_text(prompt_source, ds) - if not res: - logger.error("ERROR: get_prompts_text") - exit(0) + task_queue: set[TaskSpec] = set() + for src in prompt_source.split(","): + if src == "all": + for v in TASK_DICT.values(): + task_queue.add(v()) + break + task_queue.add(TASK_DICT[src]()) - prompts: Union[list[str], list[list[int]]] = res[0] - answer: Union[list[str], list[list[int]]] = res[1] - - logger.info(prompts) - logger.info(f"external_server {external_server}") - - server: Optional[dict] = None session = None try: - server = get_server(path_server, path_log) - server_address: str = server["address"] - assert external_server == (server["process"] is None) + server_address: str = path_server adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore session = requests.Session() session.mount("http://", adapter) session.mount("https://", adapter) + cases: list[Case] = [] data: list[dict] = [] - for p, a in zip(prompts, answer): - data.append( - { - "prompt_source": prompt_source, - "session": session, - "server_address": server_address, - "external_server": external_server, - "prompt": p, - "answer": a, - "n_predict": n_predict, - } - ) - + for task in task_queue: + for case in task.iter_cases(n_prompts, rng_seed): + cases.append(case) + data.append( + { + "prompt_source": prompt_source, + "session": session, + "server_address": server_address, + "n_predict": n_predict, + } + ) logger.info("Starting the benchmark...\n") t0 = time() - results: list[int] = thread_map( - send_prompt, data, max_workers=parallel, chunksize=1 + results: list[dict[str, Union[str, int]]] = thread_map( + send_prompt, + cases, + data, + max_workers=parallel, + chunksize=1, ) finally: - if server is not None and server["process"] is not None: - server["process"].terminate() - server["process"].wait() if session is not None: session.close() t1 = time() + logger.info(f"\nllama-eval duration: {t1-t0:.2f} s") - correct: int = sum(results) - total_questions: int = len(data) - logger.info(f"llama-eval duration: {t1-t0:.2f} s") - logger.info(f"{prompt_source} correct: {correct}") - logger.info(f"{prompt_source} total_questions: {total_questions}") - logger.info(f"{prompt_source} accuracy: {correct / total_questions}") + pertask_results = aggregate_by_task(results) + print_summary(pertask_results) if __name__ == "__main__": @@ -324,23 +569,17 @@ if __name__ == "__main__": parser.add_argument( "--path_server", type=str, - default="llama-server", - help="Path to the llama.cpp server binary", - ) - parser.add_argument( - "--path_log", - type=str, - default="server-bench-{port}.log", - help="Path to the model to use for the benchmark", + default="http://localhost:8033", + help="llama-server url", ) parser.add_argument( "--prompt_source", type=str, default="mmlu", - help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions", + help=f"Eval types supported: all,{TASK_DICT.keys()}", ) parser.add_argument( - "--n_prompts", type=int, default=100, help="Number of prompts to evaluate" + "--n_prompts", type=int, default=None, help="Number of prompts to evaluate" ) parser.add_argument( "--rng_seed",