graph : fix KQ mask, lora, cvec reuse checks (#19644)

* graph : fix KQ mask reuse condition

* cont : dedup KQ mask build and can_reuse

* cont : fix build

* graph : fix adapter check for reuse
This commit is contained in:
Georgi Gerganov 2026-02-16 09:21:11 +02:00 committed by GitHub
parent 267ba5a1d9
commit d5dfc33027
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 67 additions and 54 deletions

View File

@ -39,6 +39,8 @@ private:
std::vector<ggml_tensor *> tensors; // per layer
};
using llama_adapter_cvec_ptr = std::shared_ptr<llama_adapter_cvec>;
//
// llama_adapter_lora
//
@ -84,3 +86,4 @@ struct llama_adapter_lora {
};
using llama_adapter_loras = std::unordered_map<llama_adapter_lora *, float>;
using llama_adapter_loras_ptr = std::unique_ptr<llama_adapter_loras>;

View File

@ -22,6 +22,8 @@ llama_context::llama_context(
const llama_model & model,
llama_context_params params) :
model(model),
cvec(std::make_unique<llama_adapter_cvec>()),
loras(std::make_unique<llama_adapter_loras>()),
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
// TODO warning when creating llama_context with awkward ctx size that is not a power of 2,
// may need to be backend-dependent
@ -1065,11 +1067,11 @@ void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_a
return;
}
loras.clear();
loras.reset(new llama_adapter_loras());
for (size_t i = 0; i < n_adapters; i ++) {
if (scales[i] != 0.0f) {
loras[adapters[i]] = scales[i];
loras->insert({adapters[i], scales[i]});
}
}
@ -1079,14 +1081,14 @@ void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_a
bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) {
LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters);
if (n_adapters != loras.size()) {
if (n_adapters != loras->size()) {
return false;
}
for (size_t i = 0; i < n_adapters; i ++) {
auto it = loras.find(adapters[i]);
auto it = loras->find(adapters[i]);
if (it == loras.end() || it->second != scales[i]) {
if (it == loras->end() || it->second != scales[i]) {
return false;
}
}
@ -1104,7 +1106,7 @@ bool llama_context::set_adapter_cvec(
// TODO: should we reserve?
return cvec.apply(model, data, len, n_embd, il_start, il_end);
return cvec->apply(model, data, len, n_embd, il_start, il_end);
}
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
@ -2081,8 +2083,8 @@ llm_graph_params llama_context::graph_params(
/*.gtype =*/ gtype,
/*.sched =*/ sched.get(),
/*.backend_cpu =*/ backend_cpu,
/*.cvec =*/ &cvec,
/*.loras =*/ &loras,
/*.cvec =*/ cvec.get(),
/*.loras =*/ loras.get(),
/*.mctx =*/ mctx,
/*.cross =*/ &cross,
/*.samplers =*/ sampling.samplers,

View File

@ -256,9 +256,10 @@ private:
const llama_model & model;
llama_cparams cparams;
llama_adapter_cvec cvec;
llama_adapter_loras loras;
llama_cparams cparams;
llama_adapter_cvec_ptr cvec;
llama_adapter_loras_ptr loras;
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably

View File

@ -17,6 +17,41 @@
#include <sstream>
#include <unordered_set>
// dedup helpers
static ggml_tensor * build_kq_mask(
ggml_context * ctx,
const llama_kv_cache_context * mctx,
const llama_ubatch & ubatch,
const llama_cparams & cparams) {
const auto n_kv = mctx->get_n_kv();
const auto n_tokens = ubatch.n_tokens;
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
}
static bool can_reuse_kq_mask(
ggml_tensor * kq_mask,
const llama_kv_cache_context * mctx,
const llama_ubatch & ubatch,
const llama_cparams & cparams) {
const auto n_kv = mctx->get_n_kv();
const auto n_tokens = ubatch.n_tokens;
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
bool res = true;
res &= (kq_mask->ne[0] == n_kv);
res &= (kq_mask->ne[1] == n_tokens/n_stream);
res &= (kq_mask->ne[2] == 1);
res &= (kq_mask->ne[3] == n_stream);
return res;
}
// impl
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
if (ubatch->token) {
const int64_t n_tokens = ubatch->n_tokens;
@ -403,8 +438,7 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
return res;
}
@ -424,8 +458,7 @@ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
return res;
}
@ -455,11 +488,8 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
return res;
}
@ -521,8 +551,7 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
@ -565,8 +594,7 @@ bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) {
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
@ -625,8 +653,7 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params)
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= inp_attn->self_kq_mask->ne[0] == attn_ctx->get_base()->get_n_kv();
res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
}
// swa tensors may not be allocated if there are no SWA attention layers
@ -634,8 +661,7 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params)
res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
//res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
res &= inp_attn->self_kq_mask_swa->ne[0] == attn_ctx->get_swa()->get_n_kv();
res &= inp_attn->self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
}
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
@ -1891,14 +1917,11 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
const auto n_kv = mctx_cur->get_n_kv();
const auto n_tokens = ubatch.n_tokens;
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@ -1983,13 +2006,9 @@ static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
const auto n_kv = mctx_cur->get_n_kv();
const auto n_tokens = ubatch.n_tokens;
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@ -2188,15 +2207,11 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
{
const auto n_kv = mctx_cur->get_base()->get_n_kv();
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
ggml_set_input(inp->self_kq_mask);
ggml_set_name(inp->self_kq_mask, "self_kq_mask");
@ -2207,12 +2222,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
{
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
ggml_set_input(inp->self_kq_mask_swa);
ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
@ -2374,27 +2387,21 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx);
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
{
const auto n_kv = attn_ctx->get_base()->get_n_kv();
inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
inp_attn->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
ggml_set_input(inp_attn->self_kq_mask);
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
}
{
const auto n_kv = attn_ctx->get_swa()->get_n_kv();
inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp_attn->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
ggml_set_input(inp_attn->self_kq_mask_swa);
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;