From baa4ba0aecde5bce2e801b6bd7ecf020219bf2b7 Mon Sep 17 00:00:00 2001 From: hipudding Date: Fri, 16 Jan 2026 16:18:49 +0800 Subject: [PATCH] CANN: support gated linear attn (#18653) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * CANN: support gated linear attn This change adds support for the GGML_OP_GATED_LINEAR_ATTN operator. The feature was implemented by YushengZhao. Because the previous submission was based on an outdated codebase, this PR was rebased to merge. Co-authored-by: YushengZhao Co-authored-by: hipudding * CANN: optimize OP gla Optimize gla for high preformance * Remove unused comments --------- Co-authored-by: 赵禹昇 <2501112001@cninfer02.localdomain> Co-authored-by: YushengZhao --- ggml/src/ggml-cann/aclnn_ops.cpp | 220 ++++++++++++++++++++----------- ggml/src/ggml-cann/aclnn_ops.h | 123 ++++++----------- ggml/src/ggml-cann/ggml-cann.cpp | 4 + 3 files changed, 186 insertions(+), 161 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 6b718e01c3..02867e4fdb 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -58,6 +58,7 @@ #include #include #include +#include #include #include #include @@ -2338,20 +2339,21 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx, // Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor. // TODO: acl_yarn_ramp_tensor use rope cache. - bool yarn_ramp_tensor_updated = false; - acl_tensor_ptr acl_yarn_ramp_tensor; + bool yarn_ramp_tensor_updated = false; + acl_tensor_ptr acl_yarn_ramp_tensor; if (ext_factor != 0 && (theta_scale_updated || ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.freq_scale != freq_scale)) { yarn_ramp_tensor_updated = true; if (ctx.rope_cache.yarn_ramp_cache != nullptr) { ACL_CHECK(aclrtFree(ctx.rope_cache.yarn_ramp_cache)); } - ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float), + ACL_MEM_MALLOC_HUGE_FIRST)); // -rope_yarn_ramp // const float y = (i0 / 2 - low) / MAX(0.001f, high - low); // return MIN(1, MAX(0, y)) - 1; - acl_yarn_ramp_tensor = - ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); + acl_yarn_ramp_tensor = ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), + theta_scale_ne, theta_scale_nb, 1); float zero_value = 0, one_value = 1; float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT); @@ -2382,8 +2384,8 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx, GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get()); } else { - acl_yarn_ramp_tensor = - ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); + acl_yarn_ramp_tensor = ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), + theta_scale_ne, theta_scale_nb, 1); } // Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale. if (ext_factor != 0) { @@ -2991,20 +2993,20 @@ void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) { GGML_CANN_CALL_ACLNN_OP(ctx, ArgMax, acl_src.get(), 3, false, acl_dst.get()); } -void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){ +void ggml_cann_conv_transpose_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; ggml_tensor * src1 = dst->src[1]; // stride - int64_t s0 = ((const int32_t*)(dst->op_params))[0]; + int64_t s0 = ((const int32_t *) (dst->op_params))[0]; - acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL); + acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL); acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL); - acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL); // get base information of input and kernel - int64_t input_len = *(src1->ne); - int64_t dst_len = *(dst->ne); + int64_t input_len = *(src1->ne); + int64_t dst_len = *(dst->ne); int64_t kernel_size = *(src0->ne); // set the max kernel size for each conv @@ -3012,56 +3014,55 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds // compute the partition of kernel int64_t part_num = 1; - part_num = (kernel_size + max_kernel_size - 1) / max_kernel_size; + part_num = (kernel_size + max_kernel_size - 1) / max_kernel_size; int64_t strideVal[1]; - strideVal[0] = s0; - acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1); - int64_t paddingVal[] = {0}; - acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1); - int64_t dilationVal[] = {1}; - acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1); - bool transposed = true; - int64_t groups = 1; - int8_t cubeMathType = 0; + strideVal[0] = s0; + acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1); + int64_t paddingVal[] = { 0 }; + acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1); + int64_t dilationVal[] = { 1 }; + acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1); + bool transposed = true; + int64_t groups = 1; + int8_t cubeMathType = 0; #ifdef ASCEND_310P cubeMathType = 1; #endif auto weight_type = ggml_cann_type_mapping(src0->type); - auto dst_type = ggml_cann_type_mapping(dst->type); + auto dst_type = ggml_cann_type_mapping(dst->type); // slice the kernel to make each conv available - int64_t slice_dim = -1; + int64_t slice_dim = -1; int64_t slice_start = 0; - int64_t slice_end = max_kernel_size; - int64_t slice_step = 1; - int64_t interval = max_kernel_size; + int64_t slice_end = max_kernel_size; + int64_t slice_step = 1; + int64_t interval = max_kernel_size; - int64_t left_pad_len = dilationVal[0] * (max_kernel_size - 1) + 1 - 2 * paddingVal[0]; + int64_t left_pad_len = dilationVal[0] * (max_kernel_size - 1) + 1 - 2 * paddingVal[0]; int64_t right_pad_len = 0; - acl_scalar_ptr alpha = nullptr; - float alphaValue = 1.0; - alpha = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT); + acl_scalar_ptr alpha = nullptr; + float alphaValue = 1.0; + alpha = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT); // set zero to destination GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get()); - for(int k = 0; k < part_num; k++){ - + for (int k = 0; k < part_num; k++) { // create part kernel tensor and slice from big kernel slice_start = max_kernel_size * k; - if(k == part_num - 1){ + if (k == part_num - 1) { slice_end = kernel_size; - interval = kernel_size - max_kernel_size * k; - }else{ - slice_end = max_kernel_size * (k+1); + interval = kernel_size - max_kernel_size * k; + } else { + slice_end = max_kernel_size * (k + 1); } int64_t part_ne[4]; - for(int i = 0; i < 4; i++) { + for (int i = 0; i < 4; i++) { part_ne[i] = *(src0->ne + i); } part_ne[0] = interval; @@ -3074,16 +3075,17 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds ggml_cann_pool_alloc part_kernel_allocator; part_kernel_allocator.alloc(ctx.pool(), part_nb[3]); - void* part_kernel_buf = part_kernel_allocator.get(); + void * part_kernel_buf = part_kernel_allocator.get(); - acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type, - ggml_element_size(src0), part_ne, part_nb, 3, ACL_FORMAT_NCL); + acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type, ggml_element_size(src0), + part_ne, part_nb, 3, ACL_FORMAT_NCL); - GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step, part_kernel.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step, + part_kernel.get()); // create the part conv result tensor int64_t part_dst_ne[4]; - for(int i = 0; i < 4; i++){ + for (int i = 0; i < 4; i++) { part_dst_ne[i] = *(dst->ne + i); } part_dst_ne[0] = (input_len - 1) * strideVal[0] - 2 * paddingVal[0] + dilationVal[0] * (part_ne[0] - 1) + 1; @@ -3095,32 +3097,33 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds } ggml_cann_pool_alloc part_dst_allocator; part_dst_allocator.alloc(ctx.pool(), part_dst_nb[3]); - void* part_dst_buf = part_dst_allocator.get(); + void * part_dst_buf = part_dst_allocator.get(); acl_tensor_ptr acl_part_dst = ggml_cann_create_tensor(part_dst_buf, dst_type, ggml_element_size(dst), - part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL); + part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_part_dst.get()); // compute part conv transpose 1d GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), part_kernel.get(), nullptr, stride.get(), - padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(), cubeMathType); + padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(), + cubeMathType); // compute the position of part result in final result int64_t global_start = slice_start; - int64_t global_end = std::min((input_len - 1) * strideVal[0] + slice_end, dst_len); + int64_t global_end = std::min((input_len - 1) * strideVal[0] + slice_end, dst_len); - left_pad_len = global_start; + left_pad_len = global_start; right_pad_len = dst_len - global_end; - std::vector padDataVal = {left_pad_len,right_pad_len}; - acl_int_array_ptr padData = ggml_cann_create_int_array(padDataVal.data(), 2); + std::vector padDataVal = { left_pad_len, right_pad_len }; + acl_int_array_ptr padData = ggml_cann_create_int_array(padDataVal.data(), 2); - acl_scalar_ptr pad_value = nullptr; - float pad_valueVal = 0.0; - pad_value = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT); + acl_scalar_ptr pad_value = nullptr; + float pad_valueVal = 0.0; + pad_value = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT); int64_t conv_result_ne[4]; - for(int i = 0; i < 4; i++){ + for (int i = 0; i < 4; i++) { conv_result_ne[i] = *(dst->ne + i); } @@ -3132,13 +3135,14 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds ggml_cann_pool_alloc conv_result_allocator; conv_result_allocator.alloc(ctx.pool(), conv_result_nb[3]); - void* conv_result_buf = conv_result_allocator.get(); + void * conv_result_buf = conv_result_allocator.get(); acl_tensor_ptr conv_result = ggml_cann_create_tensor(conv_result_buf, dst_type, ggml_element_size(dst), - conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL); + conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, conv_result.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(), conv_result.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(), + conv_result.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst.get(), conv_result.get(), alpha.get()); } } @@ -3742,15 +3746,15 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) { // we want a view: ne_w = { nc, 1, nr } // [K, 1, C] // so that reversed dims -> [C, 1, K] which matches // [out_channels, in_channels/groups, kernel_size] - int64_t w_ne[GGML_MAX_DIMS] = { nc, 1, nr, 1 }; // [K, 1 input ch. per group, C groups] + int64_t w_ne[GGML_MAX_DIMS] = { nc, 1, nr, 1 }; // [K, 1 input ch. per group, C groups] // Layout: src1 data is [K, C] with // offset(k, c) = k*nb0 + c*nb1 // We want offset_w(k, 0, c) = k*nb0 + c*nb1, // so we can reuse nb0 and nb1, and set nb2 = nb1. - size_t w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] }; // same as src1 + size_t w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] }; // same as src1 - acl_tensor_ptr acl_w = ggml_cann_create_tensor( - src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL); + acl_tensor_ptr acl_w = ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), + ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL); // 3) Output: dst is { d_inner, n_t, n_s } (CLN) // @@ -3768,11 +3772,12 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) { // nb_y[0] = nr * sizeof(float); // step in L // nb_y[1] = sizeof(float); // step in C // nb_y[2] = nr * n_t * sizeof(float); // step in N - int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 }; // [L_out, C, N] - size_t y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float), dst->nb[3] }; // [nr, 1, nr * n_t] + int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 }; // [L_out, C, N] + size_t y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float), + dst->nb[3] }; // [nr, 1, nr * n_t] - acl_tensor_ptr acl_y = ggml_cann_create_tensor( - dst->data, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL); + acl_tensor_ptr acl_y = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL); // --- Conv1d parameters: depthwise, stride 1, no padding ("valid") --- int64_t strideVal[1] = { 1 }; @@ -3791,22 +3796,15 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) { cubeMathType = 1; #endif - GGML_CANN_CALL_ACLNN_OP(ctx, - Convolution, + GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_x.get(), // input: N, C, L_in = ncs acl_w.get(), // weight: [C, 1, K] with groups=nr nullptr, // bias - stride.get(), - padding.get(), - dilation.get(), - transposed, - padding.get(), // output padding (unused for non-transposed) - groups, - acl_y.get(), - cubeMathType); + stride.get(), padding.get(), dilation.get(), transposed, + padding.get(), // output padding (unused for non-transposed) + groups, acl_y.get(), cubeMathType); } - void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, ggml_tensor * add_node, ggml_tensor * rms_norm_node) { @@ -3860,3 +3858,71 @@ void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, eps, // double type acl_yout.get(), acl_rstd.get(), acl_xout.get()); } + +void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * k = dst->src[0]; + ggml_tensor * v = dst->src[1]; + ggml_tensor * q = dst->src[2]; + ggml_tensor * g = dst->src[3]; + ggml_tensor * s = dst->src[4]; + + int64_t B = dst->src[4]->ne[1]; + int64_t T = dst->src[0]->ne[2]; + int64_t H = dst->src[0]->ne[1]; + int64_t C = dst->ne[0]; + int64_t D = C / H; + int64_t L = T / B; + + int64_t ne_qkg[2] = { 1, D }; + int64_t ne_s[2] = { D, D }; + int64_t ne_st[2] = { ne_s[1], ne_s[0] }; + int64_t ne_vo[2] = { D, 1 }; + int64_t ne_q[1] = { D }; + size_t nb_base = ggml_type_size(k->type); + size_t nb_qkg[2] = { nb_base, nb_base }; + size_t nb_s[2] = { nb_base, D * nb_base }; + size_t nb_st[2] = { nb_s[1], nb_s[0] }; + size_t nb_vo[2] = { nb_base, D * nb_base }; + size_t nb_q[1] = { nb_base }; + + const float scale = ggml_get_op_params_f32(dst, 0); + + acl_tensor_ptr acl_s = ggml_cann_create_tensor(s, s->ne, s->nb, 2, ACL_FORMAT_ND); + acl_tensor_ptr new_state = ggml_cann_create_tensor(dst, s->ne, s->nb, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base); + cann_copy(ctx, acl_s.get(), new_state.get()); + + for (int64_t b = 0; b < B; b++) { + for (int64_t h = 0; h < H; h++) { + size_t s_offset = (b * (H * D * D) + h * (D * D)) * nb_base; + // D * D + acl_tensor_ptr acl_s_new = + ggml_cann_create_tensor(dst, ne_s, nb_s, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset); + acl_tensor_ptr acl_s_new_t = + ggml_cann_create_tensor(dst, ne_st, nb_st, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset); + for (int64_t l = 0; l < L; l++) { + size_t qkvgo_offset = (b * (L * H * D) + l * (H * D) + h * (D)) * nb_base; + // D * 1 + acl_tensor_ptr acl_k = ggml_cann_create_tensor(k, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset); + acl_tensor_ptr acl_g = ggml_cann_create_tensor(g, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset); + // D + acl_tensor_ptr acl_q = ggml_cann_create_tensor(q, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset); + // 1 * D + acl_tensor_ptr acl_v = ggml_cann_create_tensor(v, ne_vo, nb_vo, 2, ACL_FORMAT_ND, qkvgo_offset); + // D + acl_tensor_ptr acl_o = ggml_cann_create_tensor(dst, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset); + // k ⊗ v + size_t buf_size = D * D * nb_base; + ggml_cann_pool_alloc buffer_allocator(ctx.pool(), buf_size); + acl_tensor_ptr tmp_tensor = ggml_cann_create_tensor( + buffer_allocator.get(), ggml_cann_type_mapping(k->type), nb_base, ne_s, nb_s, 2); + aclnn_mul(ctx, acl_k.get(), acl_v.get(), tmp_tensor.get()); + //s_new = g ⊗ s_old + k ⊗ v + aclnn_mul(ctx, acl_s_new.get(), acl_g.get(), nullptr); + aclnn_add(ctx, acl_s_new.get(), tmp_tensor.get(), nullptr); + // compute output + GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_s_new_t.get(), acl_q.get(), acl_o.get(), 1); + aclnn_muls(ctx, acl_o.get(), scale, nullptr, true); + } + } + } +} diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 08ee7b1fbd..b76e4707ac 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -814,67 +814,20 @@ void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst); */ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst); -/* - * @brief A generic wrapper for ACL resources with custom deleter support. - */ -using any_acl_resource = std::unique_ptr>; - /** - * @brief Trait structure used to define how to destroy a given ACL resource type. + * @brief Forward Gated Linear Attention on the CANN backend. * - * @tparam T ACL resource type. - */ -template struct acl_resource_traits; - -/** - * @brief Specialization for aclTensor, defines how to destroy an aclTensor resource. - */ -template <> struct acl_resource_traits { - static void destroy(void * p) { ACL_CHECK(aclDestroyTensor(static_cast(p))); } -}; - -/** - * @brief Specialization for aclIntArray, defines how to destroy an aclIntArray resource. - */ -template <> struct acl_resource_traits { - static void destroy(void * p) { ACL_CHECK(aclDestroyIntArray(static_cast(p))); } -}; - -/** - * @brief Specialization for aclScalar, defines how to destroy an aclScalar resource. - */ -template <> struct acl_resource_traits { - static void destroy(void * p) { ACL_CHECK(aclDestroyScalar(static_cast(p))); } -}; - -/** - * @brief Specialization for aclTensorList, defines how to destroy an aclTensorList resource. - */ -template <> struct acl_resource_traits { - static void destroy(void * p) { ACL_CHECK(aclDestroyTensorList(static_cast(p))); } -}; - -/** - * @brief Creates a generic ACL resource wrapper with proper destruction logic. + * Expects dst->src[0..4] = {k, v, q, g, s} with shape conventions: + * k, v, q, g: [D] with outer dims T x H batched as ne[2]=T, ne[1]=H + * s: initial state [B, H, D, D], where B is batch and D=C/H + * dst holds both outputs (o) and updated state; a scale factor is read from op params. * - * @tparam T ACL resource type. - * @param ptr Raw pointer to ACL resource. - * @return any_acl_resource Smart pointer that handles destruction. - */ -template any_acl_resource make_acl_resource(T * ptr) { - return any_acl_resource(static_cast(ptr), [](void * p) { acl_resource_traits::destroy(p); }); -} - -/** - * @brief Registers multiple ACL resources into a vector for lifetime management. + * The kernel updates per time step l: S_new = g ⊗ S_old + k ⊗ v, then computes o = (S_new^T q) * scale. * - * @tparam Args Variadic list of ACL resource types. - * @param vec Target vector to hold ACL resources. - * @param args Raw pointers to ACL resources. + * @param ctx Backend context providing stream/allocator utilities. + * @param dst Output tensor; src deps are k, v, q, g, s as above. */ -template void register_acl_resources(std::vector & vec, Args *... args) { - (vec.emplace_back(make_acl_resource(args)), ...); -} +void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Launches an asynchronous task using the memory allocator. @@ -894,19 +847,19 @@ template void register_acl_resources(std::vector 0) { \ - ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \ - workspaceAddr = workspace_allocator.get(); \ - } \ - ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \ - } while (0) +# define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \ + do { \ + uint64_t workspaceSize = 0; \ + aclOpExecutor * executor; \ + void * workspaceAddr = nullptr; \ + ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \ + /* workspace should alloced in main thread to keep malloc order when using vmm. */ \ + if (workspaceSize > 0) { \ + ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \ + workspaceAddr = workspace_allocator.get(); \ + } \ + ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \ + } while (0) /** * @brief Performs sparse expert-based matrix multiplication using the CANN backend. @@ -947,7 +900,9 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst); * @param rms_norm_tensor The RMS_NORM operation node, contains the gamma weights * and epsilon parameter. */ -void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, ggml_tensor * add_node, ggml_tensor * rms_norm_node); +void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, + ggml_tensor * add_node, + ggml_tensor * rms_norm_node); /** * @brief Check whether a tensor is a weight tensor for matrix multiplication. @@ -1104,13 +1059,13 @@ void ggml_cann_op_unary_gated(std::function