ggml-quants : use a max-heap for TQ1_0 and TQ2_0 quantization

This commit is contained in:
Francis Couture-Harpin 2025-03-22 12:03:26 -04:00
parent f86b8ff210
commit 3e4b675c9f
1 changed files with 199 additions and 39 deletions

View File

@ -1007,8 +1007,8 @@ static float make_qkxs_quants(int n, int nmin, int nmax, const float * restrict
// exhaustive search with cumulative sums
static float make_qkxh_quants(int n, const float * restrict x, const float * restrict weights, int8_t * restrict L, int8_t * restrict Laux, struct k_heap * restrict k_heap, bool signed_scale) {
const int nmin = -k_heap->mid_k; // TODO: maybe directly pass these
const int nmax = k_heap->k + nmin - 1;
const int nmin = MIN(0, -k_heap->mid_k); // TODO: maybe directly pass these
const int nmax = MAX(0, k_heap->k + nmin - 1);
float amax = fabsf(x[0]);
float w_amax = (weights ? weights[0] : x[0] * x[0]) * amax;
int amax_i = 0;
@ -1046,10 +1046,10 @@ static float make_qkxh_quants(int n, const float * restrict x, const float * res
// Find the max range in [0, amax_range] which doesn't result in clamping.
// This is the range from the side which would clamp first (biggest ratio of max to nmax).
// But it's easier and safer to simply use the smallest range.
int amax_range = MIN(abs(nmin), abs(nmax));
int amax_range = MIN(-nmin, nmax);
if (amax_range == 0) {
// one side will clamp anyway
amax_range = MAX(abs(nmin), abs(nmax));
amax_range = MAX(-nmin, nmax);
}
float sumlx = 0.0f;
float suml2 = 0.0f;
@ -1087,40 +1087,192 @@ static float make_qkxh_quants(int n, const float * restrict x, const float * res
const int max_odd = 2*(imax_range + 1) + 1;
const float wmax = fabsf(x[w_amax_i]);
// const float wmax = amax;
{
int best_p_i = -1; // consecutive with 0..n_frac
int i = 0;
while (k_heap->n > 0) {
struct fraction frac = k_heap_pop(k_heap);
if (frac.numer == 0.0f) { break; }
const float v_max_odd = frac.numer * max_odd;
if (wmax * frac.denom > v_max_odd) {
// stop when the inverse scale would result in clamping the most important value
break;
int best_p_i = -1; // consecutive with 0..n_frac
for (int i = 0; k_heap->n > 0; ++i) {
struct fraction frac = k_heap_pop(k_heap);
if (frac.numer == 0.0f) { break; }
const float v_max_odd = frac.numer * max_odd;
if (wmax * frac.denom > v_max_odd) {
// stop when the inverse scale would result in clamping the most important value
break;
}
// maximize the weighted cosine similarity
const int ii = frac.i;
const float w = weights ? weights[ii] : x[ii] * x[ii];
if (negative_scale) {
frac.numer = -frac.numer;
}
sumlx += w * frac.numer;
suml2 += w * frac.denom;
const float current = sumlx * sumlx;
Laux[ii] += (x[ii] < 0.0f) != negative_scale ? -1 : 1;
if (suml2 > 0.0f && current * best_denom > best * suml2) {
best = current;
best_denom = suml2;
scale = sumlx / suml2;
if (i == best_p_i + 1) {
// reduce copies for consecutive bests
L[ii] += (x[ii] < 0.0f) != negative_scale ? -1 : 1;
} else {
memcpy(L, Laux, n);
}
// maximize the weighted cosine similarity
const int ii = frac.i;
const float w = weights ? weights[ii] : x[ii] * x[ii];
if (negative_scale) {
frac.numer = -frac.numer;
}
sumlx += w * frac.numer;
suml2 += w * frac.denom;
const float current = sumlx * sumlx;
Laux[ii] += (x[ii] < 0.0f) != negative_scale ? -1 : 1;
if (suml2 > 0.0f && current * best_denom > best * suml2) {
best = current;
best_denom = suml2;
scale = sumlx / suml2;
if (i == best_p_i + 1) {
// reduce copies for consecutive bests
L[ii] += (x[ii] < 0.0f) != negative_scale ? -1 : 1;
} else {
memcpy(L, Laux, n);
best_p_i = i;
}
}
return scale;
}
// like make_qkxh_quants, but doesn't assume the sign of the scale is the sign of the absmax value
static float make_qkxsh_quants(int n, int nmin, int nmax, const float * restrict x, const float * restrict weights, int8_t * restrict L, int8_t * restrict Laux, struct k_heap * restrict k_heap) {
nmin = MIN(0, nmin);
nmax = MAX(0, nmax);
// start at zero
float amax = fabsf(x[0]);
float min = x[0];
float max = x[0];
float w_amax = weights[0] * amax;
int w_amax_i = 0;
for (int i = 1; i < n; ++i) {
const float w = weights[i];
const float ax = fabsf(x[i]);
const float wax = w * ax;
if (ax > amax) { amax = ax; }
if (x[i] > max) { max = x[i]; }
if (x[i] < min) { min = x[i]; }
// Find the most important value
if (wax > w_amax) { w_amax = wax; w_amax_i = i; }
}
if (amax < GROUP_MAX_EPS) { // all zero
memset(L, 0, n);
return 0.0f;
}
// Use the side which will clamp first.
// The first clamped value is the absmax at the end of the common range.
int amax_range = MIN(-nmin, nmax);
if (amax_range == 0) {
// One side will always clamp anyway
amax_range = MAX(-nmin, nmax);
}
float sumlx_p = 0.0f;
float suml2_p = 0.0f;
float sumlx_n = 0.0f;
float suml2_n = 0.0f;
float scale = 0.0f;
float best = 0.0f;
float best_denom = 1.0f;
int best_i = -1; // consecutive with 0..n_frac
// Pre-calculate the half-point for the common range.
// All smaller vectors have a representable vector with twice the values, and thus can be skipped.
if (amax_range > 1) {
const float iscale = ((float)((amax_range >> 1) + 1))/amax;
for (int i = 0; i < n; ++i) {
const float w = weights[i];
int l = MAX(nmin, MIN(lroundf(x[i] * iscale), nmax));
Laux[i] = l + k_heap->mid_k;
suml2_p += w * l * l;
sumlx_p += w * l * x[i];
}
sumlx_n = -sumlx_p;
suml2_n = suml2_p;
const float current_p = sumlx_p * sumlx_p;
if (suml2_p > 0.0f) {
best = current_p;
best_denom = suml2_p;
scale = sumlx_p / suml2_p;
best_i = -1; // right before 0 of the loop after sorting
}
} else {
memset(Laux, k_heap->mid_k, n);
}
memcpy(L, Laux, n);
k_heap_set_x_L(k_heap, x, Laux, n, false);
// TODO: make that range sign-aware to reduce the search space
const int imax_range = MAX(nmax, -nmin);
const int max_odd = 2*(imax_range + 1) + 1;
const float wmax = fabsf(x[w_amax_i]);
const float max_common_odd = (MIN(nmax, -nmin) * 2) + 1;
const float max_odd_p = (nmax * 2) + 1;
const float max_odd_n = (-nmin * 2) + 1;
for (int i = 0; k_heap->n > 0; ++i) {
struct fraction frac = k_heap_pop(k_heap);
// maximize the weighted cosine similarity
const int ii = frac.i;
const float w = weights[ii];
const float lx = w * frac.numer;
const float odd = frac.denom;
const float l2 = w * odd;
if (wmax * odd > frac.numer * max_odd) {
// stop when the inverse scale would result in clamping the most important value
break;
}
Laux[ii] += x[ii] < 0.0f ? -1 : 1;
float sumlx;
float proj;
float norm;
if (odd < max_common_odd) {
sumlx_p += lx;
suml2_p += l2;
sumlx_n -= lx;
suml2_n += l2;
sumlx = sumlx_p;
proj = sumlx_p * sumlx_p;
norm = suml2_p;
// avoid double-copying Laux in a single iteration
if (suml2_p != suml2_n && suml2_p * suml2_n > 0.0f) {
const float proj_n = sumlx_n * sumlx_n;
if (proj_n * norm > proj * suml2_n) {
sumlx = sumlx_n;
proj = proj_n;
norm = suml2_n;
}
best_p_i = i;
}
i += 1;
} else if (x[ii] < 0.0f ? odd < max_odd_n : odd < max_odd_p) {
sumlx_p += lx;
suml2_p += l2;
sumlx = sumlx_p;
proj = sumlx_p * sumlx_p;
norm = suml2_p;
} else {
// outside the positive range means we're now into negatives
sumlx_n -= lx;
suml2_n += l2;
sumlx = sumlx_n;
proj = sumlx_n * sumlx_n;
norm = suml2_n;
}
if (norm > 0.0f && proj * best_denom > best * norm) {
best = proj;
best_denom = norm;
scale = sumlx / norm;
if (i == best_i + 1) {
// reduce copies for consecutive bests
L[ii] += x[ii] < 0.0f ? -1 : 1;
} else {
memcpy(L, Laux, n);
}
best_i = i;
}
}
if (scale < 0.0f) {
for (int i = 0; i < n; ++i) {
L[i] = MAX(nmin, MIN(-(L[i] - k_heap->mid_k), nmax)) - nmin;
}
} else {
for (int i = 0; i < n; ++i) {
L[i] = MAX(nmin, MIN(L[i] - k_heap->mid_k, nmax)) - nmin;
}
}
@ -2898,7 +3050,11 @@ static void quantize_row_tq1_0_impl(const float * restrict x, block_tq1_0 * rest
float weight[QK_K];
int8_t L[QK_K];
int8_t Laux[QK_K];
struct fraction Faux[1 * QK_K];
struct k_heap_cell heap_cells[QK_K];
float odd[3];
struct k_heap k_heap;
k_heap_init_linear(&k_heap, -1, 1, heap_cells, odd);
float sum_x2 = 0;
for (int j = 0; j < n_per_row; ++j) { sum_x2 += x[j]*x[j]; }
@ -2910,7 +3066,7 @@ static void quantize_row_tq1_0_impl(const float * restrict x, block_tq1_0 * rest
const float * qw = quant_weights + QK_K * ib;
const int8_t * Lptr = L;
for (int j = 0; j < QK_K; ++j) { weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); }
float d = make_qkxs_quants(QK_K, -1, 1, xb, weight, L, Laux, Faux, false);
float d = make_qkxh_quants(QK_K, xb, weight, L, Laux, &k_heap, false);
y[ib].d = GGML_FP32_TO_FP16(d);
// 5 elements per byte, along 32 bytes
@ -2999,7 +3155,11 @@ static void quantize_row_tq2_0_impl(const float * restrict x, block_tq2_0 * rest
float weight[QK_K];
int8_t L[QK_K];
int8_t Laux[QK_K];
struct fraction Faux[2 * QK_K];
struct k_heap_cell heap_cells[QK_K];
float odd[4 + 1];
struct k_heap k_heap;
k_heap_init_linear(&k_heap, -2, 2, heap_cells, odd);
float sum_x2 = 0;
for (int j = 0; j < n_per_row; ++j) { sum_x2 += x[j]*x[j]; }
@ -3010,7 +3170,7 @@ static void quantize_row_tq2_0_impl(const float * restrict x, block_tq2_0 * rest
const float * xb = x + QK_K * ib;
const float * qw = quant_weights + QK_K * ib;
for (int j = 0; j < QK_K; ++j) { weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); }
float d = make_qkxss_quants(QK_K, -1, 2, xb, weight, L, Laux, Faux);
float d = make_qkxsh_quants(QK_K, -1, 2, xb, weight, L, Laux, &k_heap);
y[ib].d = GGML_FP32_TO_FP16(d);
for (size_t j = 0; j < sizeof(y->qs); j += 32) {