diff --git a/include/llama.h b/include/llama.h index f862930099..bf4e688ca4 100644 --- a/include/llama.h +++ b/include/llama.h @@ -378,9 +378,14 @@ extern "C" { bool pure; // quantize all tensors to the default type bool keep_split; // quantize to the same number of shards void * imatrix; // pointer to importance matrix data + void * activations; // pointer to activations data void * kv_overrides; // pointer to vector containing overrides void * tensor_types; // pointer to vector containing tensor types void * prune_layers; // pointer to vector containing layer indices to prune + float target_bpw; // target bits per weight (bpw) + bool keep_bpw_state; // keep bpw state file + void * bpw_state; // pointer to bpw state file + bool no_importance; // allocate target bpw budget equitably across all tensors } llama_model_quantize_params; typedef struct llama_logit_bias { diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index bc4b05c3b5..33b7f7e584 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -4,14 +4,20 @@ #include "llama-model-loader.h" #include +#include #include #include #include +#include #include #include +#include +#include +#include #include #include #include +#include // Quantization types. Changes to this struct must be replicated in quantize.cpp struct tensor_quantization { @@ -19,6 +25,87 @@ struct tensor_quantization { ggml_type quant = GGML_TYPE_COUNT; }; +static bool is_quantizable(const std::string & name, const llm_arch arch, const llama_model_quantize_params * params) { + if (params->only_copy) { return false; } + + const auto tn = LLM_TN(arch); + + // This used to be a regex, but has an extreme cost to compile times. + bool q = name.size() >= 6 && name.rfind("weight") == name.size() - 6; // ends with 'weight'? + + // Do not quantize norm tensors + q &= name.find("_norm.weight") == std::string::npos; + + // Do not quantize expert gating tensors + // NOTE: can't use LLM_TN here because the layer number is not known + q &= name.find("ffn_gate_inp.weight") == std::string::npos; + + // These are very small (e.g. 4x4) + q &= name.find("altup") == std::string::npos; + q &= name.find("laurel") == std::string::npos; + + // These are not too big so keep them as it is + q &= name.find("per_layer_model_proj") == std::string::npos; + + // Do not quantize positional embeddings and token types (BERT) + q &= name != tn(LLM_TENSOR_POS_EMBD, "weight"); + q &= name != tn(LLM_TENSOR_TOKEN_TYPES, "weight"); + + // Do not quantize Jamba, Mamba, LFM2's small yet 2D weights + // NOTE: can't use LLM_TN here because the layer number is not known + q &= name.find("ssm_conv1d.weight") == std::string::npos; + q &= name.find("shortconv.conv.weight") == std::string::npos; + + // Do not quantize ARWKV, RWKV's small yet 2D weights + q &= name.find("time_mix_first.weight") == std::string::npos; + q &= name.find("time_mix_w0.weight") == std::string::npos; + q &= name.find("time_mix_w1.weight") == std::string::npos; + q &= name.find("time_mix_w2.weight") == std::string::npos; + q &= name.find("time_mix_v0.weight") == std::string::npos; + q &= name.find("time_mix_v1.weight") == std::string::npos; + q &= name.find("time_mix_v2.weight") == std::string::npos; + q &= name.find("time_mix_a0.weight") == std::string::npos; + q &= name.find("time_mix_a1.weight") == std::string::npos; + q &= name.find("time_mix_a2.weight") == std::string::npos; + q &= name.find("time_mix_g1.weight") == std::string::npos; + q &= name.find("time_mix_g2.weight") == std::string::npos; + q &= name.find("time_mix_decay_w1.weight") == std::string::npos; + q &= name.find("time_mix_decay_w2.weight") == std::string::npos; + q &= name.find("time_mix_lerp_fused.weight") == std::string::npos; + + // Do not quantize relative position bias (T5) + q &= name.find("attn_rel_b.weight") == std::string::npos; + + return q; +} + +static enum ggml_type fallback_type(const enum ggml_type new_type) { + switch (new_type) { + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: + return GGML_TYPE_Q4_0; // symmetric-ish fallback + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_IQ4_XS: + return GGML_TYPE_IQ4_NL; + case GGML_TYPE_Q4_K: + return GGML_TYPE_Q5_0; + case GGML_TYPE_Q5_K: + return GGML_TYPE_Q5_1; + case GGML_TYPE_Q6_K: + return GGML_TYPE_Q8_0; + default: + return new_type; + } +} + static void zeros(std::ofstream & file, size_t n) { char zero = 0; for (size_t i = 0; i < n; ++i) { @@ -66,7 +153,6 @@ static std::string remap_imatrix (const std::string & orig_name, const std::map< for (const auto & p : mapped) { if (p.second == blk) { - LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first); return new_name.replace(match.position(1), match.length(1), std::to_string(p.first)); } } @@ -89,10 +175,11 @@ struct quantize_state_impl { int i_ffn_gate = 0; int i_ffn_up = 0; - int n_k_quantized = 0; - int n_fallback = 0; + int n_k_quantized = 0; + int n_fallback = 0; - bool has_imatrix = false; + bool has_imatrix = false; + bool has_activations = false; // used to figure out if a model shares tok_embd with the output weight bool has_output = false; @@ -530,6 +617,1127 @@ static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * return new_size; } +static std::atomic bpw_stop{ false }; + +static void signal_handler(int) { + bpw_stop.store(true, std::memory_order_relaxed); +} + +// Returns tensor type overrides to meet a global bpw target +static std::unordered_map target_bpw_type( + llama_model_loader & ml, + const llama_model & model, + const std::vector & tensors, + const std::map & mapped, + const std::unordered_map> * values_data, + const std::unordered_map> * activations_data, + const llama_model_quantize_params * params, + int nthread +) { + bpw_stop.store(false, std::memory_order_relaxed); + // SIGINT/SIGTERM signal handlers + struct signal_scope_guard { + using handler_t = void (*)(int); + handler_t prev_int = SIG_DFL; + handler_t prev_term = SIG_DFL; + signal_scope_guard() { + prev_int = std::signal(SIGINT, signal_handler); + prev_term = std::signal(SIGTERM, signal_handler); + } + ~signal_scope_guard() { + std::signal(SIGINT, prev_int); + std::signal(SIGTERM, prev_term); + } + } signal_guard; + + struct candidate_types { + ggml_type type = GGML_TYPE_COUNT; + float bpw = 0.0f; + size_t bytes = 0; + double error = 0.0; + double mse = 0.0; + double proj = 0.0; + }; + + struct tensor_info { + const llama_model_loader::llama_tensor_weight * w = nullptr; + std::vector candidate; + int choice = -1; + float min_bpw = 0.0; + float max_bpw = 0.0; + size_t n_elements = 0; + }; + + // subset of quantization types with the best accuracy/size tradeoff + constexpr ggml_type quant_types[] = { + GGML_TYPE_IQ1_S, + GGML_TYPE_IQ1_M, + GGML_TYPE_IQ2_XXS, + GGML_TYPE_IQ2_XS, + GGML_TYPE_IQ2_S, + GGML_TYPE_Q2_K, + GGML_TYPE_IQ3_XXS, + GGML_TYPE_Q3_K, + GGML_TYPE_IQ4_XS, + GGML_TYPE_IQ4_NL, + GGML_TYPE_Q4_K, + GGML_TYPE_Q5_K, + GGML_TYPE_Q6_K, + GGML_TYPE_Q8_0, +#ifdef GGML_USE_METAL + GGML_TYPE_F16 +#else + GGML_TYPE_BF16 +#endif + }; + + constexpr double epsilon = 1e-12; + constexpr double infinity = std::numeric_limits::infinity(); + constexpr uint32_t file_magic = 0x42505731; // BPW1 + constexpr uint64_t arbitrary_magic = 0xeabada55cafed00d; + const char * func = __func__; + + auto tensor_bytes = [](const ggml_tensor * t, const ggml_type typ) -> size_t { + const int64_t n_per_row = t->ne[0]; + const size_t row_sz = ggml_row_size(typ, n_per_row); + return (size_t)ggml_nrows(t) * row_sz; + }; + + auto tensor_bpw = [&](const ggml_tensor * t, const ggml_type typ) -> double { + const size_t bytes = tensor_bytes(t, typ); + return (double)bytes * 8.0 / (double)ggml_nelements(t); + }; + + auto is_compatible = [](const ggml_tensor * t, const ggml_type typ) -> bool { + const int64_t blck = ggml_blck_size(typ); + return blck <= 1 || (t->ne[0] % blck) == 0; + }; + + auto is_iq = [](const enum ggml_type t) { + switch (t) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + return true; + default: + return false; + } + }; + + auto make_compatible = [&](const ggml_tensor * t, const ggml_type typ) -> ggml_type { + if (is_compatible(t, typ)) { return typ; } + const ggml_type fb = fallback_type(typ); + return is_compatible(t, fb) ? fb : GGML_TYPE_F16; + }; + + auto can_quantize = [&](const ggml_tensor * t) -> bool { + if (ggml_n_dims(t) < 2) { return false; } // skip 1D tensors + return is_quantizable(ggml_get_name(t), model.arch, params); + }; + + // Saved state per tensor + struct saved_info { + std::vector candidate; + int choice = -1; + float min_bpw = 0.0f; + float max_bpw = 0.0f; + size_t n_elements = 0; + }; + + auto djb2_hash = [&](const uint8_t * data, const size_t n) -> uint64_t { + uint64_t h = 5381; + for (size_t i = 0; i < n; ++i) { + h = (h << 5) + h + data[i]; + } + return h ? h : arbitrary_magic; + }; + + auto metadata_id = [&](const gguf_context * ctx) -> uint64_t { + const size_t sz = gguf_get_meta_size(ctx); + std::vector buf(sz); + gguf_get_meta_data(ctx, buf.data()); + return djb2_hash(buf.data(), buf.size()); + }; + + char hex[17]; + const uint64_t model_id = metadata_id(ml.meta.get()); + std::snprintf(hex, sizeof(hex), "%016" PRIx64, (uint64_t)model_id); + std::string checkpoint_file = ml.arch_name + "-" + std::string(hex) + ".bpw_state"; + if (params->keep_bpw_state && params->bpw_state) { + const auto * filename = static_cast(params->bpw_state); + std::ifstream ifs(filename, std::ios::binary); + if (ifs.good()) { + checkpoint_file = std::string(filename); + } else { + std::ofstream ofs(filename, std::ios::binary | std::ios::app); + if (ofs.is_open()) { + checkpoint_file = std::string(filename); + ofs.close(); + std::remove(checkpoint_file.c_str()); + } else { + LLAMA_LOG_WARN("%s: %s is not a valid file name. Using %s instead\n", func, filename, checkpoint_file.c_str()); + } + } + } + + auto save_bpw_state = [&](const std::vector & all_vec) { + const std::string tmp = checkpoint_file + ".tmp"; + std::ofstream ofs(tmp, std::ios::binary | std::ios::trunc); + if (!ofs) { return; } + ofs.write((const char *)&file_magic, sizeof(file_magic)); + ofs.write((const char *)&model_id, sizeof(model_id)); + const uint64_t n = all_vec.size(); + ofs.write((const char *)&n, sizeof(n)); + for (const auto & ti : all_vec) { + const std::string name = ggml_get_name(ti.w->tensor); + const auto len = (uint32_t)name.size(); + ofs.write((const char *)&len, sizeof(len)); + ofs.write(name.data(), len); + + const uint64_t cn = ti.candidate.size(); + ofs.write((const char *)&cn, sizeof(cn)); + ofs.write((const char *)&ti.choice, sizeof(ti.choice)); + ofs.write((const char *)&ti.min_bpw, sizeof(ti.min_bpw)); + ofs.write((const char *)&ti.max_bpw, sizeof(ti.max_bpw)); + const uint64_t ne = ti.n_elements; + ofs.write((const char *)&ne, sizeof(ne)); + + for (const auto & c : ti.candidate) { + const int32_t t = c.type; + const uint64_t b = c.bytes; + ofs.write((const char *)&t, sizeof(t)); + ofs.write((const char *)&c.bpw, sizeof(c.bpw)); + ofs.write((const char *)&b, sizeof(b)); + ofs.write((const char *)&c.error, sizeof(c.error)); + } + } + + ofs.close(); + std::remove(checkpoint_file.c_str()); + std::rename(tmp.c_str(), checkpoint_file.c_str()); + LLAMA_LOG_INFO("%s: saved progress for %lu tensors to %s\n", func, all_vec.size(), checkpoint_file.c_str()); + }; + + auto load_bpw_state = [&]() -> std::unordered_map { + std::unordered_map out; + std::ifstream ifs(checkpoint_file, std::ios::binary); + if (!ifs) { return out; } + + uint32_t magic = 0; + uint64_t id = 0; + ifs.read((char *)&magic, sizeof(magic)); + ifs.read((char *)&id, sizeof(id)); + if (magic != file_magic) { + LLAMA_LOG_WARN("%s: invalid resume file, ignoring: %s\n", func, checkpoint_file.c_str()); + return out; + } + if (id != model_id) { + LLAMA_LOG_WARN("%s: model ID mismatch, ignoring: %s\n", func, checkpoint_file.c_str()); + return out; + } + + LLAMA_LOG_INFO("%s: state file found, resuming tensor quantization\n", func); + + uint64_t n = 0; + ifs.read((char *)&n, sizeof(n)); + for (uint64_t i = 0; i < n; ++i) { + uint32_t len = 0; + ifs.read((char *)&len, sizeof(len)); + std::string name(len, '\0'); + ifs.read(name.data(), len); + + uint64_t cn = 0; + ifs.read((char *)&cn, sizeof(cn)); + + saved_info si; + ifs.read((char *)&si.choice, sizeof(si.choice)); + ifs.read((char *)&si.min_bpw, sizeof(si.min_bpw)); + ifs.read((char *)&si.max_bpw, sizeof(si.max_bpw)); + uint64_t ne = 0; + ifs.read((char *)&ne, sizeof(ne)); + si.n_elements = (size_t)ne; + + si.candidate.resize(cn); + for (auto & s : si.candidate) { + int32_t t = 0; + uint64_t b = 0; + ifs.read((char *)&t, sizeof(t)); + s.type = (ggml_type)t; + ifs.read((char *)&s.bpw, sizeof(s.bpw)); + ifs.read((char *)&b, sizeof(b)); + s.bytes = (size_t)b; + ifs.read((char *)&s.error, sizeof(s.error)); + } + + out.emplace(std::move(name), std::move(si)); + } + + LLAMA_LOG_INFO("%s: loaded bpw state for %lu tensors from %s\n", func, out.size(), checkpoint_file.c_str()); + return out; + }; + + auto delete_bpw_state = [&] { + std::ifstream ifs(checkpoint_file); + if (ifs.good() && !params->keep_bpw_state) { + LLAMA_LOG_INFO("%s: deleting %s\n", func, checkpoint_file.c_str()); + std::remove(checkpoint_file.c_str()); + } + }; + + auto check_signal_handler = [&](const std::vector & all_vec) { + if (bpw_stop.load(std::memory_order_relaxed)) { + LLAMA_LOG_INFO("\n%s: saving progress for %lu tensors to %s\n", func, all_vec.size(), checkpoint_file.c_str()); + save_bpw_state(all_vec); + throw std::runtime_error("user interrupted the process"); + } + }; + + // Estimate error for a given type using a sampled subset of rows + auto estimate_error = [&](const ggml_tensor * t, + const ggml_type quant_type, + const std::vector & f32_sample, + const std::vector & rows_sample, + const float * values_sample, + const float * activations_sample, + std::vector & quantized_buffer, + std::vector & dequantized_buffer, + float tensor_bias_lambda, + const float * slice_bias_lambda, + double * out_mse = nullptr, + double * out_proj = nullptr) -> double + { + const int64_t n_per_row = t->ne[0]; + const int64_t nrows = t->ne[1]; + const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1; + const size_t sample_elems = f32_sample.size(); + const size_t sample_rows = n_per_row > 0 ? sample_elems / (size_t)n_per_row : 0; + + if (sample_rows == 0) { + if (out_mse) { *out_mse = 0.0; } + if (out_proj) { *out_proj = 0.0; } + return 0.0; + } + + size_t expected_rows = 0; + for (int64_t s = 0; s < ne2; ++s) { + expected_rows += (size_t)rows_sample[s]; + } + + if (expected_rows != sample_rows) { + if (out_mse) { *out_mse = infinity; } + if (out_proj) { *out_proj = 0.0; } + return infinity; + } + + const size_t row_sz = ggml_row_size(quant_type, n_per_row); + const size_t buf_sz = row_sz * sample_rows; + + if (quantized_buffer.size() < buf_sz) { quantized_buffer.resize(buf_sz); } + if (dequantized_buffer.size() < sample_elems) { dequantized_buffer.resize(sample_elems); } + + const bool has_values = values_sample != nullptr; + const bool has_activations = activations_sample != nullptr; + + // Bias denominators per slice + std::vector bias_denom(ne2, 0.0); + if (has_activations) { + for (int64_t s = 0; s < ne2; ++s) { + const float * v = has_values ? values_sample + s * n_per_row : nullptr; + const float * a = activations_sample + s * n_per_row; + double denom = 0.0; + for (int64_t j = 0; j < n_per_row; ++j) { + const double w = v ? std::max(0.0f, v[j]) : 1.0; + const double aj = a[j]; + denom += w * aj * aj; + } + + bias_denom[s] = denom; + } + } + + // Row squared norms (weighted if values present) + std::vector row_sq_norm(sample_rows, 0.0); + { + size_t off = 0; + size_t ridx = 0; + for (int64_t s = 0; s < ne2; ++s) { + const int64_t rs = rows_sample[s]; + if (rs == 0) { continue; } + + const float * v = has_values ? values_sample + s * n_per_row : nullptr; + for (int64_t r = 0; r < rs; ++r, ++ridx) { + const float * x = f32_sample.data() + off; + double sum = 0.0; + if (v) { + for (int64_t j = 0; j < n_per_row; ++j) { + const double w = std::max(0.0f, v[j]); + const double xx = x[j]; + sum += w * xx * xx; + } + } else { + for (int64_t j = 0; j < n_per_row; ++j) { + const double xx = x[j]; + sum += xx * xx; + } + } + + row_sq_norm[ridx] = sum; + off += (size_t)n_per_row; + } + } + } + + // Quantize per slice into quantized_buffer + { + size_t qoff = 0; + size_t foff = 0; + for (int64_t s = 0; s < ne2; ++s) { + const int64_t rs = rows_sample[s]; + if (rs == 0) { continue; } + + const float * v = has_values ? values_sample + s * n_per_row : nullptr; + (void)ggml_quantize_chunk(quant_type, f32_sample.data() + foff, quantized_buffer.data() + qoff, 0, rs, n_per_row, v); + qoff += row_sz * (size_t)rs; + foff += (size_t)rs * (size_t)n_per_row; + } + } + + // Dequantize into dequantized_buffer + { + if (quant_type == GGML_TYPE_F16) { + for (size_t r = 0; r < sample_rows; ++r) { + auto src = (const ggml_fp16_t *)(quantized_buffer.data() + r * row_sz); + float * dst = dequantized_buffer.data() + r * (size_t)n_per_row; + ggml_fp16_to_fp32_row(src, dst, (int)n_per_row); + } + } else if (quant_type == GGML_TYPE_BF16) { + for (size_t r = 0; r < sample_rows; ++r) { + auto src = (const ggml_bf16_t *)(quantized_buffer.data() + r * row_sz); + float * dst = dequantized_buffer.data() + r * (size_t)n_per_row; + ggml_bf16_to_fp32_row(src, dst, (int)n_per_row); + } + } else { + const ggml_type_traits * traits = ggml_get_type_traits(quant_type); + if (!traits || !traits->to_float) { + if (out_mse) { *out_mse = infinity; } + if (out_proj) { *out_proj = 0.0; } + return infinity; + } + for (size_t r = 0; r < sample_rows; ++r) { + const uint8_t * src = quantized_buffer.data() + r * row_sz; + float * dst = dequantized_buffer.data() + r * (size_t)n_per_row; + traits->to_float(src, dst, (int)n_per_row); + } + } + } + + // Compute error per slice with trimmed aggregation + auto trimmed_mean = [](std::vector & v) -> double { + const int64_t n = (int64_t)v.size(); + if (n == 0) { return 0.0; } + double sum = std::accumulate(v.begin(), v.end(), 0.0); + if (n < 50) { return sum / (double)n; } // too few elements to trim + int64_t k = (int64_t) std::floor(0.025 * (double)n); // trim 5% (2.5% each side) + std::sort(v.begin(), v.end()); + const auto num = (double)(n - 2 * k); + sum = std::accumulate(v.begin() + k, v.begin() + (n - k), 0.0); + return sum / std::max(1.0, num); + }; + + size_t off = 0; + size_t ridx = 0; + double total_mse = 0.0; + double total_proj = 0.0; + double total_bias = 0.0; + for (int64_t s = 0; s < ne2; ++s) { + const int64_t rs = rows_sample[s]; + if (rs == 0) { continue; } + + const float * v = has_values ? values_sample + s * n_per_row : nullptr; + const float * a = has_activations ? activations_sample + s * n_per_row : nullptr; + const double denom_bias = has_activations ? bias_denom[s] : 0.0; + std::vector row_mse_norm; + row_mse_norm.reserve(rs); + std::vector row_proj_norm; + if (a) { row_proj_norm.reserve(rs); } + + for (int64_t r = 0; r < rs; ++r, ++ridx) { + const float * x = f32_sample.data() + off; + const float * y = dequantized_buffer.data() + off; + double w_mse = 0.0; + double bias_num = 0.0; + for (int64_t j = 0; j < n_per_row; ++j) { + const double wj = v ? std::max(0.0f, v[j]) : 1.0; + const double e = y[j] - x[j]; + w_mse += wj * e * e; + if (a) { bias_num += wj * e * a[j]; } + } + + const double denom_x = row_sq_norm[ridx]; + const double m_norm = w_mse / (denom_x + epsilon); + row_mse_norm.push_back(std::isfinite(m_norm) ? m_norm : infinity); + + if (a) { + double p_norm = 0.0; + if (denom_bias > 0.0) { + const double proj = bias_num * bias_num / (denom_bias + epsilon); + p_norm = std::isfinite(proj) ? proj : 0.0; + } + + row_proj_norm.push_back(p_norm); + } + + off += (size_t)n_per_row; + } + + const double slice_mse = trimmed_mean(row_mse_norm) * (double)nrows; + const double slice_proj = a ? trimmed_mean(row_proj_norm) * (double)nrows : 0.0; + + total_mse += slice_mse; + total_proj += slice_proj; + + const double bl = slice_bias_lambda ? (double)std::max(0.0f, slice_bias_lambda[s]) : (double)tensor_bias_lambda; + total_bias += bl * slice_proj; + + if (!std::isfinite(total_mse) || !std::isfinite(total_proj) || !std::isfinite(total_bias)) { + if (out_mse) { *out_mse = infinity; } + if (out_proj) { *out_proj = 0.0; } + return infinity; + } + } + + if (out_mse) { *out_mse = total_mse; } + if (out_proj) { *out_proj = total_proj; } + + const double total_err = total_mse + total_bias; + return std::isfinite(total_err) ? total_err : infinity; + }; + + // Returns lambda per slice or 0.0 if no activations + auto estimate_lambda = [&](const float * values, const float * activations, const int64_t n_per_row, const int64_t ne2) -> std::vector { + const int64_t ns = std::max(1, ne2); + std::vector lambdas(ns, 0.0f); + if (!activations) { return lambdas; } + + for (int64_t s = 0; s < ns; ++s) { + const float * v = values ? values + s * n_per_row : nullptr; + const float * a = activations + s * n_per_row; + double s1 = 0.0; + double s2 = 0.0; + for (int64_t j = 0; j < n_per_row; ++j) { + const double w = v ? std::max(0.0f, v[j]) : 1.0; + const double aw = std::sqrt(w) * a[j]; + const double z = aw * aw; + s1 += z; + s2 += z * z; + } + + float l = 0.0f; + if (s1 > 0.0) { + const auto n = (double)n_per_row; + const double c = std::max(0.0, s2 / (s1 * s1 + epsilon) - 1.0 / n); + l = (float)std::clamp(12.0 * (c / (c + 1.0)), 0.0, 16.0); + } + + lambdas[(size_t)s] = l; + } + + return lambdas; + }; + + const auto bpw_data = load_bpw_state(); + + // Parallelize tensor processing - courtesy of https://github.com/ddh0 + auto process_tensor = [&](const llama_model_loader::llama_tensor_weight * tw, + std::vector> & thread_local_buffer, + std::mutex & loader_mutex, + std::mutex & log_mutex) -> std::optional + { + ggml_tensor * tensor = tw->tensor; + const std::string name = ggml_get_name(tensor); + if (bpw_stop.load(std::memory_order_relaxed)) { + return std::nullopt; + } + + // check for pre-computed results from a checkpoint file. + auto it_saved = bpw_data.find(name); + if (it_saved != bpw_data.end()) { + tensor_info info; + info.w = tw; + info.candidate = it_saved->second.candidate; + info.choice = it_saved->second.choice; + info.min_bpw = it_saved->second.min_bpw; + info.max_bpw = it_saved->second.max_bpw; + info.n_elements = it_saved->second.n_elements ? it_saved->second.n_elements : (size_t)ggml_nelements(tensor); + return info; + } + { + std::lock_guard lock(log_mutex); + LLAMA_LOG_INFO("\t%s: - processing tensor %45s \t(%12" PRId64 " elements)\n", func, name.c_str(), ggml_nelements(tensor)); + } + + if (!ml.use_mmap) { + if (thread_local_buffer.size() < ggml_nbytes(tensor)) { thread_local_buffer.resize(ggml_nbytes(tensor)); } + tensor->data = thread_local_buffer.data(); + } + { + std::lock_guard lock(loader_mutex); + ml.load_data_for(tensor); + } + + // Dequantize sampled rows into f32_sample + const int64_t n_per_row = tensor->ne[0]; + const int64_t nrows_total = tensor->ne[1]; + const int64_t ne2 = tensor->ne[2] > 0 ? tensor->ne[2] : 1; + + // Compute rows based on tensor shape and slice count + auto sample_rows = [](const int64_t n, const int64_t rows, const int64_t n2, const bool has_acts) -> int64_t { + const double tensor_budget = has_acts ? 1 * 1024 * 1024 : 0.5 * 1024 * 1024; + const double scale_rows = std::clamp(std::sqrt(std::max(1.0, (double)rows) / 4096.0), 0.5, 2.0); // favour more rows for large tensors + const double slice_budget = tensor_budget * scale_rows / std::max(1, n2); + const int64_t min_rows = has_acts ? 128 : 64; + constexpr int64_t max_rows = 4096; // row limit to avoid excessive memory use + int64_t total_rows = std::llround(slice_budget / std::max(1, n)); + total_rows = std::max(min_rows, std::min(total_rows, std::min(rows, max_rows))); + if (rows <= min_rows * 2) { total_rows = rows; } + return total_rows; + }; + + const int64_t rows_sample_per_expert = sample_rows(n_per_row, nrows_total, ne2, activations_data != nullptr); + std::vector f32_sample; + f32_sample.reserve((size_t)ne2 * (size_t)std::min(nrows_total, rows_sample_per_expert) * (size_t)n_per_row); + std::vector rows_sample(ne2, 0); + const ggml_type src_type = tensor->type; + const ggml_type_traits * src_traits = ggml_get_type_traits(src_type); + const bool src_is_quant = ggml_is_quantized(src_type); + const size_t src_row_sz = ggml_row_size(src_type, n_per_row); + + // Convert a single row to fp32 + auto row_to_fp32 = [&](const uint8_t * src, float * dst) { + const ggml_type t = src_type; + if (t == GGML_TYPE_F32) { + std::memcpy(dst, src, sizeof(float) * (size_t)n_per_row); + return; + } + if (t == GGML_TYPE_F16) { + ggml_fp16_to_fp32_row((const ggml_fp16_t *)src, dst, (int)n_per_row); + return; + } + if (t == GGML_TYPE_BF16) { + ggml_bf16_to_fp32_row((const ggml_bf16_t *)src, dst, (int)n_per_row); + return; + } + if (src_is_quant) { + GGML_ASSERT(src_traits && src_traits->to_float); + src_traits->to_float(src, dst, (int)n_per_row); + return; + } + + throw std::runtime_error(format("unsupported src type %s for sampling", ggml_type_name(t))); + }; + + // Sample rows randomly per slice + { + f32_sample.clear(); + std::vector row_buffer(n_per_row); + for (int64_t slice = 0; slice < ne2; ++slice) { + std::mt19937 rng(std::hash{}(name) ^ arbitrary_magic ^ slice); + const int64_t rows_sample_max = std::max(1, std::min(nrows_total, rows_sample_per_expert)); + const int64_t stride = std::max(1, nrows_total / rows_sample_max); + int64_t offset = 0; + if (stride > 1) { + std::uniform_int_distribution dist(0, stride - 1); + offset = dist(rng); + } + + int64_t current = 0; + for (int64_t r = offset; r < nrows_total && current < rows_sample_max; r += stride) { + const uint8_t * src_row = (const uint8_t *)tensor->data + slice * (src_row_sz * nrows_total) + r * src_row_sz; + if (src_type == GGML_TYPE_F32) { + const auto *src_f32 = (const float *)src_row; + f32_sample.insert(f32_sample.end(), src_f32, src_f32 + n_per_row); + } else { + row_to_fp32(src_row, row_buffer.data()); + f32_sample.insert(f32_sample.end(), row_buffer.begin(), row_buffer.end()); + } + + ++current; + } + + rows_sample[slice] = current; + } + } + + auto side_data = [&](const std::unordered_map> * m, const std::string & tensor_name) { + if (!m) { return std::pair{nullptr, 0}; } + + const std::string key = remap_imatrix(tensor_name, mapped); + const auto it = m->find(key); + return it == m->end() ? std::pair{nullptr, 0} : std::pair{ it->second.data(), it->second.size() }; + }; + + // Copy this row's side data (values and activations), or broadcasts to all slices + auto copy_or_broadcast = [&](const float * src, size_t src_sz, std::vector & dst) { + dst.clear(); + if (!src || src_sz == 0) { return; } + + const size_t want = (size_t)ne2 * (size_t)n_per_row; + if (src_sz == want) { + dst.assign(src, src + want); + return; + } + if (src_sz == (size_t)n_per_row) { + dst.resize(want); + for (int64_t s = 0; s < ne2; ++s) { + std::memcpy(dst.data() + s * n_per_row, src, n_per_row * sizeof(float)); + } + return; + } + + std::lock_guard lock(log_mutex); + LLAMA_LOG_WARN("%s: side data size mismatch for %s: got %zu, expected %zu or %zu; ignoring\n", func, name.c_str(), src_sz, (size_t)n_per_row, want); + }; + + const auto [values_all, values_sz] = side_data(values_data, name); + const auto [activations_all, activations_sz] = side_data(activations_data, name); + std::vector values_sample; + std::vector activations_sample; + if (values_all) { copy_or_broadcast(values_all, values_sz, values_sample); } + if (activations_all) { copy_or_broadcast(activations_all, activations_sz, activations_sample); } + + tensor_info info; + info.w = tw; + info.n_elements = ggml_nelements(tensor); + size_t total_sampled_rows = f32_sample.size() / n_per_row; + + // Build list of candidate types first (compatible ones) + const bool has_valid_imatrix = !values_sample.empty() && values_sample.size() == (size_t)ne2 * (size_t)n_per_row; + size_t max_row_sz = 0; + const ggml_type * base_arr = quant_types; + const size_t base_sz = std::size(quant_types); + std::vector compatible_candidates; + compatible_candidates.reserve(base_sz); + + for (size_t i = 0; i < base_sz; ++i) { + ggml_type ts_type = base_arr[i]; + if (is_iq(ts_type) && !has_valid_imatrix) { + std::lock_guard lock(log_mutex); + LLAMA_LOG_WARN("\t%s: skipping %s for %s, no or mismatched imatrix\n", func, ggml_type_name(ts_type), name.c_str()); + continue; + } + + ggml_type tt = make_compatible(tensor, ts_type); + if (!is_compatible(tensor, tt)) { continue; } + compatible_candidates.push_back(tt); + max_row_sz = std::max(max_row_sz, ggml_row_size(tt, n_per_row)); + } + + std::sort(compatible_candidates.begin(), compatible_candidates.end()); + compatible_candidates.erase(std::unique(compatible_candidates.begin(), compatible_candidates.end()), compatible_candidates.end()); + + // Adjusts the trade-off between systematic bias (introduced by block‑wise scaling) and MSE. + // Larger values favours quantisation types that produce smaller bias even if the MSE is slightly bigger + float tensor_lambda = 0.0f; + std::vector lambdas; + const float * values = values_sample.empty() ? nullptr : values_sample.data(); + const float * activations = activations_sample.empty() ? nullptr : activations_sample.data(); + double acc = 0.0; + int ns = 0; + lambdas = estimate_lambda(values, activations, n_per_row, ne2); + for (float l : lambdas) { acc += l; ++ns; } + tensor_lambda = ns ? (float)(acc / ns) : 0.0f; + + // Evaluate candidates + std::vector eval_candidates(compatible_candidates.size()); + std::vector quantized_buffer(max_row_sz * total_sampled_rows); + std::vector dequantized_buffer(f32_sample.size()); + const float * slice_lambda = lambdas.empty() ? nullptr : lambdas.data(); + for (size_t i = 0; i < compatible_candidates.size(); ++i) { + if (bpw_stop.load(std::memory_order_relaxed)) { return std::nullopt; } + + const ggml_type tensor_type = compatible_candidates[i]; + const auto bpw = (float)tensor_bpw(tensor, tensor_type); + const size_t bytes = tensor_bytes(tensor, tensor_type); + double mse = 0.0; + double proj = 0.0; + const auto err = estimate_error(tensor, tensor_type, f32_sample, rows_sample, values, activations, + quantized_buffer, dequantized_buffer, tensor_lambda, slice_lambda, &mse, &proj); + eval_candidates[i] = candidate_types{ tensor_type, bpw, bytes, err, mse, proj }; + } + + if (bpw_stop.load(std::memory_order_relaxed)) { return std::nullopt; } + + // Check if biasing is needed + bool bias_needed = false; + if (!lambdas.empty()) { + int min_mse = -1; + int min_bias = -1; + double best_mse = std::numeric_limits::infinity(); + double best_err = std::numeric_limits::infinity(); + for (int i = 0; i < (int)eval_candidates.size(); ++i) { + const auto & c = eval_candidates[i]; + if (c.bytes == 0) { continue; } + if (c.mse < best_mse) { + best_mse = c.mse; + min_mse = i; + } + if (c.error < best_err) { + best_err = c.error; + min_bias = i; + } + } + + if (min_mse != min_bias) { + bias_needed = true; + } else { + double max_rel_bias = 0.0; + for (const auto & c : eval_candidates) { + if (c.bytes == 0) { continue; } + const double mse = std::max(c.mse, epsilon); + const double bias_term = std::max(0.0, c.error - c.mse); + max_rel_bias = std::max(bias_term / mse, max_rel_bias); + } + + bias_needed = max_rel_bias >= 0.5; // >= 50% of MSE? + } + } + + for (auto & c : eval_candidates) { + if (c.bytes == 0) { continue; } + const double final_err = bias_needed ? c.error : c.mse; + info.candidate.push_back(candidate_types{ c.type, c.bpw, c.bytes, final_err, c.mse, c.proj }); + } + + if (info.candidate.empty()) { + // As a last resort, keep original type + float bpw = ggml_nbytes(tensor) * 8.0f / info.n_elements; + info.candidate.push_back(candidate_types{ tensor->type, bpw, ggml_nbytes(tensor), 0.0 }); + } + + // Keep only the pareto‑optimal candidates and enforce convexity in (bytes, error) curve + auto pareto_convex = [&](std::vector & candidates) { + if (candidates.empty()) { return; } + + std::sort(candidates.begin(), candidates.end(), [](const candidate_types & a, const candidate_types & b) { + if (a.bytes != b.bytes) { return a.bytes < b.bytes; } + return a.error < b.error; + }); + candidates.erase(std::unique(candidates.begin(), candidates.end(), [](const candidate_types & a, const candidate_types & b) { + return a.bytes == b.bytes; + }), candidates.end()); + std::vector pareto; + pareto.reserve(candidates.size()); + double best_err = infinity; + for (const auto & c : candidates) { + if (c.error < best_err) { + best_err = c.error; + pareto.push_back(c); + } + } + candidates.swap(pareto); + if (candidates.size() < 3) { return; } // need at least 3 points to do convex hull + + // Convex hull (lower envelope) + auto cross_product = [](const candidate_types & h0, const candidate_types & h1, const candidate_types & p) -> double { + const double dx1 = (double)h1.bytes - (double)h0.bytes; + const double dy1 = h1.error - h0.error; + const double dx2 = (double)p.bytes - (double)h0.bytes; + const double dy2 = p.error - h0.error; + return dx1 * dy2 - dx2 * dy1; + }; + std::vector hull; hull.reserve(candidates.size()); + for (const auto & c : candidates) { + while (hull.size() >= 2) { + if (cross_product(hull[hull.size() - 2], hull[hull.size() - 1], c) <= epsilon) { + hull.pop_back(); + } else { + break; + } + } + + hull.push_back(c); + } + + candidates.swap(hull); + }; + + pareto_convex(info.candidate); + + // Initialize choice at the smallest bpw candidate + info.choice = 0; + info.min_bpw = info.candidate.front().bpw; + info.max_bpw = info.candidate.back().bpw; + + return info; + }; + + std::vector all; // this vector will be populated by the parallel workers + { + std::atomic tensor_idx{0}; // shared work queue index for all threads + const size_t tensors_to_process = tensors.size(); + std::mutex loader_mutex; + std::mutex log_mutex; + std::mutex results_mutex; + std::vector workers; + int threads_to_spawn = std::max(1, std::min(nthread, (int)tensors_to_process)); + + for (int i = 0; i < threads_to_spawn; ++i) { + workers.emplace_back([&]() { + std::vector> thread_local_buffer; + while (true) { + const size_t current_idx = tensor_idx.fetch_add(1); + if (current_idx >= tensors_to_process) { break; } + const auto * tw = tensors[current_idx]; + if (!can_quantize(tw->tensor)) { continue; } + // Execute the main processing logic for this tensor + std::optional result_info = process_tensor(tw, thread_local_buffer, loader_mutex, log_mutex); + if (result_info) { + std::lock_guard lock(results_mutex); + all.push_back(std::move(*result_info)); + } + } + }); + } + + for (auto & w : workers) { w.join(); } + } + + check_signal_handler(all); + if (params->keep_bpw_state) { save_bpw_state(all); } + + if (all.empty()) { return {}; } + + // Compute total elements across all tensors and bytes for non-quantizable tensors + size_t nq_elements = 0; + size_t nq_bytes = 0; + for (const auto * it : tensors) { + const ggml_tensor * tensor = it->tensor; + const std::string name = ggml_get_name(tensor); + nq_elements += (size_t)ggml_nelements(tensor); + if (!can_quantize(tensor)) { nq_bytes += ggml_nbytes(tensor); } + } + + auto total_bytes = [&]() -> size_t { + size_t tb = 0; + for (const auto & ti : all) { + tb += ti.candidate[ti.choice].bytes; + } + + return tb; + }; + + size_t q_elements = 0; + size_t min_bytes = 0; + size_t max_bytes = 0; + for (const auto & ti : all) { + q_elements += (size_t)ti.n_elements; + min_bytes += ti.candidate.front().bytes; // smallest candidate per tensor + max_bytes += ti.candidate.back().bytes; // largest candidate per tensor + } + + if (q_elements == 0) { return {}; } + + const double target_bpw = params->target_bpw; + size_t target_total_bytes = std::llround(target_bpw * (double)nq_elements / 8.0); + size_t budget_bytes = target_total_bytes >= nq_bytes ? target_total_bytes - nq_bytes : min_bytes; + + auto emit_overrides = [&]() -> std::unordered_map { + std::unordered_map overrides; + LLAMA_LOG_INFO("%s: - estimated tensor quantization mix:\n", func); + for (const auto & ti : all) { + LLAMA_LOG_INFO("\t%s: %45s - \t%8s, \t%1.4f bpw,\terror: %.4f\n", + func, ggml_get_name(ti.w->tensor), ggml_type_name(ti.candidate[ti.choice].type), ti.candidate[ti.choice].bpw, ti.candidate[ti.choice].error); + overrides[ggml_get_name(ti.w->tensor)] = ti.candidate[ti.choice].type; + } + + return overrides; + }; + + if (budget_bytes <= min_bytes) { + for (auto & ti : all) { ti.choice = 0; } + return emit_overrides(); + } + if (budget_bytes >= max_bytes) { + for (auto & ti : all) { ti.choice = (int)ti.candidate.size() - 1; } + return emit_overrides(); + } + + // Certain tensors have a higher impact on model quality, so we apply a lower penalty to them + auto is_important = [&](const std::string & tensor_name) -> bool { + bool important = tensor_name == "output.weight"; + if (!important && !params->no_importance) { + important = tensor_name.find(".attn_v.weight") != std::string::npos || + tensor_name.find(".time_mix_value.weight") != std::string::npos || + tensor_name.find(".ffn_down.weight") != std::string::npos || + tensor_name.find(".ffn_down_exps.weight") != std::string::npos || + tensor_name.find(".attn_output.weight") != std::string::npos || + tensor_name.find(".time_mix_output.weight") != std::string::npos || + tensor_name.find(".attn_o.weight") != std::string::npos; + } + + return important; + }; + + // Lagrangian relaxation to minimise error subject to a bpw target constraint + auto lagrange_penalty = [&](const double mu, std::vector & choice, size_t & bytes, double & err) { + choice.resize(all.size()); + bytes = 0; + err = 0.0; + for (size_t i = 0; i < all.size(); ++i) { + const auto & candidate = all[i].candidate; + const std::string tensor_name = ggml_get_name(all[i].w->tensor); + double effective_mu = mu; + if (is_important(tensor_name)) { effective_mu *= 0.1; } // important tensors get 10x lower penalty + + int best_j = 0; + double best_val = infinity; + for (int j = 0; j < (int)candidate.size(); ++j) { + const double bits = (double)candidate[j].bytes * 8.0; + const double val = candidate[j].error + effective_mu * bits; + if (val < best_val - epsilon || (std::abs(val - best_val) <= epsilon && candidate[j].bytes < candidate[best_j].bytes)) { + best_val = val; + best_j = j; + } + } + + choice[i] = best_j; + bytes += candidate[best_j].bytes; + err += candidate[best_j].error; + } + }; + + size_t bytes_lo = 0; + size_t bytes_hi = 0; + size_t bytes_mid = 0; + double mu_lo = 0.0; + double mu_hi = 1.0; + double err_lo = 0.0; + double err_hi = 0.0; + double err_mid = 0.0; + std::vector choice_lo; + std::vector choice_hi; + std::vector choice_mid; + std::vector best_under_choice; + std::vector best_over_choice; + + lagrange_penalty(mu_lo, choice_lo, bytes_lo, err_lo); + + // increase mu until we get under budget or hit a safety cap + { + int expand = 0; + size_t prev_bytes_hi = std::numeric_limits::max(); + while (true) { + lagrange_penalty(mu_hi, choice_hi, bytes_hi, err_hi); + if (bytes_hi <= budget_bytes) { break; } + if (bytes_hi >= prev_bytes_hi) { break; } + prev_bytes_hi = bytes_hi; + + mu_hi *= 2.0; // double the penalty multiplier to reduce tensor sizes + if (++expand > 60) { break; } // safety cap to prevent an infinite loop + } + } + + double best_under_gap = infinity; + double best_over_gap = infinity; + double best_under_err = infinity; + double best_over_err = infinity; + for (int it = 0; it < 40; ++it) { // binary search iterations for optimal Lagrange multiplier (40 ≈ 1e-12 precision) + double mu = 0.5 * (mu_lo + mu_hi); // midpoint of current bounds + lagrange_penalty(mu, choice_mid, bytes_mid, err_mid); + + const double gap = std::abs((double)bytes_mid - (double)budget_bytes); + if (bytes_mid > budget_bytes) { + // Too big, need stronger penalty + mu_lo = mu; + if (gap < best_over_gap - epsilon || (std::abs(gap - best_over_gap) <= epsilon && err_mid < best_over_err)) { + best_over_gap = gap; + best_over_err = err_mid; + best_over_choice = choice_mid; + } + } else { + // Under budget, good candidate + mu_hi = mu; + if (gap < best_under_gap - epsilon || (std::abs(gap - best_under_gap) <= epsilon && err_mid < best_under_err)) { + best_under_gap = gap; + best_under_err = err_mid; + best_under_choice = choice_mid; + } + } + } + + if (!best_under_choice.empty()) { + for (size_t i = 0; i < all.size(); ++i) { + all[i].choice = best_under_choice[i]; + } + } else if (!best_over_choice.empty()) { + for (size_t i = 0; i < all.size(); ++i) { + all[i].choice = best_over_choice[i]; + } + } else { + // Pick whichever side we already have, or keep minimal + if (bytes_hi <= budget_bytes && !choice_hi.empty()) { + for (size_t i = 0; i < all.size(); ++i) { + all[i].choice = choice_hi[i]; + } + } else { + for (auto & ti : all) { + ti.choice = 0; + } + } + } + + // Spend any remaining budget with best upgrades that still fit (one pass) + { + auto cur_bytes = total_bytes(); + while (true) { + int best_i = -1; + int best_j = -1; + double best_ratio = -1.0; + double best_gain = -1.0; + + for (int i = 0; i < (int)all.size(); ++i) { + const auto & ti = all[i]; + const std::string tensor_name = ggml_get_name(ti.w->tensor); + int j = ti.choice + 1; + if (j >= (int)ti.candidate.size()) { continue; } // no upgrade available + + size_t delta_bytes = ti.candidate[j].bytes - ti.candidate[ti.choice].bytes; + if (cur_bytes + delta_bytes > budget_bytes) { continue; } // won't fit in budget + + double err_gain = std::max(0.0, ti.candidate[ti.choice].error - ti.candidate[j].error); + if (err_gain < epsilon) { continue; } // no error improvement + + double ratio = err_gain / (double)delta_bytes; // error reduction per byte + if (is_important(tensor_name)) { ratio *= 5.0; } // important tensors get 5x boost + + // For tie-breaking, prioritize the largest absolute error improvement. + if (ratio > best_ratio + epsilon || (std::abs(ratio - best_ratio) <= epsilon && err_gain > best_gain)) { + best_ratio = ratio; + best_gain = err_gain; + best_i = i; + best_j = j; + } + } + + if (best_i < 0) { break; } // no more upgrades within budget found + + size_t upgrade_cost = all[best_i].candidate[best_j].bytes - all[best_i].candidate[all[best_i].choice].bytes; + all[best_i].choice = best_j; + cur_bytes += upgrade_cost; + } + } + + delete_bpw_state(); // we're done, clear any checkpoint + + return emit_overrides(); +} + static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { ggml_type default_type; llama_ftype ftype = params->ftype; @@ -610,14 +1818,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: if (params->only_copy) { ftype = ml.ftype; } - const std::unordered_map> * imatrix_data = nullptr; + const std::unordered_map> * values_data = nullptr; + const std::unordered_map> * activations_data = nullptr; if (params->imatrix) { - imatrix_data = static_cast>*>(params->imatrix); - if (imatrix_data) { - LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size())); + values_data = static_cast>*>(params->imatrix); + if (values_data) { + LLAMA_LOG_INFO("================================ Have weights data with %d entries",int(values_data->size())); qs.has_imatrix = true; // check imatrix for nans or infs - for (const auto & kv : *imatrix_data) { + for (const auto & kv : *values_data) { for (float f : kv.second) { if (!std::isfinite(f)) { throw std::runtime_error(format("imatrix contains non-finite value %f\n", f)); @@ -626,8 +1835,23 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } } } + if (params->activations) { + activations_data = static_cast>*>(params->activations); + if (activations_data) { + LLAMA_LOG_INFO(" and %d activations",int(activations_data->size())); + qs.has_activations = true; + // check activations for nans or infs + for (const auto & kv : *activations_data) { + for (float f : kv.second) { + if (!std::isfinite(f)) { + throw std::runtime_error(format("activations contain non-finite value %f\n", f)); + } + } + } + } + } + LLAMA_LOG_INFO("\n"); - const size_t align = GGUF_DEFAULT_ALIGNMENT; gguf_context_ptr ctx_out { gguf_init_empty() }; std::vector prune_list = {}; @@ -756,6 +1980,27 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } } + std::unordered_map bpw_overrides = {}; + if (params->target_bpw != -1.0f && !params->only_copy) { + if (params->imatrix) { + if (params->activations) { + LLAMA_LOG_INFO("%s: imatrix has activations, process will be more accurate\n", __func__); + } else { + LLAMA_LOG_INFO("%s: imatrix does not have activations, process may be less accurate\n", __func__); + } + if (params->no_importance) { + LLAMA_LOG_INFO("%s: distributing bpw budget equitably across all tensors\n", __func__); + } else { + LLAMA_LOG_INFO("%s: assigning more bpw budget to important tensors\n", __func__); + } + LLAMA_LOG_INFO("%s: computing tensor quantization mix to achieve %.4f bpw\n", __func__, params->target_bpw); + + bpw_overrides = target_bpw_type(ml, model, tensors, mapped, values_data, activations_data, params, nthread); + } else { + LLAMA_LOG_WARN("%s: --target-bpw requires an imatrix but none was provided, option will be ignored\n", __func__); + } + } + int cur_split = -1; std::ofstream fout; auto close_ofstream = [&]() { @@ -788,6 +2033,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: const auto tn = LLM_TN(model.arch); new_ofstream(0); for (const auto * it : tensors) { + const size_t align = GGUF_DEFAULT_ALIGNMENT; const auto & weight = *it; ggml_tensor * tensor = weight.tensor; if (weight.idx != cur_split && params->keep_split) { @@ -806,62 +2052,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: ml.load_data_for(tensor); LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ", - ++idx, ml.n_tensors, - ggml_get_name(tensor), - llama_format_tensor_shape(tensor).c_str(), - ggml_type_name(tensor->type)); - - // This used to be a regex, but has an extreme cost to compile times. - bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? - - // quantize only 2D and 3D tensors (experts) - quantize &= (ggml_n_dims(tensor) >= 2); - - // do not quantize norm tensors - quantize &= name.find("_norm.weight") == std::string::npos; + ++idx, ml.n_tensors, ggml_get_name(tensor), llama_format_tensor_shape(tensor).c_str(), ggml_type_name(tensor->type)); + bool quantize = ggml_n_dims(tensor) >= 2 && is_quantizable(name, model.arch, params); quantize &= params->quantize_output_tensor || name != "output.weight"; - quantize &= !params->only_copy; - - // do not quantize expert gating tensors - // NOTE: can't use LLM_TN here because the layer number is not known - quantize &= name.find("ffn_gate_inp.weight") == std::string::npos; - - // these are very small (e.g. 4x4) - quantize &= name.find("altup") == std::string::npos; - quantize &= name.find("laurel") == std::string::npos; - - // these are not too big so keep them as it is - quantize &= name.find("per_layer_model_proj") == std::string::npos; - - // do not quantize positional embeddings and token types (BERT) - quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight"); - quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight"); - - // do not quantize Mamba's small yet 2D weights - // NOTE: can't use LLM_TN here because the layer number is not known - quantize &= name.find("ssm_conv1d.weight") == std::string::npos; - quantize &= name.find("shortconv.conv.weight") == std::string::npos; - - // do not quantize RWKV's small yet 2D weights - quantize &= name.find("time_mix_first.weight") == std::string::npos; - quantize &= name.find("time_mix_w0.weight") == std::string::npos; - quantize &= name.find("time_mix_w1.weight") == std::string::npos; - quantize &= name.find("time_mix_w2.weight") == std::string::npos; - quantize &= name.find("time_mix_v0.weight") == std::string::npos; - quantize &= name.find("time_mix_v1.weight") == std::string::npos; - quantize &= name.find("time_mix_v2.weight") == std::string::npos; - quantize &= name.find("time_mix_a0.weight") == std::string::npos; - quantize &= name.find("time_mix_a1.weight") == std::string::npos; - quantize &= name.find("time_mix_a2.weight") == std::string::npos; - quantize &= name.find("time_mix_g1.weight") == std::string::npos; - quantize &= name.find("time_mix_g2.weight") == std::string::npos; - quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos; - quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos; - quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos; - - // do not quantize relative position bias (T5) - quantize &= name.find("attn_rel_b.weight") == std::string::npos; // do not quantize specific multimodal tensors quantize &= name.find(".position_embd.") == std::string::npos; @@ -874,17 +2068,27 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: new_type = default_type; // get more optimal quantization type based on the tensor shape, layer, etc. - if (!params->pure && ggml_is_quantized(default_type)) { + if (!params->pure && (ggml_is_quantized(default_type) || params->target_bpw != -1.0f)) { int fallback = qs.n_fallback; new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); - // unless the user specifies a type, and the tensor geometry will not require fallback quantisation + + // get quantization type overrides targeting a given bits per weight budget + if (params->target_bpw != -1.0f && !bpw_overrides.empty()) { + const auto override = bpw_overrides.find(name); + if (override != bpw_overrides.end() && override->second != new_type) { + LLAMA_LOG_DEBUG("(bpw override %s) ", ggml_type_name(new_type)); + new_type = override->second; + } + } + + // unless the user specifies a type, and the tensor shape will not require fallback quantisation if (params->tensor_types && qs.n_fallback - fallback == 0) { const std::vector & tensor_types = *static_cast *>(params->tensor_types); const std::string tensor_name(tensor->name); for (const auto & [tname, qtype] : tensor_types) { if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) { if (qtype != new_type) { - LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type)); + LLAMA_LOG_DEBUG("(type override %s) ", ggml_type_name(new_type)); new_type = qtype; // if two or more types are specified for the same tensor, the last match wins } } @@ -912,10 +2116,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: const int64_t nelements = ggml_nelements(tensor); const float * imatrix = nullptr; - if (imatrix_data) { - auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped)); - if (it == imatrix_data->end()) { - LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name); + if (values_data) { + auto it = values_data->find(remap_imatrix(tensor->name, mapped)); + if (it == values_data->end()) { + LLAMA_LOG_INFO("\n====== %s: did not find weights for %s, ", __func__, tensor->name); } else { if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) { imatrix = it->second.data(); @@ -1049,9 +2253,14 @@ llama_model_quantize_params llama_model_quantize_default_params() { /*.pure =*/ false, /*.keep_split =*/ false, /*.imatrix =*/ nullptr, + /*.activations =*/ nullptr, /*.kv_overrides =*/ nullptr, /*.tensor_type =*/ nullptr, - /*.prune_layers =*/ nullptr + /*.prune_layers =*/ nullptr, + /*.target_bpw =*/ -1.0f, + /*.keep_bpw_state =*/ false, + /*.bpw_state =*/ nullptr, + /*.no_importance =*/ false }; return result; diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index 470dc3d916..a1426ea4a3 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -117,21 +117,27 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp [[noreturn]] static void usage(const char * executable) { - printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights]\n", executable); - printf(" [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--tensor-type] [--prune-layers] [--keep-split] [--override-kv]\n"); - printf(" model-f32.gguf [model-quant.gguf] type [nthreads]\n\n"); - printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n"); - printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n"); - printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n"); + printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights]\n", executable); + printf(" [--target-bpw n] [--no-importance] [--keep-bpw-state] [--bpw-state filename] [--output-tensor-type] [--token-embedding-type] [--tensor-type]\n"); + printf(" [--prune-layers] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n"); + printf(" --allow-requantize: allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n"); + printf(" --leave-output-tensor: will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n"); + printf(" --pure: disable k-quant mixtures and quantize all tensors to the same type\n"); printf(" --imatrix file_name: use data in file_name as importance matrix for quant optimizations\n"); printf(" --include-weights tensor_name: use importance matrix for this/these tensor(s)\n"); printf(" --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n"); printf(" --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor\n"); printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n"); - printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n"); + printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. Example: --tensor-type attn_q=q8_0\n"); printf(" Advanced option to selectively quantize tensors. May be specified multiple times.\n"); printf(" --prune-layers L0,L1,L2...comma-separated list of layer numbers to prune from the model\n"); printf(" Advanced option to remove all tensors from the given layers\n"); + printf(" --target-bpw: target bits per weight (bpw). Must be a positive number between 0.0 and 16.0\n"); + printf(" Advanced option to automatically select quantization types to achieve a total bits per weight (bpw) target\n"); + printf(" --no-importance: distribute bpw budget equitably across all tensors\n"); + printf(" Advanced option to disable assigning more bpw budget to important tensors. It may increase quality for some models\n"); + printf(" --keep-bpw-state: save the bpw computations to -.bpw_state\n"); + printf(" --bpw-state: file name to use instead of default\n"); printf(" --keep-split: will generate quantized model in the same shards as input\n"); printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n"); @@ -214,7 +220,10 @@ static int load_legacy_imatrix(const std::string & imatrix_file, std::vector & imatrix_datasets, std::unordered_map> & imatrix_data) { +static int load_imatrix(const std::string & imatrix_file, + std::vector & imatrix_datasets, + std::unordered_map> & values_data, + std::unordered_map> & activations_data) { struct ggml_context * ctx = nullptr; struct gguf_init_params meta_gguf_params = { @@ -224,7 +233,7 @@ static int load_imatrix(const std::string & imatrix_file, std::vector> sums_counts_for; + std::map> sums_counts_for; for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { std::string name = cur->name; @@ -258,44 +268,55 @@ static int load_imatrix(const std::string & imatrix_file, std::vector(sums_counts_for[std::move(name)]) = cur; + } else if (string_remove_suffix(name, sums2_suffix)) { // in_sum2 - sums_counts_for[std::move(name)].first = cur; + std::get<1>(sums_counts_for[std::move(name)]) = cur; } else if (string_remove_suffix(name, counts_suffix)) { // counts - sums_counts_for[std::move(name)].second = cur; - } else { + std::get<2>(sums_counts_for[std::move(name)]) = cur; + } else { // ignore other tensors } } for (const auto & sc : sums_counts_for) { const std::string & name = sc.first; - const struct ggml_tensor * sums = sc.second.first; - const struct ggml_tensor * counts = sc.second.second; + const struct ggml_tensor * sums = std::get<0>(sc.second); + const struct ggml_tensor * sums2 = std::get<1>(sc.second); + const struct ggml_tensor * counts = std::get<2>(sc.second); - if (!sums || !counts) { + // check sums2 and counts are present, and that sums and sums2 have the same shape + if (!sums2 || !counts || (sums != nullptr && ggml_nelements(sums) != ggml_nelements(sums2))) { fprintf(stderr, "%s: mismatched sums and counts for %s\n", __func__, name.c_str()); gguf_free(ctx_gguf); ggml_free(ctx); exit(1); } - const int64_t ne0 = sums->ne[0]; - const int64_t ne1 = sums->ne[1]; + const int64_t ne0 = sums2->ne[0]; + const int64_t ne1 = sums2->ne[1]; - auto & e = imatrix_data[name]; - e.resize(ggml_nelements(sums)); + auto & activations = activations_data[name]; + auto & values = values_data[name]; + if (sums) { + activations.resize(ggml_nelements(sums)); + } + values.resize(ggml_nelements(sums2)); float max_count = 0.0f; for (int64_t j = 0; j < ne1; ++j) { const float count = ((const float *) counts->data)[j]; if (count > 0.0f) { for (int64_t i = 0; i < ne0; ++i) { - e[j*ne0 + i] = ((const float *) sums->data)[j*ne0 + i] / count; + values[j*ne0 + i] = ((const float *) sums2->data)[j*ne0 + i] / count; + if (sums) { activations[j*ne0 + i] = ((const float *) sums->data)[j*ne0 + i] / count; } } } else { // Partial imatrix data, this tensor never got any input during calibration for (int64_t i = 0; i < ne0; ++i) { - e[j*ne0 + i] = 1; + values[j*ne0 + i] = 1; + if (sums) { activations[j*ne0 + i] = 0; } } } if (count > max_count) { @@ -303,7 +324,8 @@ static int load_imatrix(const std::string & imatrix_file, std::vector & imatrix_dataset, const std::vector & included_weights, const std::vector & excluded_weights, - std::unordered_map> & imatrix_data) { + std::unordered_map> & values_data, + std::unordered_map> & activations_data) { int m_last_call = -1; if (!imatrix_file.empty()) { - m_last_call = load_imatrix(imatrix_file, imatrix_dataset, imatrix_data); + m_last_call = load_imatrix(imatrix_file, imatrix_dataset, values_data, activations_data); } - if (imatrix_data.empty()) { + if (values_data.empty()) { return m_last_call; } if (!excluded_weights.empty()) { for (const auto & name : excluded_weights) { - for (auto it = imatrix_data.begin(); it != imatrix_data.end();) { - auto pos = it->first.find(name); + for (auto vt = values_data.begin(); vt != values_data.end();) { + auto pos = vt->first.find(name); if (pos != std::string::npos) { - it = imatrix_data.erase(it); + vt = values_data.erase(vt); } else { - ++it; + ++vt; + } + } + for (auto at = activations_data.begin(); at != activations_data.end();) { + auto pos = at->first.find(name); + if (pos != std::string::npos) { + at = activations_data.erase(at); + } else { + ++at; } } } } if (!included_weights.empty()) { - std::unordered_map> tmp; + std::unordered_map> tmp_values; + std::unordered_map> tmp_activations; for (const auto & name : included_weights) { - for (auto & e : imatrix_data) { + for (auto & e : values_data) { auto pos = e.first.find(name); if (pos != std::string::npos) { - tmp.emplace(std::move(e)); + tmp_values.emplace(std::move(e)); + } + } + for (auto & a : activations_data) { + auto pos = a.first.find(name); + if (pos != std::string::npos) { + tmp_activations.emplace(std::move(a)); } } } - imatrix_data = std::move(tmp); - } - if (!imatrix_data.empty()) { - printf("%s: have %d importance matrix entries\n", __func__, int(imatrix_data.size())); + values_data = std::move(tmp_values); + activations_data = std::move(tmp_activations); } + return m_last_call; } @@ -440,6 +477,52 @@ static bool parse_layer_prune(const char * data, std::vector & prune_layers return true; } +static bool parse_target_bpw(const char * data, float & target_bpw) { + if (!data) { + printf("\n%s: no target bits per weight (bpw) provided\n\n", __func__); + return false; + } + + try { + target_bpw = std::stof(data); + if (target_bpw < 0.0f || target_bpw > 16.0f) { + printf("\n%s: target bits per weight (bpw) must be a positive number between 0.0 and 16.0\n\n", __func__); + return false; + } + } + catch (const std::exception & e) { + printf("\n%s: '%s' is not valid. Target bits per weight (bpw) must be a positive number between 0.0 and 16.0\n\n", __func__, data); + return false; + } + + return true; +} + +static const char * get_ftype(const float bpw) { + const std::map quant_bpw = { + {1.5625, "IQ1_S"}, + {1.7500, "IQ1_M"}, + {2.0625, "IQ2_XXS"}, + {2.3125, "IQ2_XS"}, + {2.5625, "IQ2_S"}, + {2.6250, "Q2_K"}, + {3.0625, "IQ3_XXS"}, + {3.4375, "Q3_K"}, + {4.2500, "IQ4_XS"}, + {4.5000, "Q4_K"}, + {5.5000, "Q5_K"}, + {6.5625, "Q6_K"}, + {8.5000, "Q8_0"}, +#ifdef GGML_USE_METAL + {16.0000, "F16"} +#else + {16.0000, "BF16"} +#endif + }; + + return quant_bpw.lower_bound(bpw)->second; +} + int main(int argc, char ** argv) { if (argc < 3) { usage(argv[0]); @@ -453,6 +536,7 @@ int main(int argc, char ** argv) { std::vector kv_overrides; std::vector tensor_types; std::vector prune_layers; + float target_bpw = -1.0f; for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) { if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) { @@ -479,6 +563,20 @@ int main(int argc, char ** argv) { if (arg_idx == argc-1 || !parse_tensor_type(argv[++arg_idx], tensor_types)) { usage(argv[0]); } + } else if (strcmp(argv[arg_idx], "--target-bpw") == 0) { + if (arg_idx == argc-1 || !parse_target_bpw(argv[++arg_idx], target_bpw)) { + usage(argv[0]); + } + } else if (strcmp(argv[arg_idx], "--no-importance") == 0) { + params.no_importance = true; + } else if (strcmp(argv[arg_idx], "--keep-bpw-state") == 0) { + params.keep_bpw_state = true; + } else if (strcmp(argv[arg_idx], "--bpw-state") == 0) { + if (arg_idx < argc-1) { + params.bpw_state = argv[++arg_idx]; + } else { + usage(argv[0]); + } } else if (strcmp(argv[arg_idx], "--prune-layers") == 0) { if (arg_idx == argc-1 || !parse_layer_prune(argv[++arg_idx], prune_layers)) { usage(argv[0]); @@ -525,10 +623,11 @@ int main(int argc, char ** argv) { } std::vector imatrix_datasets; - std::unordered_map> imatrix_data; - int m_last_call = prepare_imatrix(imatrix_file, imatrix_datasets, included_weights, excluded_weights, imatrix_data); - if (!imatrix_data.empty()) { - params.imatrix = &imatrix_data; + std::unordered_map> values_data; + std::unordered_map> activations_data; + int m_last_call = prepare_imatrix(imatrix_file, imatrix_datasets, included_weights, excluded_weights, values_data, activations_data); + if (!values_data.empty()) { + params.imatrix = &values_data; { llama_model_kv_override kvo; std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_FILE); @@ -551,7 +650,7 @@ int main(int argc, char ** argv) { llama_model_kv_override kvo; std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES); kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; - kvo.val_i64 = imatrix_data.size(); + kvo.val_i64 = values_data.size(); kv_overrides.emplace_back(std::move(kvo)); } @@ -563,6 +662,9 @@ int main(int argc, char ** argv) { kv_overrides.emplace_back(std::move(kvo)); } } + if (!activations_data.empty()) { + params.activations = &activations_data; + } if (!kv_overrides.empty()) { kv_overrides.emplace_back(); kv_overrides.back().key[0] = 0; @@ -574,6 +676,9 @@ int main(int argc, char ** argv) { if (!prune_layers.empty()) { params.prune_layers = &prune_layers; } + if (target_bpw != -1.0f) { + params.target_bpw = target_bpw; + } llama_backend_init(); @@ -584,6 +689,7 @@ int main(int argc, char ** argv) { std::string ftype_str; std::string suffix = ".gguf"; + std::vector tmp_argv(argv, argv + argc); if (try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) { std::string fpath; const size_t pos = fname_inp.find_last_of("/\\"); @@ -607,7 +713,15 @@ int main(int argc, char ** argv) { } arg_idx++; - if (argc <= arg_idx) { + // select quantization type if target_bpw is set unless user specifies type and threads + if (argc - arg_idx <= 1 && params.target_bpw != -1.0f) { + auto * ftype = const_cast(get_ftype(params.target_bpw)); + if (argc == arg_idx) { tmp_argv.push_back(ftype); } + else { tmp_argv.insert(tmp_argv.end() - 1, ftype); } + tmp_argv.push_back(nullptr); + argv = const_cast(tmp_argv.data()); + argc++; + } else if (argc <= arg_idx) { fprintf(stderr, "%s: missing ftype\n", __func__); return 1; } @@ -636,7 +750,7 @@ int main(int argc, char ** argv) { params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || - params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) && imatrix_data.empty()) { + params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) && values_data.empty()) { fprintf(stderr, "\n==========================================================================================================\n"); fprintf(stderr, "Please do not use IQ1_S, IQ1_M, IQ2_S, IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n"); fprintf(stderr, "==========================================================================================================\n\n\n");