diff --git a/include/llama.h b/include/llama.h index 8b3c8a7b10..274e776bad 100644 --- a/include/llama.h +++ b/include/llama.h @@ -378,9 +378,14 @@ extern "C" { bool pure; // quantize all tensors to the default type bool keep_split; // quantize to the same number of shards void * imatrix; // pointer to importance matrix data + void * activations; // pointer to activations data void * kv_overrides; // pointer to vector containing overrides void * tensor_types; // pointer to vector containing tensor types void * prune_layers; // pointer to vector containing layer indices to prune + float target_bpw; // target bits per weight (bpw) + bool keep_bpw_state; // keep bpw state file + void * bpw_state; // pointer to bpw state file + bool no_importance; // allocate target bpw budget equitably across all tensors } llama_model_quantize_params; typedef struct llama_logit_bias { diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index bc4b05c3b5..67e5aa9827 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -4,14 +4,20 @@ #include "llama-model-loader.h" #include +#include #include #include #include +#include #include #include +#include +#include +#include #include #include #include +#include // Quantization types. Changes to this struct must be replicated in quantize.cpp struct tensor_quantization { @@ -19,6 +25,87 @@ struct tensor_quantization { ggml_type quant = GGML_TYPE_COUNT; }; +static bool is_quantizable(const std::string & name, const llm_arch arch, const llama_model_quantize_params * params) { + if (params->only_copy) { return false; } + + const auto tn = LLM_TN(arch); + + // This used to be a regex, but has an extreme cost to compile times. + bool q = name.size() >= 6 && name.rfind("weight") == name.size() - 6; // ends with 'weight'? + + // Do not quantize norm tensors + q &= name.find("_norm.weight") == std::string::npos; + + // Do not quantize expert gating tensors + // NOTE: can't use LLM_TN here because the layer number is not known + q &= name.find("ffn_gate_inp.weight") == std::string::npos; + + // These are very small (e.g. 4x4) + q &= name.find("altup") == std::string::npos; + q &= name.find("laurel") == std::string::npos; + + // These are not too big so keep them as it is + q &= name.find("per_layer_model_proj") == std::string::npos; + + // Do not quantize positional embeddings and token types (BERT) + q &= name != tn(LLM_TENSOR_POS_EMBD, "weight"); + q &= name != tn(LLM_TENSOR_TOKEN_TYPES, "weight"); + + // Do not quantize Jamba, Mamba, LFM2's small yet 2D weights + // NOTE: can't use LLM_TN here because the layer number is not known + q &= name.find("ssm_conv1d.weight") == std::string::npos; + q &= name.find("shortconv.conv.weight") == std::string::npos; + + // Do not quantize ARWKV, RWKV's small yet 2D weights + q &= name.find("time_mix_first.weight") == std::string::npos; + q &= name.find("time_mix_w0.weight") == std::string::npos; + q &= name.find("time_mix_w1.weight") == std::string::npos; + q &= name.find("time_mix_w2.weight") == std::string::npos; + q &= name.find("time_mix_v0.weight") == std::string::npos; + q &= name.find("time_mix_v1.weight") == std::string::npos; + q &= name.find("time_mix_v2.weight") == std::string::npos; + q &= name.find("time_mix_a0.weight") == std::string::npos; + q &= name.find("time_mix_a1.weight") == std::string::npos; + q &= name.find("time_mix_a2.weight") == std::string::npos; + q &= name.find("time_mix_g1.weight") == std::string::npos; + q &= name.find("time_mix_g2.weight") == std::string::npos; + q &= name.find("time_mix_decay_w1.weight") == std::string::npos; + q &= name.find("time_mix_decay_w2.weight") == std::string::npos; + q &= name.find("time_mix_lerp_fused.weight") == std::string::npos; + + // Do not quantize relative position bias (T5) + q &= name.find("attn_rel_b.weight") == std::string::npos; + + return q; +} + +static enum ggml_type fallback_type(const enum ggml_type new_type) { + switch (new_type) { + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: + return GGML_TYPE_Q4_0; // symmetric-ish fallback + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_IQ4_XS: + return GGML_TYPE_IQ4_NL; + case GGML_TYPE_Q4_K: + return GGML_TYPE_Q5_0; + case GGML_TYPE_Q5_K: + return GGML_TYPE_Q5_1; + case GGML_TYPE_Q6_K: + return GGML_TYPE_Q8_0; + default: + return new_type; + } +} + static void zeros(std::ofstream & file, size_t n) { char zero = 0; for (size_t i = 0; i < n; ++i) { @@ -66,7 +153,6 @@ static std::string remap_imatrix (const std::string & orig_name, const std::map< for (const auto & p : mapped) { if (p.second == blk) { - LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first); return new_name.replace(match.position(1), match.length(1), std::to_string(p.first)); } } @@ -89,10 +175,11 @@ struct quantize_state_impl { int i_ffn_gate = 0; int i_ffn_up = 0; - int n_k_quantized = 0; - int n_fallback = 0; + int n_k_quantized = 0; + int n_fallback = 0; - bool has_imatrix = false; + bool has_imatrix = false; + bool has_activations = false; // used to figure out if a model shares tok_embd with the output weight bool has_output = false; @@ -530,6 +617,1150 @@ static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * return new_size; } +static std::atomic bpw_stop{ false }; + +static void signal_handler(int) { + bpw_stop.store(true, std::memory_order_relaxed); +} + +// Returns tensor type overrides that meet a global bpw target +static std::unordered_map target_bpw_type( + llama_model_loader & ml, + const llama_model & model, + const std::vector & tensors, + const std::map & mapped, + const std::unordered_map> * values_data, + const std::unordered_map> * activations_data, + const llama_model_quantize_params * params, + int nthread +) { + bpw_stop.store(false, std::memory_order_relaxed); + // SIGINT/SIGTERM signal handlers + struct signal_scope_guard { + using handler_t = void (*)(int); + handler_t prev_int = SIG_DFL; + handler_t prev_term = SIG_DFL; + signal_scope_guard() { + prev_int = std::signal(SIGINT, signal_handler); + prev_term = std::signal(SIGTERM, signal_handler); + } + ~signal_scope_guard() { + std::signal(SIGINT, prev_int); + std::signal(SIGTERM, prev_term); + } + } signal_guard; + + // Error and bias projection per GGML_TYPE per tensor + struct candidate_types { + ggml_type type = GGML_TYPE_COUNT; + float bpw = 0.0f; + size_t bytes = 0; + double error = 0.0; + double mse = 0.0; + double proj = 0.0; + }; + + // Per‑tensor quantization mix that satisfies a global bpw target + struct tensor_info { + const llama_model_loader::llama_tensor_weight * w = nullptr; + std::vector candidate; + int choice = -1; + float min_bpw = 0.0; + float max_bpw = 0.0; + size_t n_elements = 0; + }; + + // subset of quantization types with the best accuracy/size tradeoff + constexpr ggml_type quant_types[] = { + GGML_TYPE_IQ1_S, + GGML_TYPE_IQ1_M, + GGML_TYPE_IQ2_XXS, + GGML_TYPE_IQ2_XS, + GGML_TYPE_IQ2_S, + GGML_TYPE_Q2_K, + GGML_TYPE_IQ3_XXS, + GGML_TYPE_Q3_K, + GGML_TYPE_IQ4_XS, + GGML_TYPE_IQ4_NL, + GGML_TYPE_Q4_K, + GGML_TYPE_Q5_K, + GGML_TYPE_Q6_K, + GGML_TYPE_Q8_0, +#ifdef GGML_USE_METAL + GGML_TYPE_F16 +#else + GGML_TYPE_BF16 +#endif + }; + + constexpr double epsilon = 1e-12; + constexpr double infinity = std::numeric_limits::infinity(); + constexpr uint32_t file_magic = 0x4d534531; // MSE1 + constexpr uint64_t arbitrary_magic = 0xeabada55cafed00d; + const char * func = __func__; + + // Tensor size in bytes for a given type + auto tensor_bytes = [](const ggml_tensor * t, const ggml_type typ) -> size_t { + const int64_t n_per_row = t->ne[0]; + const size_t row_sz = ggml_row_size(typ, n_per_row); + return (size_t)ggml_nrows(t) * row_sz; + }; + + // Tensor bpw for a given type + auto tensor_bpw = [&](const ggml_tensor * t, const ggml_type typ) -> double { + const size_t bytes = tensor_bytes(t, typ); + return (double)bytes * 8.0 / (double)ggml_nelements(t); + }; + + // Check if tensor is compatible with quantization type + auto is_compatible = [](const ggml_tensor * t, const ggml_type typ) -> bool { + const int64_t blck = ggml_blck_size(typ); + return blck <= 1 || (t->ne[0] % blck) == 0; + }; + + // Get suitable fallback for type + auto make_compatible = [&](const ggml_tensor * t, const ggml_type typ) -> ggml_type { + if (is_compatible(t, typ)) { return typ; } + const ggml_type fb = fallback_type(typ); + return is_compatible(t, fb) ? fb : GGML_TYPE_F16; + }; + + // Check if tensor is an IQ type + auto is_iq = [](const enum ggml_type t) { + switch (t) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + return true; + default: + return false; + } + }; + + // Check if tensor can be quantized + auto can_quantize = [&](const ggml_tensor * t) -> bool { + if (ggml_n_dims(t) < 2) { return false; } // skip 1D tensors + return is_quantizable(ggml_get_name(t), model.arch, params); + }; + + // Saved state per tensor + struct saved_info { + std::vector candidate; + int choice = -1; + float min_bpw = 0.0f; + float max_bpw = 0.0f; + size_t n_elements = 0; + }; + + // DJB2 hashing algorithm + auto djb2_hash = [&](const uint8_t * data, const size_t n) -> uint64_t { + uint64_t h = 5381; + for (size_t i = 0; i < n; ++i) { + h = (h << 5) + h + data[i]; + } + return h ? h : arbitrary_magic; + }; + + // Get model ID from metadata hash + auto metadata_id = [&](const gguf_context * ctx) -> uint64_t { + const size_t sz = gguf_get_meta_size(ctx); + std::vector buf(sz); + gguf_get_meta_data(ctx, buf.data()); + return djb2_hash(buf.data(), buf.size()); + }; + + std::string gen_name; + std::string checkpoint_file; + char hex[17]; + const uint64_t model_id = metadata_id(ml.meta.get()); + + std::snprintf(hex, sizeof(hex), "%016" PRIx64, (uint64_t)model_id); + ml.get_key(LLM_KV_GENERAL_NAME, gen_name, false); + std::replace(gen_name.begin(), gen_name.end(), ' ', '_'); + + gen_name.empty() ? checkpoint_file = ml.arch_name : checkpoint_file = gen_name; + checkpoint_file += "-" + std::string(hex) + "-mse.bpw_state"; + + if (params->keep_bpw_state && params->bpw_state) { + const auto * filename = static_cast(params->bpw_state); + std::ifstream ifs(filename, std::ios::binary); + if (ifs.good()) { + checkpoint_file = std::string(filename); + } else { + std::ofstream ofs(filename, std::ios::binary | std::ios::app); + if (ofs.is_open()) { + checkpoint_file = std::string(filename); + ofs.close(); + std::remove(checkpoint_file.c_str()); + } else { + LLAMA_LOG_WARN("%s: %s is not a valid file name. Using %s instead\n", func, filename, checkpoint_file.c_str()); + } + } + } + + // Serializes vector to disk + auto save_bpw_state = [&](const std::vector & all_vec) { + const std::string tmp = checkpoint_file + ".tmp"; + std::ofstream ofs(tmp, std::ios::binary | std::ios::trunc); + if (!ofs) { return; } + ofs.write((const char *)&file_magic, sizeof(file_magic)); + ofs.write((const char *)&model_id, sizeof(model_id)); + const uint64_t n = all_vec.size(); + ofs.write((const char *)&n, sizeof(n)); + for (const auto & ti : all_vec) { + const std::string name = ggml_get_name(ti.w->tensor); + const auto len = (uint32_t)name.size(); + ofs.write((const char *)&len, sizeof(len)); + ofs.write(name.data(), len); + + const uint64_t cn = ti.candidate.size(); + ofs.write((const char *)&cn, sizeof(cn)); + ofs.write((const char *)&ti.choice, sizeof(ti.choice)); + ofs.write((const char *)&ti.min_bpw, sizeof(ti.min_bpw)); + ofs.write((const char *)&ti.max_bpw, sizeof(ti.max_bpw)); + const uint64_t ne = ti.n_elements; + ofs.write((const char *)&ne, sizeof(ne)); + + for (const auto & c : ti.candidate) { + const int32_t t = c.type; + const uint64_t b = c.bytes; + ofs.write((const char *)&t, sizeof(t)); + ofs.write((const char *)&c.bpw, sizeof(c.bpw)); + ofs.write((const char *)&b, sizeof(b)); + ofs.write((const char *)&c.error, sizeof(c.error)); + } + } + + ofs.close(); + std::remove(checkpoint_file.c_str()); + std::rename(tmp.c_str(), checkpoint_file.c_str()); + LLAMA_LOG_INFO("%s: saved progress for %lu tensors to %s\n", func, all_vec.size(), checkpoint_file.c_str()); + }; + + // Deserializes vector from disk + auto load_bpw_state = [&]() -> std::unordered_map { + std::unordered_map out; + std::ifstream ifs(checkpoint_file, std::ios::binary); + if (!ifs) { return out; } + + uint32_t magic = 0; + uint64_t id = 0; + ifs.read((char *)&magic, sizeof(magic)); + ifs.read((char *)&id, sizeof(id)); + if (magic != file_magic) { + LLAMA_LOG_WARN("%s: invalid resume file, ignoring: %s\n", func, checkpoint_file.c_str()); + return out; + } + if (id != model_id) { + LLAMA_LOG_WARN("%s: model ID mismatch, ignoring: %s\n", func, checkpoint_file.c_str()); + return out; + } + + LLAMA_LOG_INFO("%s: state file found, resuming tensor quantization\n", func); + + uint64_t n = 0; + ifs.read((char *)&n, sizeof(n)); + for (uint64_t i = 0; i < n; ++i) { + uint32_t len = 0; + ifs.read((char *)&len, sizeof(len)); + std::string name(len, '\0'); + ifs.read(name.data(), len); + + uint64_t cn = 0; + ifs.read((char *)&cn, sizeof(cn)); + + saved_info si; + ifs.read((char *)&si.choice, sizeof(si.choice)); + ifs.read((char *)&si.min_bpw, sizeof(si.min_bpw)); + ifs.read((char *)&si.max_bpw, sizeof(si.max_bpw)); + uint64_t ne = 0; + ifs.read((char *)&ne, sizeof(ne)); + si.n_elements = (size_t)ne; + + si.candidate.resize(cn); + for (auto & s : si.candidate) { + int32_t t = 0; + uint64_t b = 0; + ifs.read((char *)&t, sizeof(t)); + s.type = (ggml_type)t; + ifs.read((char *)&s.bpw, sizeof(s.bpw)); + ifs.read((char *)&b, sizeof(b)); + s.bytes = (size_t)b; + ifs.read((char *)&s.error, sizeof(s.error)); + } + + out.emplace(std::move(name), std::move(si)); + } + + LLAMA_LOG_INFO("%s: loaded bpw state for %lu tensors from %s\n", func, out.size(), checkpoint_file.c_str()); + return out; + }; + + // Deletes checkpoint file unless --keep-bpw-state is set + auto delete_bpw_state = [&] { + std::ifstream ifs(checkpoint_file); + if (ifs.good() && !params->keep_bpw_state) { + LLAMA_LOG_INFO("%s: deleting %s\n", func, checkpoint_file.c_str()); + std::remove(checkpoint_file.c_str()); + } + }; + + // Check for user interrupt and save progress + auto check_signal_handler = [&](const std::vector & all_vec) { + if (bpw_stop.load(std::memory_order_relaxed)) { + LLAMA_LOG_INFO("\n%s: saving progress for %lu tensors to %s\n", func, all_vec.size(), checkpoint_file.c_str()); + save_bpw_state(all_vec); + throw std::runtime_error("user interrupted the process"); + } + }; + + // Estimate error for a given type using a sampled subset of rows + auto estimate_error = [&](const ggml_tensor * t, + const ggml_type quant_type, + const std::vector & f32_sample, + const std::vector & rows_sample, + const float * values_sample, + const float * activations_sample, + std::vector & quantized_buffer, + std::vector & dequantized_buffer, + float tensor_bias_lambda, + const float * slice_bias_lambda, + double * out_mse = nullptr, + double * out_proj = nullptr) -> double + { + const int64_t n_per_row = t->ne[0]; + const int64_t nrows = t->ne[1]; + const int64_t ne2 = t->ne[2] > 0 ? t->ne[2] : 1; + const size_t sample_elems = f32_sample.size(); + const size_t sample_rows = n_per_row > 0 ? sample_elems / (size_t)n_per_row : 0; + + if (sample_rows == 0) { + if (out_mse) { *out_mse = 0.0; } + if (out_proj) { *out_proj = 0.0; } + return 0.0; + } + + size_t expected_rows = 0; + for (int64_t s = 0; s < ne2; ++s) { + expected_rows += (size_t)rows_sample[s]; + } + + if (expected_rows != sample_rows) { + if (out_mse) { *out_mse = infinity; } + if (out_proj) { *out_proj = 0.0; } + return infinity; + } + + const size_t row_sz = ggml_row_size(quant_type, n_per_row); + const size_t buf_sz = row_sz * sample_rows; + + if (quantized_buffer.size() < buf_sz) { quantized_buffer.resize(buf_sz); } + if (dequantized_buffer.size() < sample_elems) { dequantized_buffer.resize(sample_elems); } + + const bool has_values = values_sample != nullptr; + const bool has_activations = activations_sample != nullptr; + + // Bias denominators per slice + std::vector bias_denom(ne2, 0.0); + if (has_activations) { + for (int64_t s = 0; s < ne2; ++s) { + const float * v = has_values ? values_sample + s * n_per_row : nullptr; + const float * a = activations_sample + s * n_per_row; + double denom = 0.0; + for (int64_t j = 0; j < n_per_row; ++j) { + const double w = v ? std::max(0.0f, v[j]) : 1.0; + const double aj = a[j]; + denom += w * aj * aj; + } + + bias_denom[s] = denom; + } + } + + // Row squared norms (weighted if values present) + std::vector row_sq_norm(sample_rows, 0.0); + { + size_t off = 0; + size_t ridx = 0; + for (int64_t s = 0; s < ne2; ++s) { + const int64_t rs = rows_sample[s]; + if (rs == 0) { continue; } + + const float * v = has_values ? values_sample + s * n_per_row : nullptr; + for (int64_t r = 0; r < rs; ++r, ++ridx) { + const float * x = f32_sample.data() + off; + double sum = 0.0; + if (v) { + for (int64_t j = 0; j < n_per_row; ++j) { + const double w = std::max(0.0f, v[j]); + const double xx = x[j]; + sum += w * xx * xx; + } + } else { + for (int64_t j = 0; j < n_per_row; ++j) { + const double xx = x[j]; + sum += xx * xx; + } + } + + row_sq_norm[ridx] = sum; + off += (size_t)n_per_row; + } + } + } + + // Quantize per slice into quantized_buffer + { + size_t qoff = 0; + size_t foff = 0; + for (int64_t s = 0; s < ne2; ++s) { + const int64_t rs = rows_sample[s]; + if (rs == 0) { continue; } + + const float * v = has_values ? values_sample + s * n_per_row : nullptr; + (void)ggml_quantize_chunk(quant_type, f32_sample.data() + foff, quantized_buffer.data() + qoff, 0, rs, n_per_row, v); + qoff += row_sz * (size_t)rs; + foff += (size_t)rs * (size_t)n_per_row; + } + } + + // Dequantize into dequantized_buffer + { + if (quant_type == GGML_TYPE_F16) { + for (size_t r = 0; r < sample_rows; ++r) { + auto src = (const ggml_fp16_t *)(quantized_buffer.data() + r * row_sz); + float * dst = dequantized_buffer.data() + r * (size_t)n_per_row; + ggml_fp16_to_fp32_row(src, dst, (int)n_per_row); + } + } else if (quant_type == GGML_TYPE_BF16) { + for (size_t r = 0; r < sample_rows; ++r) { + auto src = (const ggml_bf16_t *)(quantized_buffer.data() + r * row_sz); + float * dst = dequantized_buffer.data() + r * (size_t)n_per_row; + ggml_bf16_to_fp32_row(src, dst, (int)n_per_row); + } + } else { + const ggml_type_traits * traits = ggml_get_type_traits(quant_type); + if (!traits || !traits->to_float) { + if (out_mse) { *out_mse = infinity; } + if (out_proj) { *out_proj = 0.0; } + return infinity; + } + for (size_t r = 0; r < sample_rows; ++r) { + const uint8_t * src = quantized_buffer.data() + r * row_sz; + float * dst = dequantized_buffer.data() + r * (size_t)n_per_row; + traits->to_float(src, dst, (int)n_per_row); + } + } + } + + // Compute error per slice with trimmed aggregation + auto trimmed_mean = [](std::vector & v) -> double { + const int64_t n = (int64_t)v.size(); + if (n == 0) { return 0.0; } + double sum = std::accumulate(v.begin(), v.end(), 0.0); + if (n < 50) { return sum / (double)n; } // too few elements to trim + int64_t k = (int64_t) std::floor(0.025 * (double)n); // trim 5% (2.5% each side) + std::sort(v.begin(), v.end()); + const auto num = (double)(n - 2 * k); + sum = std::accumulate(v.begin() + k, v.begin() + (n - k), 0.0); + return sum / std::max(1.0, num); + }; + + size_t off = 0; + size_t ridx = 0; + double total_mse = 0.0; + double total_proj = 0.0; + double total_bias = 0.0; + for (int64_t s = 0; s < ne2; ++s) { + const int64_t rs = rows_sample[s]; + if (rs == 0) { continue; } + + const float * v = has_values ? values_sample + s * n_per_row : nullptr; + const float * a = has_activations ? activations_sample + s * n_per_row : nullptr; + const double denom_bias = has_activations ? bias_denom[s] : 0.0; + std::vector row_mse_norm; + row_mse_norm.reserve(rs); + std::vector row_proj_norm; + if (a) { row_proj_norm.reserve(rs); } + + for (int64_t r = 0; r < rs; ++r, ++ridx) { + const float * x = f32_sample.data() + off; + const float * y = dequantized_buffer.data() + off; + double w_mse = 0.0; + double bias_num = 0.0; + for (int64_t j = 0; j < n_per_row; ++j) { + const double wj = v ? std::max(0.0f, v[j]) : 1.0; + const double e = y[j] - x[j]; + w_mse += wj * e * e; + if (a) { bias_num += wj * e * a[j]; } + } + + const double denom_x = row_sq_norm[ridx]; + const double m_norm = w_mse / (denom_x + epsilon); + row_mse_norm.push_back(std::isfinite(m_norm) ? m_norm : infinity); + + if (a) { + double p_norm = 0.0; + if (denom_bias > 0.0) { + const double proj = bias_num * bias_num / (denom_bias + epsilon); + p_norm = std::isfinite(proj) ? proj : 0.0; + } + + row_proj_norm.push_back(p_norm); + } + + off += (size_t)n_per_row; + } + + const double slice_mse = trimmed_mean(row_mse_norm) * (double)nrows; + const double slice_proj = a ? trimmed_mean(row_proj_norm) * (double)nrows : 0.0; + + total_mse += slice_mse; + total_proj += slice_proj; + + const double bl = slice_bias_lambda ? (double)std::max(0.0f, slice_bias_lambda[s]) : (double)tensor_bias_lambda; + total_bias += bl * slice_proj; + + if (!std::isfinite(total_mse) || !std::isfinite(total_proj) || !std::isfinite(total_bias)) { + if (out_mse) { *out_mse = infinity; } + if (out_proj) { *out_proj = 0.0; } + return infinity; + } + } + + if (out_mse) { *out_mse = total_mse; } + if (out_proj) { *out_proj = total_proj; } + + const double total_err = total_mse + total_bias; + return std::isfinite(total_err) ? total_err : infinity; + }; + + // Returns lambda per slice or 0.0 if no activations + auto estimate_lambda = [&](const float * values, const float * activations, const int64_t n_per_row, const int64_t ne2) -> std::vector { + const int64_t ns = std::max(1, ne2); + std::vector lambdas(ns, 0.0f); + if (!activations) { return lambdas; } + + for (int64_t s = 0; s < ns; ++s) { + const float * v = values ? values + s * n_per_row : nullptr; + const float * a = activations + s * n_per_row; + double s1 = 0.0; + double s2 = 0.0; + for (int64_t j = 0; j < n_per_row; ++j) { + const double w = v ? std::max(0.0f, v[j]) : 1.0; + const double aw = std::sqrt(w) * a[j]; + const double z = aw * aw; + s1 += z; + s2 += z * z; + } + + float l = 0.0f; + if (s1 > 0.0) { + const auto n = (double)n_per_row; + const double c = std::max(0.0, s2 / (s1 * s1 + epsilon) - 1.0 / n); + l = (float)std::clamp(12.0 * (c / (c + 1.0)), 0.0, 16.0); + } + + lambdas[(size_t)s] = l; + } + + return lambdas; + }; + + const auto bpw_data = load_bpw_state(); + + // Parallelize tensor processing (courtesy of https://github.com/ddh0) + auto process_tensor = [&](const llama_model_loader::llama_tensor_weight * tw, + std::vector> & thread_local_buffer, + std::mutex & loader_mutex, + std::mutex & log_mutex) -> std::optional + { + ggml_tensor * tensor = tw->tensor; + const std::string name = ggml_get_name(tensor); + if (bpw_stop.load(std::memory_order_relaxed)) { + return std::nullopt; + } + + // check for pre-computed results from a checkpoint file. + auto it_saved = bpw_data.find(name); + if (it_saved != bpw_data.end()) { + tensor_info info; + info.w = tw; + info.candidate = it_saved->second.candidate; + info.choice = it_saved->second.choice; + info.min_bpw = it_saved->second.min_bpw; + info.max_bpw = it_saved->second.max_bpw; + info.n_elements = it_saved->second.n_elements ? it_saved->second.n_elements : (size_t)ggml_nelements(tensor); + return info; + } + { + std::lock_guard lock(log_mutex); + LLAMA_LOG_INFO("\t%s: - processing tensor %45s \t(%12" PRId64 " elements)\n", func, name.c_str(), ggml_nelements(tensor)); + } + + if (!ml.use_mmap) { + if (thread_local_buffer.size() < ggml_nbytes(tensor)) { thread_local_buffer.resize(ggml_nbytes(tensor)); } + tensor->data = thread_local_buffer.data(); + } + { + std::lock_guard lock(loader_mutex); + ml.load_data_for(tensor); + } + + // Dequantize sampled rows into f32_sample + const int64_t n_per_row = tensor->ne[0]; + const int64_t nrows_total = tensor->ne[1]; + const int64_t ne2 = tensor->ne[2] > 0 ? tensor->ne[2] : 1; + + // Compute rows based on tensor shape and slice count + auto sample_rows = [](const int64_t n, const int64_t rows, const int64_t n2, const bool has_acts) -> int64_t { + const double tensor_budget = has_acts ? 1 * 1024 * 1024 : 0.5 * 1024 * 1024; + const double scale_rows = std::clamp(std::sqrt(std::max(1.0, (double)rows) / 4096.0), 0.5, 2.0); // favour more rows for large tensors + const double slice_budget = tensor_budget * scale_rows / std::max(1, n2); + const int64_t min_rows = has_acts ? 128 : 64; + constexpr int64_t max_rows = 4096; // row limit to avoid excessive memory use + int64_t total_rows = std::llround(slice_budget / std::max(1, n)); + total_rows = std::max(min_rows, std::min(total_rows, std::min(rows, max_rows))); + if (rows <= min_rows * 2) { total_rows = rows; } + return total_rows; + }; + + const int64_t rows_sample_per_expert = sample_rows(n_per_row, nrows_total, ne2, activations_data != nullptr); + std::vector f32_sample; + f32_sample.reserve((size_t)ne2 * (size_t)std::min(nrows_total, rows_sample_per_expert) * (size_t)n_per_row); + std::vector rows_sample(ne2, 0); + const ggml_type src_type = tensor->type; + const ggml_type_traits * src_traits = ggml_get_type_traits(src_type); + const bool src_is_quant = ggml_is_quantized(src_type); + const size_t src_row_sz = ggml_row_size(src_type, n_per_row); + + // Convert a single row to fp32 + auto row_to_fp32 = [&](const uint8_t * src, float * dst) { + const ggml_type t = src_type; + if (t == GGML_TYPE_F32) { + std::memcpy(dst, src, sizeof(float) * (size_t)n_per_row); + return; + } + if (t == GGML_TYPE_F16) { + ggml_fp16_to_fp32_row((const ggml_fp16_t *)src, dst, (int)n_per_row); + return; + } + if (t == GGML_TYPE_BF16) { + ggml_bf16_to_fp32_row((const ggml_bf16_t *)src, dst, (int)n_per_row); + return; + } + if (src_is_quant) { + GGML_ASSERT(src_traits && src_traits->to_float); + src_traits->to_float(src, dst, (int)n_per_row); + return; + } + + throw std::runtime_error(format("unsupported src type %s for sampling", ggml_type_name(t))); + }; + + // Sample rows randomly per slice + { + f32_sample.clear(); + std::vector row_buffer(n_per_row); + for (int64_t slice = 0; slice < ne2; ++slice) { + std::mt19937 rng(std::hash{}(name) ^ arbitrary_magic ^ slice); + const int64_t rows_sample_max = std::max(1, std::min(nrows_total, rows_sample_per_expert)); + const int64_t stride = std::max(1, nrows_total / rows_sample_max); + int64_t offset = 0; + if (stride > 1) { + std::uniform_int_distribution dist(0, stride - 1); + offset = dist(rng); + } + + int64_t current = 0; + for (int64_t r = offset; r < nrows_total && current < rows_sample_max; r += stride) { + const uint8_t * src_row = (const uint8_t *)tensor->data + slice * (src_row_sz * nrows_total) + r * src_row_sz; + if (src_type == GGML_TYPE_F32) { + const auto *src_f32 = (const float *)src_row; + f32_sample.insert(f32_sample.end(), src_f32, src_f32 + n_per_row); + } else { + row_to_fp32(src_row, row_buffer.data()); + f32_sample.insert(f32_sample.end(), row_buffer.begin(), row_buffer.end()); + } + + ++current; + } + + rows_sample[slice] = current; + } + } + + auto side_data = [&](const std::unordered_map> * m, const std::string & tensor_name) { + if (!m) { return std::pair{nullptr, 0}; } + + const std::string key = remap_imatrix(tensor_name, mapped); + const auto it = m->find(key); + return it == m->end() ? std::pair{nullptr, 0} : std::pair{ it->second.data(), it->second.size() }; + }; + + // Copy this row's side data (values and activations), or broadcasts to all slices + auto copy_or_broadcast = [&](const float * src, size_t src_sz, std::vector & dst) { + dst.clear(); + if (!src || src_sz == 0) { return; } + + const size_t want = (size_t)ne2 * (size_t)n_per_row; + if (src_sz == want) { + dst.assign(src, src + want); + return; + } + if (src_sz == (size_t)n_per_row) { + dst.resize(want); + for (int64_t s = 0; s < ne2; ++s) { + std::memcpy(dst.data() + s * n_per_row, src, n_per_row * sizeof(float)); + } + return; + } + + std::lock_guard lock(log_mutex); + LLAMA_LOG_WARN("%s: side data size mismatch for %s: got %zu, expected %zu or %zu; ignoring\n", func, name.c_str(), src_sz, (size_t)n_per_row, want); + }; + + const auto [values_all, values_sz] = side_data(values_data, name); + const auto [activations_all, activations_sz] = side_data(activations_data, name); + std::vector values_sample; + std::vector activations_sample; + if (values_all) { copy_or_broadcast(values_all, values_sz, values_sample); } + if (activations_all) { copy_or_broadcast(activations_all, activations_sz, activations_sample); } + + tensor_info info; + info.w = tw; + info.n_elements = ggml_nelements(tensor); + size_t total_sampled_rows = f32_sample.size() / n_per_row; + + // Build list of candidate types first (compatible ones) + const bool has_valid_imatrix = !values_sample.empty() && values_sample.size() == (size_t)ne2 * (size_t)n_per_row; + size_t max_row_sz = 0; + const ggml_type * base_arr = quant_types; + const size_t base_sz = std::size(quant_types); + std::vector compatible_candidates; + compatible_candidates.reserve(base_sz); + + for (size_t i = 0; i < base_sz; ++i) { + ggml_type ts_type = base_arr[i]; + if (is_iq(ts_type) && !has_valid_imatrix) { + std::lock_guard lock(log_mutex); + LLAMA_LOG_WARN("\t%s: skipping %s for %s, no or mismatched imatrix\n", func, ggml_type_name(ts_type), name.c_str()); + continue; + } + + ggml_type tt = make_compatible(tensor, ts_type); + if (!is_compatible(tensor, tt)) { continue; } + compatible_candidates.push_back(tt); + max_row_sz = std::max(max_row_sz, ggml_row_size(tt, n_per_row)); + } + + std::sort(compatible_candidates.begin(), compatible_candidates.end()); + compatible_candidates.erase(std::unique(compatible_candidates.begin(), compatible_candidates.end()), compatible_candidates.end()); + + // Adjusts the trade-off between systematic bias (introduced by block‑wise scaling) and MSE. + // Larger values favours quantisation types that produce smaller bias even if the MSE is slightly bigger + float tensor_lambda = 0.0f; + std::vector lambdas; + const float * values = values_sample.empty() ? nullptr : values_sample.data(); + const float * activations = activations_sample.empty() ? nullptr : activations_sample.data(); + double acc = 0.0; + int ns = 0; + lambdas = estimate_lambda(values, activations, n_per_row, ne2); + for (float l : lambdas) { acc += l; ++ns; } + tensor_lambda = ns ? (float)(acc / ns) : 0.0f; + + // Evaluate candidates + std::vector eval_candidates(compatible_candidates.size()); + std::vector quantized_buffer(max_row_sz * total_sampled_rows); + std::vector dequantized_buffer(f32_sample.size()); + const float * slice_lambda = lambdas.empty() ? nullptr : lambdas.data(); + for (size_t i = 0; i < compatible_candidates.size(); ++i) { + if (bpw_stop.load(std::memory_order_relaxed)) { return std::nullopt; } + + const ggml_type tensor_type = compatible_candidates[i]; + const auto bpw = (float)tensor_bpw(tensor, tensor_type); + const size_t bytes = tensor_bytes(tensor, tensor_type); + double mse = 0.0; + double proj = 0.0; + const auto err = estimate_error(tensor, tensor_type, f32_sample, rows_sample, values, activations, + quantized_buffer, dequantized_buffer, tensor_lambda, slice_lambda, &mse, &proj); + eval_candidates[i] = candidate_types{ tensor_type, bpw, bytes, err, mse, proj }; + } + + if (bpw_stop.load(std::memory_order_relaxed)) { return std::nullopt; } + + // Check if biasing is needed + bool bias_needed = false; + if (!lambdas.empty()) { + int min_mse = -1; + int min_bias = -1; + double best_mse = std::numeric_limits::infinity(); + double best_err = std::numeric_limits::infinity(); + for (int i = 0; i < (int)eval_candidates.size(); ++i) { + const auto & c = eval_candidates[i]; + if (c.bytes == 0) { continue; } + if (c.mse < best_mse) { + best_mse = c.mse; + min_mse = i; + } + if (c.error < best_err) { + best_err = c.error; + min_bias = i; + } + } + + if (min_mse != min_bias) { + bias_needed = true; + } else { + double max_rel_bias = 0.0; + for (const auto & c : eval_candidates) { + if (c.bytes == 0) { continue; } + const double mse = std::max(c.mse, epsilon); + const double bias_term = std::max(0.0, c.error - c.mse); + max_rel_bias = std::max(bias_term / mse, max_rel_bias); + } + + bias_needed = max_rel_bias >= 0.5; // >= 50% of MSE? + } + } + + for (auto & c : eval_candidates) { + if (c.bytes == 0) { continue; } + const double final_err = bias_needed ? c.error : c.mse; + info.candidate.push_back(candidate_types{ c.type, c.bpw, c.bytes, final_err, c.mse, c.proj }); + } + + if (info.candidate.empty()) { + // As a last resort, keep original type + float bpw = ggml_nbytes(tensor) * 8.0f / info.n_elements; + info.candidate.push_back(candidate_types{ tensor->type, bpw, ggml_nbytes(tensor), 0.0 }); + } + + // Keep only the pareto‑optimal candidates and enforce convexity in (bytes, error) curve + auto pareto_convex = [&](std::vector & candidates) { + if (candidates.empty()) { return; } + + std::sort(candidates.begin(), candidates.end(), [](const candidate_types & a, const candidate_types & b) { + if (a.bytes != b.bytes) { return a.bytes < b.bytes; } + return a.error < b.error; + }); + candidates.erase(std::unique(candidates.begin(), candidates.end(), [](const candidate_types & a, const candidate_types & b) { + return a.bytes == b.bytes; + }), candidates.end()); + std::vector pareto; + pareto.reserve(candidates.size()); + double best_err = infinity; + for (const auto & c : candidates) { + if (c.error < best_err) { + best_err = c.error; + pareto.push_back(c); + } + } + candidates.swap(pareto); + if (candidates.size() < 3) { return; } // need at least 3 points to do convex hull + + // Convex hull (lower envelope) + auto cross_product = [](const candidate_types & h0, const candidate_types & h1, const candidate_types & p) -> double { + const double dx1 = (double)h1.bytes - (double)h0.bytes; + const double dy1 = h1.error - h0.error; + const double dx2 = (double)p.bytes - (double)h0.bytes; + const double dy2 = p.error - h0.error; + return dx1 * dy2 - dx2 * dy1; + }; + std::vector hull; hull.reserve(candidates.size()); + for (const auto & c : candidates) { + while (hull.size() >= 2) { + if (cross_product(hull[hull.size() - 2], hull[hull.size() - 1], c) <= epsilon) { + hull.pop_back(); + } else { + break; + } + } + + hull.push_back(c); + } + + candidates.swap(hull); + }; + + pareto_convex(info.candidate); + + // Initialize choice at the smallest bpw candidate + info.choice = 0; + info.min_bpw = info.candidate.front().bpw; + info.max_bpw = info.candidate.back().bpw; + + return info; + }; + + std::vector all; // this vector will be populated by the parallel workers + { + std::atomic tensor_idx{0}; // shared work queue index for all threads + const size_t tensors_to_process = tensors.size(); + std::mutex loader_mutex; + std::mutex log_mutex; + std::mutex results_mutex; + std::vector workers; + int threads_to_spawn = std::max(1, std::min(nthread, (int)tensors_to_process)); + + for (int i = 0; i < threads_to_spawn; ++i) { + workers.emplace_back([&]() { + std::vector> thread_local_buffer; + while (true) { + const size_t current_idx = tensor_idx.fetch_add(1); + if (current_idx >= tensors_to_process) { break; } + const auto * tw = tensors[current_idx]; + if (!can_quantize(tw->tensor)) { continue; } + // Execute the main processing logic for this tensor + std::optional result_info = process_tensor(tw, thread_local_buffer, loader_mutex, log_mutex); + if (result_info) { + std::lock_guard lock(results_mutex); + all.push_back(std::move(*result_info)); + } + } + }); + } + + for (auto & w : workers) { w.join(); } + } + + check_signal_handler(all); + if (params->keep_bpw_state) { save_bpw_state(all); } + + if (all.empty()) { return {}; } + + // Compute total elements across all tensors and bytes for non-quantizable tensors + size_t nq_elements = 0; + size_t nq_bytes = 0; + for (const auto * it : tensors) { + const ggml_tensor * tensor = it->tensor; + const std::string name = ggml_get_name(tensor); + nq_elements += (size_t)ggml_nelements(tensor); + if (!can_quantize(tensor)) { nq_bytes += ggml_nbytes(tensor); } + } + + auto total_bytes = [&]() -> size_t { + size_t tb = 0; + for (const auto & ti : all) { + tb += ti.candidate[ti.choice].bytes; + } + + return tb; + }; + + size_t q_elements = 0; + size_t min_bytes = 0; + size_t max_bytes = 0; + for (const auto & ti : all) { + q_elements += (size_t)ti.n_elements; + min_bytes += ti.candidate.front().bytes; // smallest candidate per tensor + max_bytes += ti.candidate.back().bytes; // largest candidate per tensor + } + + if (q_elements == 0) { return {}; } + + const double target_bpw = params->target_bpw; + size_t target_total_bytes = std::llround(target_bpw * (double)nq_elements / 8.0); + size_t budget_bytes = target_total_bytes >= nq_bytes ? target_total_bytes - nq_bytes : min_bytes; + + // Get the types' override + auto emit_overrides = [&]() -> std::unordered_map { + std::unordered_map overrides; + LLAMA_LOG_INFO("%s: - estimated tensor quantization mix:\n", func); + for (const auto & ti : all) { + LLAMA_LOG_INFO("\t%s: %45s - \t%8s, \t%1.4f bpw,\terror: %.4f\n", + func, ggml_get_name(ti.w->tensor), ggml_type_name(ti.candidate[ti.choice].type), ti.candidate[ti.choice].bpw, ti.candidate[ti.choice].error); + overrides[ggml_get_name(ti.w->tensor)] = ti.candidate[ti.choice].type; + } + + return overrides; + }; + + if (budget_bytes <= min_bytes) { + for (auto & ti : all) { ti.choice = 0; } + return emit_overrides(); + } + if (budget_bytes >= max_bytes) { + for (auto & ti : all) { ti.choice = (int)ti.candidate.size() - 1; } + return emit_overrides(); + } + + // Certain tensors have a higher impact on model quality, so we apply a lower penalty to them + auto is_important = [&](const std::string & tensor_name) -> bool { + bool important = tensor_name == "output.weight"; + if (!important && !params->no_importance) { + important = tensor_name.find(".attn_v.weight") != std::string::npos || + tensor_name.find(".time_mix_value.weight") != std::string::npos || + tensor_name.find(".ffn_down.weight") != std::string::npos || + tensor_name.find(".ffn_down_exps.weight") != std::string::npos || + tensor_name.find(".attn_output.weight") != std::string::npos || + tensor_name.find(".time_mix_output.weight") != std::string::npos || + tensor_name.find(".attn_o.weight") != std::string::npos; + } + + return important; + }; + + // Lagrangian relaxation to minimize error subject to a bpw target constraint + auto lagrange_penalty = [&](const double mu, std::vector & choice, size_t & bytes, double & err) { + choice.resize(all.size()); + bytes = 0; + err = 0.0; + for (size_t i = 0; i < all.size(); ++i) { + const auto & candidate = all[i].candidate; + const std::string tensor_name = ggml_get_name(all[i].w->tensor); + double effective_mu = mu; + if (is_important(tensor_name)) { effective_mu *= 0.1; } // important tensors get 10x lower penalty + + int best_j = 0; + double best_val = infinity; + for (int j = 0; j < (int)candidate.size(); ++j) { + const double bits = (double)candidate[j].bytes * 8.0; + const double val = candidate[j].error + effective_mu * bits; + if (val < best_val - epsilon || (std::abs(val - best_val) <= epsilon && candidate[j].bytes < candidate[best_j].bytes)) { + best_val = val; + best_j = j; + } + } + + choice[i] = best_j; + bytes += candidate[best_j].bytes; + err += candidate[best_j].error; + } + }; + + size_t bytes_lo = 0; + size_t bytes_hi = 0; + size_t bytes_mid = 0; + double mu_lo = 0.0; + double mu_hi = 1.0; + double err_lo = 0.0; + double err_hi = 0.0; + double err_mid = 0.0; + std::vector choice_lo; + std::vector choice_hi; + std::vector choice_mid; + std::vector best_under_choice; + std::vector best_over_choice; + + lagrange_penalty(mu_lo, choice_lo, bytes_lo, err_lo); + + // Increase mu until we get under budget or hit a safety cap + { + int expand = 0; + size_t prev_bytes_hi = std::numeric_limits::max(); + while (true) { + lagrange_penalty(mu_hi, choice_hi, bytes_hi, err_hi); + if (bytes_hi <= budget_bytes) { break; } + if (bytes_hi >= prev_bytes_hi) { break; } + prev_bytes_hi = bytes_hi; + + mu_hi *= 2.0; // double the penalty multiplier to reduce tensor sizes + if (++expand > 60) { break; } // safety cap to prevent an infinite loop + } + } + + double best_under_gap = infinity; + double best_over_gap = infinity; + double best_under_err = infinity; + double best_over_err = infinity; + for (int it = 0; it < 40; ++it) { // binary search iterations for optimal Lagrange multiplier (40 ≈ 1e-12 precision) + double mu = 0.5 * (mu_lo + mu_hi); // midpoint of current bounds + lagrange_penalty(mu, choice_mid, bytes_mid, err_mid); + + const double gap = std::abs((double)bytes_mid - (double)budget_bytes); + if (bytes_mid > budget_bytes) { + // Too big, need stronger penalty + mu_lo = mu; + if (gap < best_over_gap - epsilon || (std::abs(gap - best_over_gap) <= epsilon && err_mid < best_over_err)) { + best_over_gap = gap; + best_over_err = err_mid; + best_over_choice = choice_mid; + } + } else { + // Under budget, good candidate + mu_hi = mu; + if (gap < best_under_gap - epsilon || (std::abs(gap - best_under_gap) <= epsilon && err_mid < best_under_err)) { + best_under_gap = gap; + best_under_err = err_mid; + best_under_choice = choice_mid; + } + } + } + + if (!best_under_choice.empty()) { + for (size_t i = 0; i < all.size(); ++i) { + all[i].choice = best_under_choice[i]; + } + } else if (!best_over_choice.empty()) { + for (size_t i = 0; i < all.size(); ++i) { + all[i].choice = best_over_choice[i]; + } + } else { + // Pick whichever side we already have, or keep minimal + if (bytes_hi <= budget_bytes && !choice_hi.empty()) { + for (size_t i = 0; i < all.size(); ++i) { + all[i].choice = choice_hi[i]; + } + } else { + for (auto & ti : all) { + ti.choice = 0; + } + } + } + + // Spend any remaining budget with best upgrades that still fit (one pass) + { + auto cur_bytes = total_bytes(); + while (true) { + int best_i = -1; + int best_j = -1; + double best_ratio = -1.0; + double best_gain = -1.0; + + for (int i = 0; i < (int)all.size(); ++i) { + const auto & ti = all[i]; + const std::string tensor_name = ggml_get_name(ti.w->tensor); + int j = ti.choice + 1; + if (j >= (int)ti.candidate.size()) { continue; } // no upgrade available + + size_t delta_bytes = ti.candidate[j].bytes - ti.candidate[ti.choice].bytes; + if (cur_bytes + delta_bytes > budget_bytes) { continue; } // won't fit in budget + + double err_gain = std::max(0.0, ti.candidate[ti.choice].error - ti.candidate[j].error); + if (err_gain < epsilon) { continue; } // no error improvement + + double ratio = err_gain / (double)delta_bytes; // error reduction per byte + if (is_important(tensor_name)) { ratio *= 5.0; } // important tensors get 5x boost + + // For tie-breaking, prioritize the largest absolute error improvement. + if (ratio > best_ratio + epsilon || (std::abs(ratio - best_ratio) <= epsilon && err_gain > best_gain)) { + best_ratio = ratio; + best_gain = err_gain; + best_i = i; + best_j = j; + } + } + + if (best_i < 0) { break; } // no more upgrades within budget found + + size_t upgrade_cost = all[best_i].candidate[best_j].bytes - all[best_i].candidate[all[best_i].choice].bytes; + all[best_i].choice = best_j; + cur_bytes += upgrade_cost; + } + } + + delete_bpw_state(); + + return emit_overrides(); +} + static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { ggml_type default_type; llama_ftype ftype = params->ftype; @@ -610,14 +1841,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: if (params->only_copy) { ftype = ml.ftype; } - const std::unordered_map> * imatrix_data = nullptr; + const std::unordered_map> * values_data = nullptr; + const std::unordered_map> * activations_data = nullptr; if (params->imatrix) { - imatrix_data = static_cast>*>(params->imatrix); - if (imatrix_data) { - LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size())); + values_data = static_cast>*>(params->imatrix); + if (values_data) { + LLAMA_LOG_INFO("================================ Have weights data with %d entries",int(values_data->size())); qs.has_imatrix = true; // check imatrix for nans or infs - for (const auto & kv : *imatrix_data) { + for (const auto & kv : *values_data) { for (float f : kv.second) { if (!std::isfinite(f)) { throw std::runtime_error(format("imatrix contains non-finite value %f\n", f)); @@ -626,8 +1858,23 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } } } + if (params->activations) { + activations_data = static_cast>*>(params->activations); + if (activations_data) { + LLAMA_LOG_INFO(" and %d activations",int(activations_data->size())); + qs.has_activations = true; + // check activations for nans or infs + for (const auto & kv : *activations_data) { + for (float f : kv.second) { + if (!std::isfinite(f)) { + throw std::runtime_error(format("activations contain non-finite value %f\n", f)); + } + } + } + } + } + LLAMA_LOG_INFO("\n"); - const size_t align = GGUF_DEFAULT_ALIGNMENT; gguf_context_ptr ctx_out { gguf_init_empty() }; std::vector prune_list = {}; @@ -756,6 +2003,27 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } } + std::unordered_map bpw_overrides = {}; + if (params->target_bpw != -1.0f && !params->only_copy) { + if (params->imatrix) { + if (params->activations) { + LLAMA_LOG_INFO("%s: imatrix has activations, process will be more accurate\n", __func__); + } else { + LLAMA_LOG_INFO("%s: imatrix does not have activations, process may be less accurate\n", __func__); + } + if (params->no_importance) { + LLAMA_LOG_INFO("%s: distributing bpw budget equitably across all tensors\n", __func__); + } else { + LLAMA_LOG_INFO("%s: assigning more bpw budget to important tensors\n", __func__); + } + LLAMA_LOG_INFO("%s: computing tensor quantization mix to achieve %.4f bpw\n", __func__, params->target_bpw); + + bpw_overrides = target_bpw_type(ml, model, tensors, mapped, values_data, activations_data, params, nthread); + } else { + LLAMA_LOG_WARN("%s: --target-bpw requires an imatrix but none was provided, option will be ignored\n", __func__); + } + } + int cur_split = -1; std::ofstream fout; auto close_ofstream = [&]() { @@ -788,6 +2056,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: const auto tn = LLM_TN(model.arch); new_ofstream(0); for (const auto * it : tensors) { + const size_t align = GGUF_DEFAULT_ALIGNMENT; const auto & weight = *it; ggml_tensor * tensor = weight.tensor; if (weight.idx != cur_split && params->keep_split) { @@ -806,62 +2075,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: ml.load_data_for(tensor); LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ", - ++idx, ml.n_tensors, - ggml_get_name(tensor), - llama_format_tensor_shape(tensor).c_str(), - ggml_type_name(tensor->type)); - - // This used to be a regex, but has an extreme cost to compile times. - bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? - - // quantize only 2D and 3D tensors (experts) - quantize &= (ggml_n_dims(tensor) >= 2); - - // do not quantize norm tensors - quantize &= name.find("_norm.weight") == std::string::npos; + ++idx, ml.n_tensors, ggml_get_name(tensor), llama_format_tensor_shape(tensor).c_str(), ggml_type_name(tensor->type)); + bool quantize = ggml_n_dims(tensor) >= 2 && is_quantizable(name, model.arch, params); quantize &= params->quantize_output_tensor || name != "output.weight"; - quantize &= !params->only_copy; - - // do not quantize expert gating tensors - // NOTE: can't use LLM_TN here because the layer number is not known - quantize &= name.find("ffn_gate_inp.weight") == std::string::npos; - - // these are very small (e.g. 4x4) - quantize &= name.find("altup") == std::string::npos; - quantize &= name.find("laurel") == std::string::npos; - - // these are not too big so keep them as it is - quantize &= name.find("per_layer_model_proj") == std::string::npos; - - // do not quantize positional embeddings and token types (BERT) - quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight"); - quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight"); - - // do not quantize Mamba's small yet 2D weights - // NOTE: can't use LLM_TN here because the layer number is not known - quantize &= name.find("ssm_conv1d.weight") == std::string::npos; - quantize &= name.find("shortconv.conv.weight") == std::string::npos; - - // do not quantize RWKV's small yet 2D weights - quantize &= name.find("time_mix_first.weight") == std::string::npos; - quantize &= name.find("time_mix_w0.weight") == std::string::npos; - quantize &= name.find("time_mix_w1.weight") == std::string::npos; - quantize &= name.find("time_mix_w2.weight") == std::string::npos; - quantize &= name.find("time_mix_v0.weight") == std::string::npos; - quantize &= name.find("time_mix_v1.weight") == std::string::npos; - quantize &= name.find("time_mix_v2.weight") == std::string::npos; - quantize &= name.find("time_mix_a0.weight") == std::string::npos; - quantize &= name.find("time_mix_a1.weight") == std::string::npos; - quantize &= name.find("time_mix_a2.weight") == std::string::npos; - quantize &= name.find("time_mix_g1.weight") == std::string::npos; - quantize &= name.find("time_mix_g2.weight") == std::string::npos; - quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos; - quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos; - quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos; - - // do not quantize relative position bias (T5) - quantize &= name.find("attn_rel_b.weight") == std::string::npos; // do not quantize specific multimodal tensors quantize &= name.find(".position_embd.") == std::string::npos; @@ -874,17 +2091,27 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: new_type = default_type; // get more optimal quantization type based on the tensor shape, layer, etc. - if (!params->pure && ggml_is_quantized(default_type)) { + if (!params->pure && (ggml_is_quantized(default_type) || params->target_bpw != -1.0f)) { int fallback = qs.n_fallback; new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); - // unless the user specifies a type, and the tensor geometry will not require fallback quantisation + + // get quantization type overrides targeting a given bits per weight budget + if (params->target_bpw != -1.0f && !bpw_overrides.empty()) { + const auto override = bpw_overrides.find(name); + if (override != bpw_overrides.end() && override->second != new_type) { + LLAMA_LOG_DEBUG("(bpw override %s) ", ggml_type_name(new_type)); + new_type = override->second; + } + } + + // unless the user specifies a type, and the tensor shape will not require fallback quantisation if (params->tensor_types && qs.n_fallback - fallback == 0) { const std::vector & tensor_types = *static_cast *>(params->tensor_types); const std::string tensor_name(tensor->name); for (const auto & [tname, qtype] : tensor_types) { if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) { if (qtype != new_type) { - LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type)); + LLAMA_LOG_DEBUG("(type override %s) ", ggml_type_name(new_type)); new_type = qtype; // if two or more types are specified for the same tensor, the last match wins } } @@ -912,10 +2139,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: const int64_t nelements = ggml_nelements(tensor); const float * imatrix = nullptr; - if (imatrix_data) { - auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped)); - if (it == imatrix_data->end()) { - LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name); + if (values_data) { + auto it = values_data->find(remap_imatrix(tensor->name, mapped)); + if (it == values_data->end()) { + LLAMA_LOG_INFO("\n====== %s: did not find weights for %s, ", __func__, tensor->name); } else { if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) { imatrix = it->second.data(); @@ -1049,9 +2276,14 @@ llama_model_quantize_params llama_model_quantize_default_params() { /*.pure =*/ false, /*.keep_split =*/ false, /*.imatrix =*/ nullptr, + /*.activations =*/ nullptr, /*.kv_overrides =*/ nullptr, /*.tensor_type =*/ nullptr, - /*.prune_layers =*/ nullptr + /*.prune_layers =*/ nullptr, + /*.target_bpw =*/ -1.0f, + /*.keep_bpw_state =*/ false, + /*.bpw_state =*/ nullptr, + /*.no_importance =*/ false }; return result; diff --git a/tools/quantize/README.md b/tools/quantize/README.md index 22f0710286..986ba95be5 100644 --- a/tools/quantize/README.md +++ b/tools/quantize/README.md @@ -56,8 +56,10 @@ Options: * `--keep-split` will generate the quantized model in the same shards as the input file otherwise it will produce a single quantized file Advanced options: -* `--tensor-type` quantize specific tensor(s) to specific quant types. Supports regex syntax. May be specified multiple times. +* `--tensor-type` quantize specific tensor(s) to specific quant types. Supports regex syntax. May be specified multiple times * `--prune-layers` prune (remove) the layers in the list +* `--target-bpw` automatically choose quant types so that the overall model size matches a given bits per weight (bpw) average +* `--no-importance` during bpw computation, treat each tensor equally instead of prioritizing some. It may yield better quality for some models * `--override-kv` option to override model metadata by key in the quantized model. May be specified multiple times Examples: @@ -97,59 +99,54 @@ Examples: ./llama-quantize --imatrix imatrix.gguf --override-kv qwen3moe.expert_used_count=int:16 --prune-layers 20,21,22 input-model-f32.gguf pruned-model-f32.gguf copy 8 ``` +```bash +# quantize model targeting a specific bpw average and save the bpw computations to the default file. Model type is optional and can be omitted +./llama-quantize --target-bpw 4.567 --keep-bpw-state --imatrix imatrix.gguf input-model-f32.gguf 8 +``` + ## Memory/Disk Requirements When running the larger models, make sure you have enough disk space to store all the intermediate files. As the models are currently fully loaded into memory, you will need adequate disk space to save them and sufficient RAM to load them. At the moment, memory and disk requirements are the same. For exmaple (Llama 3.1): | Model | Original size | Quantized size (Q4_K_M) | -| ----: | ------------: | ----------------------: | +|------:|--------------:|------------------------:| | 8B | 32.1 GB | 4.9 GB | | 70B | 280.9 GB | 43.1 GB | | 405B | 1,625.1 GB | 249.1 GB | - ## Quantization Several quantization methods are supported. They differ in the resulting model disk size and inference speed. For example, ### [meta-llama/Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B) -| Measure | IQ1_S | IQ1_M | IQ2_XXS | IQ2_XS | IQ2_S | IQ2_M | -| --------------------------- | ------------ | ------------ | ------------ | ------------- | ------------- | ------------ | -| bits/weight | 2.0042 | 2.1460 | 2.3824 | 2.5882 | 2.7403 | 2.9294 | -| size (GiB) | 1.87 | 2.01 | 2.23 | 2.42 | 2.56 | 2.74 | -| prompt processing t/s @ 512 | 858.88 ±1.22 | 847.99 ±0.47 | 852.39 ±0.85 | 826.99 ±12.51 | 783.55 ±13.73 | 787.68 ±7.00 | -| text generation t/s @ 128 | 79.73 ±0.79 | 72.92 ±0.14 | 79.86 ±0.22 | 78.04 ±0.46 | 77.30 ±2.47 | 74.44 ±0.15 | - -| Measure | IQ3_XXS | IQ3_XS | IQ3_S | IQ3_M | IQ4_XS | IQ4_NL | -| --------------------------- | ------------ | ------------ | ------------ | ------------- | ------------- | ------------ | -| bits/weight | 3.2548 | 3.4977 | 3.6606 | 3.7628 | 4.4597 | 4.6818 | -| size (GiB) | 3.04 | 3.27 | 3.42 | 3.52 | 4.17 | 4.38 | -| prompt processing t/s @ 512 | 813.88 ±6.53 | 708.71 ±1.26 | 798.78 ±8.81 | 768.70 ±13.73 | 771.80 ±11.38 | 806.03 ±7.07 | -| text generation t/s @ 128 | 73.95 ±0.20 | 71.67 ±0.54 | 69.31 ±0.63 | 70.15 ±0.33 | 77.51 ±0.20 | 76.63 ±0.28 | - - -| Measure | Q2_K_S | Q2_K | Q3_K_S | Q3_K_M | Q3_K_L | Q4_K_S | -| --------------------------- | ------------ | ------------ | ------------ | ------------ | ------------ | ------------ | -| bits/weight | 2.9697 | 3.1593 | 3.6429 | 3.9960 | 4.2979 | 4.6672 | -| size (GiB) | 2.78 | 2.95 | 3.41 | 3.74 | 4.02 | 4.36 | -| prompt processing t/s @ 512 | 798.91 ±6.40 | 784.45 ±7.85 | 752.17 ±7.94 | 783.44 ±9.92 | 761.17 ±7.55 | 818.55 ±9.58 | -| text generation t/s @ 128 | 90.01 ±0.12 | 79.85 ±0.20 | 69.84 ±0.18 | 71.68 ±0.22 | 69.38 ±0.49 | 76.71 ±0.20 | - -| Measure | Q4_K_S | Q4_K_M | Q5_K_S | Q5_K_M | Q6_K | Q8_0 | -| --------------------------- | ------------ | ------------- | ------------ | ------------ | ------------- | ------------ | -| bits/weight | 4.6672 | 4.8944 | 5.5704 | 5.7036 | 6.5633 | 8.5008 | -| size (GiB) | 4.36 | 4.58 | 5.21 | 5.33 | 6.14 | 7.95 | -| prompt processing t/s @ 512 | 818.55 ±9.58 | 821.81 ±21.44 | 752.52 ±0.99 | 758.69 ±7.43 | 812.01 ±10.82 | 865.09 ±8.30 | -| text generation t/s @ 128 | 76.71 ±0.20 | 71.93 ±1.52 | 69.53 ±0.18 | 67.23 ±1.08 | 58.67 ±3.13 | 50.93 ±0.08 | - -| Measure | F16 | -| --------------------------- | ------------ | -| bits/weight | 16.0005 | -| size (GiB) | 14.96 | -| prompt processing t/s @ 512 | 923.49 ±0.53 | -| text generation t/s @ 128 | 29.17 ±0.04 | +| Quant Type | bits/weight | size (GiB) | prompt processing t/s @ 512 | text generation t/s @ 128 | +|:----------:|------------:|-----------:|----------------------------:|--------------------------:| +| IQ1_S | 2.0042 | 1.87 | 858.88 ±1.22 | 79.73 ±0.79 | +| IQ1_M | 2.1460 | 2.01 | 847.99 ±0.47 | 72.92 ±0.14 | +| IQ2_XXS | 2.3824 | 2.23 | 852.39 ±0.85 | 79.86 ±0.22 | +| IQ2_XS | 2.5882 | 2.42 | 826.99 ±12.51 | 78.04 ±0.46 | +| IQ2_S | 2.7403 | 2.56 | 783.55 ±13.73 | 77.30 ±2.47 | +| IQ2_M | 2.9294 | 2.74 | 787.68 ±7.00 | 74.44 ±0.15 | +| IQ3_XXS | 3.2548 | 3.04 | 813.88 ±6.53 | 73.95 ±0.20 | +| IQ3_XS | 3.4977 | 3.27 | 708.71 ±1.26 | 71.67 ±0.54 | +| IQ3_S | 3.6606 | 3.42 | 798.78 ±8.81 | 69.31 ±0.63 | +| IQ3_M | 3.7628 | 3.52 | 768.70 ±13.73 | 70.15 ±0.33 | +| IQ4_XS | 4.4597 | 4.17 | 771.80 ±11.38 | 77.51 ±0.20 | +| IQ4_NL | 4.6818 | 4.38 | 818.55 ±9.58 | 76.71 ±0.20 | +| Q2_K_S | 2.9697 | 2.78 | 798.91 ±6.40 | 90.01 ±0.12 | +| Q2_K | 3.1593 | 2.95 | 784.45 ±7.85 | 79.85 ±0.20 | +| Q3_K_S | 3.6429 | 3.41 | 752.17 ±7.94 | 71.68 ±0.22 | +| Q3_K_L | 4.2979 | 4.02 | 761.17 ±7.55 | 69.38 ±0.49 | +| Q4_K_S | 4.6672 | 4.36 | 818.55 ±9.58 | 76.71 ±0.20 | +| Q4_K_S | 4.6672 | 4.36 | 818.55 ±9.58 | 76.71 ±0.20 | +| Q4_K_M | 4.8944 | 4.58 | 821.81 ±21.44 | 71.93 ±1.52 | +| Q5_K_S | 5.5704 | 5.21 | 752.52 ±0.99 | 69.53 ±0.18 | +| Q5_K_M | 5.7036 | 5.33 | 758.69 ±7.43 | 67.23 ±1.08 | +| Q6_K | 6.5633 | 6.14 | 812.01 ±10.82 | 58.67 ±3.13 | +| Q8_0 | 8.5008 | 7.95 | 865.09 ±8.30 | 50.93 ±0.08 | +| F16 | 16.0005 | 14.96 | 923.49 ±0.53 | 29.17 ±0.04 | ## Background information on llama-quantize diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index 881f4b3dd9..9ffff70d09 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -118,21 +118,27 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp [[noreturn]] static void usage(const char * executable) { - printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights]\n", executable); - printf(" [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--tensor-type] [--prune-layers] [--keep-split] [--override-kv]\n"); - printf(" model-f32.gguf [model-quant.gguf] type [nthreads]\n\n"); - printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n"); - printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n"); - printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n"); + printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights]\n", executable); + printf(" [--target-bpw n] [--no-importance] [--keep-bpw-state] [--bpw-state filename] [--output-tensor-type] [--token-embedding-type] [--tensor-type]\n"); + printf(" [--prune-layers] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n"); + printf(" --allow-requantize: allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n"); + printf(" --leave-output-tensor: will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n"); + printf(" --pure: disable k-quant mixtures and quantize all tensors to the same type\n"); printf(" --imatrix file_name: use data in file_name as importance matrix for quant optimizations\n"); printf(" --include-weights tensor_name: use importance matrix for this/these tensor(s)\n"); printf(" --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n"); printf(" --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor\n"); printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n"); - printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n"); + printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. Example: --tensor-type attn_q=q8_0\n"); printf(" Advanced option to selectively quantize tensors. May be specified multiple times.\n"); printf(" --prune-layers L0,L1,L2...comma-separated list of layer numbers to prune from the model\n"); printf(" Advanced option to remove all tensors from the given layers\n"); + printf(" --target-bpw: target bits per weight (bpw). Must be a positive number between 0.0 and 16.0\n"); + printf(" Advanced option to automatically select quantization types to achieve a total bits per weight (bpw) target\n"); + printf(" --no-importance: distribute bpw budget equitably across all tensors\n"); + printf(" Advanced option to disable assigning more bpw budget to important tensors. It may increase quality for some models\n"); + printf(" --keep-bpw-state: save the bpw computations to -.bpw_state\n"); + printf(" --bpw-state: file name to use instead of default\n"); printf(" --keep-split: will generate quantized model in the same shards as input\n"); printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n"); @@ -215,7 +221,10 @@ static int load_legacy_imatrix(const std::string & imatrix_file, std::vector & imatrix_datasets, std::unordered_map> & imatrix_data) { +static int load_imatrix(const std::string & imatrix_file, + std::vector & imatrix_datasets, + std::unordered_map> & values_data, + std::unordered_map> & activations_data) { struct ggml_context * ctx = nullptr; struct gguf_init_params meta_gguf_params = { @@ -225,7 +234,7 @@ static int load_imatrix(const std::string & imatrix_file, std::vector> sums_counts_for; + std::map> sums_counts_for; for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { std::string name = cur->name; @@ -259,44 +269,55 @@ static int load_imatrix(const std::string & imatrix_file, std::vector(sums_counts_for[std::move(name)]) = cur; + } else if (string_remove_suffix(name, sums2_suffix)) { // in_sum2 - sums_counts_for[std::move(name)].first = cur; + std::get<1>(sums_counts_for[std::move(name)]) = cur; } else if (string_remove_suffix(name, counts_suffix)) { // counts - sums_counts_for[std::move(name)].second = cur; - } else { + std::get<2>(sums_counts_for[std::move(name)]) = cur; + } else { // ignore other tensors } } for (const auto & sc : sums_counts_for) { const std::string & name = sc.first; - const struct ggml_tensor * sums = sc.second.first; - const struct ggml_tensor * counts = sc.second.second; + const struct ggml_tensor * sums = std::get<0>(sc.second); + const struct ggml_tensor * sums2 = std::get<1>(sc.second); + const struct ggml_tensor * counts = std::get<2>(sc.second); - if (!sums || !counts) { + // check sums2 and counts are present, and that sums and sums2 have the same shape + if (!sums2 || !counts || (sums != nullptr && ggml_nelements(sums) != ggml_nelements(sums2))) { fprintf(stderr, "%s: mismatched sums and counts for %s\n", __func__, name.c_str()); gguf_free(ctx_gguf); ggml_free(ctx); exit(1); } - const int64_t ne0 = sums->ne[0]; - const int64_t ne1 = sums->ne[1]; + const int64_t ne0 = sums2->ne[0]; + const int64_t ne1 = sums2->ne[1]; - auto & e = imatrix_data[name]; - e.resize(ggml_nelements(sums)); + auto & activations = activations_data[name]; + auto & values = values_data[name]; + if (sums) { + activations.resize(ggml_nelements(sums)); + } + values.resize(ggml_nelements(sums2)); float max_count = 0.0f; for (int64_t j = 0; j < ne1; ++j) { const float count = ((const float *) counts->data)[j]; if (count > 0.0f) { for (int64_t i = 0; i < ne0; ++i) { - e[j*ne0 + i] = ((const float *) sums->data)[j*ne0 + i] / count; + values[j*ne0 + i] = ((const float *) sums2->data)[j*ne0 + i] / count; + if (sums) { activations[j*ne0 + i] = ((const float *) sums->data)[j*ne0 + i] / count; } } } else { // Partial imatrix data, this tensor never got any input during calibration for (int64_t i = 0; i < ne0; ++i) { - e[j*ne0 + i] = 1; + values[j*ne0 + i] = 1; + if (sums) { activations[j*ne0 + i] = 0; } } } if (count > max_count) { @@ -304,7 +325,8 @@ static int load_imatrix(const std::string & imatrix_file, std::vector & imatrix_dataset, const std::vector & included_weights, const std::vector & excluded_weights, - std::unordered_map> & imatrix_data) { + std::unordered_map> & values_data, + std::unordered_map> & activations_data) { int m_last_call = -1; if (!imatrix_file.empty()) { - m_last_call = load_imatrix(imatrix_file, imatrix_dataset, imatrix_data); + m_last_call = load_imatrix(imatrix_file, imatrix_dataset, values_data, activations_data); } - if (imatrix_data.empty()) { + if (values_data.empty()) { return m_last_call; } if (!excluded_weights.empty()) { for (const auto & name : excluded_weights) { - for (auto it = imatrix_data.begin(); it != imatrix_data.end();) { - auto pos = it->first.find(name); + for (auto vt = values_data.begin(); vt != values_data.end();) { + auto pos = vt->first.find(name); if (pos != std::string::npos) { - it = imatrix_data.erase(it); + vt = values_data.erase(vt); } else { - ++it; + ++vt; + } + } + for (auto at = activations_data.begin(); at != activations_data.end();) { + auto pos = at->first.find(name); + if (pos != std::string::npos) { + at = activations_data.erase(at); + } else { + ++at; } } } } if (!included_weights.empty()) { - std::unordered_map> tmp; + std::unordered_map> tmp_values; + std::unordered_map> tmp_activations; for (const auto & name : included_weights) { - for (auto & e : imatrix_data) { + for (auto & e : values_data) { auto pos = e.first.find(name); if (pos != std::string::npos) { - tmp.emplace(std::move(e)); + tmp_values.emplace(std::move(e)); + } + } + for (auto & a : activations_data) { + auto pos = a.first.find(name); + if (pos != std::string::npos) { + tmp_activations.emplace(std::move(a)); } } } - imatrix_data = std::move(tmp); - } - if (!imatrix_data.empty()) { - printf("%s: have %d importance matrix entries\n", __func__, int(imatrix_data.size())); + values_data = std::move(tmp_values); + activations_data = std::move(tmp_activations); } + return m_last_call; } @@ -441,6 +478,52 @@ static bool parse_layer_prune(const char * data, std::vector & prune_layers return true; } +static bool parse_target_bpw(const char * data, float & target_bpw) { + if (!data) { + printf("\n%s: no target bits per weight (bpw) provided\n\n", __func__); + return false; + } + + try { + target_bpw = std::stof(data); + if (target_bpw < 0.0f || target_bpw > 16.0f) { + printf("\n%s: target bits per weight (bpw) must be a positive number between 0.0 and 16.0\n\n", __func__); + return false; + } + } + catch (const std::exception & e) { + printf("\n%s: '%s' is not valid. Target bits per weight (bpw) must be a positive number between 0.0 and 16.0\n\n", __func__, data); + return false; + } + + return true; +} + +static const char * get_ftype(const float bpw) { + const std::map quant_bpw = { + {1.5625, "IQ1_S"}, + {1.7500, "IQ1_M"}, + {2.0625, "IQ2_XXS"}, + {2.3125, "IQ2_XS"}, + {2.5625, "IQ2_S"}, + {2.6250, "Q2_K"}, + {3.0625, "IQ3_XXS"}, + {3.4375, "Q3_K"}, + {4.2500, "IQ4_XS"}, + {4.5000, "Q4_K"}, + {5.5000, "Q5_K"}, + {6.5625, "Q6_K"}, + {8.5000, "Q8_0"}, +#ifdef GGML_USE_METAL + {16.0000, "F16"} +#else + {16.0000, "BF16"} +#endif + }; + + return quant_bpw.lower_bound(bpw)->second; +} + int main(int argc, char ** argv) { if (argc < 3) { usage(argv[0]); @@ -454,6 +537,7 @@ int main(int argc, char ** argv) { std::vector kv_overrides; std::vector tensor_types; std::vector prune_layers; + float target_bpw = -1.0f; for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) { if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) { @@ -480,6 +564,20 @@ int main(int argc, char ** argv) { if (arg_idx == argc-1 || !parse_tensor_type(argv[++arg_idx], tensor_types)) { usage(argv[0]); } + } else if (strcmp(argv[arg_idx], "--target-bpw") == 0) { + if (arg_idx == argc-1 || !parse_target_bpw(argv[++arg_idx], target_bpw)) { + usage(argv[0]); + } + } else if (strcmp(argv[arg_idx], "--no-importance") == 0) { + params.no_importance = true; + } else if (strcmp(argv[arg_idx], "--keep-bpw-state") == 0) { + params.keep_bpw_state = true; + } else if (strcmp(argv[arg_idx], "--bpw-state") == 0) { + if (arg_idx < argc-1) { + params.bpw_state = argv[++arg_idx]; + } else { + usage(argv[0]); + } } else if (strcmp(argv[arg_idx], "--prune-layers") == 0) { if (arg_idx == argc-1 || !parse_layer_prune(argv[++arg_idx], prune_layers)) { usage(argv[0]); @@ -526,10 +624,11 @@ int main(int argc, char ** argv) { } std::vector imatrix_datasets; - std::unordered_map> imatrix_data; - int m_last_call = prepare_imatrix(imatrix_file, imatrix_datasets, included_weights, excluded_weights, imatrix_data); - if (!imatrix_data.empty()) { - params.imatrix = &imatrix_data; + std::unordered_map> values_data; + std::unordered_map> activations_data; + int m_last_call = prepare_imatrix(imatrix_file, imatrix_datasets, included_weights, excluded_weights, values_data, activations_data); + if (!values_data.empty()) { + params.imatrix = &values_data; { llama_model_kv_override kvo; std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_FILE); @@ -552,7 +651,7 @@ int main(int argc, char ** argv) { llama_model_kv_override kvo; std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES); kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT; - kvo.val_i64 = imatrix_data.size(); + kvo.val_i64 = values_data.size(); kv_overrides.emplace_back(std::move(kvo)); } @@ -564,6 +663,9 @@ int main(int argc, char ** argv) { kv_overrides.emplace_back(std::move(kvo)); } } + if (!activations_data.empty()) { + params.activations = &activations_data; + } if (!kv_overrides.empty()) { kv_overrides.emplace_back(); kv_overrides.back().key[0] = 0; @@ -575,6 +677,9 @@ int main(int argc, char ** argv) { if (!prune_layers.empty()) { params.prune_layers = &prune_layers; } + if (target_bpw != -1.0f) { + params.target_bpw = target_bpw; + } llama_backend_init(); @@ -585,6 +690,7 @@ int main(int argc, char ** argv) { std::string ftype_str; std::string suffix = ".gguf"; + std::vector tmp_argv(argv, argv + argc); if (try_parse_ftype(argv[arg_idx], params.ftype, ftype_str)) { std::string fpath; const size_t pos = fname_inp.find_last_of("/\\"); @@ -608,7 +714,15 @@ int main(int argc, char ** argv) { } arg_idx++; - if (argc <= arg_idx) { + // select quantization type if target_bpw is set unless user specifies type and threads + if (argc - arg_idx <= 1 && params.target_bpw != -1.0f) { + auto * ftype = const_cast(get_ftype(params.target_bpw)); + if (argc == arg_idx) { tmp_argv.push_back(ftype); } + else { tmp_argv.insert(tmp_argv.end() - 1, ftype); } + tmp_argv.push_back(nullptr); + argv = const_cast(tmp_argv.data()); + argc++; + } else if (argc <= arg_idx) { fprintf(stderr, "%s: missing ftype\n", __func__); return 1; } @@ -637,7 +751,7 @@ int main(int argc, char ** argv) { params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || - params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) && imatrix_data.empty()) { + params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) && values_data.empty()) { fprintf(stderr, "\n==========================================================================================================\n"); fprintf(stderr, "Please do not use IQ1_S, IQ1_M, IQ2_S, IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n"); fprintf(stderr, "==========================================================================================================\n\n\n");