model : GGML_OP_SCATTER AND GGML_OP_FILL now work with f16 data, so we can get rid of ggml_cast() calls in sparse attention implementation

This commit is contained in:
Stanisław Szymczyk 2026-03-25 11:35:57 +01:00
parent 1c830a178b
commit 83a0313a14
1 changed files with 2 additions and 5 deletions

View File

@ -2171,17 +2171,14 @@ ggml_tensor * llm_graph_context::build_attn(
const auto & kq_mask = inp->get_kq_mask();
ggml_tensor * kq_mask_f32 = ggml_cast(ctx0, kq_mask, GGML_TYPE_F32);
// prepare new kq mask - starts filled with -INFINITY
ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask_f32, -INFINITY);
ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask, -INFINITY);
// modify it by unmasking tokens that are in top_k indices
ggml_tensor * kq_mask_top_k = ggml_scatter(ctx0, kq_mask_all, top_k, 0);
// combine with the original kq mask
kq_mask_top_k = ggml_add(ctx0, kq_mask_top_k, kq_mask_f32);
kq_mask_top_k = ggml_cast(ctx0, kq_mask_top_k, kq_mask->type);
kq_mask_top_k = ggml_add(ctx0, kq_mask_top_k, kq_mask);
ggml_tensor * q = q_cur;
ggml_tensor * k = mctx_cur->get_k(ctx0, il);