Implement bpw_overrides call
This commit is contained in:
parent
92f49ab399
commit
1187f6aa9e
|
|
@ -1314,6 +1314,12 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unordered_map<std::string, ggml_type> bpw_overrides = {};
|
||||||
|
if (params->target_bpw != -1.0f) {
|
||||||
|
LLAMA_LOG_INFO("%s: computing tensor quantization mix to achieve %.3f bpw at lowest ppl - this opearation may take some time\n", __func__, params->target_bpw);
|
||||||
|
bpw_overrides = target_bpw_type(ml, read_data, model, tensors, mapped, values_data, activations_data, params->target_bpw, nthread);
|
||||||
|
}
|
||||||
|
|
||||||
int cur_split = -1;
|
int cur_split = -1;
|
||||||
std::ofstream fout;
|
std::ofstream fout;
|
||||||
auto close_ofstream = [&]() {
|
auto close_ofstream = [&]() {
|
||||||
|
|
@ -1430,6 +1436,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
if (!params->pure && ggml_is_quantized(default_type)) {
|
if (!params->pure && ggml_is_quantized(default_type)) {
|
||||||
int fallback = qs.n_fallback;
|
int fallback = qs.n_fallback;
|
||||||
new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
|
new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
|
||||||
|
// get bpw override
|
||||||
|
const auto override = bpw_overrides.find(name);
|
||||||
|
if (override != bpw_overrides.end()) { new_type = override->second; }
|
||||||
// unless the user specifies a type, and the tensor geometry will not require fallback quantisation
|
// unless the user specifies a type, and the tensor geometry will not require fallback quantisation
|
||||||
if (params->tensor_types && qs.n_fallback - fallback == 0) {
|
if (params->tensor_types && qs.n_fallback - fallback == 0) {
|
||||||
const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
|
const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue