diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index b545c2586d..ea750d93c4 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -3751,62 +3751,75 @@ void ggml_gemm_q6_K_8x4_q8_K(int n, // 4 rows * 16 elements per scale // 4 reads of 16 bytes each constexpr int reads_per_sb = 4; + int8x16_t q8_l[reads_per_sb]; + int8x16_t q8_h[reads_per_sb]; for (int k = 0; k < reads_per_sb; k++) { - const int8x16_t q8_l = vld1q_s8(q8_base_l + 16 * k); - const int8x16_t q8_h = vld1q_s8(q8_base_h + 16 * k); + q8_l[k] = vld1q_s8(q8_base_l + 16 * k); + q8_h[k] = vld1q_s8(q8_base_h + 16 * k); + } - const int ql_off_base = sb * QK_K / 2 + k * 32; - const int qh_off_base = ql_off_base & 255; + const int ql_off_base = sb * QK_K / 2; + const int qh_off_base = ql_off_base & 255; - uint8x16_t q6_ql_0123 = vld1q_u8(ql_base + ql_off_base); - uint8x16_t q6_ql_4567 = vld1q_u8(ql_base + ql_off_base + 16); - uint8x16_t q6_qh_0123 = vld1q_u8(qh_base + qh_off_base); - uint8x16_t q6_qh_4567 = vld1q_u8(qh_base + qh_off_base + 16); + uint8x16_t q6_ql_0123[reads_per_sb]; + uint8x16_t q6_ql_4567[reads_per_sb]; + uint8x16_t q6_qh_0123[reads_per_sb]; + uint8x16_t q6_qh_4567[reads_per_sb]; - if (sb > 1) { - q6_qh_0123 = vshrq_n_u8(q6_qh_0123, 2); - q6_qh_4567 = vshrq_n_u8(q6_qh_4567, 2); + for (int k = 0; k < reads_per_sb; k++) { + q6_ql_0123[k] = vld1q_u8(ql_base + ql_off_base + k * 32); + q6_ql_4567[k] = vld1q_u8(ql_base + ql_off_base + k * 32 + 16); + q6_qh_0123[k] = vld1q_u8(qh_base + qh_off_base + k * 32); + q6_qh_4567[k] = vld1q_u8(qh_base + qh_off_base + k * 32 + 16); + } + + if (sb > 1) { + for (int k = 0; k < reads_per_sb; k++) { + q6_qh_0123[k] = vshrq_n_u8(q6_qh_0123[k], 2); + q6_qh_4567[k] = vshrq_n_u8(q6_qh_4567[k], 2); } + } + for (int k = 0; k < reads_per_sb; k++) { // q = (ql | qh) - 32 - const uint8x16_t hbit_lo_0123 = vandq_u8(q6_qh_0123, mask_lo); - const uint8x16_t hbit_hi_0123 = vandq_u8(q6_qh_0123, mask_hi); - const uint8x16_t hbit_lo_4567 = vandq_u8(q6_qh_4567, mask_lo); - const uint8x16_t hbit_hi_4567 = vandq_u8(q6_qh_4567, mask_hi); + const uint8x16_t hbit_lo_0123 = vandq_u8(q6_qh_0123[k], mask_lo); + const uint8x16_t hbit_hi_0123 = vandq_u8(q6_qh_0123[k], mask_hi); + const uint8x16_t hbit_lo_4567 = vandq_u8(q6_qh_4567[k], mask_lo); + const uint8x16_t hbit_hi_4567 = vandq_u8(q6_qh_4567[k], mask_hi); const int8x16_t q6_0123_lo = vsubq_s8( - vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_0123, m4b), hbit_lo_0123, 4)), + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_0123[k], m4b), hbit_lo_0123, 4)), m32s); const int8x16_t q6_0123_hi = vsubq_s8( - vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_0123, 4), hbit_hi_0123)), + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_0123[k], 4), hbit_hi_0123)), m32s); - acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q6_0123_lo, q8_l, 0); // 0..3 r0 c0123 - acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q6_0123_lo, q8_l, 1); // 0..3 r1 c0123 - acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q6_0123_lo, q8_l, 2); // 0..3 r2 c0123 - acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q6_0123_lo, q8_l, 3); // 0..3 r3 c0123 + acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q6_0123_lo, q8_l[k], 0); // 0..3 r0 c0123 + acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q6_0123_lo, q8_l[k], 1); // 0..3 r1 c0123 + acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q6_0123_lo, q8_l[k], 2); // 0..3 r2 c0123 + acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q6_0123_lo, q8_l[k], 3); // 0..3 r3 c0123 - acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q6_0123_hi, q8_h, 0); // 64..67 r0 c0123 - acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q6_0123_hi, q8_h, 1); // 64..67 r1 c0123 - acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q6_0123_hi, q8_h, 2); // 64..67 r2 c0123 - acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q6_0123_hi, q8_h, 3); // 64..67 r3 c0123 + acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q6_0123_hi, q8_h[k], 0); // 64..67 r0 c0123 + acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q6_0123_hi, q8_h[k], 1); // 64..67 r1 c0123 + acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q6_0123_hi, q8_h[k], 2); // 64..67 r2 c0123 + acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q6_0123_hi, q8_h[k], 3); // 64..67 r3 c0123 const int8x16_t q6_4567_lo = vsubq_s8( - vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_4567, m4b), hbit_lo_4567, 4)), + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_4567[k], m4b), hbit_lo_4567, 4)), m32s); const int8x16_t q6_4567_hi = vsubq_s8( - vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_4567, 4), hbit_hi_4567)), + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_4567[k], 4), hbit_hi_4567)), m32s); - acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q6_4567_lo, q8_l, 0); // 0..3 r0 c4567 - acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q6_4567_lo, q8_l, 1); // 0..3 r1 c4567 - acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q6_4567_lo, q8_l, 2); // 0..3 r2 c4567 - acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q6_4567_lo, q8_l, 3); // 0..3 r3 c4567 + acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q6_4567_lo, q8_l[k], 0); // 0..3 r0 c4567 + acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q6_4567_lo, q8_l[k], 1); // 0..3 r1 c4567 + acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q6_4567_lo, q8_l[k], 2); // 0..3 r2 c4567 + acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q6_4567_lo, q8_l[k], 3); // 0..3 r3 c4567 - acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q6_4567_hi, q8_h, 0); // 64..67 r0 c4567 - acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q6_4567_hi, q8_h, 1); // 64..67 r1 c4567 - acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q6_4567_hi, q8_h, 2); // 64..67 r2 c4567 - acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q6_4567_hi, q8_h, 3); // 64..67 r3 c4567 + acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q6_4567_hi, q8_h[k], 0); // 64..67 r0 c4567 + acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q6_4567_hi, q8_h[k], 1); // 64..67 r1 c4567 + acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q6_4567_hi, q8_h[k], 2); // 64..67 r2 c4567 + acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q6_4567_hi, q8_h[k], 3); // 64..67 r3 c4567 } // Scale and bias