From 9a63e7ab76b435cbb87d9bdebcb023382475066b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Sat, 14 Mar 2026 20:20:39 +0100 Subject: [PATCH] model : crude proof-of-concept implementation of the DSA indexer for DeepSeek V3.2. --- src/models/deepseek32.cpp | 104 +++++++++++++++++++++++++++++++++++++- 1 file changed, 102 insertions(+), 2 deletions(-) diff --git a/src/models/deepseek32.cpp b/src/models/deepseek32.cpp index 836c582850..20d31f73ac 100644 --- a/src/models/deepseek32.cpp +++ b/src/models/deepseek32.cpp @@ -1,5 +1,38 @@ #include "models.h" +#include "llama-kv-cache.h" + +void mask_top_k_callback(struct ggml_tensor * dst, const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata) { + // a = kq_mask, b = top_k, dst = output tensor + const int n_seq = a->ne[1]; + const int n_tokens = a->ne[0]; + const int k = b->ne[0]; + + // Get data pointers (assuming F32 for mask, I32 for indices) + const float * mask_data = (const float *) a->data; + const int32_t * topk_data = (const int32_t *) b->data; + float * dst_data = (float *) dst->data; + + // Distribute work across threads if nth > 1 + const int start_row = (n_seq * ith) / nth; + const int end_row = (n_seq * (ith + 1)) / nth; + + for (int i = start_row; i < end_row; ++i) { + // First, set the entire row to -inf + for (int j = 0; j < n_tokens; ++j) { + dst_data[i * n_tokens + j] = -INFINITY; + } + + // Then, restore the values indicated by top_k + for (int j = 0; j < k; ++j) { + int32_t keep_idx = topk_data[i * k + j]; + if (keep_idx >= 0 && keep_idx < n_tokens) { + dst_data[i * n_tokens + keep_idx] = mask_data[i * n_tokens + keep_idx]; + } + } + } +} + llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const bool is_mla = hparams.is_mla(); @@ -15,6 +48,7 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_ const int64_t n_embd_indexer_head = hparams.indexer_head_size; const int64_t n_embd_indexer_head_rope = hparams.n_rot(); const int64_t n_embd_indexer_head_nope = n_embd_indexer_head - n_embd_indexer_head_rope; + const uint32_t n_indexer_top_k = hparams.indexer_top_k; const uint32_t kv_lora_rank = hparams.n_lora_kv; @@ -60,6 +94,10 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_ qr = build_norm(qr, model.layers[il].attn_q_a_norm, nullptr, LLM_NORM_RMS, il); cb(qr, "qr", il); + ggml_tensor * kq_mask = is_mla ? inp_attn_k->get_kq_mask() : inp_attn_kv->get_kq_mask(); + ggml_tensor * kq_mask_bak = ggml_dup(ctx0, kq_mask); + ggml_build_forward_expand(gf, kq_mask_bak); + // lightning indexer { ggml_tensor * indexer_q = ggml_mul_mat(ctx0, model.layers[il].indexer_attn_q_b, qr); @@ -119,8 +157,68 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_ indexer_k = ggml_concat(ctx0, indexer_k_pe, indexer_k_nope, 0); cb(indexer_k, "indexer_k", il); - ggml_build_forward_expand(gf, indexer_q); - ggml_build_forward_expand(gf, indexer_k); + indexer_q = ggml_hadamard(ctx0, indexer_q, n_embd_indexer_head); + cb(indexer_q, "indexer_q", il); + indexer_k = ggml_hadamard(ctx0, indexer_k, n_embd_indexer_head); + cb(indexer_k, "indexer_k", il); + + // store to KV cache + const auto * mctx_cur = is_mla ? inp_attn_k->mctx : inp_attn_kv->mctx; + const auto & k_idxs = is_mla ? inp_attn_k->get_k_idxs() : inp_attn_kv->get_k_idxs(); + ggml_build_forward_expand(gf, mctx_cur->cpy_ik(ctx0, indexer_k, k_idxs, il)); + + ggml_tensor * indexer_weights = ggml_mul_mat(ctx0, model.layers[il].indexer_proj, cur); + cb(indexer_weights, "indexer_weights", il); + + indexer_weights = ggml_scale(ctx0, indexer_weights, 1.0f / sqrtf(float(n_indexer_head))); + cb(indexer_weights, "indexer_weights", il); + + // get cached indexer keys + indexer_k = mctx_cur->get_ik(ctx0, il); + + indexer_q = ggml_permute(ctx0, indexer_q, 0, 2, 1, 3); + cb(indexer_q, "indexer_q", il); + indexer_k = ggml_permute(ctx0, indexer_k, 0, 2, 1, 3); + cb(indexer_k, "indexer_k", il); + + ggml_tensor * indexer_kq = ggml_mul_mat(ctx0, indexer_k, indexer_q); + cb(indexer_kq, "indexer_kq", il); + + indexer_kq = ggml_cont(ctx0, ggml_permute(ctx0, indexer_kq, 2, 1, 0, 3)); + cb(indexer_kq, "indexer_kq", il); + + ggml_tensor * indexer_score = ggml_relu(ctx0, indexer_kq); + cb(indexer_score, "indexer_score", il); + + indexer_score = ggml_mul(ctx0, indexer_score, indexer_weights); + cb(indexer_score, "indexer_score", il); + + indexer_score = ggml_sum_rows(ctx0, indexer_score); + cb(indexer_score, "indexer_score", il); + + indexer_score = ggml_permute(ctx0, indexer_score, 2, 1, 0, 3); + cb(indexer_score, "indexer_score", il); + + indexer_score = ggml_cont(ctx0, indexer_score); + cb(indexer_score, "indexer_score", il); + + indexer_score = ggml_scale(ctx0, indexer_score, 1.0f / sqrtf(float(n_embd_indexer_head))); + cb(indexer_score, "indexer_score", il); + + // mask indexer scores + ggml_tensor * kq_mask_f32 = ggml_cast(ctx0, kq_mask, GGML_TYPE_F32); + indexer_score = ggml_add(ctx0, indexer_score, kq_mask_f32); + cb(indexer_score, "indexer_score", il); + + uint32_t n_top_k = indexer_score->ne[0] < n_indexer_top_k ? indexer_score->ne[0] : n_indexer_top_k; + ggml_tensor * top_k = ggml_cont(ctx0, ggml_argsort_top_k(ctx0, indexer_score, n_top_k)); + cb(top_k, "top_k", il); + + // modify kq mask by masking tokens that are not in top_k indices + ggml_tensor * kq_mask_top_k = ggml_map_custom2(ctx0, kq_mask_f32, top_k, mask_top_k_callback, GGML_DEFAULT_N_THREADS, NULL); + cb(kq_mask_top_k, "kq_mask_top_k", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_cast(ctx0, kq_mask_top_k, kq_mask->type), kq_mask)); } ggml_tensor * q = ggml_mul_mat(ctx0, model.layers[il].wq_b, qr); @@ -230,6 +328,8 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_ model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, kq_mask_bak, kq_mask)); } if (il == effective_n_layers - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids);