From 8df1d00ae4042a1eee38c1fc9ac06137d5ce5078 Mon Sep 17 00:00:00 2001 From: Ed Addario Date: Thu, 28 Aug 2025 16:04:28 +0100 Subject: [PATCH] Add directional scaling --- src/llama-quant.cpp | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index da1267ddbc..a9621eab8e 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -900,6 +900,27 @@ static std::unordered_map target_bpw_type( return std::isfinite(total_err) ? total_err : 1e35; }; + auto directional_scale = [&](const float * values, const float * activations, int64_t n_per_row) { + if (!activations) { return 1.0f; } + // Compute dominance = ||sqrt(v).*a||_2 / (RMS(a)*sqrt(sum(v))) + // If no values, use v=1 + double sum_v = 0.0; + double sum_aw2 = 0.0; + double sum_a2 = 0.0; + for (int64_t j = 0; j < n_per_row; ++j) { + const double v = values ? std::max(0.0f, values[j]) : 1.0; + const double a = activations[j]; + sum_v += v; + sum_aw2 += v * a * a; + sum_a2 += a * a; + } + const double rms_a = std::sqrt(sum_a2 / std::max(1.0, (double)n_per_row)); + const double denom = std::sqrt(std::max(epsilon, sum_v)) * std::max(epsilon, rms_a); + const double scale = denom > 0.0 ? std::sqrt(sum_aw2) / denom : 1.0; + + // Clamp to a reasonable range + return (float)std::clamp(scale, 0.5, 2.0); + }; std::vector all; all.reserve(tensors.size()); for (const auto * tw : tensors) {