diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 9a16c28c37..7c66f9a1d3 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -804,7 +804,6 @@ void llama_context::set_attn_heads(const int32_t * layers, const int32_t * heads cparams.attn_layers.insert(layers[i]); } - sched_need_reserve = true; } float * llama_context::get_attn_ith(int32_t i) { @@ -828,6 +827,81 @@ int32_t llama_context::get_attn_n_kv() const { return attn_n_kv; } +// static +bool llama_context::attn_cb_eval_fn(struct ggml_tensor * t, bool ask, void * user_data) { + auto * state = static_cast(user_data); + auto * ctx = state->ctx; + + const auto user_cb = ctx->cparams.cb_eval; + const auto user_data_cb = ctx->cparams.cb_eval_user_data; + + const char * name = t->name; + if (strncmp(name, "kq_soft_max-", 12) != 0) { + return user_cb ? user_cb(t, ask, user_data_cb) : false; + } + + const int layer_idx = atoi(name + 12); + + if (ctx->cparams.attn_layers.find(layer_idx) == ctx->cparams.attn_layers.end()) { + return user_cb ? user_cb(t, ask, user_data_cb) : false; + } + + if (ask) { + if (user_cb) { + user_cb(t, true, user_data_cb); + } + return true; + } + + // data is available, extract attention weights + const auto & ubatch = *state->ubatch; + const auto & attn_heads_vec = ctx->attn_heads; + const size_t n_pairs = attn_heads_vec.size(); + const size_t attn_stride = n_pairs * ctx->cparams.n_ctx; + + const int64_t t_n_kv = t->ne[0]; + const int64_t t_n_tokens = t->ne[1]; + const int64_t t_n_head = t->ne[2]; + const int64_t t_n_stream = t->ne[3]; + + ctx->attn_n_kv = (int32_t) t_n_kv; + + for (size_t p = 0; p < n_pairs; p++) { + if (attn_heads_vec[p].layer != layer_idx) { + continue; + } + + const int head = attn_heads_vec[p].head; + if (head >= t_n_head) { + continue; + } + + int32_t out_idx = 0; + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + if (!ubatch.output[i]) { + continue; + } + + const int64_t stream = t_n_stream > 1 ? (int64_t) i / t_n_tokens : 0; + const int64_t t_in_str = t_n_stream > 1 ? (int64_t) i % t_n_tokens : (int64_t) i; + + const size_t src_offset = stream * t->nb[3] + head * t->nb[2] + t_in_str * t->nb[1]; + + float * dst = ctx->attn.data + (state->n_outputs_prev + out_idx) * attn_stride + p * ctx->cparams.n_ctx; + + ggml_backend_tensor_get(t, dst, src_offset, t_n_kv * sizeof(float)); + + out_idx++; + } + } + + if (user_cb) { + return user_cb(t, false, user_data_cb); + } + + return true; +} + llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); @@ -1146,7 +1220,12 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll res->reset(); ggml_backend_sched_reset(sched.get()); - ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + + if (attn.data && cparams.attn_weights && !attn_heads.empty()) { + ggml_backend_sched_set_eval_callback(sched.get(), attn_cb_eval_fn, &attn_cb); + } else { + ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + } //const auto t_start_us = ggml_time_us(); @@ -1633,6 +1712,10 @@ int llama_context::decode(const llama_batch & batch_inp) { n_outputs = n_outputs_new; } + attn_cb.ctx = this; + attn_cb.ubatch = &ubatch; + attn_cb.n_outputs_prev = n_outputs_prev; + ggml_status status; const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); @@ -1750,59 +1833,6 @@ int llama_context::decode(const llama_batch & batch_inp) { } } - // extract attention weights - if (attn.data && cparams.attn_weights && n_outputs > 0 && !attn_heads.empty()) { - const size_t n_pairs = attn_heads.size(); - const size_t attn_stride = n_pairs * cparams.n_ctx; // stride per output token in the attn buffer - - // iterate over requested (layer, head) pairs - for (size_t p = 0; p < n_pairs; p++) { - const int layer = attn_heads[p].layer; - const int head = attn_heads[p].head; - - auto it = res->t_attn.find(layer); - if (it == res->t_attn.end() || it->second == nullptr) { - continue; - } - - ggml_tensor * t = it->second; - ggml_backend_t backend_attn = ggml_backend_sched_get_tensor_backend(sched.get(), t); - GGML_ASSERT(backend_attn != nullptr); - - // t shape: [n_kv, n_tokens_in_stream, n_head, n_stream] - const int64_t t_n_kv = t->ne[0]; - const int64_t t_n_tokens_in_stream = t->ne[1]; - const int64_t t_n_head = t->ne[2]; - const int64_t t_n_stream = t->ne[3]; - - attn_n_kv = (int32_t) t_n_kv; - - GGML_ASSERT(head < t_n_head); - - // extract attention for each output token - int32_t out_idx = 0; - for (uint32_t i = 0; i < ubatch.n_tokens; i++) { - if (!ubatch.output[i]) { - continue; - } - - // compute token position within stream - const int64_t stream = t_n_stream > 1 ? (int64_t) i / t_n_tokens_in_stream : 0; - const int64_t t_in_str = t_n_stream > 1 ? (int64_t) i % t_n_tokens_in_stream : (int64_t) i; - - // byte offset into the tensor for this (token, head, stream) - const size_t src_offset = ((stream * t_n_head + head) * t_n_tokens_in_stream + t_in_str) * t_n_kv * sizeof(float); - - // destination in the output buffer - float * dst = attn.data + (n_outputs_prev + out_idx) * attn_stride + p * cparams.n_ctx; - - ggml_backend_tensor_get_async(backend_attn, t, dst, src_offset, t_n_kv * sizeof(float)); - - out_idx++; - } - } - } - // Copy backend sampling output if this ubatch produced any sampling tensors. if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) { const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); @@ -2201,7 +2231,7 @@ ggml_status llama_context::graph_compute( LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, status); } - // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched)); + // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(sched.get())); return status; } diff --git a/src/llama-context.h b/src/llama-context.h index 78a4f3a637..034e6739f6 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -308,6 +308,15 @@ private: buffer_view attn = {nullptr, 0}; // [n_outputs][n_pairs * n_ctx] int32_t attn_n_kv = 0; // KV cache length at time of last decode + struct attn_cb_state { + llama_context * ctx = nullptr; + const llama_ubatch * ubatch = nullptr; + int64_t n_outputs_prev = 0; + }; + attn_cb_state attn_cb; + + static bool attn_cb_eval_fn(struct ggml_tensor * t, bool ask, void * user_data); + // reuse the batch_allocr to avoid unnecessary memory allocations std::unique_ptr balloc; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index add8dd4ad7..6ade0030cd 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -790,11 +790,6 @@ void llm_graph_result::set_outputs() { ggml_set_output(t); } } - for (auto & [layer, t] : t_attn) { - if (t != nullptr) { - ggml_set_output(t); - } - } for (auto & [seq_id, t] : t_candidates) { if (t != nullptr) { ggml_set_output(t);