model : GGML_OP_SCATTER AND GGML_OP_FILL now work with f16 data, so we can get rid of ggml_cast() calls in sparse attention implementation
This commit is contained in:
parent
1c830a178b
commit
83a0313a14
|
|
@ -2171,17 +2171,14 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
|
||||
ggml_tensor * kq_mask_f32 = ggml_cast(ctx0, kq_mask, GGML_TYPE_F32);
|
||||
|
||||
// prepare new kq mask - starts filled with -INFINITY
|
||||
ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask_f32, -INFINITY);
|
||||
ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask, -INFINITY);
|
||||
|
||||
// modify it by unmasking tokens that are in top_k indices
|
||||
ggml_tensor * kq_mask_top_k = ggml_scatter(ctx0, kq_mask_all, top_k, 0);
|
||||
|
||||
// combine with the original kq mask
|
||||
kq_mask_top_k = ggml_add(ctx0, kq_mask_top_k, kq_mask_f32);
|
||||
kq_mask_top_k = ggml_cast(ctx0, kq_mask_top_k, kq_mask->type);
|
||||
kq_mask_top_k = ggml_add(ctx0, kq_mask_top_k, kq_mask);
|
||||
|
||||
ggml_tensor * q = q_cur;
|
||||
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||
|
|
|
|||
Loading…
Reference in New Issue