diff --git a/include/llama.h b/include/llama.h index 01c5b67c75..3a5bda32ea 100644 --- a/include/llama.h +++ b/include/llama.h @@ -357,6 +357,7 @@ extern "C" { void * tensor_types; // pointer to vector containing tensor types void * prune_layers; // pointer to vector containing layer indices to prune float target_bpw; // target bits per weight (bpw) + bool precise_lambda; // use precise_lambda calculation - slow computation but very accurate } llama_model_quantize_params; typedef struct llama_logit_bias { diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 662760fbe9..98fc11d840 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -722,7 +722,8 @@ static std::unordered_map target_bpw_type( const float * values_sample, const float * activations_sample, std::vector & quantized_buffer, - std::vector & dequantized_buffer) -> double + std::vector & dequantized_buffer, + float bias_lambda) -> double { const int64_t n_per_row = t->ne[0]; const int64_t nrows = t->ne[1]; @@ -878,10 +879,6 @@ static std::unordered_map target_bpw_type( } } - // bias_lambda adjusts the trade-off between systematic bias (introduced by block‑wise scaling) and MSE - // larger value favours quantisation types that produce smaller bias even if the MSE is slightly larger - constexpr float bias_lambda = 1.5f; - constexpr double epsilon = 1e-12; double err_num = weighted_mse; if (activations && bias_lambda != 0.0f) { const double proj = bias_num * bias_num / (bias_denom + epsilon); @@ -1163,6 +1160,15 @@ static std::unordered_map target_bpw_type( std::sort(compatible_candidates.begin(), compatible_candidates.end()); compatible_candidates.erase(std::unique(compatible_candidates.begin(), compatible_candidates.end()), compatible_candidates.end()); + // Compute adaptive bias_lambda for this tensor + float bias_lambda = 0.0f; + { + const float * values = values_sample.empty() ? nullptr : values_sample.data(); + const float * activations = activations_sample.empty() ? nullptr : activations_sample.data(); + bias_lambda = params->precise_lambda ? precise_lambda(t, f32_sample, sample_rows_per_slice, values, activations, compatible_candidates) : + fast_lambda(values, activations, n_per_row); + } + // Now evaluate candidates std::vector eval_candidates(compatible_candidates.size()); const float * values = values_sample.empty() ? nullptr : values_sample.data(); @@ -1186,7 +1192,7 @@ static std::unordered_map target_bpw_type( const ggml_type tt = compatible_candidates[i]; const auto bpw = (float)tensor_bpw(t, tt); const size_t bytes = tensor_bytes(t, tt); - const auto err = (float)estimate_error(t, tt, f32_sample, sample_rows_per_slice, values, activations, tl_quantized_buffer, tl_dequantised_buffer); + const auto err = (float)estimate_error(t, tt, f32_sample, sample_rows_per_slice, values, activations, tl_quantized_buffer, tl_dequantised_buffer, bias_lambda); eval_candidates[i] = candidate_types{ tt, bpw, bytes, err }; } }); @@ -1301,7 +1307,6 @@ static std::unordered_map target_bpw_type( }; auto recompute_best_upgrade = [&]() -> upgrade { - const double eps = 1e-12; upgrade best{ -1, -1, 0.0, 0, -1.0 }; for (int i = 0; i < (int) all.size(); ++i) { const auto & ti = all[i]; @@ -1653,10 +1658,11 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: if (params->target_bpw != -1.0f && !params->only_copy) { if (params->imatrix) { if (params->activations) { - LLAMA_LOG_INFO("%s: imatrix with activations provided, target bpw quantization will be more accurate\n", __func__); + LLAMA_LOG_INFO("%s: imatrix with activations provided, target bpw quantization will be more accurate - ",__func__); } else { - LLAMA_LOG_WARN("%s: imatrix without activations provided, target bpw quantization will be less accurate\n", __func__); + LLAMA_LOG_WARN("%s: imatrix without activations provided, target bpw quantization will be less accurate - ", __func__); } + LLAMA_LOG_INFO("using %s\n", params->precise_lambda ? "precise lambda (slow)" : "fast lambda"); LLAMA_LOG_INFO("%s: computing tensor quantization mix to achieve %.4f bpw\n", __func__, params->target_bpw); bpw_overrides = target_bpw_type(ml, read_data, model, tensors, mapped, values_data, activations_data, params, nthread); } else { @@ -1966,7 +1972,8 @@ llama_model_quantize_params llama_model_quantize_default_params() { /*.kv_overrides =*/ nullptr, /*.tensor_type =*/ nullptr, /*.prune_layers =*/ nullptr, - /*.target_bpw =*/ -1.0f + /*.target_bpw =*/ -1.0f, + /*.precise_lambda =*/ false }; return result; diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index 77fa6b90ce..0c9460513c 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -132,7 +132,9 @@ static void usage(const char * executable) { printf(" Advanced option to selectively quantize tensors. May be specified multiple times.\n"); printf(" --prune-layers L0,L1,L2...comma-separated list of layer numbers to prune from the model\n"); printf(" Advanced option to remove all tensors from the given layers\n"); - printf(" --target-bpw: target bits per weight (bpw). Must be a positive number between 0.0 and 16.0 \n"); + printf(" --target-bpw: target bits per weight (bpw). Must be a positive number between 0.0 and 16.0\n"); + printf(" Advanced option to automatically select quantization types to achieve a total bits per weight (bpw) target\n"); + printf(" --precise-lambda: given a target bpw, use a high-precision error computation at the expense of longer processing times\n"); printf(" --keep-split: will generate quantized model in the same shards as input\n"); printf(" --override-kv KEY=TYPE:VALUE\n"); printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n"); @@ -538,6 +540,8 @@ int main(int argc, char ** argv) { if (arg_idx == argc-1 || !parse_target_bpw(argv[++arg_idx], target_bpw)) { usage(argv[0]); } + } else if (strcmp(argv[arg_idx], "--precise-lambda") == 0) { + params.precise_lambda = true; } else if (strcmp(argv[arg_idx], "--prune-layers") == 0) { if (arg_idx == argc-1 || !parse_layer_prune(argv[++arg_idx], prune_layers)) { usage(argv[0]);