diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 8468ba8488..a781fb91f5 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -778,12 +778,15 @@ void launch_fattn( ) { constexpr int ncols = ncols1 * ncols2; - const bool is_mla = DV == 512; // TODO better parameterization - const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; + // TODO: make this more generic by removing the notion of "MLA". + // for example "is V a view of K?" so we can skip loading it. + // V strides should be driven by V itself and avoid assumption of the data layout + const bool is_mla = V->op == GGML_OP_VIEW && V->src[0] == K; + GGML_ASSERT(V || is_mla); const ggml_tensor * mask = dst->src[3]; diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 8cca89c2bf..203569e345 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -794,7 +794,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // For MLA K and V have the same data. // Therefore, iterate over V in reverse and re-use the data if possible. static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); - constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV; + // constexpr int reusable_cutoff = mla ? (DV - 1) - (DV - 1) % (2*nbatch_K2) : DV; + constexpr int reusable_cutoff = DV; // TODO implement properly #if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) T_A_VKQ A_identity; make_identity_mat(A_identity); @@ -1552,7 +1553,7 @@ static __global__ void flash_attn_ext_f16( (const half *) (mask + nb33*(sequence % ne33)); float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; @@ -1596,7 +1597,7 @@ static __global__ void flash_attn_ext_f16( (const half *) (mask + nb33*(sequence % ne33)); float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 57485c534e..5ebd0cf8aa 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1565,6 +1565,11 @@ ggml_tensor * llm_graph_context::build_attn_mha( v = ggml_transpose(ctx0, v); } + // TODO: update llama_kv_cache to not store V cache in the MLA case and automatically return a view of K + if (v_mla) { + v = ggml_view_4d(ctx0, k, v->ne[0], v->ne[1], v->ne[2], v->ne[3], k->nb[1], k->nb[2], k->nb[3], 0); + } + // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn) if (k->type == GGML_TYPE_F32) { k = ggml_cast(ctx0, k, GGML_TYPE_F16); diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index fd9f97d52e..a7327c4987 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1594,6 +1594,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co const auto & n_embd_head_k = hparams.n_embd_head_k; //const auto & n_embd_head_v = hparams.n_embd_head_v; + const auto & n_rot = hparams.n_rot; + + const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0; + auto inp = std::make_unique(this); inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream); @@ -1614,10 +1618,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co ggml_tensor * k = ggml_view_3d(ctx, layer.k, - n_embd_head_k, n_head_kv, get_size()*n_stream, + n_rot, n_head_kv, get_size()*n_stream, ggml_row_size(layer.k->type, n_embd_head_k), ggml_row_size(layer.k->type, n_embd_k_gqa), - 0); + ggml_row_size(layer.k->type, n_embd_nope)); ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index ca63a62ad1..c404c1946d 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -124,14 +124,14 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} // note: rope must go first for in-place context shifting in build_rope_shift() - ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0); + ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0); cb(Qcur, "Qcur", il); kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); cb(kv_cmpr, "kv_cmpr_reshape", il); // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} - ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0); + ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0); cb(Kcur, "Kcur", il); // {kv_lora_rank, 1, n_tokens} @@ -169,11 +169,10 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr Vcur = ggml_cont(ctx0, Vcur); cb(Vcur, "Vcur_cont", il); - // note: rope must go first for in-place context shifting in build_rope_shift() - ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope, 0); + ggml_tensor * Qcur = ggml_concat(ctx0, q_nope, q_pe, 0); cb(Qcur, "Qcur", il); - ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0); + ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); cb(Kcur, "Kcur", il); if (inp_attn_scale) { diff --git a/src/models/minicpm3.cpp b/src/models/minicpm3.cpp index f374a9fd03..297cc34ba5 100644 --- a/src/models/minicpm3.cpp +++ b/src/models/minicpm3.cpp @@ -9,6 +9,7 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; ggml_tensor * cur; diff --git a/src/models/plm.cpp b/src/models/plm.cpp index 481cbba690..612a487c56 100644 --- a/src/models/plm.cpp +++ b/src/models/plm.cpp @@ -5,6 +5,7 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params & const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; ggml_tensor * cur; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 9f61c6483d..146d05f53b 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6122,7 +6122,19 @@ struct test_flash_attn_ext : public test_case { ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1], true); // the K tensor is usually a view of the K cache ggml_set_name(k, "k"); - ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache + ggml_tensor * v = nullptr; + if (hsk_padded == 576 && hsv_padded == 512) { + // TODO: this branch should become a separate test case parameter instead of hardcoding this for these head shapes + + // in this branch, the V cache is sub-view of the K cache. this is used by some MLA-based models + // for more info: + // - https://github.com/ggml-org/llama.cpp/pull/13435 + // - https://github.com/ggml-org/llama.cpp/pull/18953#issuecomment-3774948392 + // - https://github.com/ggml-org/llama.cpp/pull/18986 + v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], 0); + } else { + v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache + } ggml_set_name(v, "v"); ggml_tensor * m = nullptr;