llama.cpp/examples/attn-weights/test_attn_weights.py

250 lines
7.5 KiB
Python

#!/usr/bin/env python3
"""
Test: verify llama.cpp attention weight extraction works correctly.
Usage:
python3 benchmark/test_attn_weights.py <model.gguf>
Tests:
1. Basic extraction: attention weights are non-null and sum to ~1.0
2. Multi-head: multiple (layer, head) pairs return independent weights
3. Greedy generation: attention is extracted at each autoregressive step
4. Cross-validation with known properties (monotonicity, sparsity)
"""
import sys
import os
import numpy as np
sys.path.insert(0, os.path.dirname(__file__))
import llama_attn as ll
def test_basic(model, vocab, ctx, n_layers):
"""Test 1: basic attention extraction on a prompt."""
print("=" * 60)
print("TEST 1: Basic attention extraction")
print("=" * 60)
tokens = ll.tokenize(vocab, "The quick brown fox jumps over the lazy dog")
n_tokens = len(tokens)
print(f" Tokens: {n_tokens}")
ret = ll.decode_batch(ctx, tokens, output_last_only=True)
assert ret == 0, f"decode failed: {ret}"
n_c = ll.n_ctx(ctx)
attn = ll.get_attn_weights(ctx, -1, 1, n_c)
assert attn is not None, "get_attn_weights returned None"
n_kv = ll._lib.llama_get_attn_n_kv(ctx)
print(f" n_kv: {n_kv}")
print(f" Attention shape: {attn.shape}")
print(f" Attention sum: {attn[0].sum():.6f}")
print(f" Attention max: {attn[0].max():.6f} at position {attn[0].argmax()}")
print(f" Attention min: {attn[0].min():.6f}")
# Softmax output should sum to ~1.0
assert abs(attn[0].sum() - 1.0) < 0.05, f"Attention doesn't sum to 1.0: {attn[0].sum()}"
# All values should be non-negative
assert (attn[0] >= 0).all(), "Negative attention values found"
print(" PASSED\n")
return True
def test_multi_head(model, vocab, ctx, n_layers):
"""Test 2: multiple (layer, head) pairs."""
print("=" * 60)
print("TEST 2: Multi-head attention extraction")
print("=" * 60)
# Set multiple heads across different layers
layers = [0, n_layers // 2, n_layers - 1]
heads = [0, 0, 0]
n_pairs = len(layers)
ll.set_attn_heads(ctx, layers, heads)
print(f" Configured {n_pairs} heads: {list(zip(layers, heads))}")
tokens = ll.tokenize(vocab, "Hello world, this is a test of attention")
ret = ll.decode_batch(ctx, tokens, output_last_only=True)
assert ret == 0, f"decode failed: {ret}"
n_c = ll.n_ctx(ctx)
attn = ll.get_attn_weights(ctx, -1, n_pairs, n_c)
assert attn is not None, "get_attn_weights returned None"
print(f" Attention shape: {attn.shape}")
for p in range(n_pairs):
s = attn[p].sum()
print(f" Pair {p} (L{layers[p]},H{heads[p]}): sum={s:.6f}, max={attn[p].max():.4f} @ pos {attn[p].argmax()}")
assert abs(s - 1.0) < 0.05, f"Pair {p} doesn't sum to 1.0: {s}"
# Different layers should produce different attention patterns
if n_pairs >= 2:
diff = np.abs(attn[0] - attn[-1]).mean()
print(f" Mean abs difference between first and last layer: {diff:.6f}")
# They should not be identical (unless the model is degenerate)
# Don't assert this as a hard requirement
# Reset to default (last layer, head 0)
ll.set_attn_heads(ctx, [n_layers - 1], [0])
print(" PASSED\n")
return True
def test_generation(model, vocab, ctx, n_layers):
"""Test 3: attention during autoregressive generation."""
print("=" * 60)
print("TEST 3: Autoregressive generation with attention")
print("=" * 60)
tokens = ll.tokenize(vocab, "Once upon a time")
n_prompt = len(tokens)
print(f" Prompt: {n_prompt} tokens")
# Prefill
ret = ll.decode_batch(ctx, tokens, output_last_only=True)
assert ret == 0, f"prefill decode failed: {ret}"
n_c = ll.n_ctx(ctx)
nv = ll.n_vocab(vocab)
eos = ll.vocab_eos(vocab)
max_gen = 10
gen_tokens = []
attn_sums = []
for step in range(max_gen):
# Get attention for current token
attn = ll.get_attn_weights(ctx, -1, 1, n_c)
assert attn is not None, f"Step {step}: attention is None"
n_kv = ll._lib.llama_get_attn_n_kv(ctx)
s = attn[0].sum()
attn_sums.append(s)
# Get next token (greedy)
next_tok = ll.argmax_logits(ctx, -1, nv)
if next_tok == eos:
print(f" Step {step}: EOS")
break
gen_tokens.append(next_tok)
# Decode next token
pos = n_prompt + step
ret = ll.decode_single(ctx, next_tok, pos, output=True)
assert ret == 0, f"Step {step}: decode failed: {ret}"
print(f" Generated {len(gen_tokens)} tokens")
print(f" Attention sums: {[f'{s:.4f}' for s in attn_sums]}")
for i, s in enumerate(attn_sums):
assert abs(s - 1.0) < 0.05, f"Step {i}: attention sum = {s}"
print(" PASSED\n")
return True
def test_multiple_heads_same_layer(model, vocab, ctx, n_layers):
"""Test 4: multiple heads from the same layer."""
print("=" * 60)
print("TEST 4: Multiple heads from same layer")
print("=" * 60)
n_h = ll.n_head(model)
last_layer = n_layers - 1
n_test_heads = min(4, n_h)
layers = [last_layer] * n_test_heads
heads = list(range(n_test_heads))
ll.set_attn_heads(ctx, layers, heads)
print(f" Layer {last_layer}, heads {heads}")
tokens = ll.tokenize(vocab, "Attention is all you need")
ret = ll.decode_batch(ctx, tokens, output_last_only=True)
assert ret == 0, f"decode failed: {ret}"
n_c = ll.n_ctx(ctx)
attn = ll.get_attn_weights(ctx, -1, n_test_heads, n_c)
assert attn is not None, "get_attn_weights returned None"
print(f" Attention shape: {attn.shape}")
for h in range(n_test_heads):
s = attn[h].sum()
peak = attn[h].argmax()
print(f" Head {h}: sum={s:.6f}, peak @ pos {peak}, max={attn[h].max():.4f}")
assert abs(s - 1.0) < 0.05, f"Head {h}: sum = {s}"
# Different heads should show at least some variation
if n_test_heads >= 2:
patterns_identical = all(
np.allclose(attn[0], attn[h], atol=1e-5)
for h in range(1, n_test_heads)
)
if patterns_identical:
print(" WARNING: all heads have identical attention patterns")
else:
print(" OK: heads show different patterns")
# Reset
ll.set_attn_heads(ctx, [n_layers - 1], [0])
print(" PASSED\n")
return True
def main():
if len(sys.argv) < 2:
print(f"Usage: {sys.argv[0]} <model.gguf>")
sys.exit(1)
model_path = sys.argv[1]
n_ctx = 512
if len(sys.argv) > 2:
n_ctx = int(sys.argv[2])
print(f"Model: {model_path}")
print(f"n_ctx: {n_ctx}\n")
ll.init()
model = ll.load_model(model_path)
vocab = ll.get_vocab(model)
n_layers = ll.n_layer(model)
n_heads = ll.n_head(model)
nv = ll.n_vocab(vocab)
print(f"Loaded: {n_layers} layers, {n_heads} heads, vocab={nv}\n")
passed = 0
failed = 0
for test_fn in [test_basic, test_multi_head, test_generation, test_multiple_heads_same_layer]:
# Create fresh context for each test
ctx = ll.create_context(model, n_ctx=n_ctx, n_batch=n_ctx, attn_weights=True)
try:
if test_fn(model, vocab, ctx, n_layers):
passed += 1
else:
failed += 1
except Exception as e:
print(f" FAILED: {e}\n")
failed += 1
finally:
ll.free_context(ctx)
print("=" * 60)
print(f"Results: {passed} passed, {failed} failed")
print("=" * 60)
ll.free_model(model)
ll.cleanup()
sys.exit(0 if failed == 0 else 1)
if __name__ == "__main__":
main()