diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index b0d25cdc47..c07e28f37d 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -3086,10 +3086,8 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, vst1q_s32(&bsums_arr32[q8_row][4], v32_hi); } - svint32_t sb_acc_00 = svdup_n_s32(0); - svint32_t sb_acc_11 = svdup_n_s32(0); - svint32_t sb_acc_22 = svdup_n_s32(0); - svint32_t sb_acc_33 = svdup_n_s32(0); + svint32_t sb_acc_0 = svdup_n_s32(0); + svint32_t sb_acc_2 = svdup_n_s32(0); svint32_t acc_00 = svdup_n_s32(0); svint32_t acc_11 = svdup_n_s32(0); @@ -3200,69 +3198,59 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, // predicate for activating lower lanes for 16 int8 elements const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16); + svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112)); + svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144)); + svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176)); + svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208)); - svint8_t q8_qs_00 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 16)); - svint8_t q8_qs_02 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 80)); - svint8_t q8_qs_04 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 128), svld1_s8(pl16, q8_base_1 + 144)); - svint8_t q8_qs_06 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 192), svld1_s8(pl16, q8_base_1 + 208)); - - svint8_t q8_qs_10 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 32)); - svint8_t q8_qs_12 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 96)); - svint8_t q8_qs_14 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 144), svld1_s8(pl16, q8_base_1 + 160)); - svint8_t q8_qs_16 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 208), svld1_s8(pl16, q8_base_1 + 224)); + svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128)); + svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160)); + svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192)); + svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224)); // Q4s columns iterated in pairs (01, 23, 45, 67) for (int cp = 0; cp < ncols_interleaved / 2; cp++) { - sb_acc_00 = svdup_n_s32(0); - sb_acc_11 = svdup_n_s32(0); - sb_acc_22 = svdup_n_s32(0); - sb_acc_33 = svdup_n_s32(0); + sb_acc_0 = svdup_n_s32(0); + sb_acc_2 = svdup_n_s32(0); - svbool_t pg = svptrue_pat_b8(SV_VL16); + svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 0); + svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 64); + svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 128); + svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 192); - svuint8_t q4_qs_cp_0 = svadd_u8_m(svptrue_b8(),svld1_u8(pg, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 0), svld1_u8(pl16, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 48)); - svuint8_t q4_qs_cp_1 = svadd_u8_m(svptrue_b8(),svld1_u8(pg, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 128), svld1_u8(pl16, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 176)); + svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4)); + svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4)); + svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4)); + svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4)); - svint8_t q4_nibbles_0 = svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q4_qs_cp_0, m4b_1)); - svint8_t q4_nibbles_1 = svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q4_qs_cp_1, m4b_1)); - svint8_t q4_nibbles_2 = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), q4_qs_cp_0, 4)); - svint8_t q4_nibbles_3 = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), q4_qs_cp_1, 4)); + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0); + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2); - sb_acc_00 = svmmla_s32(sb_acc_00, q4_nibbles_0, q8_qs_00); - sb_acc_00 = svmmla_s32(sb_acc_00, q4_nibbles_1, q8_qs_02); + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4); + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6); - sb_acc_11 = svmmla_s32(sb_acc_11, q4_nibbles_2, q8_qs_04); - sb_acc_11 = svmmla_s32(sb_acc_11, q4_nibbles_3, q8_qs_06); + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1); + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3); - sb_acc_22 = svmmla_s32(sb_acc_22, q4_nibbles_0, q8_qs_10); - sb_acc_22 = svmmla_s32(sb_acc_22, q4_nibbles_1, q8_qs_12); - - sb_acc_33 = svmmla_s32(sb_acc_33, q4_nibbles_2, q8_qs_14); - sb_acc_33 = svmmla_s32(sb_acc_33, q4_nibbles_3, q8_qs_16); - - sb_acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), sb_acc_00, svext_s32(sb_acc_00, sb_acc_00, 4)); - sb_acc_11 = svext_s32(svdup_s32(0), svadd_s32_z(svptrue_pat_b32(SV_VL4), sb_acc_11, svext_s32(sb_acc_11, sb_acc_11, 4)), 4); - sb_acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), sb_acc_22, svext_s32(sb_acc_22, sb_acc_22, 4)); - sb_acc_33 = svext_s32(svdup_s32(0), svadd_s32_z(svptrue_pat_b32(SV_VL4), sb_acc_33, svext_s32(sb_acc_33, sb_acc_33, 4)), 4); - sb_acc_00 = svadd_s32_m(svptrue_b32(), sb_acc_00, sb_acc_11); - sb_acc_22 = svadd_s32_m(svptrue_b32(), sb_acc_22, sb_acc_33); + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5); + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7); if(cp == 0) { - acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_00, block_scale_0); - acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_22, block_scale_0); + acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0); + acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0); } if(cp == 1) { - acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_00, block_scale_1); - acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_22, block_scale_1); + acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1); + acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1); } if(cp == 2) { - acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_00, block_scale_2); - acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_22, block_scale_2); + acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2); + acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2); } if(cp == 3) { - acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_00, block_scale_3); - acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_22, block_scale_3); + acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3); + acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3); } }