Fix MoE tensor estimation

This commit is contained in:
Ed Addario 2025-09-14 22:38:27 +01:00
parent 8503d59ee4
commit c709e1a335
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 28 additions and 17 deletions

View File

@ -1021,27 +1021,38 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
};
// Faster to compute but may yield lower precision. Best option for the vast majority of cases
auto fast_lambda = [&](const float * values, const float * activations, const int64_t n_per_row) {
auto fast_lambda = [&](const float * values, const float * activations, const int64_t n_per_row, const int64_t ne2) {
if (!activations) { return 0.0f; }
double s = 0.0;
double accum = 0.0;
int ns = 0;
for (int64_t s = 0; s < std::max<int64_t>(1, ne2); ++s) {
const float * v = values ? values + s * n_per_row : nullptr;
const float * a = activations + s * n_per_row;
double s1 = 0.0;
double s2 = 0.0;
for (int64_t j = 0; j < n_per_row; ++j) {
const double w = values ? std::max(0.0f, values[j]) : 1.0;
const double aw = std::sqrt(w) * activations[j];
const double w = v ? std::max(0.0f, v[j]) : 1.0;
const double aw = std::sqrt(w) * a[j];
const double aw2 = aw * aw;
s += aw2;
s1 += aw2;
s2 += aw2 * aw2;
}
if (s2 <= 0.0) { return 0.0f; }
const auto d = (double)n_per_row;
double base = 1.0 - s * s / (d * s2 + epsilon);
base = std::clamp(base, 0.0, 1.0);
if (s1 > 0.0) {
const double n = (double)n_per_row;
double c = std::max(0.0, s2 / (s1 * s1 + epsilon) - 1.0 / n);
double lambda = 8.0 * (c / (c + 1.0));
accum += std::clamp(lambda, 0.0, 8.0);
++ns;
}
}
const double lambda = std::clamp(base, 0.0, 1.0) * 8.0;
if (ns == 0) { return 0.0f; }
return (float)lambda;
return (float)(accum / ns);
};
std::vector<tensor_info> all;
@ -1190,7 +1201,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
const float * values = values_sample.empty() ? nullptr : values_sample.data();
const float * activations = activations_sample.empty() ? nullptr : activations_sample.data();
if (params->bpw_bias == 1) {
bias_lambda = fast_lambda(values, activations, n_per_row);
bias_lambda = fast_lambda(values, activations, n_per_row, ne2);
} else if (params->bpw_bias == 2) {
bias_lambda = precise_lambda(t, f32_sample, sample_rows_per_slice, values, activations, compatible_candidates);
}