This commit is contained in:
chraac 2025-12-19 21:26:40 +08:00
parent 398aa85311
commit 500c627fbc
1 changed files with 12 additions and 15 deletions

View File

@ -912,23 +912,20 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x;
const HVX_UVectorPair * restrict vy = (const HVX_UVectorPair * restrict) y;
uint32_t nv0 = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nv1 = n % VLEN_FP16; // leftover elements
uint32_t nv0 = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nv1 = n % VLEN_FP16; // leftover elements
// for some reason we need volatile here so that the compiler doesn't try anything funky
const HVX_Vector zero = Q6_Vh_vsplat_R(0x3C00); // 1.0 in fp16
volatile HVX_Vector rsum = Q6_V_vsplat_R(0);
uint32_t i = 0;
const HVX_Vector zero = Q6_Vh_vsplat_R(0x3C00); // 1.0 in fp16
HVX_Vector rsum = Q6_V_vsplat_R(0);
uint32_t i = 0;
for (i = 0; i < nv0; i++) {
HVX_VectorPair yp = vy[i];
HVX_Vector x = vx[i];
HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), zero); // mul by 1.0
//NOTE: need volatile here to prevent compiler optimization
// Seem compiler cannot guarantee read-after-write??
volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
HVX_Vector sum = Q6_Vqf32_vadd_Vqf32Vqf32(hi, lo);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
@ -942,8 +939,8 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
HVX_Vector l_x;
HVX_Vector l_y;
if (nv1 >= 32) {
volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi);
HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi);
nv1 -= 32;
l_x = Q6_V_hi_W(xp);
l_y = Q6_V_hi_W(yp);
@ -953,9 +950,9 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
}
if (nv1) {
volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(l_x), l_y);
HVX_Vector sum = Q6_V_valign_VVR(lo, Q6_V_vzero(), nv1 * sizeof(float));
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(l_x), l_y);
HVX_Vector sum = Q6_V_valign_VVR(lo, Q6_V_vzero(), nv1 * sizeof(float));
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
}
// hvx_vec_dump_fp16("X", x);