model : GGML_OP_SCATTER AND GGML_OP_FILL now work with f16 data, so we can get rid of ggml_cast() calls in sparse attention implementation
This commit is contained in:
parent
1c830a178b
commit
83a0313a14
|
|
@ -2171,17 +2171,14 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
|
|
||||||
const auto & kq_mask = inp->get_kq_mask();
|
const auto & kq_mask = inp->get_kq_mask();
|
||||||
|
|
||||||
ggml_tensor * kq_mask_f32 = ggml_cast(ctx0, kq_mask, GGML_TYPE_F32);
|
|
||||||
|
|
||||||
// prepare new kq mask - starts filled with -INFINITY
|
// prepare new kq mask - starts filled with -INFINITY
|
||||||
ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask_f32, -INFINITY);
|
ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask, -INFINITY);
|
||||||
|
|
||||||
// modify it by unmasking tokens that are in top_k indices
|
// modify it by unmasking tokens that are in top_k indices
|
||||||
ggml_tensor * kq_mask_top_k = ggml_scatter(ctx0, kq_mask_all, top_k, 0);
|
ggml_tensor * kq_mask_top_k = ggml_scatter(ctx0, kq_mask_all, top_k, 0);
|
||||||
|
|
||||||
// combine with the original kq mask
|
// combine with the original kq mask
|
||||||
kq_mask_top_k = ggml_add(ctx0, kq_mask_top_k, kq_mask_f32);
|
kq_mask_top_k = ggml_add(ctx0, kq_mask_top_k, kq_mask);
|
||||||
kq_mask_top_k = ggml_cast(ctx0, kq_mask_top_k, kq_mask->type);
|
|
||||||
|
|
||||||
ggml_tensor * q = q_cur;
|
ggml_tensor * q = q_cur;
|
||||||
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue