model : crude proof-of-concept implementation of the DSA indexer for DeepSeek V3.2.

This commit is contained in:
Stanisław Szymczyk 2026-03-14 20:20:39 +01:00
parent 961bc95d96
commit 9a63e7ab76
1 changed files with 102 additions and 2 deletions

View File

@ -1,5 +1,38 @@
#include "models.h"
#include "llama-kv-cache.h"
void mask_top_k_callback(struct ggml_tensor * dst, const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata) {
// a = kq_mask, b = top_k, dst = output tensor
const int n_seq = a->ne[1];
const int n_tokens = a->ne[0];
const int k = b->ne[0];
// Get data pointers (assuming F32 for mask, I32 for indices)
const float * mask_data = (const float *) a->data;
const int32_t * topk_data = (const int32_t *) b->data;
float * dst_data = (float *) dst->data;
// Distribute work across threads if nth > 1
const int start_row = (n_seq * ith) / nth;
const int end_row = (n_seq * (ith + 1)) / nth;
for (int i = start_row; i < end_row; ++i) {
// First, set the entire row to -inf
for (int j = 0; j < n_tokens; ++j) {
dst_data[i * n_tokens + j] = -INFINITY;
}
// Then, restore the values indicated by top_k
for (int j = 0; j < k; ++j) {
int32_t keep_idx = topk_data[i * k + j];
if (keep_idx >= 0 && keep_idx < n_tokens) {
dst_data[i * n_tokens + keep_idx] = mask_data[i * n_tokens + keep_idx];
}
}
}
}
llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const bool is_mla = hparams.is_mla();
@ -15,6 +48,7 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
const int64_t n_embd_indexer_head = hparams.indexer_head_size;
const int64_t n_embd_indexer_head_rope = hparams.n_rot();
const int64_t n_embd_indexer_head_nope = n_embd_indexer_head - n_embd_indexer_head_rope;
const uint32_t n_indexer_top_k = hparams.indexer_top_k;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
@ -60,6 +94,10 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
qr = build_norm(qr, model.layers[il].attn_q_a_norm, nullptr, LLM_NORM_RMS, il);
cb(qr, "qr", il);
ggml_tensor * kq_mask = is_mla ? inp_attn_k->get_kq_mask() : inp_attn_kv->get_kq_mask();
ggml_tensor * kq_mask_bak = ggml_dup(ctx0, kq_mask);
ggml_build_forward_expand(gf, kq_mask_bak);
// lightning indexer
{
ggml_tensor * indexer_q = ggml_mul_mat(ctx0, model.layers[il].indexer_attn_q_b, qr);
@ -119,8 +157,68 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
indexer_k = ggml_concat(ctx0, indexer_k_pe, indexer_k_nope, 0);
cb(indexer_k, "indexer_k", il);
ggml_build_forward_expand(gf, indexer_q);
ggml_build_forward_expand(gf, indexer_k);
indexer_q = ggml_hadamard(ctx0, indexer_q, n_embd_indexer_head);
cb(indexer_q, "indexer_q", il);
indexer_k = ggml_hadamard(ctx0, indexer_k, n_embd_indexer_head);
cb(indexer_k, "indexer_k", il);
// store to KV cache
const auto * mctx_cur = is_mla ? inp_attn_k->mctx : inp_attn_kv->mctx;
const auto & k_idxs = is_mla ? inp_attn_k->get_k_idxs() : inp_attn_kv->get_k_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_ik(ctx0, indexer_k, k_idxs, il));
ggml_tensor * indexer_weights = ggml_mul_mat(ctx0, model.layers[il].indexer_proj, cur);
cb(indexer_weights, "indexer_weights", il);
indexer_weights = ggml_scale(ctx0, indexer_weights, 1.0f / sqrtf(float(n_indexer_head)));
cb(indexer_weights, "indexer_weights", il);
// get cached indexer keys
indexer_k = mctx_cur->get_ik(ctx0, il);
indexer_q = ggml_permute(ctx0, indexer_q, 0, 2, 1, 3);
cb(indexer_q, "indexer_q", il);
indexer_k = ggml_permute(ctx0, indexer_k, 0, 2, 1, 3);
cb(indexer_k, "indexer_k", il);
ggml_tensor * indexer_kq = ggml_mul_mat(ctx0, indexer_k, indexer_q);
cb(indexer_kq, "indexer_kq", il);
indexer_kq = ggml_cont(ctx0, ggml_permute(ctx0, indexer_kq, 2, 1, 0, 3));
cb(indexer_kq, "indexer_kq", il);
ggml_tensor * indexer_score = ggml_relu(ctx0, indexer_kq);
cb(indexer_score, "indexer_score", il);
indexer_score = ggml_mul(ctx0, indexer_score, indexer_weights);
cb(indexer_score, "indexer_score", il);
indexer_score = ggml_sum_rows(ctx0, indexer_score);
cb(indexer_score, "indexer_score", il);
indexer_score = ggml_permute(ctx0, indexer_score, 2, 1, 0, 3);
cb(indexer_score, "indexer_score", il);
indexer_score = ggml_cont(ctx0, indexer_score);
cb(indexer_score, "indexer_score", il);
indexer_score = ggml_scale(ctx0, indexer_score, 1.0f / sqrtf(float(n_embd_indexer_head)));
cb(indexer_score, "indexer_score", il);
// mask indexer scores
ggml_tensor * kq_mask_f32 = ggml_cast(ctx0, kq_mask, GGML_TYPE_F32);
indexer_score = ggml_add(ctx0, indexer_score, kq_mask_f32);
cb(indexer_score, "indexer_score", il);
uint32_t n_top_k = indexer_score->ne[0] < n_indexer_top_k ? indexer_score->ne[0] : n_indexer_top_k;
ggml_tensor * top_k = ggml_cont(ctx0, ggml_argsort_top_k(ctx0, indexer_score, n_top_k));
cb(top_k, "top_k", il);
// modify kq mask by masking tokens that are not in top_k indices
ggml_tensor * kq_mask_top_k = ggml_map_custom2(ctx0, kq_mask_f32, top_k, mask_top_k_callback, GGML_DEFAULT_N_THREADS, NULL);
cb(kq_mask_top_k, "kq_mask_top_k", il);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_cast(ctx0, kq_mask_top_k, kq_mask->type), kq_mask));
}
ggml_tensor * q = ggml_mul_mat(ctx0, model.layers[il].wq_b, qr);
@ -230,6 +328,8 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
}
ggml_build_forward_expand(gf, ggml_cpy(ctx0, kq_mask_bak, kq_mask));
}
if (il == effective_n_layers - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);