From c709e1a3353cbefbe58320c2eae1a1edafc0f618 Mon Sep 17 00:00:00 2001 From: Ed Addario Date: Sun, 14 Sep 2025 22:38:27 +0100 Subject: [PATCH] Fix MoE tensor estimation --- src/llama-quant.cpp | 45 ++++++++++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 41fd819f86..1efb1c5eee 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -1021,27 +1021,38 @@ static std::unordered_map target_bpw_type( }; // Faster to compute but may yield lower precision. Best option for the vast majority of cases - auto fast_lambda = [&](const float * values, const float * activations, const int64_t n_per_row) { + auto fast_lambda = [&](const float * values, const float * activations, const int64_t n_per_row, const int64_t ne2) { if (!activations) { return 0.0f; } - double s = 0.0; - double s2 = 0.0; - for (int64_t j = 0; j < n_per_row; ++j) { - const double w = values ? std::max(0.0f, values[j]) : 1.0; - const double aw = std::sqrt(w) * activations[j]; - const double aw2 = aw * aw; - s += aw2; - s2 += aw2 * aw2; + double accum = 0.0; + int ns = 0; + + for (int64_t s = 0; s < std::max(1, ne2); ++s) { + const float * v = values ? values + s * n_per_row : nullptr; + const float * a = activations + s * n_per_row; + + double s1 = 0.0; + double s2 = 0.0; + for (int64_t j = 0; j < n_per_row; ++j) { + const double w = v ? std::max(0.0f, v[j]) : 1.0; + const double aw = std::sqrt(w) * a[j]; + const double aw2 = aw * aw; + s1 += aw2; + s2 += aw2 * aw2; + } + + if (s1 > 0.0) { + const double n = (double)n_per_row; + double c = std::max(0.0, s2 / (s1 * s1 + epsilon) - 1.0 / n); + double lambda = 8.0 * (c / (c + 1.0)); + accum += std::clamp(lambda, 0.0, 8.0); + ++ns; + } } - if (s2 <= 0.0) { return 0.0f; } - const auto d = (double)n_per_row; - double base = 1.0 - s * s / (d * s2 + epsilon); - base = std::clamp(base, 0.0, 1.0); + if (ns == 0) { return 0.0f; } - const double lambda = std::clamp(base, 0.0, 1.0) * 8.0; - - return (float)lambda; + return (float)(accum / ns); }; std::vector all; @@ -1190,7 +1201,7 @@ static std::unordered_map target_bpw_type( const float * values = values_sample.empty() ? nullptr : values_sample.data(); const float * activations = activations_sample.empty() ? nullptr : activations_sample.data(); if (params->bpw_bias == 1) { - bias_lambda = fast_lambda(values, activations, n_per_row); + bias_lambda = fast_lambda(values, activations, n_per_row, ne2); } else if (params->bpw_bias == 2) { bias_lambda = precise_lambda(t, f32_sample, sample_rows_per_slice, values, activations, compatible_candidates); }