Implement bpw_overrides call

This commit is contained in:
Ed Addario 2025-08-19 11:07:03 +01:00
parent 92f49ab399
commit 1187f6aa9e
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 9 additions and 0 deletions

View File

@ -1314,6 +1314,12 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
} }
} }
std::unordered_map<std::string, ggml_type> bpw_overrides = {};
if (params->target_bpw != -1.0f) {
LLAMA_LOG_INFO("%s: computing tensor quantization mix to achieve %.3f bpw at lowest ppl - this opearation may take some time\n", __func__, params->target_bpw);
bpw_overrides = target_bpw_type(ml, read_data, model, tensors, mapped, values_data, activations_data, params->target_bpw, nthread);
}
int cur_split = -1; int cur_split = -1;
std::ofstream fout; std::ofstream fout;
auto close_ofstream = [&]() { auto close_ofstream = [&]() {
@ -1430,6 +1436,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
if (!params->pure && ggml_is_quantized(default_type)) { if (!params->pure && ggml_is_quantized(default_type)) {
int fallback = qs.n_fallback; int fallback = qs.n_fallback;
new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
// get bpw override
const auto override = bpw_overrides.find(name);
if (override != bpw_overrides.end()) { new_type = override->second; }
// unless the user specifies a type, and the tensor geometry will not require fallback quantisation // unless the user specifies a type, and the tensor geometry will not require fallback quantisation
if (params->tensor_types && qs.n_fallback - fallback == 0) { if (params->tensor_types && qs.n_fallback - fallback == 0) {
const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types); const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);