Feat: optiized unaligned sigmoid_f32

This commit is contained in:
shouyud 2025-12-12 11:58:45 -05:00
parent 84f2f23aa9
commit fc2289dc96
2 changed files with 109 additions and 4 deletions

View File

@ -317,10 +317,7 @@ static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0,
} else {
hvx_mul_scalar_f32( (const uint8_t *) src0, (float)1.702, (uint8_t *) src0_spad_data, ne0);
// sigmoid
hvx_exp_f32((const uint8_t *) src0_spad_data, src0_spad_data, ne0, true);
hvx_add_scalar_f32(src0_spad_data, 1.0, dst_spad_data, ne0);
hvx_inverse_f32(dst_spad_data, src0_spad_data, ne0);
hvx_sigmoid_f32((const uint8_t *) src0_spad_data, (uint8_t *) src0_spad_data, ne0);
hvx_mul_f32((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
}
}

View File

@ -265,12 +265,16 @@ static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t
}
}
/* Return whether 'n' elements from vector are in the one chunk of 'chunk_size'. */
static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) {
uint32_t left_off = (size_t) addr & (chunk_size - 1);
uint32_t right_off = left_off + n;
return right_off <= chunk_size;
}
static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) {
HVX_VectorAlias u = { .v = v };
@ -994,6 +998,110 @@ static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t *
}
}
static inline void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems){
int step_of_1 = num_elems >> 5; // divby 32, because 32 float = 128 bytes per HVX vector
int leftover = num_elems - (step_of_1 * VLEN_FP32);
// assert(remaining == 0);//TODO: handle remaining elements later
int32_t leftover_size = leftover * sizeof(float);
static const float kMinExp = -87.f; // 0
static const float kMaxExp = 87.f; // 1
const HVX_Vector one = hvx_vec_splat_fp32(1.f);
const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
const HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp);
const float *input = (float *)src;
float *output = (float *)dst;
HVX_Vector * input_v_ptr = (HVX_Vector *) input;
HVX_UVector * output_v_ptr = (HVX_UVector *) output;
HVX_Vector slinep;
HVX_Vector slinec;
HVX_Vector sline;
slinep = *input_v_ptr++;
#pragma unroll(4)
for(uint32_t i = step_of_1 -1; i> 0; i--){
slinec = *input_v_ptr++;
sline = Q6_V_valign_VVR(slinec, slinep, (size_t) input);
*((HVX_UVector *)(output_v_ptr++)) = hvx_vec_fast_sigmoid_fp32_guard(sline, one, max_exp, min_exp);
/* Prepare slinep for next iteration */
slinep = slinec;
}
if(step_of_1> 0){
slinec = htp_is_aligned(input_v_ptr, 128) && leftover == 0 ? slinep : *input_v_ptr++;
sline = Q6_V_valign_VVR(slinec, slinep, (size_t) input);
*((HVX_UVector *)(output_v_ptr++)) = hvx_vec_fast_sigmoid_fp32_guard(sline, one, max_exp, min_exp);;
slinep = slinec;
}
if(leftover> 0){
slinec = (is_in_one_chunk(input_v_ptr, leftover_size, 128)
? slinep
: *input_v_ptr++);
sline = Q6_V_valign_VVR(slinec, slinep, (size_t) input);
HVX_Vector sout = hvx_vec_fast_sigmoid_fp32_guard(sline, one, max_exp, min_exp);
/* Store output */
hvx_vec_store_u(output_v_ptr, leftover_size, sout);
}
}
// static inline void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems){
// int step_of_1 = num_elems >> 5; // divby 32, because 32 float = 128 bytes per HVX vector
// int leftover = num_elems - (step_of_1 * VLEN_FP32);
// // assert(remaining == 0);//TODO: handle remaining elements later
// int32_t leftover_size = leftover * sizeof(float);
// static const float kMinExp = -87.f; // 0
// static const float kMaxExp = 87.f; // 1
// const HVX_Vector one = hvx_vec_splat_fp32(1.f);
// const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
// const HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp);
// const float *input = (float *)src;
// float *output = (float *)dst;
// HVX_UVector * input_v_ptr = (HVX_UVector *) input;
// HVX_UVector * output_v_ptr = (HVX_UVector *) output;
// // #pragma unroll(4) NOTE: this actual got slower
// for(uint32_t i = step_of_1; i> 0; i--){
// *((HVX_UVector *)(output_v_ptr++)) = hvx_vec_fast_sigmoid_fp32_guard(*(input_v_ptr++), one, max_exp, min_exp);
// }
// if(leftover> 0){
// HVX_Vector sout = hvx_vec_fast_sigmoid_fp32_guard(*(input_v_ptr++), one, max_exp, min_exp);
// /* Store output */
// hvx_vec_store_u(output_v_ptr, leftover_size, sout);
// }
// }
float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems);
void hvx_mul_f32(const uint8_t * restrict src0,
const uint8_t * restrict src1,