hexmm: implement vec_dot_rx2x2 for Q8_0 and MXFP4

This commit is contained in:
Max Krasnyansky 2026-02-05 19:00:20 -08:00
parent ff0a674947
commit c71da3a3f2
1 changed files with 329 additions and 24 deletions

View File

@ -600,15 +600,11 @@ static void vec_dot_q4x4x2_q8x4x2_rx2x2(const int n,
}
// Reduce and store results
r0_c0_sum = hvx_vec_reduce_sum_f32(r0_c0_sum);
r0_c1_sum = hvx_vec_reduce_sum_f32(r0_c1_sum);
r1_c0_sum = hvx_vec_reduce_sum_f32(r1_c0_sum);
r1_c1_sum = hvx_vec_reduce_sum_f32(r1_c1_sum);
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
hvx_vec_store_u(&s0[0], 4, r0_c0_sum); // row0, col0
hvx_vec_store_u(&s0[1], 4, r1_c0_sum); // row1, col0
hvx_vec_store_u(&s1[0], 4, r0_c1_sum); // row0, col1
hvx_vec_store_u(&s1[1], 4, r1_c1_sum); // row1, col1
hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
}
static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@ -782,6 +778,141 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
hvx_vec_store_u(&s[0], 8, rsum);
}
static void vec_dot_q8x4x2_q8x4x2_rx2x2(const int n,
float * restrict s0,
float * restrict s1,
const void * restrict vx,
uint32_t vx_row_size,
const void * restrict vy0,
const void * restrict vy1) {
assert(n % 32 == 0); // min sub-block size
assert((unsigned long) vx % 128 == 0);
assert((unsigned long) vy0 % 128 == 0);
assert((unsigned long) vy1 % 128 == 0);
const uint32_t qk = QK_Q8_0x4x2 * 4;
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
const uint32_t x_qblk_size = qk; // int8
const uint32_t x_qrow_size = n; // int8 (not padded)
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
const uint32_t y_qblk_size = qk; // int8
const uint32_t y_qrow_size = n; // int8 (not padded)
const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first
const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales
const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first
const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales
const uint8_t * restrict y0_q = ((const uint8_t *) vy0 + 0); // quants first
const uint8_t * restrict y0_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
const uint8_t * restrict y1_q = ((const uint8_t *) vy1 + 0); // quants first
const uint8_t * restrict y1_d = ((const uint8_t *) vy1 + y_qrow_size); // then scales
// Row sums (sf) - 4 accumulators for 2×2 tile
HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
const uint32_t nb = n / qk; // num full blocks
const uint32_t nloe = n % qk; // num leftover elements
uint32_t i = 0;
for (; i < nb; i++) {
// Load src1 columns (reused across both src0 rows)
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
// Load src0 rows (reused across both src1 columns)
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
// Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
// Load scales
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
// Compute combined scales
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
// Apply scales and accumulate
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
}
// Process leftovers
if (nloe) {
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size);
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
// Zero out unused scales
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
}
// Reduce and store results
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
}
static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
float * restrict s,
const void * restrict vx,
@ -1022,6 +1153,181 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
hvx_vec_store_u(&s[0], 8, rsum);
}
static void vec_dot_mxfp4x4x2_q8x4x2_rx2x2(const int n,
float * restrict s0,
float * restrict s1,
const void * restrict vx,
uint32_t vx_row_size,
const void * restrict vy0,
const void * restrict vy1) {
assert(n % 32 == 0); // min sub-block size
assert((unsigned long) vx % 128 == 0);
assert((unsigned long) vy0 % 128 == 0);
assert((unsigned long) vy1 % 128 == 0);
const uint32_t qk = QK_MXFP4x4x2 * 4;
const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0
const uint32_t x_qblk_size = qk / 2; // fp4
const uint32_t x_qrow_size = n / 2; // fp4 (not padded)
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
const uint32_t y_qblk_size = qk; // int8
const uint32_t y_qrow_size = n; // int8 (not padded)
const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first
const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales
const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first
const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales
const uint8_t * restrict y0_q = ((const uint8_t *) vy0 + 0); // quants first
const uint8_t * restrict y0_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
const uint8_t * restrict y1_q = ((const uint8_t *) vy1 + 0); // quants first
const uint8_t * restrict y1_d = ((const uint8_t *) vy1 + y_qrow_size); // then scales
// Row sums (sf) - 4 accumulators for 2×2 tile
HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
const uint32_t nb = n / qk; // num full blocks
const uint32_t nloe = n % qk; // num leftover elements
uint32_t i = 0;
for (; i < nb; i++) {
// Load src1 columns (reused across both src0 rows)
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
// Load src0 rows (reused across both src1 columns)
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
// Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
// Load scales
HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size);
HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size);
HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
// Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half));
vy0_d = Q6_Vsf_equals_Vqf32(vy0_d);
vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half));
vy1_d = Q6_Vsf_equals_Vqf32(vy1_d);
// Convert rX_d scales from e8m0 to fp32
// Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
// Left shift with zero fill to create FP32
// FIXME: might need to handle zero as a special case (see ggml-cpu code)
HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
r0_d = Q6_V_vdelta_VV(r0_d, expand);
r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
r1_d = Q6_V_vdelta_VV(r1_d, expand);
r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
// Compute combined scales
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d));
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d));
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d));
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d));
// Apply scales and accumulate
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
}
// Process leftovers
if (nloe) {
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size);
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size);
HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size);
HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size);
HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size);
// Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving
HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16
vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half));
vy0_d = Q6_Vsf_equals_Vqf32(vy0_d);
vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half));
vy1_d = Q6_Vsf_equals_Vqf32(vy1_d);
// Convert rX_d scales from e8m0 to fp32
// Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ...
// Left shift with zero fill to create FP32
// FIXME: might need to handle zero as a special case (see ggml-cpu code)
HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0;
HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff);
r0_d = Q6_V_vdelta_VV(r0_d, expand);
r0_d = Q6_V_vand_VV(r0_d, e8m0_mask);
r0_d = Q6_Vw_vasl_VwR(r0_d, 23);
r1_d = Q6_V_vdelta_VV(r1_d, expand);
r1_d = Q6_V_vand_VV(r1_d, e8m0_mask);
r1_d = Q6_Vw_vasl_VwR(r1_d, 23);
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d));
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d));
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d));
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d));
// Zero out unused scales
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
}
// Reduce and store results
HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum);
HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum);
hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0
hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1
}
static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const HVX_Vector * restrict x = (const HVX_Vector *) vx;
const HVX_Vector * restrict y = (const HVX_Vector *) vy;
@ -1361,17 +1667,14 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
// Process src1 columns in pairs (2×2 tiling)
uint32_t ir1 = 0;
if (mt->vec_dot_rx2x2) {
for (; ir1 + 1 < src1_nrows; ir1 += 2) {
const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride);
const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride);
float * restrict dst_row0 = (float *) (dst->data + ((ir1+0) * dst_row_size));
float * restrict dst_row1 = (float *) (dst->data + ((ir1+1) * dst_row_size));
mt->vec_dot_rx2x2(ne00, &dst_row0[ir0], &dst_row1[ir0], ss0, src0_stride, src1_col0, src1_col1);
}
for (; ir1 + 1 < src1_nrows; ir1 += 2) {
const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride);
const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride);
float * restrict dst_row0 = (float *) (dst->data + ((ir1+0) * dst_row_size));
float * restrict dst_row1 = (float *) (dst->data + ((ir1+1) * dst_row_size));
mt->vec_dot_rx2x2(ne00, &dst_row0[ir0], &dst_row1[ir0], ss0, src0_stride, src1_col0, src1_col1);
}
// Handle remaining src1 rows (fallback to 2×1)
for (; ir1 < src1_nrows; ++ir1) {
const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
@ -2078,9 +2381,10 @@ static void htp_matmul_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * d
struct htp_ops_context * octx = data;
struct htp_matmul_type mt;
mt.type = "q8x4x2-q8x4x2";
mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
mt.type = "q8x4x2-q8x4x2";
mt.vec_dot = vec_dot_q8x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2;
mt.vec_dot_rx2x2 = vec_dot_q8x4x2_q8x4x2_rx2x2;
matmul_2d(&mt, octx, n, i);
}
@ -2100,9 +2404,10 @@ static void htp_matmul_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void
struct htp_ops_context * octx = data;
struct htp_matmul_type mt;
mt.type = "mxfp4x4x2-q8x4x2";
mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
mt.type = "mxfp4x4x2-q8x4x2";
mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2;
mt.vec_dot_rx2x2 = vec_dot_mxfp4x4x2_q8x4x2_rx2x2;
matmul_2d(&mt, octx, n, i);
}