diff --git a/tools/imatrix/imatrix.cpp b/tools/imatrix/imatrix.cpp index 634739082a..2f16f3489c 100644 --- a/tools/imatrix/imatrix.cpp +++ b/tools/imatrix/imatrix.cpp @@ -160,7 +160,7 @@ static std::vector compute_tensor_averages(const Stats & tstats) { return vec; } -static int compute_vector_statistics(std::vector & tstats, const std::string & name, const Stats & e) { +static bool compute_vector_statistics(std::vector & tstats, const std::string & name, const Stats & e) { if (e.values.size() % e.counts.size() != 0) { LOG_ERR("%s: activation size mismatch for tensor %s (%zu vs %zu)\n", __func__, name.c_str(), e.counts.size(), e.values.size()); return -1;; @@ -172,7 +172,6 @@ static int compute_vector_statistics(std::vector & tstats, co const int n_mat = e.counts.size(); const int row_size = e.values.size() / n_mat; - const int calc_mode = e.activations.empty() ? 2 : 1; std::vector activations; @@ -203,7 +202,15 @@ static int compute_vector_statistics(std::vector & tstats, co const float std_deviation = std::sqrt(std::max(0.0f, variance)); float entropy = 0; - if (calc_mode == 1) { + if (e.activations.empty()) { + if (sum > 0) { + for (const auto act : activations) { + if (const float p = act / sum; p > 0) { + entropy -= p * std::log2(p); + } + } + } + } else { float div = 0.0; std::vector weights(activations.size()); for (size_t i = 0; i < activations.size(); ++i) { @@ -218,14 +225,6 @@ static int compute_vector_statistics(std::vector & tstats, co if (p > 0.0) entropy -= p * std::log2(p); } } - } else { - if (sum > 0) { - for (const auto act : activations) { - if (const float p = act / sum; p > 0) { - entropy -= p * std::log2(p); - } - } - } } int z_score = 0; @@ -247,7 +246,7 @@ static int compute_vector_statistics(std::vector & tstats, co ts.entropy = entropy; ts.zd_score = static_cast(z_score) / ts.elements; - return calc_mode; + return e.activations.empty(); } static void compute_tensor_statistics(std::vector & tstats) { @@ -1257,7 +1256,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c static bool show_statistics(const common_params & params) { std::vector ts; - int tensor_calc_mode = 0; + bool legacy_mode = false; if (params.in_files.empty() || params.in_files.size() > 1) { LOG_ERR("\nError: a single imatrix file is required to compute tensor statistics\n\n"); @@ -1265,7 +1264,7 @@ static bool show_statistics(const common_params & params) { } if (g_collector.load_imatrix(params.in_files[0].c_str())) { for (const auto & [name, stats] :g_collector.get_mstats()) { - tensor_calc_mode =compute_vector_statistics(ts, name, stats); + legacy_mode = compute_vector_statistics(ts, name, stats); } } else { LOG_ERR("\nError: %s is not a valid imatrix file\n\n", params.in_files[0].c_str()); @@ -1300,7 +1299,7 @@ static bool show_statistics(const common_params & params) { LOG_INF("\n%6s\t%18s\t%13s\t%8s\t%8s\t%7s\t%15s\t%13s\t%12s\t%s\t%5s\t%10s\n", "Layer", "Tensor", - tensor_calc_mode == 1 ? "L₂ Norm" : "Σ(Act²)", + legacy_mode ? "Σ(Act²)" : "L₂ Norm", "Min", "Max", "μ", @@ -1327,7 +1326,7 @@ static bool show_statistics(const common_params & params) { LOG_INF("%5s\t%-20s\t%11.2f\t%10.4f\t%10.4f\t%8.2f\t%8.2f\t%7d\t%12.4f\t%7.2f%%\t%6.2f%%\t%10.4f\n", layer.c_str(), name.c_str(), - tensor_calc_mode == 1 ? tstat.l2_norm : tstat.sum_values, + legacy_mode == 1 ? tstat.sum_values : tstat.l2_norm, tstat.min_values, tstat.max_values, tstat.mean_values, @@ -1361,7 +1360,7 @@ static bool show_statistics(const common_params & params) { LOG_INF("\nComputing aggregated statistics per layer (%ld layers)\n", layers); LOG_INF("\n%6s\t%13s\t%5s\t%10s\n", "Layer", - tensor_calc_mode == 1 ? "L₂ Norm" : "Σ(Act²)", + legacy_mode ? "Σ(Act²)" : "L₂ Norm", "ZD", "CosSim"); LOG_INF("============================================\n"); @@ -1375,7 +1374,7 @@ static bool show_statistics(const common_params & params) { const float l2_norm = (ll2n != lyr_l2_norm.end()) ? ll2n->second : 0.0f; LOG_INF("%5d\t%11.2f\t%6.2f%%\t%10.4f\n", layer, - tensor_calc_mode == 1 ? l2_norm: lyr_sum, + legacy_mode ? lyr_sum : l2_norm, 100.0f * lyr_zd, lyr_cs); }