This commit is contained in:
Quentin Fuxa 2026-03-15 23:55:07 +02:00 committed by GitHub
commit 41f2d34aef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 2038 additions and 6 deletions

View File

@ -14,6 +14,7 @@ llama_add_compile_flags()
if (EMSCRIPTEN)
else()
add_subdirectory(attn-weights)
add_subdirectory(batched)
add_subdirectory(debug)
add_subdirectory(embedding)

View File

@ -0,0 +1,5 @@
set(TARGET llama-attn-weights)
add_executable(${TARGET} attn-weights.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)

View File

@ -0,0 +1,255 @@
// Attention weights extraction example
//
// Decodes a prompt and prints an ASCII heatmap of the attention pattern
// for a selected (layer, head) pair.
//
// Usage:
// ./llama-attn-weights -m model.gguf [-l layer] [-hd head] [-ngl N] [prompt]
//
// Defaults: last layer, head 0.
// Use -l -1 for the last layer.
#include "llama.h"
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <string>
#include <vector>
static std::string token_str(const llama_vocab * vocab, llama_token id) {
char buf[256];
int n = llama_token_to_piece(vocab, id, buf, sizeof(buf), 0, true);
if (n < 0) {
return "???";
}
// sanitize: replace newlines/tabs with visible chars
std::string s(buf, n);
for (auto & c : s) {
if (c == '\n') { c = '|'; }
if (c == '\t') { c = ' '; }
if (c == '\r') { c = ' '; }
}
return s;
}
// Truncate or pad a string to exactly `w` characters
static std::string fit(const std::string & s, int w) {
if ((int) s.size() <= w) {
return s + std::string(w - (int) s.size(), ' ');
}
return s.substr(0, w);
}
static void print_usage(int, char ** argv) {
printf("\nexample usage:\n");
printf("\n %s -m model.gguf [-l layer] [-hd head] [-ngl n_gpu_layers] [prompt]\n", argv[0]);
printf("\n");
printf(" -l layer index (default: -1 = last layer)\n");
printf(" -hd head index (default: 0)\n");
printf("\n");
}
int main(int argc, char ** argv) {
std::string model_path;
std::string prompt = "The cat sat on the mat";
int ngl = 99;
int layer = -1; // -1 means last layer
int head = 0;
// parse args
{
int i = 1;
for (; i < argc; i++) {
if (strcmp(argv[i], "-m") == 0) {
if (i + 1 < argc) { model_path = argv[++i]; }
else { print_usage(argc, argv); return 1; }
} else if (strcmp(argv[i], "-l") == 0) {
if (i + 1 < argc) {
try { layer = std::stoi(argv[++i]); } catch (...) { print_usage(argc, argv); return 1; }
} else { print_usage(argc, argv); return 1; }
} else if (strcmp(argv[i], "-hd") == 0) {
if (i + 1 < argc) {
try { head = std::stoi(argv[++i]); } catch (...) { print_usage(argc, argv); return 1; }
} else { print_usage(argc, argv); return 1; }
} else if (strcmp(argv[i], "-ngl") == 0) {
if (i + 1 < argc) {
try { ngl = std::stoi(argv[++i]); } catch (...) { print_usage(argc, argv); return 1; }
} else { print_usage(argc, argv); return 1; }
} else {
break;
}
}
if (model_path.empty()) {
print_usage(argc, argv);
return 1;
}
if (i < argc) {
prompt = argv[i++];
for (; i < argc; i++) {
prompt += " ";
prompt += argv[i];
}
}
}
// load backends & model
ggml_backend_load_all();
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = ngl;
llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params);
if (!model) {
fprintf(stderr, "%s: error: unable to load model\n", __func__);
return 1;
}
const llama_vocab * vocab = llama_model_get_vocab(model);
// resolve layer index
const int32_t n_layer = llama_model_n_layer(model);
if (layer < 0) {
layer = n_layer + layer; // -1 -> last layer
}
if (layer < 0 || layer >= n_layer) {
fprintf(stderr, "%s: error: layer %d out of range [0, %d)\n", __func__, layer, n_layer);
llama_model_free(model);
return 1;
}
// tokenize
const int n_prompt = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
std::vector<llama_token> tokens(n_prompt);
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), tokens.data(), tokens.size(), true, true) < 0) {
fprintf(stderr, "%s: error: failed to tokenize\n", __func__);
llama_model_free(model);
return 1;
}
// create context with attention weights enabled
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = n_prompt;
ctx_params.n_batch = n_prompt;
ctx_params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
ctx_params.attn_weights = true;
ctx_params.no_perf = true;
llama_context * ctx = llama_init_from_model(model, ctx_params);
if (!ctx) {
fprintf(stderr, "%s: error: failed to create context\n", __func__);
llama_model_free(model);
return 1;
}
// configure which (layer, head) to extract
{
int32_t layers[] = { (int32_t) layer };
int32_t heads[] = { (int32_t) head };
llama_set_attn_heads(ctx, layers, heads, 1);
}
// prepare batch — request output for all tokens
llama_batch batch = llama_batch_init(n_prompt, 0, 1);
for (int i = 0; i < n_prompt; i++) {
batch.token[i] = tokens[i];
batch.pos[i] = i;
batch.n_seq_id[i] = 1;
batch.seq_id[i][0] = 0;
batch.logits[i] = 1; // request output for every token
}
batch.n_tokens = n_prompt;
// decode
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s: error: decode failed\n", __func__);
llama_batch_free(batch);
llama_free(ctx);
llama_model_free(model);
return 1;
}
// collect token labels
std::vector<std::string> labels(n_prompt);
for (int i = 0; i < n_prompt; i++) {
labels[i] = token_str(vocab, tokens[i]);
}
const int32_t n_kv = llama_get_attn_n_kv(ctx);
const int32_t n_ctx = llama_n_ctx(ctx);
printf("\nAttention weights — layer %d, head %d\n", layer, head);
printf("Tokens: %d, KV entries: %d\n\n", n_prompt, n_kv);
// print as a matrix: rows = query tokens, columns = key tokens
// only the first n_kv columns are valid
const int col_w = 7; // column width for values
const int lbl_w = 12; // label column width
// header row: key token labels
printf("%s", fit("", lbl_w).c_str());
for (int k = 0; k < n_kv && k < n_prompt; k++) {
printf("%s", fit(labels[k], col_w).c_str());
}
printf("\n");
// separator
printf("%s", std::string(lbl_w + col_w * std::min(n_kv, (int32_t) n_prompt), '-').c_str());
printf("\n");
// ASCII heatmap chars from low to high attention
const char * heat = " .:-=+*#@";
const int heat_len = 9;
// one row per query token
for (int q = 0; q < n_prompt; q++) {
float * attn = llama_get_attn_ith(ctx, q);
if (!attn) {
printf("%s (no data)\n", fit(labels[q], lbl_w).c_str());
continue;
}
// find max for this row (for heatmap scaling)
float row_max = 0.0f;
for (int k = 0; k < n_kv && k < n_prompt; k++) {
if (attn[k] > row_max) {
row_max = attn[k];
}
}
// numeric row
printf("%s", fit(labels[q], lbl_w).c_str());
for (int k = 0; k < n_kv && k < n_prompt; k++) {
printf(" %5.3f ", attn[k]);
}
printf("\n");
// heatmap row
if (row_max > 0.0f) {
printf("%s", fit("", lbl_w).c_str());
for (int k = 0; k < n_kv && k < n_prompt; k++) {
float norm = attn[k] / row_max;
int idx = (int)(norm * (heat_len - 1));
idx = std::max(0, std::min(heat_len - 1, idx));
// center the char in the column
int pad_l = (col_w - 1) / 2;
int pad_r = col_w - 1 - pad_l;
printf("%*s%c%*s", pad_l, "", heat[idx], pad_r, "");
}
printf("\n");
}
}
printf("\nLegend: ' .:-=+*#@' (low -> high attention)\n");
// print the n_ctx value for reference
(void) n_ctx;
llama_batch_free(batch);
llama_free(ctx);
llama_model_free(model);
return 0;
}

View File

@ -0,0 +1,343 @@
#!/usr/bin/env python3
"""
Benchmark: TranslateGemma 4B + AlignAtt streaming policy (ENZH)
llama.cpp implementation comparable to CTranslate2/benchmark/eval_benchmark_translategemma.py
Translates FLORES ENZH sentences using llama.cpp with alignment heads.
Computes latency metrics (AL, LAAL, AP, CW) and saves translations for quality scoring.
Usage:
# First convert TranslateGemma to GGUF:
# python3 convert_hf_to_gguf.py /tmp/translategemma-text-only --outfile /tmp/translategemma.gguf
python3 examples/attn-weights/eval_benchmark_translategemma.py /tmp/translategemma.gguf \\
--heads examples/attn-weights/translation_heads_en_zh.json \\
-n 100 --output examples/attn-weights/results_llamacpp.json
"""
import argparse
import json
import os
import sys
import time
import numpy as np
sys.path.insert(0, os.path.dirname(__file__))
import llama_attn as ll
# Copy heads JSON from CTranslate2 benchmark if not present locally
CT2_HEADS = os.path.join(
os.path.dirname(__file__), "..", "..", "..", "CTranslate2", "benchmark", "translation_heads_en_zh.json"
)
PROMPT_TEMPLATE = (
"<bos><start_of_turn>user\n"
"You are a professional English (en) to Chinese (zh) translator. "
"Your goal is to accurately convey the meaning and nuances of the "
"original English text while adhering to Chinese grammar, vocabulary, "
"and cultural sensitivities.\n"
"Produce only the Chinese translation, without any additional "
"explanations or commentary. Please translate the following English "
"text into Chinese:\n\n\n"
"{source}<end_of_turn>\n"
"<start_of_turn>model\n"
)
def find_source_token_range(source_text, vocab):
"""Find token range of source text within the prompt."""
prefix = PROMPT_TEMPLATE.split("{source}")[0]
suffix = PROMPT_TEMPLATE.split("{source}")[1]
prefix_ids = ll.tokenize(vocab, prefix, add_bos=False, special=True)
full_ids = ll.tokenize(vocab, prefix + source_text + suffix, add_bos=False, special=True)
suffix_ids = ll.tokenize(vocab, suffix, add_bos=False, special=True)
return len(prefix_ids), len(full_ids) - len(suffix_ids)
def compute_metrics(alignments, num_src, num_tgt):
"""Compute AL, LAAL, AP, Max CW — same as CTranslate2 benchmark."""
if not alignments or num_tgt == 0 or num_src == 0:
return {"al": 0.0, "max_cw": 0, "laal": 0.0, "ap": 0.0}
mono = []
max_dep = 0
for dep in alignments:
max_dep = max(max_dep, dep)
mono.append(max_dep)
ratio = num_src / num_tgt
total_lag = sum(max(0, (mono[t] + 1) - t * ratio) for t in range(num_tgt))
al = total_lag / num_tgt
tau = min(num_src / num_tgt, 1.0)
laal = al * tau
total_src_read = sum(mono[t] + 1 for t in range(num_tgt))
ap = total_src_read / (num_src * num_tgt)
max_cw = 0
cw = 0
prev_needed = 0
for dep in mono:
needed = dep + 1
reads = max(0, needed - prev_needed)
if reads > 0:
cw += reads
max_cw = max(max_cw, cw)
else:
cw = 0
prev_needed = needed
cw = 0
return {"al": al, "max_cw": max_cw, "laal": laal, "ap": ap}
def aggregate_ts_weighted_vote(src_attn, ts_scores):
"""TS-weighted vote: weighted argmax across heads."""
head_argmaxes = np.argmax(src_attn, axis=1)
weighted = {}
for h, pos in enumerate(head_argmaxes):
pos = int(pos)
weighted[pos] = weighted.get(pos, 0) + ts_scores[h]
return max(weighted, key=weighted.get)
def aggregate_mean(src_attn):
avg = src_attn.mean(axis=0)
return int(np.argmax(avg))
def main():
parser = argparse.ArgumentParser(description="TranslateGemma benchmark with llama.cpp attention")
parser.add_argument("model", help="Path to GGUF model file")
parser.add_argument("-n", type=int, default=100, help="Number of FLORES sentences")
parser.add_argument("--heads", default=None, help="Alignment heads JSON (default: CTranslate2 benchmark)")
parser.add_argument("--top-k", type=int, default=10, help="Number of top alignment heads")
parser.add_argument("--strategy", default="ts_weighted_vote", choices=["ts_weighted_vote", "mean"])
parser.add_argument("--output", default="benchmark/results_llamacpp.json")
parser.add_argument("--n-ctx", type=int, default=1024)
parser.add_argument("--max-tokens", type=int, default=256)
args = parser.parse_args()
# Load alignment heads
heads_path = args.heads or CT2_HEADS
if not os.path.exists(heads_path):
# Try local copy
local = os.path.join(os.path.dirname(__file__), "translation_heads_en_zh.json")
if os.path.exists(local):
heads_path = local
else:
print(f"ERROR: Cannot find alignment heads JSON at {heads_path}")
print("Copy from CTranslate2/benchmark/translation_heads_en_zh.json")
sys.exit(1)
with open(heads_path) as f:
data = json.load(f)
top_heads = data["token_alignment_heads"][:args.top_k]
head_layers = [h["layer"] for h in top_heads]
head_indices = [h["head"] for h in top_heads]
ts_scores = [h["ts"] for h in top_heads]
num_heads = len(top_heads)
print(f"Using {num_heads} alignment heads (strategy={args.strategy})")
for h in top_heads[:5]:
print(f" L{h['layer']:2d} H{h['head']} (TS={h['ts']:.3f})")
if num_heads > 5:
print(f" ... and {num_heads - 5} more")
# Load FLORES EN→ZH
print("\nLoading FLORES EN→ZH...")
from datasets import load_dataset
ds = load_dataset("openlanguagedata/flores_plus", split="dev")
en_ds = ds.filter(lambda x: x["iso_639_3"] == "eng" and x["iso_15924"] == "Latn")
zh_ds = ds.filter(lambda x: x["iso_639_3"] == "cmn" and x["iso_15924"] == "Hans")
en_map = {row["id"]: row["text"] for row in en_ds}
zh_map = {row["id"]: row["text"] for row in zh_ds}
common_ids = sorted(set(en_map) & set(zh_map))
n = min(args.n, len(common_ids))
print(f" {n} sentence pairs")
# Initialize llama.cpp
print(f"\nLoading model: {args.model}")
ll.init()
model = ll.load_model(args.model)
vocab = ll.get_vocab(model)
nv = ll.n_vocab(vocab)
n_layers = ll.n_layer(model)
eos_id = ll.vocab_eos(vocab)
print(f" {n_layers} layers, vocab={nv}")
# Find stop token IDs
stop_ids = set()
for tok_str in ["<end_of_turn>", "<eos>"]:
toks = ll.tokenize(vocab, tok_str, add_bos=False, special=True)
if len(toks) == 1:
stop_ids.add(toks[0])
stop_ids.add(eos_id)
print(f" Stop IDs: {stop_ids}")
# Evaluate
results = []
al_list, cw_list, laal_list, ap_list = [], [], [], []
for idx in range(n):
sid = common_ids[idx]
source = en_map[sid]
reference = zh_map[sid]
prompt = PROMPT_TEMPLATE.format(source=source)
prompt_tokens = ll.tokenize(vocab, prompt, add_bos=False, special=True)
prompt_len = len(prompt_tokens)
src_start, src_end = find_source_token_range(source, vocab)
num_src = src_end - src_start
# Create fresh context for each sentence (clean KV cache)
ctx = ll.create_context(model, n_ctx=args.n_ctx, n_batch=args.n_ctx, attn_weights=True)
ll.set_attn_heads(ctx, head_layers, head_indices)
n_c = ll.n_ctx(ctx)
t0 = time.time()
# Prefill
ret = ll.decode_batch(ctx, prompt_tokens, output_last_only=True)
if ret != 0:
print(f" [{idx+1}] prefill failed: {ret}, skipping")
ll.free_context(ctx)
continue
# Autoregressive generation with attention extraction
generated_ids = []
alignments = []
pos = prompt_len
for step in range(args.max_tokens):
# Get attention for current token
attn = ll.get_attn_weights(ctx, -1, num_heads, n_c)
if attn is not None:
# Extract attention over source token range only
n_kv = ll._lib.llama_get_attn_n_kv(ctx)
if src_start < n_kv and src_end <= n_kv:
src_attn = attn[:, src_start:src_end] # (num_heads, num_src)
if args.strategy == "ts_weighted_vote":
aligned_pos = aggregate_ts_weighted_vote(src_attn, ts_scores)
else:
aligned_pos = aggregate_mean(src_attn)
alignments.append(aligned_pos)
else:
alignments.append(num_src - 1)
else:
alignments.append(num_src - 1)
# Get next token
next_tok = ll.argmax_logits(ctx, -1, nv)
if next_tok in stop_ids or next_tok < 0:
break
generated_ids.append(next_tok)
# Decode next token
ret = ll.decode_single(ctx, next_tok, pos, output=True)
if ret != 0:
break
pos += 1
gen_time = time.time() - t0
ll.free_context(ctx)
num_tgt = len(generated_ids)
if num_tgt == 0:
continue
# Trim alignments to match generated tokens
alignments = alignments[:num_tgt]
# Decode translation text
pieces = [ll.token_to_piece(vocab, tid) for tid in generated_ids]
translation = "".join(pieces)
metrics = compute_metrics(alignments, num_src, num_tgt)
al_list.append(metrics["al"])
cw_list.append(metrics["max_cw"])
laal_list.append(metrics["laal"])
ap_list.append(metrics["ap"])
results.append({
"id": int(sid),
"source": source,
"reference": reference,
"translation": translation,
"num_src_tokens": num_src,
"num_tgt_tokens": num_tgt,
"al": round(metrics["al"], 3),
"max_cw": metrics["max_cw"],
"laal": round(metrics["laal"], 3),
"ap": round(metrics["ap"], 3),
"gen_time_ms": round(gen_time * 1000),
})
if (idx + 1) % 10 == 0:
avg_al = np.mean(al_list)
print(f" [{idx+1}/{n}] Avg AL={avg_al:.2f}, last: {translation[:50]}...", flush=True)
# Summary
if not results:
print("No results!")
ll.free_model(model)
ll.cleanup()
sys.exit(1)
al_arr = np.array(al_list)
cw_arr = np.array(cw_list)
laal_arr = np.array(laal_list)
ap_arr = np.array(ap_list)
summary = {
"system": "TranslateGemma-4B + AlignAtt (llama.cpp)",
"language_pair": "en-zh",
"num_sentences": len(results),
"num_alignment_heads": num_heads,
"strategy": args.strategy,
"latency": {
"avg_al": round(float(np.mean(al_arr)), 3),
"median_al": round(float(np.median(al_arr)), 3),
"p90_al": round(float(np.percentile(al_arr, 90)), 3),
"avg_max_cw": round(float(np.mean(cw_arr)), 1),
"max_max_cw": int(np.max(cw_arr)),
"avg_laal": round(float(np.mean(laal_arr)), 3),
"avg_ap": round(float(np.mean(ap_arr)), 3),
},
"avg_gen_time_ms": round(float(np.mean([r["gen_time_ms"] for r in results]))),
}
output = {"summary": summary, "sentences": results}
os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True)
with open(args.output, "w") as f:
json.dump(output, f, indent=2, ensure_ascii=False)
print(f"\n{'='*60}")
print(f"TranslateGemma + AlignAtt (llama.cpp) — EN→ZH ({len(results)} sentences)")
print(f"{'='*60}")
print(f" Avg AL: {summary['latency']['avg_al']:.3f}")
print(f" Median AL: {summary['latency']['median_al']:.3f}")
print(f" P90 AL: {summary['latency']['p90_al']:.3f}")
print(f" Avg Max CW: {summary['latency']['avg_max_cw']:.1f}")
print(f" Max CW: {summary['latency']['max_max_cw']}")
print(f" Avg LAAL: {summary['latency']['avg_laal']:.3f}")
print(f" Avg AP: {summary['latency']['avg_ap']:.3f}")
print(f" Avg gen: {summary['avg_gen_time_ms']}ms/sentence")
print(f"\nSaved to {args.output}")
ll.free_model(model)
ll.cleanup()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,387 @@
"""
Minimal ctypes wrapper for llama.cpp attention weight extraction API.
"""
import ctypes
import os
import sys
# Find the shared library
_LIB_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "build", "bin")
_LIB_NAMES = ["libllama.dylib", "libllama.so", "llama.dll"]
_lib = None
for name in _LIB_NAMES:
path = os.path.join(_LIB_DIR, name)
if os.path.exists(path):
_lib = ctypes.CDLL(path)
break
if _lib is None:
raise RuntimeError(f"Cannot find libllama in {_LIB_DIR}. Build first with: cmake --build build")
# --- Types ---
class llama_model(ctypes.Structure):
pass
class llama_context(ctypes.Structure):
pass
class llama_vocab(ctypes.Structure):
pass
llama_token = ctypes.c_int32
llama_pos = ctypes.c_int32
llama_seq_id = ctypes.c_int32
class llama_batch(ctypes.Structure):
_fields_ = [
("n_tokens", ctypes.c_int32),
("token", ctypes.POINTER(llama_token)),
("embd", ctypes.POINTER(ctypes.c_float)),
("pos", ctypes.POINTER(llama_pos)),
("n_seq_id", ctypes.POINTER(ctypes.c_int32)),
("seq_id", ctypes.POINTER(ctypes.POINTER(llama_seq_id))),
("logits", ctypes.POINTER(ctypes.c_int8)),
]
# Model params - we only need to get the default and possibly modify a few fields
# Since the struct is large and complex, we'll treat it as opaque bytes
# and use the C function to get defaults
class llama_model_params(ctypes.Structure):
_fields_ = [
("_opaque", ctypes.c_uint8 * 256), # oversized, safe
]
# We need the exact layout of llama_context_params to set attn_weights and flash_attn_type
# Let's read it from the header. The key fields we need are at known positions.
# Instead of matching the exact struct, we'll use the C API defaults and patch bytes.
# Enums
LLAMA_FLASH_ATTN_TYPE_AUTO = -1
LLAMA_FLASH_ATTN_TYPE_DISABLED = 0
LLAMA_FLASH_ATTN_TYPE_ENABLED = 1
# --- Function signatures ---
# void llama_backend_init(void)
_lib.llama_backend_init.argtypes = []
_lib.llama_backend_init.restype = None
# void llama_backend_free(void)
_lib.llama_backend_free.argtypes = []
_lib.llama_backend_free.restype = None
# llama_model_params llama_model_default_params(void)
_lib.llama_model_default_params.argtypes = []
_lib.llama_model_default_params.restype = llama_model_params
# llama_model * llama_model_load_from_file(const char * path, llama_model_params params)
_lib.llama_model_load_from_file.argtypes = [ctypes.c_char_p, llama_model_params]
_lib.llama_model_load_from_file.restype = ctypes.POINTER(llama_model)
# void llama_model_free(llama_model * model)
_lib.llama_model_free.argtypes = [ctypes.POINTER(llama_model)]
_lib.llama_model_free.restype = None
# const llama_vocab * llama_model_get_vocab(const llama_model * model)
_lib.llama_model_get_vocab.argtypes = [ctypes.POINTER(llama_model)]
_lib.llama_model_get_vocab.restype = ctypes.POINTER(llama_vocab)
# int32_t llama_model_n_layer(const llama_model * model)
_lib.llama_model_n_layer.argtypes = [ctypes.POINTER(llama_model)]
_lib.llama_model_n_layer.restype = ctypes.c_int32
# int32_t llama_model_n_head(const llama_model * model)
_lib.llama_model_n_head.argtypes = [ctypes.POINTER(llama_model)]
_lib.llama_model_n_head.restype = ctypes.c_int32
# int32_t llama_tokenize(const llama_vocab *, const char *, int32_t, llama_token *, int32_t, bool, bool)
_lib.llama_tokenize.argtypes = [
ctypes.POINTER(llama_vocab), ctypes.c_char_p, ctypes.c_int32,
ctypes.POINTER(llama_token), ctypes.c_int32, ctypes.c_bool, ctypes.c_bool
]
_lib.llama_tokenize.restype = ctypes.c_int32
# llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max)
_lib.llama_batch_init.argtypes = [ctypes.c_int32, ctypes.c_int32, ctypes.c_int32]
_lib.llama_batch_init.restype = llama_batch
# void llama_batch_free(llama_batch batch)
_lib.llama_batch_free.argtypes = [llama_batch]
_lib.llama_batch_free.restype = None
# int32_t llama_decode(llama_context * ctx, llama_batch batch)
_lib.llama_decode.argtypes = [ctypes.POINTER(llama_context), llama_batch]
_lib.llama_decode.restype = ctypes.c_int32
# void llama_free(llama_context * ctx)
_lib.llama_free.argtypes = [ctypes.POINTER(llama_context)]
_lib.llama_free.restype = None
# void llama_set_attn_heads(llama_context *, const int32_t * layers, const int32_t * heads, size_t n_pairs)
_lib.llama_set_attn_heads.argtypes = [
ctypes.POINTER(llama_context),
ctypes.POINTER(ctypes.c_int32), ctypes.POINTER(ctypes.c_int32), ctypes.c_size_t
]
_lib.llama_set_attn_heads.restype = None
# float * llama_get_attn_ith(llama_context * ctx, int32_t i)
_lib.llama_get_attn_ith.argtypes = [ctypes.POINTER(llama_context), ctypes.c_int32]
_lib.llama_get_attn_ith.restype = ctypes.POINTER(ctypes.c_float)
# int32_t llama_get_attn_n_kv(llama_context * ctx)
_lib.llama_get_attn_n_kv.argtypes = [ctypes.POINTER(llama_context)]
_lib.llama_get_attn_n_kv.restype = ctypes.c_int32
# float * llama_get_logits_ith(llama_context * ctx, int32_t i)
_lib.llama_get_logits_ith.argtypes = [ctypes.POINTER(llama_context), ctypes.c_int32]
_lib.llama_get_logits_ith.restype = ctypes.POINTER(ctypes.c_float)
# uint32_t llama_n_ctx(const llama_context * ctx)
_lib.llama_n_ctx.argtypes = [ctypes.POINTER(llama_context)]
_lib.llama_n_ctx.restype = ctypes.c_uint32
# void llama_synchronize(llama_context * ctx)
_lib.llama_synchronize.argtypes = [ctypes.POINTER(llama_context)]
_lib.llama_synchronize.restype = None
# int32_t llama_vocab_n_tokens(const llama_vocab * vocab)
_lib.llama_vocab_n_tokens.argtypes = [ctypes.POINTER(llama_vocab)]
_lib.llama_vocab_n_tokens.restype = ctypes.c_int32
# llama_token llama_vocab_bos(const llama_vocab * vocab)
_lib.llama_vocab_bos.argtypes = [ctypes.POINTER(llama_vocab)]
_lib.llama_vocab_bos.restype = llama_token
# llama_token llama_vocab_eos(const llama_vocab * vocab)
_lib.llama_vocab_eos.argtypes = [ctypes.POINTER(llama_vocab)]
_lib.llama_vocab_eos.restype = llama_token
# int32_t llama_token_to_piece(const llama_vocab *, llama_token, char *, int32_t, int32_t, bool)
_lib.llama_token_to_piece.argtypes = [
ctypes.POINTER(llama_vocab), llama_token,
ctypes.POINTER(ctypes.c_char), ctypes.c_int32, ctypes.c_int32, ctypes.c_bool
]
_lib.llama_token_to_piece.restype = ctypes.c_int32
# --- Context creation ---
# Since llama_context_params is complex and may change layout between versions,
# we use a helper approach: call llama_context_default_params from C, then patch
# the specific fields we need.
# We need to know the struct size and field offsets. Let's use a small C helper.
# Actually, let's just build a properly aligned struct by reading the header.
# The key insight: we can create context via a C helper function.
# For now, let's use ctypes.c_uint8 array as an opaque blob and set fields at
# known byte offsets. This is fragile but works for our specific build.
def _create_context(model_ptr, n_ctx=512, n_batch=512, attn_weights=True, n_gpu_layers=0):
"""Create a llama_context with attention weights enabled.
Uses a small C shim compiled on-the-fly to avoid struct layout issues.
"""
import tempfile, subprocess
shim_src = r"""
#include "llama.h"
#include <stdlib.h>
// Export a function that creates a context with the right params
__attribute__((visibility("default")))
struct llama_context * create_ctx_with_attn(
struct llama_model * model,
int n_ctx, int n_batch, int attn_weights, int n_gpu_layers) {
struct llama_context_params params = llama_context_default_params();
params.n_ctx = n_ctx;
params.n_batch = n_batch;
params.n_ubatch = n_batch;
params.attn_weights = attn_weights ? true : false;
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
params.offload_kqv = n_gpu_layers > 0;
return llama_init_from_model(model, params);
}
"""
llama_dir = os.path.join(os.path.dirname(__file__), "..", "..")
include_dir = os.path.join(llama_dir, "include")
lib_dir = os.path.join(llama_dir, "build", "bin")
ggml_include = os.path.join(llama_dir, "ggml", "include")
with tempfile.NamedTemporaryFile(suffix=".c", mode="w", delete=False) as f:
f.write(shim_src)
src_path = f.name
shim_lib = os.path.join(lib_dir, "libllama_attn_shim.dylib")
if sys.platform == "linux":
shim_lib = os.path.join(lib_dir, "libllama_attn_shim.so")
cmd = [
"cc", "-shared", "-fPIC", "-o", shim_lib, src_path,
f"-I{include_dir}", f"-I{ggml_include}",
f"-L{lib_dir}", "-lllama",
f"-Wl,-rpath,{lib_dir}",
]
result = subprocess.run(cmd, capture_output=True, text=True)
os.unlink(src_path)
if result.returncode != 0:
raise RuntimeError(f"Failed to compile shim: {result.stderr}")
shim = ctypes.CDLL(shim_lib)
shim.create_ctx_with_attn.argtypes = [
ctypes.POINTER(llama_model), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int
]
shim.create_ctx_with_attn.restype = ctypes.POINTER(llama_context)
ctx = shim.create_ctx_with_attn(model_ptr, n_ctx, n_batch, 1 if attn_weights else 0, n_gpu_layers)
if not ctx:
raise RuntimeError("Failed to create llama_context")
return ctx
# --- High-level helpers ---
def tokenize(vocab_ptr, text, add_bos=True, special=True):
"""Tokenize text, returning a list of token ids."""
text_bytes = text.encode("utf-8")
buf = (llama_token * (len(text_bytes) + 32))()
n = _lib.llama_tokenize(vocab_ptr, text_bytes, len(text_bytes), buf, len(buf), add_bos, special)
if n < 0:
buf = (llama_token * (-n))()
n = _lib.llama_tokenize(vocab_ptr, text_bytes, len(text_bytes), buf, len(buf), add_bos, special)
return list(buf[:n])
def decode_batch(ctx_ptr, tokens, output_last_only=True):
"""Decode a batch of tokens. Returns llama_decode return code."""
n = len(tokens)
batch = _lib.llama_batch_init(n, 0, 1)
for i in range(n):
batch.token[i] = tokens[i]
batch.pos[i] = i
batch.n_seq_id[i] = 1
# Write seq_id value into the pre-allocated buffer (don't replace the pointer)
batch.seq_id[i][0] = 0
batch.logits[i] = 1 if (not output_last_only or i == n - 1) else 0
batch.n_tokens = n
ret = _lib.llama_decode(ctx_ptr, batch)
_lib.llama_batch_free(batch)
return ret
def decode_single(ctx_ptr, token, pos, output=True):
"""Decode a single token at a given position."""
batch = _lib.llama_batch_init(1, 0, 1)
batch.token[0] = token
batch.pos[0] = pos
batch.n_seq_id[0] = 1
batch.seq_id[0][0] = 0 # Write value into pre-allocated buffer
batch.logits[0] = 1 if output else 0
batch.n_tokens = 1
ret = _lib.llama_decode(ctx_ptr, batch)
_lib.llama_batch_free(batch)
return ret
def get_attn_weights(ctx_ptr, token_idx, n_pairs, n_ctx):
"""Get attention weights for a given output token index.
Returns numpy array of shape (n_pairs, n_kv) or None.
"""
import numpy as np
ptr = _lib.llama_get_attn_ith(ctx_ptr, token_idx)
if not ptr:
return None
n_kv = _lib.llama_get_attn_n_kv(ctx_ptr)
if n_kv <= 0:
return None
# Layout: [n_pairs * n_ctx] floats, each pair has n_ctx floats, first n_kv valid
result = np.zeros((n_pairs, n_kv), dtype=np.float32)
for p in range(n_pairs):
offset = p * n_ctx
arr = (ctypes.c_float * n_kv).from_address(ctypes.addressof(ptr.contents) + offset * 4)
result[p] = np.frombuffer(arr, dtype=np.float32)
return result
def argmax_logits(ctx_ptr, token_idx, n_vocab):
"""Get the argmax of logits for a given output token."""
ptr = _lib.llama_get_logits_ith(ctx_ptr, token_idx)
if not ptr:
return -1
logits = (ctypes.c_float * n_vocab).from_address(ctypes.addressof(ptr.contents))
import numpy as np
return int(np.argmax(np.frombuffer(logits, dtype=np.float32)))
# --- Public API ---
def init():
_lib.llama_backend_init()
def cleanup():
_lib.llama_backend_free()
def load_model(path, n_gpu_layers=0):
params = _lib.llama_model_default_params()
# n_gpu_layers is at offset 0 in llama_model_params (first field)
# Actually let's just use default params for simplicity
model = _lib.llama_model_load_from_file(path.encode(), params)
if not model:
raise RuntimeError(f"Failed to load model from {path}")
return model
def create_context(model, n_ctx=512, n_batch=512, attn_weights=True):
return _create_context(model, n_ctx, n_batch, attn_weights)
def set_attn_heads(ctx, layers, heads):
n = len(layers)
assert len(heads) == n
l_arr = (ctypes.c_int32 * n)(*layers)
h_arr = (ctypes.c_int32 * n)(*heads)
_lib.llama_set_attn_heads(ctx, l_arr, h_arr, n)
def get_vocab(model):
return _lib.llama_model_get_vocab(model)
def n_layer(model):
return _lib.llama_model_n_layer(model)
def n_head(model):
return _lib.llama_model_n_head(model)
def n_vocab(vocab):
return _lib.llama_vocab_n_tokens(vocab)
def n_ctx(ctx):
return _lib.llama_n_ctx(ctx)
def vocab_eos(vocab):
return _lib.llama_vocab_eos(vocab)
def free_context(ctx):
_lib.llama_free(ctx)
def token_to_piece(vocab, token_id, special=True):
"""Convert a single token ID to its string piece."""
buf = (ctypes.c_char * 256)()
n = _lib.llama_token_to_piece(vocab, token_id, buf, 256, 0, special)
if n > 0:
return buf[:n].decode("utf-8", errors="replace")
return ""
def free_model(model):
_lib.llama_model_free(model)

View File

@ -0,0 +1,249 @@
#!/usr/bin/env python3
"""
Test: verify llama.cpp attention weight extraction works correctly.
Usage:
python3 benchmark/test_attn_weights.py <model.gguf>
Tests:
1. Basic extraction: attention weights are non-null and sum to ~1.0
2. Multi-head: multiple (layer, head) pairs return independent weights
3. Greedy generation: attention is extracted at each autoregressive step
4. Cross-validation with known properties (monotonicity, sparsity)
"""
import sys
import os
import numpy as np
sys.path.insert(0, os.path.dirname(__file__))
import llama_attn as ll
def test_basic(model, vocab, ctx, n_layers):
"""Test 1: basic attention extraction on a prompt."""
print("=" * 60)
print("TEST 1: Basic attention extraction")
print("=" * 60)
tokens = ll.tokenize(vocab, "The quick brown fox jumps over the lazy dog")
n_tokens = len(tokens)
print(f" Tokens: {n_tokens}")
ret = ll.decode_batch(ctx, tokens, output_last_only=True)
assert ret == 0, f"decode failed: {ret}"
n_c = ll.n_ctx(ctx)
attn = ll.get_attn_weights(ctx, -1, 1, n_c)
assert attn is not None, "get_attn_weights returned None"
n_kv = ll._lib.llama_get_attn_n_kv(ctx)
print(f" n_kv: {n_kv}")
print(f" Attention shape: {attn.shape}")
print(f" Attention sum: {attn[0].sum():.6f}")
print(f" Attention max: {attn[0].max():.6f} at position {attn[0].argmax()}")
print(f" Attention min: {attn[0].min():.6f}")
# Softmax output should sum to ~1.0
assert abs(attn[0].sum() - 1.0) < 0.05, f"Attention doesn't sum to 1.0: {attn[0].sum()}"
# All values should be non-negative
assert (attn[0] >= 0).all(), "Negative attention values found"
print(" PASSED\n")
return True
def test_multi_head(model, vocab, ctx, n_layers):
"""Test 2: multiple (layer, head) pairs."""
print("=" * 60)
print("TEST 2: Multi-head attention extraction")
print("=" * 60)
# Set multiple heads across different layers
layers = [0, n_layers // 2, n_layers - 1]
heads = [0, 0, 0]
n_pairs = len(layers)
ll.set_attn_heads(ctx, layers, heads)
print(f" Configured {n_pairs} heads: {list(zip(layers, heads))}")
tokens = ll.tokenize(vocab, "Hello world, this is a test of attention")
ret = ll.decode_batch(ctx, tokens, output_last_only=True)
assert ret == 0, f"decode failed: {ret}"
n_c = ll.n_ctx(ctx)
attn = ll.get_attn_weights(ctx, -1, n_pairs, n_c)
assert attn is not None, "get_attn_weights returned None"
print(f" Attention shape: {attn.shape}")
for p in range(n_pairs):
s = attn[p].sum()
print(f" Pair {p} (L{layers[p]},H{heads[p]}): sum={s:.6f}, max={attn[p].max():.4f} @ pos {attn[p].argmax()}")
assert abs(s - 1.0) < 0.05, f"Pair {p} doesn't sum to 1.0: {s}"
# Different layers should produce different attention patterns
if n_pairs >= 2:
diff = np.abs(attn[0] - attn[-1]).mean()
print(f" Mean abs difference between first and last layer: {diff:.6f}")
# They should not be identical (unless the model is degenerate)
# Don't assert this as a hard requirement
# Reset to default (last layer, head 0)
ll.set_attn_heads(ctx, [n_layers - 1], [0])
print(" PASSED\n")
return True
def test_generation(model, vocab, ctx, n_layers):
"""Test 3: attention during autoregressive generation."""
print("=" * 60)
print("TEST 3: Autoregressive generation with attention")
print("=" * 60)
tokens = ll.tokenize(vocab, "Once upon a time")
n_prompt = len(tokens)
print(f" Prompt: {n_prompt} tokens")
# Prefill
ret = ll.decode_batch(ctx, tokens, output_last_only=True)
assert ret == 0, f"prefill decode failed: {ret}"
n_c = ll.n_ctx(ctx)
nv = ll.n_vocab(vocab)
eos = ll.vocab_eos(vocab)
max_gen = 10
gen_tokens = []
attn_sums = []
for step in range(max_gen):
# Get attention for current token
attn = ll.get_attn_weights(ctx, -1, 1, n_c)
assert attn is not None, f"Step {step}: attention is None"
n_kv = ll._lib.llama_get_attn_n_kv(ctx)
s = attn[0].sum()
attn_sums.append(s)
# Get next token (greedy)
next_tok = ll.argmax_logits(ctx, -1, nv)
if next_tok == eos:
print(f" Step {step}: EOS")
break
gen_tokens.append(next_tok)
# Decode next token
pos = n_prompt + step
ret = ll.decode_single(ctx, next_tok, pos, output=True)
assert ret == 0, f"Step {step}: decode failed: {ret}"
print(f" Generated {len(gen_tokens)} tokens")
print(f" Attention sums: {[f'{s:.4f}' for s in attn_sums]}")
for i, s in enumerate(attn_sums):
assert abs(s - 1.0) < 0.05, f"Step {i}: attention sum = {s}"
print(" PASSED\n")
return True
def test_multiple_heads_same_layer(model, vocab, ctx, n_layers):
"""Test 4: multiple heads from the same layer."""
print("=" * 60)
print("TEST 4: Multiple heads from same layer")
print("=" * 60)
n_h = ll.n_head(model)
last_layer = n_layers - 1
n_test_heads = min(4, n_h)
layers = [last_layer] * n_test_heads
heads = list(range(n_test_heads))
ll.set_attn_heads(ctx, layers, heads)
print(f" Layer {last_layer}, heads {heads}")
tokens = ll.tokenize(vocab, "Attention is all you need")
ret = ll.decode_batch(ctx, tokens, output_last_only=True)
assert ret == 0, f"decode failed: {ret}"
n_c = ll.n_ctx(ctx)
attn = ll.get_attn_weights(ctx, -1, n_test_heads, n_c)
assert attn is not None, "get_attn_weights returned None"
print(f" Attention shape: {attn.shape}")
for h in range(n_test_heads):
s = attn[h].sum()
peak = attn[h].argmax()
print(f" Head {h}: sum={s:.6f}, peak @ pos {peak}, max={attn[h].max():.4f}")
assert abs(s - 1.0) < 0.05, f"Head {h}: sum = {s}"
# Different heads should show at least some variation
if n_test_heads >= 2:
patterns_identical = all(
np.allclose(attn[0], attn[h], atol=1e-5)
for h in range(1, n_test_heads)
)
if patterns_identical:
print(" WARNING: all heads have identical attention patterns")
else:
print(" OK: heads show different patterns")
# Reset
ll.set_attn_heads(ctx, [n_layers - 1], [0])
print(" PASSED\n")
return True
def main():
if len(sys.argv) < 2:
print(f"Usage: {sys.argv[0]} <model.gguf>")
sys.exit(1)
model_path = sys.argv[1]
n_ctx = 512
if len(sys.argv) > 2:
n_ctx = int(sys.argv[2])
print(f"Model: {model_path}")
print(f"n_ctx: {n_ctx}\n")
ll.init()
model = ll.load_model(model_path)
vocab = ll.get_vocab(model)
n_layers = ll.n_layer(model)
n_heads = ll.n_head(model)
nv = ll.n_vocab(vocab)
print(f"Loaded: {n_layers} layers, {n_heads} heads, vocab={nv}\n")
passed = 0
failed = 0
for test_fn in [test_basic, test_multi_head, test_generation, test_multiple_heads_same_layer]:
# Create fresh context for each test
ctx = ll.create_context(model, n_ctx=n_ctx, n_batch=n_ctx, attn_weights=True)
try:
if test_fn(model, vocab, ctx, n_layers):
passed += 1
else:
failed += 1
except Exception as e:
print(f" FAILED: {e}\n")
failed += 1
finally:
ll.free_context(ctx)
print("=" * 60)
print(f"Results: {passed} passed, {failed} failed")
print("=" * 60)
ll.free_model(model)
ll.cleanup()
sys.exit(0 if failed == 0 else 1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,558 @@
{
"model": "TranslateGemma-4B",
"language_pair": "en-zh",
"num_layers": 34,
"num_heads": 8,
"num_sentences": 200,
"total_alignable_tokens": 5283,
"ts_threshold": 0.1,
"ts_matrix": [
[
0.0015142911224682945,
0.017982207079311,
0.029528676888131742,
0.028203672155971984,
0.0009464319515426841,
0.0018928639030853683,
0.0,
0.0
],
[
0.0,
0.0,
0.0028392958546280523,
0.0001892863903085368,
0.0003785727806170736,
0.008517887563884156,
0.0,
0.0
],
[
0.0,
0.0,
0.0,
0.0,
0.0017035775127768314,
0.0,
0.0,
0.05148589816392202
],
[
0.0,
0.0,
0.0,
0.0,
0.0,
0.0001892863903085368,
0.0,
0.0
],
[
0.045618020064357376,
0.007003596441415862,
0.10240393715691842,
0.02138936210486466,
0.0,
0.0,
0.0,
0.0
],
[
0.20632216543630513,
0.17527919742570508,
0.21370433465833807,
0.02574294908196101,
0.0,
0.0005678591709256105,
0.0001892863903085368,
0.03388226386522809
],
[
0.11073253833049404,
0.0028392958546280523,
0.05830020821502934,
0.0308536816202915,
0.10618966496308915,
0.0001892863903085368,
0.002271436683702442,
0.0003785727806170736
],
[
0.0,
0.0,
0.0,
0.0,
0.0,
0.0018928639030853683,
0.003217868635245126,
0.11186825667234526
],
[
0.0013250047321597578,
0.0308536816202915,
0.0,
0.0,
0.03615370054893053,
0.06719666855953058,
0.0,
0.0
],
[
0.1378004921446148,
0.0035964414158621994,
0.13250047321597577,
0.24626159379140639,
0.008707173954192694,
0.039371569184175656,
0.003975014196479273,
0.13325761877720993
],
[
0.0,
0.0,
0.0001892863903085368,
0.0005678591709256105,
0.0,
0.0,
0.0007571455612341472,
0.0017035775127768314
],
[
0.0,
0.06643952299829642,
0.26500094643195155,
0.17565777020632217,
0.011167897028203672,
0.3577512776831346,
0.11451826613666477,
0.07666098807495741
],
[
0.012114328979746356,
0.24096157486276737,
0.0001892863903085368,
0.0,
0.004542873367404884,
0.0056785917092561046,
0.05848949460533787,
0.00946431951542684
],
[
0.0,
0.0,
0.167139882642438,
0.0,
0.1419647927314026,
0.021010789324247586,
0.0026500094643195156,
0.0009464319515426841
],
[
0.030285822449365892,
0.0009464319515426841,
0.18076850274465267,
0.24285443876585275,
0.002271436683702442,
0.0,
0.0001892863903085368,
0.0
],
[
0.0009464319515426841,
0.0001892863903085368,
0.0,
0.0,
0.1198182850653038,
0.049025175089911034,
0.06814310051107325,
0.1692220329358319
],
[
0.0,
0.07760742002650009,
0.0,
0.19307211811470756,
0.010410751466969525,
0.017982207079311,
0.0,
0.11792542116221844
],
[
0.2824152943403369,
0.13003975014196478,
0.15010410751466968,
0.3244368729888321,
0.14328979746356238,
0.37100132500473215,
0.0,
0.0
],
[
0.0,
0.049403747870528106,
0.05318947567669884,
0.0001892863903085368,
0.0,
0.0,
0.0,
0.0
],
[
0.0,
0.2040507287526027,
0.0,
0.0,
0.0308536816202915,
0.27900813931478324,
0.0001892863903085368,
0.07230740109786106
],
[
0.012114328979746356,
0.26537951921256864,
0.0,
0.009842892296043914,
0.0,
0.0,
0.0,
0.0068143100511073255
],
[
0.3876585273518834,
0.019875070982396367,
0.0,
0.0,
0.0,
0.0,
0.12114328979746357,
0.0009464319515426841
],
[
0.0,
0.0,
0.0001892863903085368,
0.0,
0.0,
0.19250425894378195,
0.13496119628998676,
0.19269354533409047
],
[
0.030096536059057353,
0.038425137232632973,
0.004921446148021957,
0.3140261215218626,
0.0,
0.0,
0.3736513344690517,
0.24550444823017226
],
[
0.0001892863903085368,
0.08006814310051108,
0.010221465076660987,
0.0017035775127768314,
0.0848003028582245,
0.0,
0.0013250047321597578,
0.0
],
[
0.0015142911224682945,
0.0,
0.0,
0.0,
0.0017035775127768314,
0.0,
0.0,
0.0
],
[
0.024039371569184176,
0.012682188150671967,
0.03180011357183418,
0.006246450880181715,
0.014007192882831724,
0.049593034260836645,
0.0,
0.0
],
[
0.0,
0.007571455612341473,
0.0,
0.0,
0.0,
0.0,
0.07363240583002083,
0.0
],
[
0.00473215975771342,
0.022146507666098807,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0
],
[
0.0003785727806170736,
0.0003785727806170736,
0.0005678591709256105,
0.0005678591709256105,
0.002082150293393905,
0.003028582244936589,
0.0,
0.0
],
[
0.0013250047321597578,
0.0003785727806170736,
0.005867878099564641,
0.01722506151807685,
0.0003785727806170736,
0.0003785727806170736,
0.0026500094643195156,
0.004542873367404884
],
[
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.011357183418512209
],
[
0.0,
0.025175089911035398,
0.0,
0.0001892863903085368,
0.0,
0.0,
0.0,
0.0
],
[
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0,
0.0
]
],
"token_alignment_heads": [
{
"layer": 21,
"head": 0,
"ts": 0.3877
},
{
"layer": 23,
"head": 6,
"ts": 0.3737
},
{
"layer": 17,
"head": 5,
"ts": 0.371
},
{
"layer": 11,
"head": 5,
"ts": 0.3578
},
{
"layer": 17,
"head": 3,
"ts": 0.3244
},
{
"layer": 23,
"head": 3,
"ts": 0.314
},
{
"layer": 17,
"head": 0,
"ts": 0.2824
},
{
"layer": 19,
"head": 5,
"ts": 0.279
},
{
"layer": 20,
"head": 1,
"ts": 0.2654
},
{
"layer": 11,
"head": 2,
"ts": 0.265
},
{
"layer": 9,
"head": 3,
"ts": 0.2463
},
{
"layer": 23,
"head": 7,
"ts": 0.2455
},
{
"layer": 14,
"head": 3,
"ts": 0.2429
},
{
"layer": 12,
"head": 1,
"ts": 0.241
},
{
"layer": 5,
"head": 2,
"ts": 0.2137
},
{
"layer": 5,
"head": 0,
"ts": 0.2063
},
{
"layer": 19,
"head": 1,
"ts": 0.2041
},
{
"layer": 16,
"head": 3,
"ts": 0.1931
},
{
"layer": 22,
"head": 7,
"ts": 0.1927
},
{
"layer": 22,
"head": 5,
"ts": 0.1925
},
{
"layer": 14,
"head": 2,
"ts": 0.1808
},
{
"layer": 11,
"head": 3,
"ts": 0.1757
},
{
"layer": 5,
"head": 1,
"ts": 0.1753
},
{
"layer": 15,
"head": 7,
"ts": 0.1692
},
{
"layer": 13,
"head": 2,
"ts": 0.1671
},
{
"layer": 17,
"head": 2,
"ts": 0.1501
},
{
"layer": 17,
"head": 4,
"ts": 0.1433
},
{
"layer": 13,
"head": 4,
"ts": 0.142
},
{
"layer": 9,
"head": 0,
"ts": 0.1378
},
{
"layer": 22,
"head": 6,
"ts": 0.135
},
{
"layer": 9,
"head": 7,
"ts": 0.1333
},
{
"layer": 9,
"head": 2,
"ts": 0.1325
},
{
"layer": 17,
"head": 1,
"ts": 0.13
},
{
"layer": 21,
"head": 6,
"ts": 0.1211
},
{
"layer": 15,
"head": 4,
"ts": 0.1198
},
{
"layer": 16,
"head": 7,
"ts": 0.1179
},
{
"layer": 11,
"head": 6,
"ts": 0.1145
},
{
"layer": 7,
"head": 7,
"ts": 0.1119
},
{
"layer": 6,
"head": 0,
"ts": 0.1107
},
{
"layer": 6,
"head": 4,
"ts": 0.1062
},
{
"layer": 4,
"head": 2,
"ts": 0.1024
}
]
}

View File

@ -372,6 +372,7 @@ extern "C" {
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
bool attn_weights; // if true, extract attention weights (requires flash_attn disabled)
// [EXPERIMENTAL]
// backend sampler chain configuration (make sure the caller keeps the sampler chains alive)
@ -1012,6 +1013,31 @@ extern "C" {
// otherwise: float[n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
//
// attention weights API [EXPERIMENTAL]
// requires llama_context_params.attn_weights = true and flash_attn disabled
//
// Configure which (layer, head) pairs to extract attention weights from.
// By default (if never called), extracts head 0 of the last layer.
// layers and heads are parallel arrays of length n_pairs.
LLAMA_API void llama_set_attn_heads(
struct llama_context * ctx,
const int32_t * layers,
const int32_t * heads,
size_t n_pairs);
// Get attention weights for the ith output token.
// Returns a pointer to [n_pairs * n_ctx] floats where n_ctx = llama_n_ctx(ctx).
// Each pair occupies n_ctx floats; only the first llama_get_attn_n_kv() entries per pair are valid.
// Pairs are in the order set by llama_set_attn_heads: pair p starts at offset p * n_ctx.
// Returns NULL if attention extraction is not enabled or index is invalid.
LLAMA_API float * llama_get_attn_ith(struct llama_context * ctx, int32_t i);
// Get the number of valid KV entries per (layer, head) pair in the attention weight arrays.
// Use this to know how many entries to read from each n_ctx-sized chunk.
LLAMA_API int32_t llama_get_attn_n_kv(struct llama_context * ctx);
//
// backend sampling API [EXPERIMENTAL]
// note: use only if the llama_context was created with at least one llama_sampler_seq_config

View File

@ -163,6 +163,17 @@ llama_context::llama_context(
cparams.op_offload = params.op_offload;
cparams.kv_unified = params.kv_unified;
cparams.attn_weights = params.attn_weights;
if (cparams.attn_weights && cparams.flash_attn) {
throw std::runtime_error("attention weight extraction requires flash_attn to be disabled");
}
if (cparams.attn_weights) {
// default: head 0 of the last layer
const int32_t default_layer = (int32_t)(hparams.n_layer - 1);
attn_heads.push_back({default_layer, 0});
cparams.attn_layers.insert(default_layer);
}
// initialized later
cparams.pipeline_parallel = false;
@ -866,6 +877,126 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
return it->second.data();
}
void llama_context::set_attn_heads(const int32_t * layers, const int32_t * heads, size_t n_pairs) {
const auto & hparams = model.hparams;
attn_heads.clear();
attn_heads.reserve(n_pairs);
cparams.attn_layers.clear();
for (size_t i = 0; i < n_pairs; i++) {
if (layers[i] < 0 || layers[i] >= (int32_t) hparams.n_layer) {
LLAMA_LOG_ERROR("%s: invalid layer index %d (model has %u layers), skipping\n",
__func__, layers[i], hparams.n_layer);
continue;
}
if (heads[i] < 0 || heads[i] >= (int32_t) hparams.n_head(layers[i])) {
LLAMA_LOG_ERROR("%s: invalid head index %d for layer %d (layer has %u heads), skipping\n",
__func__, heads[i], layers[i], hparams.n_head(layers[i]));
continue;
}
attn_heads.push_back({layers[i], heads[i]});
cparams.attn_layers.insert(layers[i]);
}
}
float * llama_context::get_attn_ith(int32_t i) {
output_reorder();
try {
if (attn.data == nullptr) {
return nullptr;
}
const int64_t j = output_resolve_row(i);
const size_t attn_stride = attn_heads.size() * cparams.n_ctx;
return attn.data + j * attn_stride;
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid attention id %d, reason: %s\n", __func__, i, err.what());
return nullptr;
}
}
int32_t llama_context::get_attn_n_kv() const {
return attn_n_kv;
}
// static
bool llama_context::attn_cb_eval_fn(struct ggml_tensor * t, bool ask, void * user_data) {
auto * state = static_cast<attn_cb_state *>(user_data);
auto * ctx = state->ctx;
const auto user_cb = ctx->cparams.cb_eval;
const auto user_data_cb = ctx->cparams.cb_eval_user_data;
const char * name = t->name;
if (strncmp(name, "kq_soft_max-", 12) != 0) {
return user_cb ? user_cb(t, ask, user_data_cb) : false;
}
const int layer_idx = atoi(name + 12);
if (ctx->cparams.attn_layers.find(layer_idx) == ctx->cparams.attn_layers.end()) {
return user_cb ? user_cb(t, ask, user_data_cb) : false;
}
if (ask) {
if (user_cb) {
user_cb(t, true, user_data_cb);
}
return true;
}
// data is available, extract attention weights
const auto & ubatch = *state->ubatch;
const auto & attn_heads_vec = ctx->attn_heads;
const size_t n_pairs = attn_heads_vec.size();
const size_t attn_stride = n_pairs * ctx->cparams.n_ctx;
const int64_t t_n_kv = t->ne[0];
const int64_t t_n_tokens = t->ne[1];
const int64_t t_n_head = t->ne[2];
const int64_t t_n_stream = t->ne[3];
ctx->attn_n_kv = (int32_t) t_n_kv;
for (size_t p = 0; p < n_pairs; p++) {
if (attn_heads_vec[p].layer != layer_idx) {
continue;
}
const int head = attn_heads_vec[p].head;
if (head >= t_n_head) {
continue;
}
int32_t out_idx = 0;
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
if (!ubatch.output[i]) {
continue;
}
const int64_t stream = t_n_stream > 1 ? (int64_t) i / t_n_tokens : 0;
const int64_t t_in_str = t_n_stream > 1 ? (int64_t) i % t_n_tokens : (int64_t) i;
const size_t src_offset = stream * t->nb[3] + head * t->nb[2] + t_in_str * t->nb[1];
float * dst = ctx->attn.data + (state->n_outputs_prev + out_idx) * attn_stride + p * ctx->cparams.n_ctx;
ggml_backend_tensor_get(t, dst, src_offset, t_n_kv * sizeof(float));
out_idx++;
}
}
if (user_cb) {
return user_cb(t, false, user_data_cb);
}
return true;
}
llama_token llama_context::get_sampled_token_ith(int32_t idx) {
output_reorder();
@ -1192,7 +1323,12 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
res->reset();
ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
if (attn.data && cparams.attn_weights && !attn_heads.empty()) {
ggml_backend_sched_set_eval_callback(sched.get(), attn_cb_eval_fn, &attn_cb);
} else {
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
}
//const auto t_start_us = ggml_time_us();
@ -1680,6 +1816,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
n_outputs = n_outputs_new;
}
attn_cb.ctx = this;
attn_cb.ubatch = &ubatch;
attn_cb.n_outputs_prev = n_outputs_prev;
ggml_status status;
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
@ -1899,6 +2039,10 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
logits.size = has_logits ? n_vocab*n_outputs_max : 0;
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
const bool has_attn = cparams.attn_weights && !attn_heads.empty();
const size_t n_attn_pairs = attn_heads.size();
attn.size = has_attn ? n_attn_pairs * cparams.n_ctx * n_outputs_max : 0;
// Allocate backend sampling output buffers if there are backend samplers configured.
const bool has_sampling = !sampling.samplers.empty();
if (has_sampling) {
@ -1913,8 +2057,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0;
const size_t new_size =
(logits.size + embd.size + backend_float_count) * sizeof(float) +
( backend_token_count) * sizeof(llama_token);
(logits.size + embd.size + attn.size + backend_float_count) * sizeof(float) +
( backend_token_count) * sizeof(llama_token);
// alloc only when more than the current capacity is required
// TODO: also consider shrinking the buffer
@ -1930,6 +2074,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
buf_output = nullptr;
logits.data = nullptr;
embd.data = nullptr;
attn.data = nullptr;
}
auto * buft = ggml_backend_cpu_buffer_type();
@ -1957,6 +2102,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0};
offset += embd.size * sizeof(float);
attn = has_attn ? buffer_view<float>{(float *) (base + offset), attn.size} : buffer_view<float>{nullptr, 0};
offset += attn.size * sizeof(float);
if (has_sampling) {
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
offset += sampling.logits.size * sizeof(float);
@ -2021,6 +2169,13 @@ void llama_context::output_reorder() {
}
}
if (attn.size > 0) {
const uint64_t attn_stride = attn_heads.size() * cparams.n_ctx;
for (uint64_t k = 0; k < attn_stride; k++) {
std::swap(attn.data[i0*attn_stride + k], attn.data[i1*attn_stride + k]);
}
}
if (!sampling.samplers.empty()) {
assert(sampling.logits.size > 0);
assert(sampling.probs.size > 0);
@ -2180,7 +2335,7 @@ ggml_status llama_context::graph_compute(
LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
}
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched));
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched.get()));
return status;
}
@ -2903,6 +3058,7 @@ llama_context_params llama_context_default_params() {
/*.op_offload =*/ true,
/*.swa_full =*/ true,
/*.kv_unified =*/ false,
/*.attn_weights =*/ false,
/*.sampler =*/ nullptr,
/*.n_sampler =*/ 0,
};
@ -3097,6 +3253,22 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
return ctx->get_embeddings_seq(seq_id);
}
void llama_set_attn_heads(llama_context * ctx, const int32_t * layers, const int32_t * heads, size_t n_pairs) {
ctx->set_attn_heads(layers, heads, n_pairs);
}
float * llama_get_attn_ith(llama_context * ctx, int32_t i) {
ctx->synchronize();
return ctx->get_attn_ith(i);
}
int32_t llama_get_attn_n_kv(llama_context * ctx) {
ctx->synchronize();
return ctx->get_attn_n_kv();
}
bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
return ctx->set_sampler(seq_id, smpl);
}

