diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 533b127b48..b563906864 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -694,6 +694,14 @@ void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_cann_type_mapping(dst->type), acl_dst.get()); } +void ggml_cann_fill(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + float c = ggml_get_op_params_f32(dst, 0); + + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + acl_scalar_ptr acl_c = ggml_cann_create_scalar(&c, ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst.get(), acl_c.get()); +} + void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src = dst->src[0]; diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 3926525a95..b8bcabff5b 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -346,6 +346,14 @@ void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst); */ void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst); +/** + * @brief Fills a tensor with a constant scalar value using the CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_FILL`. + */ +void ggml_cann_fill(ggml_backend_cann_context & ctx, ggml_tensor * dst); + /** * @brief Upsamples a ggml tensor using nearest neighbor interpolation using * the CANN backend. diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 5e3d7d53f3..bd46c1df8c 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1914,6 +1914,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg case GGML_OP_TRI: ggml_cann_tri(ctx, dst); break; + case GGML_OP_FILL: + ggml_cann_fill(ctx, dst); + break; default: return false; } @@ -2601,6 +2604,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_TRI: return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_FILL: + return op->src[0]->type == GGML_TYPE_F32; default: return false; }