diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index 5f6fd655d2..b461930a10 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -3039,7 +3039,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, UNUSED(blocklen); #if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) - if (svcntb()*8 == 256) { + if (svcntb() * 8 == 256) { constexpr int q8_k_blocklen = 4; const svuint8_t m4b_1 = svdup_n_u8(0x0f); // 8 accumulators: 2 row pairs × 4 col pairs @@ -3053,11 +3053,9 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, for (int y = 0; y < nr / q8_k_blocklen; y++) { const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); - const block_q8_Kx4 * GGML_RESTRICT q8_ptr_1 = (const block_q8_Kx4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); - const block_q4_Kx8 * GGML_RESTRICT q4_ptr_1 = (const block_q4_Kx8 *) vx + (x * nb); acc_f32_01 = svdup_n_f32(0); acc_f32_23 = svdup_n_f32(0); @@ -3065,7 +3063,8 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, acc_f32_67 = svdup_n_f32(0); for (int b = 0; b < nb; b++) { - // bsums pairs belongs to the same q8_k subblock // 64 elemnts loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum + // bsums pairs belongs to the same q8_k subblock + // 64 elemnts loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum const int16x8_t bsums[4]{ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), @@ -3112,10 +3111,10 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, { // 2-superblock I am working on const int offset = sb * 24 + 0 * 12; - const uint8_t * scales_in = &q4_ptr_1[b].scales[offset]; + const uint8_t * scales_in = &q4_ptr[b].scales[offset]; const int offset1 = sb * 24 + 12; - const uint8_t * scales_in1 = &q4_ptr_1[b].scales[offset1]; + const uint8_t * scales_in1 = &q4_ptr[b].scales[offset1]; constexpr uint32_t kmask1 = 0x3f3f3f3f; constexpr uint32_t kmask2 = 0x0f0f0f0f; @@ -3159,39 +3158,23 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, svuint32_t R01 = svdup_n_u32(scales_u32_2); svuint32_t R23 = svdup_n_u32(scales_u32_3); - svint8_t S01_b = svreinterpret_s8_u32(S01); // s0 s1 s2 s3 ... - svint8_t S23_b = svreinterpret_s8_u32(S23); // s4 s5 s6 s7 ... - svint8_t R01_b = svreinterpret_s8_u32(R01); // r0 r1 r2 r3 ... - svint8_t R23_b = svreinterpret_s8_u32(R23); // r4 r5 r6 r7 ... + svint8_t S01_b = svreinterpret_s8_u32(S01); + svint8_t S23_b = svreinterpret_s8_u32(S23); + svint8_t R01_b = svreinterpret_s8_u32(R01); + svint8_t R23_b = svreinterpret_s8_u32(R23); svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b))); - // s0 s0 s1 s1 s2 s2 s3 s3 ... - svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b))); - // r0 r0 r1 r1 r2 r2 r3 r3 ... - svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b))); - // s4 s4 s5 s5 s6 s6 s7 s7 ... - svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b))); - // r4 r4 r5 r5 r6 r6 r7 r7 ... block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx); - // s0 s0 s1 s1 r0 r0 r1 r1 - block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx); - // s2 s2 s3 s3 r2 r2 r3 r3 - block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx); - // s4 s4 s5 s5 r4 r4 r5 r5 - block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx); - // s6 s6 s7 s7 r6 r6 r7 r7 } - // q8_ptr[b].qs has interleaved Q8 rows (01, 23) - // const int8_t * q8_base = q8_ptr[b].qs + sb * 256; - const int8_t * q8_base_1 = q8_ptr_1[b].qs + sb * 256; + const int8_t * q8_base_1 = q8_ptr[b].qs + sb * 256; // Load 32-byte per row pair, 1 subblock each time // predicate for activating higher lanes for 16 int8 elements @@ -3215,10 +3198,10 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, sb_acc_0 = svdup_n_s32(0); sb_acc_2 = svdup_n_s32(0); - svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 0); - svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 64); - svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 128); - svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 192); + svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); + svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); + svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); + svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4)); svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4)); @@ -3269,7 +3252,6 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, } // for sb - // acc[0..3] // acc[4..7] acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4)); acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4)); acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4));