diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 1e24303c52..b0b3be76ca 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -1314,6 +1314,12 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } } + std::unordered_map bpw_overrides = {}; + if (params->target_bpw != -1.0f) { + LLAMA_LOG_INFO("%s: computing tensor quantization mix to achieve %.3f bpw at lowest ppl - this opearation may take some time\n", __func__, params->target_bpw); + bpw_overrides = target_bpw_type(ml, read_data, model, tensors, mapped, values_data, activations_data, params->target_bpw, nthread); + } + int cur_split = -1; std::ofstream fout; auto close_ofstream = [&]() { @@ -1430,6 +1436,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: if (!params->pure && ggml_is_quantized(default_type)) { int fallback = qs.n_fallback; new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); + // get bpw override + const auto override = bpw_overrides.find(name); + if (override != bpw_overrides.end()) { new_type = override->second; } // unless the user specifies a type, and the tensor geometry will not require fallback quantisation if (params->tensor_types && qs.n_fallback - fallback == 0) { const std::vector & tensor_types = *static_cast *>(params->tensor_types);