Use internal cb_eval for attention extraction to eliminate graph splits
Instead of marking attention tensors as output (which caused ~70 graph splits and +28% overhead), extract attention weights via an internal cb_eval callback during graph execution. This reduces graph splits back to baseline (2) with ~0% overhead while keeping the clean public API (llama_set_attn_heads + llama_get_attn_ith). Changes: - Remove ggml_set_output() on t_attn tensors in set_outputs() - Add attn_cb_eval_fn that intercepts kq_soft_max tensors and copies attention data using ggml_backend_tensor_get with tensor strides - Remove post-decode attention extraction loop (now done during execution) - Remove sched_need_reserve from set_attn_heads (graph topology unchanged) - Chain to user's cb_eval callback when both are active
This commit is contained in:
parent
14bf6d45bb
commit
b550fa6e18
|
|
@ -804,7 +804,6 @@ void llama_context::set_attn_heads(const int32_t * layers, const int32_t * heads
|
|||
cparams.attn_layers.insert(layers[i]);
|
||||
}
|
||||
|
||||
sched_need_reserve = true;
|
||||
}
|
||||
|
||||
float * llama_context::get_attn_ith(int32_t i) {
|
||||
|
|
@ -828,6 +827,81 @@ int32_t llama_context::get_attn_n_kv() const {
|
|||
return attn_n_kv;
|
||||
}
|
||||
|
||||
// static
|
||||
bool llama_context::attn_cb_eval_fn(struct ggml_tensor * t, bool ask, void * user_data) {
|
||||
auto * state = static_cast<attn_cb_state *>(user_data);
|
||||
auto * ctx = state->ctx;
|
||||
|
||||
const auto user_cb = ctx->cparams.cb_eval;
|
||||
const auto user_data_cb = ctx->cparams.cb_eval_user_data;
|
||||
|
||||
const char * name = t->name;
|
||||
if (strncmp(name, "kq_soft_max-", 12) != 0) {
|
||||
return user_cb ? user_cb(t, ask, user_data_cb) : false;
|
||||
}
|
||||
|
||||
const int layer_idx = atoi(name + 12);
|
||||
|
||||
if (ctx->cparams.attn_layers.find(layer_idx) == ctx->cparams.attn_layers.end()) {
|
||||
return user_cb ? user_cb(t, ask, user_data_cb) : false;
|
||||
}
|
||||
|
||||
if (ask) {
|
||||
if (user_cb) {
|
||||
user_cb(t, true, user_data_cb);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// data is available, extract attention weights
|
||||
const auto & ubatch = *state->ubatch;
|
||||
const auto & attn_heads_vec = ctx->attn_heads;
|
||||
const size_t n_pairs = attn_heads_vec.size();
|
||||
const size_t attn_stride = n_pairs * ctx->cparams.n_ctx;
|
||||
|
||||
const int64_t t_n_kv = t->ne[0];
|
||||
const int64_t t_n_tokens = t->ne[1];
|
||||
const int64_t t_n_head = t->ne[2];
|
||||
const int64_t t_n_stream = t->ne[3];
|
||||
|
||||
ctx->attn_n_kv = (int32_t) t_n_kv;
|
||||
|
||||
for (size_t p = 0; p < n_pairs; p++) {
|
||||
if (attn_heads_vec[p].layer != layer_idx) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int head = attn_heads_vec[p].head;
|
||||
if (head >= t_n_head) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int32_t out_idx = 0;
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||
if (!ubatch.output[i]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64_t stream = t_n_stream > 1 ? (int64_t) i / t_n_tokens : 0;
|
||||
const int64_t t_in_str = t_n_stream > 1 ? (int64_t) i % t_n_tokens : (int64_t) i;
|
||||
|
||||
const size_t src_offset = stream * t->nb[3] + head * t->nb[2] + t_in_str * t->nb[1];
|
||||
|
||||
float * dst = ctx->attn.data + (state->n_outputs_prev + out_idx) * attn_stride + p * ctx->cparams.n_ctx;
|
||||
|
||||
ggml_backend_tensor_get(t, dst, src_offset, t_n_kv * sizeof(float));
|
||||
|
||||
out_idx++;
|
||||
}
|
||||
}
|
||||
|
||||
if (user_cb) {
|
||||
return user_cb(t, false, user_data_cb);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
llama_token llama_context::get_sampled_token_ith(int32_t idx) {
|
||||
output_reorder();
|
||||
|
||||
|
|
@ -1146,7 +1220,12 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
|
|||
res->reset();
|
||||
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||
|
||||
if (attn.data && cparams.attn_weights && !attn_heads.empty()) {
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), attn_cb_eval_fn, &attn_cb);
|
||||
} else {
|
||||
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
||||
}
|
||||
|
||||
//const auto t_start_us = ggml_time_us();
|
||||
|
||||
|
|
@ -1633,6 +1712,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
n_outputs = n_outputs_new;
|
||||
}
|
||||
|
||||
attn_cb.ctx = this;
|
||||
attn_cb.ubatch = &ubatch;
|
||||
attn_cb.n_outputs_prev = n_outputs_prev;
|
||||
|
||||
ggml_status status;
|
||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
|
||||
|
||||
|
|
@ -1750,59 +1833,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
}
|
||||
}
|
||||
|
||||
// extract attention weights
|
||||
if (attn.data && cparams.attn_weights && n_outputs > 0 && !attn_heads.empty()) {
|
||||
const size_t n_pairs = attn_heads.size();
|
||||
const size_t attn_stride = n_pairs * cparams.n_ctx; // stride per output token in the attn buffer
|
||||
|
||||
// iterate over requested (layer, head) pairs
|
||||
for (size_t p = 0; p < n_pairs; p++) {
|
||||
const int layer = attn_heads[p].layer;
|
||||
const int head = attn_heads[p].head;
|
||||
|
||||
auto it = res->t_attn.find(layer);
|
||||
if (it == res->t_attn.end() || it->second == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
ggml_tensor * t = it->second;
|
||||
ggml_backend_t backend_attn = ggml_backend_sched_get_tensor_backend(sched.get(), t);
|
||||
GGML_ASSERT(backend_attn != nullptr);
|
||||
|
||||
// t shape: [n_kv, n_tokens_in_stream, n_head, n_stream]
|
||||
const int64_t t_n_kv = t->ne[0];
|
||||
const int64_t t_n_tokens_in_stream = t->ne[1];
|
||||
const int64_t t_n_head = t->ne[2];
|
||||
const int64_t t_n_stream = t->ne[3];
|
||||
|
||||
attn_n_kv = (int32_t) t_n_kv;
|
||||
|
||||
GGML_ASSERT(head < t_n_head);
|
||||
|
||||
// extract attention for each output token
|
||||
int32_t out_idx = 0;
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||
if (!ubatch.output[i]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// compute token position within stream
|
||||
const int64_t stream = t_n_stream > 1 ? (int64_t) i / t_n_tokens_in_stream : 0;
|
||||
const int64_t t_in_str = t_n_stream > 1 ? (int64_t) i % t_n_tokens_in_stream : (int64_t) i;
|
||||
|
||||
// byte offset into the tensor for this (token, head, stream)
|
||||
const size_t src_offset = ((stream * t_n_head + head) * t_n_tokens_in_stream + t_in_str) * t_n_kv * sizeof(float);
|
||||
|
||||
// destination in the output buffer
|
||||
float * dst = attn.data + (n_outputs_prev + out_idx) * attn_stride + p * cparams.n_ctx;
|
||||
|
||||
ggml_backend_tensor_get_async(backend_attn, t, dst, src_offset, t_n_kv * sizeof(float));
|
||||
|
||||
out_idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Copy backend sampling output if this ubatch produced any sampling tensors.
|
||||
if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) {
|
||||
const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
|
||||
|
|
@ -2201,7 +2231,7 @@ ggml_status llama_context::graph_compute(
|
|||
LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status);
|
||||
}
|
||||
|
||||
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched));
|
||||
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched.get()));
|
||||
|
||||
return status;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -308,6 +308,15 @@ private:
|
|||
buffer_view<float> attn = {nullptr, 0}; // [n_outputs][n_pairs * n_ctx]
|
||||
int32_t attn_n_kv = 0; // KV cache length at time of last decode
|
||||
|
||||
struct attn_cb_state {
|
||||
llama_context * ctx = nullptr;
|
||||
const llama_ubatch * ubatch = nullptr;
|
||||
int64_t n_outputs_prev = 0;
|
||||
};
|
||||
attn_cb_state attn_cb;
|
||||
|
||||
static bool attn_cb_eval_fn(struct ggml_tensor * t, bool ask, void * user_data);
|
||||
|
||||
// reuse the batch_allocr to avoid unnecessary memory allocations
|
||||
std::unique_ptr<llama_batch_allocr> balloc;
|
||||
|
||||
|
|
|
|||
|
|
@ -790,11 +790,6 @@ void llm_graph_result::set_outputs() {
|
|||
ggml_set_output(t);
|
||||
}
|
||||
}
|
||||
for (auto & [layer, t] : t_attn) {
|
||||
if (t != nullptr) {
|
||||
ggml_set_output(t);
|
||||
}
|
||||
}
|
||||
for (auto & [seq_id, t] : t_candidates) {
|
||||
if (t != nullptr) {
|
||||
ggml_set_output(t);
|
||||
|
|
|
|||
Loading…
Reference in New Issue