From 6b8cedf3bcd2282e9f31b00026178d6bb393fc3e Mon Sep 17 00:00:00 2001 From: Ed Addario Date: Sun, 21 Sep 2025 13:42:31 +0100 Subject: [PATCH] Refactor estimate_lambda() --- src/llama-quant.cpp | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index b1302df431..ebacf68806 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -975,30 +975,29 @@ static std::unordered_map target_bpw_type( }; // Returns lambda per slice or 0.0 if no activations - auto estimate_lambda = [&](const float * values, const float * activations, const int64_t n_per_row, const int64_t ne2) -> std::vector - { - std::vector lambdas(std::max(1, ne2), 0.0f); + auto estimate_lambda = [&](const float * values, const float * activations, const int64_t n_per_row, const int64_t ne2) -> std::vector { + const int64_t ns = std::max(1, ne2); + std::vector lambdas(ns, 0.0f); if (!activations) { return lambdas; } - for (int64_t s = 0; s < std::max(1, ne2); ++s) { + for (int64_t s = 0; s < ns; ++s) { const float * v = values ? values + s * n_per_row : nullptr; const float * a = activations + s * n_per_row; double s1 = 0.0; double s2 = 0.0; for (int64_t j = 0; j < n_per_row; ++j) { const double w = v ? std::max(0.0f, v[j]) : 1.0; - const double aw = std::sqrt(w) * a[j]; - const double aw2 = aw * aw; - s1 += aw2; - s2 += aw2 * aw2; + const double aw2 = std::sqrt(w) * a[j]; + const double z = aw2 * aw2; + s1 += z; + s2 += z * z; } float l = 0.0f; if (s1 > 0.0) { const auto n = (double)n_per_row; const double c = std::max(0.0, s2 / (s1 * s1 + epsilon) - 1.0 / n); - double lambda = 8.0 * (c / (c + 1.0)); - l = (float)std::clamp(lambda, 0.0, 12.0); + l = (float) std::clamp(8.0 * (c / (c + 1.0)), 0.0, 12.0); } lambdas[(size_t)s] = l;