Refactor estimate_lambda()
This commit is contained in:
parent
bdefdb673c
commit
6b8cedf3bc
|
|
@ -975,30 +975,29 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns lambda per slice or 0.0 if no activations
|
// Returns lambda per slice or 0.0 if no activations
|
||||||
auto estimate_lambda = [&](const float * values, const float * activations, const int64_t n_per_row, const int64_t ne2) -> std::vector<float>
|
auto estimate_lambda = [&](const float * values, const float * activations, const int64_t n_per_row, const int64_t ne2) -> std::vector<float> {
|
||||||
{
|
const int64_t ns = std::max<int64_t>(1, ne2);
|
||||||
std::vector<float> lambdas(std::max<int64_t>(1, ne2), 0.0f);
|
std::vector<float> lambdas(ns, 0.0f);
|
||||||
if (!activations) { return lambdas; }
|
if (!activations) { return lambdas; }
|
||||||
|
|
||||||
for (int64_t s = 0; s < std::max<int64_t>(1, ne2); ++s) {
|
for (int64_t s = 0; s < ns; ++s) {
|
||||||
const float * v = values ? values + s * n_per_row : nullptr;
|
const float * v = values ? values + s * n_per_row : nullptr;
|
||||||
const float * a = activations + s * n_per_row;
|
const float * a = activations + s * n_per_row;
|
||||||
double s1 = 0.0;
|
double s1 = 0.0;
|
||||||
double s2 = 0.0;
|
double s2 = 0.0;
|
||||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||||
const double w = v ? std::max(0.0f, v[j]) : 1.0;
|
const double w = v ? std::max(0.0f, v[j]) : 1.0;
|
||||||
const double aw = std::sqrt(w) * a[j];
|
const double aw2 = std::sqrt(w) * a[j];
|
||||||
const double aw2 = aw * aw;
|
const double z = aw2 * aw2;
|
||||||
s1 += aw2;
|
s1 += z;
|
||||||
s2 += aw2 * aw2;
|
s2 += z * z;
|
||||||
}
|
}
|
||||||
|
|
||||||
float l = 0.0f;
|
float l = 0.0f;
|
||||||
if (s1 > 0.0) {
|
if (s1 > 0.0) {
|
||||||
const auto n = (double)n_per_row;
|
const auto n = (double)n_per_row;
|
||||||
const double c = std::max(0.0, s2 / (s1 * s1 + epsilon) - 1.0 / n);
|
const double c = std::max(0.0, s2 / (s1 * s1 + epsilon) - 1.0 / n);
|
||||||
double lambda = 8.0 * (c / (c + 1.0));
|
l = (float) std::clamp(8.0 * (c / (c + 1.0)), 0.0, 12.0);
|
||||||
l = (float)std::clamp(lambda, 0.0, 12.0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
lambdas[(size_t)s] = l;
|
lambdas[(size_t)s] = l;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue