From 0a0a0108ceca4537bc676e0642b97f32c7f1fb78 Mon Sep 17 00:00:00 2001 From: "Vithule, Prashant" Date: Mon, 5 Jan 2026 03:56:06 +0000 Subject: [PATCH] Updated repack.cpp --- ggml/src/ggml-cpu/arch/arm/repack.cpp | 383 +++++++++++++++++++++++++- 1 file changed, 382 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index 99bb70274c..105c4e5564 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -3038,7 +3038,388 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + constexpr int q8_k_blocklen = 4; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const svuint8_t m4b_1 = svdup_n_u8(0x0f); + // 8 accumulators: 2 row pairs × 4 col pairs + // printf("Ai is going to"); + float32x4_t acc_f32[blocklen]; + svfloat32_t acc_f32_0, acc_f32_1, acc_f32_2, acc_f32_3, acc_f32_4, acc_f32_5, acc_f32_6, acc_f32_7; + uint32_t idx_arr[8] = {0,2,1,3,4,6,5,7}; + svbool_t pg = svptrue_pat_b32(SV_VL8); + + svuint32_t idx = svld1(pg, idx_arr); + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + const block_q8_Kx4 * GGML_RESTRICT q8_ptr_1 = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); + const block_q4_Kx8 * GGML_RESTRICT q4_ptr_1 = (const block_q4_Kx8 *) vx + (x * nb); + // for (int i = 0; i < blocklen; i++) { + // acc_f32[i] = vdupq_n_f32(0); + // } + acc_f32_0 = svdup_n_f32(0); + acc_f32_1 = svdup_n_f32(0); + acc_f32_2 = svdup_n_f32(0); + acc_f32_3 = svdup_n_f32(0); + acc_f32_4 = svdup_n_f32(0); + acc_f32_5 = svdup_n_f32(0); + acc_f32_6 = svdup_n_f32(0); + acc_f32_7 = svdup_n_f32(0); + + for (int b = 0; b < nb; b++) { + // bsums pairs belongs to the same q8_k subblock // 64 elemnts loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum + const int16x8_t bsums[4]{ + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), + }; + + int32_t bsums_arr32[4][8]; + + for (int q8_row = 0; q8_row < 4; q8_row++) { + int16x8_t v16 = bsums[q8_row]; + + // low 4 + int32x4_t v32_lo = vmovl_s16(vget_low_s16(v16)); + vst1q_s32(&bsums_arr32[q8_row][0], v32_lo); + + // high 4 + int32x4_t v32_hi = vmovl_s16(vget_high_s16(v16)); + vst1q_s32(&bsums_arr32[q8_row][4], v32_hi); + } + + svint32_t sb_acc_0 = svdup_n_s32(0); + svint32_t sb_acc_1 = svdup_n_s32(0); + svint32_t sb_acc_2 = svdup_n_s32(0); + svint32_t sb_acc_3 = svdup_n_s32(0); + + svint32_t acc_0 = svdup_n_s32(0); + svint32_t acc_1 = svdup_n_s32(0); + svint32_t acc_2 = svdup_n_s32(0); + svint32_t acc_3 = svdup_n_s32(0); + svint32_t acc_4 = svdup_n_s32(0); + svint32_t acc_5 = svdup_n_s32(0); + svint32_t acc_6 = svdup_n_s32(0); + svint32_t acc_7 = svdup_n_s32(0); + + svint32_t bias_acc_0 = svdup_n_s32(0); + svint32_t bias_acc_1 = svdup_n_s32(0); + svint32_t bias_acc_2 = svdup_n_s32(0); + svint32_t bias_acc_3 = svdup_n_s32(0); + svint32_t bias_acc_4 = svdup_n_s32(0); + svint32_t bias_acc_5 = svdup_n_s32(0); + svint32_t bias_acc_6 = svdup_n_s32(0); + svint32_t bias_acc_7 = svdup_n_s32(0); + + for (int sb = 0; sb < QK_K / 64; sb++) { + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int8_t q4sb_scales_1[2][8]; + svint32_t q4sb_mins_0_0, q4sb_mins_0_1, q4sb_mins_1_0, q4sb_mins_1_1; + for (int i = 0; i < 2; i++) { // 2-superblock I am working on + const int offset = sb * 24 + i * 12; + const uint8_t * scales_in = &q4_ptr_1[b].scales[offset]; + + constexpr uint32_t kmask1 = 0x3f3f3f3f; + constexpr uint32_t kmask2 = 0x0f0f0f0f; + constexpr uint32_t kmask3 = 0x03030303; + constexpr uint8_t scales_size = 12; + + uint32_t sm[3]; + memcpy(sm, scales_in, scales_size); + + const uint32_t mins_0_3 = sm[1] & kmask1; + const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4); + + svuint32_t mins_u321 = svdup_n_u32(mins_0_3); + svuint32_t mins_u322 = svdup_n_u32(mins_4_7); + /* reinterpret u32 → u8 */ + svuint8_t mins_u81 = svreinterpret_u8_u32(mins_u321); + svuint8_t mins_u82 = svreinterpret_u8_u32(mins_u322); + + /* widen u8 → u16 (lower half only) */ + svuint32_t mins_u161 = svunpklo_u32(svunpklo_u16(mins_u81)); + svuint32_t mins_u162 = svunpklo_u32(svunpklo_u16(mins_u82)); + + /* reinterpret u16 → s16 */ + if(i == 0) { + q4sb_mins_0_0 = svreinterpret_s32_u32(mins_u161); + q4sb_mins_0_1 = svreinterpret_s32_u32(mins_u162); + } else { + q4sb_mins_1_0 = svreinterpret_s32_u32(mins_u161); + q4sb_mins_1_1 = svreinterpret_s32_u32(mins_u162); + } + + uint32_t scales_u32[2]; + scales_u32[0] = sm[0] & kmask1; + scales_u32[1] = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4); + memcpy(q4sb_scales_1[i], scales_u32, 8); + } + + // q8_ptr[b].qs has interleaved Q8 rows (01, 23) + const int8_t * q8_base = q8_ptr[b].qs + sb * 256; + const int8_t * q8_base_1 = q8_ptr_1[b].qs + sb * 256; + + const svbool_t ph16 = svptrue_pat_b8(SV_VL16); + // predicate for activating lower lanes for 16 int8 elements + const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16); + + svint8_t q8s_00 = svld1_s8(ph16, q8_base_1 + 0); + svint8_t q8s_01 = svld1_s8(ph16, q8_base_1 + 32); + svint8_t q8s_02 = svld1_s8(ph16, q8_base_1 + 64); + svint8_t q8s_03 = svld1_s8(ph16, q8_base_1 + 96); + svint8_t q8s_04 = svld1_s8(ph16, q8_base_1 + 128); + svint8_t q8s_05 = svld1_s8(ph16, q8_base_1 + 160); + svint8_t q8s_06 = svld1_s8(ph16, q8_base_1 + 192); + svint8_t q8s_07 = svld1_s8(ph16, q8_base_1 + 224); + + svint8_t q8s_10 = svld1_s8(ph16, q8_base_1 + 16); + svint8_t q8s_11 = svld1_s8(ph16, q8_base_1 + 48); + svint8_t q8s_12 = svld1_s8(ph16, q8_base_1 + 80); + svint8_t q8s_13 = svld1_s8(ph16, q8_base_1 + 112); + svint8_t q8s_14 = svld1_s8(ph16, q8_base_1 + 144); + svint8_t q8s_15 = svld1_s8(ph16, q8_base_1 + 176); + svint8_t q8s_16 = svld1_s8(ph16, q8_base_1 + 208); + svint8_t q8s_17 = svld1_s8(ph16, q8_base_1 + 240); + + // Q4s columns iterated in pairs (01, 23, 45, 67) + for (int cp = 0; cp < ncols_interleaved / 2; cp++) { + + sb_acc_0 = svdup_n_s32(0); + sb_acc_1 = svdup_n_s32(0); + sb_acc_2 = svdup_n_s32(0); + sb_acc_3 = svdup_n_s32(0); + + svbool_t pg = svptrue_pat_b8(SV_VL16); + + svuint8_t q4_qs_cp_00 = svld1_u8(pg, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 0); + svuint8_t q4_qs_cp_10 = svld1_u8(pg, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 64); + svuint8_t q4_qs_cp_20 = svld1_u8(pg, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 128); + svuint8_t q4_qs_cp_30 = svld1_u8(pg, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 192); + + svint8_t q4_nibbles_0_0 = svreinterpret_s8_u8(svand_u8_m(pg, q4_qs_cp_00, m4b_1)); + svint8_t q4_nibbles_0_1 = svreinterpret_s8_u8(svand_u8_m(pg, q4_qs_cp_10, m4b_1)); + svint8_t q4_nibbles_0_2 = svreinterpret_s8_u8(svand_u8_m(pg, q4_qs_cp_20, m4b_1)); + svint8_t q4_nibbles_0_3 = svreinterpret_s8_u8(svand_u8_m(pg, q4_qs_cp_30, m4b_1)); + + svint8_t q4_nibbles_1_0 = svreinterpret_s8_u8(svlsr_n_u8_m(pg, q4_qs_cp_00, 4)); + svint8_t q4_nibbles_1_1 = svreinterpret_s8_u8(svlsr_n_u8_m(pg, q4_qs_cp_10, 4)); + svint8_t q4_nibbles_1_2 = svreinterpret_s8_u8(svlsr_n_u8_m(pg, q4_qs_cp_20, 4)); + svint8_t q4_nibbles_1_3 = svreinterpret_s8_u8(svlsr_n_u8_m(pg, q4_qs_cp_30, 4)); + + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_0_0, q8s_00); + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_0_1, q8s_01); + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_0_2, q8s_02); + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_0_3, q8s_03); + + sb_acc_1 = svmmla_s32(sb_acc_1, q4_nibbles_1_0, q8s_04); + sb_acc_1 = svmmla_s32(sb_acc_1, q4_nibbles_1_1, q8s_05); + sb_acc_1 = svmmla_s32(sb_acc_1, q4_nibbles_1_2, q8s_06); + sb_acc_1 = svmmla_s32(sb_acc_1, q4_nibbles_1_3, q8s_07); + + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_0_0, q8s_10); + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_0_1, q8s_11); + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_0_2, q8s_12); + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_0_3, q8s_13); + + sb_acc_3 = svmmla_s32(sb_acc_3, q4_nibbles_1_0, q8s_14); + sb_acc_3 = svmmla_s32(sb_acc_3, q4_nibbles_1_1, q8s_15); + sb_acc_3 = svmmla_s32(sb_acc_3, q4_nibbles_1_2, q8s_16); + sb_acc_3 = svmmla_s32(sb_acc_3, q4_nibbles_1_3, q8s_17); + + // Scales[i] corresponds to column i + const int scale_offset = cp * 2; + + int32_t tmp[8] = { + (int32_t) q4sb_scales_1[0][scale_offset], + (int32_t) q4sb_scales_1[0][scale_offset], + (int32_t) q4sb_scales_1[0][scale_offset + 1], + (int32_t) q4sb_scales_1[0][scale_offset + 1], + }; + int32_t tmp1[4] = { + (int32_t) q4sb_scales_1[1][scale_offset], + (int32_t) q4sb_scales_1[1][scale_offset], + (int32_t) q4sb_scales_1[1][scale_offset + 1], + (int32_t) q4sb_scales_1[1][scale_offset + 1], + }; + + svint32_t block_scale = svld1_s32(svptrue_pat_b32(SV_VL4), tmp); + svint32_t block_scale1 = svld1_s32(svptrue_pat_b32(SV_VL4), tmp1); + + if(cp == 0) { + acc_0 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_0, sb_acc_0, block_scale); + acc_4 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_4, sb_acc_2, block_scale); + acc_0 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_0, sb_acc_1, block_scale1); + acc_4 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_4, sb_acc_3, block_scale1); + } + if(cp == 1) { + acc_1 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_1, sb_acc_0, block_scale); + acc_5 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_5, sb_acc_2, block_scale); + acc_1 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_1, sb_acc_1, block_scale1); + acc_5 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_5, sb_acc_3, block_scale1); + } + if(cp == 2) { + acc_2 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_2, sb_acc_0, block_scale); + acc_6 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_6, sb_acc_2, block_scale); + acc_2 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_2, sb_acc_1, block_scale1); + acc_6 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_6, sb_acc_3, block_scale1); + } + if(cp == 3) { + acc_3 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_3, sb_acc_0, block_scale); + acc_7 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_7, sb_acc_2, block_scale); + acc_3 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_3, sb_acc_1, block_scale1); + acc_7 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_7, sb_acc_3, block_scale1); + } + } + + bias_acc_0 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_0, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0_0); + bias_acc_0 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_0, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1_0); + bias_acc_1 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_1, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0_1); + bias_acc_1 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_1, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1_1); + + bias_acc_2 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_2, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0_0); + bias_acc_2 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_2, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1_0); + bias_acc_3 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_3, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0_1); + bias_acc_3 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_3, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1_1); + + bias_acc_4 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_4, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0_0); + bias_acc_4 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_4, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1_0); + bias_acc_5 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_5, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0_1); + bias_acc_5 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_5, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1_1); + + bias_acc_6 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_6, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0_0); + bias_acc_6 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_6, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1_0); + bias_acc_7 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_7, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0_1); + bias_acc_7 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_7, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1_1); + } // for sb + + svint32_t reorder_acc_0 = svtbl_s32(svtrn1_s32(acc_0, acc_1), idx); + svint32_t reorder_acc_1 = svtbl_s32(svtrn1_s32(acc_2, acc_3), idx); + svint32_t reorder_acc_2 = svtbl_s32(svtrn2_s32(acc_0, acc_1), idx); + svint32_t reorder_acc_3 = svtbl_s32(svtrn2_s32(acc_2, acc_3), idx); + // acc[4..7] + svint32_t reorder_acc_4 = svtbl_s32(svtrn1_s32(acc_4, acc_5), idx); + svint32_t reorder_acc_5 = svtbl_s32(svtrn1_s32(acc_6, acc_7), idx); + svint32_t reorder_acc_6 = svtbl_s32(svtrn2_s32(acc_4, acc_5), idx); + svint32_t reorder_acc_7 = svtbl_s32(svtrn2_s32(acc_6, acc_7), idx); + + // Predicate for exactly 4 lanes + svbool_t pg4 = svptrue_pat_b32(SV_VL4); + + // Broadcast q8 scalar + svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]); + + // ---- dmins ---- + svfloat16_t q4_dmin_h = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *)(q4_ptr[b].dmin + 0)); + q4_dmin_h = svzip1_f16(q4_dmin_h, q4_dmin_h); + svfloat32_t q4_dmin = svcvt_f32_f16_x(svptrue_pat_b32(SV_VL4), q4_dmin_h); + + svfloat16_t q4_dmin_h1 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *)(q4_ptr[b].dmin + 4)); + q4_dmin_h1 = svzip1_f16(q4_dmin_h1, q4_dmin_h1); + svfloat32_t q4_dmin1 = svcvt_f32_f16_x(svptrue_pat_b32(SV_VL4), q4_dmin_h1); + + //-----scale ----------------------- + svfloat16_t q4_d_h = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *)(q4_ptr[b].d + 0)); + q4_d_h = svzip1_f16(q4_d_h, q4_d_h); + svfloat32_t q4_d = svcvt_f32_f16_x(svptrue_pat_b32(SV_VL4), q4_d_h); + + svfloat16_t q4_d_h1 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *)(q4_ptr[b].d + 4)); + q4_d_h1 = svzip1_f16(q4_d_h1, q4_d_h1); + svfloat32_t q4_d1 = svcvt_f32_f16_x(svptrue_pat_b32(SV_VL4), q4_d_h1); + + svfloat32_t scale = svmul_f32_x(pg4, q4_d, q8_d); + svfloat32_t dmins = svmul_f32_x(pg4, q4_dmin, q8_d); + + acc_f32_0 = svmls_f32_m(pg4, acc_f32_0, svcvt_f32_s32_m(svdup_n_f32(0), pg4, bias_acc_0), dmins); + acc_f32_0 = svmla_f32_m(pg4, acc_f32_0, svcvt_f32_s32_m(svdup_n_f32(0), pg4, reorder_acc_0), scale); + + scale = svmul_f32_x(pg4, q4_d1, q8_d); + dmins = svmul_f32_x(pg4, q4_dmin1, q8_d); + + acc_f32_1 = svmls_f32_m(pg4, acc_f32_1, svcvt_f32_s32_m(svdup_n_f32(0), pg4, bias_acc_1), dmins); + acc_f32_1 = svmla_f32_m(pg4, acc_f32_1, svcvt_f32_s32_m(svdup_n_f32(0), pg4, reorder_acc_1), scale); + + q8_d = svdup_f32(q8_ptr[b].d[1]); + + scale = svmul_f32_x(pg4, q4_d, q8_d); + dmins = svmul_f32_x(pg4, q4_dmin, q8_d); + + acc_f32_2 = svmls_f32_m(pg4, acc_f32_2, svcvt_f32_s32_m(svdup_n_f32(0), pg4, bias_acc_2), dmins); + acc_f32_2 = svmla_f32_m(pg4, acc_f32_2, svcvt_f32_s32_m(svdup_n_f32(0), pg4, reorder_acc_2), scale); + + scale = svmul_f32_x(pg4, q4_d1, q8_d); + dmins = svmul_f32_x(pg4, q4_dmin1, q8_d); + + acc_f32_3 = svmls_f32_m(pg4, acc_f32_3, svcvt_f32_s32_m(svdup_n_f32(0), pg4, bias_acc_3), dmins); + acc_f32_3 = svmla_f32_m(pg4, acc_f32_3, svcvt_f32_s32_m(svdup_n_f32(0), pg4, reorder_acc_3), scale); + + q8_d = svdup_f32(q8_ptr[b].d[2]); + + scale = svmul_f32_x(pg4, q4_d, q8_d); + dmins = svmul_f32_x(pg4, q4_dmin, q8_d); + + acc_f32_4 = svmls_f32_m(pg4, acc_f32_4, svcvt_f32_s32_m(svdup_n_f32(0), pg4, bias_acc_4), dmins); + acc_f32_4 = svmla_f32_m(pg4, acc_f32_4, svcvt_f32_s32_m(svdup_n_f32(0), pg4, reorder_acc_4), scale); + + scale = svmul_f32_x(pg4, q4_d1, q8_d); + dmins = svmul_f32_x(pg4, q4_dmin1, q8_d); + + acc_f32_5 = svmls_f32_m(pg4, acc_f32_5, svcvt_f32_s32_m(svdup_n_f32(0), pg4, bias_acc_5), dmins); + acc_f32_5 = svmla_f32_m(pg4, acc_f32_5, svcvt_f32_s32_m(svdup_n_f32(0), pg4, reorder_acc_5), scale); + + q8_d = svdup_f32(q8_ptr[b].d[3]); + + scale = svmul_f32_x(pg4, q4_d, q8_d); + dmins = svmul_f32_x(pg4, q4_dmin, q8_d); + + acc_f32_6 = svmls_f32_m(pg4, acc_f32_6, svcvt_f32_s32_m(svdup_n_f32(0), pg4, bias_acc_6), dmins); + acc_f32_6 = svmla_f32_m(pg4, acc_f32_6, svcvt_f32_s32_m(svdup_n_f32(0), pg4, reorder_acc_6), scale); + + scale = svmul_f32_x(pg4, q4_d1, q8_d); + dmins = svmul_f32_x(pg4, q4_dmin1, q8_d); + + acc_f32_7 = svmls_f32_m(pg4, acc_f32_7, svcvt_f32_s32_m(svdup_n_f32(0), pg4, bias_acc_7), dmins); + acc_f32_7 = svmla_f32_m(pg4, acc_f32_7, svcvt_f32_s32_m(svdup_n_f32(0), pg4, reorder_acc_7), scale); + + } // for b + + // With the previous reorder, the tile is already in the correct memory layout. + // Predicate for exactly 4 lanes + svbool_t pg4 = svptrue_pat_b32(SV_VL4); + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + // vst1q_f32(s + offset, acc_f32[2 * i + j]); + if (i == 0 && j == 0) { + svst1_f32(pg4, s + offset, acc_f32_0); + } else if (i == 0 && j == 1) { + svst1_f32(pg4, s + offset, acc_f32_1); + } else if (i == 1 && j == 0) { + svst1_f32(pg4, s + offset, acc_f32_2); + } else if (i == 1 && j == 1) { + svst1_f32(pg4, s + offset, acc_f32_3); + } else if (i == 2 && j == 0) { + svst1_f32(pg4, s + offset, acc_f32_4); + } else if (i == 2 && j == 1) { + svst1_f32(pg4, s + offset, acc_f32_5); + } else if (i == 3 && j == 0) { + svst1_f32(pg4, s + offset, acc_f32_6); + } else if (i == 3 && j == 1) { + svst1_f32(pg4, s + offset, acc_f32_7); + } + } + } + } // for x + } // for y + return; + +#elif defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) constexpr int q8_k_blocklen = 4; const uint8x16_t m4b = vdupq_n_u8(0x0f);