model : crude proof-of-concept implementation of the DSA indexer for DeepSeek V3.2.
This commit is contained in:
parent
961bc95d96
commit
9a63e7ab76
|
|
@ -1,5 +1,38 @@
|
|||
#include "models.h"
|
||||
|
||||
#include "llama-kv-cache.h"
|
||||
|
||||
void mask_top_k_callback(struct ggml_tensor * dst, const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata) {
|
||||
// a = kq_mask, b = top_k, dst = output tensor
|
||||
const int n_seq = a->ne[1];
|
||||
const int n_tokens = a->ne[0];
|
||||
const int k = b->ne[0];
|
||||
|
||||
// Get data pointers (assuming F32 for mask, I32 for indices)
|
||||
const float * mask_data = (const float *) a->data;
|
||||
const int32_t * topk_data = (const int32_t *) b->data;
|
||||
float * dst_data = (float *) dst->data;
|
||||
|
||||
// Distribute work across threads if nth > 1
|
||||
const int start_row = (n_seq * ith) / nth;
|
||||
const int end_row = (n_seq * (ith + 1)) / nth;
|
||||
|
||||
for (int i = start_row; i < end_row; ++i) {
|
||||
// First, set the entire row to -inf
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
dst_data[i * n_tokens + j] = -INFINITY;
|
||||
}
|
||||
|
||||
// Then, restore the values indicated by top_k
|
||||
for (int j = 0; j < k; ++j) {
|
||||
int32_t keep_idx = topk_data[i * k + j];
|
||||
if (keep_idx >= 0 && keep_idx < n_tokens) {
|
||||
dst_data[i * n_tokens + keep_idx] = mask_data[i * n_tokens + keep_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const bool is_mla = hparams.is_mla();
|
||||
|
|
@ -15,6 +48,7 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
|
|||
const int64_t n_embd_indexer_head = hparams.indexer_head_size;
|
||||
const int64_t n_embd_indexer_head_rope = hparams.n_rot();
|
||||
const int64_t n_embd_indexer_head_nope = n_embd_indexer_head - n_embd_indexer_head_rope;
|
||||
const uint32_t n_indexer_top_k = hparams.indexer_top_k;
|
||||
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
|
||||
|
|
@ -60,6 +94,10 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
|
|||
qr = build_norm(qr, model.layers[il].attn_q_a_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(qr, "qr", il);
|
||||
|
||||
ggml_tensor * kq_mask = is_mla ? inp_attn_k->get_kq_mask() : inp_attn_kv->get_kq_mask();
|
||||
ggml_tensor * kq_mask_bak = ggml_dup(ctx0, kq_mask);
|
||||
ggml_build_forward_expand(gf, kq_mask_bak);
|
||||
|
||||
// lightning indexer
|
||||
{
|
||||
ggml_tensor * indexer_q = ggml_mul_mat(ctx0, model.layers[il].indexer_attn_q_b, qr);
|
||||
|
|
@ -119,8 +157,68 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
|
|||
indexer_k = ggml_concat(ctx0, indexer_k_pe, indexer_k_nope, 0);
|
||||
cb(indexer_k, "indexer_k", il);
|
||||
|
||||
ggml_build_forward_expand(gf, indexer_q);
|
||||
ggml_build_forward_expand(gf, indexer_k);
|
||||
indexer_q = ggml_hadamard(ctx0, indexer_q, n_embd_indexer_head);
|
||||
cb(indexer_q, "indexer_q", il);
|
||||
indexer_k = ggml_hadamard(ctx0, indexer_k, n_embd_indexer_head);
|
||||
cb(indexer_k, "indexer_k", il);
|
||||
|
||||
// store to KV cache
|
||||
const auto * mctx_cur = is_mla ? inp_attn_k->mctx : inp_attn_kv->mctx;
|
||||
const auto & k_idxs = is_mla ? inp_attn_k->get_k_idxs() : inp_attn_kv->get_k_idxs();
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_ik(ctx0, indexer_k, k_idxs, il));
|
||||
|
||||
ggml_tensor * indexer_weights = ggml_mul_mat(ctx0, model.layers[il].indexer_proj, cur);
|
||||
cb(indexer_weights, "indexer_weights", il);
|
||||
|
||||
indexer_weights = ggml_scale(ctx0, indexer_weights, 1.0f / sqrtf(float(n_indexer_head)));
|
||||
cb(indexer_weights, "indexer_weights", il);
|
||||
|
||||
// get cached indexer keys
|
||||
indexer_k = mctx_cur->get_ik(ctx0, il);
|
||||
|
||||
indexer_q = ggml_permute(ctx0, indexer_q, 0, 2, 1, 3);
|
||||
cb(indexer_q, "indexer_q", il);
|
||||
indexer_k = ggml_permute(ctx0, indexer_k, 0, 2, 1, 3);
|
||||
cb(indexer_k, "indexer_k", il);
|
||||
|
||||
ggml_tensor * indexer_kq = ggml_mul_mat(ctx0, indexer_k, indexer_q);
|
||||
cb(indexer_kq, "indexer_kq", il);
|
||||
|
||||
indexer_kq = ggml_cont(ctx0, ggml_permute(ctx0, indexer_kq, 2, 1, 0, 3));
|
||||
cb(indexer_kq, "indexer_kq", il);
|
||||
|
||||
ggml_tensor * indexer_score = ggml_relu(ctx0, indexer_kq);
|
||||
cb(indexer_score, "indexer_score", il);
|
||||
|
||||
indexer_score = ggml_mul(ctx0, indexer_score, indexer_weights);
|
||||
cb(indexer_score, "indexer_score", il);
|
||||
|
||||
indexer_score = ggml_sum_rows(ctx0, indexer_score);
|
||||
cb(indexer_score, "indexer_score", il);
|
||||
|
||||
indexer_score = ggml_permute(ctx0, indexer_score, 2, 1, 0, 3);
|
||||
cb(indexer_score, "indexer_score", il);
|
||||
|
||||
indexer_score = ggml_cont(ctx0, indexer_score);
|
||||
cb(indexer_score, "indexer_score", il);
|
||||
|
||||
indexer_score = ggml_scale(ctx0, indexer_score, 1.0f / sqrtf(float(n_embd_indexer_head)));
|
||||
cb(indexer_score, "indexer_score", il);
|
||||
|
||||
// mask indexer scores
|
||||
ggml_tensor * kq_mask_f32 = ggml_cast(ctx0, kq_mask, GGML_TYPE_F32);
|
||||
indexer_score = ggml_add(ctx0, indexer_score, kq_mask_f32);
|
||||
cb(indexer_score, "indexer_score", il);
|
||||
|
||||
uint32_t n_top_k = indexer_score->ne[0] < n_indexer_top_k ? indexer_score->ne[0] : n_indexer_top_k;
|
||||
ggml_tensor * top_k = ggml_cont(ctx0, ggml_argsort_top_k(ctx0, indexer_score, n_top_k));
|
||||
cb(top_k, "top_k", il);
|
||||
|
||||
// modify kq mask by masking tokens that are not in top_k indices
|
||||
ggml_tensor * kq_mask_top_k = ggml_map_custom2(ctx0, kq_mask_f32, top_k, mask_top_k_callback, GGML_DEFAULT_N_THREADS, NULL);
|
||||
cb(kq_mask_top_k, "kq_mask_top_k", il);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_cast(ctx0, kq_mask_top_k, kq_mask->type), kq_mask));
|
||||
}
|
||||
|
||||
ggml_tensor * q = ggml_mul_mat(ctx0, model.layers[il].wq_b, qr);
|
||||
|
|
@ -230,6 +328,8 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
|
|||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, kq_mask_bak, kq_mask));
|
||||
}
|
||||
if (il == effective_n_layers - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
|
|
|
|||
Loading…
Reference in New Issue