ggml : add c implementation for the "Q4_K quanti for AArch64" patch

This commit is contained in:
hongyang 2025-11-14 19:05:21 +08:00
parent 2f3dfe2e74
commit 8a4e25d796
3 changed files with 537 additions and 344 deletions

View File

@ -210,6 +210,141 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
#endif
}
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(QK_K == 256);
assert(k % QK_K == 0);
const int nb = k / QK_K;
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
#if defined(__ARM_NEON)
const int blck_size_interleave = 8;
float32x4_t srcv[4][64]; // 64 = QK_K/4
float iscale[4];
for (int i = 0; i < nb; i++) {
float32x4_t asrcv[64];
float32x4_t amaxv[64];
// d:
for (int row_iter = 0; row_iter < 4; row_iter++) {
for (int j = 0; j < 64; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 256 + 4 * j);
for (int j = 0; j < 64; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
for (int j = 0; j < 32; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
for (int j = 0; j < 16; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
for (int j = 0; j < 8; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
for (int j = 0; j < 4; j++) amaxv[16 * j] = vmaxq_f32(amaxv[16 * j], amaxv[16 * j + 8]);
for (int j = 0; j < 2; j++) amaxv[32 * j] = vmaxq_f32(amaxv[32 * j], amaxv[32 * j + 16]);
for (int j = 0; j < 1; j++) amaxv[64 * j] = vmaxq_f32(amaxv[64 * j], amaxv[64 * j + 32]);
const float amax = vmaxvq_f32(amaxv[0]);
// Check if exists: orig == amax
float32x4_t amax_vec = vdupq_n_f32(amax);
uint32x4_t mask_all = vdupq_n_u32(0);
for (int j = 0; j < 64; j++) {
uint32x4_t mask_curr = vceqq_f32(amax_vec, srcv[row_iter][j]);
mask_all = vorrq_u32(mask_all, mask_curr);
}
// Assume that none == amax, then check mask_all to reverse
iscale[row_iter] = ( amax != 0.0f ) ? 127.f / amax : 0.0f;
uint32x4_t cmp = vceqq_u32(mask_all, vdupq_n_u32(0xFFFFFFFFu));
if (vmaxvq_u32(cmp) != 0) {
iscale[row_iter] = ( amax != 0.0f ) ? -127.f / amax : 0.0f;
}
y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
}
// qs: 8 byte interleave over 4 rows, loop = QK_K/8
// bsums: simply generated one by one, row_i is calculated before row_i+1
// loops = 16
for (int j = 0; j < QK_K / blck_size_interleave / 2; j++) {
// Process row0 and row1
float32x4_t f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j], iscale[0]));
float32x4_t f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 1], iscale[0]));
float32x4_t f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 2], iscale[0]));
float32x4_t f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 3], iscale[0]));
int32x4_t i0_0_3 = vcvtnq_s32_f32(f0_0_3);
int32x4_t i0_4_7 = vcvtnq_s32_f32(f0_4_7);
int16x8_t i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
int32x4_t i0_8_11 = vcvtnq_s32_f32(f0_8_11);
int32x4_t i0_12_15 = vcvtnq_s32_f32(f0_12_15);
int16x8_t i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
float32x4_t f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j], iscale[1]));
float32x4_t f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 1], iscale[1]));
float32x4_t f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 2], iscale[1]));
float32x4_t f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 3], iscale[1]));
int32x4_t i1_0_3 = vcvtnq_s32_f32(f1_0_3);
int32x4_t i1_4_7 = vcvtnq_s32_f32(f1_4_7);
int16x8_t i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
int32x4_t i1_8_11 = vcvtnq_s32_f32(f1_8_11);
int32x4_t i1_12_15 = vcvtnq_s32_f32(f1_12_15);
int16x8_t i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
// Calculate and store qs
int8x16_t i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
int8x16_t i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
vst1q_s8(y[i].qs + 64 * j, i0_i1_0_7);
vst1q_s8(y[i].qs + 64 * j + 32, i0_i1_8_15);
// Calculate and store bsum
int8x16_t i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
int8x16_t i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
y[i].bsums[j] = vaddlvq_s8(i0_0_15);
y[i].bsums[j + 16] = vaddlvq_s8(i1_0_15);
// Process row2 and row3
f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j], iscale[2]));
f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 1], iscale[2]));
f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 2], iscale[2]));
f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 3], iscale[2]));
i0_0_3 = vcvtnq_s32_f32(f0_0_3);
i0_4_7 = vcvtnq_s32_f32(f0_4_7);
i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
i0_8_11 = vcvtnq_s32_f32(f0_8_11);
i0_12_15 = vcvtnq_s32_f32(f0_12_15);
i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j], iscale[3]));
f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 1], iscale[3]));
f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 2], iscale[3]));
f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 3], iscale[3]));
i1_0_3 = vcvtnq_s32_f32(f1_0_3);
i1_4_7 = vcvtnq_s32_f32(f1_4_7);
i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
i1_8_11 = vcvtnq_s32_f32(f1_8_11);
i1_12_15 = vcvtnq_s32_f32(f1_12_15);
i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
// Calculate and store qs
i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
vst1q_s8(y[i].qs + 64 * j + 16, i0_i1_0_7);
vst1q_s8(y[i].qs + 64 * j + 48, i0_i1_8_15);
// Calculate and store bsum
i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
y[i].bsums[j + 32] = vaddlvq_s8(i0_0_15);
y[i].bsums[j + 48] = vaddlvq_s8(i1_0_15);
}
}
#else
// NOTE:
// Current C implementation is actually aligned with x86 AVX2 design, but differs from above ARM NEON design.
// This is because the [bsums] layout is different in block_q8_Kx4 for the 2 designs.
// As NEON is supported in almost all the modern ARM platforms, this generic path can be rarely arrived nowadays.
// (The exceptional cases aren't suitable for AI work)
// However logically we may still need a corresponding generic version for ARM, called xxx_generic_arm for example.
UNUSED(nb);
UNUSED(y);
ggml_quantize_mat_q8_K_4x8_generic(x, vy, k);
#endif
}
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK8_0;
const int nb = n / qk;
@ -540,7 +675,8 @@ void ggml_gemv_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
return;
}
#else
// todo, c implementation
// C implementation
ggml_gemv_q4_K_4x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
#endif
}
@ -2656,7 +2792,8 @@ void ggml_gemm_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
}
return;
#else
// todo, c implementation
// C implementation
ggml_gemm_q4_K_4x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
#endif
}

