CANN: add GGML_OP_SOLVE_TRI support

Implement triangular linear system solve (AX=B) using
aclnnTriangularSolve for the lower-triangular, non-unit case.
This commit is contained in:
hipudding 2026-03-28 05:50:57 +00:00
parent 871ffea262
commit 168d05f3d5
3 changed files with 35 additions and 0 deletions

View File

@ -85,6 +85,7 @@ aclnnStatus aclnnInplaceFillDiagonal(
#include <aclnnop/aclnn_sum.h>
#include <aclnnop/aclnn_threshold.h>
#include <aclnnop/aclnn_tril.h>
#include <aclnnop/aclnn_triangular_solve.h>
#include <aclnnop/aclnn_triu.h>
#include <aclnnop/aclnn_logical_not.h>
#include <aclnnop/aclnn_masked_fill_scalar.h>
@ -694,6 +695,27 @@ void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_cann_type_mapping(dst->type), acl_dst.get());
}
void ggml_cann_solve_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src0 = dst->src[0]; // A: [N, N, B2, B3] lower triangular
ggml_tensor * src1 = dst->src[1]; // B: [K, N, B2, B3]
acl_tensor_ptr acl_a = ggml_cann_create_tensor(src0);
acl_tensor_ptr acl_b = ggml_cann_create_tensor(src1);
acl_tensor_ptr acl_x = ggml_cann_create_tensor(dst);
// mOut: triangular copy of A (required output), same shape as A.
const size_t a_bytes = ggml_nbytes(src0);
ggml_cann_pool_alloc m_alloc(ctx.pool(), a_bytes);
acl_tensor_ptr acl_m = ggml_cann_create_tensor(
m_alloc.get(), ggml_cann_type_mapping(src0->type),
ggml_type_size(src0->type), src0->ne, src0->nb, GGML_MAX_DIMS);
// Solve AX = B: upper=false (lower tri), transpose=false, unitriangular=false.
GGML_CANN_CALL_ACLNN_OP(ctx, TriangularSolve,
acl_b.get(), acl_a.get(), false, false, false,
acl_x.get(), acl_m.get());
}
void ggml_cann_diag(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
ggml_tensor * src = dst->src[0];

View File

@ -346,6 +346,14 @@ void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst);
*/
void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Solves a triangular linear system AX=B using the CANN backend.
*
* @param ctx The CANN context used for operations.
* @param dst The destination tensor. dst->op is `GGML_OP_SOLVE_TRI`.
*/
void ggml_cann_solve_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Creates a diagonal matrix from a vector using the CANN backend.
*

View File

@ -1920,6 +1920,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
case GGML_OP_DIAG:
ggml_cann_diag(ctx, dst);
break;
case GGML_OP_SOLVE_TRI:
ggml_cann_solve_tri(ctx, dst);
break;
default:
return false;
}
@ -2611,6 +2614,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_DIAG:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SOLVE_TRI:
return op->src[0]->type == GGML_TYPE_F32;
default:
return false;
}