scripts: add script to compare logprobs of llama.cpp against other frameworks (#17947)

* scripts: add script to compare logits of llama.cpp against other frameworks

* accept custom prompt file

* fix code style

* clarify endpoint

* fix displaying

* use abs for diff

* fix vllm case

* rm output file

* rename to compare-logprobs

* add "pattern"
This commit is contained in:
Xuan-Son Nguyen 2025-12-13 22:33:29 +01:00 committed by GitHub
parent 4ed2bae50d
commit c00ff929dc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 281 additions and 0 deletions

281
scripts/compare-logprobs.py Normal file
View File

@ -0,0 +1,281 @@
import argparse
import requests
import json
from pathlib import Path
import logging
logger = logging.getLogger("compare-logprobs")
logging.basicConfig(level=logging.INFO)
DESCRIPTION = """
Compare logits between llama.cpp and another inference engine using OpenAI-compatible server endpoints.
Unlike compare-logits.py, it allows dumping logits from a hosted API endpoint. Useful when it's not possible to run both models locally.
Example usage:
Step 1: Dump logits from two different servers
python scripts/compare-logprobs.py dump logits_llama.log http://localhost:8080/v1/completions
python scripts/compare-logprobs.py dump logits_other.log http://other-engine:8000/v1/completions
(optionally, you can add --api-key <key> if the endpoint requires authentication)
Step 2: Compare the dumped logits
python scripts/compare-logprobs.py compare logits_llama.log logits_other.log report.md
"""
def generate_input_prompt(length: int) -> list[str]:
CORPUS = """
You are an advanced AI assistant capable of using tools to gather information, perform calculations, or execute tasks. Always think step by step before responding. If a user's query requires external data, computation, or actions beyond your internal knowledge, use the appropriate tools via function calls.
### Tool Call Format:
When you need to use a tool, output the call in this exact XML format. Include the opening and closing tags. Do not escape arguments; they will be parsed as plain text.
You can make multiple calls in one go by placing them one after another.
"""
words = [w.strip() for w in CORPUS.strip().split(" ")]
words = [w for w in words if len(w) > 0] # filter out empty strings
while len(words) < length:
words += words
return words[:length]
def dump_logits(
endpoint: str,
output_path: Path,
input_words: list[str],
pattern: list[tuple[bool, int]],
api_key=None,
):
logger.info(f"Dumping logits to {output_path} from endpoint {endpoint}...")
words = input_words
curr_text = ""
n_total = sum(n for get, n in pattern if get)
n_done = 0
i_cur = 0
i_total = len(words)
with output_path.open("w") as f:
for get, n in pattern:
if not get:
# skip n words
for i in range(n):
curr_text += words.pop(0) + " "
i_cur += 1
continue
# get n words
for i in range(n):
curr_text += words.pop(0) + " "
payload = {
"prompt": curr_text.strip(),
"temperature": 0.0,
"top_k": 1,
"max_tokens": 1,
"logprobs": 1,
"stream": False,
}
response = requests.post(
endpoint,
json=payload,
headers={"Authorization": f"Bearer {api_key}"} if api_key else {},
)
response.raise_for_status()
data = response.json()
data["__index"] = i_cur # add index for easier debugging later
data = json.dumps(data)
f.write(f"{data}\n")
n_done += 1
i_cur += 1
logger.info(
f"\n\n{data}\n\n[Step: {n_done}/{n_total} | Word: {i_cur}/{i_total}]"
)
logger.info(f"Logits dumped to {output_path}")
def get_token_logprobs(data: dict):
logprobs = data["choices"][0]["logprobs"]
if "content" in logprobs:
# llama.cpp case
top = logprobs["content"][0]["top_logprobs"][0]
return top["token"], top["logprob"]
else:
# vllm case
tokens = logprobs["tokens"]
token_logprobs = logprobs["token_logprobs"]
return tokens[0], token_logprobs[0]
def clean_text(text: str) -> str:
return (
"'"
+ text.replace("\n", "\\n")
.replace("\t", "\\t")
.replace("\r", "\\r")
.replace("|", "\\|")
+ "'"
)
def compare_logits(input1: Path, input2: Path, output_path: Path):
with input1.open("r") as f1, input2.open("r") as f2, output_path.open("w") as fout:
lines1 = f1.readlines()
lines2 = f2.readlines()
tab_header = [
"idx",
input1.name,
"logprob_1",
input2.name,
"logprob_2",
"diff (abs)",
]
tab_entries = []
tab_max_widths = [len(h) for h in tab_header]
assert len(lines1) == len(
lines2
), "Input files must have the same number of lines."
fout.write("# Logits Comparison Report\n\n")
for i, (line1, line2) in enumerate(zip(lines1, lines2)):
if not line1.strip() or not line2.strip():
continue # skip empty lines
data1 = json.loads(line1)
data2 = json.loads(line2)
idx1 = data1.get("__index", -1)
idx2 = data2.get("__index", -1)
if idx1 != idx2:
logger.warning(
f"Warning: Mismatched indices at line {i}: {idx1} vs {idx2}"
)
token1, logprob1 = get_token_logprobs(data1)
token2, logprob2 = get_token_logprobs(data2)
token1 = clean_text(token1)
token2 = clean_text(token2)
abs_diff = abs(logprob1 - logprob2)
tab_entries.append(
(
str(idx1 + 1),
token1,
f"{logprob1:.4f}",
token2,
f"{logprob2:.4f}",
f"{(abs_diff):.4f}",
)
)
for i in range(len(tab_entries)):
for j in range(len(tab_header)):
tab_max_widths[j] = max(tab_max_widths[j], len(tab_entries[i][j]))
output = ""
for j in range(len(tab_header)):
output += f"| {tab_header[j]:<{tab_max_widths[j]}} "
output += "|\n"
for j in range(len(tab_header)):
output += f"|{'-' * (tab_max_widths[j] + 2)}"
output += "|\n"
for entry in tab_entries:
for j in range(len(tab_header)):
output += f"| {entry[j]:<{tab_max_widths[j]}} "
output += "|\n"
logger.info("\n" + output)
fout.write(output)
logger.info(f"Report written to {output_path}")
def parse_pattern(pattern: str) -> list[tuple[bool, int]]:
parts = pattern.split(",")
result = []
for i, part in enumerate(parts):
n = int(part)
if i % 2 == 0:
result.append((True, n)) # get n words
else:
result.append((False, n)) # skip n words
return result
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=DESCRIPTION, formatter_class=argparse.RawTextHelpFormatter
)
subparsers = parser.add_subparsers(
dest="verb", required=True, help="action to perform"
)
# dump subcommand
parser_dump = subparsers.add_parser("dump", help="dump logits from an endpoint")
parser_dump.add_argument(
"output", type=Path, help="output path for dumped logits (.log)"
)
parser_dump.add_argument(
"endpoint", type=str, help="OAI-compat /completions endpoint"
)
parser_dump.add_argument(
"--api-key",
type=str,
default=None,
help="API key for authentication (if required)",
)
parser_dump.add_argument(
"--file",
type=Path,
default=None,
help="File containing prompt to use instead of the default",
)
parser_dump.add_argument(
"--pattern",
type=str,
default="10,1000,10,4000,10",
help="Pattern n_get,n_skip,... where n_get is number of words to get and n_skip is number of words to skip (num of words, NOT num of tokens)",
)
# compare subcommand
parser_compare = subparsers.add_parser(
"compare", help="compare two dumped logits files"
)
parser_compare.add_argument("input1", type=Path, help="first input file (.log)")
parser_compare.add_argument("input2", type=Path, help="second input file (.log)")
parser_compare.add_argument(
"output", type=Path, help="output path for comparison report (.md)"
)
try:
return parser.parse_args()
except Exception as e:
parser.print_help()
raise e
def main():
args = parse_args()
if args.verb == "dump":
pattern = parse_pattern(args.pattern)
input_length = sum(n for _, n in pattern)
input_words = generate_input_prompt(input_length)
if args.file is not None:
with args.file.open("r") as f:
input_words = f.read().strip().split(" ")
if input_length < sum(n for _, n in pattern):
raise ValueError(
f"Input file has only {input_length} words, but pattern requires at least {input_length} words."
)
input_length = len(input_words)
logger.info(f"Using {input_length} words")
dump_logits(args.endpoint, args.output, input_words, pattern, args.api_key)
elif args.verb == "compare":
compare_logits(args.input1, args.input2, args.output)
else:
raise ValueError(f"Unknown verb: {args.verb}")
if __name__ == "__main__":
main()