Revert "opt: use qf32 internal precision for vec_dot_f16_f32"
This reverts commit 8600ecd20d6c902fe16271d6af1e59504eff4a27.
This commit is contained in:
parent
cb0a8ff4e7
commit
2058f28b3e
|
|
@ -915,21 +915,18 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
|
||||||
uint32_t nv0 = n / VLEN_FP16; // num full fp16 hvx vectors
|
uint32_t nv0 = n / VLEN_FP16; // num full fp16 hvx vectors
|
||||||
uint32_t nv1 = n % VLEN_FP16; // leftover elements
|
uint32_t nv1 = n % VLEN_FP16; // leftover elements
|
||||||
|
|
||||||
const HVX_Vector one = Q6_Vh_vsplat_R(0x3C00); // 1.0 in fp16
|
const HVX_Vector zero = Q6_Vh_vsplat_R(0x3C00); // 1.0 in fp16
|
||||||
const HVX_Vector zero = Q6_V_vsplat_R(0); // 0.0 in fp16
|
|
||||||
HVX_Vector rsum = Q6_V_vsplat_R(0);
|
HVX_Vector rsum = Q6_V_vsplat_R(0);
|
||||||
uint32_t i = 0;
|
uint32_t i = 0;
|
||||||
|
|
||||||
#pragma unroll(2)
|
#pragma unroll(2)
|
||||||
for (i = 0; i < nv0; i++) {
|
for (i = 0; i < nv0; i++) {
|
||||||
HVX_VectorPair yp = vy[i];
|
HVX_VectorPair yp = vy[i];
|
||||||
HVX_Vector x = vx[i];
|
HVX_Vector x = vx[i];
|
||||||
HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), one); // mul by 1.0
|
HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), zero); // mul by 1.0
|
||||||
HVX_Vector y_hi = Q6_Vqf32_vadd_VsfVsf(Q6_V_hi_W(yp), zero); // convert to qf32
|
|
||||||
HVX_Vector y_lo = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(yp), zero); // convert to qf32
|
|
||||||
|
|
||||||
HVX_Vector hi = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(xp), y_hi);
|
HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp));
|
||||||
HVX_Vector lo = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(xp), y_lo);
|
HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
|
||||||
|
|
||||||
HVX_Vector sum = Q6_Vqf32_vadd_Vqf32Vqf32(hi, lo);
|
HVX_Vector sum = Q6_Vqf32_vadd_Vqf32Vqf32(hi, lo);
|
||||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
|
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
|
||||||
|
|
@ -938,25 +935,24 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri
|
||||||
if (nv1) {
|
if (nv1) {
|
||||||
HVX_VectorPair yp = vy[i];
|
HVX_VectorPair yp = vy[i];
|
||||||
HVX_Vector x = vx[i];
|
HVX_Vector x = vx[i];
|
||||||
HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), one); // mul by 1.0
|
HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), zero); // mul by 1.0
|
||||||
|
|
||||||
HVX_Vector leftover_x;
|
HVX_Vector l_x;
|
||||||
HVX_Vector leftover_y;
|
HVX_Vector l_y;
|
||||||
if (nv1 >= VLEN_FP32) {
|
if (nv1 >= VLEN_FP32) {
|
||||||
HVX_Vector y_lo = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(yp), zero); // convert to qf32
|
HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp));
|
||||||
HVX_Vector lo = Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(xp), y_lo);
|
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi);
|
||||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, lo);
|
|
||||||
nv1 -= VLEN_FP32;
|
nv1 -= VLEN_FP32;
|
||||||
leftover_x = Q6_V_hi_W(xp);
|
l_x = Q6_V_hi_W(xp);
|
||||||
leftover_y = Q6_Vqf32_vadd_VsfVsf(Q6_V_hi_W(yp), zero); // convert to qf32
|
l_y = Q6_V_hi_W(yp);
|
||||||
} else {
|
} else {
|
||||||
leftover_x = Q6_V_lo_W(xp);
|
l_x = Q6_V_lo_W(xp);
|
||||||
leftover_y = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(yp), zero); // convert to qf32
|
l_y = Q6_V_lo_W(yp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (nv1) {
|
if (nv1) {
|
||||||
HVX_Vector lo = Q6_Vqf32_vmpy_Vqf32Vqf32(leftover_x, leftover_y);
|
HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(l_x), l_y);
|
||||||
HVX_Vector sum = Q6_V_valign_VVR(lo, zero, nv1 * sizeof(float));
|
HVX_Vector sum = Q6_V_valign_VVR(lo, Q6_V_vzero(), nv1 * sizeof(float));
|
||||||
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
|
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue