ggml : add c implementation for the "Q4_K quanti for AArch64" patch
This commit is contained in:
parent
2f3dfe2e74
commit
8a4e25d796
|
|
@ -210,6 +210,141 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
||||
assert(QK_K == 256);
|
||||
assert(k % QK_K == 0);
|
||||
const int nb = k / QK_K;
|
||||
|
||||
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
|
||||
|
||||
#if defined(__ARM_NEON)
|
||||
const int blck_size_interleave = 8;
|
||||
float32x4_t srcv[4][64]; // 64 = QK_K/4
|
||||
float iscale[4];
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float32x4_t asrcv[64];
|
||||
float32x4_t amaxv[64];
|
||||
|
||||
// d:
|
||||
for (int row_iter = 0; row_iter < 4; row_iter++) {
|
||||
for (int j = 0; j < 64; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 256 + 4 * j);
|
||||
for (int j = 0; j < 64; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
|
||||
|
||||
for (int j = 0; j < 32; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
|
||||
for (int j = 0; j < 16; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
|
||||
for (int j = 0; j < 8; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
|
||||
for (int j = 0; j < 4; j++) amaxv[16 * j] = vmaxq_f32(amaxv[16 * j], amaxv[16 * j + 8]);
|
||||
for (int j = 0; j < 2; j++) amaxv[32 * j] = vmaxq_f32(amaxv[32 * j], amaxv[32 * j + 16]);
|
||||
for (int j = 0; j < 1; j++) amaxv[64 * j] = vmaxq_f32(amaxv[64 * j], amaxv[64 * j + 32]);
|
||||
|
||||
const float amax = vmaxvq_f32(amaxv[0]);
|
||||
|
||||
// Check if exists: orig == amax
|
||||
float32x4_t amax_vec = vdupq_n_f32(amax);
|
||||
uint32x4_t mask_all = vdupq_n_u32(0);
|
||||
for (int j = 0; j < 64; j++) {
|
||||
uint32x4_t mask_curr = vceqq_f32(amax_vec, srcv[row_iter][j]);
|
||||
mask_all = vorrq_u32(mask_all, mask_curr);
|
||||
}
|
||||
|
||||
// Assume that none == amax, then check mask_all to reverse
|
||||
iscale[row_iter] = ( amax != 0.0f ) ? 127.f / amax : 0.0f;
|
||||
uint32x4_t cmp = vceqq_u32(mask_all, vdupq_n_u32(0xFFFFFFFFu));
|
||||
if (vmaxvq_u32(cmp) != 0) {
|
||||
iscale[row_iter] = ( amax != 0.0f ) ? -127.f / amax : 0.0f;
|
||||
}
|
||||
|
||||
y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
|
||||
}
|
||||
|
||||
// qs: 8 byte interleave over 4 rows, loop = QK_K/8
|
||||
// bsums: simply generated one by one, row_i is calculated before row_i+1
|
||||
// loops = 16
|
||||
for (int j = 0; j < QK_K / blck_size_interleave / 2; j++) {
|
||||
// Process row0 and row1
|
||||
float32x4_t f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j], iscale[0]));
|
||||
float32x4_t f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 1], iscale[0]));
|
||||
float32x4_t f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 2], iscale[0]));
|
||||
float32x4_t f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 3], iscale[0]));
|
||||
int32x4_t i0_0_3 = vcvtnq_s32_f32(f0_0_3);
|
||||
int32x4_t i0_4_7 = vcvtnq_s32_f32(f0_4_7);
|
||||
int16x8_t i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
int32x4_t i0_8_11 = vcvtnq_s32_f32(f0_8_11);
|
||||
int32x4_t i0_12_15 = vcvtnq_s32_f32(f0_12_15);
|
||||
int16x8_t i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
float32x4_t f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j], iscale[1]));
|
||||
float32x4_t f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 1], iscale[1]));
|
||||
float32x4_t f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 2], iscale[1]));
|
||||
float32x4_t f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 3], iscale[1]));
|
||||
int32x4_t i1_0_3 = vcvtnq_s32_f32(f1_0_3);
|
||||
int32x4_t i1_4_7 = vcvtnq_s32_f32(f1_4_7);
|
||||
int16x8_t i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
int32x4_t i1_8_11 = vcvtnq_s32_f32(f1_8_11);
|
||||
int32x4_t i1_12_15 = vcvtnq_s32_f32(f1_12_15);
|
||||
int16x8_t i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
// Calculate and store qs
|
||||
int8x16_t i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
int8x16_t i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
vst1q_s8(y[i].qs + 64 * j, i0_i1_0_7);
|
||||
vst1q_s8(y[i].qs + 64 * j + 32, i0_i1_8_15);
|
||||
// Calculate and store bsum
|
||||
int8x16_t i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
int8x16_t i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
y[i].bsums[j] = vaddlvq_s8(i0_0_15);
|
||||
y[i].bsums[j + 16] = vaddlvq_s8(i1_0_15);
|
||||
|
||||
// Process row2 and row3
|
||||
f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j], iscale[2]));
|
||||
f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 1], iscale[2]));
|
||||
f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 2], iscale[2]));
|
||||
f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 3], iscale[2]));
|
||||
i0_0_3 = vcvtnq_s32_f32(f0_0_3);
|
||||
i0_4_7 = vcvtnq_s32_f32(f0_4_7);
|
||||
i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
i0_8_11 = vcvtnq_s32_f32(f0_8_11);
|
||||
i0_12_15 = vcvtnq_s32_f32(f0_12_15);
|
||||
i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j], iscale[3]));
|
||||
f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 1], iscale[3]));
|
||||
f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 2], iscale[3]));
|
||||
f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 3], iscale[3]));
|
||||
i1_0_3 = vcvtnq_s32_f32(f1_0_3);
|
||||
i1_4_7 = vcvtnq_s32_f32(f1_4_7);
|
||||
i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
i1_8_11 = vcvtnq_s32_f32(f1_8_11);
|
||||
i1_12_15 = vcvtnq_s32_f32(f1_12_15);
|
||||
i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
// Calculate and store qs
|
||||
i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
vst1q_s8(y[i].qs + 64 * j + 16, i0_i1_0_7);
|
||||
vst1q_s8(y[i].qs + 64 * j + 48, i0_i1_8_15);
|
||||
// Calculate and store bsum
|
||||
i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
y[i].bsums[j + 32] = vaddlvq_s8(i0_0_15);
|
||||
y[i].bsums[j + 48] = vaddlvq_s8(i1_0_15);
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
// NOTE:
|
||||
// Current C implementation is actually aligned with x86 AVX2 design, but differs from above ARM NEON design.
|
||||
// This is because the [bsums] layout is different in block_q8_Kx4 for the 2 designs.
|
||||
// As NEON is supported in almost all the modern ARM platforms, this generic path can be rarely arrived nowadays.
|
||||
// (The exceptional cases aren't suitable for AI work)
|
||||
// However logically we may still need a corresponding generic version for ARM, called xxx_generic_arm for example.
|
||||
UNUSED(nb);
|
||||
UNUSED(y);
|
||||
ggml_quantize_mat_q8_K_4x8_generic(x, vy, k);
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
|
@ -540,7 +675,8 @@ void ggml_gemv_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
return;
|
||||
}
|
||||
#else
|
||||
// todo, c implementation
|
||||
// C implementation
|
||||
ggml_gemv_q4_K_4x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
@ -2656,7 +2792,8 @@ void ggml_gemm_q4_K_4x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
}
|
||||
return;
|
||||
#else
|
||||
// todo, c implementation
|
||||
// C implementation
|
||||
ggml_gemm_q4_K_4x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -287,6 +287,230 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
||||
assert(QK_K == 256);
|
||||
assert(k % QK_K == 0);
|
||||
const int nb = k / QK_K;
|
||||
|
||||
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
|
||||
|
||||
#if defined(__AVX2__)
|
||||
float iscale[4];
|
||||
__m256 srcv[4][32];
|
||||
__m256 iscale_vec[4];
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
for (int row_iter = 0; row_iter < 4; row_iter++) {
|
||||
// Load elements into 4 AVX vectors
|
||||
__m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 );
|
||||
__m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 8 );
|
||||
__m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 16 );
|
||||
__m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 24 );
|
||||
|
||||
// Compute max(abs(e)) for the block
|
||||
const __m256 signBit = _mm256_set1_ps( -0.0f );
|
||||
__m256 abs0 = _mm256_andnot_ps( signBit, v0 );
|
||||
__m256 abs1 = _mm256_andnot_ps( signBit, v1 );
|
||||
__m256 abs2 = _mm256_andnot_ps( signBit, v2 );
|
||||
__m256 abs3 = _mm256_andnot_ps( signBit, v3 );
|
||||
|
||||
__m256 maxAbs = _mm256_max_ps( abs0, abs1 );
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs2 );
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs3 );
|
||||
|
||||
__m256 mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
|
||||
__m256 mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
|
||||
__m256 mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
|
||||
__m256 mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
|
||||
|
||||
__m256 maskAbs = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
|
||||
|
||||
srcv[row_iter][0] = v0;
|
||||
srcv[row_iter][1] = v1;
|
||||
srcv[row_iter][2] = v2;
|
||||
srcv[row_iter][3] = v3;
|
||||
|
||||
for (int sb = 1; sb < 8; sb++) {
|
||||
// Temporarily stores absolute quant values
|
||||
__m256 tempAbs = maxAbs;
|
||||
|
||||
// Load elements into 4 AVX vectors
|
||||
__m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32);
|
||||
__m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 8 );
|
||||
__m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 16 );
|
||||
__m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 24 );
|
||||
|
||||
// Compute max(abs(e)) for the block
|
||||
__m256 abs0 = _mm256_andnot_ps( signBit, v0 );
|
||||
__m256 abs1 = _mm256_andnot_ps( signBit, v1 );
|
||||
__m256 abs2 = _mm256_andnot_ps( signBit, v2 );
|
||||
__m256 abs3 = _mm256_andnot_ps( signBit, v3 );
|
||||
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs0 );
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs1 );
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs2 );
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs3 );
|
||||
|
||||
__m256 mask_prev = _mm256_cmp_ps( tempAbs, maxAbs, _CMP_EQ_OQ );
|
||||
maskAbs = _mm256_and_ps( maskAbs, mask_prev );
|
||||
|
||||
mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
|
||||
mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
|
||||
mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
|
||||
mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
|
||||
|
||||
__m256 mask_curr = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
|
||||
maskAbs = _mm256_or_ps(maskAbs, mask_curr);
|
||||
|
||||
srcv[row_iter][sb * 4] = v0;
|
||||
srcv[row_iter][sb * 4 + 1] = v1;
|
||||
srcv[row_iter][sb * 4 + 2] = v2;
|
||||
srcv[row_iter][sb * 4 + 3] = v3;
|
||||
}
|
||||
|
||||
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
|
||||
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
||||
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
||||
const float maxScalar = _mm_cvtss_f32( max4 );
|
||||
|
||||
__m256 maxScalarVec = _mm256_set1_ps(maxScalar);
|
||||
|
||||
__m256 mask_next = _mm256_cmp_ps( maxScalarVec, maxAbs, _CMP_EQ_OQ );
|
||||
__m256 finalMask = _mm256_and_ps(maskAbs, mask_next);
|
||||
|
||||
const int mask = _mm256_movemask_ps(finalMask);
|
||||
iscale[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
|
||||
|
||||
if(mask) {
|
||||
iscale[row_iter] = ( maxScalar != 0.0f ) ? -127.f / maxScalar: 0.0f;
|
||||
}
|
||||
|
||||
y[i].d[row_iter] = maxScalar ? 1/iscale[row_iter] : 0;
|
||||
iscale_vec[row_iter] = _mm256_set1_ps(iscale[row_iter]);
|
||||
}
|
||||
|
||||
__m256i quants_interleaved[32];
|
||||
for (int j = 0; j < 32; j++) {
|
||||
// Apply the multiplier
|
||||
__m256 v0 = _mm256_mul_ps(srcv[0][j], iscale_vec[0]);
|
||||
__m256 v1 = _mm256_mul_ps(srcv[1][j], iscale_vec[1]);
|
||||
__m256 v2 = _mm256_mul_ps(srcv[2][j], iscale_vec[2]);
|
||||
__m256 v3 = _mm256_mul_ps(srcv[3][j], iscale_vec[3]);
|
||||
|
||||
// Round to nearest integer
|
||||
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
|
||||
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
|
||||
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
|
||||
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
|
||||
|
||||
// Convert floats to integers
|
||||
__m256i i0 = _mm256_cvtps_epi32( v0 );
|
||||
__m256i i1 = _mm256_cvtps_epi32( v1 );
|
||||
__m256i i2 = _mm256_cvtps_epi32( v2 );
|
||||
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
||||
|
||||
// Convert int32 to int16
|
||||
i0 = _mm256_packs_epi32( i0, i1 );
|
||||
i2 = _mm256_packs_epi32( i2, i3 );
|
||||
// Convert int16 to int8
|
||||
i0 = _mm256_packs_epi16( i0, i2 );
|
||||
|
||||
// Permute and store the quantized weights in the required order after the pack instruction
|
||||
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
|
||||
i0 = _mm256_permutevar8x32_epi32( i0, perm );
|
||||
|
||||
_mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0);
|
||||
quants_interleaved[j] = i0;
|
||||
}
|
||||
|
||||
// Masks to shuffle the quants of corresonding sub blocks for rearraning quants for vectorized bsums computation
|
||||
__m256i shuffle_mask_sb2 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 0, 1, 4, 5, 6, 7, 8, 9, 8, 9, 12, 13, 14, 15));
|
||||
shuffle_mask_sb2 = _mm256_permute2f128_si256(shuffle_mask_sb2, shuffle_mask_sb2, 0);
|
||||
__m256i shuffle_mask_sb3 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 8, 9, 14, 15));
|
||||
shuffle_mask_sb3 = _mm256_permute2f128_si256(shuffle_mask_sb3, shuffle_mask_sb3, 0);
|
||||
__m256i shuffle_mask_sb4 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 4, 5, 0, 1, 8, 9, 10, 11, 12, 13, 8, 9));
|
||||
shuffle_mask_sb4 = _mm256_permute2f128_si256(shuffle_mask_sb4, shuffle_mask_sb4, 0);
|
||||
|
||||
for (int k = 0; k < 4; k++) {
|
||||
// Quants from four different sub blocks are taken
|
||||
__m256i q0 = quants_interleaved[k * 8 + 0];
|
||||
__m256i q1 = quants_interleaved[k * 8 + 1];
|
||||
__m256i q2 = quants_interleaved[k * 8 + 2];
|
||||
__m256i q3 = quants_interleaved[k * 8 + 3];
|
||||
__m256i q4 = quants_interleaved[k * 8 + 4];
|
||||
__m256i q5 = quants_interleaved[k * 8 + 5];
|
||||
__m256i q6 = quants_interleaved[k * 8 + 6];
|
||||
__m256i q7 = quants_interleaved[k * 8 + 7];
|
||||
|
||||
|
||||
// The below code block has the first half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
|
||||
__m256i sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
|
||||
__m256i sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
|
||||
__m256i sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
|
||||
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
|
||||
__m256i sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
|
||||
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
|
||||
|
||||
__m256i one = _mm256_set1_epi8(1);
|
||||
__m256i bsums_r1 = _mm256_maddubs_epi16(one, sb_h1_interleaved);
|
||||
|
||||
for (int l = 0; l < 3; l++) {
|
||||
// Quants value shifted to process next two values from each sub block
|
||||
q0 = _mm256_srli_epi64(q0, 16);
|
||||
q2 = _mm256_srli_epi64(q2, 16);
|
||||
q4 = _mm256_srli_epi64(q4, 16);
|
||||
q6 = _mm256_srli_epi64(q6, 16);
|
||||
|
||||
sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
|
||||
sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
|
||||
sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
|
||||
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
|
||||
sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
|
||||
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
|
||||
|
||||
bsums_r1 = _mm256_add_epi16(bsums_r1, _mm256_maddubs_epi16(one, sb_h1_interleaved));
|
||||
}
|
||||
|
||||
// The below code block has the second half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
|
||||
__m256i sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
|
||||
__m256i sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
|
||||
__m256i sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
|
||||
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
|
||||
__m256i sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
|
||||
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
|
||||
|
||||
__m256i bsums_r2 = _mm256_maddubs_epi16(one, sb_h2_interleaved);
|
||||
|
||||
for (int l = 0; l < 3; l++) {
|
||||
// Quants value shifted to process next two values from each sub block
|
||||
q1 = _mm256_srli_epi64(q1, 16);
|
||||
q3 = _mm256_srli_epi64(q3, 16);
|
||||
q5 = _mm256_srli_epi64(q5, 16);
|
||||
q7 = _mm256_srli_epi64(q7, 16);
|
||||
|
||||
sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
|
||||
sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
|
||||
sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
|
||||
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
|
||||
sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
|
||||
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
|
||||
|
||||
bsums_r2 = _mm256_add_epi16(bsums_r2, _mm256_maddubs_epi16(one, sb_h2_interleaved));
|
||||
}
|
||||
|
||||
// Overall bsums in interleaved fashion computed by adding results of both halves
|
||||
__m256i bsums_r = _mm256_add_epi16(bsums_r1, bsums_r2);
|
||||
_mm256_storeu_si256((__m256i *)(y[i].bsums + 16 * k), bsums_r);
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
UNUSED(nb);
|
||||
UNUSED(y);
|
||||
ggml_quantize_mat_q8_K_4x8_generic(x, vy, k);
|
||||
#endif
|
||||
}
|
||||
|
||||
//
|
||||
// GEMV/GEMM templates
|
||||
//
|
||||
|
|
|
|||
|
|
@ -227,346 +227,6 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
||||
assert(QK_K == 256);
|
||||
assert(k % QK_K == 0);
|
||||
const int nb = k / QK_K;
|
||||
|
||||
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
|
||||
|
||||
#if defined(__ARM_NEON)
|
||||
const int blck_size_interleave = 8;
|
||||
float32x4_t srcv[4][64]; // 64 = QK_K/4
|
||||
float iscale[4];
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float32x4_t asrcv[64];
|
||||
float32x4_t amaxv[64];
|
||||
|
||||
// d:
|
||||
for (int row_iter = 0; row_iter < 4; row_iter++) {
|
||||
for (int j = 0; j < 64; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 256 + 4 * j);
|
||||
for (int j = 0; j < 64; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
|
||||
|
||||
for (int j = 0; j < 32; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
|
||||
for (int j = 0; j < 16; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
|
||||
for (int j = 0; j < 8; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
|
||||
for (int j = 0; j < 4; j++) amaxv[16 * j] = vmaxq_f32(amaxv[16 * j], amaxv[16 * j + 8]);
|
||||
for (int j = 0; j < 2; j++) amaxv[32 * j] = vmaxq_f32(amaxv[32 * j], amaxv[32 * j + 16]);
|
||||
for (int j = 0; j < 1; j++) amaxv[64 * j] = vmaxq_f32(amaxv[64 * j], amaxv[64 * j + 32]);
|
||||
|
||||
const float amax = vmaxvq_f32(amaxv[0]);
|
||||
|
||||
// Check if exists: orig == amax
|
||||
float32x4_t amax_vec = vdupq_n_f32(amax);
|
||||
uint32x4_t mask_all = vdupq_n_u32(0);
|
||||
for (int j = 0; j < 64; j++) {
|
||||
uint32x4_t mask_curr = vceqq_f32(amax_vec, srcv[row_iter][j]);
|
||||
mask_all = vorrq_u32(mask_all, mask_curr);
|
||||
}
|
||||
|
||||
// Assume that none == amax, then check mask_all to reverse
|
||||
iscale[row_iter] = ( amax != 0.0f ) ? 127.f / amax : 0.0f;
|
||||
uint32x4_t cmp = vceqq_u32(mask_all, vdupq_n_u32(0xFFFFFFFFu));
|
||||
if (vmaxvq_u32(cmp) != 0) {
|
||||
iscale[row_iter] = ( amax != 0.0f ) ? -127.f / amax : 0.0f;
|
||||
}
|
||||
|
||||
y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
|
||||
}
|
||||
|
||||
// qs: 8 byte interleave over 4 rows, loop = QK_K/8
|
||||
// bsums: simply generated one by one, row_i is calculated before row_i+1
|
||||
// loops = 16
|
||||
for (int j = 0; j < QK_K / blck_size_interleave / 2; j++) {
|
||||
// Process row0 and row1
|
||||
float32x4_t f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j], iscale[0]));
|
||||
float32x4_t f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 1], iscale[0]));
|
||||
float32x4_t f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 2], iscale[0]));
|
||||
float32x4_t f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 3], iscale[0]));
|
||||
int32x4_t i0_0_3 = vcvtnq_s32_f32(f0_0_3);
|
||||
int32x4_t i0_4_7 = vcvtnq_s32_f32(f0_4_7);
|
||||
int16x8_t i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
int32x4_t i0_8_11 = vcvtnq_s32_f32(f0_8_11);
|
||||
int32x4_t i0_12_15 = vcvtnq_s32_f32(f0_12_15);
|
||||
int16x8_t i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
float32x4_t f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j], iscale[1]));
|
||||
float32x4_t f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 1], iscale[1]));
|
||||
float32x4_t f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 2], iscale[1]));
|
||||
float32x4_t f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 3], iscale[1]));
|
||||
int32x4_t i1_0_3 = vcvtnq_s32_f32(f1_0_3);
|
||||
int32x4_t i1_4_7 = vcvtnq_s32_f32(f1_4_7);
|
||||
int16x8_t i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
int32x4_t i1_8_11 = vcvtnq_s32_f32(f1_8_11);
|
||||
int32x4_t i1_12_15 = vcvtnq_s32_f32(f1_12_15);
|
||||
int16x8_t i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
// Calculate and store qs
|
||||
int8x16_t i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
int8x16_t i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
vst1q_s8(y[i].qs + 64 * j, i0_i1_0_7);
|
||||
vst1q_s8(y[i].qs + 64 * j + 32, i0_i1_8_15);
|
||||
// Calculate and store bsum
|
||||
int8x16_t i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
int8x16_t i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
y[i].bsums[j] = vaddlvq_s8(i0_0_15);
|
||||
y[i].bsums[j + 16] = vaddlvq_s8(i1_0_15);
|
||||
|
||||
// Process row2 and row3
|
||||
f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j], iscale[2]));
|
||||
f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 1], iscale[2]));
|
||||
f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 2], iscale[2]));
|
||||
f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 3], iscale[2]));
|
||||
i0_0_3 = vcvtnq_s32_f32(f0_0_3);
|
||||
i0_4_7 = vcvtnq_s32_f32(f0_4_7);
|
||||
i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
i0_8_11 = vcvtnq_s32_f32(f0_8_11);
|
||||
i0_12_15 = vcvtnq_s32_f32(f0_12_15);
|
||||
i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j], iscale[3]));
|
||||
f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 1], iscale[3]));
|
||||
f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 2], iscale[3]));
|
||||
f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 3], iscale[3]));
|
||||
i1_0_3 = vcvtnq_s32_f32(f1_0_3);
|
||||
i1_4_7 = vcvtnq_s32_f32(f1_4_7);
|
||||
i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
i1_8_11 = vcvtnq_s32_f32(f1_8_11);
|
||||
i1_12_15 = vcvtnq_s32_f32(f1_12_15);
|
||||
i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
// Calculate and store qs
|
||||
i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
vst1q_s8(y[i].qs + 64 * j + 16, i0_i1_0_7);
|
||||
vst1q_s8(y[i].qs + 64 * j + 48, i0_i1_8_15);
|
||||
// Calculate and store bsum
|
||||
i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
y[i].bsums[j + 32] = vaddlvq_s8(i0_0_15);
|
||||
y[i].bsums[j + 48] = vaddlvq_s8(i1_0_15);
|
||||
}
|
||||
}
|
||||
#elif defined(__AVX2__)
|
||||
float iscale[4];
|
||||
__m256 srcv[4][32];
|
||||
__m256 iscale_vec[4];
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
for (int row_iter = 0; row_iter < 4; row_iter++) {
|
||||
// Load elements into 4 AVX vectors
|
||||
__m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 );
|
||||
__m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 8 );
|
||||
__m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 16 );
|
||||
__m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + 24 );
|
||||
|
||||
// Compute max(abs(e)) for the block
|
||||
const __m256 signBit = _mm256_set1_ps( -0.0f );
|
||||
__m256 abs0 = _mm256_andnot_ps( signBit, v0 );
|
||||
__m256 abs1 = _mm256_andnot_ps( signBit, v1 );
|
||||
__m256 abs2 = _mm256_andnot_ps( signBit, v2 );
|
||||
__m256 abs3 = _mm256_andnot_ps( signBit, v3 );
|
||||
|
||||
__m256 maxAbs = _mm256_max_ps( abs0, abs1 );
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs2 );
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs3 );
|
||||
|
||||
__m256 mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
|
||||
__m256 mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
|
||||
__m256 mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
|
||||
__m256 mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
|
||||
|
||||
__m256 maskAbs = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
|
||||
|
||||
srcv[row_iter][0] = v0;
|
||||
srcv[row_iter][1] = v1;
|
||||
srcv[row_iter][2] = v2;
|
||||
srcv[row_iter][3] = v3;
|
||||
|
||||
for (int sb = 1; sb < 8; sb++) {
|
||||
// Temporarily stores absolute quant values
|
||||
__m256 tempAbs = maxAbs;
|
||||
|
||||
// Load elements into 4 AVX vectors
|
||||
__m256 v0 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32);
|
||||
__m256 v1 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 8 );
|
||||
__m256 v2 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 16 );
|
||||
__m256 v3 = _mm256_loadu_ps( x + row_iter * k + i * 256 + sb * 32 + 24 );
|
||||
|
||||
// Compute max(abs(e)) for the block
|
||||
__m256 abs0 = _mm256_andnot_ps( signBit, v0 );
|
||||
__m256 abs1 = _mm256_andnot_ps( signBit, v1 );
|
||||
__m256 abs2 = _mm256_andnot_ps( signBit, v2 );
|
||||
__m256 abs3 = _mm256_andnot_ps( signBit, v3 );
|
||||
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs0 );
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs1 );
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs2 );
|
||||
maxAbs = _mm256_max_ps( maxAbs, abs3 );
|
||||
|
||||
__m256 mask_prev = _mm256_cmp_ps( tempAbs, maxAbs, _CMP_EQ_OQ );
|
||||
maskAbs = _mm256_and_ps( maskAbs, mask_prev );
|
||||
|
||||
mask0 = _mm256_cmp_ps( maxAbs, v0, _CMP_EQ_OQ );
|
||||
mask1 = _mm256_cmp_ps( maxAbs, v1, _CMP_EQ_OQ );
|
||||
mask2 = _mm256_cmp_ps( maxAbs, v2, _CMP_EQ_OQ );
|
||||
mask3 = _mm256_cmp_ps( maxAbs, v3, _CMP_EQ_OQ );
|
||||
|
||||
__m256 mask_curr = _mm256_or_ps(_mm256_or_ps(mask0, mask1),_mm256_or_ps(mask2, mask3));
|
||||
maskAbs = _mm256_or_ps(maskAbs, mask_curr);
|
||||
|
||||
srcv[row_iter][sb * 4] = v0;
|
||||
srcv[row_iter][sb * 4 + 1] = v1;
|
||||
srcv[row_iter][sb * 4 + 2] = v2;
|
||||
srcv[row_iter][sb * 4 + 3] = v3;
|
||||
}
|
||||
|
||||
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
|
||||
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
|
||||
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
|
||||
const float maxScalar = _mm_cvtss_f32( max4 );
|
||||
|
||||
__m256 maxScalarVec = _mm256_set1_ps(maxScalar);
|
||||
|
||||
__m256 mask_next = _mm256_cmp_ps( maxScalarVec, maxAbs, _CMP_EQ_OQ );
|
||||
__m256 finalMask = _mm256_and_ps(maskAbs, mask_next);
|
||||
|
||||
const int mask = _mm256_movemask_ps(finalMask);
|
||||
iscale[row_iter] = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
|
||||
|
||||
if(mask) {
|
||||
iscale[row_iter] = ( maxScalar != 0.0f ) ? -127.f / maxScalar: 0.0f;
|
||||
}
|
||||
|
||||
y[i].d[row_iter] = maxScalar ? 1/iscale[row_iter] : 0;
|
||||
iscale_vec[row_iter] = _mm256_set1_ps(iscale[row_iter]);
|
||||
}
|
||||
|
||||
__m256i quants_interleaved[32];
|
||||
for (int j = 0; j < 32; j++) {
|
||||
// Apply the multiplier
|
||||
__m256 v0 = _mm256_mul_ps(srcv[0][j], iscale_vec[0]);
|
||||
__m256 v1 = _mm256_mul_ps(srcv[1][j], iscale_vec[1]);
|
||||
__m256 v2 = _mm256_mul_ps(srcv[2][j], iscale_vec[2]);
|
||||
__m256 v3 = _mm256_mul_ps(srcv[3][j], iscale_vec[3]);
|
||||
|
||||
// Round to nearest integer
|
||||
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
|
||||
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
|
||||
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
|
||||
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
|
||||
|
||||
// Convert floats to integers
|
||||
__m256i i0 = _mm256_cvtps_epi32( v0 );
|
||||
__m256i i1 = _mm256_cvtps_epi32( v1 );
|
||||
__m256i i2 = _mm256_cvtps_epi32( v2 );
|
||||
__m256i i3 = _mm256_cvtps_epi32( v3 );
|
||||
|
||||
// Convert int32 to int16
|
||||
i0 = _mm256_packs_epi32( i0, i1 );
|
||||
i2 = _mm256_packs_epi32( i2, i3 );
|
||||
// Convert int16 to int8
|
||||
i0 = _mm256_packs_epi16( i0, i2 );
|
||||
|
||||
// Permute and store the quantized weights in the required order after the pack instruction
|
||||
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
|
||||
i0 = _mm256_permutevar8x32_epi32( i0, perm );
|
||||
|
||||
_mm256_storeu_si256((__m256i *)(y[i].qs + 32 * j), i0);
|
||||
quants_interleaved[j] = i0;
|
||||
}
|
||||
|
||||
// Masks to shuffle the quants of corresonding sub blocks for rearraning quants for vectorized bsums computation
|
||||
__m256i shuffle_mask_sb2 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 0, 1, 4, 5, 6, 7, 8, 9, 8, 9, 12, 13, 14, 15));
|
||||
shuffle_mask_sb2 = _mm256_permute2f128_si256(shuffle_mask_sb2, shuffle_mask_sb2, 0);
|
||||
__m256i shuffle_mask_sb3 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 8, 9, 14, 15));
|
||||
shuffle_mask_sb3 = _mm256_permute2f128_si256(shuffle_mask_sb3, shuffle_mask_sb3, 0);
|
||||
__m256i shuffle_mask_sb4 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 4, 5, 0, 1, 8, 9, 10, 11, 12, 13, 8, 9));
|
||||
shuffle_mask_sb4 = _mm256_permute2f128_si256(shuffle_mask_sb4, shuffle_mask_sb4, 0);
|
||||
|
||||
for (int k = 0; k < 4; k++) {
|
||||
// Quants from four different sub blocks are taken
|
||||
__m256i q0 = quants_interleaved[k * 8 + 0];
|
||||
__m256i q1 = quants_interleaved[k * 8 + 1];
|
||||
__m256i q2 = quants_interleaved[k * 8 + 2];
|
||||
__m256i q3 = quants_interleaved[k * 8 + 3];
|
||||
__m256i q4 = quants_interleaved[k * 8 + 4];
|
||||
__m256i q5 = quants_interleaved[k * 8 + 5];
|
||||
__m256i q6 = quants_interleaved[k * 8 + 6];
|
||||
__m256i q7 = quants_interleaved[k * 8 + 7];
|
||||
|
||||
|
||||
// The below code block has the first half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
|
||||
__m256i sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
|
||||
__m256i sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
|
||||
__m256i sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
|
||||
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
|
||||
__m256i sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
|
||||
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
|
||||
|
||||
__m256i one = _mm256_set1_epi8(1);
|
||||
__m256i bsums_r1 = _mm256_maddubs_epi16(one, sb_h1_interleaved);
|
||||
|
||||
for (int l = 0; l < 3; l++) {
|
||||
// Quants value shifted to process next two values from each sub block
|
||||
q0 = _mm256_srli_epi64(q0, 16);
|
||||
q2 = _mm256_srli_epi64(q2, 16);
|
||||
q4 = _mm256_srli_epi64(q4, 16);
|
||||
q6 = _mm256_srli_epi64(q6, 16);
|
||||
|
||||
sb2_h1_shuffled = _mm256_shuffle_epi8(q2, shuffle_mask_sb2);
|
||||
sb_h1_interleaved = _mm256_blend_epi16(q0, sb2_h1_shuffled, 34);
|
||||
sb3_h1_shuffled = _mm256_shuffle_epi8(q4, shuffle_mask_sb3);
|
||||
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb3_h1_shuffled, 68);
|
||||
sb4_h1_shuffled = _mm256_shuffle_epi8(q6, shuffle_mask_sb4);
|
||||
sb_h1_interleaved = _mm256_blend_epi16(sb_h1_interleaved, sb4_h1_shuffled, 136);
|
||||
|
||||
bsums_r1 = _mm256_add_epi16(bsums_r1, _mm256_maddubs_epi16(one, sb_h1_interleaved));
|
||||
}
|
||||
|
||||
// The below code block has the second half of different sub blocks shuffled and blended so as to process 2 values from each sub block at a time
|
||||
__m256i sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
|
||||
__m256i sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
|
||||
__m256i sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
|
||||
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
|
||||
__m256i sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
|
||||
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
|
||||
|
||||
__m256i bsums_r2 = _mm256_maddubs_epi16(one, sb_h2_interleaved);
|
||||
|
||||
for (int l = 0; l < 3; l++) {
|
||||
// Quants value shifted to process next two values from each sub block
|
||||
q1 = _mm256_srli_epi64(q1, 16);
|
||||
q3 = _mm256_srli_epi64(q3, 16);
|
||||
q5 = _mm256_srli_epi64(q5, 16);
|
||||
q7 = _mm256_srli_epi64(q7, 16);
|
||||
|
||||
sb2_h2_shuffled = _mm256_shuffle_epi8(q3, shuffle_mask_sb2);
|
||||
sb_h2_interleaved = _mm256_blend_epi16(q1, sb2_h2_shuffled, 34);
|
||||
sb3_h2_shuffled = _mm256_shuffle_epi8(q5, shuffle_mask_sb3);
|
||||
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb3_h2_shuffled, 68);
|
||||
sb4_h2_shuffled = _mm256_shuffle_epi8(q7, shuffle_mask_sb4);
|
||||
sb_h2_interleaved = _mm256_blend_epi16(sb_h2_interleaved, sb4_h2_shuffled, 136);
|
||||
|
||||
bsums_r2 = _mm256_add_epi16(bsums_r2, _mm256_maddubs_epi16(one, sb_h2_interleaved));
|
||||
}
|
||||
|
||||
// Overall bsums in interleaved fashion computed by adding results of both halves
|
||||
__m256i bsums_r = _mm256_add_epi16(bsums_r1, bsums_r2);
|
||||
_mm256_storeu_si256((__m256i *)(y[i].bsums + 16 * k), bsums_r);
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
// NOTE: This default c implementation is aligned with AVX2 implemantation, but differs from arm implementation.
|
||||
// especially in bsums arrangement.
|
||||
UNUSED(nb);
|
||||
UNUSED(y);
|
||||
ggml_quantize_mat_q8_K_4x8_generic(x, vy, k);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
||||
template <int64_t INTER_SIZE, ggml_type PARAM_TYPE>
|
||||
|
|
@ -731,6 +391,87 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_K_4x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 4;
|
||||
const int blocklen = 8;
|
||||
assert (n % qk == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
UNUSED(s);
|
||||
UNUSED(bs);
|
||||
UNUSED(vx);
|
||||
UNUSED(vy);
|
||||
UNUSED(nr);
|
||||
UNUSED(nc);
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
uint8_t scales[4][8]; // scales for 8 subblocks of 4 q4_k unit (4 cols)
|
||||
uint8_t mins[4][8]; // mins for 8 subblocks of 4 q4_k unit (4 cols)
|
||||
float sumf[4]; // 1x4 unit: final result
|
||||
float sum_minf[4]; // 1x4 unit: final minus result
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
const block_q8_K * a_ptr = (const block_q8_K *) vy;
|
||||
|
||||
// loop on n dimension, each iteration works on 4 columns
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_Kx4 * b_ptr = (const block_q4_Kx4 *) vx + (x * nb);
|
||||
|
||||
// initialize results for 4 cols
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[j] = 0.0;
|
||||
sum_minf[j] = 0.0;
|
||||
}
|
||||
|
||||
// loop on k dimension, each iteration works on 1 q4_kx4 block
|
||||
for (int n = 0; n < nb; n++) {
|
||||
// prepare scales and mins for 4 cols
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
for (int i = 0; i < 8; i++) {
|
||||
scales[j][i] = b_ptr[n].scales[i * 8 + j];
|
||||
mins[j][i] = b_ptr[n].scales[i * 8 + j + ncols_interleaved];
|
||||
}
|
||||
}
|
||||
|
||||
// core loop: each iteration works on an interleaved unit (four 8-byte segments from 4 cols)
|
||||
for (int k = 0; k < (qk / (2 * blocklen); k++)) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; i++) {
|
||||
const int v0 = (int8_t)(b_ptr[n].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xf);
|
||||
const int v1 = (int8_t)(b_ptr[n].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & >> 4);
|
||||
sumi1 = (v0 * a_ptr[n].qs((k / 2) * 32 + (k % 2) * blocklen + i));
|
||||
sumi2 = (v0 * a_ptr[n].qs((k / 2) * 32 + (k % 2) * blocklen + i + 16));
|
||||
uint8_t scale = scales[j][k / 2];
|
||||
sumi += sumi1 * scale + sumi2 * scale;
|
||||
}
|
||||
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[n].d[j]) * a_ptr[n].d;
|
||||
}
|
||||
}
|
||||
|
||||
// prepare partial results for tail work
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
for (int i = 0; i < QK_K / 32; i++) {
|
||||
sum_minf[j] += mins[j][i] * (a_ptr[n].bsums[i * 2] + a_ptr[n].bsums[i * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[n].dmin[j]) * a_ptr[n].d;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// save results
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
|
@ -1290,6 +1031,95 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_K_4x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
const int ncols_interleaved = 4;
|
||||
const int blocklen = 8;
|
||||
|
||||
assert (n % qk == 0);
|
||||
assert (nr % 4 == 0);
|
||||
assert (nc % ncols_interleaved == 0);
|
||||
|
||||
UNUSED(nb);
|
||||
UNUSED(ncols_interleaved);
|
||||
UNUSED(blocklen);
|
||||
|
||||
uint8_t scales[4][8]; // scales for 8 subblocks of 4 q4_k unit (4 cols)
|
||||
uint8_t mins[4][8]; // mins for 8 subblocks of 4 q4_k unit (4 cols)
|
||||
float sumf[4][4]; // 4x4 unit: final result
|
||||
float sum_minf[4][4]; // 4x4 unit: final minus result
|
||||
int sumi1;
|
||||
int sumi2;
|
||||
int sumi;
|
||||
|
||||
// loop on m dimension, each iteration works on 4 rows
|
||||
for (int y = 0; y < nr / 4; y++) {
|
||||
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
||||
|
||||
// loop on n dimension, each iteration works on 4 columns
|
||||
for (int x = 0; x < nc / ncols_interleaved; x++) {
|
||||
const block_q4_Kx4 * b_ptr = (const block_q4_Kx4 *) vx + (x * nb);
|
||||
|
||||
// initialize results for 4 cols
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumf[m][j] = 0.0;
|
||||
sum_minf[m][j] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
// loop on k dimension, each iteration works on 1 q4_kx4 block
|
||||
for (int n = 0; n < nb; n++) {
|
||||
// prepare scales and mins for 4 cols
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
for (int i = 0; i < 8; i++) {
|
||||
scales[j][i] = b_ptr[n].scales[i * 8 + j];
|
||||
mins[j][i] = b_ptr[n].scales[i * 8 + j + ncols_interleaved];
|
||||
}
|
||||
}
|
||||
|
||||
// core loop: each iteration works on an interleaved unit (four 8-byte segments from 4 cols)
|
||||
for (int k = 0; k < (qk / (2 * blocklen); k++)) {
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
sumi1 = 0;
|
||||
sumi2 = 0;
|
||||
sumi = 0;
|
||||
for (int i = 0; i < blocklen; i++) {
|
||||
const int v0 = (int8_t)(b_ptr[n].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xf);
|
||||
const int v1 = (int8_t)(b_ptr[n].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & >> 4);
|
||||
sumi1 = (v0 * a_ptr[n].qs((k / 2) * 128 + (k % 2) * 4 * blocklen + i));
|
||||
sumi2 = (v0 * a_ptr[n].qs((k / 2) * 128 + (k % 2) * 4 * blocklen + i + 64));
|
||||
uint8_t scale = scales[j][k / 2];
|
||||
sumi += sumi1 * scale + sumi2 * scale;
|
||||
}
|
||||
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[n].d[j]) * a_ptr[n].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// prepare partial results for tail work
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
for (int i = 0; i < QK_K / 32; i++) {
|
||||
const int16_t *bsums = a_ptr[n].bsums + (i * 8) - ((i % 2) * 6) + (m * 4);
|
||||
sum_minf[m][j] += mins[j][i] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[n].dmin[j]) * a_ptr[n].d[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// save results
|
||||
for (int m = 0; m < 4; m++) {
|
||||
for (int j = 0; j < ncols_interleaved; j++) {
|
||||
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
const int qk = QK_K;
|
||||
const int nb = n / qk;
|
||||
|
|
@ -3037,8 +2867,10 @@ static size_t ggml_backend_cpu_repack_buffer_type_get_alignment(ggml_backend_buf
|
|||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
// size calculation after q4_kx4 repacking, it's different from traditional type
|
||||
size_t ggml_nbytes_q4_kx4(const struct ggml_tensor * tensor) {
|
||||
// For the aarch64 q4_k repack implementation only:
|
||||
// New tensor storage size calculation after q4_kx4 repacking.
|
||||
// Because sizeof(block_q4_Kx4) is a bit larger than default sizeof(block_q4_K)*4.
|
||||
static inline size_t ggml_nbytes_q4_kx4(const struct ggml_tensor * tensor) {
|
||||
size_t nbytes;
|
||||
const size_t blck_size = 256;
|
||||
const size_t type_size = sizeof(block_q4_Kx4) / 4;
|
||||
|
|
|
|||
Loading…
Reference in New Issue