View File

@ -79,6 +79,10 @@ struct llama_context {
float * get_embeddings_ith(int32_t i);
float * get_embeddings_seq(llama_seq_id seq_id);
void set_attn_heads(const int32_t * layers, const int32_t * heads, size_t n_pairs);
float * get_attn_ith(int32_t i);
int32_t get_attn_n_kv() const;
llama_token * get_sampled_tokens() const;
llama_token get_sampled_token_ith(int32_t idx);
@ -295,6 +299,24 @@ private:
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
std::map<llama_seq_id, std::vector<float>> embd_seq;
// attention weight extraction
struct attn_head_pair {
int32_t layer;
int32_t head;
};
std::vector<attn_head_pair> attn_heads; // which (layer, head) pairs to extract
buffer_view<float> attn = {nullptr, 0}; // [n_outputs][n_pairs * n_ctx]
int32_t attn_n_kv = 0; // KV cache length at time of last decode
struct attn_cb_state {
llama_context * ctx = nullptr;
const llama_ubatch * ubatch = nullptr;
int64_t n_outputs_prev = 0;
};
attn_cb_state attn_cb;
static bool attn_cb_eval_fn(struct ggml_tensor * t, bool ask, void * user_data);
// reuse the batch_allocr to avoid unnecessary memory allocations
std::unique_ptr<llama_batch_allocr> balloc;

View File

@ -3,6 +3,7 @@
#include "llama.h"
#include <cstdint>
#include <set>
#define LLAMA_MAX_SEQ 256
@ -39,6 +40,9 @@ struct llama_cparams {
bool op_offload;
bool kv_unified;
bool pipeline_parallel;
bool attn_weights;
std::set<int32_t> attn_layers; // which layers to extract attention from (derived from attn_heads)
enum llama_pooling_type pooling_type;

View File

@ -737,6 +737,7 @@ void llm_graph_result::reset() {
t_logits = nullptr;
t_embd = nullptr;
t_embd_pooled = nullptr;
t_attn.clear();
t_sampled.clear();
t_sampled_probs.clear();
t_sampled_logits.clear();
@ -1883,6 +1884,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
ggml_soft_max_add_sinks(kq, sinks);
cb(kq, "kq_soft_max", il);
if (!cparams.attn_layers.empty() && cparams.attn_layers.count(il)) {
res->t_attn[il] = kq;
}
if (!v_trans) {
// note: avoid this branch
v = ggml_cont(ctx0, ggml_transpose(ctx0, v));

View File

@ -613,8 +613,9 @@ struct llm_graph_params {
}
return
cparams.embeddings == other.cparams.embeddings &&
cparams.causal_attn == other.cparams.causal_attn &&
cparams.embeddings == other.cparams.embeddings &&
cparams.causal_attn == other.cparams.causal_attn &&
cparams.attn_layers == other.cparams.attn_layers &&
arch == other.arch &&
gtype == other.gtype &&
cvec == other.cvec &&
@ -667,6 +668,10 @@ public:
std::map<llama_seq_id, ggml_tensor*> t_sampled;
std::map<llama_seq_id, ggml_tensor*> t_sampled_probs;
// attention weight tensors, keyed by layer index
// shape per tensor: [n_kv, n_tokens, n_head, n_stream] (after permute in build_attn_mha)
std::map<int, ggml_tensor *> t_attn;
std::vector<llm_graph_input_ptr> inputs;
ggml_context_ptr ctx_compute;