fix compatibility with other q4_k repacking models
This commit is contained in:
parent
da606bd736
commit
126ce2c9e4
|
|
@ -210,129 +210,137 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc) {
|
||||||
assert(QK_K == 256);
|
assert(QK_K == 256);
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
|
UNUSED(nc);
|
||||||
const int nb = k / QK_K;
|
const int nb = k / QK_K;
|
||||||
|
|
||||||
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
|
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
|
||||||
|
|
||||||
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
const int blck_size_interleave = 8;
|
if (nc % 8 == 0) {
|
||||||
float32x4_t srcv[4][64]; // 64 = QK_K/4
|
UNUSED(nb);
|
||||||
float iscale[4];
|
UNUSED(y);
|
||||||
|
ggml_quantize_mat_q8_K_4x8_generic(x, vy, k, nc);
|
||||||
|
} else if (nc % 4 == 0) {
|
||||||
|
const int blck_size_interleave = 8;
|
||||||
|
float32x4_t srcv[4][64]; // 64 = QK_K/4
|
||||||
|
float iscale[4];
|
||||||
|
|
||||||
for (int i = 0; i < nb; i++) {
|
for (int i = 0; i < nb; i++) {
|
||||||
float32x4_t asrcv[64];
|
float32x4_t asrcv[64];
|
||||||
float32x4_t amaxv[64];
|
float32x4_t amaxv[64];
|
||||||
|
|
||||||
// d:
|
// d:
|
||||||
for (int row_iter = 0; row_iter < 4; row_iter++) {
|
for (int row_iter = 0; row_iter < 4; row_iter++) {
|
||||||
for (int j = 0; j < 64; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 256 + 4 * j);
|
for (int j = 0; j < 64; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 256 + 4 * j);
|
||||||
for (int j = 0; j < 64; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
|
for (int j = 0; j < 64; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
|
||||||
|
|
||||||
for (int j = 0; j < 32; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
|
for (int j = 0; j < 32; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
|
||||||
for (int j = 0; j < 16; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
|
for (int j = 0; j < 16; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
|
||||||
for (int j = 0; j < 8; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
|
for (int j = 0; j < 8; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
|
||||||
for (int j = 0; j < 4; j++) amaxv[16 * j] = vmaxq_f32(amaxv[16 * j], amaxv[16 * j + 8]);
|
for (int j = 0; j < 4; j++) amaxv[16 * j] = vmaxq_f32(amaxv[16 * j], amaxv[16 * j + 8]);
|
||||||
for (int j = 0; j < 2; j++) amaxv[32 * j] = vmaxq_f32(amaxv[32 * j], amaxv[32 * j + 16]);
|
for (int j = 0; j < 2; j++) amaxv[32 * j] = vmaxq_f32(amaxv[32 * j], amaxv[32 * j + 16]);
|
||||||
for (int j = 0; j < 1; j++) amaxv[64 * j] = vmaxq_f32(amaxv[64 * j], amaxv[64 * j + 32]);
|
for (int j = 0; j < 1; j++) amaxv[64 * j] = vmaxq_f32(amaxv[64 * j], amaxv[64 * j + 32]);
|
||||||
|
|
||||||
const float amax = vmaxvq_f32(amaxv[0]);
|
const float amax = vmaxvq_f32(amaxv[0]);
|
||||||
|
|
||||||
// Check if exists: orig == amax
|
// Check if exists: orig == amax
|
||||||
float32x4_t amax_vec = vdupq_n_f32(amax);
|
float32x4_t amax_vec = vdupq_n_f32(amax);
|
||||||
uint32x4_t mask_all = vdupq_n_u32(0);
|
uint32x4_t mask_all = vdupq_n_u32(0);
|
||||||
for (int j = 0; j < 64; j++) {
|
for (int j = 0; j < 64; j++) {
|
||||||
uint32x4_t mask_curr = vceqq_f32(amax_vec, srcv[row_iter][j]);
|
uint32x4_t mask_curr = vceqq_f32(amax_vec, srcv[row_iter][j]);
|
||||||
mask_all = vorrq_u32(mask_all, mask_curr);
|
mask_all = vorrq_u32(mask_all, mask_curr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assume that none == amax, then check mask_all to reverse
|
||||||
|
iscale[row_iter] = ( amax != 0.0f ) ? 127.f / amax : 0.0f;
|
||||||
|
uint32x4_t cmp = vceqq_u32(mask_all, vdupq_n_u32(0xFFFFFFFFu));
|
||||||
|
if (vmaxvq_u32(cmp) != 0) {
|
||||||
|
iscale[row_iter] = ( amax != 0.0f ) ? -127.f / amax : 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assume that none == amax, then check mask_all to reverse
|
// qs: 8 byte interleave over 4 rows, loop = QK_K/8
|
||||||
iscale[row_iter] = ( amax != 0.0f ) ? 127.f / amax : 0.0f;
|
// bsums: simply generated one by one, row_i is calculated before row_i+1
|
||||||
uint32x4_t cmp = vceqq_u32(mask_all, vdupq_n_u32(0xFFFFFFFFu));
|
// loops = 16
|
||||||
if (vmaxvq_u32(cmp) != 0) {
|
for (int j = 0; j < QK_K / blck_size_interleave / 2; j++) {
|
||||||
iscale[row_iter] = ( amax != 0.0f ) ? -127.f / amax : 0.0f;
|
// Process row0 and row1
|
||||||
|
float32x4_t f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j], iscale[0]));
|
||||||
|
float32x4_t f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 1], iscale[0]));
|
||||||
|
float32x4_t f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 2], iscale[0]));
|
||||||
|
float32x4_t f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 3], iscale[0]));
|
||||||
|
int32x4_t i0_0_3 = vcvtnq_s32_f32(f0_0_3);
|
||||||
|
int32x4_t i0_4_7 = vcvtnq_s32_f32(f0_4_7);
|
||||||
|
int16x8_t i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||||
|
int32x4_t i0_8_11 = vcvtnq_s32_f32(f0_8_11);
|
||||||
|
int32x4_t i0_12_15 = vcvtnq_s32_f32(f0_12_15);
|
||||||
|
int16x8_t i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||||
|
|
||||||
|
float32x4_t f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j], iscale[1]));
|
||||||
|
float32x4_t f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 1], iscale[1]));
|
||||||
|
float32x4_t f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 2], iscale[1]));
|
||||||
|
float32x4_t f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 3], iscale[1]));
|
||||||
|
int32x4_t i1_0_3 = vcvtnq_s32_f32(f1_0_3);
|
||||||
|
int32x4_t i1_4_7 = vcvtnq_s32_f32(f1_4_7);
|
||||||
|
int16x8_t i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||||
|
int32x4_t i1_8_11 = vcvtnq_s32_f32(f1_8_11);
|
||||||
|
int32x4_t i1_12_15 = vcvtnq_s32_f32(f1_12_15);
|
||||||
|
int16x8_t i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||||
|
|
||||||
|
// Calculate and store qs
|
||||||
|
int8x16_t i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||||
|
int8x16_t i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||||
|
vst1q_s8(y[i].qs + 64 * j, i0_i1_0_7);
|
||||||
|
vst1q_s8(y[i].qs + 64 * j + 32, i0_i1_8_15);
|
||||||
|
// Calculate and store bsum
|
||||||
|
int8x16_t i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||||
|
int8x16_t i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||||
|
y[i].bsums[j] = vaddlvq_s8(i0_0_15);
|
||||||
|
y[i].bsums[j + 16] = vaddlvq_s8(i1_0_15);
|
||||||
|
|
||||||
|
// Process row2 and row3
|
||||||
|
f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j], iscale[2]));
|
||||||
|
f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 1], iscale[2]));
|
||||||
|
f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 2], iscale[2]));
|
||||||
|
f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 3], iscale[2]));
|
||||||
|
i0_0_3 = vcvtnq_s32_f32(f0_0_3);
|
||||||
|
i0_4_7 = vcvtnq_s32_f32(f0_4_7);
|
||||||
|
i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||||
|
i0_8_11 = vcvtnq_s32_f32(f0_8_11);
|
||||||
|
i0_12_15 = vcvtnq_s32_f32(f0_12_15);
|
||||||
|
i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||||
|
|
||||||
|
f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j], iscale[3]));
|
||||||
|
f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 1], iscale[3]));
|
||||||
|
f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 2], iscale[3]));
|
||||||
|
f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 3], iscale[3]));
|
||||||
|
i1_0_3 = vcvtnq_s32_f32(f1_0_3);
|
||||||
|
i1_4_7 = vcvtnq_s32_f32(f1_4_7);
|
||||||
|
i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||||
|
i1_8_11 = vcvtnq_s32_f32(f1_8_11);
|
||||||
|
i1_12_15 = vcvtnq_s32_f32(f1_12_15);
|
||||||
|
i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||||
|
|
||||||
|
// Calculate and store qs
|
||||||
|
i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||||
|
i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||||
|
vst1q_s8(y[i].qs + 64 * j + 16, i0_i1_0_7);
|
||||||
|
vst1q_s8(y[i].qs + 64 * j + 48, i0_i1_8_15);
|
||||||
|
// Calculate and store bsum
|
||||||
|
i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||||
|
i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||||
|
y[i].bsums[j + 32] = vaddlvq_s8(i0_0_15);
|
||||||
|
y[i].bsums[j + 48] = vaddlvq_s8(i1_0_15);
|
||||||
}
|
}
|
||||||
|
|
||||||
y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// qs: 8 byte interleave over 4 rows, loop = QK_K/8
|
|
||||||
// bsums: simply generated one by one, row_i is calculated before row_i+1
|
|
||||||
// loops = 16
|
|
||||||
for (int j = 0; j < QK_K / blck_size_interleave / 2; j++) {
|
|
||||||
// Process row0 and row1
|
|
||||||
float32x4_t f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j], iscale[0]));
|
|
||||||
float32x4_t f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 1], iscale[0]));
|
|
||||||
float32x4_t f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 2], iscale[0]));
|
|
||||||
float32x4_t f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 3], iscale[0]));
|
|
||||||
int32x4_t i0_0_3 = vcvtnq_s32_f32(f0_0_3);
|
|
||||||
int32x4_t i0_4_7 = vcvtnq_s32_f32(f0_4_7);
|
|
||||||
int16x8_t i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
|
||||||
int32x4_t i0_8_11 = vcvtnq_s32_f32(f0_8_11);
|
|
||||||
int32x4_t i0_12_15 = vcvtnq_s32_f32(f0_12_15);
|
|
||||||
int16x8_t i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
|
||||||
|
|
||||||
float32x4_t f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j], iscale[1]));
|
|
||||||
float32x4_t f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 1], iscale[1]));
|
|
||||||
float32x4_t f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 2], iscale[1]));
|
|
||||||
float32x4_t f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 3], iscale[1]));
|
|
||||||
int32x4_t i1_0_3 = vcvtnq_s32_f32(f1_0_3);
|
|
||||||
int32x4_t i1_4_7 = vcvtnq_s32_f32(f1_4_7);
|
|
||||||
int16x8_t i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
|
||||||
int32x4_t i1_8_11 = vcvtnq_s32_f32(f1_8_11);
|
|
||||||
int32x4_t i1_12_15 = vcvtnq_s32_f32(f1_12_15);
|
|
||||||
int16x8_t i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
|
||||||
|
|
||||||
// Calculate and store qs
|
|
||||||
int8x16_t i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
|
||||||
int8x16_t i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
|
||||||
vst1q_s8(y[i].qs + 64 * j, i0_i1_0_7);
|
|
||||||
vst1q_s8(y[i].qs + 64 * j + 32, i0_i1_8_15);
|
|
||||||
// Calculate and store bsum
|
|
||||||
int8x16_t i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
|
||||||
int8x16_t i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
|
||||||
y[i].bsums[j] = vaddlvq_s8(i0_0_15);
|
|
||||||
y[i].bsums[j + 16] = vaddlvq_s8(i1_0_15);
|
|
||||||
|
|
||||||
// Process row2 and row3
|
|
||||||
f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j], iscale[2]));
|
|
||||||
f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 1], iscale[2]));
|
|
||||||
f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 2], iscale[2]));
|
|
||||||
f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 3], iscale[2]));
|
|
||||||
i0_0_3 = vcvtnq_s32_f32(f0_0_3);
|
|
||||||
i0_4_7 = vcvtnq_s32_f32(f0_4_7);
|
|
||||||
i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
|
||||||
i0_8_11 = vcvtnq_s32_f32(f0_8_11);
|
|
||||||
i0_12_15 = vcvtnq_s32_f32(f0_12_15);
|
|
||||||
i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
|
||||||
|
|
||||||
f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j], iscale[3]));
|
|
||||||
f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 1], iscale[3]));
|
|
||||||
f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 2], iscale[3]));
|
|
||||||
f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 3], iscale[3]));
|
|
||||||
i1_0_3 = vcvtnq_s32_f32(f1_0_3);
|
|
||||||
i1_4_7 = vcvtnq_s32_f32(f1_4_7);
|
|
||||||
i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
|
||||||
i1_8_11 = vcvtnq_s32_f32(f1_8_11);
|
|
||||||
i1_12_15 = vcvtnq_s32_f32(f1_12_15);
|
|
||||||
i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
|
||||||
|
|
||||||
// Calculate and store qs
|
|
||||||
i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
|
||||||
i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
|
||||||
vst1q_s8(y[i].qs + 64 * j + 16, i0_i1_0_7);
|
|
||||||
vst1q_s8(y[i].qs + 64 * j + 48, i0_i1_8_15);
|
|
||||||
// Calculate and store bsum
|
|
||||||
i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
|
||||||
i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
|
||||||
y[i].bsums[j + 32] = vaddlvq_s8(i0_0_15);
|
|
||||||
y[i].bsums[j + 48] = vaddlvq_s8(i1_0_15);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return;
|
||||||
|
#endif
|
||||||
|
|
||||||
#else
|
|
||||||
// NOTE:
|
// NOTE:
|
||||||
// Current C impl of Q8_K quanti is originally designed to work with block_q4_Kx8 in x86 AVX design, and differs from
|
// Current C impl of Q8_K quanti is originally designed to work with block_q4_Kx8 in x86 AVX design, and differs from
|
||||||
// above Q8_K quanti logic in AArch64 NEON design, which is designed to work with block_q4_Kx4. The main difference is in
|
// above Q8_K quanti logic in AArch64 NEON design, which is designed to work with block_q4_Kx4. The main difference is in
|
||||||
|
|
@ -340,8 +348,7 @@ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
|
||||||
// "[bsums] layout" correctly in ggml_gemm_q4_K_4x8_q8_K_generic().
|
// "[bsums] layout" correctly in ggml_gemm_q4_K_4x8_q8_K_generic().
|
||||||
UNUSED(nb);
|
UNUSED(nb);
|
||||||
UNUSED(y);
|
UNUSED(y);
|
||||||
ggml_quantize_mat_q8_K_4x8_generic(x, vy, k);
|
ggml_quantize_mat_q8_K_4x8_generic(x, vy, k, nc);
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||||
|
|
|
||||||
|
|
@ -176,9 +176,10 @@ void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GG
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc) {
|
||||||
assert(QK_K == 256);
|
assert(QK_K == 256);
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
|
UNUSED(nc);
|
||||||
const int nb = k / QK_K;
|
const int nb = k / QK_K;
|
||||||
|
|
||||||
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
|
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
|
||||||
|
|
@ -230,30 +231,33 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
|
||||||
template <int64_t INTER_SIZE, ggml_type PARAM_TYPE>
|
template <int64_t INTER_SIZE, ggml_type PARAM_TYPE>
|
||||||
void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row);
|
void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols);
|
||||||
|
|
||||||
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) {
|
||||||
assert(nrow == 4);
|
assert(nrow == 4);
|
||||||
UNUSED(nrow);
|
UNUSED(nrow);
|
||||||
|
UNUSED(ncols);
|
||||||
ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row);
|
ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) {
|
||||||
assert(nrow == 4);
|
assert(nrow == 4);
|
||||||
UNUSED(nrow);
|
UNUSED(nrow);
|
||||||
|
UNUSED(ncols);
|
||||||
ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
|
ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) {
|
||||||
assert(nrow == 4);
|
assert(nrow == 4);
|
||||||
UNUSED(nrow);
|
UNUSED(nrow);
|
||||||
|
UNUSED(ncols);
|
||||||
ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row);
|
ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) {
|
||||||
assert(nrow == 4);
|
assert(nrow == 4);
|
||||||
UNUSED(nrow);
|
UNUSED(nrow);
|
||||||
ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);
|
ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row, ncols);
|
||||||
}
|
}
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
@ -2502,7 +2506,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
||||||
|
|
||||||
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
||||||
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),
|
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),
|
||||||
(void *) (wdata_ptr + i11 * nbw1), 4, ne10);
|
(void *) (wdata_ptr + i11 * nbw1), 4, ne10, ne01);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t i11_processed = ne11 - ne11 % 4;
|
const int64_t i11_processed = ne11 - ne11 % 4;
|
||||||
|
|
@ -2775,15 +2779,13 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
||||||
return &q4_K_8x8_q8_K;
|
return &q4_K_8x8_q8_K;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { // new for ARM N2
|
|
||||||
if (cur->ne[1] % 4 == 0) {
|
|
||||||
return &q4_K_4x8_q8_K;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||||
if (cur->ne[1] % 8 == 0) {
|
if (cur->ne[1] % 8 == 0) {
|
||||||
return &q4_K_8x8_q8_K;
|
return &q4_K_8x8_q8_K;
|
||||||
}
|
}
|
||||||
|
if (cur->ne[1] % 4 == 0) {
|
||||||
|
return &q4_K_4x8_q8_K;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
||||||
if (cur->ne[1] % 8 == 0) {
|
if (cur->ne[1] % 8 == 0) {
|
||||||
|
|
|
||||||
|
|
@ -88,7 +88,7 @@ extern "C" {
|
||||||
void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||||
void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||||
void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc);
|
||||||
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
|
@ -116,7 +116,7 @@ void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
||||||
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||||
void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||||
void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||||
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc);
|
||||||
// gemv_generic ???
|
// gemv_generic ???
|
||||||
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue