fix: fix hvx_min_scalar_f32

This commit is contained in:
shouyud 2025-12-16 15:20:44 -05:00
parent 925a83ac70
commit 946f1a2037
2 changed files with 22 additions and 34 deletions

View File

@ -229,19 +229,9 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
} }
// x (src0_spad_data) = std::min(src0_p[k], limit); // x (src0_spad_data) = std::min(src0_p[k], limit);
//hvx_min_scalar_f32((const uint8_t *) src0, limit, src0_spad_data, nc); hvx_min_scalar_f32((const uint8_t *) src0, limit, src0_spad_data, nc);
// // y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit); // // y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit);
// hvx_clamp_scalar_f32((const uint8_t *) src1, limit, limit, src1_spad_data, nc); // hvx_clamp_scalar_f32((const uint8_t *) src1, limit, limit, src1_spad_data, nc);
// do manual limit
for (int i = 0; i < nc; i++) {
if (src0[i] > limit) {
((float *) src0_spad_data)[i] = limit;
} else {
((float *) src0_spad_data)[i] = src0[i];
}
}
for (int i = 0; i < nc; i++) { for (int i = 0; i < nc; i++) {
if (src1[i] > limit) { if (src1[i] > limit) {
((float *) src1_spad_data)[i] = limit; ((float *) src1_spad_data)[i] = limit;

View File

@ -884,46 +884,44 @@ void hvx_min_scalar_f32(const uint8_t * restrict src, const float val, uint8_t *
const float * src_f = (const float *) src; const float * src_f = (const float *) src;
HVX_Vector vec_min = Q6_V_vsplat_R(val); HVX_Vector vec_min =hvx_vec_splat_fp32(val);
int unaligned_input_addr = 0; int unalign_address = 0;
int unaligned_output_addr = 0;
if(htp_is_aligned((void *) src, VLEN) == 0) {
unaligned_input_addr = 1; if(htp_is_aligned((void *) src, VLEN) == 0 ||htp_is_aligned((void *) dst, VLEN) == 0 ) {
} unalign_address = 1;
if(htp_is_aligned((void *) dst, VLEN) == 0) {
unaligned_output_addr = 1;
} }
if(unalign_address == 0){
if(unaligned_input_addr == 0 && unaligned_output_addr == 0){
HVX_Vector * restrict vec_in = (HVX_Vector *) src; HVX_Vector * restrict vec_in = (HVX_Vector *) src;
HVX_Vector * restrict vec_out = (HVX_Vector *) dst; HVX_Vector * restrict vec_out = (HVX_Vector *) dst;
#pragma unroll(4) #pragma unroll(4)
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
vec_min = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++); HVX_Vector min_clamp = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++);
*vec_out++ = Q6_Vsf_equals_Vqf32(vec_min); *vec_out++ = (min_clamp);
}
}else{
HVX_UVector * restrict vec_in = (HVX_Vector *) src;
HVX_UVector * restrict vec_out = (HVX_Vector *) dst;
#pragma unroll(4)
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
HVX_Vector min_clamp = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++);
*vec_out++ = (min_clamp);
} }
} }
else if(unaligned_output_addr == 0){
}else if()
if (left_over > 0 ) { if (left_over > 0 ) {
const float * srcf = (const float *) src + num_elems_whole; const float * srcf = (const float *) src + num_elems_whole;
float * dstf = (float *) dst + num_elems_whole; float * dstf = (float *) dst + num_elems_whole;
HVX_Vector in = *(HVX_UVector *) srcf; HVX_UVector in = *(HVX_UVector *) srcf;
vec_min = Q6_Vsf_vmin_VsfVsf(vec_min, in); HVX_UVector min_clamp = Q6_Vsf_vmin_VsfVsf(vec_min, in);
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(vec_min)); hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, (min_clamp));
} }
} }