diff --git a/include/llama.h b/include/llama.h index ce04011e19..517ef5e0fb 100644 --- a/include/llama.h +++ b/include/llama.h @@ -368,6 +368,7 @@ extern "C" { float target_bpw; // target bits per weight (bpw) bool keep_bpw_state; // keep bpw state file void * bpw_state; // pointer to bpw state file + void * statistics; // pointer to statistics data } llama_model_quantize_params; typedef struct llama_logit_bias { diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index fdce1f4285..a8153494f9 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -631,6 +631,7 @@ static std::unordered_map target_bpw_type( const std::map & mapped, const std::unordered_map> * values_data, const std::unordered_map> * activations_data, + const std::unordered_map> * statistics_data, const llama_model_quantize_params * params, int nthread ) { @@ -1815,6 +1816,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } const std::unordered_map> * values_data = nullptr; const std::unordered_map> * activations_data = nullptr; + const std::unordered_map> * statistics_data = nullptr; if (params->imatrix) { values_data = static_cast>*>(params->imatrix); if (values_data) { @@ -1845,6 +1847,12 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } } } + if (params->statistics) { + statistics_data = static_cast>*>(params->statistics); + if (statistics_data) { + LLAMA_LOG_INFO(" and %d statistics",int(statistics_data->size())); + } + } LLAMA_LOG_INFO("\n"); gguf_context_ptr ctx_out { gguf_init_empty() }; @@ -1999,15 +2007,18 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: std::unordered_map bpw_overrides = {}; if (params->target_bpw != -1.0f && !params->only_copy) { if (params->imatrix) { - if (params->activations) { - LLAMA_LOG_INFO("%s: imatrix with activations provided, target bpw quantization will be more accurate\n",__func__); - } else { - LLAMA_LOG_WARN("%s: imatrix without activations provided, target bpw quantization will be less accurate\n", __func__); - } + const char* base_msg = params->activations + ? (params->statistics + ? "imatrix with activations and statistics provided, process will be more accurate\n" + : "imatrix with activations provided, process will be accurate\n") + : "imatrix without activations provided, process will be less accurate\n"; + if (params->activations) { LLAMA_LOG_INFO("%s: %s", __func__, base_msg); } + else { LLAMA_LOG_WARN("%s: %s", __func__, base_msg); } + LLAMA_LOG_INFO("%s: computing tensor quantization mix to achieve %.4f bpw\n", __func__, params->target_bpw); - bpw_overrides = target_bpw_type(ml, model, tensors, mapped, values_data, activations_data, params, nthread); + bpw_overrides = target_bpw_type(ml, model, tensors, mapped, values_data, activations_data, statistics_data, params, nthread); } else { - LLAMA_LOG_WARN("%s: no imatrix provided, target bpw will not apply\n", __func__); + LLAMA_LOG_WARN("%s: --target-bpw requires an imatrix but none was provided, option will be ignored\n", __func__); } } @@ -2269,7 +2280,8 @@ llama_model_quantize_params llama_model_quantize_default_params() { /*.prune_layers =*/ nullptr, /*.target_bpw =*/ -1.0f, /*.keep_bpw_state =*/ false, - /*.bpw_state =*/ nullptr + /*.bpw_state =*/ nullptr, + /*.statistics =*/ nullptr }; return result; diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index f994999e59..0b2b05b60a 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -221,7 +221,8 @@ static int load_legacy_imatrix(const std::string & imatrix_file, std::vector & imatrix_datasets, std::unordered_map> & values_data, - std::unordered_map> & activations_data) { + std::unordered_map> & activations_data, + std::unordered_map> & statistics_data) { struct ggml_context * ctx = nullptr; struct gguf_init_params meta_gguf_params = { @@ -256,24 +257,28 @@ static int load_imatrix(const std::string & imatrix_file, const std::string sums_suffix{ ".in_sum" }; const std::string sums2_suffix{ ".in_sum2" }; const std::string counts_suffix{ ".counts" }; + const std::string stats_suffix{ ".stats" }; // Using an ordered map to get a deterministic iteration order. - std::map> sums_counts_for; + std::map> sums_counts_for; for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { std::string name = cur->name; if (name.empty()) { continue; } - if (string_remove_suffix(name, sums2_suffix)) { - // in_sum2 + if (string_remove_suffix(name, sums_suffix)) { + // in_sum std::get<0>(sums_counts_for[std::move(name)]) = cur; + } else if (string_remove_suffix(name, sums2_suffix)) { + // in_sum2 + std::get<1>(sums_counts_for[std::move(name)]) = cur; } else if (string_remove_suffix(name, counts_suffix)) { // counts - std::get<1>(sums_counts_for[std::move(name)]) = cur; - } else if (string_remove_suffix(name, sums_suffix)) { - // in_sum std::get<2>(sums_counts_for[std::move(name)]) = cur; + } else if (string_remove_suffix(name, stats_suffix)) { + // stats + std::get<3>(sums_counts_for[std::move(name)]) = cur; } else { // ignore other tensors @@ -282,11 +287,12 @@ static int load_imatrix(const std::string & imatrix_file, for (const auto & sc : sums_counts_for) { const std::string & name = sc.first; - const struct ggml_tensor * sums = std::get<2>(sc.second); - const struct ggml_tensor * sums2 = std::get<0>(sc.second); - const struct ggml_tensor * counts = std::get<1>(sc.second); + const struct ggml_tensor * sums = std::get<0>(sc.second); + const struct ggml_tensor * sums2 = std::get<1>(sc.second); + const struct ggml_tensor * counts = std::get<2>(sc.second); + const struct ggml_tensor * stats = std::get<3>(sc.second); - // check that sums, sums2 and counts have the same shape + // check sums2 and counts are present, and that sums and sums2 have the same shape if (!sums2 || !counts || (sums != nullptr && ggml_nelements(sums) != ggml_nelements(sums2))) { fprintf(stderr, "%s: mismatched sums and counts for %s\n", __func__, name.c_str()); gguf_free(ctx_gguf); @@ -302,6 +308,19 @@ static int load_imatrix(const std::string & imatrix_file, if (sums) { activations.resize(ggml_nelements(sums)); } + if (stats) { + auto & statistics = statistics_data[name]; + statistics.resize(ggml_nelements(stats)); + if (stats->type == GGML_TYPE_F32) { + std::memcpy(statistics.data(), stats->data, ggml_nelements(stats) * sizeof(float)); + } else { + fprintf(stderr, "%s: unsupported .stats type '%s' for '%s' - ignoring entry\n", + __func__, ggml_type_name(stats->type), name.c_str()); + statistics.clear(); + statistics_data.erase(name); + } + + } values.resize(ggml_nelements(sums2)); float max_count = 0.0f; for (int64_t j = 0; j < ne1; ++j) { @@ -354,10 +373,11 @@ static int prepare_imatrix(const std::string & imatrix_file, const std::vector & included_weights, const std::vector & excluded_weights, std::unordered_map> & values_data, - std::unordered_map> & activations_data) { + std::unordered_map> & activations_data, + std::unordered_map> & statistics_data) { int m_last_call = -1; if (!imatrix_file.empty()) { - m_last_call = load_imatrix(imatrix_file, imatrix_dataset, values_data, activations_data); + m_last_call = load_imatrix(imatrix_file, imatrix_dataset, values_data, activations_data, statistics_data); } if (values_data.empty()) { return m_last_call; @@ -380,11 +400,20 @@ static int prepare_imatrix(const std::string & imatrix_file, ++at; } } + for (auto st = statistics_data.begin(); st != statistics_data.end();) { + auto pos = st->first.find(name); + if (pos != std::string::npos) { + st = activations_data.erase(st); + } else { + ++st; + } + } } } if (!included_weights.empty()) { std::unordered_map> tmp_values; std::unordered_map> tmp_activations; + std::unordered_map> tmp_statistics; for (const auto & name : included_weights) { for (auto & e : values_data) { auto pos = e.first.find(name); @@ -398,9 +427,16 @@ static int prepare_imatrix(const std::string & imatrix_file, tmp_activations.emplace(std::move(a)); } } + for (auto & s : statistics_data) { + auto pos = s.first.find(name); + if (pos != std::string::npos) { + tmp_statistics.emplace(std::move(s)); + } + } } values_data = std::move(tmp_values); activations_data = std::move(tmp_activations); + statistics_data = std::move(tmp_statistics); } return m_last_call; @@ -617,7 +653,8 @@ int main(int argc, char ** argv) { std::vector imatrix_datasets; std::unordered_map> values_data; std::unordered_map> activations_data; - int m_last_call = prepare_imatrix(imatrix_file, imatrix_datasets, included_weights, excluded_weights, values_data, activations_data); + std::unordered_map> statistics_data; + int m_last_call = prepare_imatrix(imatrix_file, imatrix_datasets, included_weights, excluded_weights, values_data, activations_data, statistics_data); if (!values_data.empty()) { params.imatrix = &values_data; { @@ -657,6 +694,9 @@ int main(int argc, char ** argv) { if (!activations_data.empty()) { params.activations = &activations_data; } + if (!statistics_data.empty()) { + params.statistics = &statistics_data; + } if (!kv_overrides.empty()) { kv_overrides.emplace_back(); kv_overrides.back().key[0] = 0;