model : used new GGML_OP_WHERE_ID op in DeepSeek V3.2 lightning indexer implementation
This commit is contained in:
parent
08dc7fd9d9
commit
998f496475
|
|
@ -2,37 +2,6 @@
|
|||
|
||||
#include "llama-kv-cache.h"
|
||||
|
||||
void mask_top_k_callback(struct ggml_tensor * dst, const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata) {
|
||||
// a = kq_mask, b = top_k, dst = output tensor
|
||||
const int n_seq = a->ne[1];
|
||||
const int n_tokens = a->ne[0];
|
||||
const int k = b->ne[0];
|
||||
|
||||
// Get data pointers (assuming F32 for mask, I32 for indices)
|
||||
const float * mask_data = (const float *) a->data;
|
||||
const int32_t * topk_data = (const int32_t *) b->data;
|
||||
float * dst_data = (float *) dst->data;
|
||||
|
||||
// Distribute work across threads if nth > 1
|
||||
const int start_row = (n_seq * ith) / nth;
|
||||
const int end_row = (n_seq * (ith + 1)) / nth;
|
||||
|
||||
for (int i = start_row; i < end_row; ++i) {
|
||||
// First, set the entire row to -inf
|
||||
for (int j = 0; j < n_tokens; ++j) {
|
||||
dst_data[i * n_tokens + j] = -INFINITY;
|
||||
}
|
||||
|
||||
// Then, restore the values indicated by top_k
|
||||
for (int j = 0; j < k; ++j) {
|
||||
int32_t keep_idx = topk_data[i * k + j];
|
||||
if (keep_idx >= 0 && keep_idx < n_tokens) {
|
||||
dst_data[i * n_tokens + keep_idx] = mask_data[i * n_tokens + keep_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const bool is_mla = hparams.is_mla();
|
||||
|
|
@ -214,8 +183,12 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
|
|||
ggml_tensor * top_k = ggml_cont(ctx0, ggml_argsort_top_k(ctx0, indexer_score, n_top_k));
|
||||
cb(top_k, "top_k", il);
|
||||
|
||||
// modify kq mask by masking tokens that are not in top_k indices
|
||||
ggml_tensor * kq_mask_top_k = ggml_map_custom2(ctx0, kq_mask_f32, top_k, mask_top_k_callback, GGML_DEFAULT_N_THREADS, NULL);
|
||||
// prepare new kq mask - starts filled with -INFINITY
|
||||
ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask_f32, -INFINITY);
|
||||
cb(kq_mask_all, "kq_mask_all", il);
|
||||
|
||||
// modify it by unmasking tokens that are in top_k indices
|
||||
ggml_tensor * kq_mask_top_k = ggml_where_id(ctx0, kq_mask_f32, kq_mask_all, top_k);
|
||||
cb(kq_mask_top_k, "kq_mask_top_k", il);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_cast(ctx0, kq_mask_top_k, kq_mask->type), kq_mask));
|
||||
|
|
|
|||
Loading…
Reference in New Issue