Add target_bpw_type() logic

This commit is contained in:
Ed Addario 2025-08-19 11:05:01 +01:00
parent 017945a3b2
commit 92f49ab399
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 482 additions and 0 deletions

View File

@ -575,6 +575,488 @@ static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float *
return new_size; return new_size;
} }
// Returns per-tensor overrides of quantization types to meet target BPW with best expected quality.
// imatrix_data: map from tensor name -> length (ne[0] * ne[2]) containing per-column E[a^2] by expert
// activations_data: optional map from tensor name -> length (ne[0] * ne[2]) containing per-column E[a] by expert
// bias_lambda: relative weight on bias term (|sum e_j * E[a_j]|) vs MSE term (sum e_j^2 * E[a_j^2])
static std::unordered_map<std::string, ggml_type> target_bpw_type(
llama_model_loader & ml,
std::vector<no_init<uint8_t>> & read_data,
const llama_model & model,
const std::vector<const llama_model_loader::llama_tensor_weight *> & tensors,
const std::map<int, std::string> & mapped,
const std::unordered_map<std::string, std::vector<float>> * values_data,
const std::unordered_map<std::string, std::vector<float>> * activations_data,
float target_bpw,
int nthread,
int sample_rows_per_expert = 128,
float bias_lambda = 1.0
) {
struct candidate_types {
ggml_type type;
float bpw;
size_t bytes;
float error; // lower is better
};
struct tensor_info {
const llama_model_loader::llama_tensor_weight * w;
std::vector<candidate_types> candidate; // sorted by bpw ascending
int choice = -1; // index into cand
float min_bpw = 0.0;
float max_bpw = 0.0;
size_t n_elements = 0;
};
auto name_tn = LLM_TN(model.arch);
// The candidate types we consider; adjust as needed
const ggml_type base_candidates[] = {
// Model's
GGML_TYPE_IQ1_S,
GGML_TYPE_IQ1_M,
GGML_TYPE_IQ2_XXS,
GGML_TYPE_IQ2_XS,
GGML_TYPE_IQ2_S,
GGML_TYPE_IQ3_XXS,
GGML_TYPE_IQ3_S,
GGML_TYPE_IQ4_XS,
GGML_TYPE_IQ4_NL,
GGML_TYPE_Q2_K,
GGML_TYPE_Q3_K,
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1,
GGML_TYPE_Q4_K,
GGML_TYPE_Q5_0,
GGML_TYPE_Q5_1,
GGML_TYPE_Q5_K,
GGML_TYPE_Q6_K,
GGML_TYPE_Q8_0
};
auto can_quantize = [&](const ggml_tensor * t) -> bool {
const std::string name = ggml_get_name(t);
bool q = name.rfind("weight") == name.size() - 6;
q &= (ggml_n_dims(t) >= 2);
q &= name.find("_norm.weight") == std::string::npos;
//q &= name != name_tn(LLM_TENSOR_TOKEN_EMBD, "weight");
//q &= name != name_tn(LLM_TENSOR_OUTPUT, "weight");
q &= name.find("ffn_gate_inp.weight") == std::string::npos;
q &= name.find("altup") == std::string::npos;
q &= name.find("laurel") == std::string::npos;
q &= name.find("per_layer_model_proj") == std::string::npos;
q &= name != name_tn(LLM_TENSOR_POS_EMBD, "weight");
q &= name != name_tn(LLM_TENSOR_TOKEN_TYPES, "weight");
q &= name.find("ssm_conv1d.weight") == std::string::npos;
q &= name.find("shortconv.conv.weight") == std::string::npos;
q &= name.find("time_mix_first.weight") == std::string::npos;
q &= name.find("time_mix_w0.weight") == std::string::npos;
q &= name.find("time_mix_w1.weight") == std::string::npos;
q &= name.find("time_mix_w2.weight") == std::string::npos;
q &= name.find("time_mix_v0.weight") == std::string::npos;
q &= name.find("time_mix_v1.weight") == std::string::npos;
q &= name.find("time_mix_v2.weight") == std::string::npos;
q &= name.find("time_mix_a0.weight") == std::string::npos;
q &= name.find("time_mix_a1.weight") == std::string::npos;
q &= name.find("time_mix_a2.weight") == std::string::npos;
q &= name.find("time_mix_g1.weight") == std::string::npos;
q &= name.find("time_mix_g2.weight") == std::string::npos;
q &= name.find("time_mix_decay_w1.weight") == std::string::npos;
q &= name.find("time_mix_decay_w2.weight") == std::string::npos;
q &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
q &= name.find("attn_rel_b.weight") == std::string::npos;
return q;
};
auto get_values = [&](const std::string & tensor_name) -> const float * {
if (!values_data) { return nullptr; }
const auto it = values_data->find(remap_imatrix(tensor_name, mapped));
if (it == values_data->end()) { return nullptr; }
return it->second.data();
};
auto get_activations = [&](const std::string & tensor_name) -> const float * {
if (!activations_data) { return nullptr; }
const auto it = activations_data->find(remap_imatrix(tensor_name, mapped));
if (it == activations_data->end()) { return nullptr; }
return it->second.data();
};
auto total_bytes = [](const ggml_tensor * t, const ggml_type typ) -> size_t {
const int64_t n_per_row = t->ne[0];
const int64_t nrows = t->ne[1];
const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1;
const size_t row_sz = ggml_row_size(typ, n_per_row);
return (size_t)ne2 * (size_t)nrows * row_sz;
};
auto tensor_bpw = [&](const ggml_tensor * t, const ggml_type typ) -> double {
const int64_t nelem = ggml_nelements(t);
const size_t bytes = total_bytes(t, typ);
return bytes * 8.0 / nelem;
};
auto is_compatible = [&](const ggml_tensor * t, const ggml_type typ) -> bool {
const int64_t n_per_row = t->ne[0];
const int64_t blck = ggml_blck_size(typ);
if (blck <= 1) { return true; } // FP16/BF16/Q8_0 etc
return n_per_row % blck == 0;
};
auto make_compatible = [&](const ggml_tensor * t, const ggml_type typ) -> ggml_type {
if (is_compatible(t, typ)) { return typ; }
ggml_type fb = fallback_type(typ);
if (is_compatible(t, fb)) { return fb; }
return GGML_TYPE_F16; // final guard
};
// Estimate error for a given type using a sampled subset of rows.
// Uses both imatrix (E[a^2]) and activations (E[a]) if available.
auto estimate_error = [&](const ggml_tensor * t, const float * f32_data, const ggml_type typ, const float * values_all, const float * activations_all) -> double {
const int64_t n_per_row = t->ne[0];
const int64_t nrows = t->ne[1];
const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1;
const ggml_type_traits * traits = ggml_get_type_traits(typ);
if (!traits || !traits->to_float) {
// cannot dequantize candidate -> assign very high error
return 1e35f;
}
// Sampling plan: for each expert slice, take up to sample_rows rows spread uniformly
const int64_t rows_per_expert = nrows;
const int64_t sample_rows = std::max<int64_t>(1, std::min<int64_t>(rows_per_expert, sample_rows_per_expert));
const int64_t stride = std::max<int64_t>(1, rows_per_expert / sample_rows);
const size_t row_sz = ggml_row_size(typ, n_per_row);
std::vector<uint8_t> qbuf(row_sz * sample_rows);
std::vector<float> f32_sample(sample_rows * n_per_row);
std::vector<float> deq(sample_rows * n_per_row);
float total_err = 0.0;
for (int64_t i03 = 0; i03 < ne2; ++i03) {
const float * value = values_all ? (values_all + i03 * n_per_row) : nullptr;
const float * activation = activations_all ? (activations_all + i03 * n_per_row) : nullptr;
// Assemble sampled rows into contiguous f32_sample
int64_t rs = 0;
for (int64_t r = 0; r < rows_per_expert && rs < sample_rows; r += stride) {
const float * src = f32_data + i03 * (n_per_row * rows_per_expert) + r * n_per_row;
std::memcpy(f32_sample.data() + rs * n_per_row, src, sizeof(float) * n_per_row);
++rs;
}
if (rs == 0) { continue; }
// Quantize sampled rows in one chunk; pass the imatrix for this expert slice
const size_t got = ggml_quantize_chunk(typ, f32_sample.data(), qbuf.data(), 0, rs, n_per_row, value);
(void)got; // not strictly needed here
// Dequantize
traits->to_float(qbuf.data(), deq.data(), rs * n_per_row);
// Compute error proxy per sampled row
for (int64_t s = 0; s < rs; ++s) {
const float * xs = f32_sample.data() + s * n_per_row;
const float * ys = deq.data() + s * n_per_row;
float mse_w = 0.0;
float bias = 0.0;
float bias_sum = 0.0;
if (value) {
for (int64_t j = 0; j < n_per_row; ++j) {
const float e = ys[j] - xs[j];
mse_w += e * e * value[j];
if (activation) {
bias_sum += e * activation[j];
}
}
} else {
for (int64_t j = 0; j < n_per_row; ++j) {
const float e = ys[j] - xs[j];
mse_w += e*e;
if (activation) {
bias_sum += e * activation[j];
}
}
}
if (activation) {
bias = std::abs(bias_sum);
}
// Normalize by n_per_row to get a per-row average scale
float row_err = mse_w / std::max<int64_t>(1, n_per_row);
if (bias_lambda != 0.0) {
row_err += bias_lambda * (bias / std::max<int64_t>(1, n_per_row));
}
total_err += row_err;
}
// Scale for the rows we didn't sample in this expert: multiply by stride-ish factor
const float scale_rows = rows_per_expert / std::max<int64_t>(1, rs);
total_err *= scale_rows;
}
return total_err;
};
// Produce per-tensor candidate lists
std::vector<tensor_info> all;
all.reserve(tensors.size());
for (const auto * tw : tensors) {
// Temporary workers for dequantization
std::vector<std::thread> workers;
workers.reserve(std::max(1, nthread));
ggml_tensor * t = tw->tensor;
const std::string name = ggml_get_name(t);
if (!can_quantize(t)) {
continue;
}
LLAMA_LOG_INFO("\t%s: - processing tensor %45s \t(%12d elements)\n", __func__, name.c_str(), (int)ggml_nelements(t));
if (!ml.use_mmap) {
if (read_data.size() < ggml_nbytes(t)) {
read_data.resize(ggml_nbytes(t));
}
t->data = read_data.data();
}
ml.load_data_for(t);
// Prepare f32 weights for error estimates
const int64_t nelem = ggml_nelements(t);
std::vector<no_init<float>> f32_conv_buf;
float * f32_data = nullptr;
if (t->type == GGML_TYPE_F32) {
f32_data = (float *)t->data;
} else {
llama_tensor_dequantize_impl(t, f32_conv_buf, workers, nelem, nthread);
f32_data = (float *)f32_conv_buf.data();
}
const float * values = get_values(name);
const float * activations = get_activations(name);
tensor_info info;
info.w = tw;
info.n_elements = nelem;
// Candidate build with compatibility handling and availability checks
for (ggml_type ts_type : base_candidates) {
// Skip IQ* without imatrix
if (is_iq(ts_type) && !values) { continue; }
ggml_type tt = make_compatible(t, ts_type);
// After fallback, if still incompatible, skip
if (!is_compatible(t, tt)) { continue; }
// Compute bpw and bytes
auto bpw = (float)tensor_bpw(t, tt);
size_t bytes = total_bytes(t, tt);
// Estimate error
auto err = (float)estimate_error(t, f32_data, tt, values, activations);
info.candidate.push_back(candidate_types{tt, bpw, bytes, err});
}
if (info.candidate.empty()) {
// as a last resort, keep original type
float bpw = ggml_nbytes(t) * 8.0f / nelem;
info.candidate.push_back(candidate_types{t->type, bpw, ggml_nbytes(t), 0.0});
}
// Sort by bpw ascending
std::sort(info.candidate.begin(), info.candidate.end(), [](const candidate_types &a, const candidate_types &b) {
if (a.bpw != b.bpw) { return a.bpw < b.bpw; }
if (a.error != b.error) { return a.error < b.error; }
return a.bytes < b.bytes;
});
// collapse candidates with identical storage size (bytes)
{
std::vector<candidate_types> uniq;
uniq.reserve(info.candidate.size());
for (size_t i = 0; i < info.candidate.size(); ) {
size_t j = i + 1;
candidate_types best = info.candidate[i];
// group same-byte entries, keep the one with the lowest error
while (j < info.candidate.size() && info.candidate[j].bytes == info.candidate[i].bytes) {
if (info.candidate[j].error < best.error) { best = info.candidate[j]; }
++j;
}
uniq.push_back(best);
i = j;
}
info.candidate.swap(uniq);
}
// Initialize choice at the smallest bpw candidate
info.choice = 0;
info.min_bpw = info.candidate.front().bpw;
info.max_bpw = info.candidate.back().bpw;
all.push_back(std::move(info));
}
if (all.empty()) { return {}; }
// Greedy allocation from minimum bpw upward to reach target_bpw
// Start with minimal bpw assignment
auto current_total_bytes = [&]() -> size_t {
size_t b = 0;
for (const auto & ti : all) {
b += ti.candidate[ti.choice].bytes;
}
return b;
};
auto total_weights = [&]() -> size_t {
size_t w = 0;
for (const auto & ti : all) {
w += ti.n_elements;
}
return w;
};
const size_t tw = total_weights();
auto current_bpw = [&]() -> double {
return (double)current_total_bytes() * 8.0f / (double)tw;
};
// Precompute current bpw
double bpw_now = current_bpw();
// If minimal bpw is already above the target, we're constrained by geometry; return closest (min bpw)
if (bpw_now >= target_bpw) {
std::unordered_map<std::string, ggml_type> overrides;
for (const auto & ti : all) {
overrides[ggml_get_name(ti.w->tensor)] = ti.candidate[ti.choice].type;
}
return overrides;
}
struct upgrade {
int idx; // tensor index
int next; // next candidate index (strictly larger bytes)
double err; // error reduction
size_t delta_bytes; // increase in bytes
double ratio; // err per added bit
};
// Find next strictly-larger candidate index for a tensor
auto next_distinct_idx = [&](const tensor_info &ti) -> int {
const auto &cand = ti.candidate;
const auto &cur = cand[ti.choice];
int j = ti.choice + 1;
while (j < (int)cand.size() && cand[j].bytes == cur.bytes) ++j;
return j < (int)cand.size() ? j : -1;
};
auto recompute_best_upgrade = [&]() -> upgrade {
const double eps = 1e-12;
upgrade best{-1, -1, 0.0, 0, -1.0};
for (int i = 0; i < (int)all.size(); ++i) {
const auto &ti = all[i];
if (ti.choice >= (int)ti.candidate.size() - 1) { continue; }
int j = next_distinct_idx(ti);
if (j < 0) { continue; } // no larger-size candidate remains
const auto &cur = ti.candidate[ti.choice];
const auto &nxt = ti.candidate[j];
size_t delta_bytes = nxt.bytes - cur.bytes;
if (delta_bytes == 0) { continue; } // should not happen after dedup, but be safe
double err = (double)cur.error - (double)nxt.error;
err = std::max(err, 0.0); // do not penalize due to sampling noise
double ratio = err / (double)(delta_bytes * 8ull);
if (ratio > best.ratio + eps || (std::abs(ratio - best.ratio) <= eps && delta_bytes < best.delta_bytes)) {
best = upgrade{i, j, err, delta_bytes, ratio};
}
}
return best;
};
while (true) {
upgrade up = recompute_best_upgrade();
if (up.idx < 0) { break; }
size_t now_bytes = current_total_bytes();
size_t next_bytes = now_bytes + up.delta_bytes;
double bpw_next = (double)next_bytes * 8.0 / (double)tw;
if (bpw_next <= (double)target_bpw + 1e-12) {
all[up.idx].choice = up.next;
bpw_now = bpw_next;
} else {
break;
}
}
// We might still be below target but taking any single upgrade overshoots.
{
double under_gap = (double)target_bpw - bpw_now;
upgrade best_over{-1, -1, 0.0, 0, -1.0};
double best_over_gap = 1e300;
size_t now_bytes = current_total_bytes();
for (int i = 0; i < (int)all.size(); ++i) {
const auto &ti = all[i];
if (ti.choice >= (int)ti.candidate.size() - 1) { continue; }
int j = next_distinct_idx(ti);
if (j < 0) { continue; }
const auto &cur = ti.candidate[ti.choice];
const auto &nxt = ti.candidate[j];
size_t delta_bytes = nxt.bytes - cur.bytes;
if (delta_bytes == 0) { continue; }
size_t over_bytes = now_bytes + delta_bytes;
double bpw_over = (double)over_bytes * 8.0 / (double)tw;
double over_gap = std::abs(bpw_over - (double)target_bpw);
double err = (double)cur.error - (double)nxt.error;
if (err < 0.0) { err = 0.0; }
double ratio = err / (double)(delta_bytes * 8ull);
if (over_gap < best_over_gap - 1e-12 || (std::abs(over_gap - best_over_gap) <= 1e-12 && ratio > best_over.ratio)) {
best_over_gap = over_gap;
best_over = upgrade{i, j, err, delta_bytes, ratio};
}
}
if (best_over.idx >= 0) {
if (best_over_gap < under_gap) {
all[best_over.idx].choice = best_over.next;
}
}
}
// Build the override map
std::unordered_map<std::string, ggml_type> overrides;
LLAMA_LOG_INFO("%s: - estimated tensor quantization mix to achieve %.4f bpw at lowest ppl\n", __func__, target_bpw);
for (const auto & ti : all) {
LLAMA_LOG_INFO("\t%s: %45s - \t%8s, \t%1.4f bpw,\terror: %.4f\n",
__func__, ggml_get_name(ti.w->tensor), ggml_type_name(ti.candidate[ti.choice].type), ti.candidate[ti.choice].bpw, ti.candidate[ti.choice].error);
overrides[ggml_get_name(ti.w->tensor)] = ti.candidate[ti.choice].type;
}
return overrides;
}
static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
ggml_type default_type; ggml_type default_type;
llama_ftype ftype = params->ftype; llama_ftype ftype = params->ftype;