diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 3911eba43b..a4a10da062 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -730,7 +730,7 @@ static std::unordered_map target_bpw_type( std::vector f32_sample(sample_rows * n_per_row); std::vector deq(sample_rows * n_per_row); - float total_err = 0.0; + double total_err = 0.0; for (int64_t slice = 0; slice < ne2; ++slice) { const float * value = values_all ? (values_all + slice * n_per_row) : nullptr; @@ -754,9 +754,9 @@ static std::unordered_map target_bpw_type( const float * xs = f32_sample.data() + s * n_per_row; const float * ys = deq.data() + s * n_per_row; - float mse_w = 0.0; - float bias = 0.0; - float bias_sum = 0.0; + double mse_w = 0.0; + double bias = 0.0; + double bias_sum = 0.0; if (value) { for (int64_t j = 0; j < n_per_row; ++j) { @@ -769,19 +769,17 @@ static std::unordered_map target_bpw_type( } else { for (int64_t j = 0; j < n_per_row; ++j) { const float e = ys[j] - xs[j]; - mse_w += e*e; + mse_w += e * e; if (activation) { bias_sum += e * activation[j]; } } } - if (activation) { - bias = std::abs(bias_sum); - } + if (activation) { bias = std::abs(bias_sum); } // Normalize by n_per_row to get a per-row average scale - float row_err = mse_w / std::max(1, n_per_row); + double row_err = mse_w / std::max(1, n_per_row); if (bias_lambda != 0.0) { row_err += bias_lambda * (bias / std::max(1, n_per_row)); } @@ -790,11 +788,11 @@ static std::unordered_map target_bpw_type( } // Scale for the rows we didn't sample in this expert: multiply by stride-ish factor - const float scale_rows = (float)rows_per_expert / std::max(1.0f, (float)rs); + const auto scale_rows = (double)rows_per_expert / std::max(1.0, (double)rs); total_err *= scale_rows; } - return total_err; + return std::isfinite(total_err) ? total_err : 1e35; }; std::vector all;