fix compatibility with other q4_k repacking models
This commit is contained in:
parent
da606bd736
commit
126ce2c9e4
|
|
@ -210,129 +210,137 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
|
|||
#endif
|
||||
}
|
||||
|
||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc) {
|
||||
assert(QK_K == 256);
|
||||
assert(k % QK_K == 0);
|
||||
UNUSED(nc);
|
||||
const int nb = k / QK_K;
|
||||
|
||||
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
|
||||
|
||||
#if defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
const int blck_size_interleave = 8;
|
||||
float32x4_t srcv[4][64]; // 64 = QK_K/4
|
||||
float iscale[4];
|
||||
if (nc % 8 == 0) {
|
||||
UNUSED(nb);
|
||||
UNUSED(y);
|
||||
ggml_quantize_mat_q8_K_4x8_generic(x, vy, k, nc);
|
||||
} else if (nc % 4 == 0) {
|
||||
const int blck_size_interleave = 8;
|
||||
float32x4_t srcv[4][64]; // 64 = QK_K/4
|
||||
float iscale[4];
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float32x4_t asrcv[64];
|
||||
float32x4_t amaxv[64];
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float32x4_t asrcv[64];
|
||||
float32x4_t amaxv[64];
|
||||
|
||||
// d:
|
||||
for (int row_iter = 0; row_iter < 4; row_iter++) {
|
||||
for (int j = 0; j < 64; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 256 + 4 * j);
|
||||
for (int j = 0; j < 64; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
|
||||
// d:
|
||||
for (int row_iter = 0; row_iter < 4; row_iter++) {
|
||||
for (int j = 0; j < 64; j++) srcv[row_iter][j] = vld1q_f32(x + row_iter * k + i * 256 + 4 * j);
|
||||
for (int j = 0; j < 64; j++) asrcv[j] = vabsq_f32(srcv[row_iter][j]);
|
||||
|
||||
for (int j = 0; j < 32; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
|
||||
for (int j = 0; j < 16; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
|
||||
for (int j = 0; j < 8; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
|
||||
for (int j = 0; j < 4; j++) amaxv[16 * j] = vmaxq_f32(amaxv[16 * j], amaxv[16 * j + 8]);
|
||||
for (int j = 0; j < 2; j++) amaxv[32 * j] = vmaxq_f32(amaxv[32 * j], amaxv[32 * j + 16]);
|
||||
for (int j = 0; j < 1; j++) amaxv[64 * j] = vmaxq_f32(amaxv[64 * j], amaxv[64 * j + 32]);
|
||||
for (int j = 0; j < 32; j++) amaxv[2 * j] = vmaxq_f32(asrcv[2 * j], asrcv[2 * j + 1]);
|
||||
for (int j = 0; j < 16; j++) amaxv[4 * j] = vmaxq_f32(amaxv[4 * j], amaxv[4 * j + 2]);
|
||||
for (int j = 0; j < 8; j++) amaxv[8 * j] = vmaxq_f32(amaxv[8 * j], amaxv[8 * j + 4]);
|
||||
for (int j = 0; j < 4; j++) amaxv[16 * j] = vmaxq_f32(amaxv[16 * j], amaxv[16 * j + 8]);
|
||||
for (int j = 0; j < 2; j++) amaxv[32 * j] = vmaxq_f32(amaxv[32 * j], amaxv[32 * j + 16]);
|
||||
for (int j = 0; j < 1; j++) amaxv[64 * j] = vmaxq_f32(amaxv[64 * j], amaxv[64 * j + 32]);
|
||||
|
||||
const float amax = vmaxvq_f32(amaxv[0]);
|
||||
const float amax = vmaxvq_f32(amaxv[0]);
|
||||
|
||||
// Check if exists: orig == amax
|
||||
float32x4_t amax_vec = vdupq_n_f32(amax);
|
||||
uint32x4_t mask_all = vdupq_n_u32(0);
|
||||
for (int j = 0; j < 64; j++) {
|
||||
uint32x4_t mask_curr = vceqq_f32(amax_vec, srcv[row_iter][j]);
|
||||
mask_all = vorrq_u32(mask_all, mask_curr);
|
||||
// Check if exists: orig == amax
|
||||
float32x4_t amax_vec = vdupq_n_f32(amax);
|
||||
uint32x4_t mask_all = vdupq_n_u32(0);
|
||||
for (int j = 0; j < 64; j++) {
|
||||
uint32x4_t mask_curr = vceqq_f32(amax_vec, srcv[row_iter][j]);
|
||||
mask_all = vorrq_u32(mask_all, mask_curr);
|
||||
}
|
||||
|
||||
// Assume that none == amax, then check mask_all to reverse
|
||||
iscale[row_iter] = ( amax != 0.0f ) ? 127.f / amax : 0.0f;
|
||||
uint32x4_t cmp = vceqq_u32(mask_all, vdupq_n_u32(0xFFFFFFFFu));
|
||||
if (vmaxvq_u32(cmp) != 0) {
|
||||
iscale[row_iter] = ( amax != 0.0f ) ? -127.f / amax : 0.0f;
|
||||
}
|
||||
|
||||
y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
|
||||
}
|
||||
|
||||
// Assume that none == amax, then check mask_all to reverse
|
||||
iscale[row_iter] = ( amax != 0.0f ) ? 127.f / amax : 0.0f;
|
||||
uint32x4_t cmp = vceqq_u32(mask_all, vdupq_n_u32(0xFFFFFFFFu));
|
||||
if (vmaxvq_u32(cmp) != 0) {
|
||||
iscale[row_iter] = ( amax != 0.0f ) ? -127.f / amax : 0.0f;
|
||||
// qs: 8 byte interleave over 4 rows, loop = QK_K/8
|
||||
// bsums: simply generated one by one, row_i is calculated before row_i+1
|
||||
// loops = 16
|
||||
for (int j = 0; j < QK_K / blck_size_interleave / 2; j++) {
|
||||
// Process row0 and row1
|
||||
float32x4_t f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j], iscale[0]));
|
||||
float32x4_t f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 1], iscale[0]));
|
||||
float32x4_t f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 2], iscale[0]));
|
||||
float32x4_t f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 3], iscale[0]));
|
||||
int32x4_t i0_0_3 = vcvtnq_s32_f32(f0_0_3);
|
||||
int32x4_t i0_4_7 = vcvtnq_s32_f32(f0_4_7);
|
||||
int16x8_t i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
int32x4_t i0_8_11 = vcvtnq_s32_f32(f0_8_11);
|
||||
int32x4_t i0_12_15 = vcvtnq_s32_f32(f0_12_15);
|
||||
int16x8_t i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
float32x4_t f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j], iscale[1]));
|
||||
float32x4_t f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 1], iscale[1]));
|
||||
float32x4_t f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 2], iscale[1]));
|
||||
float32x4_t f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 3], iscale[1]));
|
||||
int32x4_t i1_0_3 = vcvtnq_s32_f32(f1_0_3);
|
||||
int32x4_t i1_4_7 = vcvtnq_s32_f32(f1_4_7);
|
||||
int16x8_t i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
int32x4_t i1_8_11 = vcvtnq_s32_f32(f1_8_11);
|
||||
int32x4_t i1_12_15 = vcvtnq_s32_f32(f1_12_15);
|
||||
int16x8_t i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
// Calculate and store qs
|
||||
int8x16_t i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
int8x16_t i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
vst1q_s8(y[i].qs + 64 * j, i0_i1_0_7);
|
||||
vst1q_s8(y[i].qs + 64 * j + 32, i0_i1_8_15);
|
||||
// Calculate and store bsum
|
||||
int8x16_t i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
int8x16_t i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
y[i].bsums[j] = vaddlvq_s8(i0_0_15);
|
||||
y[i].bsums[j + 16] = vaddlvq_s8(i1_0_15);
|
||||
|
||||
// Process row2 and row3
|
||||
f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j], iscale[2]));
|
||||
f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 1], iscale[2]));
|
||||
f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 2], iscale[2]));
|
||||
f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 3], iscale[2]));
|
||||
i0_0_3 = vcvtnq_s32_f32(f0_0_3);
|
||||
i0_4_7 = vcvtnq_s32_f32(f0_4_7);
|
||||
i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
i0_8_11 = vcvtnq_s32_f32(f0_8_11);
|
||||
i0_12_15 = vcvtnq_s32_f32(f0_12_15);
|
||||
i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j], iscale[3]));
|
||||
f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 1], iscale[3]));
|
||||
f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 2], iscale[3]));
|
||||
f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 3], iscale[3]));
|
||||
i1_0_3 = vcvtnq_s32_f32(f1_0_3);
|
||||
i1_4_7 = vcvtnq_s32_f32(f1_4_7);
|
||||
i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
i1_8_11 = vcvtnq_s32_f32(f1_8_11);
|
||||
i1_12_15 = vcvtnq_s32_f32(f1_12_15);
|
||||
i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
// Calculate and store qs
|
||||
i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
vst1q_s8(y[i].qs + 64 * j + 16, i0_i1_0_7);
|
||||
vst1q_s8(y[i].qs + 64 * j + 48, i0_i1_8_15);
|
||||
// Calculate and store bsum
|
||||
i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
y[i].bsums[j + 32] = vaddlvq_s8(i0_0_15);
|
||||
y[i].bsums[j + 48] = vaddlvq_s8(i1_0_15);
|
||||
}
|
||||
|
||||
y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
|
||||
}
|
||||
|
||||
// qs: 8 byte interleave over 4 rows, loop = QK_K/8
|
||||
// bsums: simply generated one by one, row_i is calculated before row_i+1
|
||||
// loops = 16
|
||||
for (int j = 0; j < QK_K / blck_size_interleave / 2; j++) {
|
||||
// Process row0 and row1
|
||||
float32x4_t f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j], iscale[0]));
|
||||
float32x4_t f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 1], iscale[0]));
|
||||
float32x4_t f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 2], iscale[0]));
|
||||
float32x4_t f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[0][4 * j + 3], iscale[0]));
|
||||
int32x4_t i0_0_3 = vcvtnq_s32_f32(f0_0_3);
|
||||
int32x4_t i0_4_7 = vcvtnq_s32_f32(f0_4_7);
|
||||
int16x8_t i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
int32x4_t i0_8_11 = vcvtnq_s32_f32(f0_8_11);
|
||||
int32x4_t i0_12_15 = vcvtnq_s32_f32(f0_12_15);
|
||||
int16x8_t i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
float32x4_t f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j], iscale[1]));
|
||||
float32x4_t f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 1], iscale[1]));
|
||||
float32x4_t f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 2], iscale[1]));
|
||||
float32x4_t f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[1][4 * j + 3], iscale[1]));
|
||||
int32x4_t i1_0_3 = vcvtnq_s32_f32(f1_0_3);
|
||||
int32x4_t i1_4_7 = vcvtnq_s32_f32(f1_4_7);
|
||||
int16x8_t i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
int32x4_t i1_8_11 = vcvtnq_s32_f32(f1_8_11);
|
||||
int32x4_t i1_12_15 = vcvtnq_s32_f32(f1_12_15);
|
||||
int16x8_t i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
// Calculate and store qs
|
||||
int8x16_t i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
int8x16_t i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
vst1q_s8(y[i].qs + 64 * j, i0_i1_0_7);
|
||||
vst1q_s8(y[i].qs + 64 * j + 32, i0_i1_8_15);
|
||||
// Calculate and store bsum
|
||||
int8x16_t i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
int8x16_t i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
y[i].bsums[j] = vaddlvq_s8(i0_0_15);
|
||||
y[i].bsums[j + 16] = vaddlvq_s8(i1_0_15);
|
||||
|
||||
// Process row2 and row3
|
||||
f0_0_3 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j], iscale[2]));
|
||||
f0_4_7 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 1], iscale[2]));
|
||||
f0_8_11 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 2], iscale[2]));
|
||||
f0_12_15 = vrndnq_f32(vmulq_n_f32(srcv[2][4 * j + 3], iscale[2]));
|
||||
i0_0_3 = vcvtnq_s32_f32(f0_0_3);
|
||||
i0_4_7 = vcvtnq_s32_f32(f0_4_7);
|
||||
i0_0_7 = vcombine_s16(vqmovn_s32(i0_0_3), vqmovn_s32(i0_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
i0_8_11 = vcvtnq_s32_f32(f0_8_11);
|
||||
i0_12_15 = vcvtnq_s32_f32(f0_12_15);
|
||||
i0_8_15 = vcombine_s16(vqmovn_s32(i0_8_11), vqmovn_s32(i0_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
f1_0_3 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j], iscale[3]));
|
||||
f1_4_7 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 1], iscale[3]));
|
||||
f1_8_11 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 2], iscale[3]));
|
||||
f1_12_15 = vrndnq_f32(vmulq_n_f32(srcv[3][4 * j + 3], iscale[3]));
|
||||
i1_0_3 = vcvtnq_s32_f32(f1_0_3);
|
||||
i1_4_7 = vcvtnq_s32_f32(f1_4_7);
|
||||
i1_0_7 = vcombine_s16(vqmovn_s32(i1_0_3), vqmovn_s32(i1_4_7)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
i1_8_11 = vcvtnq_s32_f32(f1_8_11);
|
||||
i1_12_15 = vcvtnq_s32_f32(f1_12_15);
|
||||
i1_8_15 = vcombine_s16(vqmovn_s32(i1_8_11), vqmovn_s32(i1_12_15)); // int32x4 * 2 → int16x4 * 2 → int16x8
|
||||
|
||||
// Calculate and store qs
|
||||
i0_i1_0_7 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i1_0_7)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
i0_i1_8_15 = vcombine_s8(vqmovn_s16(i0_8_15), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
vst1q_s8(y[i].qs + 64 * j + 16, i0_i1_0_7);
|
||||
vst1q_s8(y[i].qs + 64 * j + 48, i0_i1_8_15);
|
||||
// Calculate and store bsum
|
||||
i0_0_15 = vcombine_s8(vqmovn_s16(i0_0_7), vqmovn_s16(i0_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
i1_0_15 = vcombine_s8(vqmovn_s16(i1_0_7), vqmovn_s16(i1_8_15)); // int16x8 * 2 → int8x8 * 2 → int8x16
|
||||
y[i].bsums[j + 32] = vaddlvq_s8(i0_0_15);
|
||||
y[i].bsums[j + 48] = vaddlvq_s8(i1_0_15);
|
||||
}
|
||||
}
|
||||
return;
|
||||
#endif
|
||||
|
||||
#else
|
||||
// NOTE:
|
||||
// Current C impl of Q8_K quanti is originally designed to work with block_q4_Kx8 in x86 AVX design, and differs from
|
||||
// above Q8_K quanti logic in AArch64 NEON design, which is designed to work with block_q4_Kx4. The main difference is in
|
||||
|
|
@ -340,8 +348,7 @@ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR
|
|||
// "[bsums] layout" correctly in ggml_gemm_q4_K_4x8_q8_K_generic().
|
||||
UNUSED(nb);
|
||||
UNUSED(y);
|
||||
ggml_quantize_mat_q8_K_4x8_generic(x, vy, k);
|
||||
#endif
|
||||
ggml_quantize_mat_q8_K_4x8_generic(x, vy, k, nc);
|
||||
}
|
||||
|
||||
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
||||
|
|
|
|||
|
|
@ -176,9 +176,10 @@ void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GG
|
|||
}
|
||||
}
|
||||
|
||||
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
||||
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc) {
|
||||
assert(QK_K == 256);
|
||||
assert(k % QK_K == 0);
|
||||
UNUSED(nc);
|
||||
const int nb = k / QK_K;
|
||||
|
||||
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
|
||||
|
|
@ -230,30 +231,33 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG
|
|||
} // extern "C"
|
||||
|
||||
template <int64_t INTER_SIZE, ggml_type PARAM_TYPE>
|
||||
void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row);
|
||||
void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols);
|
||||
|
||||
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
||||
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) {
|
||||
assert(nrow == 4);
|
||||
UNUSED(nrow);
|
||||
UNUSED(ncols);
|
||||
ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row);
|
||||
}
|
||||
|
||||
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
||||
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) {
|
||||
assert(nrow == 4);
|
||||
UNUSED(nrow);
|
||||
UNUSED(ncols);
|
||||
ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
|
||||
}
|
||||
|
||||
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
||||
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) {
|
||||
assert(nrow == 4);
|
||||
UNUSED(nrow);
|
||||
UNUSED(ncols);
|
||||
ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row);
|
||||
}
|
||||
|
||||
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
|
||||
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t ncols) {
|
||||
assert(nrow == 4);
|
||||
UNUSED(nrow);
|
||||
ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);
|
||||
ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row, ncols);
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
|
@ -2502,7 +2506,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
|
|||
|
||||
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
||||
ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),
|
||||
(void *) (wdata_ptr + i11 * nbw1), 4, ne10);
|
||||
(void *) (wdata_ptr + i11 * nbw1), 4, ne10, ne01);
|
||||
}
|
||||
|
||||
const int64_t i11_processed = ne11 - ne11 % 4;
|
||||
|
|
@ -2775,15 +2779,13 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
|
|||
return &q4_K_8x8_q8_K;
|
||||
}
|
||||
}
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { // new for ARM N2
|
||||
if (cur->ne[1] % 4 == 0) {
|
||||
return &q4_K_4x8_q8_K;
|
||||
}
|
||||
}
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
return &q4_K_8x8_q8_K;
|
||||
}
|
||||
if (cur->ne[1] % 4 == 0) {
|
||||
return &q4_K_4x8_q8_K;
|
||||
}
|
||||
}
|
||||
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
|
||||
if (cur->ne[1] % 8 == 0) {
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ extern "C" {
|
|||
void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc);
|
||||
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
|
@ -116,7 +116,7 @@ void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
|
|||
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
|
||||
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int64_t nc);
|
||||
// gemv_generic ???
|
||||
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||
|
|
|
|||
Loading…
Reference in New Issue