hexagon: improve RMS_NORM and DIV accuracy (#21251)
* hexagon-rms_norm: fix RMS_NORM for non-aligned tensor sizes Co-authored-by: Krishna Sridhar <srsr@qti.qualcomm.com> * hexagon-div: perform DIV in fp16 domain for lower dsp archs --------- Co-authored-by: Krishna Sridhar <srsr@qti.qualcomm.com>
This commit is contained in:
parent
1d6d4cf7a5
commit
8710e5f9b9
|
|
@ -16,8 +16,10 @@
|
|||
|
||||
#if __HVX_ARCH__ < 79
|
||||
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
|
||||
#define HVX_OP_MUL_F16(a, b) Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b))
|
||||
#else
|
||||
#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
|
||||
#define HVX_OP_MUL_F16(a, b) Q6_Vhf_vmpy_VhfVhf(a, b)
|
||||
#endif
|
||||
|
||||
// Compute div by scaler in f32. Requires first by expanding fp32 to fp16 and converting the result back to fp32.
|
||||
|
|
@ -43,46 +45,67 @@ static inline HVX_Vector hvx_div_mul_f16_const_using_f32(HVX_Vector vec1_hf, HVX
|
|||
return res;
|
||||
}
|
||||
|
||||
#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \
|
||||
\
|
||||
const uint32_t nvec = n / VLEN_FP16; \
|
||||
const uint32_t nloe = n % VLEN_FP16; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
|
||||
vdst[i] = res; \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
|
||||
vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \
|
||||
} \
|
||||
// Variant for <v79: Use pre-computed f16 reciprocal constant
|
||||
static inline HVX_Vector hvx_div_mul_f16_const_using_f16(HVX_Vector vec1_hf, HVX_Vector const_inv_hf) {
|
||||
// Multiply by pre-computed f16 reciprocal constant
|
||||
return HVX_OP_MUL_F16(vec1_hf, const_inv_hf);
|
||||
}
|
||||
|
||||
#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src_type * restrict vsrc = (src_type *) src; \
|
||||
\
|
||||
HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \
|
||||
\
|
||||
const uint32_t nvec = n / VLEN_FP16; \
|
||||
const uint32_t nloe = n % VLEN_FP16; \
|
||||
\
|
||||
uint32_t i = 0; \
|
||||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
HVX_Vector res; \
|
||||
if (__HVX_ARCH__ < 79) { \
|
||||
res = hvx_div_mul_f16_const_using_f16(vsrc[i], val_vec_f16); \
|
||||
} else { \
|
||||
res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
|
||||
} \
|
||||
vdst[i] = res; \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector res; \
|
||||
if (__HVX_ARCH__ < 79) { \
|
||||
res = hvx_div_mul_f16_const_using_f16(vsrc[i], val_vec_f16); \
|
||||
} else { \
|
||||
res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \
|
||||
} \
|
||||
vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
static inline void hvx_div_scalar_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
|
||||
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
|
||||
const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val);
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
assert((uintptr_t) src % 128 == 0);
|
||||
hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
|
||||
}
|
||||
static inline void hvx_div_scalar_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
|
||||
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
|
||||
const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val);
|
||||
assert((uintptr_t) dst % 128 == 0);
|
||||
hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
|
||||
}
|
||||
static inline void hvx_div_scalar_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
|
||||
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
|
||||
const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val);
|
||||
assert((uintptr_t) src % 128 == 0);
|
||||
hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
|
||||
}
|
||||
static inline void hvx_div_scalar_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) {
|
||||
const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val));
|
||||
const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val);
|
||||
hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
|
||||
}
|
||||
|
||||
|
|
@ -128,13 +151,25 @@ static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector v
|
|||
return recip;
|
||||
}
|
||||
|
||||
// Hybrid approach: f16 reciprocal for <v79, f32 precision for >=v79
|
||||
static inline HVX_Vector hvx_vec_hybrid_div_f16(HVX_Vector vec1, HVX_Vector vec2, HVX_Vector f32_nan_inf_mask, HVX_Vector f16_nan_inf_mask, HVX_Vector vec_hf_one_1_0) {
|
||||
#if __HVX_ARCH__ < 79
|
||||
// For older architectures, use f16 reciprocal to avoid NaN/-inf issues
|
||||
HVX_Vector vec2_inv = hvx_vec_inverse_f16_guard(vec2, f16_nan_inf_mask);
|
||||
return HVX_OP_MUL_F16(vec1, vec2_inv);
|
||||
#else
|
||||
return hvx_vec_div_f16_using_f32(vec1, vec2, f32_nan_inf_mask, vec_hf_one_1_0);
|
||||
#endif
|
||||
}
|
||||
|
||||
#define hvx_div_f16_loop_body(dst_type, src0_type, src1_type, vec_store) \
|
||||
do { \
|
||||
dst_type * restrict vdst = (dst_type *) dst; \
|
||||
src0_type * restrict vsrc0 = (src0_type *) src0; \
|
||||
src1_type * restrict vsrc1 = (src1_type *) src1; \
|
||||
\
|
||||
const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
|
||||
const HVX_Vector f32_nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
|
||||
const HVX_Vector f16_nan_inf_mask = Q6_Vh_vsplat_R(0x7c00); \
|
||||
const HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \
|
||||
\
|
||||
const uint32_t nvec = n / VLEN_FP16; \
|
||||
|
|
@ -144,11 +179,15 @@ static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector v
|
|||
\
|
||||
_Pragma("unroll(4)") \
|
||||
for (; i < nvec; i++) { \
|
||||
HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \
|
||||
HVX_Vector res = hvx_vec_hybrid_div_f16(vsrc0[i], vsrc1[i], \
|
||||
f32_nan_inf_mask, f16_nan_inf_mask, \
|
||||
hf_one); \
|
||||
vdst[i] = res; \
|
||||
} \
|
||||
if (nloe) { \
|
||||
HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \
|
||||
HVX_Vector res = hvx_vec_hybrid_div_f16(vsrc0[i], vsrc1[i], \
|
||||
f32_nan_inf_mask, f16_nan_inf_mask, \
|
||||
hf_one); \
|
||||
vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \
|
||||
} \
|
||||
} while(0)
|
||||
|
|
@ -247,5 +286,6 @@ HVX_DIV_DISPATCHER(hvx_div_f32)
|
|||
HVX_DIV_DISPATCHER(hvx_div_f16)
|
||||
|
||||
#undef HVX_OP_MUL_F32
|
||||
#undef HVX_OP_MUL_F16
|
||||
|
||||
#endif // HVX_DIV_H
|
||||
|
|
|
|||
|
|
@ -67,34 +67,61 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
|
|||
uint8_t * restrict pad,
|
||||
const int num_elems,
|
||||
float epsilon) {
|
||||
(void)pad;
|
||||
|
||||
const HVX_Vector * restrict v_src = (HVX_Vector *) src;
|
||||
HVX_Vector * restrict v_dst = (HVX_Vector *) dst;
|
||||
|
||||
HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000);
|
||||
const int nvec = num_elems / VLEN_FP32; // number of full vectors
|
||||
const int nloe = num_elems % VLEN_FP32; // leftover elements
|
||||
|
||||
// Compute sum of squares for full vectors
|
||||
HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000);
|
||||
HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon);
|
||||
|
||||
int step_of_1 = num_elems >> 5;
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
for (int i = 0; i < nvec; i++) {
|
||||
HVX_Vector v1 = v_src[i];
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
||||
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
||||
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
||||
}
|
||||
|
||||
sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); // replicated over all lanes
|
||||
// Handle tail elements using vectorized ops with masking
|
||||
if (nloe > 0) {
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
||||
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1);
|
||||
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
|
||||
}
|
||||
|
||||
// Reduce HVX sum
|
||||
sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
|
||||
|
||||
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
|
||||
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);
|
||||
HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v);
|
||||
HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v);
|
||||
|
||||
// Scale full vectors
|
||||
HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v));
|
||||
|
||||
#pragma unroll(4)
|
||||
for (int i = 0; i < step_of_1; i++) {
|
||||
for (int i = 0; i < nvec; i++) {
|
||||
HVX_Vector v1 = v_src[i];
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
|
||||
v_dst[i] = Q6_Vsf_equals_Vqf32(v2);
|
||||
v_dst[i] = Q6_Vsf_equals_Vqf32(v2);
|
||||
}
|
||||
|
||||
// Handle tail elements using vectorized ops with masking
|
||||
if (nloe > 0) {
|
||||
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4);
|
||||
HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]);
|
||||
HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v);
|
||||
HVX_Vector result = Q6_Vsf_equals_Vqf32(v2);
|
||||
|
||||
// Store with masking to avoid overwriting memory beyond the tensor
|
||||
hvx_vec_store_a(&v_dst[nvec], nloe * 4, result);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue