From 07ba6d275b0f5c138c72f75d7f3df2661f17c27a Mon Sep 17 00:00:00 2001 From: Chenguang Li <757486878@qq.com> Date: Thu, 19 Mar 2026 11:02:42 +0800 Subject: [PATCH] CANN: support flash attention for head dim not multiple of 16, fix ALiBi slope offset (#20031) - Allow FLASH_ATTN_EXT when head dimension D is not a multiple of 16 by padding Q/K/V to D_padded = GGML_PAD(D, 16), running FusedInferAttentionScoreV2, then slicing the output back to D (ggml-cann.cpp + aclnn_ops.cpp). - Fix aclnn_get_slope second-part offset: use ggml_type_size(dtype) instead of sizeof(float) so ALiBi slopes are correct when dtype is F16 (e.g. GQA with 48 heads); fixes buffer overflow and large numerical errors in those cases. --- ggml/src/ggml-cann/aclnn_ops.cpp | 78 ++++++++++++++++++++++++++++---- ggml/src/ggml-cann/ggml-cann.cpp | 4 -- 2 files changed, 70 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index fc7c3e3b72..4b7aab1e72 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -1544,8 +1544,8 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx, end = 2 * ((n_head - 1) - n_head_log2) + 1; step = 2; count = n_head - n_head_log2; - aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * sizeof(float), m1, count, start, end + 1, step, - dtype); + aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * ggml_type_size(dtype), m1, count, start, end + 1, + step, dtype); } } @@ -3599,6 +3599,44 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst acl_k_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne, src1_bsnd_nb, GGML_MAX_DIMS); acl_v_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, src2_bsnd_nb, GGML_MAX_DIMS); + // Step 2.5: Pad Q, K, V along head dimension if D is not a multiple of 16 + // (required by FusedInferAttentionScoreV2) + const int64_t D = src0->ne[0]; + const int64_t D_padded = GGML_PAD(D, 16); + const bool needs_padding = (D != D_padded); + + ggml_cann_pool_alloc q_pad_allocator(ctx.pool()); + ggml_cann_pool_alloc k_pad_allocator(ctx.pool()); + ggml_cann_pool_alloc v_pad_allocator(ctx.pool()); + + if (needs_padding) { + int64_t paddings[] = { 0, D_padded - D, 0, 0, 0, 0, 0, 0 }; + + auto pad_fa_tensor = [&](acl_tensor_ptr & tensor, const int64_t * bsnd_ne, + ggml_cann_pool_alloc & allocator) { + int64_t pad_ne[GGML_MAX_DIMS] = { D_padded, bsnd_ne[1], bsnd_ne[2], bsnd_ne[3] }; + size_t pad_nb[GGML_MAX_DIMS]; + pad_nb[0] = faElemSize; + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + pad_nb[i] = pad_nb[i - 1] * pad_ne[i - 1]; + } + int64_t nelements = pad_ne[0] * pad_ne[1] * pad_ne[2] * pad_ne[3]; + void * buffer = allocator.alloc(nelements * faElemSize); + acl_tensor_ptr padded = + ggml_cann_create_tensor(buffer, faDataType, faElemSize, pad_ne, pad_nb, GGML_MAX_DIMS); + aclnn_pad(ctx, tensor.get(), padded.get(), paddings); + tensor = std::move(padded); + }; + + pad_fa_tensor(acl_q_tensor, src0_bsnd_ne, q_pad_allocator); + pad_fa_tensor(acl_k_tensor, src1_bsnd_ne, k_pad_allocator); + pad_fa_tensor(acl_v_tensor, src2_bsnd_ne, v_pad_allocator); + + src0_bsnd_ne[0] = D_padded; + src1_bsnd_ne[0] = D_padded; + src2_bsnd_ne[0] = D_padded; + } + // Step 3: create the PSEShift tensor if needed // this tensor is considered as mask (f16) in the llama.cpp acl_tensor_ptr bcast_pse_tensor; @@ -3688,17 +3726,16 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); acl_tensor_ptr fa_dst_tensor; - acl_tensor_ptr acl_dst_tensor; ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); - if (dst->type == GGML_TYPE_F32) { - void * out_f16_buffer = out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); - + if (dst->type == GGML_TYPE_F32 || needs_padding) { int64_t * out_f16_ne = src0_bsnd_ne; size_t out_f16_nb[GGML_MAX_DIMS]; out_f16_nb[0] = faElemSize; for (int i = 1; i < GGML_MAX_DIMS; ++i) { out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; } + int64_t out_nelements = out_f16_ne[0] * out_f16_ne[1] * out_f16_ne[2] * out_f16_ne[3]; + void * out_f16_buffer = out_f16_allocator.alloc(out_nelements * faElemSize); fa_dst_tensor = ggml_cann_create_tensor(out_f16_buffer, faDataType, faElemSize, out_f16_ne, out_f16_nb, GGML_MAX_DIMS); @@ -3730,8 +3767,33 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst nullptr // softmaxLse ); - if (dst->type == GGML_TYPE_F32) { - // Step 6: post-processing, permute and cast to f32 + // Step 6: post-processing — slice padded output and/or cast to f32 + if (needs_padding) { + ggml_cann_pool_alloc sliced_f16_allocator(ctx.pool()); + + if (dst->type == GGML_TYPE_F32) { + int64_t sliced_ne[GGML_MAX_DIMS] = { D, src0_bsnd_ne[1], src0_bsnd_ne[2], src0_bsnd_ne[3] }; + size_t sliced_nb[GGML_MAX_DIMS]; + sliced_nb[0] = faElemSize; + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + sliced_nb[i] = sliced_nb[i - 1] * sliced_ne[i - 1]; + } + int64_t sliced_nelements = sliced_ne[0] * sliced_ne[1] * sliced_ne[2] * sliced_ne[3]; + void * sliced_buffer = sliced_f16_allocator.alloc(sliced_nelements * faElemSize); + acl_tensor_ptr sliced_f16_tensor = ggml_cann_create_tensor(sliced_buffer, faDataType, faElemSize, + sliced_ne, sliced_nb, GGML_MAX_DIMS); + + GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(), + (int64_t) -1, (int64_t) 0, D, (int64_t) 1, sliced_f16_tensor.get()); + + acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst); + aclnn_cast(ctx, sliced_f16_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type)); + } else { + acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst); + GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(), + (int64_t) -1, (int64_t) 0, D, (int64_t) 1, acl_dst_tensor.get()); + } + } else if (dst->type == GGML_TYPE_F32) { acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst); aclnn_cast(ctx, fa_dst_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type)); } diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 3f3de9f0bc..a682746bb4 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2503,10 +2503,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten // different head sizes of K and V are not supported yet return false; } - if (op->src[0]->ne[0] % 16 != 0) { - // TODO: padding to support - return false; - } float logitSoftcap = 0.0f; memcpy(&logitSoftcap, (const float *) (op->op_params) + 2, sizeof(float)); if (logitSoftcap != 0.0f) {