multi source llama-eval
This commit is contained in:
parent
2357f6f193
commit
f3a5b4ea72
|
|
@ -2,91 +2,43 @@
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import random
|
from time import time
|
||||||
import subprocess
|
from typing import Union, Any, Mapping, cast
|
||||||
from time import sleep, time
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import logging
|
import logging
|
||||||
import requests
|
import requests
|
||||||
from tqdm.contrib.concurrent import thread_map
|
from tqdm.contrib.concurrent import thread_map
|
||||||
from typing import Iterator
|
from typing import Iterator
|
||||||
from abc import ABC
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||||
logger = logging.getLogger("llama-eval")
|
logger = logging.getLogger("llama-eval")
|
||||||
|
|
||||||
|
|
||||||
MATH_TEMPLATE = """
|
MATH_TEMPLATE = """
|
||||||
{question}
|
{question}
|
||||||
Put your final answer within \\boxed{{}}.
|
Do not include any explanation. Put your final answer within \\boxed{{}}.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
MC_FROM_INT = {
|
|
||||||
0: "A",
|
|
||||||
1: "B",
|
|
||||||
2: "C",
|
|
||||||
3: "D",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def format_multiple_choice(prompt: str, choices: list[str]):
|
def format_multiple_choice(prompt: str, choices: list[str]):
|
||||||
QUERY_TEMPLATE_MULTICHOICE = """
|
lines = [prompt]
|
||||||
{question}
|
|
||||||
|
|
||||||
(A) {A}
|
labels = [chr(ord("A") + i) for i in range(len(choices))]
|
||||||
(B) {B}
|
for l, c in zip(labels, choices):
|
||||||
(C) {C}
|
lines.append(f"({l}): {c.strip()}")
|
||||||
(D) {D}
|
lines.append(
|
||||||
|
"Do not include any explanation. Answer with the corresponding option letter only"
|
||||||
Express your final answer as the corresponding option 'A', 'B', 'C', or 'D'. Put your final answer within \\boxed{{}}.
|
|
||||||
|
|
||||||
""".strip()
|
|
||||||
A_str = choices[0]
|
|
||||||
B_str = choices[1]
|
|
||||||
C_str = choices[2]
|
|
||||||
D_str = choices[3]
|
|
||||||
query = QUERY_TEMPLATE_MULTICHOICE.format(
|
|
||||||
question=prompt, A=A_str, B=B_str, C=C_str, D=D_str
|
|
||||||
)
|
)
|
||||||
return query
|
lines.append(", ".join(labels))
|
||||||
|
lines.append("Put your final answer within \\boxed{{}}.")
|
||||||
|
|
||||||
|
return "\n".join(lines), labels
|
||||||
|
|
||||||
|
|
||||||
# Preprocess hellaswag
|
def extract_boxed_text(text: str) -> str:
|
||||||
def preprocess(text):
|
|
||||||
text = text.strip()
|
|
||||||
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
|
|
||||||
text = text.replace(" [title]", ". ")
|
|
||||||
text = re.sub("\\[.*?\\]", "", text)
|
|
||||||
text = text.replace(" ", " ")
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def hellaswag_process_doc(doc):
|
|
||||||
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
|
|
||||||
question = preprocess(doc["activity_label"] + ": " + ctx)
|
|
||||||
proc_answers = [preprocess(answer) for answer in doc["endings"]]
|
|
||||||
prompt = format_multiple_choice(question, proc_answers)
|
|
||||||
out_doc = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"gold": MC_FROM_INT[int(doc["label"])],
|
|
||||||
}
|
|
||||||
return out_doc
|
|
||||||
|
|
||||||
|
|
||||||
def mmlu_process_doc(doc):
|
|
||||||
prompt = format_multiple_choice(doc["question"], doc["choices"])
|
|
||||||
out_doc = {
|
|
||||||
"prompt": prompt,
|
|
||||||
"gold": MC_FROM_INT[int(doc["answer"])],
|
|
||||||
}
|
|
||||||
return out_doc
|
|
||||||
|
|
||||||
|
|
||||||
def extract_boxed_text(text):
|
|
||||||
pattern = r"boxed{(.*?)}|framebox{(.*?)}"
|
pattern = r"boxed{(.*?)}|framebox{(.*?)}"
|
||||||
matches = re.findall(pattern, text, re.DOTALL)
|
matches = re.findall(pattern, text, re.DOTALL)
|
||||||
logger.debug(matches)
|
logger.debug(matches)
|
||||||
|
|
@ -95,222 +47,515 @@ def extract_boxed_text(text):
|
||||||
for group in match:
|
for group in match:
|
||||||
if group != "":
|
if group != "":
|
||||||
return group.split(",")[-1].strip()
|
return group.split(",")[-1].strip()
|
||||||
logger.warning(
|
logger.warning("Could not extract boxed text. Maybe expand context window")
|
||||||
"Could not extract boxed text. Using last integer. Maybe expand context window"
|
|
||||||
)
|
|
||||||
pattern = r"\d+" # get the last integer if no pattern found
|
|
||||||
matches = re.findall(pattern, text, re.DOTALL)
|
|
||||||
if matches:
|
|
||||||
return matches[-1]
|
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def get_prompts_text(
|
@dataclass(frozen=True)
|
||||||
dataset_name: str, ds: datasets.Dataset
|
class Case:
|
||||||
) -> Optional[tuple[list[str], list[str]]]:
|
task: str
|
||||||
ret = []
|
kind: str
|
||||||
if dataset_name.lower() == "mmlu":
|
case_id: str
|
||||||
ds = ds.map(mmlu_process_doc)
|
prompt: str
|
||||||
ret = ds["prompt"], ds["gold"]
|
gold: str
|
||||||
elif dataset_name.lower() == "hellaswag":
|
meta_data: dict[str, Any]
|
||||||
ds = ds.map(hellaswag_process_doc)
|
|
||||||
ret = ds["prompt"], ds["gold"]
|
|
||||||
elif dataset_name.lower() == "aime":
|
class TaskSpec(ABC):
|
||||||
|
name: str
|
||||||
|
kind: str
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load(self, limit, seed) -> datasets.Dataset:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def grade(case: Case, response: dict) -> dict[str, Any]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MCTaskSpec(TaskSpec):
|
||||||
|
@staticmethod
|
||||||
|
def grade(case: Case, response: dict) -> dict[str, Any]:
|
||||||
|
logger.debug(f"response {response}")
|
||||||
|
result = {
|
||||||
|
"task": case.task,
|
||||||
|
"case_id": case.case_id,
|
||||||
|
"correct": 0,
|
||||||
|
"pred": None,
|
||||||
|
"gold": case.gold,
|
||||||
|
"status": "ok",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
extracted_answer = extract_boxed_text(response["choices"][0]["text"])
|
||||||
|
except Exception as e:
|
||||||
|
result["status"] = "error"
|
||||||
|
logger.warning("ERROR: extract_boxed_text")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
if not extracted_answer:
|
||||||
|
result["status"] = "invalid"
|
||||||
|
logger.warning("INVALID: extract_boxed_text")
|
||||||
|
return result
|
||||||
|
|
||||||
|
logger.debug(f"extracted_answer {extracted_answer}")
|
||||||
|
logger.debug(f"data['answer'] {case.gold}")
|
||||||
|
result["pred"] = extracted_answer
|
||||||
|
result["correct"] = 1 if extracted_answer == case.gold else 0
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class MathTaskSpec(TaskSpec):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def grade(case: Case, response: dict) -> dict[str, Any]:
|
||||||
|
logger.debug(f"response {response}")
|
||||||
|
result = {
|
||||||
|
"task": case.task,
|
||||||
|
"case_id": case.case_id,
|
||||||
|
"correct": 0,
|
||||||
|
"gold": case.gold,
|
||||||
|
"status": "ok",
|
||||||
|
"pred": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
extracted_answer = extract_boxed_text(response["choices"][0]["text"])
|
||||||
|
except Exception as e:
|
||||||
|
result["status"] = "error"
|
||||||
|
return result
|
||||||
|
|
||||||
|
source_answer = case.gold
|
||||||
|
try: # All AIME answers are integers, so we convert the extracted answer to an integer
|
||||||
|
extracted_answer = int(extracted_answer)
|
||||||
|
source_answer = int(case.gold)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
result["status"] = "invalid"
|
||||||
|
return result
|
||||||
|
|
||||||
|
logger.debug(f"extracted_answer {extracted_answer}")
|
||||||
|
logger.debug(f"data['answer'] {case.gold}")
|
||||||
|
result["pred"] = extracted_answer
|
||||||
|
result["correct"] = 1 if extracted_answer == source_answer else 0
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class ARC_Task(MCTaskSpec):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.name = "arc"
|
||||||
|
self.kind = "mc"
|
||||||
|
|
||||||
|
def load(self, limit, seed) -> datasets.Dataset:
|
||||||
|
ds = datasets.load_dataset("allenai/ai2_arc", "ARC-Challenge", split="test")
|
||||||
|
if limit:
|
||||||
|
ds = ds.shuffle(seed=seed)
|
||||||
|
ds = ds.select(range(min(limit, len(ds))))
|
||||||
|
return ds
|
||||||
|
|
||||||
|
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
|
||||||
|
ds = self.load(limit, seed)
|
||||||
|
|
||||||
|
for i, doc in enumerate(ds):
|
||||||
|
doc = cast(Mapping[str, Any], doc)
|
||||||
|
|
||||||
|
prompt, labels = format_multiple_choice(
|
||||||
|
doc["question"], doc["choices"]["text"]
|
||||||
|
)
|
||||||
|
yield Case(
|
||||||
|
task=self.name,
|
||||||
|
kind=self.kind,
|
||||||
|
case_id=f"ARC-Challenge:{i}",
|
||||||
|
prompt=prompt,
|
||||||
|
gold=doc["answerKey"],
|
||||||
|
meta_data={"labels": labels},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WinoGrande_Task(MCTaskSpec):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.name = "winogrande"
|
||||||
|
self.kind = "mc"
|
||||||
|
|
||||||
|
def load(self, limit, seed) -> datasets.Dataset:
|
||||||
|
ds = datasets.load_dataset(
|
||||||
|
"winogrande", "winogrande_debiased", split="validation"
|
||||||
|
)
|
||||||
|
if limit:
|
||||||
|
ds = ds.shuffle(seed=seed)
|
||||||
|
ds = ds.select(range(min(limit, len(ds))))
|
||||||
|
return ds
|
||||||
|
|
||||||
|
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
|
||||||
|
ds = self.load(limit, seed)
|
||||||
|
|
||||||
|
for i, doc in enumerate(ds):
|
||||||
|
doc = cast(Mapping[str, Any], doc)
|
||||||
|
|
||||||
|
prompt, labels = format_multiple_choice(
|
||||||
|
doc["sentence"], [doc["option1"], doc["option2"]]
|
||||||
|
)
|
||||||
|
yield Case(
|
||||||
|
task=self.name,
|
||||||
|
kind=self.kind,
|
||||||
|
case_id=f"winogrande:{i}",
|
||||||
|
prompt=prompt,
|
||||||
|
gold=labels[int(doc["answer"]) - 1], # winogrande answers are 1 based
|
||||||
|
meta_data={"labels": labels},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MMLU_Task(MCTaskSpec):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.name = "mmlu"
|
||||||
|
self.kind = "mc"
|
||||||
|
|
||||||
|
def load(self, limit, seed) -> datasets.Dataset:
|
||||||
|
ds = datasets.load_dataset("cais/mmlu", "all", split="test")
|
||||||
|
if limit:
|
||||||
|
ds = ds.shuffle(seed=seed)
|
||||||
|
ds = ds.select(range(min(limit, len(ds))))
|
||||||
|
return ds
|
||||||
|
|
||||||
|
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
|
||||||
|
ds = self.load(limit, seed)
|
||||||
|
|
||||||
|
for i, doc in enumerate(ds):
|
||||||
|
doc = cast(Mapping[str, Any], doc)
|
||||||
|
|
||||||
|
prompt, labels = format_multiple_choice(doc["question"], doc["choices"])
|
||||||
|
yield Case(
|
||||||
|
task=self.name,
|
||||||
|
kind=self.kind,
|
||||||
|
case_id=f"mmlu:{doc['subject']}:{i}",
|
||||||
|
prompt=prompt,
|
||||||
|
gold=labels[int(doc["answer"])],
|
||||||
|
meta_data={"subject": doc["subject"], "labels": labels},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Hellaswag_Task(MCTaskSpec):
|
||||||
|
|
||||||
|
# Preprocess hellaswag
|
||||||
|
@staticmethod
|
||||||
|
def preprocess(text: str):
|
||||||
|
text = text.strip()
|
||||||
|
# NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
|
||||||
|
text = text.replace(" [title]", ". ")
|
||||||
|
text = re.sub("\\[.*?\\]", "", text)
|
||||||
|
text = text.replace(" ", " ")
|
||||||
|
return text
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def hellaswag_process_doc(doc: dict[str, str]):
|
||||||
|
ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
|
||||||
|
question = Hellaswag_Task.preprocess(doc["activity_label"] + ": " + ctx)
|
||||||
|
proc_answers = [Hellaswag_Task.preprocess(answer) for answer in doc["endings"]]
|
||||||
|
prompt, labels = format_multiple_choice(question, proc_answers)
|
||||||
|
out_doc = {
|
||||||
|
"prompt": prompt,
|
||||||
|
"gold": labels[int(doc["label"])],
|
||||||
|
}
|
||||||
|
return out_doc
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.name = "hellaswag"
|
||||||
|
self.kind = "mc"
|
||||||
|
|
||||||
|
def load(self, limit, seed) -> datasets.Dataset:
|
||||||
|
ds = datasets.load_dataset("Rowan/hellaswag", split="validation")
|
||||||
|
if limit:
|
||||||
|
ds = ds.shuffle(seed=seed)
|
||||||
|
ds = ds.select(range(min(limit, len(ds))))
|
||||||
|
ds = ds.map(Hellaswag_Task.hellaswag_process_doc)
|
||||||
|
|
||||||
|
return ds
|
||||||
|
|
||||||
|
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
|
||||||
|
ds = self.load(limit, seed)
|
||||||
|
for i, doc in enumerate(ds):
|
||||||
|
doc = cast(Mapping[str, Any], doc)
|
||||||
|
yield Case(
|
||||||
|
task=self.name,
|
||||||
|
kind=self.kind,
|
||||||
|
case_id=f"hellaswag:{i}",
|
||||||
|
prompt=doc["prompt"],
|
||||||
|
gold=doc["gold"],
|
||||||
|
meta_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Aime_Task(MathTaskSpec):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.name = "aime"
|
||||||
|
self.kind = "math"
|
||||||
|
|
||||||
|
def load(self, limit, seed) -> datasets.Dataset:
|
||||||
|
ds = datasets.load_dataset("AI-MO/aimo-validation-aime", split="train")
|
||||||
|
|
||||||
|
if limit:
|
||||||
|
ds = ds.shuffle(seed=seed)
|
||||||
|
ds = ds.select(range(min(limit, len(ds))))
|
||||||
|
|
||||||
ds = ds.map(
|
ds = ds.map(
|
||||||
lambda k: {
|
lambda ex: {
|
||||||
"prompt": MATH_TEMPLATE.format(
|
"prompt": MATH_TEMPLATE.format(
|
||||||
question=k["problem"],
|
question=ex["problem"],
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
ret = ds["prompt"], ds["answer"]
|
return ds
|
||||||
elif dataset_name.lower() == "gsm8k":
|
|
||||||
ds = ds.map(lambda k: {"prompt": MATH_TEMPLATE.format(question=k["question"])})
|
|
||||||
la = []
|
|
||||||
for answer in ds["answer"]:
|
|
||||||
la.append(answer.split("### ")[-1].rstrip())
|
|
||||||
ret = ds["prompt"], la
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return ret
|
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
|
||||||
|
ds = self.load(limit, seed)
|
||||||
|
|
||||||
|
for i, doc in enumerate(ds):
|
||||||
|
doc = cast(Mapping[str, Any], doc)
|
||||||
|
yield Case(
|
||||||
|
task=self.name,
|
||||||
|
kind=self.kind,
|
||||||
|
case_id=f"aime:{i}",
|
||||||
|
prompt=doc["prompt"],
|
||||||
|
gold=doc["answer"],
|
||||||
|
meta_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_dataset(
|
class Gsm8k_Task(MathTaskSpec):
|
||||||
dataset_name: str, n_prompts: int, rng_seed: int
|
|
||||||
) -> Optional[datasets.Dataset]:
|
def __init__(self):
|
||||||
ds = None
|
self.name = "gsm8k"
|
||||||
cache_dir = "./build/bin/datasets"
|
self.kind = "math"
|
||||||
logger.info(f"Loading {dataset_name.lower()} dataset...")
|
|
||||||
if dataset_name.lower() == "mmlu":
|
def load(self, limit, seed) -> datasets.Dataset:
|
||||||
ds = datasets.load_dataset(
|
ds = datasets.load_dataset("openai/gsm8k", "main", split="test")
|
||||||
"cais/mmlu", "all", split="test", cache_dir=cache_dir
|
if limit:
|
||||||
|
ds = ds.shuffle(seed=seed)
|
||||||
|
ds = ds.select(range(min(limit, len(ds))))
|
||||||
|
|
||||||
|
ds = ds.map(
|
||||||
|
lambda k: {
|
||||||
|
"prompt": MATH_TEMPLATE.format(
|
||||||
|
question=k["question"],
|
||||||
|
),
|
||||||
|
"gold": k["answer"].split("### ")[-1].rstrip(),
|
||||||
|
}
|
||||||
)
|
)
|
||||||
elif dataset_name.lower() == "hellaswag":
|
return ds
|
||||||
ds = datasets.load_dataset(
|
|
||||||
"Rowan/hellaswag", split="validation", cache_dir=cache_dir
|
|
||||||
)
|
|
||||||
elif dataset_name.lower() == "aime":
|
|
||||||
ds = datasets.load_dataset(
|
|
||||||
"AI-MO/aimo-validation-aime", split="train", cache_dir=cache_dir
|
|
||||||
)
|
|
||||||
elif dataset_name.lower() == "gsm8k":
|
|
||||||
ds = datasets.load_dataset("openai/gsm8k", split="test")
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if n_prompts >= 0:
|
def iter_cases(self, limit: int, seed: int) -> Iterator[Case]:
|
||||||
ds = ds.shuffle(seed=rng_seed)
|
ds = self.load(limit, seed)
|
||||||
ds = ds.select(range(min(n_prompts, len(ds))))
|
|
||||||
return ds
|
for i, doc in enumerate(ds):
|
||||||
|
doc = cast(Mapping[str, Any], doc)
|
||||||
|
yield Case(
|
||||||
|
task=self.name,
|
||||||
|
kind=self.kind,
|
||||||
|
case_id=f"gsm8k:{i}",
|
||||||
|
prompt=doc["prompt"],
|
||||||
|
gold=doc["gold"],
|
||||||
|
meta_data={},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def send_prompt(data: dict) -> int:
|
TASK_DICT: dict[str, type[TaskSpec]] = {
|
||||||
session = data["session"]
|
"mmlu": MMLU_Task,
|
||||||
server_address: str = data["server_address"]
|
"aime": Aime_Task,
|
||||||
prompt: str = data["prompt"]
|
"gsm8k": Gsm8k_Task,
|
||||||
logger.info(f"data['external_server'] {data['external_server']}")
|
"hellaswag": Hellaswag_Task,
|
||||||
logger.info(f"data['prompt'] {prompt}")
|
"arc": ARC_Task,
|
||||||
logger.info(f"data['n_predict'] {data['n_predict']}")
|
"winogrande": WinoGrande_Task,
|
||||||
|
}
|
||||||
|
|
||||||
json_data: dict = {
|
|
||||||
"prompt": prompt,
|
def build_request(case: Case, n_predict: int) -> dict[str, Any]:
|
||||||
"max_tokens": data["n_predict"],
|
json_data = {
|
||||||
|
"n_predict": n_predict,
|
||||||
|
"max_tokens": n_predict,
|
||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
|
"prompt": case.prompt,
|
||||||
}
|
}
|
||||||
response = session.post(f"{server_address}/v1/completions", json=json_data)
|
return json_data
|
||||||
res = json.loads(response.text)
|
|
||||||
logger.info(f"response {res}")
|
|
||||||
extracted_answer = extract_boxed_text(res["choices"][0]["text"])
|
|
||||||
source_answer = data["answer"]
|
|
||||||
if data["prompt_source"] == "aime" or data["prompt_source"] == "gsm8k":
|
|
||||||
try: # All AIME answers are integers, so we convert the extracted answer to an integer
|
|
||||||
extracted_answer = int(extracted_answer)
|
|
||||||
source_answer = int(source_answer)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
extracted_answer = None
|
|
||||||
logger.info(f"extracted_answer {extracted_answer}")
|
|
||||||
logger.info(f"data['answer'] {data['answer']}")
|
|
||||||
|
|
||||||
score = 1 if extracted_answer == source_answer else 0
|
|
||||||
|
|
||||||
return score
|
|
||||||
|
|
||||||
|
|
||||||
def get_server(path_server: str, path_log: Optional[str]) -> dict:
|
def send_prompt(
|
||||||
if path_server.startswith("http://") or path_server.startswith("https://"):
|
case: Case,
|
||||||
return {"process": None, "address": path_server, "fout": None}
|
data: dict,
|
||||||
if os.environ.get("LLAMA_ARG_HOST") is None:
|
) -> dict[str, Union[str, int]]:
|
||||||
logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1")
|
ret_err = {
|
||||||
os.environ["LLAMA_ARG_HOST"] = "127.0.0.1"
|
"task": case.task,
|
||||||
if os.environ.get("LLAMA_ARG_PORT") is None:
|
"case_id": case.case_id,
|
||||||
logger.info("LLAMA_ARG_PORT not explicitly set, using 8080")
|
"status": "error",
|
||||||
os.environ["LLAMA_ARG_PORT"] = "8080"
|
"correct": 0,
|
||||||
hostname: Optional[str] = os.environ.get("LLAMA_ARG_HOST")
|
"gold": case.gold,
|
||||||
port: Optional[str] = os.environ.get("LLAMA_ARG_PORT")
|
"pred": "",
|
||||||
assert hostname is not None
|
"error": "",
|
||||||
assert port is not None
|
}
|
||||||
address: str = f"http://{hostname}:{port}"
|
session: requests.Session = data["session"]
|
||||||
logger.info(f"Starting the llama.cpp server under {address}...")
|
server_address: str = data["server_address"]
|
||||||
|
task = TASK_DICT.get(case.task)
|
||||||
|
if task is None:
|
||||||
|
ret_err["error"] = f"unknown_task: {case.task}"
|
||||||
|
return ret_err
|
||||||
|
logger.debug(case.prompt)
|
||||||
|
|
||||||
fout = open(path_log.format(port=port), "w") if path_log is not None else subprocess.DEVNULL
|
json_data = build_request(case, data["n_predict"])
|
||||||
process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT)
|
try:
|
||||||
|
response = session.post(f"{server_address}/v1/completions", json=json_data)
|
||||||
|
if response.ok:
|
||||||
|
res_json = response.json()
|
||||||
|
else:
|
||||||
|
ret_err["error"] = f"http_response: {response.status_code}"
|
||||||
|
logger.warning(ret_err["error"])
|
||||||
|
return ret_err
|
||||||
|
except Exception as e:
|
||||||
|
ret_err["error"] = f"http_exception: {e}"
|
||||||
|
logger.warning(ret_err["error"])
|
||||||
|
return ret_err
|
||||||
|
logger.debug(response.text)
|
||||||
|
return TASK_DICT[case.task].grade(case, res_json)
|
||||||
|
|
||||||
n_failures: int = 0
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
sleep(1.0)
|
|
||||||
exit_code = process.poll()
|
|
||||||
if exit_code is not None:
|
|
||||||
raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}{path_log and f', see {path_log.format(port=port)}' or ''}")
|
|
||||||
response = requests.get(f"{address}/health")
|
|
||||||
if response.status_code == 200:
|
|
||||||
break
|
|
||||||
except requests.ConnectionError:
|
|
||||||
n_failures += 1
|
|
||||||
if n_failures >= 10:
|
|
||||||
raise RuntimeError("llama.cpp server is not healthy after 10 seconds")
|
|
||||||
|
|
||||||
return {"process": process, "address": address, "fout": fout}
|
def aggregate_by_task(results: list[dict[str, Any]]) -> dict[str, dict[str, int]]:
|
||||||
|
tmp = {
|
||||||
|
"total": 0,
|
||||||
|
"error": 0,
|
||||||
|
"invalid": 0,
|
||||||
|
"correct": 0,
|
||||||
|
}
|
||||||
|
agg: dict[str, dict[str, int]] = {}
|
||||||
|
for row in results:
|
||||||
|
d = agg.get(row["task"], tmp.copy())
|
||||||
|
d["total"] += 1
|
||||||
|
status = row["status"]
|
||||||
|
if status == "ok":
|
||||||
|
d["correct"] += row["correct"]
|
||||||
|
elif status == "invalid":
|
||||||
|
d["invalid"] += 1
|
||||||
|
elif status == "error":
|
||||||
|
d["error"] += 1
|
||||||
|
|
||||||
|
agg[row["task"]] = d
|
||||||
|
return agg
|
||||||
|
|
||||||
|
|
||||||
|
def print_summary(pertask_results: dict[str, dict[str, int]]):
|
||||||
|
print("\n=== llama-eval suite summary ===")
|
||||||
|
print(
|
||||||
|
f"{'Task':<15} {'Acc':>8} {'Correct':>8} {'Total':>8} {'Invalid':>8} {'Error':>8}"
|
||||||
|
)
|
||||||
|
print("-" * 65)
|
||||||
|
|
||||||
|
suite_total = 0
|
||||||
|
suite_correct = 0
|
||||||
|
|
||||||
|
for task in sorted(pertask_results.keys()):
|
||||||
|
stats = pertask_results[task]
|
||||||
|
total = stats["total"]
|
||||||
|
correct = stats["correct"]
|
||||||
|
invalid = stats["invalid"]
|
||||||
|
error = stats["error"]
|
||||||
|
|
||||||
|
acc = (correct / total) if total > 0 else 0.0
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"{task:<15} "
|
||||||
|
f"{acc:8.3f} "
|
||||||
|
f"{correct:8d} "
|
||||||
|
f"{total:8d} "
|
||||||
|
f"{invalid:8d} "
|
||||||
|
f"{error:8d}"
|
||||||
|
)
|
||||||
|
|
||||||
|
suite_total += total
|
||||||
|
suite_correct += correct
|
||||||
|
|
||||||
|
# Overall summary
|
||||||
|
print("-" * 65)
|
||||||
|
suite_acc = (suite_correct / suite_total) if suite_total > 0 else 0.0
|
||||||
|
print(
|
||||||
|
f"{'ALL':<15} " f"{suite_acc:8.3f} " f"{suite_correct:8d} " f"{suite_total:8d}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def benchmark(
|
def benchmark(
|
||||||
path_server: str,
|
path_server: str,
|
||||||
path_log: Optional[str],
|
|
||||||
prompt_source: str,
|
prompt_source: str,
|
||||||
n_prompts: int,
|
n_prompts: int,
|
||||||
n_predict: int,
|
n_predict: int,
|
||||||
rng_seed: int,
|
rng_seed: int,
|
||||||
):
|
):
|
||||||
external_server: bool = path_server.startswith("http://") or path_server.startswith("https://")
|
if not path_server.startswith("http://") and not path_server.startswith("https://"):
|
||||||
|
logger.error("ERROR: malformed server path")
|
||||||
|
return
|
||||||
|
|
||||||
if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
|
if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
|
||||||
logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
|
logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
|
||||||
os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
|
os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
|
||||||
|
|
||||||
parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore
|
parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore
|
||||||
ds: Union[datasets.Dataset, None] = get_dataset(prompt_source, n_prompts, rng_seed)
|
|
||||||
if not ds:
|
|
||||||
logger.error("ERROR: get_dataset")
|
|
||||||
exit(0)
|
|
||||||
|
|
||||||
res: Union[tuple[list[str], list[str]], None] = get_prompts_text(prompt_source, ds)
|
task_queue: set[TaskSpec] = set()
|
||||||
if not res:
|
for src in prompt_source.split(","):
|
||||||
logger.error("ERROR: get_prompts_text")
|
if src == "all":
|
||||||
exit(0)
|
for v in TASK_DICT.values():
|
||||||
|
task_queue.add(v())
|
||||||
|
break
|
||||||
|
task_queue.add(TASK_DICT[src]())
|
||||||
|
|
||||||
prompts: Union[list[str], list[list[int]]] = res[0]
|
|
||||||
answer: Union[list[str], list[list[int]]] = res[1]
|
|
||||||
|
|
||||||
logger.info(prompts)
|
|
||||||
logger.info(f"external_server {external_server}")
|
|
||||||
|
|
||||||
server: Optional[dict] = None
|
|
||||||
session = None
|
session = None
|
||||||
try:
|
try:
|
||||||
server = get_server(path_server, path_log)
|
server_address: str = path_server
|
||||||
server_address: str = server["address"]
|
|
||||||
assert external_server == (server["process"] is None)
|
|
||||||
|
|
||||||
adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore
|
adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel) # type: ignore
|
||||||
session = requests.Session()
|
session = requests.Session()
|
||||||
session.mount("http://", adapter)
|
session.mount("http://", adapter)
|
||||||
session.mount("https://", adapter)
|
session.mount("https://", adapter)
|
||||||
|
|
||||||
|
cases: list[Case] = []
|
||||||
data: list[dict] = []
|
data: list[dict] = []
|
||||||
for p, a in zip(prompts, answer):
|
for task in task_queue:
|
||||||
data.append(
|
for case in task.iter_cases(n_prompts, rng_seed):
|
||||||
{
|
cases.append(case)
|
||||||
"prompt_source": prompt_source,
|
data.append(
|
||||||
"session": session,
|
{
|
||||||
"server_address": server_address,
|
"prompt_source": prompt_source,
|
||||||
"external_server": external_server,
|
"session": session,
|
||||||
"prompt": p,
|
"server_address": server_address,
|
||||||
"answer": a,
|
"n_predict": n_predict,
|
||||||
"n_predict": n_predict,
|
}
|
||||||
}
|
)
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Starting the benchmark...\n")
|
logger.info("Starting the benchmark...\n")
|
||||||
t0 = time()
|
t0 = time()
|
||||||
results: list[int] = thread_map(
|
results: list[dict[str, Union[str, int]]] = thread_map(
|
||||||
send_prompt, data, max_workers=parallel, chunksize=1
|
send_prompt,
|
||||||
|
cases,
|
||||||
|
data,
|
||||||
|
max_workers=parallel,
|
||||||
|
chunksize=1,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
if server is not None and server["process"] is not None:
|
|
||||||
server["process"].terminate()
|
|
||||||
server["process"].wait()
|
|
||||||
if session is not None:
|
if session is not None:
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
t1 = time()
|
t1 = time()
|
||||||
|
logger.info(f"\nllama-eval duration: {t1-t0:.2f} s")
|
||||||
|
|
||||||
correct: int = sum(results)
|
pertask_results = aggregate_by_task(results)
|
||||||
total_questions: int = len(data)
|
print_summary(pertask_results)
|
||||||
logger.info(f"llama-eval duration: {t1-t0:.2f} s")
|
|
||||||
logger.info(f"{prompt_source} correct: {correct}")
|
|
||||||
logger.info(f"{prompt_source} total_questions: {total_questions}")
|
|
||||||
logger.info(f"{prompt_source} accuracy: {correct / total_questions}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
@ -324,23 +569,17 @@ if __name__ == "__main__":
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--path_server",
|
"--path_server",
|
||||||
type=str,
|
type=str,
|
||||||
default="llama-server",
|
default="http://localhost:8033",
|
||||||
help="Path to the llama.cpp server binary",
|
help="llama-server url",
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--path_log",
|
|
||||||
type=str,
|
|
||||||
default="server-bench-{port}.log",
|
|
||||||
help="Path to the model to use for the benchmark",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt_source",
|
"--prompt_source",
|
||||||
type=str,
|
type=str,
|
||||||
default="mmlu",
|
default="mmlu",
|
||||||
help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions",
|
help=f"Eval types supported: all,{TASK_DICT.keys()}",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--n_prompts", type=int, default=100, help="Number of prompts to evaluate"
|
"--n_prompts", type=int, default=None, help="Number of prompts to evaluate"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--rng_seed",
|
"--rng_seed",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue