388 lines
13 KiB
Python
388 lines
13 KiB
Python
"""
|
|
Minimal ctypes wrapper for llama.cpp attention weight extraction API.
|
|
"""
|
|
|
|
import ctypes
|
|
import os
|
|
import sys
|
|
|
|
# Find the shared library
|
|
_LIB_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "build", "bin")
|
|
_LIB_NAMES = ["libllama.dylib", "libllama.so", "llama.dll"]
|
|
|
|
_lib = None
|
|
for name in _LIB_NAMES:
|
|
path = os.path.join(_LIB_DIR, name)
|
|
if os.path.exists(path):
|
|
_lib = ctypes.CDLL(path)
|
|
break
|
|
|
|
if _lib is None:
|
|
raise RuntimeError(f"Cannot find libllama in {_LIB_DIR}. Build first with: cmake --build build")
|
|
|
|
|
|
# --- Types ---
|
|
|
|
class llama_model(ctypes.Structure):
|
|
pass
|
|
|
|
class llama_context(ctypes.Structure):
|
|
pass
|
|
|
|
class llama_vocab(ctypes.Structure):
|
|
pass
|
|
|
|
llama_token = ctypes.c_int32
|
|
llama_pos = ctypes.c_int32
|
|
llama_seq_id = ctypes.c_int32
|
|
|
|
|
|
class llama_batch(ctypes.Structure):
|
|
_fields_ = [
|
|
("n_tokens", ctypes.c_int32),
|
|
("token", ctypes.POINTER(llama_token)),
|
|
("embd", ctypes.POINTER(ctypes.c_float)),
|
|
("pos", ctypes.POINTER(llama_pos)),
|
|
("n_seq_id", ctypes.POINTER(ctypes.c_int32)),
|
|
("seq_id", ctypes.POINTER(ctypes.POINTER(llama_seq_id))),
|
|
("logits", ctypes.POINTER(ctypes.c_int8)),
|
|
]
|
|
|
|
|
|
# Model params - we only need to get the default and possibly modify a few fields
|
|
# Since the struct is large and complex, we'll treat it as opaque bytes
|
|
# and use the C function to get defaults
|
|
class llama_model_params(ctypes.Structure):
|
|
_fields_ = [
|
|
("_opaque", ctypes.c_uint8 * 256), # oversized, safe
|
|
]
|
|
|
|
|
|
# We need the exact layout of llama_context_params to set attn_weights and flash_attn_type
|
|
# Let's read it from the header. The key fields we need are at known positions.
|
|
# Instead of matching the exact struct, we'll use the C API defaults and patch bytes.
|
|
|
|
# Enums
|
|
LLAMA_FLASH_ATTN_TYPE_AUTO = -1
|
|
LLAMA_FLASH_ATTN_TYPE_DISABLED = 0
|
|
LLAMA_FLASH_ATTN_TYPE_ENABLED = 1
|
|
|
|
|
|
# --- Function signatures ---
|
|
|
|
# void llama_backend_init(void)
|
|
_lib.llama_backend_init.argtypes = []
|
|
_lib.llama_backend_init.restype = None
|
|
|
|
# void llama_backend_free(void)
|
|
_lib.llama_backend_free.argtypes = []
|
|
_lib.llama_backend_free.restype = None
|
|
|
|
# llama_model_params llama_model_default_params(void)
|
|
_lib.llama_model_default_params.argtypes = []
|
|
_lib.llama_model_default_params.restype = llama_model_params
|
|
|
|
# llama_model * llama_model_load_from_file(const char * path, llama_model_params params)
|
|
_lib.llama_model_load_from_file.argtypes = [ctypes.c_char_p, llama_model_params]
|
|
_lib.llama_model_load_from_file.restype = ctypes.POINTER(llama_model)
|
|
|
|
# void llama_model_free(llama_model * model)
|
|
_lib.llama_model_free.argtypes = [ctypes.POINTER(llama_model)]
|
|
_lib.llama_model_free.restype = None
|
|
|
|
# const llama_vocab * llama_model_get_vocab(const llama_model * model)
|
|
_lib.llama_model_get_vocab.argtypes = [ctypes.POINTER(llama_model)]
|
|
_lib.llama_model_get_vocab.restype = ctypes.POINTER(llama_vocab)
|
|
|
|
# int32_t llama_model_n_layer(const llama_model * model)
|
|
_lib.llama_model_n_layer.argtypes = [ctypes.POINTER(llama_model)]
|
|
_lib.llama_model_n_layer.restype = ctypes.c_int32
|
|
|
|
# int32_t llama_model_n_head(const llama_model * model)
|
|
_lib.llama_model_n_head.argtypes = [ctypes.POINTER(llama_model)]
|
|
_lib.llama_model_n_head.restype = ctypes.c_int32
|
|
|
|
# int32_t llama_tokenize(const llama_vocab *, const char *, int32_t, llama_token *, int32_t, bool, bool)
|
|
_lib.llama_tokenize.argtypes = [
|
|
ctypes.POINTER(llama_vocab), ctypes.c_char_p, ctypes.c_int32,
|
|
ctypes.POINTER(llama_token), ctypes.c_int32, ctypes.c_bool, ctypes.c_bool
|
|
]
|
|
_lib.llama_tokenize.restype = ctypes.c_int32
|
|
|
|
# llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max)
|
|
_lib.llama_batch_init.argtypes = [ctypes.c_int32, ctypes.c_int32, ctypes.c_int32]
|
|
_lib.llama_batch_init.restype = llama_batch
|
|
|
|
# void llama_batch_free(llama_batch batch)
|
|
_lib.llama_batch_free.argtypes = [llama_batch]
|
|
_lib.llama_batch_free.restype = None
|
|
|
|
# int32_t llama_decode(llama_context * ctx, llama_batch batch)
|
|
_lib.llama_decode.argtypes = [ctypes.POINTER(llama_context), llama_batch]
|
|
_lib.llama_decode.restype = ctypes.c_int32
|
|
|
|
# void llama_free(llama_context * ctx)
|
|
_lib.llama_free.argtypes = [ctypes.POINTER(llama_context)]
|
|
_lib.llama_free.restype = None
|
|
|
|
# void llama_set_attn_heads(llama_context *, const int32_t * layers, const int32_t * heads, size_t n_pairs)
|
|
_lib.llama_set_attn_heads.argtypes = [
|
|
ctypes.POINTER(llama_context),
|
|
ctypes.POINTER(ctypes.c_int32), ctypes.POINTER(ctypes.c_int32), ctypes.c_size_t
|
|
]
|
|
_lib.llama_set_attn_heads.restype = None
|
|
|
|
# float * llama_get_attn_ith(llama_context * ctx, int32_t i)
|
|
_lib.llama_get_attn_ith.argtypes = [ctypes.POINTER(llama_context), ctypes.c_int32]
|
|
_lib.llama_get_attn_ith.restype = ctypes.POINTER(ctypes.c_float)
|
|
|
|
# int32_t llama_get_attn_n_kv(llama_context * ctx)
|
|
_lib.llama_get_attn_n_kv.argtypes = [ctypes.POINTER(llama_context)]
|
|
_lib.llama_get_attn_n_kv.restype = ctypes.c_int32
|
|
|
|
# float * llama_get_logits_ith(llama_context * ctx, int32_t i)
|
|
_lib.llama_get_logits_ith.argtypes = [ctypes.POINTER(llama_context), ctypes.c_int32]
|
|
_lib.llama_get_logits_ith.restype = ctypes.POINTER(ctypes.c_float)
|
|
|
|
# uint32_t llama_n_ctx(const llama_context * ctx)
|
|
_lib.llama_n_ctx.argtypes = [ctypes.POINTER(llama_context)]
|
|
_lib.llama_n_ctx.restype = ctypes.c_uint32
|
|
|
|
# void llama_synchronize(llama_context * ctx)
|
|
_lib.llama_synchronize.argtypes = [ctypes.POINTER(llama_context)]
|
|
_lib.llama_synchronize.restype = None
|
|
|
|
# int32_t llama_vocab_n_tokens(const llama_vocab * vocab)
|
|
_lib.llama_vocab_n_tokens.argtypes = [ctypes.POINTER(llama_vocab)]
|
|
_lib.llama_vocab_n_tokens.restype = ctypes.c_int32
|
|
|
|
# llama_token llama_vocab_bos(const llama_vocab * vocab)
|
|
_lib.llama_vocab_bos.argtypes = [ctypes.POINTER(llama_vocab)]
|
|
_lib.llama_vocab_bos.restype = llama_token
|
|
|
|
# llama_token llama_vocab_eos(const llama_vocab * vocab)
|
|
_lib.llama_vocab_eos.argtypes = [ctypes.POINTER(llama_vocab)]
|
|
_lib.llama_vocab_eos.restype = llama_token
|
|
|
|
# int32_t llama_token_to_piece(const llama_vocab *, llama_token, char *, int32_t, int32_t, bool)
|
|
_lib.llama_token_to_piece.argtypes = [
|
|
ctypes.POINTER(llama_vocab), llama_token,
|
|
ctypes.POINTER(ctypes.c_char), ctypes.c_int32, ctypes.c_int32, ctypes.c_bool
|
|
]
|
|
_lib.llama_token_to_piece.restype = ctypes.c_int32
|
|
|
|
|
|
# --- Context creation ---
|
|
# Since llama_context_params is complex and may change layout between versions,
|
|
# we use a helper approach: call llama_context_default_params from C, then patch
|
|
# the specific fields we need.
|
|
|
|
# We need to know the struct size and field offsets. Let's use a small C helper.
|
|
# Actually, let's just build a properly aligned struct by reading the header.
|
|
# The key insight: we can create context via a C helper function.
|
|
|
|
# For now, let's use ctypes.c_uint8 array as an opaque blob and set fields at
|
|
# known byte offsets. This is fragile but works for our specific build.
|
|
|
|
def _create_context(model_ptr, n_ctx=512, n_batch=512, attn_weights=True, n_gpu_layers=0):
|
|
"""Create a llama_context with attention weights enabled.
|
|
|
|
Uses a small C shim compiled on-the-fly to avoid struct layout issues.
|
|
"""
|
|
import tempfile, subprocess
|
|
|
|
shim_src = r"""
|
|
#include "llama.h"
|
|
#include <stdlib.h>
|
|
|
|
// Export a function that creates a context with the right params
|
|
__attribute__((visibility("default")))
|
|
struct llama_context * create_ctx_with_attn(
|
|
struct llama_model * model,
|
|
int n_ctx, int n_batch, int attn_weights, int n_gpu_layers) {
|
|
struct llama_context_params params = llama_context_default_params();
|
|
params.n_ctx = n_ctx;
|
|
params.n_batch = n_batch;
|
|
params.n_ubatch = n_batch;
|
|
params.attn_weights = attn_weights ? true : false;
|
|
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
|
|
params.offload_kqv = n_gpu_layers > 0;
|
|
return llama_init_from_model(model, params);
|
|
}
|
|
"""
|
|
llama_dir = os.path.join(os.path.dirname(__file__), "..", "..")
|
|
include_dir = os.path.join(llama_dir, "include")
|
|
lib_dir = os.path.join(llama_dir, "build", "bin")
|
|
ggml_include = os.path.join(llama_dir, "ggml", "include")
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".c", mode="w", delete=False) as f:
|
|
f.write(shim_src)
|
|
src_path = f.name
|
|
|
|
shim_lib = os.path.join(lib_dir, "libllama_attn_shim.dylib")
|
|
if sys.platform == "linux":
|
|
shim_lib = os.path.join(lib_dir, "libllama_attn_shim.so")
|
|
|
|
cmd = [
|
|
"cc", "-shared", "-fPIC", "-o", shim_lib, src_path,
|
|
f"-I{include_dir}", f"-I{ggml_include}",
|
|
f"-L{lib_dir}", "-lllama",
|
|
f"-Wl,-rpath,{lib_dir}",
|
|
]
|
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
os.unlink(src_path)
|
|
if result.returncode != 0:
|
|
raise RuntimeError(f"Failed to compile shim: {result.stderr}")
|
|
|
|
shim = ctypes.CDLL(shim_lib)
|
|
shim.create_ctx_with_attn.argtypes = [
|
|
ctypes.POINTER(llama_model), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int
|
|
]
|
|
shim.create_ctx_with_attn.restype = ctypes.POINTER(llama_context)
|
|
|
|
ctx = shim.create_ctx_with_attn(model_ptr, n_ctx, n_batch, 1 if attn_weights else 0, n_gpu_layers)
|
|
if not ctx:
|
|
raise RuntimeError("Failed to create llama_context")
|
|
return ctx
|
|
|
|
|
|
# --- High-level helpers ---
|
|
|
|
def tokenize(vocab_ptr, text, add_bos=True, special=True):
|
|
"""Tokenize text, returning a list of token ids."""
|
|
text_bytes = text.encode("utf-8")
|
|
buf = (llama_token * (len(text_bytes) + 32))()
|
|
n = _lib.llama_tokenize(vocab_ptr, text_bytes, len(text_bytes), buf, len(buf), add_bos, special)
|
|
if n < 0:
|
|
buf = (llama_token * (-n))()
|
|
n = _lib.llama_tokenize(vocab_ptr, text_bytes, len(text_bytes), buf, len(buf), add_bos, special)
|
|
return list(buf[:n])
|
|
|
|
|
|
def decode_batch(ctx_ptr, tokens, output_last_only=True):
|
|
"""Decode a batch of tokens. Returns llama_decode return code."""
|
|
n = len(tokens)
|
|
batch = _lib.llama_batch_init(n, 0, 1)
|
|
|
|
for i in range(n):
|
|
batch.token[i] = tokens[i]
|
|
batch.pos[i] = i
|
|
batch.n_seq_id[i] = 1
|
|
# Write seq_id value into the pre-allocated buffer (don't replace the pointer)
|
|
batch.seq_id[i][0] = 0
|
|
batch.logits[i] = 1 if (not output_last_only or i == n - 1) else 0
|
|
|
|
batch.n_tokens = n
|
|
ret = _lib.llama_decode(ctx_ptr, batch)
|
|
_lib.llama_batch_free(batch)
|
|
return ret
|
|
|
|
|
|
def decode_single(ctx_ptr, token, pos, output=True):
|
|
"""Decode a single token at a given position."""
|
|
batch = _lib.llama_batch_init(1, 0, 1)
|
|
batch.token[0] = token
|
|
batch.pos[0] = pos
|
|
batch.n_seq_id[0] = 1
|
|
batch.seq_id[0][0] = 0 # Write value into pre-allocated buffer
|
|
batch.logits[0] = 1 if output else 0
|
|
batch.n_tokens = 1
|
|
ret = _lib.llama_decode(ctx_ptr, batch)
|
|
_lib.llama_batch_free(batch)
|
|
return ret
|
|
|
|
|
|
def get_attn_weights(ctx_ptr, token_idx, n_pairs, n_ctx):
|
|
"""Get attention weights for a given output token index.
|
|
|
|
Returns numpy array of shape (n_pairs, n_kv) or None.
|
|
"""
|
|
import numpy as np
|
|
|
|
ptr = _lib.llama_get_attn_ith(ctx_ptr, token_idx)
|
|
if not ptr:
|
|
return None
|
|
|
|
n_kv = _lib.llama_get_attn_n_kv(ctx_ptr)
|
|
if n_kv <= 0:
|
|
return None
|
|
|
|
# Layout: [n_pairs * n_ctx] floats, each pair has n_ctx floats, first n_kv valid
|
|
result = np.zeros((n_pairs, n_kv), dtype=np.float32)
|
|
for p in range(n_pairs):
|
|
offset = p * n_ctx
|
|
arr = (ctypes.c_float * n_kv).from_address(ctypes.addressof(ptr.contents) + offset * 4)
|
|
result[p] = np.frombuffer(arr, dtype=np.float32)
|
|
|
|
return result
|
|
|
|
|
|
def argmax_logits(ctx_ptr, token_idx, n_vocab):
|
|
"""Get the argmax of logits for a given output token."""
|
|
ptr = _lib.llama_get_logits_ith(ctx_ptr, token_idx)
|
|
if not ptr:
|
|
return -1
|
|
logits = (ctypes.c_float * n_vocab).from_address(ctypes.addressof(ptr.contents))
|
|
import numpy as np
|
|
return int(np.argmax(np.frombuffer(logits, dtype=np.float32)))
|
|
|
|
|
|
# --- Public API ---
|
|
|
|
def init():
|
|
_lib.llama_backend_init()
|
|
|
|
def cleanup():
|
|
_lib.llama_backend_free()
|
|
|
|
def load_model(path, n_gpu_layers=0):
|
|
params = _lib.llama_model_default_params()
|
|
# n_gpu_layers is at offset 0 in llama_model_params (first field)
|
|
# Actually let's just use default params for simplicity
|
|
model = _lib.llama_model_load_from_file(path.encode(), params)
|
|
if not model:
|
|
raise RuntimeError(f"Failed to load model from {path}")
|
|
return model
|
|
|
|
def create_context(model, n_ctx=512, n_batch=512, attn_weights=True):
|
|
return _create_context(model, n_ctx, n_batch, attn_weights)
|
|
|
|
def set_attn_heads(ctx, layers, heads):
|
|
n = len(layers)
|
|
assert len(heads) == n
|
|
l_arr = (ctypes.c_int32 * n)(*layers)
|
|
h_arr = (ctypes.c_int32 * n)(*heads)
|
|
_lib.llama_set_attn_heads(ctx, l_arr, h_arr, n)
|
|
|
|
def get_vocab(model):
|
|
return _lib.llama_model_get_vocab(model)
|
|
|
|
def n_layer(model):
|
|
return _lib.llama_model_n_layer(model)
|
|
|
|
def n_head(model):
|
|
return _lib.llama_model_n_head(model)
|
|
|
|
def n_vocab(vocab):
|
|
return _lib.llama_vocab_n_tokens(vocab)
|
|
|
|
def n_ctx(ctx):
|
|
return _lib.llama_n_ctx(ctx)
|
|
|
|
def vocab_eos(vocab):
|
|
return _lib.llama_vocab_eos(vocab)
|
|
|
|
def free_context(ctx):
|
|
_lib.llama_free(ctx)
|
|
|
|
def token_to_piece(vocab, token_id, special=True):
|
|
"""Convert a single token ID to its string piece."""
|
|
buf = (ctypes.c_char * 256)()
|
|
n = _lib.llama_token_to_piece(vocab, token_id, buf, 256, 0, special)
|
|
if n > 0:
|
|
return buf[:n].decode("utf-8", errors="replace")
|
|
return ""
|
|
|
|
def free_model(model):
|
|
_lib.llama_model_free(model)
|