diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index b563906864..24379b9003 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -694,6 +694,36 @@ void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_cann_type_mapping(dst->type), acl_dst.get()); } +void ggml_cann_diag(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; + + GGML_ASSERT(src->ne[1] == 1); + + const int64_t N = src->ne[0]; + const int64_t n_batch = src->ne[2] * src->ne[3]; + const size_t nb_f32 = sizeof(float); + + // Fill dst with zeros. + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + { + float zero = 0.0f; + acl_scalar_ptr acl_zero = ggml_cann_create_scalar(&zero, ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst.get(), acl_zero.get()); + } + + // Copy src vector onto the diagonal of dst via strided views. + // src viewed as [N, n_batch], contiguous strides. + int64_t ne_vec[2] = { N, n_batch }; + size_t nb_src_vec[2] = { nb_f32, N * nb_f32 }; + // dst diagonal view: stride (N+1)*4 steps along the diagonal. + size_t nb_dst_diag[2] = { (N + 1) * nb_f32, N * N * nb_f32 }; + + acl_tensor_ptr acl_src_vec = ggml_cann_create_tensor(src->data, ACL_FLOAT, nb_f32, ne_vec, nb_src_vec, 2); + acl_tensor_ptr acl_dst_diag = ggml_cann_create_tensor(dst->data, ACL_FLOAT, nb_f32, ne_vec, nb_dst_diag, 2); + + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst_diag.get(), acl_src_vec.get()); +} + void ggml_cann_fill(ggml_backend_cann_context & ctx, ggml_tensor * dst) { float c = ggml_get_op_params_f32(dst, 0); diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index b8bcabff5b..1fb76ea532 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -346,6 +346,14 @@ void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst); */ void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst); +/** + * @brief Creates a diagonal matrix from a vector using the CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_DIAG`. + */ +void ggml_cann_diag(ggml_backend_cann_context & ctx, ggml_tensor * dst); + /** * @brief Fills a tensor with a constant scalar value using the CANN backend. * diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index bd46c1df8c..8fc6b4bbb1 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1917,6 +1917,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg case GGML_OP_FILL: ggml_cann_fill(ctx, dst); break; + case GGML_OP_DIAG: + ggml_cann_diag(ctx, dst); + break; default: return false; } @@ -2606,6 +2609,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_FILL: return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_DIAG: + return op->src[0]->type == GGML_TYPE_F32; default: return false; }