This commit is contained in:
Alfred 2025-12-17 12:02:28 +08:00 committed by GitHub
commit a7fd5a0a76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 137 additions and 6 deletions

View File

@ -92,6 +92,18 @@ static const uint8_t __attribute__((aligned(128))) repl_1x_fp16[128] = {
0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
}; };
// vdelta control to replicate first fp16 value across all elements
static const uint8_t __attribute__((aligned(128))) repl_2x_fp16[128] = {
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02,
};
// vdelta control to expand first 32 e8m0 values into 32 uint32 elements // vdelta control to expand first 32 e8m0 values into 32 uint32 elements
static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = { static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = {
0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00,
@ -1593,6 +1605,119 @@ static void matmul_f16_f32(struct htp_tensor * restrict src0,
} }
// *** dynamic quant // *** dynamic quant
typedef void (*quantize_block_fp32_q8_func_t)(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d);
static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
assert((unsigned long) x % 128 == 0);
assert((unsigned long) y_q % 128 == 0);
HVX_Vector * vx = (HVX_Vector *) x;
HVX_Vector zero = Q6_V_vsplat_R(0);
// Use reduce max fp32 to find max(abs(e)) first
HVX_Vector vmax0_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[0]));
HVX_Vector vmax1_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[1]));
HVX_Vector vmax2_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[2]));
HVX_Vector vmax3_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[3]));
// Load and convert into QF32
HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
// Convert to QF32
HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero);
HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero);
HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero);
HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero);
// Combine and convert to fp16
HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf)));
HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf)));
// Convert into fp16
HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
// Replicate first fp16 scale across all lanes
HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_fp16;
vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16);
hvx_vec_store_u(y_d + 0, 2, vd01_hf);
HVX_Vector rotated_vd_hf = Q6_V_vror_VR(vd01_hf, 64);
hvx_vec_store_u(y_d + 2, 2, rotated_vd_hf);
hvx_vec_store_u(y_d + 4, 2, vd23_hf);
rotated_vd_hf = Q6_V_vror_VR(vd23_hf, 64);
hvx_vec_store_u(y_d + 6, 2, rotated_vd_hf);
// Divide input by the scale
HVX_Vector vd01_inv_hf = hvx_vec_inverse_fp16(vd01_hf);
HVX_Vector vd23_inv_hf = hvx_vec_inverse_fp16(vd23_hf);
vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
// Convert to int8
HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
*(HVX_Vector *) y_q = vx_i8;
}
static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
assert((unsigned long) x % 128 == 0);
assert((unsigned long) y_q % 128 == 0);
HVX_Vector * vx = (HVX_Vector *) x;
// Load and convert into QF32
HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements
HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements
HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements
HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements
// Convert into fp16
HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf)));
HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf)));
// Compute max and scale
HVX_Vector vmax01_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx01_hf));
HVX_Vector vmax23_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx23_hf));
// Replicate first fp16 scale across all lanes
HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_fp16;
vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl);
vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl);
HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0
HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16);
HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16);
hvx_vec_store_u(y_d + 0, 4, vd01_hf);
hvx_vec_store_u(y_d + 4, 4, vd23_hf);
// Divide input by the scale
HVX_Vector vd01_inv_hf = hvx_vec_inverse_fp16(vd01_hf);
HVX_Vector vd23_inv_hf = hvx_vec_inverse_fp16(vd23_hf);
vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf));
vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf));
// Convert to int8
HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf);
HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf);
HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16);
*(HVX_Vector *) y_q = vx_i8;
}
static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) {
assert((unsigned long) x % 128 == 0); assert((unsigned long) x % 128 == 0);
@ -1638,7 +1763,7 @@ static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restri
} }
// Overrides input x // Overrides input x
static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k, quantize_block_fp32_q8_func_t quantize_row_func) {
assert(k % 32 == 0); assert(k % 32 == 0);
const uint32_t qk = QK_Q8_0x4x2; const uint32_t qk = QK_Q8_0x4x2;
const uint32_t nb = (k + qk - 1) / qk; const uint32_t nb = (k + qk - 1) / qk;
@ -1655,9 +1780,9 @@ static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, u
uint8_t * restrict t_d = (uint8_t *) x; uint8_t * restrict t_d = (uint8_t *) x;
for (uint32_t i = 0; i < nb; i++) { for (uint32_t i = 0; i < nb; i++) {
quantize_block_fp32_q8x4(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2, quantize_row_func(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2,
t_d + (i * 2 + 0) * dblk_size / 2); t_d + (i * 2 + 0) * dblk_size / 2);
quantize_block_fp32_q8x4(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2, quantize_row_func(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2,
t_d + (i * 2 + 1) * dblk_size / 2); t_d + (i * 2 + 1) * dblk_size / 2);
} }
@ -1670,7 +1795,9 @@ static void quantize_fp32_q8x4x2(const struct htp_tensor * src,
struct htp_spad * spad, struct htp_spad * spad,
uint32_t nth, uint32_t nth,
uint32_t ith, uint32_t ith,
uint32_t nrows_per_thread) { uint32_t nrows_per_thread
, quantize_block_fp32_q8_func_t quantize_row_func) {
uint64_t t1 = HAP_perf_get_qtimer_count(); uint64_t t1 = HAP_perf_get_qtimer_count();
const uint32_t ne0 = src->ne[0]; const uint32_t ne0 = src->ne[0];
@ -1698,7 +1825,7 @@ static void quantize_fp32_q8x4x2(const struct htp_tensor * src,
hvx_copy_fp32_aa(tmp_data, src_data, ne0); hvx_copy_fp32_aa(tmp_data, src_data, ne0);
// FARF(HIGH, "quantize-q8x4-row: %u\n", i); // FARF(HIGH, "quantize-q8x4-row: %u\n", i);
quantize_row_fp32_q8x4x2((float *) tmp_data, dst_data, ne0); quantize_row_fp32_q8x4x2((float *) tmp_data, dst_data, ne0, quantize_row_func);
dst_data += dst_row_size; dst_data += dst_row_size;
src_data += src_row_size; src_data += src_row_size;
} }
@ -1711,7 +1838,11 @@ static void quantize_fp32_q8x4x2(const struct htp_tensor * src,
static void htp_quantize_fp32_q8x4x2(unsigned int n, unsigned int i, void * data) { static void htp_quantize_fp32_q8x4x2(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data; struct htp_ops_context * octx = data;
quantize_fp32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread); // quantize_block_fp32_q8x4: use group size 128: tested on Qwen3:0.6B k proj layer 0 on 256 tokens, Relative L2 1.7%
// quantize_block_fp32_q8x2: use group size 64 : Relative L2 1.3%
// quantize_block_fp32_q8x1: use group size 32 : Relative L2 1.1%
quantize_block_fp32_q8_func_t quantize_row_func = quantize_block_fp32_q8x1;
quantize_fp32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread, quantize_row_func);
} }
// ** matmul callbacks for worker_pool // ** matmul callbacks for worker_pool