From 2357f6f193649ad54e2f0f1e0f0582d5012d2dc2 Mon Sep 17 00:00:00 2001 From: gatbontonpc Date: Sat, 10 Jan 2026 22:19:08 -0800 Subject: [PATCH 1/4] working llama-eval mc and math suite --- examples/llama-eval/llama-eval.py | 358 ++++++++++++++++++++++++++++++ 1 file changed, 358 insertions(+) create mode 100644 examples/llama-eval/llama-eval.py diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py new file mode 100644 index 0000000000..10ec766fe6 --- /dev/null +++ b/examples/llama-eval/llama-eval.py @@ -0,0 +1,358 @@ +#!/usr/bin/env python3 + +import re +import argparse +import json +import os +import random +import subprocess +from time import sleep, time +from typing import Optional, Union + +import datasets +import logging +import requests +from tqdm.contrib.concurrent import thread_map +from typing import Iterator +from abc import ABC + +logging.basicConfig(level=logging.INFO, format='%(message)s') +logger = logging.getLogger("llama-eval") + + +MATH_TEMPLATE = """ +{question} +Put your final answer within \\boxed{{}}. +""" + +MC_FROM_INT = { + 0: "A", + 1: "B", + 2: "C", + 3: "D", +} + + +def format_multiple_choice(prompt: str, choices: list[str]): + QUERY_TEMPLATE_MULTICHOICE = """ + {question} + + (A) {A} + (B) {B} + (C) {C} + (D) {D} + + Express your final answer as the corresponding option 'A', 'B', 'C', or 'D'. Put your final answer within \\boxed{{}}. + + """.strip() + A_str = choices[0] + B_str = choices[1] + C_str = choices[2] + D_str = choices[3] + query = QUERY_TEMPLATE_MULTICHOICE.format( + question=prompt, A=A_str, B=B_str, C=C_str, D=D_str + ) + return query + + +# Preprocess hellaswag +def preprocess(text): + text = text.strip() + # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. + text = text.replace(" [title]", ". ") + text = re.sub("\\[.*?\\]", "", text) + text = text.replace(" ", " ") + return text + + +def hellaswag_process_doc(doc): + ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize() + question = preprocess(doc["activity_label"] + ": " + ctx) + proc_answers = [preprocess(answer) for answer in doc["endings"]] + prompt = format_multiple_choice(question, proc_answers) + out_doc = { + "prompt": prompt, + "gold": MC_FROM_INT[int(doc["label"])], + } + return out_doc + + +def mmlu_process_doc(doc): + prompt = format_multiple_choice(doc["question"], doc["choices"]) + out_doc = { + "prompt": prompt, + "gold": MC_FROM_INT[int(doc["answer"])], + } + return out_doc + + +def extract_boxed_text(text): + pattern = r"boxed{(.*?)}|framebox{(.*?)}" + matches = re.findall(pattern, text, re.DOTALL) + logger.debug(matches) + if matches: + for match in matches[::-1]: + for group in match: + if group != "": + return group.split(",")[-1].strip() + logger.warning( + "Could not extract boxed text. Using last integer. Maybe expand context window" + ) + pattern = r"\d+" # get the last integer if no pattern found + matches = re.findall(pattern, text, re.DOTALL) + if matches: + return matches[-1] + + return "" + + +def get_prompts_text( + dataset_name: str, ds: datasets.Dataset +) -> Optional[tuple[list[str], list[str]]]: + ret = [] + if dataset_name.lower() == "mmlu": + ds = ds.map(mmlu_process_doc) + ret = ds["prompt"], ds["gold"] + elif dataset_name.lower() == "hellaswag": + ds = ds.map(hellaswag_process_doc) + ret = ds["prompt"], ds["gold"] + elif dataset_name.lower() == "aime": + ds = ds.map( + lambda k: { + "prompt": MATH_TEMPLATE.format( + question=k["problem"], + ) + } + ) + ret = ds["prompt"], ds["answer"] + elif dataset_name.lower() == "gsm8k": + ds = ds.map(lambda k: {"prompt": MATH_TEMPLATE.format(question=k["question"])}) + la = [] + for answer in ds["answer"]: + la.append(answer.split("### ")[-1].rstrip()) + ret = ds["prompt"], la + else: + return None + + return ret + + +def get_dataset( + dataset_name: str, n_prompts: int, rng_seed: int +) -> Optional[datasets.Dataset]: + ds = None + cache_dir = "./build/bin/datasets" + logger.info(f"Loading {dataset_name.lower()} dataset...") + if dataset_name.lower() == "mmlu": + ds = datasets.load_dataset( + "cais/mmlu", "all", split="test", cache_dir=cache_dir + ) + elif dataset_name.lower() == "hellaswag": + ds = datasets.load_dataset( + "Rowan/hellaswag", split="validation", cache_dir=cache_dir + ) + elif dataset_name.lower() == "aime": + ds = datasets.load_dataset( + "AI-MO/aimo-validation-aime", split="train", cache_dir=cache_dir + ) + elif dataset_name.lower() == "gsm8k": + ds = datasets.load_dataset("openai/gsm8k", split="test") + else: + return None + + if n_prompts >= 0: + ds = ds.shuffle(seed=rng_seed) + ds = ds.select(range(min(n_prompts, len(ds)))) + return ds + + +def send_prompt(data: dict) -> int: + session = data["session"] + server_address: str = data["server_address"] + prompt: str = data["prompt"] + logger.info(f"data['external_server'] {data['external_server']}") + logger.info(f"data['prompt'] {prompt}") + logger.info(f"data['n_predict'] {data['n_predict']}") + + json_data: dict = { + "prompt": prompt, + "max_tokens": data["n_predict"], + "temperature": 0, + } + response = session.post(f"{server_address}/v1/completions", json=json_data) + res = json.loads(response.text) + logger.info(f"response {res}") + extracted_answer = extract_boxed_text(res["choices"][0]["text"]) + source_answer = data["answer"] + if data["prompt_source"] == "aime" or data["prompt_source"] == "gsm8k": + try: # All AIME answers are integers, so we convert the extracted answer to an integer + extracted_answer = int(extracted_answer) + source_answer = int(source_answer) + except (ValueError, TypeError): + extracted_answer = None + logger.info(f"extracted_answer {extracted_answer}") + logger.info(f"data['answer'] {data['answer']}") + + score = 1 if extracted_answer == source_answer else 0 + + return score + + +def get_server(path_server: str, path_log: Optional[str]) -> dict: + if path_server.startswith("http://") or path_server.startswith("https://"): + return {"process": None, "address": path_server, "fout": None} + if os.environ.get("LLAMA_ARG_HOST") is None: + logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1") + os.environ["LLAMA_ARG_HOST"] = "127.0.0.1" + if os.environ.get("LLAMA_ARG_PORT") is None: + logger.info("LLAMA_ARG_PORT not explicitly set, using 8080") + os.environ["LLAMA_ARG_PORT"] = "8080" + hostname: Optional[str] = os.environ.get("LLAMA_ARG_HOST") + port: Optional[str] = os.environ.get("LLAMA_ARG_PORT") + assert hostname is not None + assert port is not None + address: str = f"http://{hostname}:{port}" + logger.info(f"Starting the llama.cpp server under {address}...") + + fout = open(path_log.format(port=port), "w") if path_log is not None else subprocess.DEVNULL + process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT) + + n_failures: int = 0 + while True: + try: + sleep(1.0) + exit_code = process.poll() + if exit_code is not None: + raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}{path_log and f', see {path_log.format(port=port)}' or ''}") + response = requests.get(f"{address}/health") + if response.status_code == 200: + break + except requests.ConnectionError: + n_failures += 1 + if n_failures >= 10: + raise RuntimeError("llama.cpp server is not healthy after 10 seconds") + + return {"process": process, "address": address, "fout": fout} + + +def benchmark( + path_server: str, + path_log: Optional[str], + prompt_source: str, + n_prompts: int, + n_predict: int, + rng_seed: int, +): + external_server: bool = path_server.startswith("http://") or path_server.startswith("https://") + if os.environ.get("LLAMA_ARG_N_PARALLEL") is None: + logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32") + os.environ["LLAMA_ARG_N_PARALLEL"] = "32" + + parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore + ds: Union[datasets.Dataset, None] = get_dataset(prompt_source, n_prompts, rng_seed) + if not ds: + logger.error("ERROR: get_dataset") + exit(0) + + res: Union[tuple[list[str], list[str]], None] = get_prompts_text(prompt_source, ds) + if not res: + logger.error("ERROR: get_prompts_text") + exit(0) + + prompts: Union[list[str], list[list[int]]] = res[0] + answer: Union[list[str], list[list[int]]] = res[1] + + logger.info(prompts) + logger.info(f"external_server {external_server}") + + server: Optional[dict] = None + session = None + try: + server = get_server(path_server, path_log) + server_address: str = server["address"] + assert external_server == (server["process"] is None) + + adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore + session = requests.Session() + session.mount("http://", adapter) + session.mount("https://", adapter) + + data: list[dict] = [] + for p, a in zip(prompts, answer): + data.append( + { + "prompt_source": prompt_source, + "session": session, + "server_address": server_address, + "external_server": external_server, + "prompt": p, + "answer": a, + "n_predict": n_predict, + } + ) + + logger.info("Starting the benchmark...\n") + t0 = time() + results: list[int] = thread_map( + send_prompt, data, max_workers=parallel, chunksize=1 + ) + finally: + if server is not None and server["process"] is not None: + server["process"].terminate() + server["process"].wait() + if session is not None: + session.close() + + t1 = time() + + correct: int = sum(results) + total_questions: int = len(data) + logger.info(f"llama-eval duration: {t1-t0:.2f} s") + logger.info(f"{prompt_source} correct: {correct}") + logger.info(f"{prompt_source} total_questions: {total_questions}") + logger.info(f"{prompt_source} accuracy: {correct / total_questions}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Tool for benchmarking the throughput of the llama.cpp HTTP server. " + "Results are printed to console and visualized as plots (saved to current working directory). " + "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help). " + "The reported numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, " + "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model)." + ) + parser.add_argument( + "--path_server", + type=str, + default="llama-server", + help="Path to the llama.cpp server binary", + ) + parser.add_argument( + "--path_log", + type=str, + default="server-bench-{port}.log", + help="Path to the model to use for the benchmark", + ) + parser.add_argument( + "--prompt_source", + type=str, + default="mmlu", + help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions", + ) + parser.add_argument( + "--n_prompts", type=int, default=100, help="Number of prompts to evaluate" + ) + parser.add_argument( + "--rng_seed", + type=int, + default=42, + help="Number to see rng (Used to select prompts from datasource)", + ) + parser.add_argument( + "--n_predict", + type=int, + default=2048, + help="Max. number of tokens to predict per prompt", + ) + args = parser.parse_args() + benchmark(**vars(args)) From f3a5b4ea728fb14bd0b2e5394f82c90d315e65b5 Mon Sep 17 00:00:00 2001 From: gatbontonpc Date: Mon, 12 Jan 2026 13:47:43 -0500 Subject: [PATCH 2/4] multi source llama-eval --- examples/llama-eval/llama-eval.py | 705 ++++++++++++++++++++---------- 1 file changed, 472 insertions(+), 233 deletions(-) diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py index 10ec766fe6..411d0adbab 100644 --- a/examples/llama-eval/llama-eval.py +++ b/examples/llama-eval/llama-eval.py @@ -2,91 +2,43 @@ import re import argparse -import json import os -import random -import subprocess -from time import sleep, time -from typing import Optional, Union +from time import time +from typing import Union, Any, Mapping, cast import datasets import logging import requests from tqdm.contrib.concurrent import thread_map from typing import Iterator -from abc import ABC +from abc import ABC, abstractmethod +from dataclasses import dataclass logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger("llama-eval") - MATH_TEMPLATE = """ {question} -Put your final answer within \\boxed{{}}. +Do not include any explanation. Put your final answer within \\boxed{{}}. """ -MC_FROM_INT = { - 0: "A", - 1: "B", - 2: "C", - 3: "D", -} - def format_multiple_choice(prompt: str, choices: list[str]): - QUERY_TEMPLATE_MULTICHOICE = """ - {question} + lines = [prompt] - (A) {A} - (B) {B} - (C) {C} - (D) {D} - - Express your final answer as the corresponding option 'A', 'B', 'C', or 'D'. Put your final answer within \\boxed{{}}. - - """.strip() - A_str = choices[0] - B_str = choices[1] - C_str = choices[2] - D_str = choices[3] - query = QUERY_TEMPLATE_MULTICHOICE.format( - question=prompt, A=A_str, B=B_str, C=C_str, D=D_str + labels = [chr(ord("A") + i) for i in range(len(choices))] + for l, c in zip(labels, choices): + lines.append(f"({l}): {c.strip()}") + lines.append( + "Do not include any explanation. Answer with the corresponding option letter only" ) - return query + lines.append(", ".join(labels)) + lines.append("Put your final answer within \\boxed{{}}.") + + return "\n".join(lines), labels -# Preprocess hellaswag -def preprocess(text): - text = text.strip() - # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. - text = text.replace(" [title]", ". ") - text = re.sub("\\[.*?\\]", "", text) - text = text.replace(" ", " ") - return text - - -def hellaswag_process_doc(doc): - ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize() - question = preprocess(doc["activity_label"] + ": " + ctx) - proc_answers = [preprocess(answer) for answer in doc["endings"]] - prompt = format_multiple_choice(question, proc_answers) - out_doc = { - "prompt": prompt, - "gold": MC_FROM_INT[int(doc["label"])], - } - return out_doc - - -def mmlu_process_doc(doc): - prompt = format_multiple_choice(doc["question"], doc["choices"]) - out_doc = { - "prompt": prompt, - "gold": MC_FROM_INT[int(doc["answer"])], - } - return out_doc - - -def extract_boxed_text(text): +def extract_boxed_text(text: str) -> str: pattern = r"boxed{(.*?)}|framebox{(.*?)}" matches = re.findall(pattern, text, re.DOTALL) logger.debug(matches) @@ -95,222 +47,515 @@ def extract_boxed_text(text): for group in match: if group != "": return group.split(",")[-1].strip() - logger.warning( - "Could not extract boxed text. Using last integer. Maybe expand context window" - ) - pattern = r"\d+" # get the last integer if no pattern found - matches = re.findall(pattern, text, re.DOTALL) - if matches: - return matches[-1] + logger.warning("Could not extract boxed text. Maybe expand context window") return "" -def get_prompts_text( - dataset_name: str, ds: datasets.Dataset -) -> Optional[tuple[list[str], list[str]]]: - ret = [] - if dataset_name.lower() == "mmlu": - ds = ds.map(mmlu_process_doc) - ret = ds["prompt"], ds["gold"] - elif dataset_name.lower() == "hellaswag": - ds = ds.map(hellaswag_process_doc) - ret = ds["prompt"], ds["gold"] - elif dataset_name.lower() == "aime": +@dataclass(frozen=True) +class Case: + task: str + kind: str + case_id: str + prompt: str + gold: str + meta_data: dict[str, Any] + + +class TaskSpec(ABC): + name: str + kind: str + + @abstractmethod + def load(self, limit, seed) -> datasets.Dataset: + pass + + @abstractmethod + def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: + pass + + @staticmethod + @abstractmethod + def grade(case: Case, response: dict) -> dict[str, Any]: + pass + + +class MCTaskSpec(TaskSpec): + @staticmethod + def grade(case: Case, response: dict) -> dict[str, Any]: + logger.debug(f"response {response}") + result = { + "task": case.task, + "case_id": case.case_id, + "correct": 0, + "pred": None, + "gold": case.gold, + "status": "ok", + } + + try: + extracted_answer = extract_boxed_text(response["choices"][0]["text"]) + except Exception as e: + result["status"] = "error" + logger.warning("ERROR: extract_boxed_text") + + return result + + if not extracted_answer: + result["status"] = "invalid" + logger.warning("INVALID: extract_boxed_text") + return result + + logger.debug(f"extracted_answer {extracted_answer}") + logger.debug(f"data['answer'] {case.gold}") + result["pred"] = extracted_answer + result["correct"] = 1 if extracted_answer == case.gold else 0 + + return result + + +class MathTaskSpec(TaskSpec): + + @staticmethod + def grade(case: Case, response: dict) -> dict[str, Any]: + logger.debug(f"response {response}") + result = { + "task": case.task, + "case_id": case.case_id, + "correct": 0, + "gold": case.gold, + "status": "ok", + "pred": None, + } + + try: + extracted_answer = extract_boxed_text(response["choices"][0]["text"]) + except Exception as e: + result["status"] = "error" + return result + + source_answer = case.gold + try: # All AIME answers are integers, so we convert the extracted answer to an integer + extracted_answer = int(extracted_answer) + source_answer = int(case.gold) + except (ValueError, TypeError): + result["status"] = "invalid" + return result + + logger.debug(f"extracted_answer {extracted_answer}") + logger.debug(f"data['answer'] {case.gold}") + result["pred"] = extracted_answer + result["correct"] = 1 if extracted_answer == source_answer else 0 + + return result + + +class ARC_Task(MCTaskSpec): + + def __init__(self): + self.name = "arc" + self.kind = "mc" + + def load(self, limit, seed) -> datasets.Dataset: + ds = datasets.load_dataset("allenai/ai2_arc", "ARC-Challenge", split="test") + if limit: + ds = ds.shuffle(seed=seed) + ds = ds.select(range(min(limit, len(ds)))) + return ds + + def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: + ds = self.load(limit, seed) + + for i, doc in enumerate(ds): + doc = cast(Mapping[str, Any], doc) + + prompt, labels = format_multiple_choice( + doc["question"], doc["choices"]["text"] + ) + yield Case( + task=self.name, + kind=self.kind, + case_id=f"ARC-Challenge:{i}", + prompt=prompt, + gold=doc["answerKey"], + meta_data={"labels": labels}, + ) + + +class WinoGrande_Task(MCTaskSpec): + + def __init__(self): + self.name = "winogrande" + self.kind = "mc" + + def load(self, limit, seed) -> datasets.Dataset: + ds = datasets.load_dataset( + "winogrande", "winogrande_debiased", split="validation" + ) + if limit: + ds = ds.shuffle(seed=seed) + ds = ds.select(range(min(limit, len(ds)))) + return ds + + def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: + ds = self.load(limit, seed) + + for i, doc in enumerate(ds): + doc = cast(Mapping[str, Any], doc) + + prompt, labels = format_multiple_choice( + doc["sentence"], [doc["option1"], doc["option2"]] + ) + yield Case( + task=self.name, + kind=self.kind, + case_id=f"winogrande:{i}", + prompt=prompt, + gold=labels[int(doc["answer"]) - 1], # winogrande answers are 1 based + meta_data={"labels": labels}, + ) + + +class MMLU_Task(MCTaskSpec): + + def __init__(self): + self.name = "mmlu" + self.kind = "mc" + + def load(self, limit, seed) -> datasets.Dataset: + ds = datasets.load_dataset("cais/mmlu", "all", split="test") + if limit: + ds = ds.shuffle(seed=seed) + ds = ds.select(range(min(limit, len(ds)))) + return ds + + def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: + ds = self.load(limit, seed) + + for i, doc in enumerate(ds): + doc = cast(Mapping[str, Any], doc) + + prompt, labels = format_multiple_choice(doc["question"], doc["choices"]) + yield Case( + task=self.name, + kind=self.kind, + case_id=f"mmlu:{doc['subject']}:{i}", + prompt=prompt, + gold=labels[int(doc["answer"])], + meta_data={"subject": doc["subject"], "labels": labels}, + ) + + +class Hellaswag_Task(MCTaskSpec): + + # Preprocess hellaswag + @staticmethod + def preprocess(text: str): + text = text.strip() + # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag. + text = text.replace(" [title]", ". ") + text = re.sub("\\[.*?\\]", "", text) + text = text.replace(" ", " ") + return text + + @staticmethod + def hellaswag_process_doc(doc: dict[str, str]): + ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize() + question = Hellaswag_Task.preprocess(doc["activity_label"] + ": " + ctx) + proc_answers = [Hellaswag_Task.preprocess(answer) for answer in doc["endings"]] + prompt, labels = format_multiple_choice(question, proc_answers) + out_doc = { + "prompt": prompt, + "gold": labels[int(doc["label"])], + } + return out_doc + + def __init__(self): + self.name = "hellaswag" + self.kind = "mc" + + def load(self, limit, seed) -> datasets.Dataset: + ds = datasets.load_dataset("Rowan/hellaswag", split="validation") + if limit: + ds = ds.shuffle(seed=seed) + ds = ds.select(range(min(limit, len(ds)))) + ds = ds.map(Hellaswag_Task.hellaswag_process_doc) + + return ds + + def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: + ds = self.load(limit, seed) + for i, doc in enumerate(ds): + doc = cast(Mapping[str, Any], doc) + yield Case( + task=self.name, + kind=self.kind, + case_id=f"hellaswag:{i}", + prompt=doc["prompt"], + gold=doc["gold"], + meta_data={}, + ) + + +class Aime_Task(MathTaskSpec): + + def __init__(self): + self.name = "aime" + self.kind = "math" + + def load(self, limit, seed) -> datasets.Dataset: + ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split="train") + + if limit: + ds = ds.shuffle(seed=seed) + ds = ds.select(range(min(limit, len(ds)))) + ds = ds.map( - lambda k: { + lambda ex: { "prompt": MATH_TEMPLATE.format( - question=k["problem"], + question=ex["problem"], ) } ) - ret = ds["prompt"], ds["answer"] - elif dataset_name.lower() == "gsm8k": - ds = ds.map(lambda k: {"prompt": MATH_TEMPLATE.format(question=k["question"])}) - la = [] - for answer in ds["answer"]: - la.append(answer.split("### ")[-1].rstrip()) - ret = ds["prompt"], la - else: - return None + return ds - return ret + def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: + ds = self.load(limit, seed) + + for i, doc in enumerate(ds): + doc = cast(Mapping[str, Any], doc) + yield Case( + task=self.name, + kind=self.kind, + case_id=f"aime:{i}", + prompt=doc["prompt"], + gold=doc["answer"], + meta_data={}, + ) -def get_dataset( - dataset_name: str, n_prompts: int, rng_seed: int -) -> Optional[datasets.Dataset]: - ds = None - cache_dir = "./build/bin/datasets" - logger.info(f"Loading {dataset_name.lower()} dataset...") - if dataset_name.lower() == "mmlu": - ds = datasets.load_dataset( - "cais/mmlu", "all", split="test", cache_dir=cache_dir +class Gsm8k_Task(MathTaskSpec): + + def __init__(self): + self.name = "gsm8k" + self.kind = "math" + + def load(self, limit, seed) -> datasets.Dataset: + ds = datasets.load_dataset("openai/gsm8k", "main", split="test") + if limit: + ds = ds.shuffle(seed=seed) + ds = ds.select(range(min(limit, len(ds)))) + + ds = ds.map( + lambda k: { + "prompt": MATH_TEMPLATE.format( + question=k["question"], + ), + "gold": k["answer"].split("### ")[-1].rstrip(), + } ) - elif dataset_name.lower() == "hellaswag": - ds = datasets.load_dataset( - "Rowan/hellaswag", split="validation", cache_dir=cache_dir - ) - elif dataset_name.lower() == "aime": - ds = datasets.load_dataset( - "AI-MO/aimo-validation-aime", split="train", cache_dir=cache_dir - ) - elif dataset_name.lower() == "gsm8k": - ds = datasets.load_dataset("openai/gsm8k", split="test") - else: - return None + return ds - if n_prompts >= 0: - ds = ds.shuffle(seed=rng_seed) - ds = ds.select(range(min(n_prompts, len(ds)))) - return ds + def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: + ds = self.load(limit, seed) + + for i, doc in enumerate(ds): + doc = cast(Mapping[str, Any], doc) + yield Case( + task=self.name, + kind=self.kind, + case_id=f"gsm8k:{i}", + prompt=doc["prompt"], + gold=doc["gold"], + meta_data={}, + ) -def send_prompt(data: dict) -> int: - session = data["session"] - server_address: str = data["server_address"] - prompt: str = data["prompt"] - logger.info(f"data['external_server'] {data['external_server']}") - logger.info(f"data['prompt'] {prompt}") - logger.info(f"data['n_predict'] {data['n_predict']}") +TASK_DICT: dict[str, type[TaskSpec]] = { + "mmlu": MMLU_Task, + "aime": Aime_Task, + "gsm8k": Gsm8k_Task, + "hellaswag": Hellaswag_Task, + "arc": ARC_Task, + "winogrande": WinoGrande_Task, +} - json_data: dict = { - "prompt": prompt, - "max_tokens": data["n_predict"], + +def build_request(case: Case, n_predict: int) -> dict[str, Any]: + json_data = { + "n_predict": n_predict, + "max_tokens": n_predict, "temperature": 0, + "prompt": case.prompt, } - response = session.post(f"{server_address}/v1/completions", json=json_data) - res = json.loads(response.text) - logger.info(f"response {res}") - extracted_answer = extract_boxed_text(res["choices"][0]["text"]) - source_answer = data["answer"] - if data["prompt_source"] == "aime" or data["prompt_source"] == "gsm8k": - try: # All AIME answers are integers, so we convert the extracted answer to an integer - extracted_answer = int(extracted_answer) - source_answer = int(source_answer) - except (ValueError, TypeError): - extracted_answer = None - logger.info(f"extracted_answer {extracted_answer}") - logger.info(f"data['answer'] {data['answer']}") - - score = 1 if extracted_answer == source_answer else 0 - - return score + return json_data -def get_server(path_server: str, path_log: Optional[str]) -> dict: - if path_server.startswith("http://") or path_server.startswith("https://"): - return {"process": None, "address": path_server, "fout": None} - if os.environ.get("LLAMA_ARG_HOST") is None: - logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1") - os.environ["LLAMA_ARG_HOST"] = "127.0.0.1" - if os.environ.get("LLAMA_ARG_PORT") is None: - logger.info("LLAMA_ARG_PORT not explicitly set, using 8080") - os.environ["LLAMA_ARG_PORT"] = "8080" - hostname: Optional[str] = os.environ.get("LLAMA_ARG_HOST") - port: Optional[str] = os.environ.get("LLAMA_ARG_PORT") - assert hostname is not None - assert port is not None - address: str = f"http://{hostname}:{port}" - logger.info(f"Starting the llama.cpp server under {address}...") +def send_prompt( + case: Case, + data: dict, +) -> dict[str, Union[str, int]]: + ret_err = { + "task": case.task, + "case_id": case.case_id, + "status": "error", + "correct": 0, + "gold": case.gold, + "pred": "", + "error": "", + } + session: requests.Session = data["session"] + server_address: str = data["server_address"] + task = TASK_DICT.get(case.task) + if task is None: + ret_err["error"] = f"unknown_task: {case.task}" + return ret_err + logger.debug(case.prompt) - fout = open(path_log.format(port=port), "w") if path_log is not None else subprocess.DEVNULL - process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT) + json_data = build_request(case, data["n_predict"]) + try: + response = session.post(f"{server_address}/v1/completions", json=json_data) + if response.ok: + res_json = response.json() + else: + ret_err["error"] = f"http_response: {response.status_code}" + logger.warning(ret_err["error"]) + return ret_err + except Exception as e: + ret_err["error"] = f"http_exception: {e}" + logger.warning(ret_err["error"]) + return ret_err + logger.debug(response.text) + return TASK_DICT[case.task].grade(case, res_json) - n_failures: int = 0 - while True: - try: - sleep(1.0) - exit_code = process.poll() - if exit_code is not None: - raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}{path_log and f', see {path_log.format(port=port)}' or ''}") - response = requests.get(f"{address}/health") - if response.status_code == 200: - break - except requests.ConnectionError: - n_failures += 1 - if n_failures >= 10: - raise RuntimeError("llama.cpp server is not healthy after 10 seconds") - return {"process": process, "address": address, "fout": fout} +def aggregate_by_task(results: list[dict[str, Any]]) -> dict[str, dict[str, int]]: + tmp = { + "total": 0, + "error": 0, + "invalid": 0, + "correct": 0, + } + agg: dict[str, dict[str, int]] = {} + for row in results: + d = agg.get(row["task"], tmp.copy()) + d["total"] += 1 + status = row["status"] + if status == "ok": + d["correct"] += row["correct"] + elif status == "invalid": + d["invalid"] += 1 + elif status == "error": + d["error"] += 1 + + agg[row["task"]] = d + return agg + + +def print_summary(pertask_results: dict[str, dict[str, int]]): + print("\n=== llama-eval suite summary ===") + print( + f"{'Task':<15} {'Acc':>8} {'Correct':>8} {'Total':>8} {'Invalid':>8} {'Error':>8}" + ) + print("-" * 65) + + suite_total = 0 + suite_correct = 0 + + for task in sorted(pertask_results.keys()): + stats = pertask_results[task] + total = stats["total"] + correct = stats["correct"] + invalid = stats["invalid"] + error = stats["error"] + + acc = (correct / total) if total > 0 else 0.0 + + print( + f"{task:<15} " + f"{acc:8.3f} " + f"{correct:8d} " + f"{total:8d} " + f"{invalid:8d} " + f"{error:8d}" + ) + + suite_total += total + suite_correct += correct + + # Overall summary + print("-" * 65) + suite_acc = (suite_correct / suite_total) if suite_total > 0 else 0.0 + print( + f"{'ALL':<15} " f"{suite_acc:8.3f} " f"{suite_correct:8d} " f"{suite_total:8d}" + ) def benchmark( path_server: str, - path_log: Optional[str], prompt_source: str, n_prompts: int, n_predict: int, rng_seed: int, ): - external_server: bool = path_server.startswith("http://") or path_server.startswith("https://") + if not path_server.startswith("http://") and not path_server.startswith("https://"): + logger.error("ERROR: malformed server path") + return + if os.environ.get("LLAMA_ARG_N_PARALLEL") is None: logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32") os.environ["LLAMA_ARG_N_PARALLEL"] = "32" - parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore - ds: Union[datasets.Dataset, None] = get_dataset(prompt_source, n_prompts, rng_seed) - if not ds: - logger.error("ERROR: get_dataset") - exit(0) + parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore - res: Union[tuple[list[str], list[str]], None] = get_prompts_text(prompt_source, ds) - if not res: - logger.error("ERROR: get_prompts_text") - exit(0) + task_queue: set[TaskSpec] = set() + for src in prompt_source.split(","): + if src == "all": + for v in TASK_DICT.values(): + task_queue.add(v()) + break + task_queue.add(TASK_DICT[src]()) - prompts: Union[list[str], list[list[int]]] = res[0] - answer: Union[list[str], list[list[int]]] = res[1] - - logger.info(prompts) - logger.info(f"external_server {external_server}") - - server: Optional[dict] = None session = None try: - server = get_server(path_server, path_log) - server_address: str = server["address"] - assert external_server == (server["process"] is None) + server_address: str = path_server adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore session = requests.Session() session.mount("http://", adapter) session.mount("https://", adapter) + cases: list[Case] = [] data: list[dict] = [] - for p, a in zip(prompts, answer): - data.append( - { - "prompt_source": prompt_source, - "session": session, - "server_address": server_address, - "external_server": external_server, - "prompt": p, - "answer": a, - "n_predict": n_predict, - } - ) - + for task in task_queue: + for case in task.iter_cases(n_prompts, rng_seed): + cases.append(case) + data.append( + { + "prompt_source": prompt_source, + "session": session, + "server_address": server_address, + "n_predict": n_predict, + } + ) logger.info("Starting the benchmark...\n") t0 = time() - results: list[int] = thread_map( - send_prompt, data, max_workers=parallel, chunksize=1 + results: list[dict[str, Union[str, int]]] = thread_map( + send_prompt, + cases, + data, + max_workers=parallel, + chunksize=1, ) finally: - if server is not None and server["process"] is not None: - server["process"].terminate() - server["process"].wait() if session is not None: session.close() t1 = time() + logger.info(f"\nllama-eval duration: {t1-t0:.2f} s") - correct: int = sum(results) - total_questions: int = len(data) - logger.info(f"llama-eval duration: {t1-t0:.2f} s") - logger.info(f"{prompt_source} correct: {correct}") - logger.info(f"{prompt_source} total_questions: {total_questions}") - logger.info(f"{prompt_source} accuracy: {correct / total_questions}") + pertask_results = aggregate_by_task(results) + print_summary(pertask_results) if __name__ == "__main__": @@ -324,23 +569,17 @@ if __name__ == "__main__": parser.add_argument( "--path_server", type=str, - default="llama-server", - help="Path to the llama.cpp server binary", - ) - parser.add_argument( - "--path_log", - type=str, - default="server-bench-{port}.log", - help="Path to the model to use for the benchmark", + default="http://localhost:8033", + help="llama-server url", ) parser.add_argument( "--prompt_source", type=str, default="mmlu", - help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions", + help=f"Eval types supported: all,{TASK_DICT.keys()}", ) parser.add_argument( - "--n_prompts", type=int, default=100, help="Number of prompts to evaluate" + "--n_prompts", type=int, default=None, help="Number of prompts to evaluate" ) parser.add_argument( "--rng_seed", From b0d50a5681706ca965044ba378a3a1c9bf9883b7 Mon Sep 17 00:00:00 2001 From: gatbontonpc Date: Mon, 12 Jan 2026 13:53:39 -0500 Subject: [PATCH 3/4] Add readme --- examples/llama-eval/README.md | 20 ++++++++++++++++++++ examples/llama-eval/llama-eval.py | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 examples/llama-eval/README.md diff --git a/examples/llama-eval/README.md b/examples/llama-eval/README.md new file mode 100644 index 0000000000..4dfaf09a22 --- /dev/null +++ b/examples/llama-eval/README.md @@ -0,0 +1,20 @@ +# llama.cpp/example/llama-eval + +The purpose of this example is to to run evaluations metrics against a an openapi api compatible LLM via http (llama-server). + +```bash +./llama-server -m model.gguf --port 8033 +``` + +```bash +python examples/llama-eval/llama-eval.py --path_server http://localhost:8033 --n_prompt 100 --prompt_source arc +``` + +## Supported tasks (MVP) + +- **GSM8K** — grade-school math (final-answer only) +- **AIME** — competition math (final-answer only) +- **MMLU** — multi-domain knowledge (multiple choice) +- **HellaSwag** — commonsense reasoning (multiple choice) +- **ARC** — grade-school science reasoning (multiple choice) +- **WinoGrande** — commonsense coreference resolution (multiple choice) \ No newline at end of file diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py index 411d0adbab..0ded50545c 100644 --- a/examples/llama-eval/llama-eval.py +++ b/examples/llama-eval/llama-eval.py @@ -576,7 +576,7 @@ if __name__ == "__main__": "--prompt_source", type=str, default="mmlu", - help=f"Eval types supported: all,{TASK_DICT.keys()}", + help=f"Eval types supported: all,{list(TASK_DICT.keys())}", ) parser.add_argument( "--n_prompts", type=int, default=None, help="Number of prompts to evaluate" From 979299a32f63e8804760ce50c47639305a567117 Mon Sep 17 00:00:00 2001 From: gatbontonpc Date: Fri, 16 Jan 2026 17:58:31 -0500 Subject: [PATCH 4/4] add checkpointing --- examples/llama-eval/README.md | 21 ++-- examples/llama-eval/llama-eval.py | 182 +++++++++++++++++++++++------- 2 files changed, 153 insertions(+), 50 deletions(-) diff --git a/examples/llama-eval/README.md b/examples/llama-eval/README.md index 4dfaf09a22..46224be3ec 100644 --- a/examples/llama-eval/README.md +++ b/examples/llama-eval/README.md @@ -1,20 +1,17 @@ # llama.cpp/example/llama-eval -The purpose of this example is to to run evaluations metrics against a an openapi api compatible LLM via http (llama-server). +`llama-eval.py` is a single-script evaluation runner that sends prompt/response pairs to any OpenAI-compatible HTTP server (the default `llama-server`). ```bash ./llama-server -m model.gguf --port 8033 +python examples/llama-eval/llama-eval.py --path_server http://localhost:8033 --n_prompts 100 --prompt_source arc ``` -```bash -python examples/llama-eval/llama-eval.py --path_server http://localhost:8033 --n_prompt 100 --prompt_source arc -``` +The supported tasks are: -## Supported tasks (MVP) - -- **GSM8K** — grade-school math (final-answer only) -- **AIME** — competition math (final-answer only) -- **MMLU** — multi-domain knowledge (multiple choice) -- **HellaSwag** — commonsense reasoning (multiple choice) -- **ARC** — grade-school science reasoning (multiple choice) -- **WinoGrande** — commonsense coreference resolution (multiple choice) \ No newline at end of file +- **GSM8K** — grade-school math +- **AIME** — competition math (integer answers) +- **MMLU** — multi-domain multiple choice +- **HellaSwag** — commonsense reasoning multiple choice +- **ARC** — grade-school science multiple choice +- **WinoGrande** — commonsense coreference multiple choice diff --git a/examples/llama-eval/llama-eval.py b/examples/llama-eval/llama-eval.py index 0ded50545c..78bfc0c2e4 100644 --- a/examples/llama-eval/llama-eval.py +++ b/examples/llama-eval/llama-eval.py @@ -10,9 +10,12 @@ import datasets import logging import requests from tqdm.contrib.concurrent import thread_map -from typing import Iterator +from typing import Iterator, Set from abc import ABC, abstractmethod from dataclasses import dataclass +from pathlib import Path +import json +import threading logging.basicConfig(level=logging.INFO, format='%(message)s') logger = logging.getLogger("llama-eval") @@ -47,7 +50,7 @@ def extract_boxed_text(text: str) -> str: for group in match: if group != "": return group.split(",")[-1].strip() - logger.warning("Could not extract boxed text. Maybe expand context window") + logger.debug("Could not extract boxed text. Maybe expand context window") return "" @@ -130,8 +133,9 @@ class MathTaskSpec(TaskSpec): try: extracted_answer = extract_boxed_text(response["choices"][0]["text"]) - except Exception as e: + except: result["status"] = "error" + logger.warning("ERROR: extract_boxed_text") return result source_answer = case.gold @@ -155,9 +159,12 @@ class ARC_Task(MCTaskSpec): def __init__(self): self.name = "arc" self.kind = "mc" + self.config = "ARC-Challenge" + self.split = "test" def load(self, limit, seed) -> datasets.Dataset: - ds = datasets.load_dataset("allenai/ai2_arc", "ARC-Challenge", split="test") + ds = datasets.load_dataset("allenai/ai2_arc", self.config, split=self.split) + ds = ds.add_column("_row_id", list(range(len(ds)))) if limit: ds = ds.shuffle(seed=seed) ds = ds.select(range(min(limit, len(ds)))) @@ -166,7 +173,7 @@ class ARC_Task(MCTaskSpec): def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: ds = self.load(limit, seed) - for i, doc in enumerate(ds): + for doc in ds: doc = cast(Mapping[str, Any], doc) prompt, labels = format_multiple_choice( @@ -175,7 +182,7 @@ class ARC_Task(MCTaskSpec): yield Case( task=self.name, kind=self.kind, - case_id=f"ARC-Challenge:{i}", + case_id=f"ARC-Challenge_{self.config}_{self.split}_{doc['_row_id']}", prompt=prompt, gold=doc["answerKey"], meta_data={"labels": labels}, @@ -187,11 +194,13 @@ class WinoGrande_Task(MCTaskSpec): def __init__(self): self.name = "winogrande" self.kind = "mc" + self.config = "winogrande_debiased" + self.split = "validation" def load(self, limit, seed) -> datasets.Dataset: - ds = datasets.load_dataset( - "winogrande", "winogrande_debiased", split="validation" - ) + ds = datasets.load_dataset("winogrande", self.config, split=self.split) + + ds = ds.add_column("_row_id", list(range(len(ds)))) if limit: ds = ds.shuffle(seed=seed) ds = ds.select(range(min(limit, len(ds)))) @@ -200,7 +209,7 @@ class WinoGrande_Task(MCTaskSpec): def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: ds = self.load(limit, seed) - for i, doc in enumerate(ds): + for doc in ds: doc = cast(Mapping[str, Any], doc) prompt, labels = format_multiple_choice( @@ -209,7 +218,7 @@ class WinoGrande_Task(MCTaskSpec): yield Case( task=self.name, kind=self.kind, - case_id=f"winogrande:{i}", + case_id=f"winogrande_{self.config}_{self.split}_{doc['_row_id']}", prompt=prompt, gold=labels[int(doc["answer"]) - 1], # winogrande answers are 1 based meta_data={"labels": labels}, @@ -221,9 +230,12 @@ class MMLU_Task(MCTaskSpec): def __init__(self): self.name = "mmlu" self.kind = "mc" + self.config = "all" + self.split = "test" def load(self, limit, seed) -> datasets.Dataset: - ds = datasets.load_dataset("cais/mmlu", "all", split="test") + ds = datasets.load_dataset("cais/mmlu", self.config, split=self.split) + ds = ds.add_column("_row_id", list(range(len(ds)))) if limit: ds = ds.shuffle(seed=seed) ds = ds.select(range(min(limit, len(ds)))) @@ -232,14 +244,14 @@ class MMLU_Task(MCTaskSpec): def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: ds = self.load(limit, seed) - for i, doc in enumerate(ds): + for doc in ds: doc = cast(Mapping[str, Any], doc) prompt, labels = format_multiple_choice(doc["question"], doc["choices"]) yield Case( task=self.name, kind=self.kind, - case_id=f"mmlu:{doc['subject']}:{i}", + case_id=f"mmlu_{self.config}_{self.split}_{doc['subject']}_{doc['_row_id']}", prompt=prompt, gold=labels[int(doc["answer"])], meta_data={"subject": doc["subject"], "labels": labels}, @@ -285,12 +297,12 @@ class Hellaswag_Task(MCTaskSpec): def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: ds = self.load(limit, seed) - for i, doc in enumerate(ds): + for doc in ds: doc = cast(Mapping[str, Any], doc) yield Case( task=self.name, kind=self.kind, - case_id=f"hellaswag:{i}", + case_id=f"hellaswag_{doc['split']}_{doc['ind']}", prompt=doc["prompt"], gold=doc["gold"], meta_data={}, @@ -302,9 +314,10 @@ class Aime_Task(MathTaskSpec): def __init__(self): self.name = "aime" self.kind = "math" + self.split = "train" def load(self, limit, seed) -> datasets.Dataset: - ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split="train") + ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split=self.split) if limit: ds = ds.shuffle(seed=seed) @@ -327,10 +340,10 @@ class Aime_Task(MathTaskSpec): yield Case( task=self.name, kind=self.kind, - case_id=f"aime:{i}", + case_id=f"aime_{self.split}_{doc['id']}", prompt=doc["prompt"], gold=doc["answer"], - meta_data={}, + meta_data={"id": doc["id"]}, ) @@ -339,9 +352,12 @@ class Gsm8k_Task(MathTaskSpec): def __init__(self): self.name = "gsm8k" self.kind = "math" + self.config = "main" + self.split = "test" def load(self, limit, seed) -> datasets.Dataset: - ds = datasets.load_dataset("openai/gsm8k", "main", split="test") + ds = datasets.load_dataset("openai/gsm8k", self.config, split=self.split) + ds = ds.add_column("_row_id", list(range(len(ds)))) if limit: ds = ds.shuffle(seed=seed) ds = ds.select(range(min(limit, len(ds)))) @@ -359,12 +375,12 @@ class Gsm8k_Task(MathTaskSpec): def iter_cases(self, limit: int, seed: int) -> Iterator[Case]: ds = self.load(limit, seed) - for i, doc in enumerate(ds): + for doc in ds: doc = cast(Mapping[str, Any], doc) yield Case( task=self.name, kind=self.kind, - case_id=f"gsm8k:{i}", + case_id=f"gsm8k_{self.config}_{self.split}:{doc['_row_id']}", prompt=doc["prompt"], gold=doc["gold"], meta_data={}, @@ -391,11 +407,21 @@ def build_request(case: Case, n_predict: int) -> dict[str, Any]: return json_data +def write_checkpoint_line( + checkpoint_file: Path, + row: dict[str, Any], + file_lock: threading.Lock, +): + with file_lock: + with checkpoint_file.open(mode="a", encoding="utf-8") as f: + f.write(json.dumps(row) + "\n") + + def send_prompt( case: Case, data: dict, ) -> dict[str, Union[str, int]]: - ret_err = { + result = { "task": case.task, "case_id": case.case_id, "status": "error", @@ -408,26 +434,29 @@ def send_prompt( server_address: str = data["server_address"] task = TASK_DICT.get(case.task) if task is None: - ret_err["error"] = f"unknown_task: {case.task}" - return ret_err + result["error"] = f"unknown_task: {case.task}" + return result logger.debug(case.prompt) json_data = build_request(case, data["n_predict"]) + res_json = {} try: response = session.post(f"{server_address}/v1/completions", json=json_data) - if response.ok: - res_json = response.json() - else: - ret_err["error"] = f"http_response: {response.status_code}" - logger.warning(ret_err["error"]) - return ret_err + res_json = response.json() + result["status"] = "ok" except Exception as e: - ret_err["error"] = f"http_exception: {e}" - logger.warning(ret_err["error"]) - return ret_err - logger.debug(response.text) - return TASK_DICT[case.task].grade(case, res_json) + result["error"] = f"http_exception: {e}" + logger.warning(result["error"]) + if result["status"] == "ok": + result = TASK_DICT[case.task].grade(case, res_json) + + write_checkpoint_line( + data["checkpoint_file"], + result.copy(), + data["file_lock"], + ) + return result def aggregate_by_task(results: list[dict[str, Any]]) -> dict[str, dict[str, int]]: tmp = { @@ -491,13 +520,52 @@ def print_summary(pertask_results: dict[str, dict[str, int]]): ) +def read_checkpoint( + checkpoint_file: Path, resume_flag: bool +) -> tuple[Set[str], Set[str], list[dict[str, Any]]]: + done = set() + errored = set() + results = [] + if not resume_flag or not checkpoint_file.is_file(): + return done, errored, results + + with checkpoint_file.open(mode="r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + row = json.loads(line) + except Exception as e: + logger.warning(f"WARNING: malformed checkpoint line {line}\n{e}") + continue + + case_id = row.get("case_id") + if not case_id: + continue + + if row["status"] == "error": + errored.add(case_id) + else: + done.add(case_id) + results.append(row) + errored -= done + return done, errored, results + + def benchmark( path_server: str, prompt_source: str, n_prompts: int, n_predict: int, rng_seed: int, + resume_flag: bool, + checkpoint_file: Path, + log_level: int, ): + logger.setLevel(log_level) + done, errored, checkpoint_results = read_checkpoint(checkpoint_file, resume_flag) + if not path_server.startswith("http://") and not path_server.startswith("https://"): logger.error("ERROR: malformed server path") return @@ -524,11 +592,15 @@ def benchmark( session = requests.Session() session.mount("http://", adapter) session.mount("https://", adapter) - + file_lock = threading.Lock() cases: list[Case] = [] data: list[dict] = [] for task in task_queue: for case in task.iter_cases(n_prompts, rng_seed): + if case.case_id in done or case.case_id in errored: + logger.debug(f"Skipping case_id {case.case_id} from checkpoint") + continue + cases.append(case) data.append( { @@ -536,6 +608,8 @@ def benchmark( "session": session, "server_address": server_address, "n_predict": n_predict, + "file_lock": file_lock, + "checkpoint_file": checkpoint_file, } ) logger.info("Starting the benchmark...\n") @@ -553,7 +627,7 @@ def benchmark( t1 = time() logger.info(f"\nllama-eval duration: {t1-t0:.2f} s") - + results.extend(checkpoint_results) pertask_results = aggregate_by_task(results) print_summary(pertask_results) @@ -593,5 +667,37 @@ if __name__ == "__main__": default=2048, help="Max. number of tokens to predict per prompt", ) + parser.add_argument( + "--resume", + dest="resume_flag", + action="store_true", + default=True, + help="Enable resuming from last state stored in checkpoint file", + ) + parser.add_argument( + "--no-resume", + dest="resume_flag", + action="store_false", + help="Disble resuming from last state stored in checkpoint file", + ) + parser.add_argument( + "--checkpoint-file", + type=Path, + dest="checkpoint_file", + default="./llama-eval-checkpoint.jsonl", + help="Checkpoint file to read last state from", + ) + parser.set_defaults(log_level=logging.INFO) + parser.add_argument( + "--quiet", action="store_const", dest="log_level", const=logging.ERROR + ) + parser.add_argument( + "--debug", + action="store_const", + default=True, + dest="log_level", + const=logging.DEBUG, + ) + args = parser.parse_args() benchmark(**vars(args))