mla : make the V tensor a view of K (#18986)

* mla : pass V as a view of K to the FA op

* cuda : adjust mla logic to new layout

* kv-cache : fix rope shift

* tests : remove comment

* cuda : fix reusable_cutoff

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
Georgi Gerganov 2026-01-22 22:09:01 +02:00 committed by GitHub
parent e2baf02162
commit a5eaa1d6a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 39 additions and 13 deletions

View File

@ -778,12 +778,15 @@ void launch_fattn(
) {
constexpr int ncols = ncols1 * ncols2;
const bool is_mla = DV == 512; // TODO better parameterization
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
// TODO: make this more generic by removing the notion of "MLA".
// for example "is V a view of K?" so we can skip loading it.
// V strides should be driven by V itself and avoid assumption of the data layout
const bool is_mla = V->op == GGML_OP_VIEW && V->src[0] == K;
GGML_ASSERT(V || is_mla);
const ggml_tensor * mask = dst->src[3];

View File

@ -794,7 +794,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
// For MLA K and V have the same data.
// Therefore, iterate over V in reverse and re-use the data if possible.
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
// constexpr int reusable_cutoff = mla ? (DV - 1) - (DV - 1) % (2*nbatch_K2) : DV;
constexpr int reusable_cutoff = DV; // TODO implement properly
#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
T_A_VKQ A_identity;
make_identity_mat(A_identity);
@ -1552,7 +1553,7 @@ static __global__ void flash_attn_ext_f16(
(const half *) (mask + nb33*(sequence % ne33));
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
@ -1596,7 +1597,7 @@ static __global__ void flash_attn_ext_f16(
(const half *) (mask + nb33*(sequence % ne33));
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;

View File

@ -1565,6 +1565,11 @@ ggml_tensor * llm_graph_context::build_attn_mha(
v = ggml_transpose(ctx0, v);
}
// TODO: update llama_kv_cache to not store V cache in the MLA case and automatically return a view of K
if (v_mla) {
v = ggml_view_4d(ctx0, k, v->ne[0], v->ne[1], v->ne[2], v->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
}
// this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
if (k->type == GGML_TYPE_F32) {
k = ggml_cast(ctx0, k, GGML_TYPE_F16);

View File

@ -1594,6 +1594,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
const auto & n_embd_head_k = hparams.n_embd_head_k;
//const auto & n_embd_head_v = hparams.n_embd_head_v;
const auto & n_rot = hparams.n_rot;
const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0;
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
@ -1614,10 +1618,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
ggml_tensor * k =
ggml_view_3d(ctx, layer.k,
n_embd_head_k, n_head_kv, get_size()*n_stream,
n_rot, n_head_kv, get_size()*n_stream,
ggml_row_size(layer.k->type, n_embd_head_k),
ggml_row_size(layer.k->type, n_embd_k_gqa),
0);
ggml_row_size(layer.k->type, n_embd_nope));
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);

View File

@ -124,14 +124,14 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
// {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
// note: rope must go first for in-place context shifting in build_rope_shift()
ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0);
ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0);
cb(Qcur, "Qcur", il);
kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
cb(kv_cmpr, "kv_cmpr_reshape", il);
// {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0);
ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0);
cb(Kcur, "Kcur", il);
// {kv_lora_rank, 1, n_tokens}
@ -169,11 +169,10 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
Vcur = ggml_cont(ctx0, Vcur);
cb(Vcur, "Vcur_cont", il);
// note: rope must go first for in-place context shifting in build_rope_shift()
ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope, 0);
ggml_tensor * Qcur = ggml_concat(ctx0, q_nope, q_pe, 0);
cb(Qcur, "Qcur", il);
ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0);
ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
cb(Kcur, "Kcur", il);
if (inp_attn_scale) {

View File

@ -9,6 +9,7 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
ggml_tensor * cur;

View File

@ -5,6 +5,7 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params &
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
ggml_tensor * cur;

View File

@ -6122,7 +6122,19 @@ struct test_flash_attn_ext : public test_case {
ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1], true); // the K tensor is usually a view of the K cache
ggml_set_name(k, "k");
ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache
ggml_tensor * v = nullptr;
if (hsk_padded == 576 && hsv_padded == 512) {
// TODO: this branch should become a separate test case parameter instead of hardcoding this for these head shapes
// in this branch, the V cache is sub-view of the K cache. this is used by some MLA-based models
// for more info:
// - https://github.com/ggml-org/llama.cpp/pull/13435
// - https://github.com/ggml-org/llama.cpp/pull/18953#issuecomment-3774948392
// - https://github.com/ggml-org/llama.cpp/pull/18986
v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], 0);
} else {
v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache
}
ggml_set_name(v, "v");
ggml_tensor * m = nullptr;