Use internal cb_eval for attention extraction to eliminate graph splits

Instead of marking attention tensors as output (which caused ~70 graph
splits and +28% overhead), extract attention weights via an internal
cb_eval callback during graph execution. This reduces graph splits
back to baseline (2) with ~0% overhead while keeping the clean
public API (llama_set_attn_heads + llama_get_attn_ith).

Changes:
- Remove ggml_set_output() on t_attn tensors in set_outputs()
- Add attn_cb_eval_fn that intercepts kq_soft_max tensors and copies
  attention data using ggml_backend_tensor_get with tensor strides
- Remove post-decode attention extraction loop (now done during execution)
- Remove sched_need_reserve from set_attn_heads (graph topology unchanged)
- Chain to user's cb_eval callback when both are active
This commit is contained in:
Quentin Fuxa 2026-03-05 11:34:03 +01:00
parent 14bf6d45bb
commit b550fa6e18
3 changed files with 95 additions and 61 deletions

View File

@ -804,7 +804,6 @@ void llama_context::set_attn_heads(const int32_t * layers, const int32_t * heads
cparams.attn_layers.insert(layers[i]);
}
sched_need_reserve = true;
}
float * llama_context::get_attn_ith(int32_t i) {
@ -828,6 +827,81 @@ int32_t llama_context::get_attn_n_kv() const {
return attn_n_kv;
}
// static
bool llama_context::attn_cb_eval_fn(struct ggml_tensor * t, bool ask, void * user_data) {
auto * state = static_cast<attn_cb_state *>(user_data);
auto * ctx = state->ctx;
const auto user_cb = ctx->cparams.cb_eval;
const auto user_data_cb = ctx->cparams.cb_eval_user_data;
const char * name = t->name;
if (strncmp(name, "kq_soft_max-", 12) != 0) {
return user_cb ? user_cb(t, ask, user_data_cb) : false;
}
const int layer_idx = atoi(name + 12);
if (ctx->cparams.attn_layers.find(layer_idx) == ctx->cparams.attn_layers.end()) {
return user_cb ? user_cb(t, ask, user_data_cb) : false;
}
if (ask) {
if (user_cb) {
user_cb(t, true, user_data_cb);
}
return true;
}
// data is available, extract attention weights
const auto & ubatch = *state->ubatch;
const auto & attn_heads_vec = ctx->attn_heads;
const size_t n_pairs = attn_heads_vec.size();
const size_t attn_stride = n_pairs * ctx->cparams.n_ctx;
const int64_t t_n_kv = t->ne[0];
const int64_t t_n_tokens = t->ne[1];
const int64_t t_n_head = t->ne[2];
const int64_t t_n_stream = t->ne[3];
ctx->attn_n_kv = (int32_t) t_n_kv;
for (size_t p = 0; p < n_pairs; p++) {
if (attn_heads_vec[p].layer != layer_idx) {
continue;
}
const int head = attn_heads_vec[p].head;
if (head >= t_n_head) {
continue;
}
int32_t out_idx = 0;
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
if (!ubatch.output[i]) {
continue;
}
const int64_t stream = t_n_stream > 1 ? (int64_t) i / t_n_tokens : 0;
const int64_t t_in_str = t_n_stream > 1 ? (int64_t) i % t_n_tokens : (int64_t) i;
const size_t src_offset = stream * t->nb[3] + head * t->nb[2] + t_in_str * t->nb[1];
float * dst = ctx->attn.data + (state->n_outputs_prev + out_idx) * attn_stride + p * ctx->cparams.n_ctx;
ggml_backend_tensor_get(t, dst, src_offset, t_n_kv * sizeof(float));
out_idx++;
}
}
if (user_cb) {
return user_cb(t, false, user_data_cb);
}
return true;
}
llama_token llama_context::get_sampled_token_ith(int32_t idx) {
output_reorder();
@ -1146,7 +1220,12 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
res->reset();
ggml_backend_sched_reset(sched.get());
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
if (attn.data && cparams.attn_weights && !attn_heads.empty()) {
ggml_backend_sched_set_eval_callback(sched.get(), attn_cb_eval_fn, &attn_cb);
} else {
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
}
//const auto t_start_us = ggml_time_us();
@ -1633,6 +1712,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
n_outputs = n_outputs_new;
}
attn_cb.ctx = this;
attn_cb.ubatch = &ubatch;
attn_cb.n_outputs_prev = n_outputs_prev;
ggml_status status;
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
@ -1750,59 +1833,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
}
}
// extract attention weights
if (attn.data && cparams.attn_weights && n_outputs > 0 && !attn_heads.empty()) {
const size_t n_pairs = attn_heads.size();
const size_t attn_stride = n_pairs * cparams.n_ctx; // stride per output token in the attn buffer
// iterate over requested (layer, head) pairs
for (size_t p = 0; p < n_pairs; p++) {
const int layer = attn_heads[p].layer;
const int head = attn_heads[p].head;
auto it = res->t_attn.find(layer);
if (it == res->t_attn.end() || it->second == nullptr) {
continue;
}
ggml_tensor * t = it->second;
ggml_backend_t backend_attn = ggml_backend_sched_get_tensor_backend(sched.get(), t);
GGML_ASSERT(backend_attn != nullptr);
// t shape: [n_kv, n_tokens_in_stream, n_head, n_stream]
const int64_t t_n_kv = t->ne[0];
const int64_t t_n_tokens_in_stream = t->ne[1];
const int64_t t_n_head = t->ne[2];
const int64_t t_n_stream = t->ne[3];
attn_n_kv = (int32_t) t_n_kv;
GGML_ASSERT(head < t_n_head);
// extract attention for each output token
int32_t out_idx = 0;
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
if (!ubatch.output[i]) {
continue;
}
// compute token position within stream
const int64_t stream = t_n_stream > 1 ? (int64_t) i / t_n_tokens_in_stream : 0;
const int64_t t_in_str = t_n_stream > 1 ? (int64_t) i % t_n_tokens_in_stream : (int64_t) i;
// byte offset into the tensor for this (token, head, stream)
const size_t src_offset = ((stream * t_n_head + head) * t_n_tokens_in_stream + t_in_str) * t_n_kv * sizeof(float);
// destination in the output buffer
float * dst = attn.data + (n_outputs_prev + out_idx) * attn_stride + p * cparams.n_ctx;
ggml_backend_tensor_get_async(backend_attn, t, dst, src_offset, t_n_kv * sizeof(float));
out_idx++;
}
}
}
// Copy backend sampling output if this ubatch produced any sampling tensors.
if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) {
const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
@ -2201,7 +2231,7 @@ ggml_status llama_context::graph_compute(
LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
}
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched));
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched.get()));
return status;
}

View File

@ -308,6 +308,15 @@ private:
buffer_view<float> attn = {nullptr, 0}; // [n_outputs][n_pairs * n_ctx]
int32_t attn_n_kv = 0; // KV cache length at time of last decode
struct attn_cb_state {
llama_context * ctx = nullptr;
const llama_ubatch * ubatch = nullptr;
int64_t n_outputs_prev = 0;
};
attn_cb_state attn_cb;
static bool attn_cb_eval_fn(struct ggml_tensor * t, bool ask, void * user_data);
// reuse the batch_allocr to avoid unnecessary memory allocations
std::unique_ptr<llama_batch_allocr> balloc;

View File

@ -790,11 +790,6 @@ void llm_graph_result::set_outputs() {
ggml_set_output(t);
}
}
for (auto & [layer, t] : t_attn) {
if (t != nullptr) {
ggml_set_output(t);
}
}
for (auto & [seq_id, t] : t_candidates) {
if (t != nullptr) {
ggml_set_output(t);