mla : make the V tensor a view of K (#18986)
* mla : pass V as a view of K to the FA op * cuda : adjust mla logic to new layout * kv-cache : fix rope shift * tests : remove comment * cuda : fix reusable_cutoff Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
parent
e2baf02162
commit
a5eaa1d6a3
|
|
@ -778,12 +778,15 @@ void launch_fattn(
|
|||
) {
|
||||
constexpr int ncols = ncols1 * ncols2;
|
||||
|
||||
const bool is_mla = DV == 512; // TODO better parameterization
|
||||
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
// TODO: make this more generic by removing the notion of "MLA".
|
||||
// for example "is V a view of K?" so we can skip loading it.
|
||||
// V strides should be driven by V itself and avoid assumption of the data layout
|
||||
const bool is_mla = V->op == GGML_OP_VIEW && V->src[0] == K;
|
||||
|
||||
GGML_ASSERT(V || is_mla);
|
||||
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
|
|
|||
|
|
@ -794,7 +794,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|||
// For MLA K and V have the same data.
|
||||
// Therefore, iterate over V in reverse and re-use the data if possible.
|
||||
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
|
||||
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
|
||||
// constexpr int reusable_cutoff = mla ? (DV - 1) - (DV - 1) % (2*nbatch_K2) : DV;
|
||||
constexpr int reusable_cutoff = DV; // TODO implement properly
|
||||
#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
|
||||
T_A_VKQ A_identity;
|
||||
make_identity_mat(A_identity);
|
||||
|
|
@ -1552,7 +1553,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
(const half *) (mask + nb33*(sequence % ne33));
|
||||
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
||||
const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
||||
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
|
||||
|
|
@ -1596,7 +1597,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
(const half *) (mask + nb33*(sequence % ne33));
|
||||
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
||||
const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
|
||||
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
|
||||
|
|
|
|||
|
|
@ -1565,6 +1565,11 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||
v = ggml_transpose(ctx0, v);
|
||||
}
|
||||
|
||||
// TODO: update llama_kv_cache to not store V cache in the MLA case and automatically return a view of K
|
||||
if (v_mla) {
|
||||
v = ggml_view_4d(ctx0, k, v->ne[0], v->ne[1], v->ne[2], v->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
|
||||
}
|
||||
|
||||
// this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
|
||||
if (k->type == GGML_TYPE_F32) {
|
||||
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
|
||||
|
|
|
|||
|
|
@ -1594,6 +1594,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
|||
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
||||
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
||||
|
||||
const auto & n_rot = hparams.n_rot;
|
||||
|
||||
const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0;
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
||||
|
||||
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
|
||||
|
|
@ -1614,10 +1618,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
|||
|
||||
ggml_tensor * k =
|
||||
ggml_view_3d(ctx, layer.k,
|
||||
n_embd_head_k, n_head_kv, get_size()*n_stream,
|
||||
n_rot, n_head_kv, get_size()*n_stream,
|
||||
ggml_row_size(layer.k->type, n_embd_head_k),
|
||||
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
||||
0);
|
||||
ggml_row_size(layer.k->type, n_embd_nope));
|
||||
|
||||
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
|
||||
|
||||
|
|
|
|||
|
|
@ -124,14 +124,14 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
|||
|
||||
// {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
|
||||
// note: rope must go first for in-place context shifting in build_rope_shift()
|
||||
ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0);
|
||||
ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
|
||||
cb(kv_cmpr, "kv_cmpr_reshape", il);
|
||||
|
||||
// {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0);
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
// {kv_lora_rank, 1, n_tokens}
|
||||
|
|
@ -169,11 +169,10 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
|||
Vcur = ggml_cont(ctx0, Vcur);
|
||||
cb(Vcur, "Vcur_cont", il);
|
||||
|
||||
// note: rope must go first for in-place context shifting in build_rope_shift()
|
||||
ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope, 0);
|
||||
ggml_tensor * Qcur = ggml_concat(ctx0, q_nope, q_pe, 0);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0);
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
if (inp_attn_scale) {
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap
|
|||
|
||||
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
|
||||
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
|
||||
ggml_tensor * cur;
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params &
|
|||
|
||||
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
|
||||
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
|
||||
ggml_tensor * cur;
|
||||
|
|
|
|||
|
|
@ -6122,7 +6122,19 @@ struct test_flash_attn_ext : public test_case {
|
|||
ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1], true); // the K tensor is usually a view of the K cache
|
||||
ggml_set_name(k, "k");
|
||||
|
||||
ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache
|
||||
ggml_tensor * v = nullptr;
|
||||
if (hsk_padded == 576 && hsv_padded == 512) {
|
||||
// TODO: this branch should become a separate test case parameter instead of hardcoding this for these head shapes
|
||||
|
||||
// in this branch, the V cache is sub-view of the K cache. this is used by some MLA-based models
|
||||
// for more info:
|
||||
// - https://github.com/ggml-org/llama.cpp/pull/13435
|
||||
// - https://github.com/ggml-org/llama.cpp/pull/18953#issuecomment-3774948392
|
||||
// - https://github.com/ggml-org/llama.cpp/pull/18986
|
||||
v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], 0);
|
||||
} else {
|
||||
v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache
|
||||
}
|
||||
ggml_set_name(v, "v");
|
||||
|
||||
ggml_tensor * m = nullptr;
|
||||
|
|
|
|||
Loading…
Reference in New Issue