Added if condition to support only vector length 256.
This commit is contained in:
parent
cde62986b6
commit
3b9b4df2da
|
|
@ -3039,330 +3039,332 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
|||
UNUSED(blocklen);
|
||||
|
||||
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
constexpr int q8_k_blocklen = 4;
|
||||
const svuint8_t m4b_1 = svdup_n_u8(0x0f);
|
||||
// 8 accumulators: 2 row pairs × 4 col pairs
|
||||
svfloat32_t acc_f32_01, acc_f32_23, acc_f32_45, acc_f32_67;
|
||||
uint32_t idx_arr[8] = { 0, 2, 4, 6, 1, 3, 5, 7 };
|
||||
svbool_t pg = svptrue_pat_b32(SV_VL8);
|
||||
svuint32_t idx = svld1(pg, idx_arr);
|
||||
if (svcntb()*8 == 256) {
|
||||
constexpr int q8_k_blocklen = 4;
|
||||
const svuint8_t m4b_1 = svdup_n_u8(0x0f);
|
||||
// 8 accumulators: 2 row pairs × 4 col pairs
|
||||
svfloat32_t acc_f32_01, acc_f32_23, acc_f32_45, acc_f32_67;
|
||||
uint32_t idx_arr[8] = { 0, 2, 4, 6, 1, 3, 5, 7 };
|
||||
svbool_t pg = svptrue_pat_b32(SV_VL8);
|
||||
svuint32_t idx = svld1(pg, idx_arr);
|
||||
|
||||
static const uint32_t idx_data[8] = {0, 4, 2, 6, 1, 5, 3, 7};
|
||||
svuint32_t idx1 = svld1_u32(svptrue_b32(), idx_data);
|
||||
static const uint32_t idx_data[8] = {0, 4, 2, 6, 1, 5, 3, 7};
|
||||
svuint32_t idx1 = svld1_u32(svptrue_b32(), idx_data);
|
||||
|
||||
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr_1 = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
const block_q8_Kx4 * GGML_RESTRICT q8_ptr_1 = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
||||
const block_q4_Kx8 * GGML_RESTRICT q4_ptr_1 = (const block_q4_Kx8 *) vx + (x * nb);
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
||||
const block_q4_Kx8 * GGML_RESTRICT q4_ptr_1 = (const block_q4_Kx8 *) vx + (x * nb);
|
||||
|
||||
acc_f32_01 = svdup_n_f32(0);
|
||||
acc_f32_23 = svdup_n_f32(0);
|
||||
acc_f32_45 = svdup_n_f32(0);
|
||||
acc_f32_67 = svdup_n_f32(0);
|
||||
acc_f32_01 = svdup_n_f32(0);
|
||||
acc_f32_23 = svdup_n_f32(0);
|
||||
acc_f32_45 = svdup_n_f32(0);
|
||||
acc_f32_67 = svdup_n_f32(0);
|
||||
|
||||
for (int b = 0; b < nb; b++) {
|
||||
// bsums pairs belongs to the same q8_k subblock // 64 elemnts loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum
|
||||
const int16x8_t bsums[4]{
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
||||
};
|
||||
for (int b = 0; b < nb; b++) {
|
||||
// bsums pairs belongs to the same q8_k subblock // 64 elemnts loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum
|
||||
const int16x8_t bsums[4]{
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
||||
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
||||
};
|
||||
|
||||
int32_t bsums_arr32[4][8];
|
||||
int32_t bsums_arr32[4][8];
|
||||
|
||||
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
||||
int16x8_t v16 = bsums[q8_row];
|
||||
for (int q8_row = 0; q8_row < 4; q8_row++) {
|
||||
int16x8_t v16 = bsums[q8_row];
|
||||
|
||||
// low 4
|
||||
int32x4_t v32_lo = vmovl_s16(vget_low_s16(v16));
|
||||
vst1q_s32(&bsums_arr32[q8_row][0], v32_lo);
|
||||
// low 4
|
||||
int32x4_t v32_lo = vmovl_s16(vget_low_s16(v16));
|
||||
vst1q_s32(&bsums_arr32[q8_row][0], v32_lo);
|
||||
|
||||
// high 4
|
||||
int32x4_t v32_hi = vmovl_s16(vget_high_s16(v16));
|
||||
vst1q_s32(&bsums_arr32[q8_row][4], v32_hi);
|
||||
}
|
||||
|
||||
svint32_t sb_acc_0 = svdup_n_s32(0);
|
||||
svint32_t sb_acc_2 = svdup_n_s32(0);
|
||||
|
||||
svint32_t acc_00 = svdup_n_s32(0);
|
||||
svint32_t acc_11 = svdup_n_s32(0);
|
||||
svint32_t acc_22 = svdup_n_s32(0);
|
||||
svint32_t acc_33 = svdup_n_s32(0);
|
||||
svint32_t acc_44 = svdup_n_s32(0);
|
||||
svint32_t acc_55 = svdup_n_s32(0);
|
||||
svint32_t acc_66 = svdup_n_s32(0);
|
||||
svint32_t acc_77 = svdup_n_s32(0);
|
||||
|
||||
svint32_t bias_acc_00 = svdup_n_s32(0);
|
||||
svint32_t bias_acc_22 = svdup_n_s32(0);
|
||||
svint32_t bias_acc_44 = svdup_n_s32(0);
|
||||
svint32_t bias_acc_66 = svdup_n_s32(0);
|
||||
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
svint32_t block_scale_0, block_scale_1, block_scale_2, block_scale_3;
|
||||
svint32_t q4sb_mins_0, q4sb_mins_1;
|
||||
{
|
||||
// 2-superblock I am working on
|
||||
const int offset = sb * 24 + 0 * 12;
|
||||
const uint8_t * scales_in = &q4_ptr_1[b].scales[offset];
|
||||
|
||||
const int offset1 = sb * 24 + 12;
|
||||
const uint8_t * scales_in1 = &q4_ptr_1[b].scales[offset1];
|
||||
|
||||
constexpr uint32_t kmask1 = 0x3f3f3f3f;
|
||||
constexpr uint32_t kmask2 = 0x0f0f0f0f;
|
||||
constexpr uint32_t kmask3 = 0x03030303;
|
||||
constexpr uint8_t scales_size = 12;
|
||||
|
||||
uint32_t sm[3];
|
||||
memcpy(sm, scales_in, scales_size);
|
||||
|
||||
uint32_t sm1[3];
|
||||
memcpy(sm1, scales_in1, scales_size);
|
||||
|
||||
const uint32_t mins_0_3 = sm[1] & kmask1;
|
||||
const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
|
||||
|
||||
const uint32_t mins_0_3_1 = sm1[1] & kmask1;
|
||||
const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4);
|
||||
|
||||
svuint32_t mins_u32_temp = svzip1_u32(svdup_n_u32(mins_0_3), svdup_n_u32(mins_4_7));
|
||||
svuint32_t mins_u32_temp_1 = svzip1_u32(svdup_n_u32(mins_0_3_1), svdup_n_u32(mins_4_7_1));
|
||||
|
||||
/* reinterpret u32 → u8 */
|
||||
svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp);
|
||||
svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1);
|
||||
|
||||
/* widen u8 → u16->u32 (lower half only) */
|
||||
svuint32_t mins_u16 = svunpklo_u32(svunpklo_u16(mins_u8));
|
||||
svuint32_t mins_u16_1 = svunpklo_u32(svunpklo_u16(mins_u8_1));
|
||||
|
||||
q4sb_mins_0 = svreinterpret_s32_u32(mins_u16);
|
||||
|
||||
q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1);
|
||||
|
||||
uint32_t scales_u32_0 = sm[0] & kmask1;
|
||||
uint32_t scales_u32_1 = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
|
||||
uint32_t scales_u32_2 = sm1[0] & kmask1;
|
||||
uint32_t scales_u32_3 = (sm1[2] & kmask2) | (((sm1[0] >> 6) & kmask3) << 4);
|
||||
|
||||
svuint32_t S01 = svdup_n_u32(scales_u32_0);
|
||||
svuint32_t S23 = svdup_n_u32(scales_u32_1);
|
||||
svuint32_t R01 = svdup_n_u32(scales_u32_2);
|
||||
svuint32_t R23 = svdup_n_u32(scales_u32_3);
|
||||
|
||||
svint8_t S01_b = svreinterpret_s8_u32(S01); // s0 s1 s2 s3 ...
|
||||
svint8_t S23_b = svreinterpret_s8_u32(S23); // s4 s5 s6 s7 ...
|
||||
svint8_t R01_b = svreinterpret_s8_u32(R01); // r0 r1 r2 r3 ...
|
||||
svint8_t R23_b = svreinterpret_s8_u32(R23); // r4 r5 r6 r7 ...
|
||||
|
||||
svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b)));
|
||||
// s0 s0 s1 s1 s2 s2 s3 s3 ...
|
||||
|
||||
svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b)));
|
||||
// r0 r0 r1 r1 r2 r2 r3 r3 ...
|
||||
|
||||
svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b)));
|
||||
// s4 s4 s5 s5 s6 s6 s7 s7 ...
|
||||
|
||||
svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b)));
|
||||
// r4 r4 r5 r5 r6 r6 r7 r7 ...
|
||||
|
||||
block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx);
|
||||
// s0 s0 s1 s1 r0 r0 r1 r1
|
||||
|
||||
block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx);
|
||||
// s2 s2 s3 s3 r2 r2 r3 r3
|
||||
|
||||
block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx);
|
||||
// s4 s4 s5 s5 r4 r4 r5 r5
|
||||
|
||||
block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx);
|
||||
// s6 s6 s7 s7 r6 r6 r7 r7
|
||||
// high 4
|
||||
int32x4_t v32_hi = vmovl_s16(vget_high_s16(v16));
|
||||
vst1q_s32(&bsums_arr32[q8_row][4], v32_hi);
|
||||
}
|
||||
|
||||
// q8_ptr[b].qs has interleaved Q8 rows (01, 23)
|
||||
// const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
|
||||
const int8_t * q8_base_1 = q8_ptr_1[b].qs + sb * 256;
|
||||
svint32_t sb_acc_0 = svdup_n_s32(0);
|
||||
svint32_t sb_acc_2 = svdup_n_s32(0);
|
||||
|
||||
// Load 32-byte per row pair, 1 subblock each time
|
||||
// predicate for activating higher lanes for 16 int8 elements
|
||||
const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
|
||||
// predicate for activating lower lanes for 16 int8 elements
|
||||
const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
|
||||
svint32_t acc_00 = svdup_n_s32(0);
|
||||
svint32_t acc_11 = svdup_n_s32(0);
|
||||
svint32_t acc_22 = svdup_n_s32(0);
|
||||
svint32_t acc_33 = svdup_n_s32(0);
|
||||
svint32_t acc_44 = svdup_n_s32(0);
|
||||
svint32_t acc_55 = svdup_n_s32(0);
|
||||
svint32_t acc_66 = svdup_n_s32(0);
|
||||
svint32_t acc_77 = svdup_n_s32(0);
|
||||
|
||||
svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112));
|
||||
svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144));
|
||||
svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176));
|
||||
svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208));
|
||||
svint32_t bias_acc_00 = svdup_n_s32(0);
|
||||
svint32_t bias_acc_22 = svdup_n_s32(0);
|
||||
svint32_t bias_acc_44 = svdup_n_s32(0);
|
||||
svint32_t bias_acc_66 = svdup_n_s32(0);
|
||||
|
||||
svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128));
|
||||
svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160));
|
||||
svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192));
|
||||
svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224));
|
||||
for (int sb = 0; sb < QK_K / 64; sb++) {
|
||||
// Need scales for the low and high nibbles
|
||||
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
||||
svint32_t block_scale_0, block_scale_1, block_scale_2, block_scale_3;
|
||||
svint32_t q4sb_mins_0, q4sb_mins_1;
|
||||
{
|
||||
// 2-superblock I am working on
|
||||
const int offset = sb * 24 + 0 * 12;
|
||||
const uint8_t * scales_in = &q4_ptr_1[b].scales[offset];
|
||||
|
||||
// Q4s columns iterated in pairs (01, 23, 45, 67)
|
||||
for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
|
||||
const int offset1 = sb * 24 + 12;
|
||||
const uint8_t * scales_in1 = &q4_ptr_1[b].scales[offset1];
|
||||
|
||||
sb_acc_0 = svdup_n_s32(0);
|
||||
sb_acc_2 = svdup_n_s32(0);
|
||||
constexpr uint32_t kmask1 = 0x3f3f3f3f;
|
||||
constexpr uint32_t kmask2 = 0x0f0f0f0f;
|
||||
constexpr uint32_t kmask3 = 0x03030303;
|
||||
constexpr uint8_t scales_size = 12;
|
||||
|
||||
svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 0);
|
||||
svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 64);
|
||||
svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 128);
|
||||
svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 192);
|
||||
uint32_t sm[3];
|
||||
memcpy(sm, scales_in, scales_size);
|
||||
|
||||
svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4));
|
||||
svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4));
|
||||
svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4));
|
||||
svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4));
|
||||
uint32_t sm1[3];
|
||||
memcpy(sm1, scales_in1, scales_size);
|
||||
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0);
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2);
|
||||
const uint32_t mins_0_3 = sm[1] & kmask1;
|
||||
const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
|
||||
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4);
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6);
|
||||
const uint32_t mins_0_3_1 = sm1[1] & kmask1;
|
||||
const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4);
|
||||
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1);
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3);
|
||||
svuint32_t mins_u32_temp = svzip1_u32(svdup_n_u32(mins_0_3), svdup_n_u32(mins_4_7));
|
||||
svuint32_t mins_u32_temp_1 = svzip1_u32(svdup_n_u32(mins_0_3_1), svdup_n_u32(mins_4_7_1));
|
||||
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5);
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7);
|
||||
/* reinterpret u32 → u8 */
|
||||
svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp);
|
||||
svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1);
|
||||
|
||||
if(cp == 0) {
|
||||
acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0);
|
||||
acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0);
|
||||
/* widen u8 → u16->u32 (lower half only) */
|
||||
svuint32_t mins_u16 = svunpklo_u32(svunpklo_u16(mins_u8));
|
||||
svuint32_t mins_u16_1 = svunpklo_u32(svunpklo_u16(mins_u8_1));
|
||||
|
||||
q4sb_mins_0 = svreinterpret_s32_u32(mins_u16);
|
||||
|
||||
q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1);
|
||||
|
||||
uint32_t scales_u32_0 = sm[0] & kmask1;
|
||||
uint32_t scales_u32_1 = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
|
||||
uint32_t scales_u32_2 = sm1[0] & kmask1;
|
||||
uint32_t scales_u32_3 = (sm1[2] & kmask2) | (((sm1[0] >> 6) & kmask3) << 4);
|
||||
|
||||
svuint32_t S01 = svdup_n_u32(scales_u32_0);
|
||||
svuint32_t S23 = svdup_n_u32(scales_u32_1);
|
||||
svuint32_t R01 = svdup_n_u32(scales_u32_2);
|
||||
svuint32_t R23 = svdup_n_u32(scales_u32_3);
|
||||
|
||||
svint8_t S01_b = svreinterpret_s8_u32(S01); // s0 s1 s2 s3 ...
|
||||
svint8_t S23_b = svreinterpret_s8_u32(S23); // s4 s5 s6 s7 ...
|
||||
svint8_t R01_b = svreinterpret_s8_u32(R01); // r0 r1 r2 r3 ...
|
||||
svint8_t R23_b = svreinterpret_s8_u32(R23); // r4 r5 r6 r7 ...
|
||||
|
||||
svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b)));
|
||||
// s0 s0 s1 s1 s2 s2 s3 s3 ...
|
||||
|
||||
svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b)));
|
||||
// r0 r0 r1 r1 r2 r2 r3 r3 ...
|
||||
|
||||
svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b)));
|
||||
// s4 s4 s5 s5 s6 s6 s7 s7 ...
|
||||
|
||||
svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b)));
|
||||
// r4 r4 r5 r5 r6 r6 r7 r7 ...
|
||||
|
||||
block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx);
|
||||
// s0 s0 s1 s1 r0 r0 r1 r1
|
||||
|
||||
block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx);
|
||||
// s2 s2 s3 s3 r2 r2 r3 r3
|
||||
|
||||
block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx);
|
||||
// s4 s4 s5 s5 r4 r4 r5 r5
|
||||
|
||||
block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx);
|
||||
// s6 s6 s7 s7 r6 r6 r7 r7
|
||||
}
|
||||
if(cp == 1) {
|
||||
acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1);
|
||||
acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1);
|
||||
|
||||
// q8_ptr[b].qs has interleaved Q8 rows (01, 23)
|
||||
// const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
|
||||
const int8_t * q8_base_1 = q8_ptr_1[b].qs + sb * 256;
|
||||
|
||||
// Load 32-byte per row pair, 1 subblock each time
|
||||
// predicate for activating higher lanes for 16 int8 elements
|
||||
const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
|
||||
// predicate for activating lower lanes for 16 int8 elements
|
||||
const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16);
|
||||
|
||||
svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112));
|
||||
svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144));
|
||||
svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176));
|
||||
svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208));
|
||||
|
||||
svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128));
|
||||
svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160));
|
||||
svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192));
|
||||
svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224));
|
||||
|
||||
// Q4s columns iterated in pairs (01, 23, 45, 67)
|
||||
for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
|
||||
|
||||
sb_acc_0 = svdup_n_s32(0);
|
||||
sb_acc_2 = svdup_n_s32(0);
|
||||
|
||||
svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 0);
|
||||
svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 64);
|
||||
svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 128);
|
||||
svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr_1[b].qs + sb * QK_K + 16 * cp + 192);
|
||||
|
||||
svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4));
|
||||
svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4));
|
||||
svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4));
|
||||
svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4));
|
||||
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0);
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2);
|
||||
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4);
|
||||
sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6);
|
||||
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1);
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3);
|
||||
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5);
|
||||
sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7);
|
||||
|
||||
if(cp == 0) {
|
||||
acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0);
|
||||
acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0);
|
||||
}
|
||||
if(cp == 1) {
|
||||
acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1);
|
||||
acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1);
|
||||
}
|
||||
if(cp == 2) {
|
||||
acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2);
|
||||
acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2);
|
||||
}
|
||||
if(cp == 3) {
|
||||
acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3);
|
||||
acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3);
|
||||
}
|
||||
}
|
||||
if(cp == 2) {
|
||||
acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2);
|
||||
acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2);
|
||||
|
||||
bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0);
|
||||
bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1);
|
||||
|
||||
bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0);
|
||||
bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1);
|
||||
|
||||
bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0);
|
||||
bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1);
|
||||
|
||||
bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0);
|
||||
bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1);
|
||||
} // for sb
|
||||
|
||||
|
||||
// acc[0..3] // acc[4..7]
|
||||
acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4));
|
||||
acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4));
|
||||
acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4));
|
||||
acc_33 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_33, svext_s32(acc_33, acc_33, 4));
|
||||
acc_44 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_44, svext_s32(acc_44, acc_44, 4));
|
||||
acc_55 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_55, svext_s32(acc_55, acc_55, 4));
|
||||
acc_66 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_66, svext_s32(acc_66, acc_66, 4));
|
||||
acc_77 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_77, svext_s32(acc_77, acc_77, 4));
|
||||
|
||||
svint32_t reorder_acc_01 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_00, acc_11), svtrn1_s32(acc_22, acc_33)), idx1);
|
||||
svint32_t reorder_acc_23 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_00, acc_11), svtrn2_s32(acc_22, acc_33)), idx1);
|
||||
|
||||
svint32_t reorder_acc_45 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_44, acc_55), svtrn1_s32(acc_66, acc_77)), idx1);
|
||||
svint32_t reorder_acc_67 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_44, acc_55), svtrn2_s32(acc_66, acc_77)), idx1);
|
||||
|
||||
// Broadcast q8 scalar
|
||||
svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]);
|
||||
|
||||
svfloat32_t q4_dmin_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].dmin), svdup_f16(0)));
|
||||
|
||||
svfloat32_t q4_d_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].d), svdup_f16(0)));
|
||||
|
||||
svfloat32_t scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
svfloat32_t dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_01 = svmls_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_00), dmins1);
|
||||
acc_f32_01 = svmla_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_01), scale1);
|
||||
|
||||
q8_d = svdup_f32(q8_ptr[b].d[1]);
|
||||
|
||||
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_23 = svmls_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_22), dmins1);
|
||||
acc_f32_23 = svmla_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_23), scale1);
|
||||
|
||||
q8_d = svdup_f32(q8_ptr[b].d[2]);
|
||||
|
||||
|
||||
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_45 = svmls_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_44), dmins1);
|
||||
acc_f32_45 = svmla_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_45), scale1);
|
||||
|
||||
q8_d = svdup_f32(q8_ptr[b].d[3]);
|
||||
|
||||
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_67 = svmls_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_66), dmins1);
|
||||
acc_f32_67 = svmla_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_67), scale1);
|
||||
|
||||
} // for b
|
||||
|
||||
// With the previous reorder, the tile is already in the correct memory layout.
|
||||
// Predicate for exactly 4 lanes
|
||||
svbool_t pg4 = svptrue_pat_b32(SV_VL4);
|
||||
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||
int row = y * q8_k_blocklen + i;
|
||||
for (int j = 0; j < 2; j++) {
|
||||
int col = x * ncols_interleaved + j * 4;
|
||||
int offset = row * bs + col;
|
||||
|
||||
if (i == 0 && j == 0) {
|
||||
// acc_f32_0 → lower half of acc_f32_01
|
||||
svst1_f32(pg4, s + offset, acc_f32_01);
|
||||
} else if (i == 0 && j == 1) {
|
||||
// acc_f32_1 → upper half of acc_f32_01
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_01, acc_f32_01, 4));
|
||||
} else if (i == 1 && j == 0) {
|
||||
// acc_f32_2
|
||||
svst1_f32(pg4, s + offset, acc_f32_23);
|
||||
} else if (i == 1 && j == 1) {
|
||||
// acc_f32_3
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_23, acc_f32_23, 4));
|
||||
} else if (i == 2 && j == 0) {
|
||||
// acc_f32_4
|
||||
svst1_f32(pg4, s + offset, acc_f32_45);
|
||||
} else if (i == 2 && j == 1) {
|
||||
// acc_f32_5
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_45, acc_f32_45, 4));
|
||||
} else if (i == 3 && j == 0) {
|
||||
// acc_f32_6
|
||||
svst1_f32(pg4, s + offset, acc_f32_67);
|
||||
} else if (i == 3 && j == 1) {
|
||||
// acc_f32_7
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_67, acc_f32_67, 4));
|
||||
}
|
||||
if(cp == 3) {
|
||||
acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3);
|
||||
acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3);
|
||||
}
|
||||
}
|
||||
|
||||
bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0);
|
||||
bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1);
|
||||
|
||||
bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0);
|
||||
bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1);
|
||||
|
||||
bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0);
|
||||
bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1);
|
||||
|
||||
bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0);
|
||||
bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1);
|
||||
} // for sb
|
||||
|
||||
|
||||
// acc[0..3] // acc[4..7]
|
||||
acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4));
|
||||
acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4));
|
||||
acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4));
|
||||
acc_33 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_33, svext_s32(acc_33, acc_33, 4));
|
||||
acc_44 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_44, svext_s32(acc_44, acc_44, 4));
|
||||
acc_55 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_55, svext_s32(acc_55, acc_55, 4));
|
||||
acc_66 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_66, svext_s32(acc_66, acc_66, 4));
|
||||
acc_77 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_77, svext_s32(acc_77, acc_77, 4));
|
||||
|
||||
svint32_t reorder_acc_01 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_00, acc_11), svtrn1_s32(acc_22, acc_33)), idx1);
|
||||
svint32_t reorder_acc_23 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_00, acc_11), svtrn2_s32(acc_22, acc_33)), idx1);
|
||||
|
||||
svint32_t reorder_acc_45 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_44, acc_55), svtrn1_s32(acc_66, acc_77)), idx1);
|
||||
svint32_t reorder_acc_67 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_44, acc_55), svtrn2_s32(acc_66, acc_77)), idx1);
|
||||
|
||||
// Broadcast q8 scalar
|
||||
svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]);
|
||||
|
||||
svfloat32_t q4_dmin_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].dmin), svdup_f16(0)));
|
||||
|
||||
svfloat32_t q4_d_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].d), svdup_f16(0)));
|
||||
|
||||
svfloat32_t scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
svfloat32_t dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_01 = svmls_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_00), dmins1);
|
||||
acc_f32_01 = svmla_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_01), scale1);
|
||||
|
||||
q8_d = svdup_f32(q8_ptr[b].d[1]);
|
||||
|
||||
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_23 = svmls_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_22), dmins1);
|
||||
acc_f32_23 = svmla_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_23), scale1);
|
||||
|
||||
q8_d = svdup_f32(q8_ptr[b].d[2]);
|
||||
|
||||
|
||||
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_45 = svmls_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_44), dmins1);
|
||||
acc_f32_45 = svmla_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_45), scale1);
|
||||
|
||||
q8_d = svdup_f32(q8_ptr[b].d[3]);
|
||||
|
||||
scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d);
|
||||
dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d);
|
||||
|
||||
acc_f32_67 = svmls_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_66), dmins1);
|
||||
acc_f32_67 = svmla_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_67), scale1);
|
||||
|
||||
} // for b
|
||||
|
||||
// With the previous reorder, the tile is already in the correct memory layout.
|
||||
// Predicate for exactly 4 lanes
|
||||
svbool_t pg4 = svptrue_pat_b32(SV_VL4);
|
||||
for (int i = 0; i < q8_k_blocklen; i++) {
|
||||
int row = y * q8_k_blocklen + i;
|
||||
for (int j = 0; j < 2; j++) {
|
||||
int col = x * ncols_interleaved + j * 4;
|
||||
int offset = row * bs + col;
|
||||
|
||||
if (i == 0 && j == 0) {
|
||||
// acc_f32_0 → lower half of acc_f32_01
|
||||
svst1_f32(pg4, s + offset, acc_f32_01);
|
||||
} else if (i == 0 && j == 1) {
|
||||
// acc_f32_1 → upper half of acc_f32_01
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_01, acc_f32_01, 4));
|
||||
} else if (i == 1 && j == 0) {
|
||||
// acc_f32_2
|
||||
svst1_f32(pg4, s + offset, acc_f32_23);
|
||||
} else if (i == 1 && j == 1) {
|
||||
// acc_f32_3
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_23, acc_f32_23, 4));
|
||||
} else if (i == 2 && j == 0) {
|
||||
// acc_f32_4
|
||||
svst1_f32(pg4, s + offset, acc_f32_45);
|
||||
} else if (i == 2 && j == 1) {
|
||||
// acc_f32_5
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_45, acc_f32_45, 4));
|
||||
} else if (i == 3 && j == 0) {
|
||||
// acc_f32_6
|
||||
svst1_f32(pg4, s + offset, acc_f32_67);
|
||||
} else if (i == 3 && j == 1) {
|
||||
// acc_f32_7
|
||||
svst1_f32(pg4, s + offset, svext_f32(acc_f32_67, acc_f32_67, 4));
|
||||
}
|
||||
}
|
||||
}
|
||||
} // for x
|
||||
} // for y
|
||||
return;
|
||||
} // for x
|
||||
} // for y
|
||||
return;
|
||||
}
|
||||
|
||||
#elif defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
constexpr int q8_k_blocklen = 4;
|
||||
|
|
|
|||
Loading…
Reference in New Issue