diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index fdda5d35a1..1e24303c52 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -575,6 +575,488 @@ static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * return new_size; } +// Returns per-tensor overrides of quantization types to meet target BPW with best expected quality. +// imatrix_data: map from tensor name -> length (ne[0] * ne[2]) containing per-column E[a^2] by expert +// activations_data: optional map from tensor name -> length (ne[0] * ne[2]) containing per-column E[a] by expert +// bias_lambda: relative weight on bias term (|sum e_j * E[a_j]|) vs MSE term (sum e_j^2 * E[a_j^2]) +static std::unordered_map target_bpw_type( + llama_model_loader & ml, + std::vector> & read_data, + const llama_model & model, + const std::vector & tensors, + const std::map & mapped, + const std::unordered_map> * values_data, + const std::unordered_map> * activations_data, + float target_bpw, + int nthread, + int sample_rows_per_expert = 128, + float bias_lambda = 1.0 +) { + struct candidate_types { + ggml_type type; + float bpw; + size_t bytes; + float error; // lower is better + }; + + struct tensor_info { + const llama_model_loader::llama_tensor_weight * w; + std::vector candidate; // sorted by bpw ascending + int choice = -1; // index into cand + float min_bpw = 0.0; + float max_bpw = 0.0; + size_t n_elements = 0; + }; + + auto name_tn = LLM_TN(model.arch); + + // The candidate types we consider; adjust as needed + const ggml_type base_candidates[] = { + // Model's + GGML_TYPE_IQ1_S, + GGML_TYPE_IQ1_M, + GGML_TYPE_IQ2_XXS, + GGML_TYPE_IQ2_XS, + GGML_TYPE_IQ2_S, + GGML_TYPE_IQ3_XXS, + GGML_TYPE_IQ3_S, + GGML_TYPE_IQ4_XS, + GGML_TYPE_IQ4_NL, + GGML_TYPE_Q2_K, + GGML_TYPE_Q3_K, + GGML_TYPE_Q4_0, + GGML_TYPE_Q4_1, + GGML_TYPE_Q4_K, + GGML_TYPE_Q5_0, + GGML_TYPE_Q5_1, + GGML_TYPE_Q5_K, + GGML_TYPE_Q6_K, + GGML_TYPE_Q8_0 + }; + + auto can_quantize = [&](const ggml_tensor * t) -> bool { + const std::string name = ggml_get_name(t); + bool q = name.rfind("weight") == name.size() - 6; + q &= (ggml_n_dims(t) >= 2); + q &= name.find("_norm.weight") == std::string::npos; + //q &= name != name_tn(LLM_TENSOR_TOKEN_EMBD, "weight"); + //q &= name != name_tn(LLM_TENSOR_OUTPUT, "weight"); + q &= name.find("ffn_gate_inp.weight") == std::string::npos; + q &= name.find("altup") == std::string::npos; + q &= name.find("laurel") == std::string::npos; + q &= name.find("per_layer_model_proj") == std::string::npos; + q &= name != name_tn(LLM_TENSOR_POS_EMBD, "weight"); + q &= name != name_tn(LLM_TENSOR_TOKEN_TYPES, "weight"); + q &= name.find("ssm_conv1d.weight") == std::string::npos; + q &= name.find("shortconv.conv.weight") == std::string::npos; + q &= name.find("time_mix_first.weight") == std::string::npos; + q &= name.find("time_mix_w0.weight") == std::string::npos; + q &= name.find("time_mix_w1.weight") == std::string::npos; + q &= name.find("time_mix_w2.weight") == std::string::npos; + q &= name.find("time_mix_v0.weight") == std::string::npos; + q &= name.find("time_mix_v1.weight") == std::string::npos; + q &= name.find("time_mix_v2.weight") == std::string::npos; + q &= name.find("time_mix_a0.weight") == std::string::npos; + q &= name.find("time_mix_a1.weight") == std::string::npos; + q &= name.find("time_mix_a2.weight") == std::string::npos; + q &= name.find("time_mix_g1.weight") == std::string::npos; + q &= name.find("time_mix_g2.weight") == std::string::npos; + q &= name.find("time_mix_decay_w1.weight") == std::string::npos; + q &= name.find("time_mix_decay_w2.weight") == std::string::npos; + q &= name.find("time_mix_lerp_fused.weight") == std::string::npos; + q &= name.find("attn_rel_b.weight") == std::string::npos; + return q; + }; + + auto get_values = [&](const std::string & tensor_name) -> const float * { + if (!values_data) { return nullptr; } + const auto it = values_data->find(remap_imatrix(tensor_name, mapped)); + if (it == values_data->end()) { return nullptr; } + return it->second.data(); + }; + + auto get_activations = [&](const std::string & tensor_name) -> const float * { + if (!activations_data) { return nullptr; } + const auto it = activations_data->find(remap_imatrix(tensor_name, mapped)); + if (it == activations_data->end()) { return nullptr; } + return it->second.data(); + }; + + auto total_bytes = [](const ggml_tensor * t, const ggml_type typ) -> size_t { + const int64_t n_per_row = t->ne[0]; + const int64_t nrows = t->ne[1]; + const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1; + const size_t row_sz = ggml_row_size(typ, n_per_row); + return (size_t)ne2 * (size_t)nrows * row_sz; + }; + + auto tensor_bpw = [&](const ggml_tensor * t, const ggml_type typ) -> double { + const int64_t nelem = ggml_nelements(t); + const size_t bytes = total_bytes(t, typ); + return bytes * 8.0 / nelem; + }; + + auto is_compatible = [&](const ggml_tensor * t, const ggml_type typ) -> bool { + const int64_t n_per_row = t->ne[0]; + const int64_t blck = ggml_blck_size(typ); + if (blck <= 1) { return true; } // FP16/BF16/Q8_0 etc + return n_per_row % blck == 0; + }; + + auto make_compatible = [&](const ggml_tensor * t, const ggml_type typ) -> ggml_type { + if (is_compatible(t, typ)) { return typ; } + ggml_type fb = fallback_type(typ); + if (is_compatible(t, fb)) { return fb; } + return GGML_TYPE_F16; // final guard + }; + + // Estimate error for a given type using a sampled subset of rows. + // Uses both imatrix (E[a^2]) and activations (E[a]) if available. + auto estimate_error = [&](const ggml_tensor * t, const float * f32_data, const ggml_type typ, const float * values_all, const float * activations_all) -> double { + const int64_t n_per_row = t->ne[0]; + const int64_t nrows = t->ne[1]; + const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1; + + const ggml_type_traits * traits = ggml_get_type_traits(typ); + if (!traits || !traits->to_float) { + // cannot dequantize candidate -> assign very high error + return 1e35f; + } + + // Sampling plan: for each expert slice, take up to sample_rows rows spread uniformly + const int64_t rows_per_expert = nrows; + const int64_t sample_rows = std::max(1, std::min(rows_per_expert, sample_rows_per_expert)); + const int64_t stride = std::max(1, rows_per_expert / sample_rows); + + const size_t row_sz = ggml_row_size(typ, n_per_row); + std::vector qbuf(row_sz * sample_rows); + std::vector f32_sample(sample_rows * n_per_row); + std::vector deq(sample_rows * n_per_row); + + float total_err = 0.0; + + for (int64_t i03 = 0; i03 < ne2; ++i03) { + const float * value = values_all ? (values_all + i03 * n_per_row) : nullptr; + const float * activation = activations_all ? (activations_all + i03 * n_per_row) : nullptr; + + // Assemble sampled rows into contiguous f32_sample + int64_t rs = 0; + for (int64_t r = 0; r < rows_per_expert && rs < sample_rows; r += stride) { + const float * src = f32_data + i03 * (n_per_row * rows_per_expert) + r * n_per_row; + std::memcpy(f32_sample.data() + rs * n_per_row, src, sizeof(float) * n_per_row); + ++rs; + } + if (rs == 0) { continue; } + + // Quantize sampled rows in one chunk; pass the imatrix for this expert slice + const size_t got = ggml_quantize_chunk(typ, f32_sample.data(), qbuf.data(), 0, rs, n_per_row, value); + (void)got; // not strictly needed here + + // Dequantize + traits->to_float(qbuf.data(), deq.data(), rs * n_per_row); + + // Compute error proxy per sampled row + for (int64_t s = 0; s < rs; ++s) { + const float * xs = f32_sample.data() + s * n_per_row; + const float * ys = deq.data() + s * n_per_row; + + float mse_w = 0.0; + float bias = 0.0; + float bias_sum = 0.0; + + if (value) { + for (int64_t j = 0; j < n_per_row; ++j) { + const float e = ys[j] - xs[j]; + mse_w += e * e * value[j]; + if (activation) { + bias_sum += e * activation[j]; + } + } + } else { + for (int64_t j = 0; j < n_per_row; ++j) { + const float e = ys[j] - xs[j]; + mse_w += e*e; + if (activation) { + bias_sum += e * activation[j]; + } + } + } + + if (activation) { + bias = std::abs(bias_sum); + } + + // Normalize by n_per_row to get a per-row average scale + float row_err = mse_w / std::max(1, n_per_row); + if (bias_lambda != 0.0) { + row_err += bias_lambda * (bias / std::max(1, n_per_row)); + } + + total_err += row_err; + } + + // Scale for the rows we didn't sample in this expert: multiply by stride-ish factor + const float scale_rows = rows_per_expert / std::max(1, rs); + total_err *= scale_rows; + } + + return total_err; + }; + + // Produce per-tensor candidate lists + std::vector all; + all.reserve(tensors.size()); + + for (const auto * tw : tensors) { + // Temporary workers for dequantization + std::vector workers; + workers.reserve(std::max(1, nthread)); + + ggml_tensor * t = tw->tensor; + const std::string name = ggml_get_name(t); + + if (!can_quantize(t)) { + continue; + } + + LLAMA_LOG_INFO("\t%s: - processing tensor %45s \t(%12d elements)\n", __func__, name.c_str(), (int)ggml_nelements(t)); + if (!ml.use_mmap) { + if (read_data.size() < ggml_nbytes(t)) { + read_data.resize(ggml_nbytes(t)); + } + t->data = read_data.data(); + } + ml.load_data_for(t); + + // Prepare f32 weights for error estimates + const int64_t nelem = ggml_nelements(t); + std::vector> f32_conv_buf; + float * f32_data = nullptr; + + if (t->type == GGML_TYPE_F32) { + f32_data = (float *)t->data; + } else { + llama_tensor_dequantize_impl(t, f32_conv_buf, workers, nelem, nthread); + f32_data = (float *)f32_conv_buf.data(); + } + + const float * values = get_values(name); + const float * activations = get_activations(name); + + tensor_info info; + info.w = tw; + info.n_elements = nelem; + + // Candidate build with compatibility handling and availability checks + for (ggml_type ts_type : base_candidates) { + // Skip IQ* without imatrix + if (is_iq(ts_type) && !values) { continue; } + ggml_type tt = make_compatible(t, ts_type); + // After fallback, if still incompatible, skip + if (!is_compatible(t, tt)) { continue; } + + // Compute bpw and bytes + auto bpw = (float)tensor_bpw(t, tt); + size_t bytes = total_bytes(t, tt); + + // Estimate error + auto err = (float)estimate_error(t, f32_data, tt, values, activations); + + info.candidate.push_back(candidate_types{tt, bpw, bytes, err}); + } + + if (info.candidate.empty()) { + // as a last resort, keep original type + float bpw = ggml_nbytes(t) * 8.0f / nelem; + info.candidate.push_back(candidate_types{t->type, bpw, ggml_nbytes(t), 0.0}); + } + + // Sort by bpw ascending + std::sort(info.candidate.begin(), info.candidate.end(), [](const candidate_types &a, const candidate_types &b) { + if (a.bpw != b.bpw) { return a.bpw < b.bpw; } + if (a.error != b.error) { return a.error < b.error; } + return a.bytes < b.bytes; + }); + + // collapse candidates with identical storage size (bytes) + { + std::vector uniq; + uniq.reserve(info.candidate.size()); + + for (size_t i = 0; i < info.candidate.size(); ) { + size_t j = i + 1; + candidate_types best = info.candidate[i]; + // group same-byte entries, keep the one with the lowest error + while (j < info.candidate.size() && info.candidate[j].bytes == info.candidate[i].bytes) { + if (info.candidate[j].error < best.error) { best = info.candidate[j]; } + ++j; + } + uniq.push_back(best); + i = j; + } + info.candidate.swap(uniq); + } + + // Initialize choice at the smallest bpw candidate + info.choice = 0; + info.min_bpw = info.candidate.front().bpw; + info.max_bpw = info.candidate.back().bpw; + + all.push_back(std::move(info)); + } + + if (all.empty()) { return {}; } + + // Greedy allocation from minimum bpw upward to reach target_bpw + // Start with minimal bpw assignment + auto current_total_bytes = [&]() -> size_t { + size_t b = 0; + for (const auto & ti : all) { + b += ti.candidate[ti.choice].bytes; + } + return b; + }; + + auto total_weights = [&]() -> size_t { + size_t w = 0; + for (const auto & ti : all) { + w += ti.n_elements; + } + return w; + }; + + const size_t tw = total_weights(); + auto current_bpw = [&]() -> double { + return (double)current_total_bytes() * 8.0f / (double)tw; + }; + + // Precompute current bpw + double bpw_now = current_bpw(); + + // If minimal bpw is already above the target, we're constrained by geometry; return closest (min bpw) + if (bpw_now >= target_bpw) { + std::unordered_map overrides; + for (const auto & ti : all) { + overrides[ggml_get_name(ti.w->tensor)] = ti.candidate[ti.choice].type; + } + return overrides; + } + + struct upgrade { + int idx; // tensor index + int next; // next candidate index (strictly larger bytes) + double err; // error reduction + size_t delta_bytes; // increase in bytes + double ratio; // err per added bit + }; + + // Find next strictly-larger candidate index for a tensor + auto next_distinct_idx = [&](const tensor_info &ti) -> int { + const auto &cand = ti.candidate; + const auto &cur = cand[ti.choice]; + int j = ti.choice + 1; + while (j < (int)cand.size() && cand[j].bytes == cur.bytes) ++j; + return j < (int)cand.size() ? j : -1; + }; + + auto recompute_best_upgrade = [&]() -> upgrade { + const double eps = 1e-12; + upgrade best{-1, -1, 0.0, 0, -1.0}; + for (int i = 0; i < (int)all.size(); ++i) { + const auto &ti = all[i]; + if (ti.choice >= (int)ti.candidate.size() - 1) { continue; } + + int j = next_distinct_idx(ti); + if (j < 0) { continue; } // no larger-size candidate remains + + const auto &cur = ti.candidate[ti.choice]; + const auto &nxt = ti.candidate[j]; + + size_t delta_bytes = nxt.bytes - cur.bytes; + if (delta_bytes == 0) { continue; } // should not happen after dedup, but be safe + + double err = (double)cur.error - (double)nxt.error; + err = std::max(err, 0.0); // do not penalize due to sampling noise + + double ratio = err / (double)(delta_bytes * 8ull); + if (ratio > best.ratio + eps || (std::abs(ratio - best.ratio) <= eps && delta_bytes < best.delta_bytes)) { + best = upgrade{i, j, err, delta_bytes, ratio}; + } + } + return best; + }; + + while (true) { + upgrade up = recompute_best_upgrade(); + if (up.idx < 0) { break; } + + size_t now_bytes = current_total_bytes(); + size_t next_bytes = now_bytes + up.delta_bytes; + double bpw_next = (double)next_bytes * 8.0 / (double)tw; + + if (bpw_next <= (double)target_bpw + 1e-12) { + all[up.idx].choice = up.next; + bpw_now = bpw_next; + } else { + break; + } + } + + // We might still be below target but taking any single upgrade overshoots. + { + double under_gap = (double)target_bpw - bpw_now; + + upgrade best_over{-1, -1, 0.0, 0, -1.0}; + double best_over_gap = 1e300; + + size_t now_bytes = current_total_bytes(); + + for (int i = 0; i < (int)all.size(); ++i) { + const auto &ti = all[i]; + if (ti.choice >= (int)ti.candidate.size() - 1) { continue; } + + int j = next_distinct_idx(ti); + if (j < 0) { continue; } + + const auto &cur = ti.candidate[ti.choice]; + const auto &nxt = ti.candidate[j]; + + size_t delta_bytes = nxt.bytes - cur.bytes; + if (delta_bytes == 0) { continue; } + + size_t over_bytes = now_bytes + delta_bytes; + double bpw_over = (double)over_bytes * 8.0 / (double)tw; + + double over_gap = std::abs(bpw_over - (double)target_bpw); + + double err = (double)cur.error - (double)nxt.error; + if (err < 0.0) { err = 0.0; } + double ratio = err / (double)(delta_bytes * 8ull); + + if (over_gap < best_over_gap - 1e-12 || (std::abs(over_gap - best_over_gap) <= 1e-12 && ratio > best_over.ratio)) { + best_over_gap = over_gap; + best_over = upgrade{i, j, err, delta_bytes, ratio}; + } + } + + if (best_over.idx >= 0) { + if (best_over_gap < under_gap) { + all[best_over.idx].choice = best_over.next; + } + } + } + + // Build the override map + std::unordered_map overrides; + LLAMA_LOG_INFO("%s: - estimated tensor quantization mix to achieve %.4f bpw at lowest ppl\n", __func__, target_bpw); + for (const auto & ti : all) { + LLAMA_LOG_INFO("\t%s: %45s - \t%8s, \t%1.4f bpw,\terror: %.4f\n", + __func__, ggml_get_name(ti.w->tensor), ggml_type_name(ti.candidate[ti.choice].type), ti.candidate[ti.choice].bpw, ti.candidate[ti.choice].error); + overrides[ggml_get_name(ti.w->tensor)] = ti.candidate[ti.choice].type; + } + return overrides; +} + static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { ggml_type default_type; llama_ftype ftype = params->ftype;