View File

@ -287,6 +287,230 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
#endif
}
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(QK_K == 256);
assert(k % QK_K == 0);
const int nb = k / QK_K;
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
#if defined(__AVX2__)
float iscale[4];
__m256 srcv[4][32];
__m256 iscale_vec[4];
for (int i = 0; i < nb; i++) {
for (int row_iter = 0; row_iter < 4; row_iter++) {
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 );
__m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 8 );
__m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 16 );
__m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 24 );
// Compute max(abs(e)) for the block
const __m256 signBit = _mm256_set1_ps( -0.0f );
__m256 abs0 = _mm256_andnot_ps( signBit, v0 );
__m256 abs1 = _mm256_andnot_ps( signBit, v1 );
__m256 abs2 = _mm256_andnot_ps( signBit, v2 );
__m256 abs3 = _mm256_andnot_ps( signBit, v3 );
__m256 maxAbs = _mm256_max_ps( abs0, abs1 );
maxAbs = _mm256_max_ps( maxAbs, abs2 );
maxAbs = _mm256_max_ps( maxAbs, abs3 );
__m256 mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
__m256 mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
__m256 mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
__m256 mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
__m256 maskAbs = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
srcv[row_iter][0] = v0;
srcv[row_iter][1] = v1;
srcv[row_iter][2] = v2;
srcv[row_iter][3] = v3;
for (int sb = 1; sb < 8; sb++) {
// Temporarily stores absolute quant values
__m256 tempAbs = maxAbs;
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32);
__m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 8 );
__m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 16 );
__m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 24 );
// Compute max(abs(e)) for the block
__m256 abs0 = _mm256_andnot_ps( signBit, v0 );
__m256 abs1 = _mm256_andnot_ps( signBit, v1 );
__m256 abs2 = _mm256_andnot_ps( signBit, v2 );
__m256 abs3 = _mm256_andnot_ps( signBit, v3 );
maxAbs = _mm256_max_ps( maxAbs, abs0 );
maxAbs = _mm256_max_ps( maxAbs, abs1 );
maxAbs = _mm256_max_ps( maxAbs, abs2 );
maxAbs = _mm256_max_ps( maxAbs, abs3 );
__m256 mask_prev = _mm256_cmp_ps( tempAbs, maxAbs, _CMP_EQ_OQ );
maskAbs = _mm256_and_ps( maskAbs, mask_prev );
mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
__m256 mask_curr = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
maskAbs = _mm256_or_ps(maskAbs, mask_curr);
srcv[row_iter][sb * 4] = v0;
srcv[row_iter][sb * 4 + 1] = v1;
srcv[row_iter][sb * 4 + 2] = v2;
srcv[row_iter][sb * 4 + 3] = v3;
}
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
const float maxScalar = _mm_cvtss_f32( max4 );
__m256 maxScalarVec = _mm256_set1_ps(maxScalar);
__m256 mask_next = _mm256_cmp_ps( maxScalarVec, maxAbs, _CMP_EQ_OQ );
__m256 finalMask = _mm256_and_ps(maskAbs, mask_next);
const int mask = _mm256_movemask_ps(finalMask);
iscale[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
if(mask) {
iscale[row_iter] = ( maxScalar != 0.0f ) ? -127.f / maxScalar: 0.0f;
}
y[i].d[row_iter] = maxScalar ? 1/iscale[row_iter] : 0;
iscale_vec[row_iter] = _mm256_set1_ps(iscale[row_iter]);
}
__m256i quants_interleaved[32];
for (int j = 0; j < 32; j++) {
// Apply the multiplier
__m256 v0 = _mm256_mul_ps(srcv[0][j], iscale_vec[0]);
__m256 v1 = _mm256_mul_ps(srcv[1][j], iscale_vec[1]);
__m256 v2 = _mm256_mul_ps(srcv[2][j], iscale_vec[2]);
__m256 v3 = _mm256_mul_ps(srcv[3][j], iscale_vec[3]);
// Round to nearest integer
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
// Convert floats to integers
__m256i i0 = _mm256_cvtps_epi32( v0 );
__m256i i1 = _mm256_cvtps_epi32( v1 );
__m256i i2 = _mm256_cvtps_epi32( v2 );
__m256i i3 = _mm256_cvtps_epi32( v3 );
// Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 );
i2 = _mm256_packs_epi32( i2, i3 );
// Convert int16 to int8
i0 = _mm256_packs_epi16( i0, i2 );
// Permute and store the quantized weights in the required order after the pack instruction
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
_mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0);
quants_interleaved[j] = i0;
}
// Masks to shuffle the quants of corresonding sub blocks for rearraning quants for vectorized bsums computation
__m256i shuffle_mask_sb2 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 0, 1, 4, 5, 6, 7, 8, 9, 8, 9, 12, 13, 14, 15));
shuffle_mask_sb2 = _mm256_permute2f128_si256(shuffle_mask_sb2, shuffle_mask_sb2, 0);
__m256i shuffle_mask_sb3 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 8, 9, 14, 15));
shuffle_mask_sb3 = _mm256_permute2f128_si256(shuffle_mask_sb3, shuffle_mask_sb3, 0);
__m256i shuffle_mask_sb4 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 4, 5, 0, 1, 8, 9, 10, 11, 12, 13, 8, 9));
shuffle_mask_sb4 = _mm256_permute2f128_si256(shuffle_mask_sb4, shuffle_mask_sb4, 0);
for (int k = 0; k < 4; k++) {
// Quants from four different sub blocks are taken
__m256i q0 = quants_interleaved[k * 8 + 0];
__m256i q1 = quants_interleaved[k * 8 + 1];
__m256i q2 = quants_interleaved[k * 8 + 2];
__m256i q3 = quants_interleaved[k * 8 + 3];
__m256i q4 = quants_interleaved[k * 8 + 4];
__m256i q5 = quants_interleaved[k * 8 + 5];
__m256i q6 = quants_interleaved[k * 8 + 6];
__m256i q7 = quants_interleaved[k * 8 + 7];
// The below code block has the first half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
__m256i sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
__m256i sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
__m256i sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
__m256i sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
__m256i one = _mm256_set1_epi8(1);
__m256i bsums_r1 = _mm256_maddubs_epi16(one, sb_h1_interleaved);
for (int l = 0; l < 3; l++) {
// Quants value shifted to process next two values from each sub block
q0 = _mm256_srli_epi64(q0, 16);
q2 = _mm256_srli_epi64(q2, 16);
q4 = _mm256_srli_epi64(q4, 16);
q6 = _mm256_srli_epi64(q6, 16);
sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
bsums_r1 = _mm256_add_epi16(bsums_r1, _mm256_maddubs_epi16(one, sb_h1_interleaved));
}
// The below code block has the second half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
__m256i sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
__m256i sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
__m256i sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
__m256i sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
__m256i bsums_r2 = _mm256_maddubs_epi16(one, sb_h2_interleaved);
for (int l = 0; l < 3; l++) {
// Quants value shifted to process next two values from each sub block
q1 = _mm256_srli_epi64(q1, 16);
q3 = _mm256_srli_epi64(q3, 16);
q5 = _mm256_srli_epi64(q5, 16);
q7 = _mm256_srli_epi64(q7, 16);
sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
bsums_r2 = _mm256_add_epi16(bsums_r2, _mm256_maddubs_epi16(one, sb_h2_interleaved));
}
// Overall bsums in interleaved fashion computed by adding results of both halves
__m256i bsums_r = _mm256_add_epi16(bsums_r1, bsums_r2);
_mm256_storeu_si256((__m256i *)(y[i].bsums + 16 * k), bsums_r);
}
}
#else
UNUSED(nb);
UNUSED(y);
ggml_quantize_mat_q8_K_4x8_generic(x, vy, k);
#endif
}
//
// GEMV/GEMM templates
//

View File

@ -227,346 +227,6 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG
}
}
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(QK_K == 256);
assert(k % QK_K == 0);
const int nb = k / QK_K;
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
#if defined(__ARM_NEON)
const int blck_size_interleave = 8;
float32x4_t srcv[4][64]; // 64 = QK_K/4
float iscale[4];
for (int i = 0; i < nb; i++) {
float32x4_t asrcv[64];
float32x4_t amaxv[64];
// d:
for (int row_iter = 0; row_iter < 4; row_iter++) {
for (int j = 0; j < 64; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 256 + 4 * j);
for (int j = 0; j < 64; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
for (int j = 0; j < 32; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
for (int j = 0; j < 16; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
for (int j = 0; j < 8; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
for (int j = 0; j < 4; j++) amaxv[16 * j] = vmaxq_f32(amaxv[16 * j], amaxv[16 * j + 8]);
for (int j = 0; j < 2; j++) amaxv[32 * j] = vmaxq_f32(amaxv[32 * j], amaxv[32 * j + 16]);
for (int j = 0; j < 1; j++) amaxv[64 * j] = vmaxq_f32(amaxv[64 * j], amaxv[64 * j + 32]);
const float amax = vmaxvq_f32(amaxv[0]);
// Check if exists: orig == amax
float32x4_t amax_vec = vdupq_n_f32(amax);
uint32x4_t mask_all = vdupq_n_u32(0);
for (int j = 0; j < 64; j++) {
uint32x4_t mask_curr = vceqq_f32(amax_vec, srcv[row_iter][j]);
mask_all = vorrq_u32(mask_all, mask_curr);
}
// Assume that none == amax, then check mask_all to reverse
iscale[row_iter] = ( amax != 0.0f ) ? 127.f / amax : 0.0f;
uint32x4_t cmp = vceqq_u32(mask_all, vdupq_n_u32(0xFFFFFFFFu));
if (vmaxvq_u32(cmp) != 0) {
iscale[row_iter] = ( amax != 0.0f ) ? -127.f / amax : 0.0f;
}
y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
}
// qs: 8 byte interleave over 4 rows, loop = QK_K/8
// bsums: simply generated one by one, row_i is calculated before row_i+1
// loops = 16
for (int j = 0; j < QK_K / blck_size_interleave / 2; j++) {
// Process row0 and row1
float32x4_t f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j], iscale[0]));
float32x4_t f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 1], iscale[0]));
float32x4_t f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 2], iscale[0]));
float32x4_t f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 3], iscale[0]));
int32x4_t i0_0_3 = vcvtnq_s32_f32(f0_0_3);
int32x4_t i0_4_7 = vcvtnq_s32_f32(f0_4_7);
int16x8_t i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
int32x4_t i0_8_11 = vcvtnq_s32_f32(f0_8_11);
int32x4_t i0_12_15 = vcvtnq_s32_f32(f0_12_15);
int16x8_t i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
float32x4_t f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j], iscale[1]));
float32x4_t f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 1], iscale[1]));
float32x4_t f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 2], iscale[1]));
float32x4_t f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 3], iscale[1]));
int32x4_t i1_0_3 = vcvtnq_s32_f32(f1_0_3);
int32x4_t i1_4_7 = vcvtnq_s32_f32(f1_4_7);
int16x8_t i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
int32x4_t i1_8_11 = vcvtnq_s32_f32(f1_8_11);
int32x4_t i1_12_15 = vcvtnq_s32_f32(f1_12_15);
int16x8_t i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
// Calculate and store qs
int8x16_t i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
int8x16_t i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
vst1q_s8(y[i].qs + 64 * j, i0_i1_0_7);
vst1q_s8(y[i].qs + 64 * j + 32, i0_i1_8_15);
// Calculate and store bsum
int8x16_t i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
int8x16_t i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
y[i].bsums[j] = vaddlvq_s8(i0_0_15);
y[i].bsums[j + 16] = vaddlvq_s8(i1_0_15);
// Process row2 and row3
f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j], iscale[2]));
f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 1], iscale[2]));
f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 2], iscale[2]));
f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 3], iscale[2]));
i0_0_3 = vcvtnq_s32_f32(f0_0_3);
i0_4_7 = vcvtnq_s32_f32(f0_4_7);
i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
i0_8_11 = vcvtnq_s32_f32(f0_8_11);
i0_12_15 = vcvtnq_s32_f32(f0_12_15);
i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j], iscale[3]));
f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 1], iscale[3]));
f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 2], iscale[3]));
f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 3], iscale[3]));
i1_0_3 = vcvtnq_s32_f32(f1_0_3);
i1_4_7 = vcvtnq_s32_f32(f1_4_7);
i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
i1_8_11 = vcvtnq_s32_f32(f1_8_11);
i1_12_15 = vcvtnq_s32_f32(f1_12_15);
i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
// Calculate and store qs
i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
vst1q_s8(y[i].qs + 64 * j + 16, i0_i1_0_7);
vst1q_s8(y[i].qs + 64 * j + 48, i0_i1_8_15);
// Calculate and store bsum
i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
y[i].bsums[j + 32] = vaddlvq_s8(i0_0_15);
y[i].bsums[j + 48] = vaddlvq_s8(i1_0_15);
}
}
#elif defined(__AVX2__)
float iscale[4];
__m256 srcv[4][32];
__m256 iscale_vec[4];
for (int i = 0; i < nb; i++) {
for (int row_iter = 0; row_iter < 4; row_iter++) {
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 );
__m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 8 );
__m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 16 );
__m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 24 );
// Compute max(abs(e)) for the block
const __m256 signBit = _mm256_set1_ps( -0.0f );
__m256 abs0 = _mm256_andnot_ps( signBit, v0 );
__m256 abs1 = _mm256_andnot_ps( signBit, v1 );
__m256 abs2 = _mm256_andnot_ps( signBit, v2 );
__m256 abs3 = _mm256_andnot_ps( signBit, v3 );
__m256 maxAbs = _mm256_max_ps( abs0, abs1 );
maxAbs = _mm256_max_ps( maxAbs, abs2 );
maxAbs = _mm256_max_ps( maxAbs, abs3 );
__m256 mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
__m256 mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
__m256 mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
__m256 mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
__m256 maskAbs = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
srcv[row_iter][0] = v0;
srcv[row_iter][1] = v1;
srcv[row_iter][2] = v2;
srcv[row_iter][3] = v3;
for (int sb = 1; sb < 8; sb++) {
// Temporarily stores absolute quant values
__m256 tempAbs = maxAbs;
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32);
__m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 8 );
__m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 16 );
__m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 24 );
// Compute max(abs(e)) for the block
__m256 abs0 = _mm256_andnot_ps( signBit, v0 );
__m256 abs1 = _mm256_andnot_ps( signBit, v1 );
__m256 abs2 = _mm256_andnot_ps( signBit, v2 );
__m256 abs3 = _mm256_andnot_ps( signBit, v3 );
maxAbs = _mm256_max_ps( maxAbs, abs0 );
maxAbs = _mm256_max_ps( maxAbs, abs1 );
maxAbs = _mm256_max_ps( maxAbs, abs2 );
maxAbs = _mm256_max_ps( maxAbs, abs3 );
__m256 mask_prev = _mm256_cmp_ps( tempAbs, maxAbs, _CMP_EQ_OQ );
maskAbs = _mm256_and_ps( maskAbs, mask_prev );
mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
__m256 mask_curr = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
maskAbs = _mm256_or_ps(maskAbs, mask_curr);
srcv[row_iter][sb * 4] = v0;
srcv[row_iter][sb * 4 + 1] = v1;
srcv[row_iter][sb * 4 + 2] = v2;
srcv[row_iter][sb * 4 + 3] = v3;
}
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
const float maxScalar = _mm_cvtss_f32( max4 );
__m256 maxScalarVec = _mm256_set1_ps(maxScalar);
__m256 mask_next = _mm256_cmp_ps( maxScalarVec, maxAbs, _CMP_EQ_OQ );
__m256 finalMask = _mm256_and_ps(maskAbs, mask_next);
const int mask = _mm256_movemask_ps(finalMask);
iscale[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
if(mask) {
iscale[row_iter] = ( maxScalar != 0.0f ) ? -127.f / maxScalar: 0.0f;
}
y[i].d[row_iter] = maxScalar ? 1/iscale[row_iter] : 0;
iscale_vec[row_iter] = _mm256_set1_ps(iscale[row_iter]);
}
__m256i quants_interleaved[32];
for (int j = 0; j < 32; j++) {
// Apply the multiplier
__m256 v0 = _mm256_mul_ps(srcv[0][j], iscale_vec[0]);
__m256 v1 = _mm256_mul_ps(srcv[1][j], iscale_vec[1]);
__m256 v2 = _mm256_mul_ps(srcv[2][j], iscale_vec[2]);
__m256 v3 = _mm256_mul_ps(srcv[3][j], iscale_vec[3]);
// Round to nearest integer
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
// Convert floats to integers
__m256i i0 = _mm256_cvtps_epi32( v0 );
__m256i i1 = _mm256_cvtps_epi32( v1 );
__m256i i2 = _mm256_cvtps_epi32( v2 );
__m256i i3 = _mm256_cvtps_epi32( v3 );
// Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 );
i2 = _mm256_packs_epi32( i2, i3 );
// Convert int16 to int8
i0 = _mm256_packs_epi16( i0, i2 );
// Permute and store the quantized weights in the required order after the pack instruction
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
_mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0);
quants_interleaved[j] = i0;
}
// Masks to shuffle the quants of corresonding sub blocks for rearraning quants for vectorized bsums computation
__m256i shuffle_mask_sb2 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 0, 1, 4, 5, 6, 7, 8, 9, 8, 9, 12, 13, 14, 15));
shuffle_mask_sb2 = _mm256_permute2f128_si256(shuffle_mask_sb2, shuffle_mask_sb2, 0);
__m256i shuffle_mask_sb3 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 8, 9, 14, 15));
shuffle_mask_sb3 = _mm256_permute2f128_si256(shuffle_mask_sb3, shuffle_mask_sb3, 0);
__m256i shuffle_mask_sb4 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 4, 5, 0, 1, 8, 9, 10, 11, 12, 13, 8, 9));
shuffle_mask_sb4 = _mm256_permute2f128_si256(shuffle_mask_sb4, shuffle_mask_sb4, 0);
for (int k = 0; k < 4; k++) {
// Quants from four different sub blocks are taken
__m256i q0 = quants_interleaved[k * 8 + 0];
__m256i q1 = quants_interleaved[k * 8 + 1];
__m256i q2 = quants_interleaved[k * 8 + 2];
__m256i q3 = quants_interleaved[k * 8 + 3];
__m256i q4 = quants_interleaved[k * 8 + 4];
__m256i q5 = quants_interleaved[k * 8 + 5];
__m256i q6 = quants_interleaved[k * 8 + 6];
__m256i q7 = quants_interleaved[k * 8 + 7];
// The below code block has the first half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
__m256i sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
__m256i sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
__m256i sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
__m256i sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
__m256i one = _mm256_set1_epi8(1);
__m256i bsums_r1 = _mm256_maddubs_epi16(one, sb_h1_interleaved);
for (int l = 0; l < 3; l++) {
// Quants value shifted to process next two values from each sub block
q0 = _mm256_srli_epi64(q0, 16);
q2 = _mm256_srli_epi64(q2, 16);
q4 = _mm256_srli_epi64(q4, 16);
q6 = _mm256_srli_epi64(q6, 16);
sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
bsums_r1 = _mm256_add_epi16(bsums_r1, _mm256_maddubs_epi16(one, sb_h1_interleaved));
}
// The below code block has the second half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
__m256i sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
__m256i sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
__m256i sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
__m256i sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
__m256i bsums_r2 = _mm256_maddubs_epi16(one, sb_h2_interleaved);
for (int l = 0; l < 3; l++) {
// Quants value shifted to process next two values from each sub block
q1 = _mm256_srli_epi64(q1, 16);
q3 = _mm256_srli_epi64(q3, 16);
q5 = _mm256_srli_epi64(q5, 16);
q7 = _mm256_srli_epi64(q7, 16);
sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
bsums_r2 = _mm256_add_epi16(bsums_r2, _mm256_maddubs_epi16(one, sb_h2_interleaved));
}
// Overall bsums in interleaved fashion computed by adding results of both halves
__m256i bsums_r = _mm256_add_epi16(bsums_r1, bsums_r2);
_mm256_storeu_si256((__m256i *)(y[i].bsums + 16 * k), bsums_r);
}
}
#else
// NOTE: This default c implementation is aligned with AVX2 implemantation, but differs from arm implementation.
// especially in bsums arrangement.
UNUSED(nb);
UNUSED(y);
ggml_quantize_mat_q8_K_4x8_generic(x, vy, k);
#endif
}
} // extern "C"
template <int64_t INTER_SIZE, ggml_type PARAM_TYPE>
@ -731,6 +391,87 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
}
}
void ggml_gemv_q4_K_4x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK_K;
const int nb = n / qk;
const int ncols_interleaved = 4;
const int blocklen = 8;
assert (n % qk == 0);
assert (nc % ncols_interleaved == 0);
UNUSED(s);
UNUSED(bs);
UNUSED(vx);
UNUSED(vy);
UNUSED(nr);
UNUSED(nc);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
uint8_t scales[4][8]; // scales for 8 subblocks of 4 q4_k unit (4 cols)
uint8_t mins[4][8]; // mins for 8 subblocks of 4 q4_k unit (4 cols)
float sumf[4]; // 1x4 unit: final result
float sum_minf[4]; // 1x4 unit: final minus result
int sumi1;
int sumi2;
int sumi;
const block_q8_K * a_ptr = (const block_q8_K *) vy;
// loop on n dimension, each iteration works on 4 columns
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_Kx4 * b_ptr = (const block_q4_Kx4 *) vx + (x * nb);
// initialize results for 4 cols
for (int j = 0; j < ncols_interleaved; j++) {
sumf[j] = 0.0;
sum_minf[j] = 0.0;
}
// loop on k dimension, each iteration works on 1 q4_kx4 block
for (int n = 0; n < nb; n++) {
// prepare scales and mins for 4 cols
for (int j = 0; j < ncols_interleaved; j++) {
for (int i = 0; i < 8; i++) {
scales[j][i] = b_ptr[n].scales[i * 8 + j];
mins[j][i] = b_ptr[n].scales[i * 8 + j + ncols_interleaved];
}
}
// core loop: each iteration works on an interleaved unit (four 8-byte segments from 4 cols)
for (int k = 0; k < (qk / (2 * blocklen); k++)) {
for (int j = 0; j < ncols_interleaved; j++) {
sumi1 = 0;
sumi2 = 0;
sumi = 0;
for (int i = 0; i < blocklen; i++) {
const int v0 = (int8_t)(b_ptr[n].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xf);
const int v1 = (int8_t)(b_ptr[n].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & >> 4);
sumi1 = (v0 * a_ptr[n].qs((k / 2) * 32 + (k % 2) * blocklen + i));
sumi2 = (v0 * a_ptr[n].qs((k / 2) * 32 + (k % 2) * blocklen + i + 16));
uint8_t scale = scales[j][k / 2];
sumi += sumi1 * scale + sumi2 * scale;
}
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[n].d[j]) * a_ptr[n].d;
}
}
// prepare partial results for tail work
for (int j = 0; j < ncols_interleaved; j++) {
for (int i = 0; i < QK_K / 32; i++) {
sum_minf[j] += mins[j][i] * (a_ptr[n].bsums[i * 2] + a_ptr[n].bsums[i * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[n].dmin[j]) * a_ptr[n].d;
}
}
}
// save results
for (int j = 0; j < ncols_interleaved; j++) {
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
}
}
}
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK_K;
const int nb = n / qk;
@ -1290,6 +1031,95 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
}
}
void ggml_gemm_q4_K_4x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK_K;
const int nb = n / qk;
const int ncols_interleaved = 4;
const int blocklen = 8;
assert (n % qk == 0);
assert (nr % 4 == 0);
assert (nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
uint8_t scales[4][8]; // scales for 8 subblocks of 4 q4_k unit (4 cols)
uint8_t mins[4][8]; // mins for 8 subblocks of 4 q4_k unit (4 cols)
float sumf[4][4]; // 4x4 unit: final result
float sum_minf[4][4]; // 4x4 unit: final minus result
int sumi1;
int sumi2;
int sumi;
// loop on m dimension, each iteration works on 4 rows
for (int y = 0; y < nr / 4; y++) {
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
// loop on n dimension, each iteration works on 4 columns
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_Kx4 * b_ptr = (const block_q4_Kx4 *) vx + (x * nb);
// initialize results for 4 cols
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumf[m][j] = 0.0;
sum_minf[m][j] = 0.0;
}
}
// loop on k dimension, each iteration works on 1 q4_kx4 block
for (int n = 0; n < nb; n++) {
// prepare scales and mins for 4 cols
for (int j = 0; j < ncols_interleaved; j++) {
for (int i = 0; i < 8; i++) {
scales[j][i] = b_ptr[n].scales[i * 8 + j];
mins[j][i] = b_ptr[n].scales[i * 8 + j + ncols_interleaved];
}
}
// core loop: each iteration works on an interleaved unit (four 8-byte segments from 4 cols)
for (int k = 0; k < (qk / (2 * blocklen); k++)) {
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumi1 = 0;
sumi2 = 0;
sumi = 0;
for (int i = 0; i < blocklen; i++) {
const int v0 = (int8_t)(b_ptr[n].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xf);
const int v1 = (int8_t)(b_ptr[n].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & >> 4);
sumi1 = (v0 * a_ptr[n].qs((k / 2) * 128 + (k % 2) * 4 * blocklen + i));
sumi2 = (v0 * a_ptr[n].qs((k / 2) * 128 + (k % 2) * 4 * blocklen + i + 64));
uint8_t scale = scales[j][k / 2];
sumi += sumi1 * scale + sumi2 * scale;
}
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[n].d[j]) * a_ptr[n].d[m];
}
}
}
// prepare partial results for tail work
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
for (int i = 0; i < QK_K / 32; i++) {
const int16_t *bsums = a_ptr[n].bsums + (i * 8) - ((i % 2) * 6) + (m * 4);
sum_minf[m][j] += mins[j][i] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[n].dmin[j]) * a_ptr[n].d[m];
}
}
}
}
// save results
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
}
}
}
}
}
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK_K;
const int nb = n / qk;
@ -3037,8 +2867,10 @@ static size_t ggml_backend_cpu_repack_buffer_type_get_alignment(ggml_backend_buf
GGML_UNUSED(buft);
}
// size calculation after q4_kx4 repacking, it's different from traditional type
size_t ggml_nbytes_q4_kx4(const struct ggml_tensor * tensor) {
// For the aarch64 q4_k repack implementation only:
// New tensor storage size calculation after q4_kx4 repacking.
// Because sizeof(block_q4_Kx4) is a bit larger than default sizeof(block_q4_K)*4.
static inline size_t ggml_nbytes_q4_kx4(const struct ggml_tensor * tensor) {
size_t nbytes;
const size_t blck_size = 256;
const size_t type_size = sizeof(block_q4_Kx4) / 4;