diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index 105c4e5564..b0d25cdc47 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -3040,16 +3040,16 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, #if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) constexpr int q8_k_blocklen = 4; - const uint8x16_t m4b = vdupq_n_u8(0x0f); const svuint8_t m4b_1 = svdup_n_u8(0x0f); // 8 accumulators: 2 row pairs × 4 col pairs - // printf("Ai is going to"); - float32x4_t acc_f32[blocklen]; - svfloat32_t acc_f32_0, acc_f32_1, acc_f32_2, acc_f32_3, acc_f32_4, acc_f32_5, acc_f32_6, acc_f32_7; - uint32_t idx_arr[8] = {0,2,1,3,4,6,5,7}; + svfloat32_t acc_f32_01, acc_f32_23, acc_f32_45, acc_f32_67; + uint32_t idx_arr[8] = { 0, 2, 4, 6, 1, 3, 5, 7 }; svbool_t pg = svptrue_pat_b32(SV_VL8); - svuint32_t idx = svld1(pg, idx_arr); + + static const uint32_t idx_data[8] = {0, 4, 2, 6, 1, 5, 3, 7}; + svuint32_t idx1 = svld1_u32(svptrue_b32(), idx_data); + for (int y = 0; y < nr / q8_k_blocklen; y++) { const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); const block_q8_Kx4 * GGML_RESTRICT q8_ptr_1 = (const block_q8_Kx4 *) vy + (y * nb); @@ -3057,17 +3057,11 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, for (int x = 0; x < nc / ncols_interleaved; x++) { const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); const block_q4_Kx8 * GGML_RESTRICT q4_ptr_1 = (const block_q4_Kx8 *) vx + (x * nb); - // for (int i = 0; i < blocklen; i++) { - // acc_f32[i] = vdupq_n_f32(0); - // } - acc_f32_0 = svdup_n_f32(0); - acc_f32_1 = svdup_n_f32(0); - acc_f32_2 = svdup_n_f32(0); - acc_f32_3 = svdup_n_f32(0); - acc_f32_4 = svdup_n_f32(0); - acc_f32_5 = svdup_n_f32(0); - acc_f32_6 = svdup_n_f32(0); - acc_f32_7 = svdup_n_f32(0); + + acc_f32_01 = svdup_n_f32(0); + acc_f32_23 = svdup_n_f32(0); + acc_f32_45 = svdup_n_f32(0); + acc_f32_67 = svdup_n_f32(0); for (int b = 0; b < nb; b++) { // bsums pairs belongs to the same q8_k subblock // 64 elemnts loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum @@ -3092,38 +3086,38 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, vst1q_s32(&bsums_arr32[q8_row][4], v32_hi); } - svint32_t sb_acc_0 = svdup_n_s32(0); - svint32_t sb_acc_1 = svdup_n_s32(0); - svint32_t sb_acc_2 = svdup_n_s32(0); - svint32_t sb_acc_3 = svdup_n_s32(0); + svint32_t sb_acc_00 = svdup_n_s32(0); + svint32_t sb_acc_11 = svdup_n_s32(0); + svint32_t sb_acc_22 = svdup_n_s32(0); + svint32_t sb_acc_33 = svdup_n_s32(0); - svint32_t acc_0 = svdup_n_s32(0); - svint32_t acc_1 = svdup_n_s32(0); - svint32_t acc_2 = svdup_n_s32(0); - svint32_t acc_3 = svdup_n_s32(0); - svint32_t acc_4 = svdup_n_s32(0); - svint32_t acc_5 = svdup_n_s32(0); - svint32_t acc_6 = svdup_n_s32(0); - svint32_t acc_7 = svdup_n_s32(0); + svint32_t acc_00 = svdup_n_s32(0); + svint32_t acc_11 = svdup_n_s32(0); + svint32_t acc_22 = svdup_n_s32(0); + svint32_t acc_33 = svdup_n_s32(0); + svint32_t acc_44 = svdup_n_s32(0); + svint32_t acc_55 = svdup_n_s32(0); + svint32_t acc_66 = svdup_n_s32(0); + svint32_t acc_77 = svdup_n_s32(0); - svint32_t bias_acc_0 = svdup_n_s32(0); - svint32_t bias_acc_1 = svdup_n_s32(0); - svint32_t bias_acc_2 = svdup_n_s32(0); - svint32_t bias_acc_3 = svdup_n_s32(0); - svint32_t bias_acc_4 = svdup_n_s32(0); - svint32_t bias_acc_5 = svdup_n_s32(0); - svint32_t bias_acc_6 = svdup_n_s32(0); - svint32_t bias_acc_7 = svdup_n_s32(0); + svint32_t bias_acc_00 = svdup_n_s32(0); + svint32_t bias_acc_22 = svdup_n_s32(0); + svint32_t bias_acc_44 = svdup_n_s32(0); + svint32_t bias_acc_66 = svdup_n_s32(0); for (int sb = 0; sb < QK_K / 64; sb++) { // Need scales for the low and high nibbles // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total - int8_t q4sb_scales_1[2][8]; - svint32_t q4sb_mins_0_0, q4sb_mins_0_1, q4sb_mins_1_0, q4sb_mins_1_1; - for (int i = 0; i < 2; i++) { // 2-superblock I am working on - const int offset = sb * 24 + i * 12; + svint32_t block_scale_0, block_scale_1, block_scale_2, block_scale_3; + svint32_t q4sb_mins_0, q4sb_mins_1; + { + // 2-superblock I am working on + const int offset = sb * 24 + 0 * 12; const uint8_t * scales_in = &q4_ptr_1[b].scales[offset]; + const int offset1 = sb * 24 + 12; + const uint8_t * scales_in1 = &q4_ptr_1[b].scales[offset1]; + constexpr uint32_t kmask1 = 0x3f3f3f3f; constexpr uint32_t kmask2 = 0x0f0f0f0f; constexpr uint32_t kmask3 = 0x03030303; @@ -3132,258 +3126,213 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, uint32_t sm[3]; memcpy(sm, scales_in, scales_size); + uint32_t sm1[3]; + memcpy(sm1, scales_in1, scales_size); + const uint32_t mins_0_3 = sm[1] & kmask1; const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4); - svuint32_t mins_u321 = svdup_n_u32(mins_0_3); - svuint32_t mins_u322 = svdup_n_u32(mins_4_7); + const uint32_t mins_0_3_1 = sm1[1] & kmask1; + const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4); + + svuint32_t mins_u32_temp = svzip1_u32(svdup_n_u32(mins_0_3), svdup_n_u32(mins_4_7)); + svuint32_t mins_u32_temp_1 = svzip1_u32(svdup_n_u32(mins_0_3_1), svdup_n_u32(mins_4_7_1)); + /* reinterpret u32 → u8 */ - svuint8_t mins_u81 = svreinterpret_u8_u32(mins_u321); - svuint8_t mins_u82 = svreinterpret_u8_u32(mins_u322); + svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp); + svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1); - /* widen u8 → u16 (lower half only) */ - svuint32_t mins_u161 = svunpklo_u32(svunpklo_u16(mins_u81)); - svuint32_t mins_u162 = svunpklo_u32(svunpklo_u16(mins_u82)); + /* widen u8 → u16->u32 (lower half only) */ + svuint32_t mins_u16 = svunpklo_u32(svunpklo_u16(mins_u8)); + svuint32_t mins_u16_1 = svunpklo_u32(svunpklo_u16(mins_u8_1)); - /* reinterpret u16 → s16 */ - if(i == 0) { - q4sb_mins_0_0 = svreinterpret_s32_u32(mins_u161); - q4sb_mins_0_1 = svreinterpret_s32_u32(mins_u162); - } else { - q4sb_mins_1_0 = svreinterpret_s32_u32(mins_u161); - q4sb_mins_1_1 = svreinterpret_s32_u32(mins_u162); - } + q4sb_mins_0 = svreinterpret_s32_u32(mins_u16); - uint32_t scales_u32[2]; - scales_u32[0] = sm[0] & kmask1; - scales_u32[1] = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4); - memcpy(q4sb_scales_1[i], scales_u32, 8); + q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1); + + uint32_t scales_u32_0 = sm[0] & kmask1; + uint32_t scales_u32_1 = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4); + uint32_t scales_u32_2 = sm1[0] & kmask1; + uint32_t scales_u32_3 = (sm1[2] & kmask2) | (((sm1[0] >> 6) & kmask3) << 4); + + svuint32_t S01 = svdup_n_u32(scales_u32_0); + svuint32_t S23 = svdup_n_u32(scales_u32_1); + svuint32_t R01 = svdup_n_u32(scales_u32_2); + svuint32_t R23 = svdup_n_u32(scales_u32_3); + + svint8_t S01_b = svreinterpret_s8_u32(S01); // s0 s1 s2 s3 ... + svint8_t S23_b = svreinterpret_s8_u32(S23); // s4 s5 s6 s7 ... + svint8_t R01_b = svreinterpret_s8_u32(R01); // r0 r1 r2 r3 ... + svint8_t R23_b = svreinterpret_s8_u32(R23); // r4 r5 r6 r7 ... + + svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b))); + // s0 s0 s1 s1 s2 s2 s3 s3 ... + + svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b))); + // r0 r0 r1 r1 r2 r2 r3 r3 ... + + svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b))); + // s4 s4 s5 s5 s6 s6 s7 s7 ... + + svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b))); + // r4 r4 r5 r5 r6 r6 r7 r7 ... + + block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx); + // s0 s0 s1 s1 r0 r0 r1 r1 + + block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx); + // s2 s2 s3 s3 r2 r2 r3 r3 + + block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx); + // s4 s4 s5 s5 r4 r4 r5 r5 + + block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx); + // s6 s6 s7 s7 r6 r6 r7 r7 } // q8_ptr[b].qs has interleaved Q8 rows (01, 23) - const int8_t * q8_base = q8_ptr[b].qs + sb * 256; + // const int8_t * q8_base = q8_ptr[b].qs + sb * 256; const int8_t * q8_base_1 = q8_ptr_1[b].qs + sb * 256; + // Load 32-byte per row pair, 1 subblock each time + // predicate for activating higher lanes for 16 int8 elements const svbool_t ph16 = svptrue_pat_b8(SV_VL16); // predicate for activating lower lanes for 16 int8 elements const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16); - svint8_t q8s_00 = svld1_s8(ph16, q8_base_1 + 0); - svint8_t q8s_01 = svld1_s8(ph16, q8_base_1 + 32); - svint8_t q8s_02 = svld1_s8(ph16, q8_base_1 + 64); - svint8_t q8s_03 = svld1_s8(ph16, q8_base_1 + 96); - svint8_t q8s_04 = svld1_s8(ph16, q8_base_1 + 128); - svint8_t q8s_05 = svld1_s8(ph16, q8_base_1 + 160); - svint8_t q8s_06 = svld1_s8(ph16, q8_base_1 + 192); - svint8_t q8s_07 = svld1_s8(ph16, q8_base_1 + 224); - svint8_t q8s_10 = svld1_s8(ph16, q8_base_1 + 16); - svint8_t q8s_11 = svld1_s8(ph16, q8_base_1 + 48); - svint8_t q8s_12 = svld1_s8(ph16, q8_base_1 + 80); - svint8_t q8s_13 = svld1_s8(ph16, q8_base_1 + 112); - svint8_t q8s_14 = svld1_s8(ph16, q8_base_1 + 144); - svint8_t q8s_15 = svld1_s8(ph16, q8_base_1 + 176); - svint8_t q8s_16 = svld1_s8(ph16, q8_base_1 + 208); - svint8_t q8s_17 = svld1_s8(ph16, q8_base_1 + 240); + svint8_t q8_qs_00 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 16)); + svint8_t q8_qs_02 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 80)); + svint8_t q8_qs_04 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 128), svld1_s8(pl16, q8_base_1 + 144)); + svint8_t q8_qs_06 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 192), svld1_s8(pl16, q8_base_1 + 208)); + + svint8_t q8_qs_10 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 32)); + svint8_t q8_qs_12 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 96)); + svint8_t q8_qs_14 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 144), svld1_s8(pl16, q8_base_1 + 160)); + svint8_t q8_qs_16 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 208), svld1_s8(pl16, q8_base_1 + 224)); // Q4s columns iterated in pairs (01, 23, 45, 67) for (int cp = 0; cp < ncols_interleaved / 2; cp++) { - sb_acc_0 = svdup_n_s32(0); - sb_acc_1 = svdup_n_s32(0); - sb_acc_2 = svdup_n_s32(0); - sb_acc_3 = svdup_n_s32(0); + sb_acc_00 = svdup_n_s32(0); + sb_acc_11 = svdup_n_s32(0); + sb_acc_22 = svdup_n_s32(0); + sb_acc_33 = svdup_n_s32(0); svbool_t pg = svptrue_pat_b8(SV_VL16); - svuint8_t q4_qs_cp_00 = svld1_u8(pg, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 0); - svuint8_t q4_qs_cp_10 = svld1_u8(pg, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 64); - svuint8_t q4_qs_cp_20 = svld1_u8(pg, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 128); - svuint8_t q4_qs_cp_30 = svld1_u8(pg, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 192); + svuint8_t q4_qs_cp_0 = svadd_u8_m(svptrue_b8(),svld1_u8(pg, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 0), svld1_u8(pl16, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 48)); + svuint8_t q4_qs_cp_1 = svadd_u8_m(svptrue_b8(),svld1_u8(pg, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 128), svld1_u8(pl16, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 176)); - svint8_t q4_nibbles_0_0 = svreinterpret_s8_u8(svand_u8_m(pg, q4_qs_cp_00, m4b_1)); - svint8_t q4_nibbles_0_1 = svreinterpret_s8_u8(svand_u8_m(pg, q4_qs_cp_10, m4b_1)); - svint8_t q4_nibbles_0_2 = svreinterpret_s8_u8(svand_u8_m(pg, q4_qs_cp_20, m4b_1)); - svint8_t q4_nibbles_0_3 = svreinterpret_s8_u8(svand_u8_m(pg, q4_qs_cp_30, m4b_1)); + svint8_t q4_nibbles_0 = svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q4_qs_cp_0, m4b_1)); + svint8_t q4_nibbles_1 = svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q4_qs_cp_1, m4b_1)); + svint8_t q4_nibbles_2 = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), q4_qs_cp_0, 4)); + svint8_t q4_nibbles_3 = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), q4_qs_cp_1, 4)); - svint8_t q4_nibbles_1_0 = svreinterpret_s8_u8(svlsr_n_u8_m(pg, q4_qs_cp_00, 4)); - svint8_t q4_nibbles_1_1 = svreinterpret_s8_u8(svlsr_n_u8_m(pg, q4_qs_cp_10, 4)); - svint8_t q4_nibbles_1_2 = svreinterpret_s8_u8(svlsr_n_u8_m(pg, q4_qs_cp_20, 4)); - svint8_t q4_nibbles_1_3 = svreinterpret_s8_u8(svlsr_n_u8_m(pg, q4_qs_cp_30, 4)); + sb_acc_00 = svmmla_s32(sb_acc_00, q4_nibbles_0, q8_qs_00); + sb_acc_00 = svmmla_s32(sb_acc_00, q4_nibbles_1, q8_qs_02); - sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_0_0, q8s_00); - sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_0_1, q8s_01); - sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_0_2, q8s_02); - sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_0_3, q8s_03); + sb_acc_11 = svmmla_s32(sb_acc_11, q4_nibbles_2, q8_qs_04); + sb_acc_11 = svmmla_s32(sb_acc_11, q4_nibbles_3, q8_qs_06); - sb_acc_1 = svmmla_s32(sb_acc_1, q4_nibbles_1_0, q8s_04); - sb_acc_1 = svmmla_s32(sb_acc_1, q4_nibbles_1_1, q8s_05); - sb_acc_1 = svmmla_s32(sb_acc_1, q4_nibbles_1_2, q8s_06); - sb_acc_1 = svmmla_s32(sb_acc_1, q4_nibbles_1_3, q8s_07); + sb_acc_22 = svmmla_s32(sb_acc_22, q4_nibbles_0, q8_qs_10); + sb_acc_22 = svmmla_s32(sb_acc_22, q4_nibbles_1, q8_qs_12); - sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_0_0, q8s_10); - sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_0_1, q8s_11); - sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_0_2, q8s_12); - sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_0_3, q8s_13); + sb_acc_33 = svmmla_s32(sb_acc_33, q4_nibbles_2, q8_qs_14); + sb_acc_33 = svmmla_s32(sb_acc_33, q4_nibbles_3, q8_qs_16); - sb_acc_3 = svmmla_s32(sb_acc_3, q4_nibbles_1_0, q8s_14); - sb_acc_3 = svmmla_s32(sb_acc_3, q4_nibbles_1_1, q8s_15); - sb_acc_3 = svmmla_s32(sb_acc_3, q4_nibbles_1_2, q8s_16); - sb_acc_3 = svmmla_s32(sb_acc_3, q4_nibbles_1_3, q8s_17); - - // Scales[i] corresponds to column i - const int scale_offset = cp * 2; - - int32_t tmp[8] = { - (int32_t) q4sb_scales_1[0][scale_offset], - (int32_t) q4sb_scales_1[0][scale_offset], - (int32_t) q4sb_scales_1[0][scale_offset + 1], - (int32_t) q4sb_scales_1[0][scale_offset + 1], - }; - int32_t tmp1[4] = { - (int32_t) q4sb_scales_1[1][scale_offset], - (int32_t) q4sb_scales_1[1][scale_offset], - (int32_t) q4sb_scales_1[1][scale_offset + 1], - (int32_t) q4sb_scales_1[1][scale_offset + 1], - }; - - svint32_t block_scale = svld1_s32(svptrue_pat_b32(SV_VL4), tmp); - svint32_t block_scale1 = svld1_s32(svptrue_pat_b32(SV_VL4), tmp1); + sb_acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), sb_acc_00, svext_s32(sb_acc_00, sb_acc_00, 4)); + sb_acc_11 = svext_s32(svdup_s32(0), svadd_s32_z(svptrue_pat_b32(SV_VL4), sb_acc_11, svext_s32(sb_acc_11, sb_acc_11, 4)), 4); + sb_acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), sb_acc_22, svext_s32(sb_acc_22, sb_acc_22, 4)); + sb_acc_33 = svext_s32(svdup_s32(0), svadd_s32_z(svptrue_pat_b32(SV_VL4), sb_acc_33, svext_s32(sb_acc_33, sb_acc_33, 4)), 4); + sb_acc_00 = svadd_s32_m(svptrue_b32(), sb_acc_00, sb_acc_11); + sb_acc_22 = svadd_s32_m(svptrue_b32(), sb_acc_22, sb_acc_33); if(cp == 0) { - acc_0 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_0, sb_acc_0, block_scale); - acc_4 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_4, sb_acc_2, block_scale); - acc_0 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_0, sb_acc_1, block_scale1); - acc_4 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_4, sb_acc_3, block_scale1); + acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_00, block_scale_0); + acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_22, block_scale_0); } if(cp == 1) { - acc_1 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_1, sb_acc_0, block_scale); - acc_5 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_5, sb_acc_2, block_scale); - acc_1 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_1, sb_acc_1, block_scale1); - acc_5 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_5, sb_acc_3, block_scale1); + acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_00, block_scale_1); + acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_22, block_scale_1); } if(cp == 2) { - acc_2 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_2, sb_acc_0, block_scale); - acc_6 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_6, sb_acc_2, block_scale); - acc_2 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_2, sb_acc_1, block_scale1); - acc_6 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_6, sb_acc_3, block_scale1); + acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_00, block_scale_2); + acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_22, block_scale_2); } if(cp == 3) { - acc_3 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_3, sb_acc_0, block_scale); - acc_7 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_7, sb_acc_2, block_scale); - acc_3 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_3, sb_acc_1, block_scale1); - acc_7 = svmla_s32_m(svptrue_pat_b32(SV_VL4), acc_7, sb_acc_3, block_scale1); + acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_00, block_scale_3); + acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_22, block_scale_3); } } - bias_acc_0 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_0, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0_0); - bias_acc_0 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_0, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1_0); - bias_acc_1 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_1, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0_1); - bias_acc_1 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_1, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1_1); + bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0); + bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1); - bias_acc_2 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_2, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0_0); - bias_acc_2 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_2, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1_0); - bias_acc_3 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_3, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0_1); - bias_acc_3 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_3, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1_1); + bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0); + bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1); - bias_acc_4 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_4, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0_0); - bias_acc_4 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_4, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1_0); - bias_acc_5 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_5, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0_1); - bias_acc_5 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_5, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1_1); + bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0); + bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1); - bias_acc_6 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_6, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0_0); - bias_acc_6 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_6, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1_0); - bias_acc_7 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_7, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0_1); - bias_acc_7 = svmla_s32_m(svptrue_pat_b32(SV_VL4), bias_acc_7, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1_1); + bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0); + bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1); } // for sb - svint32_t reorder_acc_0 = svtbl_s32(svtrn1_s32(acc_0, acc_1), idx); - svint32_t reorder_acc_1 = svtbl_s32(svtrn1_s32(acc_2, acc_3), idx); - svint32_t reorder_acc_2 = svtbl_s32(svtrn2_s32(acc_0, acc_1), idx); - svint32_t reorder_acc_3 = svtbl_s32(svtrn2_s32(acc_2, acc_3), idx); - // acc[4..7] - svint32_t reorder_acc_4 = svtbl_s32(svtrn1_s32(acc_4, acc_5), idx); - svint32_t reorder_acc_5 = svtbl_s32(svtrn1_s32(acc_6, acc_7), idx); - svint32_t reorder_acc_6 = svtbl_s32(svtrn2_s32(acc_4, acc_5), idx); - svint32_t reorder_acc_7 = svtbl_s32(svtrn2_s32(acc_6, acc_7), idx); - // Predicate for exactly 4 lanes - svbool_t pg4 = svptrue_pat_b32(SV_VL4); + // acc[0..3] // acc[4..7] + acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4)); + acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4)); + acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4)); + acc_33 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_33, svext_s32(acc_33, acc_33, 4)); + acc_44 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_44, svext_s32(acc_44, acc_44, 4)); + acc_55 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_55, svext_s32(acc_55, acc_55, 4)); + acc_66 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_66, svext_s32(acc_66, acc_66, 4)); + acc_77 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_77, svext_s32(acc_77, acc_77, 4)); + + svint32_t reorder_acc_01 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_00, acc_11), svtrn1_s32(acc_22, acc_33)), idx1); + svint32_t reorder_acc_23 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_00, acc_11), svtrn2_s32(acc_22, acc_33)), idx1); + + svint32_t reorder_acc_45 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_44, acc_55), svtrn1_s32(acc_66, acc_77)), idx1); + svint32_t reorder_acc_67 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_44, acc_55), svtrn2_s32(acc_66, acc_77)), idx1); // Broadcast q8 scalar svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]); - // ---- dmins ---- - svfloat16_t q4_dmin_h = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *)(q4_ptr[b].dmin + 0)); - q4_dmin_h = svzip1_f16(q4_dmin_h, q4_dmin_h); - svfloat32_t q4_dmin = svcvt_f32_f16_x(svptrue_pat_b32(SV_VL4), q4_dmin_h); + svfloat32_t q4_dmin_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].dmin), svdup_f16(0))); - svfloat16_t q4_dmin_h1 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *)(q4_ptr[b].dmin + 4)); - q4_dmin_h1 = svzip1_f16(q4_dmin_h1, q4_dmin_h1); - svfloat32_t q4_dmin1 = svcvt_f32_f16_x(svptrue_pat_b32(SV_VL4), q4_dmin_h1); + svfloat32_t q4_d_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].d), svdup_f16(0))); - //-----scale ----------------------- - svfloat16_t q4_d_h = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *)(q4_ptr[b].d + 0)); - q4_d_h = svzip1_f16(q4_d_h, q4_d_h); - svfloat32_t q4_d = svcvt_f32_f16_x(svptrue_pat_b32(SV_VL4), q4_d_h); + svfloat32_t scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + svfloat32_t dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); - svfloat16_t q4_d_h1 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *)(q4_ptr[b].d + 4)); - q4_d_h1 = svzip1_f16(q4_d_h1, q4_d_h1); - svfloat32_t q4_d1 = svcvt_f32_f16_x(svptrue_pat_b32(SV_VL4), q4_d_h1); - - svfloat32_t scale = svmul_f32_x(pg4, q4_d, q8_d); - svfloat32_t dmins = svmul_f32_x(pg4, q4_dmin, q8_d); - - acc_f32_0 = svmls_f32_m(pg4, acc_f32_0, svcvt_f32_s32_m(svdup_n_f32(0), pg4, bias_acc_0), dmins); - acc_f32_0 = svmla_f32_m(pg4, acc_f32_0, svcvt_f32_s32_m(svdup_n_f32(0), pg4, reorder_acc_0), scale); - - scale = svmul_f32_x(pg4, q4_d1, q8_d); - dmins = svmul_f32_x(pg4, q4_dmin1, q8_d); - - acc_f32_1 = svmls_f32_m(pg4, acc_f32_1, svcvt_f32_s32_m(svdup_n_f32(0), pg4, bias_acc_1), dmins); - acc_f32_1 = svmla_f32_m(pg4, acc_f32_1, svcvt_f32_s32_m(svdup_n_f32(0), pg4, reorder_acc_1), scale); + acc_f32_01 = svmls_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_00), dmins1); + acc_f32_01 = svmla_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_01), scale1); q8_d = svdup_f32(q8_ptr[b].d[1]); - scale = svmul_f32_x(pg4, q4_d, q8_d); - dmins = svmul_f32_x(pg4, q4_dmin, q8_d); + scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); - acc_f32_2 = svmls_f32_m(pg4, acc_f32_2, svcvt_f32_s32_m(svdup_n_f32(0), pg4, bias_acc_2), dmins); - acc_f32_2 = svmla_f32_m(pg4, acc_f32_2, svcvt_f32_s32_m(svdup_n_f32(0), pg4, reorder_acc_2), scale); - - scale = svmul_f32_x(pg4, q4_d1, q8_d); - dmins = svmul_f32_x(pg4, q4_dmin1, q8_d); - - acc_f32_3 = svmls_f32_m(pg4, acc_f32_3, svcvt_f32_s32_m(svdup_n_f32(0), pg4, bias_acc_3), dmins); - acc_f32_3 = svmla_f32_m(pg4, acc_f32_3, svcvt_f32_s32_m(svdup_n_f32(0), pg4, reorder_acc_3), scale); + acc_f32_23 = svmls_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_22), dmins1); + acc_f32_23 = svmla_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_23), scale1); q8_d = svdup_f32(q8_ptr[b].d[2]); - scale = svmul_f32_x(pg4, q4_d, q8_d); - dmins = svmul_f32_x(pg4, q4_dmin, q8_d); - acc_f32_4 = svmls_f32_m(pg4, acc_f32_4, svcvt_f32_s32_m(svdup_n_f32(0), pg4, bias_acc_4), dmins); - acc_f32_4 = svmla_f32_m(pg4, acc_f32_4, svcvt_f32_s32_m(svdup_n_f32(0), pg4, reorder_acc_4), scale); + scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); - scale = svmul_f32_x(pg4, q4_d1, q8_d); - dmins = svmul_f32_x(pg4, q4_dmin1, q8_d); - - acc_f32_5 = svmls_f32_m(pg4, acc_f32_5, svcvt_f32_s32_m(svdup_n_f32(0), pg4, bias_acc_5), dmins); - acc_f32_5 = svmla_f32_m(pg4, acc_f32_5, svcvt_f32_s32_m(svdup_n_f32(0), pg4, reorder_acc_5), scale); + acc_f32_45 = svmls_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_44), dmins1); + acc_f32_45 = svmla_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_45), scale1); q8_d = svdup_f32(q8_ptr[b].d[3]); - scale = svmul_f32_x(pg4, q4_d, q8_d); - dmins = svmul_f32_x(pg4, q4_dmin, q8_d); + scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); - acc_f32_6 = svmls_f32_m(pg4, acc_f32_6, svcvt_f32_s32_m(svdup_n_f32(0), pg4, bias_acc_6), dmins); - acc_f32_6 = svmla_f32_m(pg4, acc_f32_6, svcvt_f32_s32_m(svdup_n_f32(0), pg4, reorder_acc_6), scale); - - scale = svmul_f32_x(pg4, q4_d1, q8_d); - dmins = svmul_f32_x(pg4, q4_dmin1, q8_d); - - acc_f32_7 = svmls_f32_m(pg4, acc_f32_7, svcvt_f32_s32_m(svdup_n_f32(0), pg4, bias_acc_7), dmins); - acc_f32_7 = svmla_f32_m(pg4, acc_f32_7, svcvt_f32_s32_m(svdup_n_f32(0), pg4, reorder_acc_7), scale); + acc_f32_67 = svmls_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_66), dmins1); + acc_f32_67 = svmla_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_67), scale1); } // for b @@ -3395,23 +3344,31 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, for (int j = 0; j < 2; j++) { int col = x * ncols_interleaved + j * 4; int offset = row * bs + col; - // vst1q_f32(s + offset, acc_f32[2 * i + j]); + if (i == 0 && j == 0) { - svst1_f32(pg4, s + offset, acc_f32_0); + // acc_f32_0 → lower half of acc_f32_01 + svst1_f32(pg4, s + offset, acc_f32_01); } else if (i == 0 && j == 1) { - svst1_f32(pg4, s + offset, acc_f32_1); + // acc_f32_1 → upper half of acc_f32_01 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_01, acc_f32_01, 4)); } else if (i == 1 && j == 0) { - svst1_f32(pg4, s + offset, acc_f32_2); + // acc_f32_2 + svst1_f32(pg4, s + offset, acc_f32_23); } else if (i == 1 && j == 1) { - svst1_f32(pg4, s + offset, acc_f32_3); + // acc_f32_3 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_23, acc_f32_23, 4)); } else if (i == 2 && j == 0) { - svst1_f32(pg4, s + offset, acc_f32_4); + // acc_f32_4 + svst1_f32(pg4, s + offset, acc_f32_45); } else if (i == 2 && j == 1) { - svst1_f32(pg4, s + offset, acc_f32_5); + // acc_f32_5 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_45, acc_f32_45, 4)); } else if (i == 3 && j == 0) { - svst1_f32(pg4, s + offset, acc_f32_6); + // acc_f32_6 + svst1_f32(pg4, s + offset, acc_f32_67); } else if (i == 3 && j == 1) { - svst1_f32(pg4, s + offset, acc_f32_7); + // acc_f32_7 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_67, acc_f32_67, 4)); } } }