Refactor estimate_error()
This commit is contained in:
parent
f75265f55b
commit
73124a9921
|
|
@ -742,38 +742,33 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
const size_t sample_row_count = sample_element_count / (size_t)n_per_row;
|
||||
if (sample_row_count == 0) { return 0.0; }
|
||||
|
||||
const size_t row_size = ggml_row_size(quant_type, n_per_row);
|
||||
const size_t buffer_size = row_size * sample_row_count;
|
||||
if (quantized_buffer.size() < buffer_size) { quantized_buffer.resize(buffer_size); }
|
||||
const size_t row_sz = ggml_row_size(quant_type, n_per_row);
|
||||
const size_t buffer_sz = row_sz * sample_row_count;
|
||||
|
||||
if (quantized_buffer.size() < buffer_sz) { quantized_buffer.resize(buffer_sz); }
|
||||
if (dequantized_buffer.size() < sample_element_count) { dequantized_buffer.resize(sample_element_count); }
|
||||
|
||||
std::vector<double> row_sq_norm(sample_row_count, 0.0);
|
||||
std::vector<double> bias_denominator_per_slice(ne2, 0.0);
|
||||
const bool has_values = values_sample != nullptr;
|
||||
const bool has_activations = activations_sample != nullptr;
|
||||
|
||||
// Precompute bias denominator per slice
|
||||
const bool has_values = (values_sample != nullptr);
|
||||
const bool has_activations = (activations_sample != nullptr);
|
||||
// Bias denominators per slice (only needed if we have activations)
|
||||
std::vector<double> bias_denominator_per_slice(ne2, 0.0);
|
||||
if (has_activations) {
|
||||
for (int64_t s = 0; s < ne2; ++s) {
|
||||
const float * values = has_values ? values_sample + s * n_per_row : nullptr;
|
||||
const float * activations = activations_sample + s * n_per_row;
|
||||
double bias_denominator = 0.0;
|
||||
if (has_values) {
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
const double a = activations[j];
|
||||
bias_denominator += values[j] * a * a;
|
||||
}
|
||||
} else {
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
const double a = activations[j];
|
||||
bias_denominator += a * a;
|
||||
}
|
||||
double denom = 0.0;
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
const double a = activations[j];
|
||||
const double w = values ? values[j] : 1.0;
|
||||
denom += w * a * a;
|
||||
}
|
||||
bias_denominator_per_slice[s] = bias_denominator;
|
||||
bias_denominator_per_slice[s] = denom;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute squared norms of sampled rows
|
||||
// Compute per-row squared norms with weighting (if values are provided)
|
||||
std::vector<double> row_sq_norm(sample_row_count, 0.0);
|
||||
{
|
||||
size_t offset = 0;
|
||||
size_t row_idx = 0;
|
||||
|
|
@ -784,18 +779,18 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
const float * values = has_values ? values_sample + s * n_per_row : nullptr;
|
||||
|
||||
for (int64_t r = 0; r < rs; ++r, ++row_idx) {
|
||||
const float * row = f32_sample.data() + offset;
|
||||
const float * x = f32_sample.data() + offset;
|
||||
double rsn = 0.0;
|
||||
if (has_values) {
|
||||
if (values) {
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
const double v = values[j];
|
||||
const double x = row[j];
|
||||
rsn += v * x * x;
|
||||
const double v = values[j];
|
||||
const double xx = x[j];
|
||||
rsn += v * xx * xx;
|
||||
}
|
||||
} else {
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
const double x = row[j];
|
||||
rsn += x * x;
|
||||
const double xx = x[j];
|
||||
rsn += xx * xx;
|
||||
}
|
||||
}
|
||||
row_sq_norm[row_idx] = rsn;
|
||||
|
|
@ -805,35 +800,44 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
}
|
||||
|
||||
// Quantize sampled rows slice-by-slice into quantized_buffer
|
||||
size_t quantised_offset = 0;
|
||||
size_t floats_offset = 0;
|
||||
for (int64_t slice = 0; slice < ne2; ++slice) {
|
||||
const int64_t rs = sample_rows_per_slice[slice];
|
||||
if (rs == 0) { continue; }
|
||||
{
|
||||
size_t q_offset = 0;
|
||||
size_t f_offset = 0;
|
||||
for (int64_t slice = 0; slice < ne2; ++slice) {
|
||||
const int64_t rs = sample_rows_per_slice[slice];
|
||||
if (rs == 0) { continue; }
|
||||
|
||||
const float * value = values_sample ? values_sample + slice * n_per_row : nullptr;
|
||||
(void)ggml_quantize_chunk(quant_type, f32_sample.data() + floats_offset, quantized_buffer.data() + quantised_offset, 0, rs, n_per_row, value);
|
||||
const float * value = has_values ? values_sample + slice * n_per_row : nullptr;
|
||||
(void)ggml_quantize_chunk(quant_type, f32_sample.data() + f_offset, quantized_buffer.data() + q_offset, 0, rs, n_per_row, value);
|
||||
|
||||
quantised_offset += row_size * (size_t)rs;
|
||||
floats_offset += (size_t)rs * (size_t)n_per_row;
|
||||
q_offset += row_sz * (size_t)rs;
|
||||
f_offset += (size_t)rs * (size_t)n_per_row;
|
||||
}
|
||||
}
|
||||
|
||||
// Dequantize into dequantized_buffer
|
||||
{
|
||||
const ggml_type_traits * traits = ggml_get_type_traits(quant_type);
|
||||
if (quant_type == GGML_TYPE_F16) {
|
||||
ggml_fp16_to_fp32_row((const ggml_fp16_t *)quantized_buffer.data(), dequantized_buffer.data(), (int)sample_element_count);
|
||||
} else if (quant_type == GGML_TYPE_BF16) {
|
||||
ggml_bf16_to_fp32_row((const ggml_bf16_t *)quantized_buffer.data(), dequantized_buffer.data(), (int)sample_element_count);
|
||||
} else {
|
||||
if (!traits || !traits->to_float) {
|
||||
LLAMA_LOG_WARN("%s: unsupported quantization type %s\n", __func__, ggml_type_name(quant_type));
|
||||
return 1e35;
|
||||
}
|
||||
const size_t row_size = ggml_row_size(quant_type, n_per_row);
|
||||
for (size_t r = 0; r < sample_row_count; ++r) {
|
||||
traits->to_float(quantized_buffer.data() + r * row_size, dequantized_buffer.data() + r * n_per_row, (int)n_per_row);
|
||||
auto row_to_float = [&](size_t r) {
|
||||
uint8_t * src = quantized_buffer.data() + r * row_sz;
|
||||
float * dst = dequantized_buffer.data() + r * (size_t)n_per_row;
|
||||
if (quant_type == GGML_TYPE_F16) {
|
||||
ggml_fp16_to_fp32_row((const ggml_fp16_t *)src, dst, (int)n_per_row);
|
||||
} else if (quant_type == GGML_TYPE_BF16) {
|
||||
ggml_bf16_to_fp32_row((const ggml_bf16_t *)src, dst, (int)n_per_row);
|
||||
} else {
|
||||
if (!traits || !traits->to_float) {
|
||||
LLAMA_LOG_WARN("%s: unsupported quantization type %s\n", __func__, ggml_type_name(quant_type));
|
||||
return false;
|
||||
}
|
||||
traits->to_float(src, dst, (int)n_per_row);
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
for (size_t r = 0; r < sample_row_count; ++r) {
|
||||
if (!row_to_float(r)) { return 1e35; }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -847,20 +851,22 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
|
||||
const float * values = has_values ? values_sample + slice * n_per_row : nullptr;
|
||||
const float * activations = has_activations ? activations_sample + slice * n_per_row : nullptr;
|
||||
const double bias_denominator = has_activations ? bias_denominator_per_slice[slice] : 0.0;
|
||||
const double bias_denom = has_activations ? bias_denominator_per_slice[slice] : 0.0;
|
||||
|
||||
double slice_err = 0.0;
|
||||
|
||||
for (int64_t r = 0; r < rs; ++r, ++row_idx) {
|
||||
const float * x = f32_sample.data() + offset;
|
||||
const float * y = dequantized_buffer.data() + offset;
|
||||
double weighted_mse = 0.0;
|
||||
double bias_numerator = 0.0;
|
||||
double bias_num = 0.0;
|
||||
if (values && activations) {
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
const double v = values[j];
|
||||
const double e = y[j] - x[j];
|
||||
const double a = activations[j];
|
||||
weighted_mse += v * e * e;
|
||||
bias_numerator += v * e * a;
|
||||
bias_num += v * e * a;
|
||||
}
|
||||
} else if (values) {
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
|
|
@ -873,7 +879,7 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
const double e = y[j] - x[j];
|
||||
const double a = activations[j];
|
||||
weighted_mse += e * e;
|
||||
bias_numerator += e * a;
|
||||
bias_num += e * a;
|
||||
}
|
||||
} else {
|
||||
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||
|
|
@ -882,24 +888,19 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
}
|
||||
}
|
||||
|
||||
double err_numerator = weighted_mse;
|
||||
constexpr float bias_lambda = 1.75f;
|
||||
constexpr double epsilon = 1e-12;
|
||||
constexpr float bias_lambda = 1.0;
|
||||
//bias_lambda defines the weight of the bias term in the weigthed MSE error function
|
||||
// 0.0 means no bias (standard MSE) 1.0 means equal weight for bias and error,
|
||||
// 2.0 means twice as much weight for bias, etc. Default is 1.0.
|
||||
if (activations && bias_lambda != 0.0) {
|
||||
const double proj = bias_numerator * bias_numerator / (bias_denominator + epsilon);
|
||||
err_numerator += bias_lambda * proj;
|
||||
double err_num = weighted_mse;
|
||||
if (activations && bias_lambda != 0.0f) {
|
||||
const double proj = bias_num * bias_num / (bias_denom + epsilon);
|
||||
err_num += (double)bias_lambda * proj;
|
||||
}
|
||||
|
||||
const double err_denominator = row_sq_norm[row_idx] + epsilon;
|
||||
const double row_err = err_numerator / err_denominator;
|
||||
slice_err += row_err;
|
||||
const double err_den = row_sq_norm[row_idx] + epsilon;
|
||||
slice_err += err_num / err_den;
|
||||
offset += (size_t)n_per_row;
|
||||
}
|
||||
|
||||
// scale to full rows (nrows)
|
||||
const double scale_rows = (double)nrows / std::max(1.0, (double)rs);
|
||||
total_err += slice_err * scale_rows;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue