model : used new GGML_OP_WHERE_ID op in DeepSeek V3.2 lightning indexer implementation

This commit is contained in:
Stanisław Szymczyk 2026-03-15 22:09:33 +01:00
parent 08dc7fd9d9
commit 998f496475
1 changed files with 6 additions and 33 deletions

View File

@ -2,37 +2,6 @@
#include "llama-kv-cache.h"
void mask_top_k_callback(struct ggml_tensor * dst, const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata) {
// a = kq_mask, b = top_k, dst = output tensor
const int n_seq = a->ne[1];
const int n_tokens = a->ne[0];
const int k = b->ne[0];
// Get data pointers (assuming F32 for mask, I32 for indices)
const float * mask_data = (const float *) a->data;
const int32_t * topk_data = (const int32_t *) b->data;
float * dst_data = (float *) dst->data;
// Distribute work across threads if nth > 1
const int start_row = (n_seq * ith) / nth;
const int end_row = (n_seq * (ith + 1)) / nth;
for (int i = start_row; i < end_row; ++i) {
// First, set the entire row to -inf
for (int j = 0; j < n_tokens; ++j) {
dst_data[i * n_tokens + j] = -INFINITY;
}
// Then, restore the values indicated by top_k
for (int j = 0; j < k; ++j) {
int32_t keep_idx = topk_data[i * k + j];
if (keep_idx >= 0 && keep_idx < n_tokens) {
dst_data[i * n_tokens + keep_idx] = mask_data[i * n_tokens + keep_idx];
}
}
}
}
llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const bool is_mla = hparams.is_mla();
@ -214,8 +183,12 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
ggml_tensor * top_k = ggml_cont(ctx0, ggml_argsort_top_k(ctx0, indexer_score, n_top_k));
cb(top_k, "top_k", il);
// modify kq mask by masking tokens that are not in top_k indices
ggml_tensor * kq_mask_top_k = ggml_map_custom2(ctx0, kq_mask_f32, top_k, mask_top_k_callback, GGML_DEFAULT_N_THREADS, NULL);
// prepare new kq mask - starts filled with -INFINITY
ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask_f32, -INFINITY);
cb(kq_mask_all, "kq_mask_all", il);
// modify it by unmasking tokens that are in top_k indices
ggml_tensor * kq_mask_top_k = ggml_where_id(ctx0, kq_mask_f32, kq_mask_all, top_k);
cb(kq_mask_top_k, "kq_mask_top_k", il);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_cast(ctx0, kq_mask_top_k, kq_mask->type), kq_mask));