diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 29d804638c..21a4158c79 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2171,17 +2171,14 @@ ggml_tensor * llm_graph_context::build_attn( const auto & kq_mask = inp->get_kq_mask(); - ggml_tensor * kq_mask_f32 = ggml_cast(ctx0, kq_mask, GGML_TYPE_F32); - // prepare new kq mask - starts filled with -INFINITY - ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask_f32, -INFINITY); + ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask, -INFINITY); // modify it by unmasking tokens that are in top_k indices ggml_tensor * kq_mask_top_k = ggml_scatter(ctx0, kq_mask_all, top_k, 0); // combine with the original kq mask - kq_mask_top_k = ggml_add(ctx0, kq_mask_top_k, kq_mask_f32); - kq_mask_top_k = ggml_cast(ctx0, kq_mask_top_k, kq_mask->type); + kq_mask_top_k = ggml_add(ctx0, kq_mask_top_k, kq_mask); ggml_tensor * q = q_cur; ggml_tensor * k = mctx_cur->get_k(ctx0, il);