diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 24379b9003..0c8fdb68e2 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -85,6 +85,7 @@ aclnnStatus aclnnInplaceFillDiagonal( #include #include #include +#include #include #include #include @@ -694,6 +695,27 @@ void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_cann_type_mapping(dst->type), acl_dst.get()); } +void ggml_cann_solve_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; // A: [N, N, B2, B3] lower triangular + ggml_tensor * src1 = dst->src[1]; // B: [K, N, B2, B3] + + acl_tensor_ptr acl_a = ggml_cann_create_tensor(src0); + acl_tensor_ptr acl_b = ggml_cann_create_tensor(src1); + acl_tensor_ptr acl_x = ggml_cann_create_tensor(dst); + + // mOut: triangular copy of A (required output), same shape as A. + const size_t a_bytes = ggml_nbytes(src0); + ggml_cann_pool_alloc m_alloc(ctx.pool(), a_bytes); + acl_tensor_ptr acl_m = ggml_cann_create_tensor( + m_alloc.get(), ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), src0->ne, src0->nb, GGML_MAX_DIMS); + + // Solve AX = B: upper=false (lower tri), transpose=false, unitriangular=false. + GGML_CANN_CALL_ACLNN_OP(ctx, TriangularSolve, + acl_b.get(), acl_a.get(), false, false, false, + acl_x.get(), acl_m.get()); +} + void ggml_cann_diag(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src = dst->src[0]; diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 1fb76ea532..2fe0874f24 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -346,6 +346,14 @@ void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst); */ void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst); +/** + * @brief Solves a triangular linear system AX=B using the CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_SOLVE_TRI`. + */ +void ggml_cann_solve_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst); + /** * @brief Creates a diagonal matrix from a vector using the CANN backend. * diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 8fc6b4bbb1..7ef4089147 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1920,6 +1920,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg case GGML_OP_DIAG: ggml_cann_diag(ctx, dst); break; + case GGML_OP_SOLVE_TRI: + ggml_cann_solve_tri(ctx, dst); + break; default: return false; } @@ -2611,6 +2614,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_DIAG: return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SOLVE_TRI: + return op->src[0]->type == GGML_TYPE_F32; default: return false; }