CANN: supports out_prod operator for F32 and F16 (#17406)

Co-authored-by: tianhao <tianhao42@huawei.com>
This commit is contained in:
TianHao324 2025-11-25 17:39:06 +08:00 committed by GitHub
parent b1846f1c8e
commit 064c90d843
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 95 additions and 0 deletions

View File

@ -42,6 +42,7 @@
#include <aclnnop/aclnn_exp.h>
#include <aclnnop/aclnn_fill_scalar.h>
#include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
#include <aclnnop/aclnn_ger.h>
#include <aclnnop/aclnn_group_norm.h>
#include <aclnnop/aclnn_grouped_matmul_v3.h>
#include <aclnnop/aclnn_gt_scalar.h>
@ -3236,3 +3237,64 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst
GGML_ABORT("Function is not implemented.");
}
}
static void ggml_cann_out_prod_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; // weight
ggml_tensor * src1 = dst->src[1]; // input
GGML_TENSOR_BINARY_OP_LOCALS
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());
const int64_t dps2 = ne2 / ne02;
const int64_t dps3 = ne3 / ne03;
for (int64_t i3 = 0; i3 < ne3; i3++) {
for (int64_t i2 = 0; i2 < ne2; i2++) {
const int64_t i02 = i2 / dps2;
const int64_t i03 = i3 / dps3;
const int64_t i12 = i2;
const int64_t i13 = i3;
acl_tensor_ptr accumulator =
ggml_cann_create_tensor((char *) dst->data + i2 * nb2 + i3 * nb3, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), dst->ne, dst->nb, 2);
// The outer product needs to be accumulated in this dimension.
for (int64_t i1 = 0; i1 < ne11; i1++) {
acl_tensor_ptr acl_input = ggml_cann_create_tensor(
(char *) src1->data + i1 * nb11 + i12 * nb12 + i13 * nb13, ggml_cann_type_mapping(src0->type),
ggml_type_size(src0->type), src1->ne, src1->nb, 1);
acl_tensor_ptr acl_weight = ggml_cann_create_tensor(
(char *) src0->data + i1 * nb01 + i02 * nb02 + i03 * nb03, ggml_cann_type_mapping(src0->type),
ggml_type_size(src0->type), src0->ne, src0->nb, 1);
ggml_cann_pool_alloc output_allocator(ctx.pool());
void * output_buffer = output_allocator.alloc(ggml_nbytes(dst));
acl_tensor_ptr acl_out = ggml_cann_create_tensor(output_buffer, ggml_cann_type_mapping(dst->type),
ggml_type_size(dst->type), dst->ne, dst->nb, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_input.get(), acl_weight.get(), acl_out.get());
float alpha_value = 1.0f;
aclScalar * alpha = aclCreateScalar(&alpha_value, ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, accumulator.get(), acl_out.get(), alpha);
}
}
}
}
void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0];
const enum ggml_type type = src0->type;
switch (type) {
case GGML_TYPE_F32:
case GGML_TYPE_F16:
ggml_cann_out_prod_fp(ctx, dst);
break;
default:
GGML_ABORT("Unsupport type for GGML_OP_OUT_PROD");
break;
}
}

View File

@ -1125,3 +1125,23 @@ void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, ac
} while (0)
#endif // CANN_ACLNN_OPS
/**
* @brief Performs outer product operation on two ggml tensors using the CANN backend.
*
* @details This function computes the outer product of two input tensors (src0 and src1)
* and stores the result in the destination tensor. The outer product operation is defined as:
* dst[i,j,k,l] = sum_m (src0[i,m,k,l] * src1[j,m,k,l])
*
* The function supports multiple data types including F32, F16. For floating-point
* types, it uses batch matrix multiplication for efficient computation.
*
* The implementation handles 4D tensor broadcasting and batch processing automatically.
*
* @param ctx The CANN backend context for operation execution and memory management.
* @param dst The destination ggml_tensor where the outer product result will be stored.
* The input tensors are assumed to be `dst->src[0]` and `dst->src[1]`.
*
* @see GGML_CANN_CALL_ACLNN_OP for CANN operator invocation
*/
void ggml_cann_out_prod(ggml_backend_cann_context & ctx, ggml_tensor * dst);

View File

@ -1886,6 +1886,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
case GGML_OP_FLASH_ATTN_EXT:
ggml_cann_flash_attn_ext(ctx, dst);
break;
case GGML_OP_OUT_PROD:
ggml_cann_out_prod(ctx, dst);
break;
default:
return false;
}
@ -2563,6 +2566,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_COUNT_EQUAL:
return true;
case GGML_OP_OUT_PROD:
{
switch (op->src[0]->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
return true;
default:
return false;
}
}
case GGML_OP_CONV_TRANSPOSE_1D:
// TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255.
return (op->src[0]->ne[0] - 1) <= 255;