From 14bf6d45bb526d23e66566c7563aaa201c78b94c Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Fri, 27 Feb 2026 21:47:12 +0100 Subject: [PATCH 1/2] llama : add attention weights extraction API [EXPERIMENTAL] --- examples/CMakeLists.txt | 1 + examples/attn-weights/CMakeLists.txt | 5 + examples/attn-weights/attn-weights.cpp | 255 ++++++++ .../eval_benchmark_translategemma.py | 343 +++++++++++ examples/attn-weights/llama_attn.py | 387 ++++++++++++ examples/attn-weights/test_attn_weights.py | 249 ++++++++ .../attn-weights/translation_heads_en_zh.json | 558 ++++++++++++++++++ include/llama.h | 26 + src/llama-context.cpp | 146 ++++- src/llama-context.h | 13 + src/llama-cparams.h | 4 + src/llama-graph.cpp | 10 + src/llama-graph.h | 9 +- 13 files changed, 2002 insertions(+), 4 deletions(-) create mode 100644 examples/attn-weights/CMakeLists.txt create mode 100644 examples/attn-weights/attn-weights.cpp create mode 100644 examples/attn-weights/eval_benchmark_translategemma.py create mode 100644 examples/attn-weights/llama_attn.py create mode 100644 examples/attn-weights/test_attn_weights.py create mode 100644 examples/attn-weights/translation_heads_en_zh.json diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index a29dc707c3..139faa6d1d 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -14,6 +14,7 @@ llama_add_compile_flags() if (EMSCRIPTEN) else() + add_subdirectory(attn-weights) add_subdirectory(batched) add_subdirectory(debug) add_subdirectory(embedding) diff --git a/examples/attn-weights/CMakeLists.txt b/examples/attn-weights/CMakeLists.txt new file mode 100644 index 0000000000..4bfa55e26c --- /dev/null +++ b/examples/attn-weights/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-attn-weights) +add_executable(${TARGET} attn-weights.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/attn-weights/attn-weights.cpp b/examples/attn-weights/attn-weights.cpp new file mode 100644 index 0000000000..8de9bbd68a --- /dev/null +++ b/examples/attn-weights/attn-weights.cpp @@ -0,0 +1,255 @@ +// Attention weights extraction example +// +// Decodes a prompt and prints an ASCII heatmap of the attention pattern +// for a selected (layer, head) pair. +// +// Usage: +// ./llama-attn-weights -m model.gguf [-l layer] [-hd head] [-ngl N] [prompt] +// +// Defaults: last layer, head 0. +// Use -l -1 for the last layer. + +#include "llama.h" + +#include +#include +#include +#include +#include +#include + +static std::string token_str(const llama_vocab * vocab, llama_token id) { + char buf[256]; + int n = llama_token_to_piece(vocab, id, buf, sizeof(buf), 0, true); + if (n < 0) { + return "???"; + } + // sanitize: replace newlines/tabs with visible chars + std::string s(buf, n); + for (auto & c : s) { + if (c == '\n') { c = '|'; } + if (c == '\t') { c = ' '; } + if (c == '\r') { c = ' '; } + } + return s; +} + +// Truncate or pad a string to exactly `w` characters +static std::string fit(const std::string & s, int w) { + if ((int) s.size() <= w) { + return s + std::string(w - (int) s.size(), ' '); + } + return s.substr(0, w); +} + +static void print_usage(int, char ** argv) { + printf("\nexample usage:\n"); + printf("\n %s -m model.gguf [-l layer] [-hd head] [-ngl n_gpu_layers] [prompt]\n", argv[0]); + printf("\n"); + printf(" -l layer index (default: -1 = last layer)\n"); + printf(" -hd head index (default: 0)\n"); + printf("\n"); +} + +int main(int argc, char ** argv) { + std::string model_path; + std::string prompt = "The cat sat on the mat"; + int ngl = 99; + int layer = -1; // -1 means last layer + int head = 0; + + // parse args + { + int i = 1; + for (; i < argc; i++) { + if (strcmp(argv[i], "-m") == 0) { + if (i + 1 < argc) { model_path = argv[++i]; } + else { print_usage(argc, argv); return 1; } + } else if (strcmp(argv[i], "-l") == 0) { + if (i + 1 < argc) { + try { layer = std::stoi(argv[++i]); } catch (...) { print_usage(argc, argv); return 1; } + } else { print_usage(argc, argv); return 1; } + } else if (strcmp(argv[i], "-hd") == 0) { + if (i + 1 < argc) { + try { head = std::stoi(argv[++i]); } catch (...) { print_usage(argc, argv); return 1; } + } else { print_usage(argc, argv); return 1; } + } else if (strcmp(argv[i], "-ngl") == 0) { + if (i + 1 < argc) { + try { ngl = std::stoi(argv[++i]); } catch (...) { print_usage(argc, argv); return 1; } + } else { print_usage(argc, argv); return 1; } + } else { + break; + } + } + if (model_path.empty()) { + print_usage(argc, argv); + return 1; + } + if (i < argc) { + prompt = argv[i++]; + for (; i < argc; i++) { + prompt += " "; + prompt += argv[i]; + } + } + } + + // load backends & model + ggml_backend_load_all(); + + llama_model_params model_params = llama_model_default_params(); + model_params.n_gpu_layers = ngl; + + llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params); + if (!model) { + fprintf(stderr, "%s: error: unable to load model\n", __func__); + return 1; + } + + const llama_vocab * vocab = llama_model_get_vocab(model); + + // resolve layer index + const int32_t n_layer = llama_model_n_layer(model); + if (layer < 0) { + layer = n_layer + layer; // -1 -> last layer + } + if (layer < 0 || layer >= n_layer) { + fprintf(stderr, "%s: error: layer %d out of range [0, %d)\n", __func__, layer, n_layer); + llama_model_free(model); + return 1; + } + + // tokenize + const int n_prompt = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true); + std::vector tokens(n_prompt); + if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), tokens.data(), tokens.size(), true, true) < 0) { + fprintf(stderr, "%s: error: failed to tokenize\n", __func__); + llama_model_free(model); + return 1; + } + + // create context with attention weights enabled + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_ctx = n_prompt; + ctx_params.n_batch = n_prompt; + ctx_params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; + ctx_params.attn_weights = true; + ctx_params.no_perf = true; + + llama_context * ctx = llama_init_from_model(model, ctx_params); + if (!ctx) { + fprintf(stderr, "%s: error: failed to create context\n", __func__); + llama_model_free(model); + return 1; + } + + // configure which (layer, head) to extract + { + int32_t layers[] = { (int32_t) layer }; + int32_t heads[] = { (int32_t) head }; + llama_set_attn_heads(ctx, layers, heads, 1); + } + + // prepare batch — request output for all tokens + llama_batch batch = llama_batch_init(n_prompt, 0, 1); + for (int i = 0; i < n_prompt; i++) { + batch.token[i] = tokens[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = 1; // request output for every token + } + batch.n_tokens = n_prompt; + + // decode + if (llama_decode(ctx, batch)) { + fprintf(stderr, "%s: error: decode failed\n", __func__); + llama_batch_free(batch); + llama_free(ctx); + llama_model_free(model); + return 1; + } + + // collect token labels + std::vector labels(n_prompt); + for (int i = 0; i < n_prompt; i++) { + labels[i] = token_str(vocab, tokens[i]); + } + + const int32_t n_kv = llama_get_attn_n_kv(ctx); + const int32_t n_ctx = llama_n_ctx(ctx); + + printf("\nAttention weights — layer %d, head %d\n", layer, head); + printf("Tokens: %d, KV entries: %d\n\n", n_prompt, n_kv); + + // print as a matrix: rows = query tokens, columns = key tokens + // only the first n_kv columns are valid + + const int col_w = 7; // column width for values + const int lbl_w = 12; // label column width + + // header row: key token labels + printf("%s", fit("", lbl_w).c_str()); + for (int k = 0; k < n_kv && k < n_prompt; k++) { + printf("%s", fit(labels[k], col_w).c_str()); + } + printf("\n"); + + // separator + printf("%s", std::string(lbl_w + col_w * std::min(n_kv, (int32_t) n_prompt), '-').c_str()); + printf("\n"); + + // ASCII heatmap chars from low to high attention + const char * heat = " .:-=+*#@"; + const int heat_len = 9; + + // one row per query token + for (int q = 0; q < n_prompt; q++) { + float * attn = llama_get_attn_ith(ctx, q); + if (!attn) { + printf("%s (no data)\n", fit(labels[q], lbl_w).c_str()); + continue; + } + + // find max for this row (for heatmap scaling) + float row_max = 0.0f; + for (int k = 0; k < n_kv && k < n_prompt; k++) { + if (attn[k] > row_max) { + row_max = attn[k]; + } + } + + // numeric row + printf("%s", fit(labels[q], lbl_w).c_str()); + for (int k = 0; k < n_kv && k < n_prompt; k++) { + printf(" %5.3f ", attn[k]); + } + printf("\n"); + + // heatmap row + if (row_max > 0.0f) { + printf("%s", fit("", lbl_w).c_str()); + for (int k = 0; k < n_kv && k < n_prompt; k++) { + float norm = attn[k] / row_max; + int idx = (int)(norm * (heat_len - 1)); + idx = std::max(0, std::min(heat_len - 1, idx)); + // center the char in the column + int pad_l = (col_w - 1) / 2; + int pad_r = col_w - 1 - pad_l; + printf("%*s%c%*s", pad_l, "", heat[idx], pad_r, ""); + } + printf("\n"); + } + } + + printf("\nLegend: ' .:-=+*#@' (low -> high attention)\n"); + + // print the n_ctx value for reference + (void) n_ctx; + + llama_batch_free(batch); + llama_free(ctx); + llama_model_free(model); + + return 0; +} diff --git a/examples/attn-weights/eval_benchmark_translategemma.py b/examples/attn-weights/eval_benchmark_translategemma.py new file mode 100644 index 0000000000..03c07fa800 --- /dev/null +++ b/examples/attn-weights/eval_benchmark_translategemma.py @@ -0,0 +1,343 @@ +#!/usr/bin/env python3 +""" +Benchmark: TranslateGemma 4B + AlignAtt streaming policy (EN→ZH) +llama.cpp implementation — comparable to CTranslate2/benchmark/eval_benchmark_translategemma.py + +Translates FLORES EN→ZH sentences using llama.cpp with alignment heads. +Computes latency metrics (AL, LAAL, AP, CW) and saves translations for quality scoring. + +Usage: + # First convert TranslateGemma to GGUF: + # python3 convert_hf_to_gguf.py /tmp/translategemma-text-only --outfile /tmp/translategemma.gguf + + python3 examples/attn-weights/eval_benchmark_translategemma.py /tmp/translategemma.gguf \\ + --heads examples/attn-weights/translation_heads_en_zh.json \\ + -n 100 --output examples/attn-weights/results_llamacpp.json +""" + +import argparse +import json +import os +import sys +import time + +import numpy as np + +sys.path.insert(0, os.path.dirname(__file__)) +import llama_attn as ll + + +# Copy heads JSON from CTranslate2 benchmark if not present locally +CT2_HEADS = os.path.join( + os.path.dirname(__file__), "..", "..", "..", "CTranslate2", "benchmark", "translation_heads_en_zh.json" +) + +PROMPT_TEMPLATE = ( + "user\n" + "You are a professional English (en) to Chinese (zh) translator. " + "Your goal is to accurately convey the meaning and nuances of the " + "original English text while adhering to Chinese grammar, vocabulary, " + "and cultural sensitivities.\n" + "Produce only the Chinese translation, without any additional " + "explanations or commentary. Please translate the following English " + "text into Chinese:\n\n\n" + "{source}\n" + "model\n" +) + + +def find_source_token_range(source_text, vocab): + """Find token range of source text within the prompt.""" + prefix = PROMPT_TEMPLATE.split("{source}")[0] + suffix = PROMPT_TEMPLATE.split("{source}")[1] + prefix_ids = ll.tokenize(vocab, prefix, add_bos=False, special=True) + full_ids = ll.tokenize(vocab, prefix + source_text + suffix, add_bos=False, special=True) + suffix_ids = ll.tokenize(vocab, suffix, add_bos=False, special=True) + return len(prefix_ids), len(full_ids) - len(suffix_ids) + + +def compute_metrics(alignments, num_src, num_tgt): + """Compute AL, LAAL, AP, Max CW — same as CTranslate2 benchmark.""" + if not alignments or num_tgt == 0 or num_src == 0: + return {"al": 0.0, "max_cw": 0, "laal": 0.0, "ap": 0.0} + + mono = [] + max_dep = 0 + for dep in alignments: + max_dep = max(max_dep, dep) + mono.append(max_dep) + + ratio = num_src / num_tgt + total_lag = sum(max(0, (mono[t] + 1) - t * ratio) for t in range(num_tgt)) + al = total_lag / num_tgt + + tau = min(num_src / num_tgt, 1.0) + laal = al * tau + + total_src_read = sum(mono[t] + 1 for t in range(num_tgt)) + ap = total_src_read / (num_src * num_tgt) + + max_cw = 0 + cw = 0 + prev_needed = 0 + for dep in mono: + needed = dep + 1 + reads = max(0, needed - prev_needed) + if reads > 0: + cw += reads + max_cw = max(max_cw, cw) + else: + cw = 0 + prev_needed = needed + cw = 0 + + return {"al": al, "max_cw": max_cw, "laal": laal, "ap": ap} + + +def aggregate_ts_weighted_vote(src_attn, ts_scores): + """TS-weighted vote: weighted argmax across heads.""" + head_argmaxes = np.argmax(src_attn, axis=1) + weighted = {} + for h, pos in enumerate(head_argmaxes): + pos = int(pos) + weighted[pos] = weighted.get(pos, 0) + ts_scores[h] + return max(weighted, key=weighted.get) + + +def aggregate_mean(src_attn): + avg = src_attn.mean(axis=0) + return int(np.argmax(avg)) + + +def main(): + parser = argparse.ArgumentParser(description="TranslateGemma benchmark with llama.cpp attention") + parser.add_argument("model", help="Path to GGUF model file") + parser.add_argument("-n", type=int, default=100, help="Number of FLORES sentences") + parser.add_argument("--heads", default=None, help="Alignment heads JSON (default: CTranslate2 benchmark)") + parser.add_argument("--top-k", type=int, default=10, help="Number of top alignment heads") + parser.add_argument("--strategy", default="ts_weighted_vote", choices=["ts_weighted_vote", "mean"]) + parser.add_argument("--output", default="benchmark/results_llamacpp.json") + parser.add_argument("--n-ctx", type=int, default=1024) + parser.add_argument("--max-tokens", type=int, default=256) + args = parser.parse_args() + + # Load alignment heads + heads_path = args.heads or CT2_HEADS + if not os.path.exists(heads_path): + # Try local copy + local = os.path.join(os.path.dirname(__file__), "translation_heads_en_zh.json") + if os.path.exists(local): + heads_path = local + else: + print(f"ERROR: Cannot find alignment heads JSON at {heads_path}") + print("Copy from CTranslate2/benchmark/translation_heads_en_zh.json") + sys.exit(1) + + with open(heads_path) as f: + data = json.load(f) + + top_heads = data["token_alignment_heads"][:args.top_k] + head_layers = [h["layer"] for h in top_heads] + head_indices = [h["head"] for h in top_heads] + ts_scores = [h["ts"] for h in top_heads] + num_heads = len(top_heads) + + print(f"Using {num_heads} alignment heads (strategy={args.strategy})") + for h in top_heads[:5]: + print(f" L{h['layer']:2d} H{h['head']} (TS={h['ts']:.3f})") + if num_heads > 5: + print(f" ... and {num_heads - 5} more") + + # Load FLORES EN→ZH + print("\nLoading FLORES EN→ZH...") + from datasets import load_dataset + ds = load_dataset("openlanguagedata/flores_plus", split="dev") + en_ds = ds.filter(lambda x: x["iso_639_3"] == "eng" and x["iso_15924"] == "Latn") + zh_ds = ds.filter(lambda x: x["iso_639_3"] == "cmn" and x["iso_15924"] == "Hans") + en_map = {row["id"]: row["text"] for row in en_ds} + zh_map = {row["id"]: row["text"] for row in zh_ds} + common_ids = sorted(set(en_map) & set(zh_map)) + n = min(args.n, len(common_ids)) + print(f" {n} sentence pairs") + + # Initialize llama.cpp + print(f"\nLoading model: {args.model}") + ll.init() + model = ll.load_model(args.model) + vocab = ll.get_vocab(model) + nv = ll.n_vocab(vocab) + n_layers = ll.n_layer(model) + eos_id = ll.vocab_eos(vocab) + + print(f" {n_layers} layers, vocab={nv}") + + # Find stop token IDs + stop_ids = set() + for tok_str in ["", ""]: + toks = ll.tokenize(vocab, tok_str, add_bos=False, special=True) + if len(toks) == 1: + stop_ids.add(toks[0]) + stop_ids.add(eos_id) + print(f" Stop IDs: {stop_ids}") + + # Evaluate + results = [] + al_list, cw_list, laal_list, ap_list = [], [], [], [] + + for idx in range(n): + sid = common_ids[idx] + source = en_map[sid] + reference = zh_map[sid] + + prompt = PROMPT_TEMPLATE.format(source=source) + prompt_tokens = ll.tokenize(vocab, prompt, add_bos=False, special=True) + prompt_len = len(prompt_tokens) + + src_start, src_end = find_source_token_range(source, vocab) + num_src = src_end - src_start + + # Create fresh context for each sentence (clean KV cache) + ctx = ll.create_context(model, n_ctx=args.n_ctx, n_batch=args.n_ctx, attn_weights=True) + ll.set_attn_heads(ctx, head_layers, head_indices) + n_c = ll.n_ctx(ctx) + + t0 = time.time() + + # Prefill + ret = ll.decode_batch(ctx, prompt_tokens, output_last_only=True) + if ret != 0: + print(f" [{idx+1}] prefill failed: {ret}, skipping") + ll.free_context(ctx) + continue + + # Autoregressive generation with attention extraction + generated_ids = [] + alignments = [] + pos = prompt_len + + for step in range(args.max_tokens): + # Get attention for current token + attn = ll.get_attn_weights(ctx, -1, num_heads, n_c) + + if attn is not None: + # Extract attention over source token range only + n_kv = ll._lib.llama_get_attn_n_kv(ctx) + if src_start < n_kv and src_end <= n_kv: + src_attn = attn[:, src_start:src_end] # (num_heads, num_src) + + if args.strategy == "ts_weighted_vote": + aligned_pos = aggregate_ts_weighted_vote(src_attn, ts_scores) + else: + aligned_pos = aggregate_mean(src_attn) + alignments.append(aligned_pos) + else: + alignments.append(num_src - 1) + else: + alignments.append(num_src - 1) + + # Get next token + next_tok = ll.argmax_logits(ctx, -1, nv) + + if next_tok in stop_ids or next_tok < 0: + break + + generated_ids.append(next_tok) + + # Decode next token + ret = ll.decode_single(ctx, next_tok, pos, output=True) + if ret != 0: + break + pos += 1 + + gen_time = time.time() - t0 + ll.free_context(ctx) + + num_tgt = len(generated_ids) + if num_tgt == 0: + continue + + # Trim alignments to match generated tokens + alignments = alignments[:num_tgt] + + # Decode translation text + pieces = [ll.token_to_piece(vocab, tid) for tid in generated_ids] + translation = "".join(pieces) + + metrics = compute_metrics(alignments, num_src, num_tgt) + al_list.append(metrics["al"]) + cw_list.append(metrics["max_cw"]) + laal_list.append(metrics["laal"]) + ap_list.append(metrics["ap"]) + + results.append({ + "id": int(sid), + "source": source, + "reference": reference, + "translation": translation, + "num_src_tokens": num_src, + "num_tgt_tokens": num_tgt, + "al": round(metrics["al"], 3), + "max_cw": metrics["max_cw"], + "laal": round(metrics["laal"], 3), + "ap": round(metrics["ap"], 3), + "gen_time_ms": round(gen_time * 1000), + }) + + if (idx + 1) % 10 == 0: + avg_al = np.mean(al_list) + print(f" [{idx+1}/{n}] Avg AL={avg_al:.2f}, last: {translation[:50]}...", flush=True) + + # Summary + if not results: + print("No results!") + ll.free_model(model) + ll.cleanup() + sys.exit(1) + + al_arr = np.array(al_list) + cw_arr = np.array(cw_list) + laal_arr = np.array(laal_list) + ap_arr = np.array(ap_list) + + summary = { + "system": "TranslateGemma-4B + AlignAtt (llama.cpp)", + "language_pair": "en-zh", + "num_sentences": len(results), + "num_alignment_heads": num_heads, + "strategy": args.strategy, + "latency": { + "avg_al": round(float(np.mean(al_arr)), 3), + "median_al": round(float(np.median(al_arr)), 3), + "p90_al": round(float(np.percentile(al_arr, 90)), 3), + "avg_max_cw": round(float(np.mean(cw_arr)), 1), + "max_max_cw": int(np.max(cw_arr)), + "avg_laal": round(float(np.mean(laal_arr)), 3), + "avg_ap": round(float(np.mean(ap_arr)), 3), + }, + "avg_gen_time_ms": round(float(np.mean([r["gen_time_ms"] for r in results]))), + } + + output = {"summary": summary, "sentences": results} + os.makedirs(os.path.dirname(args.output) or ".", exist_ok=True) + with open(args.output, "w") as f: + json.dump(output, f, indent=2, ensure_ascii=False) + + print(f"\n{'='*60}") + print(f"TranslateGemma + AlignAtt (llama.cpp) — EN→ZH ({len(results)} sentences)") + print(f"{'='*60}") + print(f" Avg AL: {summary['latency']['avg_al']:.3f}") + print(f" Median AL: {summary['latency']['median_al']:.3f}") + print(f" P90 AL: {summary['latency']['p90_al']:.3f}") + print(f" Avg Max CW: {summary['latency']['avg_max_cw']:.1f}") + print(f" Max CW: {summary['latency']['max_max_cw']}") + print(f" Avg LAAL: {summary['latency']['avg_laal']:.3f}") + print(f" Avg AP: {summary['latency']['avg_ap']:.3f}") + print(f" Avg gen: {summary['avg_gen_time_ms']}ms/sentence") + print(f"\nSaved to {args.output}") + + ll.free_model(model) + ll.cleanup() + + +if __name__ == "__main__": + main() diff --git a/examples/attn-weights/llama_attn.py b/examples/attn-weights/llama_attn.py new file mode 100644 index 0000000000..dc010c5980 --- /dev/null +++ b/examples/attn-weights/llama_attn.py @@ -0,0 +1,387 @@ +""" +Minimal ctypes wrapper for llama.cpp attention weight extraction API. +""" + +import ctypes +import os +import sys + +# Find the shared library +_LIB_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "build", "bin") +_LIB_NAMES = ["libllama.dylib", "libllama.so", "llama.dll"] + +_lib = None +for name in _LIB_NAMES: + path = os.path.join(_LIB_DIR, name) + if os.path.exists(path): + _lib = ctypes.CDLL(path) + break + +if _lib is None: + raise RuntimeError(f"Cannot find libllama in {_LIB_DIR}. Build first with: cmake --build build") + + +# --- Types --- + +class llama_model(ctypes.Structure): + pass + +class llama_context(ctypes.Structure): + pass + +class llama_vocab(ctypes.Structure): + pass + +llama_token = ctypes.c_int32 +llama_pos = ctypes.c_int32 +llama_seq_id = ctypes.c_int32 + + +class llama_batch(ctypes.Structure): + _fields_ = [ + ("n_tokens", ctypes.c_int32), + ("token", ctypes.POINTER(llama_token)), + ("embd", ctypes.POINTER(ctypes.c_float)), + ("pos", ctypes.POINTER(llama_pos)), + ("n_seq_id", ctypes.POINTER(ctypes.c_int32)), + ("seq_id", ctypes.POINTER(ctypes.POINTER(llama_seq_id))), + ("logits", ctypes.POINTER(ctypes.c_int8)), + ] + + +# Model params - we only need to get the default and possibly modify a few fields +# Since the struct is large and complex, we'll treat it as opaque bytes +# and use the C function to get defaults +class llama_model_params(ctypes.Structure): + _fields_ = [ + ("_opaque", ctypes.c_uint8 * 256), # oversized, safe + ] + + +# We need the exact layout of llama_context_params to set attn_weights and flash_attn_type +# Let's read it from the header. The key fields we need are at known positions. +# Instead of matching the exact struct, we'll use the C API defaults and patch bytes. + +# Enums +LLAMA_FLASH_ATTN_TYPE_AUTO = -1 +LLAMA_FLASH_ATTN_TYPE_DISABLED = 0 +LLAMA_FLASH_ATTN_TYPE_ENABLED = 1 + + +# --- Function signatures --- + +# void llama_backend_init(void) +_lib.llama_backend_init.argtypes = [] +_lib.llama_backend_init.restype = None + +# void llama_backend_free(void) +_lib.llama_backend_free.argtypes = [] +_lib.llama_backend_free.restype = None + +# llama_model_params llama_model_default_params(void) +_lib.llama_model_default_params.argtypes = [] +_lib.llama_model_default_params.restype = llama_model_params + +# llama_model * llama_model_load_from_file(const char * path, llama_model_params params) +_lib.llama_model_load_from_file.argtypes = [ctypes.c_char_p, llama_model_params] +_lib.llama_model_load_from_file.restype = ctypes.POINTER(llama_model) + +# void llama_model_free(llama_model * model) +_lib.llama_model_free.argtypes = [ctypes.POINTER(llama_model)] +_lib.llama_model_free.restype = None + +# const llama_vocab * llama_model_get_vocab(const llama_model * model) +_lib.llama_model_get_vocab.argtypes = [ctypes.POINTER(llama_model)] +_lib.llama_model_get_vocab.restype = ctypes.POINTER(llama_vocab) + +# int32_t llama_model_n_layer(const llama_model * model) +_lib.llama_model_n_layer.argtypes = [ctypes.POINTER(llama_model)] +_lib.llama_model_n_layer.restype = ctypes.c_int32 + +# int32_t llama_model_n_head(const llama_model * model) +_lib.llama_model_n_head.argtypes = [ctypes.POINTER(llama_model)] +_lib.llama_model_n_head.restype = ctypes.c_int32 + +# int32_t llama_tokenize(const llama_vocab *, const char *, int32_t, llama_token *, int32_t, bool, bool) +_lib.llama_tokenize.argtypes = [ + ctypes.POINTER(llama_vocab), ctypes.c_char_p, ctypes.c_int32, + ctypes.POINTER(llama_token), ctypes.c_int32, ctypes.c_bool, ctypes.c_bool +] +_lib.llama_tokenize.restype = ctypes.c_int32 + +# llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max) +_lib.llama_batch_init.argtypes = [ctypes.c_int32, ctypes.c_int32, ctypes.c_int32] +_lib.llama_batch_init.restype = llama_batch + +# void llama_batch_free(llama_batch batch) +_lib.llama_batch_free.argtypes = [llama_batch] +_lib.llama_batch_free.restype = None + +# int32_t llama_decode(llama_context * ctx, llama_batch batch) +_lib.llama_decode.argtypes = [ctypes.POINTER(llama_context), llama_batch] +_lib.llama_decode.restype = ctypes.c_int32 + +# void llama_free(llama_context * ctx) +_lib.llama_free.argtypes = [ctypes.POINTER(llama_context)] +_lib.llama_free.restype = None + +# void llama_set_attn_heads(llama_context *, const int32_t * layers, const int32_t * heads, size_t n_pairs) +_lib.llama_set_attn_heads.argtypes = [ + ctypes.POINTER(llama_context), + ctypes.POINTER(ctypes.c_int32), ctypes.POINTER(ctypes.c_int32), ctypes.c_size_t +] +_lib.llama_set_attn_heads.restype = None + +# float * llama_get_attn_ith(llama_context * ctx, int32_t i) +_lib.llama_get_attn_ith.argtypes = [ctypes.POINTER(llama_context), ctypes.c_int32] +_lib.llama_get_attn_ith.restype = ctypes.POINTER(ctypes.c_float) + +# int32_t llama_get_attn_n_kv(llama_context * ctx) +_lib.llama_get_attn_n_kv.argtypes = [ctypes.POINTER(llama_context)] +_lib.llama_get_attn_n_kv.restype = ctypes.c_int32 + +# float * llama_get_logits_ith(llama_context * ctx, int32_t i) +_lib.llama_get_logits_ith.argtypes = [ctypes.POINTER(llama_context), ctypes.c_int32] +_lib.llama_get_logits_ith.restype = ctypes.POINTER(ctypes.c_float) + +# uint32_t llama_n_ctx(const llama_context * ctx) +_lib.llama_n_ctx.argtypes = [ctypes.POINTER(llama_context)] +_lib.llama_n_ctx.restype = ctypes.c_uint32 + +# void llama_synchronize(llama_context * ctx) +_lib.llama_synchronize.argtypes = [ctypes.POINTER(llama_context)] +_lib.llama_synchronize.restype = None + +# int32_t llama_vocab_n_tokens(const llama_vocab * vocab) +_lib.llama_vocab_n_tokens.argtypes = [ctypes.POINTER(llama_vocab)] +_lib.llama_vocab_n_tokens.restype = ctypes.c_int32 + +# llama_token llama_vocab_bos(const llama_vocab * vocab) +_lib.llama_vocab_bos.argtypes = [ctypes.POINTER(llama_vocab)] +_lib.llama_vocab_bos.restype = llama_token + +# llama_token llama_vocab_eos(const llama_vocab * vocab) +_lib.llama_vocab_eos.argtypes = [ctypes.POINTER(llama_vocab)] +_lib.llama_vocab_eos.restype = llama_token + +# int32_t llama_token_to_piece(const llama_vocab *, llama_token, char *, int32_t, int32_t, bool) +_lib.llama_token_to_piece.argtypes = [ + ctypes.POINTER(llama_vocab), llama_token, + ctypes.POINTER(ctypes.c_char), ctypes.c_int32, ctypes.c_int32, ctypes.c_bool +] +_lib.llama_token_to_piece.restype = ctypes.c_int32 + + +# --- Context creation --- +# Since llama_context_params is complex and may change layout between versions, +# we use a helper approach: call llama_context_default_params from C, then patch +# the specific fields we need. + +# We need to know the struct size and field offsets. Let's use a small C helper. +# Actually, let's just build a properly aligned struct by reading the header. +# The key insight: we can create context via a C helper function. + +# For now, let's use ctypes.c_uint8 array as an opaque blob and set fields at +# known byte offsets. This is fragile but works for our specific build. + +def _create_context(model_ptr, n_ctx=512, n_batch=512, attn_weights=True, n_gpu_layers=0): + """Create a llama_context with attention weights enabled. + + Uses a small C shim compiled on-the-fly to avoid struct layout issues. + """ + import tempfile, subprocess + + shim_src = r""" +#include "llama.h" +#include + +// Export a function that creates a context with the right params +__attribute__((visibility("default"))) +struct llama_context * create_ctx_with_attn( + struct llama_model * model, + int n_ctx, int n_batch, int attn_weights, int n_gpu_layers) { + struct llama_context_params params = llama_context_default_params(); + params.n_ctx = n_ctx; + params.n_batch = n_batch; + params.n_ubatch = n_batch; + params.attn_weights = attn_weights ? true : false; + params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; + params.offload_kqv = n_gpu_layers > 0; + return llama_init_from_model(model, params); +} +""" + llama_dir = os.path.join(os.path.dirname(__file__), "..", "..") + include_dir = os.path.join(llama_dir, "include") + lib_dir = os.path.join(llama_dir, "build", "bin") + ggml_include = os.path.join(llama_dir, "ggml", "include") + + with tempfile.NamedTemporaryFile(suffix=".c", mode="w", delete=False) as f: + f.write(shim_src) + src_path = f.name + + shim_lib = os.path.join(lib_dir, "libllama_attn_shim.dylib") + if sys.platform == "linux": + shim_lib = os.path.join(lib_dir, "libllama_attn_shim.so") + + cmd = [ + "cc", "-shared", "-fPIC", "-o", shim_lib, src_path, + f"-I{include_dir}", f"-I{ggml_include}", + f"-L{lib_dir}", "-lllama", + f"-Wl,-rpath,{lib_dir}", + ] + result = subprocess.run(cmd, capture_output=True, text=True) + os.unlink(src_path) + if result.returncode != 0: + raise RuntimeError(f"Failed to compile shim: {result.stderr}") + + shim = ctypes.CDLL(shim_lib) + shim.create_ctx_with_attn.argtypes = [ + ctypes.POINTER(llama_model), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int + ] + shim.create_ctx_with_attn.restype = ctypes.POINTER(llama_context) + + ctx = shim.create_ctx_with_attn(model_ptr, n_ctx, n_batch, 1 if attn_weights else 0, n_gpu_layers) + if not ctx: + raise RuntimeError("Failed to create llama_context") + return ctx + + +# --- High-level helpers --- + +def tokenize(vocab_ptr, text, add_bos=True, special=True): + """Tokenize text, returning a list of token ids.""" + text_bytes = text.encode("utf-8") + buf = (llama_token * (len(text_bytes) + 32))() + n = _lib.llama_tokenize(vocab_ptr, text_bytes, len(text_bytes), buf, len(buf), add_bos, special) + if n < 0: + buf = (llama_token * (-n))() + n = _lib.llama_tokenize(vocab_ptr, text_bytes, len(text_bytes), buf, len(buf), add_bos, special) + return list(buf[:n]) + + +def decode_batch(ctx_ptr, tokens, output_last_only=True): + """Decode a batch of tokens. Returns llama_decode return code.""" + n = len(tokens) + batch = _lib.llama_batch_init(n, 0, 1) + + for i in range(n): + batch.token[i] = tokens[i] + batch.pos[i] = i + batch.n_seq_id[i] = 1 + # Write seq_id value into the pre-allocated buffer (don't replace the pointer) + batch.seq_id[i][0] = 0 + batch.logits[i] = 1 if (not output_last_only or i == n - 1) else 0 + + batch.n_tokens = n + ret = _lib.llama_decode(ctx_ptr, batch) + _lib.llama_batch_free(batch) + return ret + + +def decode_single(ctx_ptr, token, pos, output=True): + """Decode a single token at a given position.""" + batch = _lib.llama_batch_init(1, 0, 1) + batch.token[0] = token + batch.pos[0] = pos + batch.n_seq_id[0] = 1 + batch.seq_id[0][0] = 0 # Write value into pre-allocated buffer + batch.logits[0] = 1 if output else 0 + batch.n_tokens = 1 + ret = _lib.llama_decode(ctx_ptr, batch) + _lib.llama_batch_free(batch) + return ret + + +def get_attn_weights(ctx_ptr, token_idx, n_pairs, n_ctx): + """Get attention weights for a given output token index. + + Returns numpy array of shape (n_pairs, n_kv) or None. + """ + import numpy as np + + ptr = _lib.llama_get_attn_ith(ctx_ptr, token_idx) + if not ptr: + return None + + n_kv = _lib.llama_get_attn_n_kv(ctx_ptr) + if n_kv <= 0: + return None + + # Layout: [n_pairs * n_ctx] floats, each pair has n_ctx floats, first n_kv valid + result = np.zeros((n_pairs, n_kv), dtype=np.float32) + for p in range(n_pairs): + offset = p * n_ctx + arr = (ctypes.c_float * n_kv).from_address(ctypes.addressof(ptr.contents) + offset * 4) + result[p] = np.frombuffer(arr, dtype=np.float32) + + return result + + +def argmax_logits(ctx_ptr, token_idx, n_vocab): + """Get the argmax of logits for a given output token.""" + ptr = _lib.llama_get_logits_ith(ctx_ptr, token_idx) + if not ptr: + return -1 + logits = (ctypes.c_float * n_vocab).from_address(ctypes.addressof(ptr.contents)) + import numpy as np + return int(np.argmax(np.frombuffer(logits, dtype=np.float32))) + + +# --- Public API --- + +def init(): + _lib.llama_backend_init() + +def cleanup(): + _lib.llama_backend_free() + +def load_model(path, n_gpu_layers=0): + params = _lib.llama_model_default_params() + # n_gpu_layers is at offset 0 in llama_model_params (first field) + # Actually let's just use default params for simplicity + model = _lib.llama_model_load_from_file(path.encode(), params) + if not model: + raise RuntimeError(f"Failed to load model from {path}") + return model + +def create_context(model, n_ctx=512, n_batch=512, attn_weights=True): + return _create_context(model, n_ctx, n_batch, attn_weights) + +def set_attn_heads(ctx, layers, heads): + n = len(layers) + assert len(heads) == n + l_arr = (ctypes.c_int32 * n)(*layers) + h_arr = (ctypes.c_int32 * n)(*heads) + _lib.llama_set_attn_heads(ctx, l_arr, h_arr, n) + +def get_vocab(model): + return _lib.llama_model_get_vocab(model) + +def n_layer(model): + return _lib.llama_model_n_layer(model) + +def n_head(model): + return _lib.llama_model_n_head(model) + +def n_vocab(vocab): + return _lib.llama_vocab_n_tokens(vocab) + +def n_ctx(ctx): + return _lib.llama_n_ctx(ctx) + +def vocab_eos(vocab): + return _lib.llama_vocab_eos(vocab) + +def free_context(ctx): + _lib.llama_free(ctx) + +def token_to_piece(vocab, token_id, special=True): + """Convert a single token ID to its string piece.""" + buf = (ctypes.c_char * 256)() + n = _lib.llama_token_to_piece(vocab, token_id, buf, 256, 0, special) + if n > 0: + return buf[:n].decode("utf-8", errors="replace") + return "" + +def free_model(model): + _lib.llama_model_free(model) diff --git a/examples/attn-weights/test_attn_weights.py b/examples/attn-weights/test_attn_weights.py new file mode 100644 index 0000000000..99257ccea1 --- /dev/null +++ b/examples/attn-weights/test_attn_weights.py @@ -0,0 +1,249 @@ +#!/usr/bin/env python3 +""" +Test: verify llama.cpp attention weight extraction works correctly. + +Usage: + python3 benchmark/test_attn_weights.py + +Tests: + 1. Basic extraction: attention weights are non-null and sum to ~1.0 + 2. Multi-head: multiple (layer, head) pairs return independent weights + 3. Greedy generation: attention is extracted at each autoregressive step + 4. Cross-validation with known properties (monotonicity, sparsity) +""" + +import sys +import os +import numpy as np + +sys.path.insert(0, os.path.dirname(__file__)) +import llama_attn as ll + + +def test_basic(model, vocab, ctx, n_layers): + """Test 1: basic attention extraction on a prompt.""" + print("=" * 60) + print("TEST 1: Basic attention extraction") + print("=" * 60) + + tokens = ll.tokenize(vocab, "The quick brown fox jumps over the lazy dog") + n_tokens = len(tokens) + print(f" Tokens: {n_tokens}") + + ret = ll.decode_batch(ctx, tokens, output_last_only=True) + assert ret == 0, f"decode failed: {ret}" + + n_c = ll.n_ctx(ctx) + attn = ll.get_attn_weights(ctx, -1, 1, n_c) + assert attn is not None, "get_attn_weights returned None" + + n_kv = ll._lib.llama_get_attn_n_kv(ctx) + print(f" n_kv: {n_kv}") + print(f" Attention shape: {attn.shape}") + print(f" Attention sum: {attn[0].sum():.6f}") + print(f" Attention max: {attn[0].max():.6f} at position {attn[0].argmax()}") + print(f" Attention min: {attn[0].min():.6f}") + + # Softmax output should sum to ~1.0 + assert abs(attn[0].sum() - 1.0) < 0.05, f"Attention doesn't sum to 1.0: {attn[0].sum()}" + # All values should be non-negative + assert (attn[0] >= 0).all(), "Negative attention values found" + + print(" PASSED\n") + return True + + +def test_multi_head(model, vocab, ctx, n_layers): + """Test 2: multiple (layer, head) pairs.""" + print("=" * 60) + print("TEST 2: Multi-head attention extraction") + print("=" * 60) + + # Set multiple heads across different layers + layers = [0, n_layers // 2, n_layers - 1] + heads = [0, 0, 0] + n_pairs = len(layers) + ll.set_attn_heads(ctx, layers, heads) + print(f" Configured {n_pairs} heads: {list(zip(layers, heads))}") + + tokens = ll.tokenize(vocab, "Hello world, this is a test of attention") + ret = ll.decode_batch(ctx, tokens, output_last_only=True) + assert ret == 0, f"decode failed: {ret}" + + n_c = ll.n_ctx(ctx) + attn = ll.get_attn_weights(ctx, -1, n_pairs, n_c) + assert attn is not None, "get_attn_weights returned None" + + print(f" Attention shape: {attn.shape}") + for p in range(n_pairs): + s = attn[p].sum() + print(f" Pair {p} (L{layers[p]},H{heads[p]}): sum={s:.6f}, max={attn[p].max():.4f} @ pos {attn[p].argmax()}") + assert abs(s - 1.0) < 0.05, f"Pair {p} doesn't sum to 1.0: {s}" + + # Different layers should produce different attention patterns + if n_pairs >= 2: + diff = np.abs(attn[0] - attn[-1]).mean() + print(f" Mean abs difference between first and last layer: {diff:.6f}") + # They should not be identical (unless the model is degenerate) + # Don't assert this as a hard requirement + + # Reset to default (last layer, head 0) + ll.set_attn_heads(ctx, [n_layers - 1], [0]) + + print(" PASSED\n") + return True + + +def test_generation(model, vocab, ctx, n_layers): + """Test 3: attention during autoregressive generation.""" + print("=" * 60) + print("TEST 3: Autoregressive generation with attention") + print("=" * 60) + + tokens = ll.tokenize(vocab, "Once upon a time") + n_prompt = len(tokens) + print(f" Prompt: {n_prompt} tokens") + + # Prefill + ret = ll.decode_batch(ctx, tokens, output_last_only=True) + assert ret == 0, f"prefill decode failed: {ret}" + + n_c = ll.n_ctx(ctx) + nv = ll.n_vocab(vocab) + eos = ll.vocab_eos(vocab) + + max_gen = 10 + gen_tokens = [] + attn_sums = [] + + for step in range(max_gen): + # Get attention for current token + attn = ll.get_attn_weights(ctx, -1, 1, n_c) + assert attn is not None, f"Step {step}: attention is None" + + n_kv = ll._lib.llama_get_attn_n_kv(ctx) + s = attn[0].sum() + attn_sums.append(s) + + # Get next token (greedy) + next_tok = ll.argmax_logits(ctx, -1, nv) + if next_tok == eos: + print(f" Step {step}: EOS") + break + + gen_tokens.append(next_tok) + + # Decode next token + pos = n_prompt + step + ret = ll.decode_single(ctx, next_tok, pos, output=True) + assert ret == 0, f"Step {step}: decode failed: {ret}" + + print(f" Generated {len(gen_tokens)} tokens") + print(f" Attention sums: {[f'{s:.4f}' for s in attn_sums]}") + + for i, s in enumerate(attn_sums): + assert abs(s - 1.0) < 0.05, f"Step {i}: attention sum = {s}" + + print(" PASSED\n") + return True + + +def test_multiple_heads_same_layer(model, vocab, ctx, n_layers): + """Test 4: multiple heads from the same layer.""" + print("=" * 60) + print("TEST 4: Multiple heads from same layer") + print("=" * 60) + + n_h = ll.n_head(model) + last_layer = n_layers - 1 + n_test_heads = min(4, n_h) + + layers = [last_layer] * n_test_heads + heads = list(range(n_test_heads)) + ll.set_attn_heads(ctx, layers, heads) + print(f" Layer {last_layer}, heads {heads}") + + tokens = ll.tokenize(vocab, "Attention is all you need") + ret = ll.decode_batch(ctx, tokens, output_last_only=True) + assert ret == 0, f"decode failed: {ret}" + + n_c = ll.n_ctx(ctx) + attn = ll.get_attn_weights(ctx, -1, n_test_heads, n_c) + assert attn is not None, "get_attn_weights returned None" + + print(f" Attention shape: {attn.shape}") + for h in range(n_test_heads): + s = attn[h].sum() + peak = attn[h].argmax() + print(f" Head {h}: sum={s:.6f}, peak @ pos {peak}, max={attn[h].max():.4f}") + assert abs(s - 1.0) < 0.05, f"Head {h}: sum = {s}" + + # Different heads should show at least some variation + if n_test_heads >= 2: + patterns_identical = all( + np.allclose(attn[0], attn[h], atol=1e-5) + for h in range(1, n_test_heads) + ) + if patterns_identical: + print(" WARNING: all heads have identical attention patterns") + else: + print(" OK: heads show different patterns") + + # Reset + ll.set_attn_heads(ctx, [n_layers - 1], [0]) + + print(" PASSED\n") + return True + + +def main(): + if len(sys.argv) < 2: + print(f"Usage: {sys.argv[0]} ") + sys.exit(1) + + model_path = sys.argv[1] + n_ctx = 512 + if len(sys.argv) > 2: + n_ctx = int(sys.argv[2]) + + print(f"Model: {model_path}") + print(f"n_ctx: {n_ctx}\n") + + ll.init() + + model = ll.load_model(model_path) + vocab = ll.get_vocab(model) + n_layers = ll.n_layer(model) + n_heads = ll.n_head(model) + nv = ll.n_vocab(vocab) + print(f"Loaded: {n_layers} layers, {n_heads} heads, vocab={nv}\n") + + passed = 0 + failed = 0 + + for test_fn in [test_basic, test_multi_head, test_generation, test_multiple_heads_same_layer]: + # Create fresh context for each test + ctx = ll.create_context(model, n_ctx=n_ctx, n_batch=n_ctx, attn_weights=True) + try: + if test_fn(model, vocab, ctx, n_layers): + passed += 1 + else: + failed += 1 + except Exception as e: + print(f" FAILED: {e}\n") + failed += 1 + finally: + ll.free_context(ctx) + + print("=" * 60) + print(f"Results: {passed} passed, {failed} failed") + print("=" * 60) + + ll.free_model(model) + ll.cleanup() + + sys.exit(0 if failed == 0 else 1) + + +if __name__ == "__main__": + main() diff --git a/examples/attn-weights/translation_heads_en_zh.json b/examples/attn-weights/translation_heads_en_zh.json new file mode 100644 index 0000000000..f02f628768 --- /dev/null +++ b/examples/attn-weights/translation_heads_en_zh.json @@ -0,0 +1,558 @@ +{ + "model": "TranslateGemma-4B", + "language_pair": "en-zh", + "num_layers": 34, + "num_heads": 8, + "num_sentences": 200, + "total_alignable_tokens": 5283, + "ts_threshold": 0.1, + "ts_matrix": [ + [ + 0.0015142911224682945, + 0.017982207079311, + 0.029528676888131742, + 0.028203672155971984, + 0.0009464319515426841, + 0.0018928639030853683, + 0.0, + 0.0 + ], + [ + 0.0, + 0.0, + 0.0028392958546280523, + 0.0001892863903085368, + 0.0003785727806170736, + 0.008517887563884156, + 0.0, + 0.0 + ], + [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0017035775127768314, + 0.0, + 0.0, + 0.05148589816392202 + ], + [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0001892863903085368, + 0.0, + 0.0 + ], + [ + 0.045618020064357376, + 0.007003596441415862, + 0.10240393715691842, + 0.02138936210486466, + 0.0, + 0.0, + 0.0, + 0.0 + ], + [ + 0.20632216543630513, + 0.17527919742570508, + 0.21370433465833807, + 0.02574294908196101, + 0.0, + 0.0005678591709256105, + 0.0001892863903085368, + 0.03388226386522809 + ], + [ + 0.11073253833049404, + 0.0028392958546280523, + 0.05830020821502934, + 0.0308536816202915, + 0.10618966496308915, + 0.0001892863903085368, + 0.002271436683702442, + 0.0003785727806170736 + ], + [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0018928639030853683, + 0.003217868635245126, + 0.11186825667234526 + ], + [ + 0.0013250047321597578, + 0.0308536816202915, + 0.0, + 0.0, + 0.03615370054893053, + 0.06719666855953058, + 0.0, + 0.0 + ], + [ + 0.1378004921446148, + 0.0035964414158621994, + 0.13250047321597577, + 0.24626159379140639, + 0.008707173954192694, + 0.039371569184175656, + 0.003975014196479273, + 0.13325761877720993 + ], + [ + 0.0, + 0.0, + 0.0001892863903085368, + 0.0005678591709256105, + 0.0, + 0.0, + 0.0007571455612341472, + 0.0017035775127768314 + ], + [ + 0.0, + 0.06643952299829642, + 0.26500094643195155, + 0.17565777020632217, + 0.011167897028203672, + 0.3577512776831346, + 0.11451826613666477, + 0.07666098807495741 + ], + [ + 0.012114328979746356, + 0.24096157486276737, + 0.0001892863903085368, + 0.0, + 0.004542873367404884, + 0.0056785917092561046, + 0.05848949460533787, + 0.00946431951542684 + ], + [ + 0.0, + 0.0, + 0.167139882642438, + 0.0, + 0.1419647927314026, + 0.021010789324247586, + 0.0026500094643195156, + 0.0009464319515426841 + ], + [ + 0.030285822449365892, + 0.0009464319515426841, + 0.18076850274465267, + 0.24285443876585275, + 0.002271436683702442, + 0.0, + 0.0001892863903085368, + 0.0 + ], + [ + 0.0009464319515426841, + 0.0001892863903085368, + 0.0, + 0.0, + 0.1198182850653038, + 0.049025175089911034, + 0.06814310051107325, + 0.1692220329358319 + ], + [ + 0.0, + 0.07760742002650009, + 0.0, + 0.19307211811470756, + 0.010410751466969525, + 0.017982207079311, + 0.0, + 0.11792542116221844 + ], + [ + 0.2824152943403369, + 0.13003975014196478, + 0.15010410751466968, + 0.3244368729888321, + 0.14328979746356238, + 0.37100132500473215, + 0.0, + 0.0 + ], + [ + 0.0, + 0.049403747870528106, + 0.05318947567669884, + 0.0001892863903085368, + 0.0, + 0.0, + 0.0, + 0.0 + ], + [ + 0.0, + 0.2040507287526027, + 0.0, + 0.0, + 0.0308536816202915, + 0.27900813931478324, + 0.0001892863903085368, + 0.07230740109786106 + ], + [ + 0.012114328979746356, + 0.26537951921256864, + 0.0, + 0.009842892296043914, + 0.0, + 0.0, + 0.0, + 0.0068143100511073255 + ], + [ + 0.3876585273518834, + 0.019875070982396367, + 0.0, + 0.0, + 0.0, + 0.0, + 0.12114328979746357, + 0.0009464319515426841 + ], + [ + 0.0, + 0.0, + 0.0001892863903085368, + 0.0, + 0.0, + 0.19250425894378195, + 0.13496119628998676, + 0.19269354533409047 + ], + [ + 0.030096536059057353, + 0.038425137232632973, + 0.004921446148021957, + 0.3140261215218626, + 0.0, + 0.0, + 0.3736513344690517, + 0.24550444823017226 + ], + [ + 0.0001892863903085368, + 0.08006814310051108, + 0.010221465076660987, + 0.0017035775127768314, + 0.0848003028582245, + 0.0, + 0.0013250047321597578, + 0.0 + ], + [ + 0.0015142911224682945, + 0.0, + 0.0, + 0.0, + 0.0017035775127768314, + 0.0, + 0.0, + 0.0 + ], + [ + 0.024039371569184176, + 0.012682188150671967, + 0.03180011357183418, + 0.006246450880181715, + 0.014007192882831724, + 0.049593034260836645, + 0.0, + 0.0 + ], + [ + 0.0, + 0.007571455612341473, + 0.0, + 0.0, + 0.0, + 0.0, + 0.07363240583002083, + 0.0 + ], + [ + 0.00473215975771342, + 0.022146507666098807, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + [ + 0.0003785727806170736, + 0.0003785727806170736, + 0.0005678591709256105, + 0.0005678591709256105, + 0.002082150293393905, + 0.003028582244936589, + 0.0, + 0.0 + ], + [ + 0.0013250047321597578, + 0.0003785727806170736, + 0.005867878099564641, + 0.01722506151807685, + 0.0003785727806170736, + 0.0003785727806170736, + 0.0026500094643195156, + 0.004542873367404884 + ], + [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.011357183418512209 + ], + [ + 0.0, + 0.025175089911035398, + 0.0, + 0.0001892863903085368, + 0.0, + 0.0, + 0.0, + 0.0 + ], + [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ] + ], + "token_alignment_heads": [ + { + "layer": 21, + "head": 0, + "ts": 0.3877 + }, + { + "layer": 23, + "head": 6, + "ts": 0.3737 + }, + { + "layer": 17, + "head": 5, + "ts": 0.371 + }, + { + "layer": 11, + "head": 5, + "ts": 0.3578 + }, + { + "layer": 17, + "head": 3, + "ts": 0.3244 + }, + { + "layer": 23, + "head": 3, + "ts": 0.314 + }, + { + "layer": 17, + "head": 0, + "ts": 0.2824 + }, + { + "layer": 19, + "head": 5, + "ts": 0.279 + }, + { + "layer": 20, + "head": 1, + "ts": 0.2654 + }, + { + "layer": 11, + "head": 2, + "ts": 0.265 + }, + { + "layer": 9, + "head": 3, + "ts": 0.2463 + }, + { + "layer": 23, + "head": 7, + "ts": 0.2455 + }, + { + "layer": 14, + "head": 3, + "ts": 0.2429 + }, + { + "layer": 12, + "head": 1, + "ts": 0.241 + }, + { + "layer": 5, + "head": 2, + "ts": 0.2137 + }, + { + "layer": 5, + "head": 0, + "ts": 0.2063 + }, + { + "layer": 19, + "head": 1, + "ts": 0.2041 + }, + { + "layer": 16, + "head": 3, + "ts": 0.1931 + }, + { + "layer": 22, + "head": 7, + "ts": 0.1927 + }, + { + "layer": 22, + "head": 5, + "ts": 0.1925 + }, + { + "layer": 14, + "head": 2, + "ts": 0.1808 + }, + { + "layer": 11, + "head": 3, + "ts": 0.1757 + }, + { + "layer": 5, + "head": 1, + "ts": 0.1753 + }, + { + "layer": 15, + "head": 7, + "ts": 0.1692 + }, + { + "layer": 13, + "head": 2, + "ts": 0.1671 + }, + { + "layer": 17, + "head": 2, + "ts": 0.1501 + }, + { + "layer": 17, + "head": 4, + "ts": 0.1433 + }, + { + "layer": 13, + "head": 4, + "ts": 0.142 + }, + { + "layer": 9, + "head": 0, + "ts": 0.1378 + }, + { + "layer": 22, + "head": 6, + "ts": 0.135 + }, + { + "layer": 9, + "head": 7, + "ts": 0.1333 + }, + { + "layer": 9, + "head": 2, + "ts": 0.1325 + }, + { + "layer": 17, + "head": 1, + "ts": 0.13 + }, + { + "layer": 21, + "head": 6, + "ts": 0.1211 + }, + { + "layer": 15, + "head": 4, + "ts": 0.1198 + }, + { + "layer": 16, + "head": 7, + "ts": 0.1179 + }, + { + "layer": 11, + "head": 6, + "ts": 0.1145 + }, + { + "layer": 7, + "head": 7, + "ts": 0.1119 + }, + { + "layer": 6, + "head": 0, + "ts": 0.1107 + }, + { + "layer": 6, + "head": 4, + "ts": 0.1062 + }, + { + "layer": 4, + "head": 2, + "ts": 0.1024 + } + ] +} \ No newline at end of file diff --git a/include/llama.h b/include/llama.h index a84d56a885..ce9442de8f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -370,6 +370,7 @@ extern "C" { bool kv_unified; // use a unified buffer across the input sequences when computing the attention // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix // ref: https://github.com/ggml-org/llama.cpp/pull/14363 + bool attn_weights; // if true, extract attention weights (requires flash_attn disabled) // [EXPERIMENTAL] // backend sampler chain configuration (make sure the caller keeps the sampler chains alive) @@ -999,6 +1000,31 @@ extern "C" { // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); + // + // attention weights API [EXPERIMENTAL] + // requires llama_context_params.attn_weights = true and flash_attn disabled + // + + // Configure which (layer, head) pairs to extract attention weights from. + // By default (if never called), extracts head 0 of the last layer. + // layers and heads are parallel arrays of length n_pairs. + LLAMA_API void llama_set_attn_heads( + struct llama_context * ctx, + const int32_t * layers, + const int32_t * heads, + size_t n_pairs); + + // Get attention weights for the ith output token. + // Returns a pointer to [n_pairs * n_ctx] floats where n_ctx = llama_n_ctx(ctx). + // Each pair occupies n_ctx floats; only the first llama_get_attn_n_kv() entries per pair are valid. + // Pairs are in the order set by llama_set_attn_heads: pair p starts at offset p * n_ctx. + // Returns NULL if attention extraction is not enabled or index is invalid. + LLAMA_API float * llama_get_attn_ith(struct llama_context * ctx, int32_t i); + + // Get the number of valid KV entries per (layer, head) pair in the attention weight arrays. + // Use this to know how many entries to read from each n_ctx-sized chunk. + LLAMA_API int32_t llama_get_attn_n_kv(struct llama_context * ctx); + // // backend sampling API [EXPERIMENTAL] // note: use only if the llama_context was created with at least one llama_sampler_seq_config diff --git a/src/llama-context.cpp b/src/llama-context.cpp index eee9021296..9a16c28c37 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -158,6 +158,17 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; + cparams.attn_weights = params.attn_weights; + if (cparams.attn_weights && cparams.flash_attn) { + throw std::runtime_error("attention weight extraction requires flash_attn to be disabled"); + } + if (cparams.attn_weights) { + // default: head 0 of the last layer + const int32_t default_layer = (int32_t)(hparams.n_layer - 1); + attn_heads.push_back({default_layer, 0}); + cparams.attn_layers.insert(default_layer); + } + // initialized later cparams.pipeline_parallel = false; @@ -771,6 +782,52 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +void llama_context::set_attn_heads(const int32_t * layers, const int32_t * heads, size_t n_pairs) { + const auto & hparams = model.hparams; + + attn_heads.clear(); + attn_heads.reserve(n_pairs); + cparams.attn_layers.clear(); + + for (size_t i = 0; i < n_pairs; i++) { + if (layers[i] < 0 || layers[i] >= (int32_t) hparams.n_layer) { + LLAMA_LOG_ERROR("%s: invalid layer index %d (model has %u layers), skipping\n", + __func__, layers[i], hparams.n_layer); + continue; + } + if (heads[i] < 0 || heads[i] >= (int32_t) hparams.n_head(layers[i])) { + LLAMA_LOG_ERROR("%s: invalid head index %d for layer %d (layer has %u heads), skipping\n", + __func__, heads[i], layers[i], hparams.n_head(layers[i])); + continue; + } + attn_heads.push_back({layers[i], heads[i]}); + cparams.attn_layers.insert(layers[i]); + } + + sched_need_reserve = true; +} + +float * llama_context::get_attn_ith(int32_t i) { + output_reorder(); + + try { + if (attn.data == nullptr) { + return nullptr; + } + + const int64_t j = output_resolve_row(i); + const size_t attn_stride = attn_heads.size() * cparams.n_ctx; + return attn.data + j * attn_stride; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid attention id %d, reason: %s\n", __func__, i, err.what()); + return nullptr; + } +} + +int32_t llama_context::get_attn_n_kv() const { + return attn_n_kv; +} + llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); @@ -1693,6 +1750,59 @@ int llama_context::decode(const llama_batch & batch_inp) { } } + // extract attention weights + if (attn.data && cparams.attn_weights && n_outputs > 0 && !attn_heads.empty()) { + const size_t n_pairs = attn_heads.size(); + const size_t attn_stride = n_pairs * cparams.n_ctx; // stride per output token in the attn buffer + + // iterate over requested (layer, head) pairs + for (size_t p = 0; p < n_pairs; p++) { + const int layer = attn_heads[p].layer; + const int head = attn_heads[p].head; + + auto it = res->t_attn.find(layer); + if (it == res->t_attn.end() || it->second == nullptr) { + continue; + } + + ggml_tensor * t = it->second; + ggml_backend_t backend_attn = ggml_backend_sched_get_tensor_backend(sched.get(), t); + GGML_ASSERT(backend_attn != nullptr); + + // t shape: [n_kv, n_tokens_in_stream, n_head, n_stream] + const int64_t t_n_kv = t->ne[0]; + const int64_t t_n_tokens_in_stream = t->ne[1]; + const int64_t t_n_head = t->ne[2]; + const int64_t t_n_stream = t->ne[3]; + + attn_n_kv = (int32_t) t_n_kv; + + GGML_ASSERT(head < t_n_head); + + // extract attention for each output token + int32_t out_idx = 0; + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + if (!ubatch.output[i]) { + continue; + } + + // compute token position within stream + const int64_t stream = t_n_stream > 1 ? (int64_t) i / t_n_tokens_in_stream : 0; + const int64_t t_in_str = t_n_stream > 1 ? (int64_t) i % t_n_tokens_in_stream : (int64_t) i; + + // byte offset into the tensor for this (token, head, stream) + const size_t src_offset = ((stream * t_n_head + head) * t_n_tokens_in_stream + t_in_str) * t_n_kv * sizeof(float); + + // destination in the output buffer + float * dst = attn.data + (n_outputs_prev + out_idx) * attn_stride + p * cparams.n_ctx; + + ggml_backend_tensor_get_async(backend_attn, t, dst, src_offset, t_n_kv * sizeof(float)); + + out_idx++; + } + } + } + // Copy backend sampling output if this ubatch produced any sampling tensors. if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) { const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); @@ -1795,6 +1905,10 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { logits.size = has_logits ? n_vocab*n_outputs_max : 0; embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + const bool has_attn = cparams.attn_weights && !attn_heads.empty(); + const size_t n_attn_pairs = attn_heads.size(); + attn.size = has_attn ? n_attn_pairs * cparams.n_ctx * n_outputs_max : 0; + // Allocate backend sampling output buffers if there are backend samplers configured. const bool has_sampling = !sampling.samplers.empty(); if (has_sampling) { @@ -1809,8 +1923,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; const size_t new_size = - (logits.size + embd.size + backend_float_count) * sizeof(float) + - ( backend_token_count) * sizeof(llama_token); + (logits.size + embd.size + attn.size + backend_float_count) * sizeof(float) + + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required // TODO: also consider shrinking the buffer @@ -1826,6 +1940,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { buf_output = nullptr; logits.data = nullptr; embd.data = nullptr; + attn.data = nullptr; } auto * buft = ggml_backend_cpu_buffer_type(); @@ -1853,6 +1968,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd = has_embd ? buffer_view{(float *) (base + offset), embd.size} : buffer_view{nullptr, 0}; offset += embd.size * sizeof(float); + attn = has_attn ? buffer_view{(float *) (base + offset), attn.size} : buffer_view{nullptr, 0}; + offset += attn.size * sizeof(float); + if (has_sampling) { sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; offset += sampling.logits.size * sizeof(float); @@ -1917,6 +2035,13 @@ void llama_context::output_reorder() { } } + if (attn.size > 0) { + const uint64_t attn_stride = attn_heads.size() * cparams.n_ctx; + for (uint64_t k = 0; k < attn_stride; k++) { + std::swap(attn.data[i0*attn_stride + k], attn.data[i1*attn_stride + k]); + } + } + if (!sampling.samplers.empty()) { assert(sampling.logits.size > 0); assert(sampling.probs.size > 0); @@ -2799,6 +2924,7 @@ llama_context_params llama_context_default_params() { /*.op_offload =*/ true, /*.swa_full =*/ true, /*.kv_unified =*/ false, + /*.attn_weights =*/ false, /*.sampler =*/ nullptr, /*.n_sampler =*/ 0, }; @@ -2989,6 +3115,22 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } +void llama_set_attn_heads(llama_context * ctx, const int32_t * layers, const int32_t * heads, size_t n_pairs) { + ctx->set_attn_heads(layers, heads, n_pairs); +} + +float * llama_get_attn_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_attn_ith(i); +} + +int32_t llama_get_attn_n_kv(llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_attn_n_kv(); +} + bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { return ctx->set_sampler(seq_id, smpl); } diff --git a/src/llama-context.h b/src/llama-context.h index e0d0085c1c..78a4f3a637 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -79,6 +79,10 @@ struct llama_context { float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + void set_attn_heads(const int32_t * layers, const int32_t * heads, size_t n_pairs); + float * get_attn_ith(int32_t i); + int32_t get_attn_n_kv() const; + llama_token * get_sampled_tokens() const; llama_token get_sampled_token_ith(int32_t idx); @@ -295,6 +299,15 @@ private: // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE std::map> embd_seq; + // attention weight extraction + struct attn_head_pair { + int32_t layer; + int32_t head; + }; + std::vector attn_heads; // which (layer, head) pairs to extract + buffer_view attn = {nullptr, 0}; // [n_outputs][n_pairs * n_ctx] + int32_t attn_n_kv = 0; // KV cache length at time of last decode + // reuse the batch_allocr to avoid unnecessary memory allocations std::unique_ptr balloc; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 2da3bbd6f9..f97c235895 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -3,6 +3,7 @@ #include "llama.h" #include +#include #define LLAMA_MAX_SEQ 256 @@ -36,6 +37,9 @@ struct llama_cparams { bool op_offload; bool kv_unified; bool pipeline_parallel; + bool attn_weights; + + std::set attn_layers; // which layers to extract attention from (derived from attn_heads) enum llama_pooling_type pooling_type; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index b8126ce508..add8dd4ad7 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -736,6 +736,7 @@ void llm_graph_result::reset() { t_logits = nullptr; t_embd = nullptr; t_embd_pooled = nullptr; + t_attn.clear(); t_sampled.clear(); t_sampled_probs.clear(); t_sampled_logits.clear(); @@ -789,6 +790,11 @@ void llm_graph_result::set_outputs() { ggml_set_output(t); } } + for (auto & [layer, t] : t_attn) { + if (t != nullptr) { + ggml_set_output(t); + } + } for (auto & [seq_id, t] : t_candidates) { if (t != nullptr) { ggml_set_output(t); @@ -1833,6 +1839,10 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_soft_max_add_sinks(kq, sinks); cb(kq, "kq_soft_max", il); + if (!cparams.attn_layers.empty() && cparams.attn_layers.count(il)) { + res->t_attn[il] = kq; + } + if (!v_trans) { // note: avoid this branch v = ggml_cont(ctx0, ggml_transpose(ctx0, v)); diff --git a/src/llama-graph.h b/src/llama-graph.h index e8f006977d..ee99260463 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -613,8 +613,9 @@ struct llm_graph_params { } return - cparams.embeddings == other.cparams.embeddings && - cparams.causal_attn == other.cparams.causal_attn && + cparams.embeddings == other.cparams.embeddings && + cparams.causal_attn == other.cparams.causal_attn && + cparams.attn_layers == other.cparams.attn_layers && arch == other.arch && gtype == other.gtype && cvec == other.cvec && @@ -667,6 +668,10 @@ public: std::map t_sampled; std::map t_sampled_probs; + // attention weight tensors, keyed by layer index + // shape per tensor: [n_kv, n_tokens, n_head, n_stream] (after permute in build_attn_mha) + std::map t_attn; + std::vector inputs; ggml_context_ptr ctx_compute; From b550fa6e18073fee9256c5aabbe32f32fbf4cd7d Mon Sep 17 00:00:00 2001 From: Quentin Fuxa Date: Thu, 5 Mar 2026 11:34:03 +0100 Subject: [PATCH 2/2] Use internal cb_eval for attention extraction to eliminate graph splits Instead of marking attention tensors as output (which caused ~70 graph splits and +28% overhead), extract attention weights via an internal cb_eval callback during graph execution. This reduces graph splits back to baseline (2) with ~0% overhead while keeping the clean public API (llama_set_attn_heads + llama_get_attn_ith). Changes: - Remove ggml_set_output() on t_attn tensors in set_outputs() - Add attn_cb_eval_fn that intercepts kq_soft_max tensors and copies attention data using ggml_backend_tensor_get with tensor strides - Remove post-decode attention extraction loop (now done during execution) - Remove sched_need_reserve from set_attn_heads (graph topology unchanged) - Chain to user's cb_eval callback when both are active --- src/llama-context.cpp | 142 +++++++++++++++++++++++++----------------- src/llama-context.h | 9 +++ src/llama-graph.cpp | 5 -- 3 files changed, 95 insertions(+), 61 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 9a16c28c37..7c66f9a1d3 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -804,7 +804,6 @@ void llama_context::set_attn_heads(const int32_t * layers, const int32_t * heads cparams.attn_layers.insert(layers[i]); } - sched_need_reserve = true; } float * llama_context::get_attn_ith(int32_t i) { @@ -828,6 +827,81 @@ int32_t llama_context::get_attn_n_kv() const { return attn_n_kv; } +// static +bool llama_context::attn_cb_eval_fn(struct ggml_tensor * t, bool ask, void * user_data) { + auto * state = static_cast(user_data); + auto * ctx = state->ctx; + + const auto user_cb = ctx->cparams.cb_eval; + const auto user_data_cb = ctx->cparams.cb_eval_user_data; + + const char * name = t->name; + if (strncmp(name, "kq_soft_max-", 12) != 0) { + return user_cb ? user_cb(t, ask, user_data_cb) : false; + } + + const int layer_idx = atoi(name + 12); + + if (ctx->cparams.attn_layers.find(layer_idx) == ctx->cparams.attn_layers.end()) { + return user_cb ? user_cb(t, ask, user_data_cb) : false; + } + + if (ask) { + if (user_cb) { + user_cb(t, true, user_data_cb); + } + return true; + } + + // data is available, extract attention weights + const auto & ubatch = *state->ubatch; + const auto & attn_heads_vec = ctx->attn_heads; + const size_t n_pairs = attn_heads_vec.size(); + const size_t attn_stride = n_pairs * ctx->cparams.n_ctx; + + const int64_t t_n_kv = t->ne[0]; + const int64_t t_n_tokens = t->ne[1]; + const int64_t t_n_head = t->ne[2]; + const int64_t t_n_stream = t->ne[3]; + + ctx->attn_n_kv = (int32_t) t_n_kv; + + for (size_t p = 0; p < n_pairs; p++) { + if (attn_heads_vec[p].layer != layer_idx) { + continue; + } + + const int head = attn_heads_vec[p].head; + if (head >= t_n_head) { + continue; + } + + int32_t out_idx = 0; + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + if (!ubatch.output[i]) { + continue; + } + + const int64_t stream = t_n_stream > 1 ? (int64_t) i / t_n_tokens : 0; + const int64_t t_in_str = t_n_stream > 1 ? (int64_t) i % t_n_tokens : (int64_t) i; + + const size_t src_offset = stream * t->nb[3] + head * t->nb[2] + t_in_str * t->nb[1]; + + float * dst = ctx->attn.data + (state->n_outputs_prev + out_idx) * attn_stride + p * ctx->cparams.n_ctx; + + ggml_backend_tensor_get(t, dst, src_offset, t_n_kv * sizeof(float)); + + out_idx++; + } + } + + if (user_cb) { + return user_cb(t, false, user_data_cb); + } + + return true; +} + llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); @@ -1146,7 +1220,12 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll res->reset(); ggml_backend_sched_reset(sched.get()); - ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + + if (attn.data && cparams.attn_weights && !attn_heads.empty()) { + ggml_backend_sched_set_eval_callback(sched.get(), attn_cb_eval_fn, &attn_cb); + } else { + ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + } //const auto t_start_us = ggml_time_us(); @@ -1633,6 +1712,10 @@ int llama_context::decode(const llama_batch & batch_inp) { n_outputs = n_outputs_new; } + attn_cb.ctx = this; + attn_cb.ubatch = &ubatch; + attn_cb.n_outputs_prev = n_outputs_prev; + ggml_status status; const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); @@ -1750,59 +1833,6 @@ int llama_context::decode(const llama_batch & batch_inp) { } } - // extract attention weights - if (attn.data && cparams.attn_weights && n_outputs > 0 && !attn_heads.empty()) { - const size_t n_pairs = attn_heads.size(); - const size_t attn_stride = n_pairs * cparams.n_ctx; // stride per output token in the attn buffer - - // iterate over requested (layer, head) pairs - for (size_t p = 0; p < n_pairs; p++) { - const int layer = attn_heads[p].layer; - const int head = attn_heads[p].head; - - auto it = res->t_attn.find(layer); - if (it == res->t_attn.end() || it->second == nullptr) { - continue; - } - - ggml_tensor * t = it->second; - ggml_backend_t backend_attn = ggml_backend_sched_get_tensor_backend(sched.get(), t); - GGML_ASSERT(backend_attn != nullptr); - - // t shape: [n_kv, n_tokens_in_stream, n_head, n_stream] - const int64_t t_n_kv = t->ne[0]; - const int64_t t_n_tokens_in_stream = t->ne[1]; - const int64_t t_n_head = t->ne[2]; - const int64_t t_n_stream = t->ne[3]; - - attn_n_kv = (int32_t) t_n_kv; - - GGML_ASSERT(head < t_n_head); - - // extract attention for each output token - int32_t out_idx = 0; - for (uint32_t i = 0; i < ubatch.n_tokens; i++) { - if (!ubatch.output[i]) { - continue; - } - - // compute token position within stream - const int64_t stream = t_n_stream > 1 ? (int64_t) i / t_n_tokens_in_stream : 0; - const int64_t t_in_str = t_n_stream > 1 ? (int64_t) i % t_n_tokens_in_stream : (int64_t) i; - - // byte offset into the tensor for this (token, head, stream) - const size_t src_offset = ((stream * t_n_head + head) * t_n_tokens_in_stream + t_in_str) * t_n_kv * sizeof(float); - - // destination in the output buffer - float * dst = attn.data + (n_outputs_prev + out_idx) * attn_stride + p * cparams.n_ctx; - - ggml_backend_tensor_get_async(backend_attn, t, dst, src_offset, t_n_kv * sizeof(float)); - - out_idx++; - } - } - } - // Copy backend sampling output if this ubatch produced any sampling tensors. if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) { const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); @@ -2201,7 +2231,7 @@ ggml_status llama_context::graph_compute( LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status); } - // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched)); + // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched.get())); return status; } diff --git a/src/llama-context.h b/src/llama-context.h index 78a4f3a637..034e6739f6 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -308,6 +308,15 @@ private: buffer_view attn = {nullptr, 0}; // [n_outputs][n_pairs * n_ctx] int32_t attn_n_kv = 0; // KV cache length at time of last decode + struct attn_cb_state { + llama_context * ctx = nullptr; + const llama_ubatch * ubatch = nullptr; + int64_t n_outputs_prev = 0; + }; + attn_cb_state attn_cb; + + static bool attn_cb_eval_fn(struct ggml_tensor * t, bool ask, void * user_data); + // reuse the batch_allocr to avoid unnecessary memory allocations std::unique_ptr balloc; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index add8dd4ad7..6ade0030cd 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -790,11 +790,6 @@ void llm_graph_result::set_outputs() { ggml_set_output(t); } } - for (auto & [layer, t] : t_attn) { - if (t != nullptr) { - ggml_set_output(t); - } - } for (auto & [seq_id, t] : t_candidates) { if (t != nullptr) { ggml_set_output(t);