Add target_bpw_type() logic
This commit is contained in:
parent
017945a3b2
commit
92f49ab399
|
|
@ -575,6 +575,488 @@ static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float *
|
||||||
return new_size;
|
return new_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns per-tensor overrides of quantization types to meet target BPW with best expected quality.
|
||||||
|
// imatrix_data: map from tensor name -> length (ne[0] * ne[2]) containing per-column E[a^2] by expert
|
||||||
|
// activations_data: optional map from tensor name -> length (ne[0] * ne[2]) containing per-column E[a] by expert
|
||||||
|
// bias_lambda: relative weight on bias term (|sum e_j * E[a_j]|) vs MSE term (sum e_j^2 * E[a_j^2])
|
||||||
|
static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
||||||
|
llama_model_loader & ml,
|
||||||
|
std::vector<no_init<uint8_t>> & read_data,
|
||||||
|
const llama_model & model,
|
||||||
|
const std::vector<const llama_model_loader::llama_tensor_weight *> & tensors,
|
||||||
|
const std::map<int, std::string> & mapped,
|
||||||
|
const std::unordered_map<std::string, std::vector<float>> * values_data,
|
||||||
|
const std::unordered_map<std::string, std::vector<float>> * activations_data,
|
||||||
|
float target_bpw,
|
||||||
|
int nthread,
|
||||||
|
int sample_rows_per_expert = 128,
|
||||||
|
float bias_lambda = 1.0
|
||||||
|
) {
|
||||||
|
struct candidate_types {
|
||||||
|
ggml_type type;
|
||||||
|
float bpw;
|
||||||
|
size_t bytes;
|
||||||
|
float error; // lower is better
|
||||||
|
};
|
||||||
|
|
||||||
|
struct tensor_info {
|
||||||
|
const llama_model_loader::llama_tensor_weight * w;
|
||||||
|
std::vector<candidate_types> candidate; // sorted by bpw ascending
|
||||||
|
int choice = -1; // index into cand
|
||||||
|
float min_bpw = 0.0;
|
||||||
|
float max_bpw = 0.0;
|
||||||
|
size_t n_elements = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto name_tn = LLM_TN(model.arch);
|
||||||
|
|
||||||
|
// The candidate types we consider; adjust as needed
|
||||||
|
const ggml_type base_candidates[] = {
|
||||||
|
// Model's
|
||||||
|
GGML_TYPE_IQ1_S,
|
||||||
|
GGML_TYPE_IQ1_M,
|
||||||
|
GGML_TYPE_IQ2_XXS,
|
||||||
|
GGML_TYPE_IQ2_XS,
|
||||||
|
GGML_TYPE_IQ2_S,
|
||||||
|
GGML_TYPE_IQ3_XXS,
|
||||||
|
GGML_TYPE_IQ3_S,
|
||||||
|
GGML_TYPE_IQ4_XS,
|
||||||
|
GGML_TYPE_IQ4_NL,
|
||||||
|
GGML_TYPE_Q2_K,
|
||||||
|
GGML_TYPE_Q3_K,
|
||||||
|
GGML_TYPE_Q4_0,
|
||||||
|
GGML_TYPE_Q4_1,
|
||||||
|
GGML_TYPE_Q4_K,
|
||||||
|
GGML_TYPE_Q5_0,
|
||||||
|
GGML_TYPE_Q5_1,
|
||||||
|
GGML_TYPE_Q5_K,
|
||||||
|
GGML_TYPE_Q6_K,
|
||||||
|
GGML_TYPE_Q8_0
|
||||||
|
};
|
||||||
|
|
||||||
|
auto can_quantize = [&](const ggml_tensor * t) -> bool {
|
||||||
|
const std::string name = ggml_get_name(t);
|
||||||
|
bool q = name.rfind("weight") == name.size() - 6;
|
||||||
|
q &= (ggml_n_dims(t) >= 2);
|
||||||
|
q &= name.find("_norm.weight") == std::string::npos;
|
||||||
|
//q &= name != name_tn(LLM_TENSOR_TOKEN_EMBD, "weight");
|
||||||
|
//q &= name != name_tn(LLM_TENSOR_OUTPUT, "weight");
|
||||||
|
q &= name.find("ffn_gate_inp.weight") == std::string::npos;
|
||||||
|
q &= name.find("altup") == std::string::npos;
|
||||||
|
q &= name.find("laurel") == std::string::npos;
|
||||||
|
q &= name.find("per_layer_model_proj") == std::string::npos;
|
||||||
|
q &= name != name_tn(LLM_TENSOR_POS_EMBD, "weight");
|
||||||
|
q &= name != name_tn(LLM_TENSOR_TOKEN_TYPES, "weight");
|
||||||
|
q &= name.find("ssm_conv1d.weight") == std::string::npos;
|
||||||
|
q &= name.find("shortconv.conv.weight") == std::string::npos;
|
||||||
|
q &= name.find("time_mix_first.weight") == std::string::npos;
|
||||||
|
q &= name.find("time_mix_w0.weight") == std::string::npos;
|
||||||
|
q &= name.find("time_mix_w1.weight") == std::string::npos;
|
||||||
|
q &= name.find("time_mix_w2.weight") == std::string::npos;
|
||||||
|
q &= name.find("time_mix_v0.weight") == std::string::npos;
|
||||||
|
q &= name.find("time_mix_v1.weight") == std::string::npos;
|
||||||
|
q &= name.find("time_mix_v2.weight") == std::string::npos;
|
||||||
|
q &= name.find("time_mix_a0.weight") == std::string::npos;
|
||||||
|
q &= name.find("time_mix_a1.weight") == std::string::npos;
|
||||||
|
q &= name.find("time_mix_a2.weight") == std::string::npos;
|
||||||
|
q &= name.find("time_mix_g1.weight") == std::string::npos;
|
||||||
|
q &= name.find("time_mix_g2.weight") == std::string::npos;
|
||||||
|
q &= name.find("time_mix_decay_w1.weight") == std::string::npos;
|
||||||
|
q &= name.find("time_mix_decay_w2.weight") == std::string::npos;
|
||||||
|
q &= name.find("time_mix_lerp_fused.weight") == std::string::npos;
|
||||||
|
q &= name.find("attn_rel_b.weight") == std::string::npos;
|
||||||
|
return q;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto get_values = [&](const std::string & tensor_name) -> const float * {
|
||||||
|
if (!values_data) { return nullptr; }
|
||||||
|
const auto it = values_data->find(remap_imatrix(tensor_name, mapped));
|
||||||
|
if (it == values_data->end()) { return nullptr; }
|
||||||
|
return it->second.data();
|
||||||
|
};
|
||||||
|
|
||||||
|
auto get_activations = [&](const std::string & tensor_name) -> const float * {
|
||||||
|
if (!activations_data) { return nullptr; }
|
||||||
|
const auto it = activations_data->find(remap_imatrix(tensor_name, mapped));
|
||||||
|
if (it == activations_data->end()) { return nullptr; }
|
||||||
|
return it->second.data();
|
||||||
|
};
|
||||||
|
|
||||||
|
auto total_bytes = [](const ggml_tensor * t, const ggml_type typ) -> size_t {
|
||||||
|
const int64_t n_per_row = t->ne[0];
|
||||||
|
const int64_t nrows = t->ne[1];
|
||||||
|
const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1;
|
||||||
|
const size_t row_sz = ggml_row_size(typ, n_per_row);
|
||||||
|
return (size_t)ne2 * (size_t)nrows * row_sz;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto tensor_bpw = [&](const ggml_tensor * t, const ggml_type typ) -> double {
|
||||||
|
const int64_t nelem = ggml_nelements(t);
|
||||||
|
const size_t bytes = total_bytes(t, typ);
|
||||||
|
return bytes * 8.0 / nelem;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto is_compatible = [&](const ggml_tensor * t, const ggml_type typ) -> bool {
|
||||||
|
const int64_t n_per_row = t->ne[0];
|
||||||
|
const int64_t blck = ggml_blck_size(typ);
|
||||||
|
if (blck <= 1) { return true; } // FP16/BF16/Q8_0 etc
|
||||||
|
return n_per_row % blck == 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto make_compatible = [&](const ggml_tensor * t, const ggml_type typ) -> ggml_type {
|
||||||
|
if (is_compatible(t, typ)) { return typ; }
|
||||||
|
ggml_type fb = fallback_type(typ);
|
||||||
|
if (is_compatible(t, fb)) { return fb; }
|
||||||
|
return GGML_TYPE_F16; // final guard
|
||||||
|
};
|
||||||
|
|
||||||
|
// Estimate error for a given type using a sampled subset of rows.
|
||||||
|
// Uses both imatrix (E[a^2]) and activations (E[a]) if available.
|
||||||
|
auto estimate_error = [&](const ggml_tensor * t, const float * f32_data, const ggml_type typ, const float * values_all, const float * activations_all) -> double {
|
||||||
|
const int64_t n_per_row = t->ne[0];
|
||||||
|
const int64_t nrows = t->ne[1];
|
||||||
|
const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1;
|
||||||
|
|
||||||
|
const ggml_type_traits * traits = ggml_get_type_traits(typ);
|
||||||
|
if (!traits || !traits->to_float) {
|
||||||
|
// cannot dequantize candidate -> assign very high error
|
||||||
|
return 1e35f;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sampling plan: for each expert slice, take up to sample_rows rows spread uniformly
|
||||||
|
const int64_t rows_per_expert = nrows;
|
||||||
|
const int64_t sample_rows = std::max<int64_t>(1, std::min<int64_t>(rows_per_expert, sample_rows_per_expert));
|
||||||
|
const int64_t stride = std::max<int64_t>(1, rows_per_expert / sample_rows);
|
||||||
|
|
||||||
|
const size_t row_sz = ggml_row_size(typ, n_per_row);
|
||||||
|
std::vector<uint8_t> qbuf(row_sz * sample_rows);
|
||||||
|
std::vector<float> f32_sample(sample_rows * n_per_row);
|
||||||
|
std::vector<float> deq(sample_rows * n_per_row);
|
||||||
|
|
||||||
|
float total_err = 0.0;
|
||||||
|
|
||||||
|
for (int64_t i03 = 0; i03 < ne2; ++i03) {
|
||||||
|
const float * value = values_all ? (values_all + i03 * n_per_row) : nullptr;
|
||||||
|
const float * activation = activations_all ? (activations_all + i03 * n_per_row) : nullptr;
|
||||||
|
|
||||||
|
// Assemble sampled rows into contiguous f32_sample
|
||||||
|
int64_t rs = 0;
|
||||||
|
for (int64_t r = 0; r < rows_per_expert && rs < sample_rows; r += stride) {
|
||||||
|
const float * src = f32_data + i03 * (n_per_row * rows_per_expert) + r * n_per_row;
|
||||||
|
std::memcpy(f32_sample.data() + rs * n_per_row, src, sizeof(float) * n_per_row);
|
||||||
|
++rs;
|
||||||
|
}
|
||||||
|
if (rs == 0) { continue; }
|
||||||
|
|
||||||
|
// Quantize sampled rows in one chunk; pass the imatrix for this expert slice
|
||||||
|
const size_t got = ggml_quantize_chunk(typ, f32_sample.data(), qbuf.data(), 0, rs, n_per_row, value);
|
||||||
|
(void)got; // not strictly needed here
|
||||||
|
|
||||||
|
// Dequantize
|
||||||
|
traits->to_float(qbuf.data(), deq.data(), rs * n_per_row);
|
||||||
|
|
||||||
|
// Compute error proxy per sampled row
|
||||||
|
for (int64_t s = 0; s < rs; ++s) {
|
||||||
|
const float * xs = f32_sample.data() + s * n_per_row;
|
||||||
|
const float * ys = deq.data() + s * n_per_row;
|
||||||
|
|
||||||
|
float mse_w = 0.0;
|
||||||
|
float bias = 0.0;
|
||||||
|
float bias_sum = 0.0;
|
||||||
|
|
||||||
|
if (value) {
|
||||||
|
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||||
|
const float e = ys[j] - xs[j];
|
||||||
|
mse_w += e * e * value[j];
|
||||||
|
if (activation) {
|
||||||
|
bias_sum += e * activation[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int64_t j = 0; j < n_per_row; ++j) {
|
||||||
|
const float e = ys[j] - xs[j];
|
||||||
|
mse_w += e*e;
|
||||||
|
if (activation) {
|
||||||
|
bias_sum += e * activation[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (activation) {
|
||||||
|
bias = std::abs(bias_sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize by n_per_row to get a per-row average scale
|
||||||
|
float row_err = mse_w / std::max<int64_t>(1, n_per_row);
|
||||||
|
if (bias_lambda != 0.0) {
|
||||||
|
row_err += bias_lambda * (bias / std::max<int64_t>(1, n_per_row));
|
||||||
|
}
|
||||||
|
|
||||||
|
total_err += row_err;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scale for the rows we didn't sample in this expert: multiply by stride-ish factor
|
||||||
|
const float scale_rows = rows_per_expert / std::max<int64_t>(1, rs);
|
||||||
|
total_err *= scale_rows;
|
||||||
|
}
|
||||||
|
|
||||||
|
return total_err;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Produce per-tensor candidate lists
|
||||||
|
std::vector<tensor_info> all;
|
||||||
|
all.reserve(tensors.size());
|
||||||
|
|
||||||
|
for (const auto * tw : tensors) {
|
||||||
|
// Temporary workers for dequantization
|
||||||
|
std::vector<std::thread> workers;
|
||||||
|
workers.reserve(std::max(1, nthread));
|
||||||
|
|
||||||
|
ggml_tensor * t = tw->tensor;
|
||||||
|
const std::string name = ggml_get_name(t);
|
||||||
|
|
||||||
|
if (!can_quantize(t)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
LLAMA_LOG_INFO("\t%s: - processing tensor %45s \t(%12d elements)\n", __func__, name.c_str(), (int)ggml_nelements(t));
|
||||||
|
if (!ml.use_mmap) {
|
||||||
|
if (read_data.size() < ggml_nbytes(t)) {
|
||||||
|
read_data.resize(ggml_nbytes(t));
|
||||||
|
}
|
||||||
|
t->data = read_data.data();
|
||||||
|
}
|
||||||
|
ml.load_data_for(t);
|
||||||
|
|
||||||
|
// Prepare f32 weights for error estimates
|
||||||
|
const int64_t nelem = ggml_nelements(t);
|
||||||
|
std::vector<no_init<float>> f32_conv_buf;
|
||||||
|
float * f32_data = nullptr;
|
||||||
|
|
||||||
|
if (t->type == GGML_TYPE_F32) {
|
||||||
|
f32_data = (float *)t->data;
|
||||||
|
} else {
|
||||||
|
llama_tensor_dequantize_impl(t, f32_conv_buf, workers, nelem, nthread);
|
||||||
|
f32_data = (float *)f32_conv_buf.data();
|
||||||
|
}
|
||||||
|
|
||||||
|
const float * values = get_values(name);
|
||||||
|
const float * activations = get_activations(name);
|
||||||
|
|
||||||
|
tensor_info info;
|
||||||
|
info.w = tw;
|
||||||
|
info.n_elements = nelem;
|
||||||
|
|
||||||
|
// Candidate build with compatibility handling and availability checks
|
||||||
|
for (ggml_type ts_type : base_candidates) {
|
||||||
|
// Skip IQ* without imatrix
|
||||||
|
if (is_iq(ts_type) && !values) { continue; }
|
||||||
|
ggml_type tt = make_compatible(t, ts_type);
|
||||||
|
// After fallback, if still incompatible, skip
|
||||||
|
if (!is_compatible(t, tt)) { continue; }
|
||||||
|
|
||||||
|
// Compute bpw and bytes
|
||||||
|
auto bpw = (float)tensor_bpw(t, tt);
|
||||||
|
size_t bytes = total_bytes(t, tt);
|
||||||
|
|
||||||
|
// Estimate error
|
||||||
|
auto err = (float)estimate_error(t, f32_data, tt, values, activations);
|
||||||
|
|
||||||
|
info.candidate.push_back(candidate_types{tt, bpw, bytes, err});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (info.candidate.empty()) {
|
||||||
|
// as a last resort, keep original type
|
||||||
|
float bpw = ggml_nbytes(t) * 8.0f / nelem;
|
||||||
|
info.candidate.push_back(candidate_types{t->type, bpw, ggml_nbytes(t), 0.0});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by bpw ascending
|
||||||
|
std::sort(info.candidate.begin(), info.candidate.end(), [](const candidate_types &a, const candidate_types &b) {
|
||||||
|
if (a.bpw != b.bpw) { return a.bpw < b.bpw; }
|
||||||
|
if (a.error != b.error) { return a.error < b.error; }
|
||||||
|
return a.bytes < b.bytes;
|
||||||
|
});
|
||||||
|
|
||||||
|
// collapse candidates with identical storage size (bytes)
|
||||||
|
{
|
||||||
|
std::vector<candidate_types> uniq;
|
||||||
|
uniq.reserve(info.candidate.size());
|
||||||
|
|
||||||
|
for (size_t i = 0; i < info.candidate.size(); ) {
|
||||||
|
size_t j = i + 1;
|
||||||
|
candidate_types best = info.candidate[i];
|
||||||
|
// group same-byte entries, keep the one with the lowest error
|
||||||
|
while (j < info.candidate.size() && info.candidate[j].bytes == info.candidate[i].bytes) {
|
||||||
|
if (info.candidate[j].error < best.error) { best = info.candidate[j]; }
|
||||||
|
++j;
|
||||||
|
}
|
||||||
|
uniq.push_back(best);
|
||||||
|
i = j;
|
||||||
|
}
|
||||||
|
info.candidate.swap(uniq);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize choice at the smallest bpw candidate
|
||||||
|
info.choice = 0;
|
||||||
|
info.min_bpw = info.candidate.front().bpw;
|
||||||
|
info.max_bpw = info.candidate.back().bpw;
|
||||||
|
|
||||||
|
all.push_back(std::move(info));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (all.empty()) { return {}; }
|
||||||
|
|
||||||
|
// Greedy allocation from minimum bpw upward to reach target_bpw
|
||||||
|
// Start with minimal bpw assignment
|
||||||
|
auto current_total_bytes = [&]() -> size_t {
|
||||||
|
size_t b = 0;
|
||||||
|
for (const auto & ti : all) {
|
||||||
|
b += ti.candidate[ti.choice].bytes;
|
||||||
|
}
|
||||||
|
return b;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto total_weights = [&]() -> size_t {
|
||||||
|
size_t w = 0;
|
||||||
|
for (const auto & ti : all) {
|
||||||
|
w += ti.n_elements;
|
||||||
|
}
|
||||||
|
return w;
|
||||||
|
};
|
||||||
|
|
||||||
|
const size_t tw = total_weights();
|
||||||
|
auto current_bpw = [&]() -> double {
|
||||||
|
return (double)current_total_bytes() * 8.0f / (double)tw;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Precompute current bpw
|
||||||
|
double bpw_now = current_bpw();
|
||||||
|
|
||||||
|
// If minimal bpw is already above the target, we're constrained by geometry; return closest (min bpw)
|
||||||
|
if (bpw_now >= target_bpw) {
|
||||||
|
std::unordered_map<std::string, ggml_type> overrides;
|
||||||
|
for (const auto & ti : all) {
|
||||||
|
overrides[ggml_get_name(ti.w->tensor)] = ti.candidate[ti.choice].type;
|
||||||
|
}
|
||||||
|
return overrides;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct upgrade {
|
||||||
|
int idx; // tensor index
|
||||||
|
int next; // next candidate index (strictly larger bytes)
|
||||||
|
double err; // error reduction
|
||||||
|
size_t delta_bytes; // increase in bytes
|
||||||
|
double ratio; // err per added bit
|
||||||
|
};
|
||||||
|
|
||||||
|
// Find next strictly-larger candidate index for a tensor
|
||||||
|
auto next_distinct_idx = [&](const tensor_info &ti) -> int {
|
||||||
|
const auto &cand = ti.candidate;
|
||||||
|
const auto &cur = cand[ti.choice];
|
||||||
|
int j = ti.choice + 1;
|
||||||
|
while (j < (int)cand.size() && cand[j].bytes == cur.bytes) ++j;
|
||||||
|
return j < (int)cand.size() ? j : -1;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto recompute_best_upgrade = [&]() -> upgrade {
|
||||||
|
const double eps = 1e-12;
|
||||||
|
upgrade best{-1, -1, 0.0, 0, -1.0};
|
||||||
|
for (int i = 0; i < (int)all.size(); ++i) {
|
||||||
|
const auto &ti = all[i];
|
||||||
|
if (ti.choice >= (int)ti.candidate.size() - 1) { continue; }
|
||||||
|
|
||||||
|
int j = next_distinct_idx(ti);
|
||||||
|
if (j < 0) { continue; } // no larger-size candidate remains
|
||||||
|
|
||||||
|
const auto &cur = ti.candidate[ti.choice];
|
||||||
|
const auto &nxt = ti.candidate[j];
|
||||||
|
|
||||||
|
size_t delta_bytes = nxt.bytes - cur.bytes;
|
||||||
|
if (delta_bytes == 0) { continue; } // should not happen after dedup, but be safe
|
||||||
|
|
||||||
|
double err = (double)cur.error - (double)nxt.error;
|
||||||
|
err = std::max(err, 0.0); // do not penalize due to sampling noise
|
||||||
|
|
||||||
|
double ratio = err / (double)(delta_bytes * 8ull);
|
||||||
|
if (ratio > best.ratio + eps || (std::abs(ratio - best.ratio) <= eps && delta_bytes < best.delta_bytes)) {
|
||||||
|
best = upgrade{i, j, err, delta_bytes, ratio};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return best;
|
||||||
|
};
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
upgrade up = recompute_best_upgrade();
|
||||||
|
if (up.idx < 0) { break; }
|
||||||
|
|
||||||
|
size_t now_bytes = current_total_bytes();
|
||||||
|
size_t next_bytes = now_bytes + up.delta_bytes;
|
||||||
|
double bpw_next = (double)next_bytes * 8.0 / (double)tw;
|
||||||
|
|
||||||
|
if (bpw_next <= (double)target_bpw + 1e-12) {
|
||||||
|
all[up.idx].choice = up.next;
|
||||||
|
bpw_now = bpw_next;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We might still be below target but taking any single upgrade overshoots.
|
||||||
|
{
|
||||||
|
double under_gap = (double)target_bpw - bpw_now;
|
||||||
|
|
||||||
|
upgrade best_over{-1, -1, 0.0, 0, -1.0};
|
||||||
|
double best_over_gap = 1e300;
|
||||||
|
|
||||||
|
size_t now_bytes = current_total_bytes();
|
||||||
|
|
||||||
|
for (int i = 0; i < (int)all.size(); ++i) {
|
||||||
|
const auto &ti = all[i];
|
||||||
|
if (ti.choice >= (int)ti.candidate.size() - 1) { continue; }
|
||||||
|
|
||||||
|
int j = next_distinct_idx(ti);
|
||||||
|
if (j < 0) { continue; }
|
||||||
|
|
||||||
|
const auto &cur = ti.candidate[ti.choice];
|
||||||
|
const auto &nxt = ti.candidate[j];
|
||||||
|
|
||||||
|
size_t delta_bytes = nxt.bytes - cur.bytes;
|
||||||
|
if (delta_bytes == 0) { continue; }
|
||||||
|
|
||||||
|
size_t over_bytes = now_bytes + delta_bytes;
|
||||||
|
double bpw_over = (double)over_bytes * 8.0 / (double)tw;
|
||||||
|
|
||||||
|
double over_gap = std::abs(bpw_over - (double)target_bpw);
|
||||||
|
|
||||||
|
double err = (double)cur.error - (double)nxt.error;
|
||||||
|
if (err < 0.0) { err = 0.0; }
|
||||||
|
double ratio = err / (double)(delta_bytes * 8ull);
|
||||||
|
|
||||||
|
if (over_gap < best_over_gap - 1e-12 || (std::abs(over_gap - best_over_gap) <= 1e-12 && ratio > best_over.ratio)) {
|
||||||
|
best_over_gap = over_gap;
|
||||||
|
best_over = upgrade{i, j, err, delta_bytes, ratio};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (best_over.idx >= 0) {
|
||||||
|
if (best_over_gap < under_gap) {
|
||||||
|
all[best_over.idx].choice = best_over.next;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the override map
|
||||||
|
std::unordered_map<std::string, ggml_type> overrides;
|
||||||
|
LLAMA_LOG_INFO("%s: - estimated tensor quantization mix to achieve %.4f bpw at lowest ppl\n", __func__, target_bpw);
|
||||||
|
for (const auto & ti : all) {
|
||||||
|
LLAMA_LOG_INFO("\t%s: %45s - \t%8s, \t%1.4f bpw,\terror: %.4f\n",
|
||||||
|
__func__, ggml_get_name(ti.w->tensor), ggml_type_name(ti.candidate[ti.choice].type), ti.candidate[ti.choice].bpw, ti.candidate[ti.choice].error);
|
||||||
|
overrides[ggml_get_name(ti.w->tensor)] = ti.candidate[ti.choice].type;
|
||||||
|
}
|
||||||
|
return overrides;
|
||||||
|
}
|
||||||
|
|
||||||
static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
|
static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) {
|
||||||
ggml_type default_type;
|
ggml_type default_type;
|
||||||
llama_ftype ftype = params->ftype;
|
llama_ftype ftype = params->ftype;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue