diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index c2a4767fc9..2c45adab75 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -215,7 +215,10 @@ static int load_legacy_imatrix(const std::string & imatrix_file, std::vector & imatrix_datasets, std::unordered_map> & imatrix_data) { +static int load_imatrix(const std::string & imatrix_file, + std::vector & imatrix_datasets, + std::unordered_map> & values_data, + std::unordered_map> & activations_data) { struct ggml_context * ctx = nullptr; struct gguf_init_params meta_gguf_params = { @@ -225,7 +228,7 @@ static int load_imatrix(const std::string & imatrix_file, std::vector & imatrix_dataset, const std::vector & included_weights, const std::vector & excluded_weights, - std::unordered_map> & imatrix_data) { + std::unordered_map> & values_data, + std::unordered_map> & activations_data) { int m_last_call = -1; if (!imatrix_file.empty()) { - m_last_call = load_imatrix(imatrix_file, imatrix_dataset, imatrix_data); + m_last_call = load_imatrix(imatrix_file, imatrix_dataset, values_data, activations_data); } - if (imatrix_data.empty()) { + if (values_data.empty()) { return m_last_call; } if (!excluded_weights.empty()) { for (const auto & name : excluded_weights) { - for (auto it = imatrix_data.begin(); it != imatrix_data.end();) { + for (auto it = values_data.begin(); it != values_data.end();) { auto pos = it->first.find(name); if (pos != std::string::npos) { - it = imatrix_data.erase(it); + it = values_data.erase(it); } else { ++it; } } + for (auto at = activations_data.begin(); at != activations_data.end();) { + auto pos = at->first.find(name); + if (pos != std::string::npos) { + at = activations_data.erase(at); + } else { + ++at; + } + } } } if (!included_weights.empty()) { - std::unordered_map> tmp; + std::unordered_map> tmp_values; + std::unordered_map> tmp_activations; for (const auto & name : included_weights) { - for (auto & e : imatrix_data) { + for (auto & e : values_data) { auto pos = e.first.find(name); if (pos != std::string::npos) { - tmp.emplace(std::move(e)); + tmp_values.emplace(std::move(e)); + } + } + for (auto & a : activations_data) { + auto pos = a.first.find(name); + if (pos != std::string::npos) { + tmp_activations.emplace(std::move(a)); } } } - imatrix_data = std::move(tmp); + values_data = std::move(tmp_values); + activations_data = std::move(tmp_activations); } - if (!imatrix_data.empty()) { - printf("%s: have %d importance matrix entries\n", __func__, int(imatrix_data.size())); + if (!values_data.empty()) { + printf("%s: have %d importance matrix value entries\n", __func__, int(values_data.size())); + } + if (!activations_data.empty()) { + printf("%s: have %d importance matrix activation entries\n", __func__, int(activations_data.size())); } return m_last_call; }