Updated repack.cpp
This commit is contained in:
parent
c74d605db4
commit
cde62986b6
|
|
@ -3086,10 +3086,8 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
|||
vst1q_s32(&bsums_arr32[q8_row][4], v32_hi);
|
||||
}
|
||||
|
||||
svint32_t sb_acc_00 = svdup_n_s32(0);
|
||||
svint32_t sb_acc_11 = svdup_n_s32(0);
|
||||
svint32_t sb_acc_22 = svdup_n_s32(0);
|
||||
svint32_t sb_acc_33 = svdup_n_s32(0);
|
||||
svint32_t sb_acc_0 = svdup_n_s32(0);
|
||||
svint32_t sb_acc_2 = svdup_n_s32(0);
|
||||
|
||||
svint32_t acc_00 = svdup_n_s32(0);
|
||||
svint32_t acc_11 = svdup_n_s32(0);
|
||||
|
|
@ -3200,69 +3198,59 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
|||
// predicate for activating lower lanes for 16 int8 elements
|
||||
const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
|
||||
|
||||
svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112));
|
||||
svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144));
|
||||
svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176));
|
||||
svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208));
|
||||
|
||||
svint8_t q8_qs_00 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 16));
|
||||
svint8_t q8_qs_02 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 80));
|
||||
svint8_t q8_qs_04 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 128), svld1_s8(pl16, q8_base_1 + 144));
|
||||
svint8_t q8_qs_06 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 192), svld1_s8(pl16, q8_base_1 + 208));
|
||||
|
||||
svint8_t q8_qs_10 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 32));
|
||||
svint8_t q8_qs_12 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 96));
|
||||
svint8_t q8_qs_14 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 144), svld1_s8(pl16, q8_base_1 + 160));
|
||||
svint8_t q8_qs_16 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 208), svld1_s8(pl16, q8_base_1 + 224));
|
||||
svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128));
|
||||
svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160));
|
||||
svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192));
|
||||
svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224));
|
||||
|
||||
// Q4s columns iterated in pairs (01, 23, 45, 67)
|
||||
for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
|
||||
|
||||
sb_acc_00 = svdup_n_s32(0);
|
||||
sb_acc_11 = svdup_n_s32(0);
|
||||
sb_acc_22 = svdup_n_s32(0);
|
||||
sb_acc_33 = svdup_n_s32(0);
|
||||
sb_acc_0 = svdup_n_s32(0);
|
||||
sb_acc_2 = svdup_n_s32(0);
|
||||
|
||||
svbool_t pg = svptrue_pat_b8(SV_VL16);
|
||||
svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 0);
|
||||
svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 64);
|
||||
svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 128);
|
||||
svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 192);
|
||||
|
||||
svuint8_t q4_qs_cp_0 = svadd_u8_m(svptrue_b8(),svld1_u8(pg, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 0), svld1_u8(pl16, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 48));
|
||||
svuint8_t q4_qs_cp_1 = svadd_u8_m(svptrue_b8(),svld1_u8(pg, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 128), svld1_u8(pl16, q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 176));
|
||||
svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4));
|
||||
svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4));
|
||||
svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4));
|
||||
svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4));
|
||||
|
||||
svint8_t q4_nibbles_0 = svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q4_qs_cp_0, m4b_1));
|
||||
svint8_t q4_nibbles_1 = svreinterpret_s8_u8(svand_u8_m(svptrue_b8(), q4_qs_cp_1, m4b_1));
|
||||
svint8_t q4_nibbles_2 = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), q4_qs_cp_0, 4));
|
||||
svint8_t q4_nibbles_3 = svreinterpret_s8_u8(svlsr_n_u8_m(svptrue_b8(), q4_qs_cp_1, 4));
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0);
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2);
|
||||
|
||||
sb_acc_00 = svmmla_s32(sb_acc_00, q4_nibbles_0, q8_qs_00);
|
||||
sb_acc_00 = svmmla_s32(sb_acc_00, q4_nibbles_1, q8_qs_02);
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4);
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6);
|
||||
|
||||
sb_acc_11 = svmmla_s32(sb_acc_11, q4_nibbles_2, q8_qs_04);
|
||||
sb_acc_11 = svmmla_s32(sb_acc_11, q4_nibbles_3, q8_qs_06);
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1);
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3);
|
||||
|
||||
sb_acc_22 = svmmla_s32(sb_acc_22, q4_nibbles_0, q8_qs_10);
|
||||
sb_acc_22 = svmmla_s32(sb_acc_22, q4_nibbles_1, q8_qs_12);
|
||||
|
||||
sb_acc_33 = svmmla_s32(sb_acc_33, q4_nibbles_2, q8_qs_14);
|
||||
sb_acc_33 = svmmla_s32(sb_acc_33, q4_nibbles_3, q8_qs_16);
|
||||
|
||||
sb_acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), sb_acc_00, svext_s32(sb_acc_00, sb_acc_00, 4));
|
||||
sb_acc_11 = svext_s32(svdup_s32(0), svadd_s32_z(svptrue_pat_b32(SV_VL4), sb_acc_11, svext_s32(sb_acc_11, sb_acc_11, 4)), 4);
|
||||
sb_acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), sb_acc_22, svext_s32(sb_acc_22, sb_acc_22, 4));
|
||||
sb_acc_33 = svext_s32(svdup_s32(0), svadd_s32_z(svptrue_pat_b32(SV_VL4), sb_acc_33, svext_s32(sb_acc_33, sb_acc_33, 4)), 4);
|
||||
sb_acc_00 = svadd_s32_m(svptrue_b32(), sb_acc_00, sb_acc_11);
|
||||
sb_acc_22 = svadd_s32_m(svptrue_b32(), sb_acc_22, sb_acc_33);
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5);
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7);
|
||||
|
||||
if(cp == 0) {
|
||||
acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_00, block_scale_0);
|
||||
acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_22, block_scale_0);
|
||||
acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0);
|
||||
acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0);
|
||||
}
|
||||
if(cp == 1) {
|
||||
acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_00, block_scale_1);
|
||||
acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_22, block_scale_1);
|
||||
acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1);
|
||||
acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1);
|
||||
}
|
||||
if(cp == 2) {
|
||||
acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_00, block_scale_2);
|
||||
acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_22, block_scale_2);
|
||||
acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2);
|
||||
acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2);
|
||||
}
|
||||
if(cp == 3) {
|
||||
acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_00, block_scale_3);
|
||||
acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_22, block_scale_3);
|
||||
acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3);
|
||||
acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue