Improve precise_lambda() efficiency
This commit is contained in:
parent
bc8762f27f
commit
4dff85fbe5
|
|
@ -725,7 +725,9 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
const float * activations_sample,
|
const float * activations_sample,
|
||||||
std::vector<uint8_t> & quantized_buffer,
|
std::vector<uint8_t> & quantized_buffer,
|
||||||
std::vector<float> & dequantized_buffer,
|
std::vector<float> & dequantized_buffer,
|
||||||
float bias_lambda) -> double
|
float bias_lambda,
|
||||||
|
double * out_mse = nullptr,
|
||||||
|
double * out_proj = nullptr) -> double
|
||||||
{
|
{
|
||||||
const int64_t n_per_row = t->ne[0];
|
const int64_t n_per_row = t->ne[0];
|
||||||
const int64_t nrows = t->ne[1];
|
const int64_t nrows = t->ne[1];
|
||||||
|
|
@ -733,13 +735,23 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
|
|
||||||
const size_t sample_element_count = f32_sample.size();
|
const size_t sample_element_count = f32_sample.size();
|
||||||
const size_t sample_row_count = n_per_row > 0 ? sample_element_count / (size_t)n_per_row : 0;
|
const size_t sample_row_count = n_per_row > 0 ? sample_element_count / (size_t)n_per_row : 0;
|
||||||
if (sample_row_count == 0) { return 0.0; }
|
if (sample_row_count == 0) {
|
||||||
|
if (out_mse) { *out_mse = 0.0; }
|
||||||
|
if (out_proj) { *out_proj = 0.0; }
|
||||||
|
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
size_t expected_rows = 0;
|
size_t expected_rows = 0;
|
||||||
for (int64_t s = 0; s < ne2; ++s) {
|
for (int64_t s = 0; s < ne2; ++s) {
|
||||||
expected_rows += (size_t)sample_rows_per_slice[s];
|
expected_rows += (size_t)sample_rows_per_slice[s];
|
||||||
}
|
}
|
||||||
if (expected_rows != sample_row_count) { return infinity; }
|
if (expected_rows != sample_row_count) {
|
||||||
|
if (out_mse) { *out_mse = infinity; }
|
||||||
|
if (out_proj) { *out_proj = 0.0; }
|
||||||
|
|
||||||
|
return infinity;
|
||||||
|
}
|
||||||
|
|
||||||
const size_t row_sz = ggml_row_size(quant_type, n_per_row);
|
const size_t row_sz = ggml_row_size(quant_type, n_per_row);
|
||||||
const size_t buffer_sz = row_sz * sample_row_count;
|
const size_t buffer_sz = row_sz * sample_row_count;
|
||||||
|
|
@ -750,7 +762,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
const bool has_values = values_sample != nullptr;
|
const bool has_values = values_sample != nullptr;
|
||||||
const bool has_activations = activations_sample != nullptr;
|
const bool has_activations = activations_sample != nullptr;
|
||||||
|
|
||||||
// Bias denominators per slice (only needed if we have activations)
|
// Bias denominators per slice
|
||||||
std::vector<double> bias_denominator_per_slice(ne2, 0.0);
|
std::vector<double> bias_denominator_per_slice(ne2, 0.0);
|
||||||
if (has_activations) {
|
if (has_activations) {
|
||||||
for (int64_t s = 0; s < ne2; ++s) {
|
for (int64_t s = 0; s < ne2; ++s) {
|
||||||
|
|
@ -815,7 +827,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
// quantized_buffer -> dequantized_buffer
|
// quantized_buffer -> dequantized_buffer
|
||||||
{
|
{
|
||||||
const ggml_type_traits * traits = ggml_get_type_traits(quant_type);
|
const ggml_type_traits * traits = ggml_get_type_traits(quant_type);
|
||||||
|
|
||||||
const bool is_fp16 = quant_type == GGML_TYPE_F16;
|
const bool is_fp16 = quant_type == GGML_TYPE_F16;
|
||||||
const bool is_bf16 = quant_type == GGML_TYPE_BF16;
|
const bool is_bf16 = quant_type == GGML_TYPE_BF16;
|
||||||
if (!is_fp16 && !is_bf16 && traits && traits->to_float) {
|
if (!is_fp16 && !is_bf16 && traits && traits->to_float) {
|
||||||
|
|
@ -825,12 +836,19 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
uint8_t * src = quantized_buffer.data() + r * row_sz;
|
uint8_t * src = quantized_buffer.data() + r * row_sz;
|
||||||
float * dst = dequantized_buffer.data() + r * (size_t) n_per_row;
|
float * dst = dequantized_buffer.data() + r * (size_t) n_per_row;
|
||||||
if (is_fp16) {
|
if (is_fp16) {
|
||||||
ggml_fp16_to_fp32_row((const ggml_fp16_t *) src, dst, (int)n_per_row);
|
ggml_fp16_to_fp32_row((const ggml_fp16_t *) src, dst, (int) n_per_row);
|
||||||
} else if (is_bf16) {
|
}
|
||||||
ggml_bf16_to_fp32_row((const ggml_bf16_t *) src, dst, (int)n_per_row);
|
else if (is_bf16) {
|
||||||
} else {
|
ggml_bf16_to_fp32_row((const ggml_bf16_t *) src, dst, (int) n_per_row);
|
||||||
if (!traits || !traits->to_float) { return infinity; }
|
}
|
||||||
traits->to_float(src, dst, (int)n_per_row);
|
else {
|
||||||
|
if (!traits || !traits->to_float) {
|
||||||
|
if (out_mse) { *out_mse = infinity; }
|
||||||
|
if (out_proj) { *out_proj = 0.0; }
|
||||||
|
|
||||||
|
return infinity;
|
||||||
|
}
|
||||||
|
traits->to_float(src, dst, (int) n_per_row);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -839,8 +857,8 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
// Compute error
|
// Compute error
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
size_t row_idx = 0;
|
size_t row_idx = 0;
|
||||||
double total_err = 0.0;
|
double total_mse = 0.0;
|
||||||
|
double total_proj = 0.0;
|
||||||
for (int64_t slice = 0; slice < ne2; ++slice) {
|
for (int64_t slice = 0; slice < ne2; ++slice) {
|
||||||
const int64_t rs = sample_rows_per_slice[slice];
|
const int64_t rs = sample_rows_per_slice[slice];
|
||||||
if (rs == 0) { continue; }
|
if (rs == 0) { continue; }
|
||||||
|
|
@ -848,7 +866,11 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
const float * values = has_values ? values_sample + slice * n_per_row : nullptr;
|
const float * values = has_values ? values_sample + slice * n_per_row : nullptr;
|
||||||
const float * activations = has_activations ? activations_sample + slice * n_per_row : nullptr;
|
const float * activations = has_activations ? activations_sample + slice * n_per_row : nullptr;
|
||||||
const double bias_denom = has_activations ? bias_denominator_per_slice[slice] : 0.0;
|
const double bias_denom = has_activations ? bias_denominator_per_slice[slice] : 0.0;
|
||||||
double slice_err = 0.0;
|
std::vector<double> row_mse_norm;
|
||||||
|
std::vector<double> row_proj_norm;
|
||||||
|
row_mse_norm.reserve(rs);
|
||||||
|
if (activations) { row_proj_norm.reserve(rs); }
|
||||||
|
|
||||||
for (int64_t r = 0; r < rs; ++r, ++row_idx) {
|
for (int64_t r = 0; r < rs; ++r, ++row_idx) {
|
||||||
const float * x = f32_sample.data() + offset;
|
const float * x = f32_sample.data() + offset;
|
||||||
const float * y = dequantized_buffer.data() + offset;
|
const float * y = dequantized_buffer.data() + offset;
|
||||||
|
|
@ -868,13 +890,6 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
const double e = y[j] - x[j];
|
const double e = y[j] - x[j];
|
||||||
weighted_mse += w * e * e;
|
weighted_mse += w * e * e;
|
||||||
}
|
}
|
||||||
} else if (activations) {
|
|
||||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
|
||||||
const double e = y[j] - x[j];
|
|
||||||
const double a = activations[j];
|
|
||||||
weighted_mse += e * e;
|
|
||||||
bias_num += e * a;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||||
const double e = y[j] - x[j];
|
const double e = y[j] - x[j];
|
||||||
|
|
@ -882,28 +897,64 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
double err_num = weighted_mse;
|
const double denom_x = row_sq_norm[row_idx];
|
||||||
if (activations && bias_lambda != 0.0f) {
|
double m_norm = weighted_mse / (denom_x + epsilon);
|
||||||
|
row_mse_norm.push_back(std::isfinite(m_norm) ? m_norm : infinity);
|
||||||
|
|
||||||
|
if (activations) {
|
||||||
|
double p_norm = 0.0;
|
||||||
if (bias_denom > 0.0) {
|
if (bias_denom > 0.0) {
|
||||||
const double proj = bias_num * bias_num / (bias_denom + epsilon);
|
const double proj = bias_num * bias_num / (bias_denom + epsilon);
|
||||||
err_num += bias_lambda * proj;
|
p_norm = std::isfinite(proj) ? proj : 0.0;
|
||||||
}
|
}
|
||||||
|
row_proj_norm.push_back(p_norm);
|
||||||
}
|
}
|
||||||
|
|
||||||
const double denom = row_sq_norm[row_idx] + epsilon;
|
|
||||||
slice_err += err_num / denom;
|
|
||||||
offset += (size_t)n_per_row;
|
offset += (size_t)n_per_row;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Trimmed sum to avoid outlier rows dominating the results
|
||||||
|
auto trimmed_sum = [&](std::vector<double> & v) -> double {
|
||||||
|
if (v.empty()) { return 0.0; }
|
||||||
|
const int64_t n = (int64_t)v.size();
|
||||||
|
if (n < 50) {
|
||||||
|
double s = 0.0;
|
||||||
|
for (const double z : v) { s += z; }
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t k = (int64_t) std::floor(0.02 * (double)n); // trim 2% on each side
|
||||||
|
k = std::max<int64_t>(0, std::min<int64_t>(k, n / 32)); // but not more than 3.125%
|
||||||
|
std::nth_element(v.begin(), v.begin() + k, v.end());
|
||||||
|
std::nth_element(v.begin() + k, v.begin() + (n - k), v.end());
|
||||||
|
double s = 0.0;
|
||||||
|
for (int64_t i = k; i < n - k; ++i) {
|
||||||
|
s += v[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return s;
|
||||||
|
};
|
||||||
|
|
||||||
const double scale_rows = (double)nrows / std::max(1.0, (double)rs);
|
const double scale_rows = (double)nrows / std::max(1.0, (double)rs);
|
||||||
total_err += slice_err * scale_rows;
|
|
||||||
if (!std::isfinite(total_err)) { return infinity; }
|
total_mse += trimmed_sum(row_mse_norm) * scale_rows;
|
||||||
|
if (activations) { total_proj += trimmed_sum(row_proj_norm) * scale_rows; }
|
||||||
|
|
||||||
|
if (!std::isfinite(total_mse) || !std::isfinite(total_proj)) {
|
||||||
|
if (out_mse) { *out_mse = infinity; }
|
||||||
|
if (out_proj) { *out_proj = 0.0; }
|
||||||
|
|
||||||
|
return infinity;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (out_mse) { *out_mse = total_mse; }
|
||||||
|
if (out_proj) { *out_proj = total_proj; }
|
||||||
|
|
||||||
|
const double total_err = total_mse + bias_lambda * total_proj;
|
||||||
return std::isfinite(total_err) ? total_err : infinity;
|
return std::isfinite(total_err) ? total_err : infinity;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Higher precision but much longer to compute
|
// Higher precision but longer to compute
|
||||||
auto precise_lambda = [&](const ggml_tensor * t,
|
auto precise_lambda = [&](const ggml_tensor * t,
|
||||||
const std::vector<float> & f32_sample,
|
const std::vector<float> & f32_sample,
|
||||||
const std::vector<int64_t> & sample_rows_per_slice,
|
const std::vector<int64_t> & sample_rows_per_slice,
|
||||||
|
|
@ -936,22 +987,17 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
const int64_t n_per_row = t->ne[0];
|
const int64_t n_per_row = t->ne[0];
|
||||||
const size_t total_sampled_rows = f32_sample.size() / n_per_row;
|
const size_t total_sampled_rows = f32_sample.size() / n_per_row;
|
||||||
size_t max_row_sz = 0;
|
size_t max_row_sz = 0;
|
||||||
for (auto pt : probes) {
|
for (auto pt : probes) max_row_sz = std::max(max_row_sz, ggml_row_size(pt, n_per_row));
|
||||||
max_row_sz = std::max(max_row_sz, ggml_row_size(pt, n_per_row));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<uint8_t> quantized_buffer(max_row_sz * total_sampled_rows);
|
std::vector<uint8_t> quantized_buffer(max_row_sz * total_sampled_rows);
|
||||||
std::vector<float> dequantized_buffer(f32_sample.size());
|
std::vector<float> dequantized_buffer(f32_sample.size());
|
||||||
|
|
||||||
std::vector<double> ratios;
|
std::vector<double> ratios;
|
||||||
ratios.reserve(probes.size());
|
ratios.reserve(probes.size());
|
||||||
for (const auto pt : probes) {
|
for (const auto pt : probes) {
|
||||||
// err at lambda=0 => pure weighted MSE part
|
double m = 0.0;
|
||||||
double err0 = estimate_error(t, pt, f32_sample, sample_rows_per_slice, values, activations, quantized_buffer, dequantized_buffer, 0.0f);
|
double p = 0.0;
|
||||||
// err at lambda=1 => weighted MSE + projection penalty
|
(void)estimate_error(t, pt, f32_sample, sample_rows_per_slice, values, activations, quantized_buffer, dequantized_buffer, 0.0f, &m, &p);
|
||||||
const double err1 = estimate_error(t, pt, f32_sample, sample_rows_per_slice, values, activations, quantized_buffer, dequantized_buffer, 1.0f);
|
|
||||||
|
|
||||||
const double p = std::max(0.0, err1 - err0); // projection term contribution
|
|
||||||
const double m = std::max(0.0, err0); // MSE term contribution
|
|
||||||
if (p > epsilon && std::isfinite(m) && std::isfinite(p)) {
|
if (p > epsilon && std::isfinite(m) && std::isfinite(p)) {
|
||||||
ratios.push_back(m / p);
|
ratios.push_back(m / p);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue