From 3e4b675c9f0f955965ad1ea3a927a57eee582b6d Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 22 Mar 2025 12:03:26 -0400 Subject: [PATCH] ggml-quants : use a max-heap for TQ1_0 and TQ2_0 quantization --- ggml/src/ggml-quants.c | 238 ++++++++++++++++++++++++++++++++++------- 1 file changed, 199 insertions(+), 39 deletions(-) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 5e25887f10..9be87a671c 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -1007,8 +1007,8 @@ static float make_qkxs_quants(int n, int nmin, int nmax, const float * restrict // exhaustive search with cumulative sums static float make_qkxh_quants(int n, const float * restrict x, const float * restrict weights, int8_t * restrict L, int8_t * restrict Laux, struct k_heap * restrict k_heap, bool signed_scale) { - const int nmin = -k_heap->mid_k; // TODO: maybe directly pass these - const int nmax = k_heap->k + nmin - 1; + const int nmin = MIN(0, -k_heap->mid_k); // TODO: maybe directly pass these + const int nmax = MAX(0, k_heap->k + nmin - 1); float amax = fabsf(x[0]); float w_amax = (weights ? weights[0] : x[0] * x[0]) * amax; int amax_i = 0; @@ -1046,10 +1046,10 @@ static float make_qkxh_quants(int n, const float * restrict x, const float * res // Find the max range in [0, amax_range] which doesn't result in clamping. // This is the range from the side which would clamp first (biggest ratio of max to nmax). // But it's easier and safer to simply use the smallest range. - int amax_range = MIN(abs(nmin), abs(nmax)); + int amax_range = MIN(-nmin, nmax); if (amax_range == 0) { // one side will clamp anyway - amax_range = MAX(abs(nmin), abs(nmax)); + amax_range = MAX(-nmin, nmax); } float sumlx = 0.0f; float suml2 = 0.0f; @@ -1087,40 +1087,192 @@ static float make_qkxh_quants(int n, const float * restrict x, const float * res const int max_odd = 2*(imax_range + 1) + 1; const float wmax = fabsf(x[w_amax_i]); // const float wmax = amax; - { - int best_p_i = -1; // consecutive with 0..n_frac - int i = 0; - while (k_heap->n > 0) { - struct fraction frac = k_heap_pop(k_heap); - if (frac.numer == 0.0f) { break; } - const float v_max_odd = frac.numer * max_odd; - if (wmax * frac.denom > v_max_odd) { - // stop when the inverse scale would result in clamping the most important value - break; + int best_p_i = -1; // consecutive with 0..n_frac + for (int i = 0; k_heap->n > 0; ++i) { + struct fraction frac = k_heap_pop(k_heap); + if (frac.numer == 0.0f) { break; } + const float v_max_odd = frac.numer * max_odd; + if (wmax * frac.denom > v_max_odd) { + // stop when the inverse scale would result in clamping the most important value + break; + } + // maximize the weighted cosine similarity + const int ii = frac.i; + const float w = weights ? weights[ii] : x[ii] * x[ii]; + if (negative_scale) { + frac.numer = -frac.numer; + } + sumlx += w * frac.numer; + suml2 += w * frac.denom; + const float current = sumlx * sumlx; + Laux[ii] += (x[ii] < 0.0f) != negative_scale ? -1 : 1; + if (suml2 > 0.0f && current * best_denom > best * suml2) { + best = current; + best_denom = suml2; + scale = sumlx / suml2; + if (i == best_p_i + 1) { + // reduce copies for consecutive bests + L[ii] += (x[ii] < 0.0f) != negative_scale ? -1 : 1; + } else { + memcpy(L, Laux, n); } - // maximize the weighted cosine similarity - const int ii = frac.i; - const float w = weights ? weights[ii] : x[ii] * x[ii]; - if (negative_scale) { - frac.numer = -frac.numer; - } - sumlx += w * frac.numer; - suml2 += w * frac.denom; - const float current = sumlx * sumlx; - Laux[ii] += (x[ii] < 0.0f) != negative_scale ? -1 : 1; - if (suml2 > 0.0f && current * best_denom > best * suml2) { - best = current; - best_denom = suml2; - scale = sumlx / suml2; - if (i == best_p_i + 1) { - // reduce copies for consecutive bests - L[ii] += (x[ii] < 0.0f) != negative_scale ? -1 : 1; - } else { - memcpy(L, Laux, n); + best_p_i = i; + } + } + + return scale; +} + +// like make_qkxh_quants, but doesn't assume the sign of the scale is the sign of the absmax value +static float make_qkxsh_quants(int n, int nmin, int nmax, const float * restrict x, const float * restrict weights, int8_t * restrict L, int8_t * restrict Laux, struct k_heap * restrict k_heap) { + nmin = MIN(0, nmin); + nmax = MAX(0, nmax); + // start at zero + float amax = fabsf(x[0]); + float min = x[0]; + float max = x[0]; + float w_amax = weights[0] * amax; + int w_amax_i = 0; + for (int i = 1; i < n; ++i) { + const float w = weights[i]; + const float ax = fabsf(x[i]); + const float wax = w * ax; + if (ax > amax) { amax = ax; } + if (x[i] > max) { max = x[i]; } + if (x[i] < min) { min = x[i]; } + // Find the most important value + if (wax > w_amax) { w_amax = wax; w_amax_i = i; } + } + + if (amax < GROUP_MAX_EPS) { // all zero + memset(L, 0, n); + return 0.0f; + } + + // Use the side which will clamp first. + // The first clamped value is the absmax at the end of the common range. + int amax_range = MIN(-nmin, nmax); + if (amax_range == 0) { + // One side will always clamp anyway + amax_range = MAX(-nmin, nmax); + } + float sumlx_p = 0.0f; + float suml2_p = 0.0f; + float sumlx_n = 0.0f; + float suml2_n = 0.0f; + float scale = 0.0f; + float best = 0.0f; + float best_denom = 1.0f; + int best_i = -1; // consecutive with 0..n_frac + // Pre-calculate the half-point for the common range. + // All smaller vectors have a representable vector with twice the values, and thus can be skipped. + if (amax_range > 1) { + const float iscale = ((float)((amax_range >> 1) + 1))/amax; + for (int i = 0; i < n; ++i) { + const float w = weights[i]; + int l = MAX(nmin, MIN(lroundf(x[i] * iscale), nmax)); + Laux[i] = l + k_heap->mid_k; + suml2_p += w * l * l; + sumlx_p += w * l * x[i]; + } + sumlx_n = -sumlx_p; + suml2_n = suml2_p; + const float current_p = sumlx_p * sumlx_p; + if (suml2_p > 0.0f) { + best = current_p; + best_denom = suml2_p; + scale = sumlx_p / suml2_p; + best_i = -1; // right before 0 of the loop after sorting + } + } else { + memset(Laux, k_heap->mid_k, n); + } + memcpy(L, Laux, n); + + k_heap_set_x_L(k_heap, x, Laux, n, false); + + // TODO: make that range sign-aware to reduce the search space + const int imax_range = MAX(nmax, -nmin); + const int max_odd = 2*(imax_range + 1) + 1; + const float wmax = fabsf(x[w_amax_i]); + + const float max_common_odd = (MIN(nmax, -nmin) * 2) + 1; + const float max_odd_p = (nmax * 2) + 1; + const float max_odd_n = (-nmin * 2) + 1; + for (int i = 0; k_heap->n > 0; ++i) { + struct fraction frac = k_heap_pop(k_heap); + // maximize the weighted cosine similarity + const int ii = frac.i; + const float w = weights[ii]; + const float lx = w * frac.numer; + const float odd = frac.denom; + const float l2 = w * odd; + if (wmax * odd > frac.numer * max_odd) { + // stop when the inverse scale would result in clamping the most important value + break; + } + + Laux[ii] += x[ii] < 0.0f ? -1 : 1; + + float sumlx; + float proj; + float norm; + if (odd < max_common_odd) { + sumlx_p += lx; + suml2_p += l2; + sumlx_n -= lx; + suml2_n += l2; + + sumlx = sumlx_p; + proj = sumlx_p * sumlx_p; + norm = suml2_p; + + // avoid double-copying Laux in a single iteration + if (suml2_p != suml2_n && suml2_p * suml2_n > 0.0f) { + const float proj_n = sumlx_n * sumlx_n; + if (proj_n * norm > proj * suml2_n) { + sumlx = sumlx_n; + proj = proj_n; + norm = suml2_n; } - best_p_i = i; } - i += 1; + } else if (x[ii] < 0.0f ? odd < max_odd_n : odd < max_odd_p) { + sumlx_p += lx; + suml2_p += l2; + + sumlx = sumlx_p; + proj = sumlx_p * sumlx_p; + norm = suml2_p; + } else { + // outside the positive range means we're now into negatives + sumlx_n -= lx; + suml2_n += l2; + + sumlx = sumlx_n; + proj = sumlx_n * sumlx_n; + norm = suml2_n; + } + if (norm > 0.0f && proj * best_denom > best * norm) { + best = proj; + best_denom = norm; + scale = sumlx / norm; + if (i == best_i + 1) { + // reduce copies for consecutive bests + L[ii] += x[ii] < 0.0f ? -1 : 1; + } else { + memcpy(L, Laux, n); + } + best_i = i; + } + } + + if (scale < 0.0f) { + for (int i = 0; i < n; ++i) { + L[i] = MAX(nmin, MIN(-(L[i] - k_heap->mid_k), nmax)) - nmin; + } + } else { + for (int i = 0; i < n; ++i) { + L[i] = MAX(nmin, MIN(L[i] - k_heap->mid_k, nmax)) - nmin; } } @@ -2898,7 +3050,11 @@ static void quantize_row_tq1_0_impl(const float * restrict x, block_tq1_0 * rest float weight[QK_K]; int8_t L[QK_K]; int8_t Laux[QK_K]; - struct fraction Faux[1 * QK_K]; + struct k_heap_cell heap_cells[QK_K]; + float odd[3]; + struct k_heap k_heap; + + k_heap_init_linear(&k_heap, -1, 1, heap_cells, odd); float sum_x2 = 0; for (int j = 0; j < n_per_row; ++j) { sum_x2 += x[j]*x[j]; } @@ -2910,7 +3066,7 @@ static void quantize_row_tq1_0_impl(const float * restrict x, block_tq1_0 * rest const float * qw = quant_weights + QK_K * ib; const int8_t * Lptr = L; for (int j = 0; j < QK_K; ++j) { weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); } - float d = make_qkxs_quants(QK_K, -1, 1, xb, weight, L, Laux, Faux, false); + float d = make_qkxh_quants(QK_K, xb, weight, L, Laux, &k_heap, false); y[ib].d = GGML_FP32_TO_FP16(d); // 5 elements per byte, along 32 bytes @@ -2999,7 +3155,11 @@ static void quantize_row_tq2_0_impl(const float * restrict x, block_tq2_0 * rest float weight[QK_K]; int8_t L[QK_K]; int8_t Laux[QK_K]; - struct fraction Faux[2 * QK_K]; + struct k_heap_cell heap_cells[QK_K]; + float odd[4 + 1]; + struct k_heap k_heap; + + k_heap_init_linear(&k_heap, -2, 2, heap_cells, odd); float sum_x2 = 0; for (int j = 0; j < n_per_row; ++j) { sum_x2 += x[j]*x[j]; } @@ -3010,7 +3170,7 @@ static void quantize_row_tq2_0_impl(const float * restrict x, block_tq2_0 * rest const float * xb = x + QK_K * ib; const float * qw = quant_weights + QK_K * ib; for (int j = 0; j < QK_K; ++j) { weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); } - float d = make_qkxss_quants(QK_K, -1, 2, xb, weight, L, Laux, Faux); + float d = make_qkxsh_quants(QK_K, -1, 2, xb, weight, L, Laux, &k_heap); y[ib].d = GGML_FP32_TO_FP16(d); for (size_t j = 0; j < sizeof(y->qs); j += 32) {