Increase precision for error calculation
This commit is contained in:
parent
f22b3097eb
commit
936294f6af
|
|
@ -730,7 +730,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
std::vector<float> f32_sample(sample_rows * n_per_row);
|
std::vector<float> f32_sample(sample_rows * n_per_row);
|
||||||
std::vector<float> deq(sample_rows * n_per_row);
|
std::vector<float> deq(sample_rows * n_per_row);
|
||||||
|
|
||||||
float total_err = 0.0;
|
double total_err = 0.0;
|
||||||
|
|
||||||
for (int64_t slice = 0; slice < ne2; ++slice) {
|
for (int64_t slice = 0; slice < ne2; ++slice) {
|
||||||
const float * value = values_all ? (values_all + slice * n_per_row) : nullptr;
|
const float * value = values_all ? (values_all + slice * n_per_row) : nullptr;
|
||||||
|
|
@ -754,9 +754,9 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
const float * xs = f32_sample.data() + s * n_per_row;
|
const float * xs = f32_sample.data() + s * n_per_row;
|
||||||
const float * ys = deq.data() + s * n_per_row;
|
const float * ys = deq.data() + s * n_per_row;
|
||||||
|
|
||||||
float mse_w = 0.0;
|
double mse_w = 0.0;
|
||||||
float bias = 0.0;
|
double bias = 0.0;
|
||||||
float bias_sum = 0.0;
|
double bias_sum = 0.0;
|
||||||
|
|
||||||
if (value) {
|
if (value) {
|
||||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||||
|
|
@ -769,19 +769,17 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
} else {
|
} else {
|
||||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||||
const float e = ys[j] - xs[j];
|
const float e = ys[j] - xs[j];
|
||||||
mse_w += e*e;
|
mse_w += e * e;
|
||||||
if (activation) {
|
if (activation) {
|
||||||
bias_sum += e * activation[j];
|
bias_sum += e * activation[j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (activation) {
|
if (activation) { bias = std::abs(bias_sum); }
|
||||||
bias = std::abs(bias_sum);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Normalize by n_per_row to get a per-row average scale
|
// Normalize by n_per_row to get a per-row average scale
|
||||||
float row_err = mse_w / std::max<int64_t>(1, n_per_row);
|
double row_err = mse_w / std::max<int64_t>(1, n_per_row);
|
||||||
if (bias_lambda != 0.0) {
|
if (bias_lambda != 0.0) {
|
||||||
row_err += bias_lambda * (bias / std::max<int64_t>(1, n_per_row));
|
row_err += bias_lambda * (bias / std::max<int64_t>(1, n_per_row));
|
||||||
}
|
}
|
||||||
|
|
@ -790,11 +788,11 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scale for the rows we didn't sample in this expert: multiply by stride-ish factor
|
// Scale for the rows we didn't sample in this expert: multiply by stride-ish factor
|
||||||
const float scale_rows = (float)rows_per_expert / std::max(1.0f, (float)rs);
|
const auto scale_rows = (double)rows_per_expert / std::max(1.0, (double)rs);
|
||||||
total_err *= scale_rows;
|
total_err *= scale_rows;
|
||||||
}
|
}
|
||||||
|
|
||||||
return total_err;
|
return std::isfinite(total_err) ? total_err : 1e35;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<tensor_info> all;
|
std::vector<tensor_info> all;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue