Updated repack.cpp

This commit is contained in:
Vithule, Prashant 2026-01-13 04:41:28 +00:00 committed by Vithulep
parent c74d605db4
commit cde62986b6
1 changed files with 36 additions and 48 deletions

View File

@ -3086,10 +3086,8 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
vst1q_s32(&bsums_arr32[q8_row][4], v32_hi);
}
svint32_t sb_acc_00 = svdup_n_s32(0);
svint32_t sb_acc_11 = svdup_n_s32(0);
svint32_t sb_acc_22 = svdup_n_s32(0);
svint32_t sb_acc_33 = svdup_n_s32(0);
svint32_t sb_acc_0 = svdup_n_s32(0);
svint32_t sb_acc_2 = svdup_n_s32(0);
svint32_t acc_00 = svdup_n_s32(0);
svint32_t acc_11 = svdup_n_s32(0);
@ -3200,69 +3198,59 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
// predicate for activating lower lanes for 16 int8 elements
const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112));
svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144));
svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176));
svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208));
svint8_t q8_qs_00 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 16));
svint8_t q8_qs_02 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 80));
svint8_t q8_qs_04 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 128), svld1_s8(pl16, q8_base_1 + 144));
svint8_t q8_qs_06 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 192), svld1_s8(pl16, q8_base_1 + 208));
svint8_t q8_qs_10 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 32));
svint8_t q8_qs_12 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 96));
svint8_t q8_qs_14 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 144), svld1_s8(pl16, q8_base_1 + 160));
svint8_t q8_qs_16 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 208), svld1_s8(pl16, q8_base_1 + 224));
svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128));
svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160));
svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192));
svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224));
// Q4s columns iterated in pairs (01, 23, 45, 67)
for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
sb_acc_00 = svdup_n_s32(0);
sb_acc_11 = svdup_n_s32(0);
sb_acc_22 = svdup_n_s32(0);
sb_acc_33 = svdup_n_s32(0);
sb_acc_0 = svdup_n_s32(0);
sb_acc_2 = svdup_n_s32(0);
svbool_t pg = svptrue_pat_b8(SV_VL16);
svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 0);
svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 64);
svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 128);
svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 192);
svuint8_t q4_qs_cp_0 = svadd_u8_m(svptrue_b8(),svld1_u8(pg, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 0), svld1_u8(pl16, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 48));
svuint8_t q4_qs_cp_1 = svadd_u8_m(svptrue_b8(),svld1_u8(pg, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 128), svld1_u8(pl16, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 176));
svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4));
svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4));
svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4));
svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4));
svint8_t q4_nibbles_0 = svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q4_qs_cp_0, m4b_1));
svint8_t q4_nibbles_1 = svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q4_qs_cp_1, m4b_1));
svint8_t q4_nibbles_2 = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), q4_qs_cp_0, 4));
svint8_t q4_nibbles_3 = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), q4_qs_cp_1, 4));
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0);
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2);
sb_acc_00 = svmmla_s32(sb_acc_00, q4_nibbles_0, q8_qs_00);
sb_acc_00 = svmmla_s32(sb_acc_00, q4_nibbles_1, q8_qs_02);
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4);
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6);
sb_acc_11 = svmmla_s32(sb_acc_11, q4_nibbles_2, q8_qs_04);
sb_acc_11 = svmmla_s32(sb_acc_11, q4_nibbles_3, q8_qs_06);
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1);
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3);
sb_acc_22 = svmmla_s32(sb_acc_22, q4_nibbles_0, q8_qs_10);
sb_acc_22 = svmmla_s32(sb_acc_22, q4_nibbles_1, q8_qs_12);
sb_acc_33 = svmmla_s32(sb_acc_33, q4_nibbles_2, q8_qs_14);
sb_acc_33 = svmmla_s32(sb_acc_33, q4_nibbles_3, q8_qs_16);
sb_acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), sb_acc_00, svext_s32(sb_acc_00, sb_acc_00, 4));
sb_acc_11 = svext_s32(svdup_s32(0), svadd_s32_z(svptrue_pat_b32(SV_VL4), sb_acc_11, svext_s32(sb_acc_11, sb_acc_11, 4)), 4);
sb_acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), sb_acc_22, svext_s32(sb_acc_22, sb_acc_22, 4));
sb_acc_33 = svext_s32(svdup_s32(0), svadd_s32_z(svptrue_pat_b32(SV_VL4), sb_acc_33, svext_s32(sb_acc_33, sb_acc_33, 4)), 4);
sb_acc_00 = svadd_s32_m(svptrue_b32(), sb_acc_00, sb_acc_11);
sb_acc_22 = svadd_s32_m(svptrue_b32(), sb_acc_22, sb_acc_33);
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5);
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7);
if(cp == 0) {
acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_00, block_scale_0);
acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_22, block_scale_0);
acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0);
acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0);
}
if(cp == 1) {
acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_00, block_scale_1);
acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_22, block_scale_1);
acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1);
acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1);
}
if(cp == 2) {
acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_00, block_scale_2);
acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_22, block_scale_2);
acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2);
acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2);
}
if(cp == 3) {
acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_00, block_scale_3);
acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_22, block_scale_3);
acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3);
acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3);
}
}