CANN: add GGML_OP_FILL support
Implement FILL using aclnnInplaceFillScalar to fill a tensor with a constant scalar value from op_params.
This commit is contained in:
parent
93e0c17661
commit
4a7bb25226
|
|
@ -694,6 +694,14 @@ void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
|||
ggml_cann_type_mapping(dst->type), acl_dst.get());
|
||||
}
|
||||
|
||||
void ggml_cann_fill(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
float c = ggml_get_op_params_f32(dst, 0);
|
||||
|
||||
acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
|
||||
acl_scalar_ptr acl_c = ggml_cann_create_scalar(&c, ACL_FLOAT);
|
||||
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst.get(), acl_c.get());
|
||||
}
|
||||
|
||||
void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
|
||||
ggml_tensor * src = dst->src[0];
|
||||
|
||||
|
|
|
|||
|
|
@ -346,6 +346,14 @@ void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
|||
*/
|
||||
void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Fills a tensor with a constant scalar value using the CANN backend.
|
||||
*
|
||||
* @param ctx The CANN context used for operations.
|
||||
* @param dst The destination tensor. dst->op is `GGML_OP_FILL`.
|
||||
*/
|
||||
void ggml_cann_fill(ggml_backend_cann_context & ctx, ggml_tensor * dst);
|
||||
|
||||
/**
|
||||
* @brief Upsamples a ggml tensor using nearest neighbor interpolation using
|
||||
* the CANN backend.
|
||||
|
|
|
|||
|
|
@ -1914,6 +1914,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
|
|||
case GGML_OP_TRI:
|
||||
ggml_cann_tri(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_FILL:
|
||||
ggml_cann_fill(ctx, dst);
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
@ -2601,6 +2604,8 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
|
|||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_TRI:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_FILL:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue