ggml-hexagon: implement 2x2 matmul kernel
This commit is contained in:
parent
2ceda3f662
commit
ff0a674947
|
|
@ -27,6 +27,7 @@ struct htp_matmul_type {
|
|||
const char * type;
|
||||
void (*vec_dot)(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
||||
void (*vec_dot_rx2)(const int n, float * restrict s, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy);
|
||||
void (*vec_dot_rx2x2)(const int n, float * restrict s0, float * restrict s1, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy0, const void * restrict vy1);
|
||||
};
|
||||
|
||||
// vdelta control to replicate first 4x fp32 values across lanes
|
||||
|
|
@ -471,6 +472,145 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
|
|||
hvx_vec_store_u(&s[0], 8, rsum);
|
||||
}
|
||||
|
||||
static void vec_dot_q4x4x2_q8x4x2_rx2x2(const int n,
|
||||
float * restrict s0,
|
||||
float * restrict s1,
|
||||
const void * restrict vx,
|
||||
uint32_t vx_row_size,
|
||||
const void * restrict vy0,
|
||||
const void * restrict vy1) {
|
||||
assert(n % 32 == 0); // min sub-block size
|
||||
assert((unsigned long) vx % 128 == 0);
|
||||
assert((unsigned long) vy0 % 128 == 0);
|
||||
assert((unsigned long) vy1 % 128 == 0);
|
||||
|
||||
const uint32_t qk = QK_Q4_0x4x2 * 4;
|
||||
|
||||
const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16
|
||||
const uint32_t x_qblk_size = qk / 2; // int4
|
||||
const uint32_t x_qrow_size = n / 2; // int4 (not padded)
|
||||
|
||||
const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16
|
||||
const uint32_t y_qblk_size = qk; // int8
|
||||
const uint32_t y_qrow_size = n; // int8 (not padded)
|
||||
|
||||
const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first
|
||||
const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales
|
||||
|
||||
const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first
|
||||
const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales
|
||||
|
||||
const uint8_t * restrict y0_q = ((const uint8_t *) vy0 + 0); // quants first
|
||||
const uint8_t * restrict y0_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales
|
||||
|
||||
const uint8_t * restrict y1_q = ((const uint8_t *) vy1 + 0); // quants first
|
||||
const uint8_t * restrict y1_d = ((const uint8_t *) vy1 + y_qrow_size); // then scales
|
||||
|
||||
// Row sums (sf) - 4 accumulators for 2×2 tile
|
||||
HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0);
|
||||
HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0);
|
||||
HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0);
|
||||
HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0);
|
||||
|
||||
const uint32_t nb = n / qk; // num full blocks
|
||||
const uint32_t nloe = n % qk; // num leftover elements
|
||||
|
||||
uint32_t i = 0;
|
||||
for (; i < nb; i++) {
|
||||
// Load src1 columns (reused across both src0 rows)
|
||||
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
|
||||
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
|
||||
|
||||
// Load src0 rows (reused across both src1 columns)
|
||||
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
|
||||
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
|
||||
|
||||
// Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1
|
||||
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q));
|
||||
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q));
|
||||
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q));
|
||||
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q));
|
||||
|
||||
// Load scales
|
||||
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
|
||||
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
|
||||
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
||||
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
||||
|
||||
// Compute combined scales
|
||||
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
|
||||
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
|
||||
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
|
||||
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
|
||||
|
||||
// Apply scales and accumulate
|
||||
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
|
||||
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
|
||||
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
|
||||
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
|
||||
|
||||
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
|
||||
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
|
||||
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
|
||||
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
|
||||
}
|
||||
|
||||
// Process leftovers
|
||||
if (nloe) {
|
||||
HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size);
|
||||
HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size);
|
||||
HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size);
|
||||
HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size);
|
||||
|
||||
HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe));
|
||||
HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe));
|
||||
HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe));
|
||||
HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe));
|
||||
|
||||
HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size));
|
||||
HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size));
|
||||
HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size));
|
||||
HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size));
|
||||
|
||||
HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d)));
|
||||
HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d)));
|
||||
HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d)));
|
||||
HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d)));
|
||||
|
||||
// Zero out unused scales
|
||||
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
|
||||
r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd);
|
||||
r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd);
|
||||
r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd);
|
||||
r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd);
|
||||
r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia);
|
||||
r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia);
|
||||
r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia);
|
||||
r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia);
|
||||
|
||||
HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd);
|
||||
HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd);
|
||||
HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd);
|
||||
HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd);
|
||||
|
||||
r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum));
|
||||
r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum));
|
||||
r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum));
|
||||
r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum));
|
||||
}
|
||||
|
||||
// Reduce and store results
|
||||
r0_c0_sum = hvx_vec_reduce_sum_f32(r0_c0_sum);
|
||||
r0_c1_sum = hvx_vec_reduce_sum_f32(r0_c1_sum);
|
||||
r1_c0_sum = hvx_vec_reduce_sum_f32(r1_c0_sum);
|
||||
r1_c1_sum = hvx_vec_reduce_sum_f32(r1_c1_sum);
|
||||
|
||||
hvx_vec_store_u(&s0[0], 4, r0_c0_sum); // row0, col0
|
||||
hvx_vec_store_u(&s0[1], 4, r1_c0_sum); // row1, col0
|
||||
hvx_vec_store_u(&s1[0], 4, r0_c1_sum); // row0, col1
|
||||
hvx_vec_store_u(&s1[1], 4, r1_c1_sum); // row1, col1
|
||||
}
|
||||
|
||||
static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
assert(n % 32 == 0); // min sub-block size
|
||||
assert((unsigned long) vx % 128 == 0);
|
||||
|
|
@ -1219,8 +1359,21 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx
|
|||
for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) {
|
||||
const uint8_t * ss0 = dma_queue_pop(dma_queue).dst;
|
||||
|
||||
#pragma unroll(2)
|
||||
for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) {
|
||||
// Process src1 columns in pairs (2×2 tiling)
|
||||
uint32_t ir1 = 0;
|
||||
if (mt->vec_dot_rx2x2) {
|
||||
for (; ir1 + 1 < src1_nrows; ir1 += 2) {
|
||||
const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride);
|
||||
const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride);
|
||||
float * restrict dst_row0 = (float *) (dst->data + ((ir1+0) * dst_row_size));
|
||||
float * restrict dst_row1 = (float *) (dst->data + ((ir1+1) * dst_row_size));
|
||||
|
||||
mt->vec_dot_rx2x2(ne00, &dst_row0[ir0], &dst_row1[ir0], ss0, src0_stride, src1_col0, src1_col1);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle remaining src1 rows (fallback to 2×1)
|
||||
for (; ir1 < src1_nrows; ++ir1) {
|
||||
const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride);
|
||||
float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size));
|
||||
mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_stride, src1_col);
|
||||
|
|
@ -1902,9 +2055,10 @@ static void htp_matmul_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * d
|
|||
struct htp_ops_context * octx = data;
|
||||
|
||||
struct htp_matmul_type mt;
|
||||
mt.type = "q4x4x2-q8x4x2";
|
||||
mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
|
||||
mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
|
||||
mt.type = "q4x4x2-q8x4x2";
|
||||
mt.vec_dot = vec_dot_q4x4x2_q8x4x2;
|
||||
mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2;
|
||||
mt.vec_dot_rx2x2 = vec_dot_q4x4x2_q8x4x2_rx2x2;
|
||||
|
||||
matmul_2d(&mt, octx, n, i);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue