From 2fbde785bc106ae1c4102b0e82b9b41d9c466579 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 17 Jan 2026 15:42:42 +0200 Subject: [PATCH] kv-cache : optimize KQ mask construction (#18842) * kv-cache : optimize KQ mask construction * cont : add explanation + improve * cont : fix --- src/llama-hparams.cpp | 36 ------ src/llama-hparams.h | 39 +++++- src/llama-kv-cache.cpp | 282 +++++++++++++++++++++++++++++++---------- src/llama-kv-cache.h | 2 - 4 files changed, 250 insertions(+), 109 deletions(-) diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index c847ef91b7..5f1df995f3 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -200,42 +200,6 @@ uint32_t llama_hparams::n_layer_kv() const { return res; } -bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) { - assert(p0 >= 0 && p1 >= 0); - - switch (swa_type) { - case LLAMA_SWA_TYPE_NONE: - { - } break; - case LLAMA_SWA_TYPE_STANDARD: - { - if (p1 - p0 >= (int32_t) n_swa) { - return true; - } - } break; - case LLAMA_SWA_TYPE_CHUNKED: - { - const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; - - if (p0 < pos_chunk_start) { - return true; - } - } break; - case LLAMA_SWA_TYPE_SYMMETRIC: - { - const int32_t half_n_swa = (int32_t) n_swa / 2; - const int32_t pos_diff = p1 - p0; - - // Mask if outside the symmetric window - if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { - return true; - } - } break; - } - - return false; -} - bool llama_hparams::use_mrope() const { return rope_sections[0] > 0 && rope_sections[1] > 0; } diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 7ae3ec292e..2bf8665520 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -3,6 +3,7 @@ #include "llama.h" #include +#include // bump if necessary #define LLAMA_MAX_LAYERS 512 @@ -274,9 +275,45 @@ struct llama_hparams { uint32_t n_layer_kv() const; // note that this function uses different SWA parameters from those in the hparams + // note: inlined on purpose for performance reasons // TODO: think of a better place for this function // TODO: pack the SWA params in a struct? - static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1); + static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) { + assert(p0 >= 0 && p1 >= 0); + + switch (swa_type) { + case LLAMA_SWA_TYPE_NONE: + { + } break; + case LLAMA_SWA_TYPE_STANDARD: + { + if (p1 - p0 >= (int32_t) n_swa) { + return true; + } + } break; + case LLAMA_SWA_TYPE_CHUNKED: + { + const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; + + if (p0 < pos_chunk_start) { + return true; + } + } break; + case LLAMA_SWA_TYPE_SYMMETRIC: + { + const int32_t half_n_swa = (int32_t) n_swa / 2; + const int32_t pos_diff = p1 - p0; + + // Mask if outside the symmetric window + if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { + return true; + } + } break; + } + + return false; + } + bool use_mrope() const; }; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 3186242d60..fd9f97d52e 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -852,7 +852,7 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, const llama_seq_id seq_id_cell = cells.seq_get(idx); // SWA mask - if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { + if (llama_hparams::is_masked_swa(n_swa, swa_type, pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { can_use = true; } } @@ -1237,6 +1237,197 @@ void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const { } } +struct args_set_input_kq_mask { + const llama_hparams & hparams; + const llama_ubatch * ubatch; + + const std::vector & v_cells; + const std::vector & seq_to_stream; + + uint32_t n_swa; + llama_swa_type swa_type; + + int64_t n_kv; + int64_t n_stream; + int64_t n_tps; +}; + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + //const auto & hparams = args.hparams; + const auto & ubatch = args.ubatch; + + const auto & v_cells = args.v_cells; + const auto & seq_to_stream = args.seq_to_stream; + + const uint32_t n_swa = args.n_swa; + const llama_swa_type swa_type = args.swa_type; + + const int64_t n_kv = args.n_kv; + const int64_t n_stream = args.n_stream; + const int64_t n_tps = args.n_tps; + + // the min position in the batch for each sequence + llama_pos seq_pos_min[LLAMA_MAX_SEQ]; + std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX); + + for (uint32_t i = 0; i < ubatch->n_tokens; ++i) { + const llama_seq_id seq_id = ubatch->seq_id[i][0]; + + seq_pos_min[seq_id] = std::min(seq_pos_min[seq_id], ubatch->pos[i]); + } + + for (uint32_t s = 0; s < n_stream; ++s) { + // bookeeping of the KQ mask cells that could change for other tokens of the same sequence + std::unordered_map seq_srct; + std::unordered_map> seq_idxs; + + for (uint32_t ii = 0; ii < n_tps; ++ii) { + const uint32_t i = s*n_tps + ii; + + const llama_seq_id seq_id = ubatch->seq_id[i][0]; + + const auto & cells = v_cells.at(seq_to_stream[seq_id]); + + llama_pos p0 = -1; + const llama_pos p1 = ubatch->pos[i]; + + // for M-RoPE + const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0; + const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0; + + const uint64_t idst = n_kv*i; + + // for tokens of the same sequence, the mask is mostly the same, so we can reuse it + // the only cells that could change are the ones that are with similar positions as the + // ones in the batch (i.e. due to causal masking, SWA, etc.) + // keep track of those cells and shortcut the loop to save time + // note: this optimization is not compatible with Alibi position encoding + // ref: https://github.com/ggml-org/llama.cpp/pull/18842 + bool prev = false; + + auto & idxs = seq_idxs[seq_id]; + + if (!alibi) { + if (seq_srct.find(seq_id) != seq_srct.end()) { + const uint32_t srct = seq_srct[seq_id]; + + const uint64_t idst_prev = n_kv*srct; + + std::copy(data + idst_prev, data + idst_prev + n_kv, data + idst); + + prev = true; + } else { + idxs.clear(); + idxs.reserve(ubatch->n_tokens + n_swa + 32); + + seq_srct[seq_id] = i; + } + } + + for (uint32_t jj = 0; jj < n_kv; ++jj) { + uint32_t j = jj; + + // we have an exiting mask for this sequence -> update just seq_idxs + if (!alibi) { + if (prev) { + if (jj >= idxs.size()) { + break; + } + + j = idxs[jj]; + } + } + + if (cells.is_empty(j)) { + goto skip; + } + + // mask the token if not the same sequence + if (!cells.seq_has(j, seq_id)) { + goto skip; + } + + p0 = cells.pos_get(j); + + if (!alibi) { + if (!prev) { + // record all cells for which: p0 >= seq_pos_min[seq_id] - n_swa - 32 + if (p0 + (int32_t) (n_swa + 32) >= seq_pos_min[seq_id]) { + idxs.push_back(j); + } + } + } + + if (causal) { + // mask future tokens + if (p0 > p1) { + goto skip; + } + + // M-RoPE causal mask + if (is_2d) { + if (p0 == p1) { + const auto & p0_ext = cells.ext_get(j); + + if (p0_ext.is_2d_gt(p1_x, p1_y)) { + goto skip; + } + } + } + } + + // apply SWA if any + if (swa) { + if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) { + goto skip; + } + } + + if (alibi) { + data[idst + j] = -std::abs(p0 - p1); + } else { + data[idst + j] = 0.0f; + } + + continue; +skip: + data[idst + j] = -INFINITY; + } + } + } +} + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + const bool alibi = args.hparams.use_alibi; + if (alibi) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } +} + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + const bool is_2d = args.ubatch->is_pos_2d(); + if (is_2d) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } +} + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE; + if (swa) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } +} + void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { const uint32_t n_tokens = ubatch->n_tokens; @@ -1251,74 +1442,29 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u // n_tps == n_tokens_per_stream const int64_t n_tps = n_tokens/n_stream; - std::fill(data, data + ggml_nelements(dst), -INFINITY); + //const int64_t t_start = ggml_time_us(); - // Use only the previous KV cells of the correct sequence for each token of the ubatch. - // It's assumed that if a token in the batch has multiple sequences, they are equivalent. - // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: - // Causal mask: - // xxx------- - // xxxx------ - // xxxxx----- - // Non-causal mask: - // xxxxx----- - // xxxxx----- - // xxxxx----- - // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 - // TODO: optimize this section - for (uint32_t h = 0; h < 1; ++h) { - for (uint32_t s = 0; s < n_stream; ++s) { - for (uint32_t ii = 0; ii < n_tps; ++ii) { - const uint32_t i = s*n_tps + ii; + const args_set_input_kq_mask args = { + /*.hparams =*/ hparams, + /*.ubatch =*/ ubatch, + /*.v_cells =*/ v_cells, + /*.seq_to_stream =*/ seq_to_stream, + /*.n_swa =*/ n_swa, + /*.swa_type =*/ swa_type, + /*.n_kv =*/ n_kv, + /*.n_stream =*/ n_stream, + /*.n_tps =*/ n_tps, + }; - const llama_seq_id seq_id = ubatch->seq_id[i][0]; - - const auto & cells = v_cells[seq_to_stream[seq_id]]; - - const llama_pos p1 = ubatch->pos[i]; - - // for M-RoPE - const bool is_2d = ubatch->is_pos_2d(); - const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0; - const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0; - - const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii); - - for (uint32_t j = 0; j < n_kv; ++j) { - if (cells.is_empty(j)) { - continue; - } - - // mask the token if not the same sequence - if (!cells.seq_has(j, seq_id)) { - continue; - } - - const llama_pos p0 = cells.pos_get(j); - - // mask future tokens - if (causal_attn && p0 > p1) { - continue; - } - - // M-RoPE causal mask - if (causal_attn && is_2d && p0 == p1) { - const auto & p0_ext = cells.ext_get(j); - if (p0_ext.is_2d_gt(p1_x, p1_y)) { - continue; - } - } - - // apply SWA if any - if (is_masked_swa(p0, p1)) { - continue; - } - - data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; - } - } - } + if (causal_attn) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); } + + //const int64_t t_end = ggml_time_us(); + + //LLAMA_LOG_ERROR("%s: kq mask time: %0.3f ms\n", __func__, (t_end - t_start)/1000.0); } void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { @@ -1483,10 +1629,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co return gf; } -bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const { - return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1); -} - void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { GGML_UNUSED(flags); diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 0c4ed64845..e194bf3e26 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -257,8 +257,6 @@ private: size_t size_k_bytes() const; size_t size_v_bytes() const; - bool is_masked_swa(llama_pos p0, llama_pos p1) const; - ggml_tensor * build_rope_shift( const llama_cparams & cparams, ggml_context * ctx,