CANN: add CUMSUM and TRI op support, fix graph cache op_params matching

- Implement GGML_OP_CUMSUM using aclnnCumsum
- Implement GGML_OP_TRI with all 4 tri types (LOWER, LOWER_DIAG, UPPER, UPPER_DIAG)
  using Tril/MaskedFillScalar approach to work around CANN sparse-zero bugs
- Fix graph cache to always compare op_params for all ops, not just a whitelist
This commit is contained in:
hipudding 2026-03-28 05:29:44 +00:00
parent 11e78d8499
commit 93e0c17661
4 changed files with 146 additions and 4 deletions

View File

@ -25,6 +25,17 @@
#include "ggml-impl.h"
#include "ggml.h"
// Forward-declare InplaceFillDiagonal because aclnn_fill_diagonal.h has a
// broken include guard (OP_API_INC_ADD_H_) that conflicts with aclnn_add.h.
extern "C" {
aclnnStatus aclnnInplaceFillDiagonalGetWorkspaceSize(
aclTensor * selfRef, const aclScalar * fillValue, bool wrap,
uint64_t * workspaceSize, aclOpExecutor ** executor);
aclnnStatus aclnnInplaceFillDiagonal(
void * workspace, uint64_t workspaceSize, aclOpExecutor * executor,
aclrtStream stream);
}
#include <aclnnop/aclnn_add.h>
#include <aclnnop/aclnn_add_rms_norm.h>
#include <aclnnop/aclnn_addcdiv.h>
@ -75,6 +86,8 @@
#include <aclnnop/aclnn_threshold.h>
#include <aclnnop/aclnn_tril.h>
#include <aclnnop/aclnn_triu.h>
#include <aclnnop/aclnn_logical_not.h>
#include <aclnnop/aclnn_masked_fill_scalar.h>
#include <aclnnop/aclnn_upsample_nearest_2d.h>
#include <aclnnop/aclnn_weight_quant_batch_matmul_v2.h>
#include <aclnnop/aclnn_zero.h>
@ -670,6 +683,107 @@ void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
aclnn_reduce_sum(ctx, dst, reduce_dims, 4);
}
void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src = dst->src[0];
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src);
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
// GGML cumsum operates along dim 0 (innermost / ne[0]).
// ggml_cann_create_tensor reverses dimensions to [ne3,ne2,ne1,ne0],
// so GGML dim 0 maps to CANN dim 3 (the last dim of the 4-D tensor).
GGML_CANN_CALL_ACLNN_OP(ctx, Cumsum, acl_src.get(), (int64_t)3,
ggml_cann_type_mapping(dst->type), acl_dst.get());
}
void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src = dst->src[0];
const int64_t S = src->ne[0];
const int64_t n_batch = src->ne[2] * src->ne[3];
const size_t nb_f32 = sizeof(float);
const size_t nb_bool = sizeof(uint8_t);
const size_t buf_sz = n_batch * S * S * nb_f32;
const size_t bool_sz = n_batch * S * S * nb_bool;
int64_t ne3d[3] = { S, S, n_batch };
size_t nb3d[3] = { nb_f32, S * nb_f32, S * S * nb_f32 };
size_t nb3d_bool[3] = { nb_bool, S * nb_bool, S * S * nb_bool };
const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src->data, ACL_FLOAT, nb_f32, ne3d, nb3d, 3);
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ACL_FLOAT, nb_f32, ne3d, nb3d, 3);
// LOWER: Tril(-1) directly gives strict-lower triangle (CANN dim reversal
// makes Tril(-1) equivalent to GGML's col < row).
if (ttype == GGML_TRI_TYPE_LOWER) {
GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src.get(), (int64_t)-1, acl_dst.get());
return;
}
// For other types: copy src→dst, build a BOOL mask of positions to zero,
// then use MaskedFillScalar to zero those positions.
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst.get(), acl_src.get());
// Build lower-strict float mask (1s below diagonal, 0s elsewhere).
ggml_cann_pool_alloc ones_alloc(ctx.pool(), buf_sz);
void * ones_buf = ones_alloc.get();
acl_tensor_ptr acl_ones = ggml_cann_create_tensor(ones_buf, ACL_FLOAT, nb_f32, ne3d, nb3d, 3);
{
float one_val = 1.0f;
acl_scalar_ptr acl_one = ggml_cann_create_scalar(&one_val, ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_ones.get(), acl_one.get());
}
ggml_cann_pool_alloc mask_f_alloc(ctx.pool(), buf_sz);
void * mask_f_buf = mask_f_alloc.get();
acl_tensor_ptr acl_mask_f = ggml_cann_create_tensor(mask_f_buf, ACL_FLOAT, nb_f32, ne3d, nb3d, 3);
GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_ones.get(), (int64_t)-1, acl_mask_f.get());
// For LOWER_DIAG and UPPER: extend mask to include diagonal via strided
// diagonal view copy (Tril(0) is buggy on CANN, giving same result as Tril(-1)).
if (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) {
int64_t ne_diag[2] = { S, n_batch };
size_t nb_diag[2] = { (S + 1) * nb_f32, S * S * nb_f32 };
acl_tensor_ptr acl_ones_diag = ggml_cann_create_tensor(ones_buf, ACL_FLOAT, nb_f32, ne_diag, nb_diag, 2);
acl_tensor_ptr acl_mask_diag = ggml_cann_create_tensor(mask_f_buf, ACL_FLOAT, nb_f32, ne_diag, nb_diag, 2);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_mask_diag.get(), acl_ones_diag.get());
}
// Cast float mask to BOOL.
ggml_cann_pool_alloc mask_b_alloc(ctx.pool(), bool_sz);
void * mask_b_buf = mask_b_alloc.get();
acl_tensor_ptr acl_mask_b = ggml_cann_create_tensor(mask_b_buf, ACL_BOOL, nb_bool, ne3d, nb3d_bool, 3);
GGML_CANN_CALL_ACLNN_OP(ctx, Cast, acl_mask_f.get(), ACL_BOOL, acl_mask_b.get());
// Select which BOOL mask to pass to MaskedFillScalar (True positions get zeroed).
// LOWER_DIAG: invert lower_diag → upper_strict mask.
// UPPER_DIAG: use lower_strict mask directly.
// UPPER: use lower_diag mask directly.
ggml_cann_pool_alloc mask_inv_alloc(ctx.pool(), bool_sz);
void * mask_inv_buf = mask_inv_alloc.get();
acl_tensor_ptr acl_mask_inv = ggml_cann_create_tensor(mask_inv_buf, ACL_BOOL, nb_bool, ne3d, nb3d_bool, 3);
aclTensor * fill_mask = nullptr;
switch (ttype) {
case GGML_TRI_TYPE_LOWER_DIAG:
GGML_CANN_CALL_ACLNN_OP(ctx, LogicalNot, acl_mask_b.get(), acl_mask_inv.get());
fill_mask = acl_mask_inv.get();
break;
case GGML_TRI_TYPE_UPPER_DIAG:
fill_mask = acl_mask_b.get();
break;
case GGML_TRI_TYPE_UPPER:
fill_mask = acl_mask_b.get();
break;
default:
GGML_ABORT("unsupported tri type");
}
float zero_val = 0.0f;
acl_scalar_ptr acl_zero = ggml_cann_create_scalar(&zero_val, ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMaskedFillScalar, acl_dst.get(), fill_mask, acl_zero.get());
}
void ggml_cann_upsample_nearest2d(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src = dst->src[0];
acl_tensor_ptr acl_src = ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW);

View File

@ -32,6 +32,9 @@
#include <aclnnop/aclnn_cat.h>
#include <aclnnop/aclnn_clamp.h>
#include <aclnnop/aclnn_cos.h>
#include <aclnnop/aclnn_cumsum.h>
#include <aclnnop/aclnn_tril.h>
#include <aclnnop/aclnn_triu.h>
#include <aclnnop/aclnn_exp.h>
#include <aclnnop/aclnn_gelu.h>
#include <aclnnop/aclnn_gelu_v2.h>
@ -325,6 +328,24 @@ void ggml_cann_sum_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst);
void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes the cumulative sum of a ggml tensor along dim 0 using the
* CANN backend.
*
* @param ctx The CANN context used for operations.
* @param dst The destination tensor. dst->op is `GGML_OP_CUMSUM`.
*/
void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Computes a triangular mask (tril/triu) of a square ggml tensor
* using the CANN backend.
*
* @param ctx The CANN context used for operations.
* @param dst The destination tensor. dst->op is `GGML_OP_TRI`.
*/
void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Upsamples a ggml tensor using nearest neighbor interpolation using
* the CANN backend.

View File

@ -277,10 +277,7 @@ struct ggml_graph_node_properties {
}
}
if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU || node->op == GGML_OP_ROPE){
return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
}
return true;
return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
}
};

View File

@ -1908,6 +1908,12 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
case GGML_OP_GATED_DELTA_NET:
ggml_cann_gated_delta_net(ctx, dst);
break;
case GGML_OP_CUMSUM:
ggml_cann_cumsum(ctx, dst);
break;
case GGML_OP_TRI:
ggml_cann_tri(ctx, dst);
break;
default:
return false;
}
@ -2591,6 +2597,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
&& ggml_is_contiguous(beta)
&& q->type == GGML_TYPE_F32;
}
case GGML_OP_CUMSUM:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_TRI:
return op->src[0]->type == GGML_TYPE_F32;
default:
return false;
}