ggml-hexagon: implement 2x2 matmul kernel

This commit is contained in:
Trivikram Reddy 2026-02-05 08:37:02 -08:00
parent 2ceda3f662
commit ff0a674947
1 changed files with 159 additions and 5 deletions

View File

@ -27,6 +27,7 @@ struct htp_matmul_type {
const char * type;
void (*vec_dot)(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void (*vec_dot_rx2)(const int n, float * restrict s, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy);
void (*vec_dot_rx2x2)(const int n, float * restrict s0, float * restrict s1, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy0, const void * restrict vy1);
};
// vdelta control to replicate first 4x fp32 values across lanes
@ -471,6 +472,145 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
hvx_vec_store_u(&s[0], 8, rsum);
}
static void vec_dot_q4x4x2_q8x4x2_rx2x2(const int n,
float * restrict s0,
float * restrict s1,
const void * restrict vx,
uint32_t vx_row_size,
const void * restrict vy0,
const void * restrict vy1) {
assert(n % 32 == 0); // min sub-block size
assert((unsigned long) vx % 128 == 0);
assert((unsigned long) vy0 % 128 == 0);
assert((unsigned long) vy1 % 128 == 0);
const uint32_t qk = QK_Q4_0x4x2 * 4;
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
const uint32_t x_qblk_size = qk / 2; // int4
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
const uint32_t y_qblk_size = qk; // int8
const uint32_t y_qrow_size = n; // int8 (not padded)
const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first
const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales
const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first
const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales
const uint8_t * restrict y0_q = ((const uint8_t *) vy0 + 0); // quants first
const uint8_t * restrict y0_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
const uint8_t * restrict y1_q = ((const uint8_t *) vy1 + 0); // quants first
const uint8_t * restrict y1_d = ((const uint8_t *) vy1 + y_qrow_size); // then scales
// Row sums (sf) - 4 accumulators for 2×2 tile
HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
const uint32_t nb = n / qk; // num full blocks
const uint32_t nloe = n % qk; // num leftover elements
uint32_t i = 0;
for (; i < nb; i++) {
// Load src1 columns (reused across both src0 rows)
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
// Load src0 rows (reused across both src1 columns)
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
// Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
// Load scales
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
// Compute combined scales
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
// Apply scales and accumulate
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
}
// Process leftovers
if (nloe) {
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
// Zero out unused scales
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
}
// Reduce and store results
r0_c0_sum = hvx_vec_reduce_sum_f32(r0_c0_sum);
r0_c1_sum = hvx_vec_reduce_sum_f32(r0_c1_sum);
r1_c0_sum = hvx_vec_reduce_sum_f32(r1_c0_sum);
r1_c1_sum = hvx_vec_reduce_sum_f32(r1_c1_sum);
hvx_vec_store_u(&s0[0], 4, r0_c0_sum); // row0, col0
hvx_vec_store_u(&s0[1], 4, r1_c0_sum); // row1, col0
hvx_vec_store_u(&s1[0], 4, r0_c1_sum); // row0, col1
hvx_vec_store_u(&s1[1], 4, r1_c1_sum); // row1, col1
}
static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % 32 == 0); // min sub-block size
assert((unsigned long) vx % 128 == 0);
@ -1219,8 +1359,21 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
#pragma unroll(2)
for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
// Process src1 columns in pairs (2×2 tiling)
uint32_t ir1 = 0;
if (mt->vec_dot_rx2x2) {
for (; ir1 + 1 < src1_nrows; ir1 += 2) {
const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride);
const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride);
float * restrict dst_row0 = (float *) (dst->data + ((ir1+0) * dst_row_size));
float * restrict dst_row1 = (float *) (dst->data + ((ir1+1) * dst_row_size));
mt->vec_dot_rx2x2(ne00, &dst_row0[ir0], &dst_row1[ir0], ss0, src0_stride, src1_col0, src1_col1);
}
}
// Handle remaining src1 rows (fallback to 2×1)
for (; ir1 < src1_nrows; ++ir1) {
const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_stride, src1_col);
@ -1902,9 +2055,10 @@ static void htp_matmul_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * d
struct htp_ops_context * octx = data;
struct htp_matmul_type mt;
mt.type = "q4x4x2-q8x4x2";
mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
mt.type = "q4x4x2-q8x4x2";
mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
mt.vec_dot_rx2x2 = vec_dot_q4x4x2_q8x4x2_rx2x2;
matmul_2d(&mt, octx, n, i);
}