fix the bug of type conversion

This commit is contained in:
unknown 2025-12-14 10:31:45 +08:00
parent e840f09af1
commit 3c6c17479b
1 changed files with 16 additions and 16 deletions

View File

@ -3013,11 +3013,11 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
int64_t strideVal[1];
strideVal[0] = s0;
aclIntArray *stride = aclCreateIntArray(strideVal, 1);
acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1);
int64_t paddingVal[] = {0};
aclIntArray *padding = aclCreateIntArray(paddingVal, 1);
acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1);
int64_t dilationVal[] = {1};
aclIntArray *dilation = aclCreateIntArray(dilationVal, 1);
acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1);
bool transposed = true;
int64_t groups = 1;
int8_t cubeMathType = 0;
@ -3039,12 +3039,12 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
int64_t left_pad_len = dilationVal[0] * (max_kernel_size - 1) + 1 - 2 * paddingVal[0];
int64_t right_pad_len = 0;
aclScalar* alpha = nullptr;
acl_scalar_ptr alpha = nullptr;
float alphaValue = 1.0;
alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
alpha = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT);
// set zero to destination
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());
for(int k = 0; k < part_num; k++){
@ -3076,7 +3076,7 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type,
ggml_element_size(src0), part_ne, part_nb, 3, ACL_FORMAT_NCL);
GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight, slice_dim, slice_start, slice_end, slice_step, part_kernel);
GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step, part_kernel.get());
// create the part conv result tensor
int64_t part_dst_ne[4];
@ -3096,11 +3096,11 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
acl_tensor_ptr acl_part_dst = ggml_cann_create_tensor(part_dst_buf, dst_type, ggml_element_size(dst),
part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_part_dst);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_part_dst.get());
// compute part conv transpose 1d
GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input, part_kernel, nullptr, stride,
padding, dilation, transposed, padding, groups, acl_part_dst, cubeMathType);
GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), part_kernel.get(), nullptr, stride.get(),
padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(), cubeMathType);
// compute the position of part result in final result
int64_t global_start = slice_start;
@ -3110,11 +3110,11 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
right_pad_len = dst_len - global_end;
std::vector<int64_t> padDataVal = {left_pad_len,right_pad_len};
aclIntArray *padData = aclCreateIntArray(padDataVal.data(), 2);
acl_int_array_ptr padData = ggml_cann_create_int_array(padDataVal.data(), 2);
aclScalar* pad_value = nullptr;
acl_scalar_ptr pad_value = nullptr;
float pad_valueVal = 0.0;
pad_value = aclCreateScalar(&pad_valueVal, aclDataType::ACL_FLOAT);
pad_value = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT);
int64_t conv_result_ne[4];
for(int i = 0; i < 4; i++){
@ -3134,9 +3134,9 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
acl_tensor_ptr conv_result = ggml_cann_create_tensor(conv_result_buf, dst_type, ggml_element_size(dst),
conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, conv_result);
GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst, padData, pad_value, conv_result);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst, conv_result, alpha);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, conv_result.get());
GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(), conv_result.get());
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst.get(), conv_result.get(), alpha.get());
}
}