Refactor legacy mode

This commit is contained in:
Ed Addario 2025-08-05 14:16:45 +01:00
parent 4c3fea89d6
commit 88854c9179
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 17 additions and 18 deletions

View File

@ -160,7 +160,7 @@ static std::vector<float> compute_tensor_averages(const Stats & tstats) {
return vec; return vec;
} }
static int compute_vector_statistics(std::vector<tensor_statistics> & tstats, const std::string & name, const Stats & e) { static bool compute_vector_statistics(std::vector<tensor_statistics> & tstats, const std::string & name, const Stats & e) {
if (e.values.size() % e.counts.size() != 0) { if (e.values.size() % e.counts.size() != 0) {
LOG_ERR("%s: activation size mismatch for tensor %s (%zu vs %zu)\n", __func__, name.c_str(), e.counts.size(), e.values.size()); LOG_ERR("%s: activation size mismatch for tensor %s (%zu vs %zu)\n", __func__, name.c_str(), e.counts.size(), e.values.size());
return -1;; return -1;;
@ -172,7 +172,6 @@ static int compute_vector_statistics(std::vector<tensor_statistics> & tstats, co
const int n_mat = e.counts.size(); const int n_mat = e.counts.size();
const int row_size = e.values.size() / n_mat; const int row_size = e.values.size() / n_mat;
const int calc_mode = e.activations.empty() ? 2 : 1;
std::vector<float> activations; std::vector<float> activations;
@ -203,7 +202,15 @@ static int compute_vector_statistics(std::vector<tensor_statistics> & tstats, co
const float std_deviation = std::sqrt(std::max(0.0f, variance)); const float std_deviation = std::sqrt(std::max(0.0f, variance));
float entropy = 0; float entropy = 0;
if (calc_mode == 1) { if (e.activations.empty()) {
if (sum > 0) {
for (const auto act : activations) {
if (const float p = act / sum; p > 0) {
entropy -= p * std::log2(p);
}
}
}
} else {
float div = 0.0; float div = 0.0;
std::vector<float> weights(activations.size()); std::vector<float> weights(activations.size());
for (size_t i = 0; i < activations.size(); ++i) { for (size_t i = 0; i < activations.size(); ++i) {
@ -218,14 +225,6 @@ static int compute_vector_statistics(std::vector<tensor_statistics> & tstats, co
if (p > 0.0) entropy -= p * std::log2(p); if (p > 0.0) entropy -= p * std::log2(p);
} }
} }
} else {
if (sum > 0) {
for (const auto act : activations) {
if (const float p = act / sum; p > 0) {
entropy -= p * std::log2(p);
}
}
}
} }
int z_score = 0; int z_score = 0;
@ -247,7 +246,7 @@ static int compute_vector_statistics(std::vector<tensor_statistics> & tstats, co
ts.entropy = entropy; ts.entropy = entropy;
ts.zd_score = static_cast<float>(z_score) / ts.elements; ts.zd_score = static_cast<float>(z_score) / ts.elements;
return calc_mode; return e.activations.empty();
} }
static void compute_tensor_statistics(std::vector<tensor_statistics> & tstats) { static void compute_tensor_statistics(std::vector<tensor_statistics> & tstats) {
@ -1257,7 +1256,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c
static bool show_statistics(const common_params & params) { static bool show_statistics(const common_params & params) {
std::vector<tensor_statistics> ts; std::vector<tensor_statistics> ts;
int tensor_calc_mode = 0; bool legacy_mode = false;
if (params.in_files.empty() || params.in_files.size() > 1) { if (params.in_files.empty() || params.in_files.size() > 1) {
LOG_ERR("\nError: a single imatrix file is required to compute tensor statistics\n\n"); LOG_ERR("\nError: a single imatrix file is required to compute tensor statistics\n\n");
@ -1265,7 +1264,7 @@ static bool show_statistics(const common_params & params) {
} }
if (g_collector.load_imatrix(params.in_files[0].c_str())) { if (g_collector.load_imatrix(params.in_files[0].c_str())) {
for (const auto & [name, stats] :g_collector.get_mstats()) { for (const auto & [name, stats] :g_collector.get_mstats()) {
tensor_calc_mode =compute_vector_statistics(ts, name, stats); legacy_mode = compute_vector_statistics(ts, name, stats);
} }
} else { } else {
LOG_ERR("\nError: %s is not a valid imatrix file\n\n", params.in_files[0].c_str()); LOG_ERR("\nError: %s is not a valid imatrix file\n\n", params.in_files[0].c_str());
@ -1300,7 +1299,7 @@ static bool show_statistics(const common_params & params) {
LOG_INF("\n%6s\t%18s\t%13s\t%8s\t%8s\t%7s\t%15s\t%13s\t%12s\t%s\t%5s\t%10s\n", LOG_INF("\n%6s\t%18s\t%13s\t%8s\t%8s\t%7s\t%15s\t%13s\t%12s\t%s\t%5s\t%10s\n",
"Layer", "Layer",
"Tensor", "Tensor",
tensor_calc_mode == 1 ? "L₂ Norm" : "Σ(Act²)", legacy_mode ? "Σ(Act²)" : "L₂ Norm",
"Min", "Min",
"Max", "Max",
"μ", "μ",
@ -1327,7 +1326,7 @@ static bool show_statistics(const common_params & params) {
LOG_INF("%5s\t%-20s\t%11.2f\t%10.4f\t%10.4f\t%8.2f\t%8.2f\t%7d\t%12.4f\t%7.2f%%\t%6.2f%%\t%10.4f\n", LOG_INF("%5s\t%-20s\t%11.2f\t%10.4f\t%10.4f\t%8.2f\t%8.2f\t%7d\t%12.4f\t%7.2f%%\t%6.2f%%\t%10.4f\n",
layer.c_str(), layer.c_str(),
name.c_str(), name.c_str(),
tensor_calc_mode == 1 ? tstat.l2_norm : tstat.sum_values, legacy_mode == 1 ? tstat.sum_values : tstat.l2_norm,
tstat.min_values, tstat.min_values,
tstat.max_values, tstat.max_values,
tstat.mean_values, tstat.mean_values,
@ -1361,7 +1360,7 @@ static bool show_statistics(const common_params & params) {
LOG_INF("\nComputing aggregated statistics per layer (%ld layers)\n", layers); LOG_INF("\nComputing aggregated statistics per layer (%ld layers)\n", layers);
LOG_INF("\n%6s\t%13s\t%5s\t%10s\n", LOG_INF("\n%6s\t%13s\t%5s\t%10s\n",
"Layer", "Layer",
tensor_calc_mode == 1 ? "L₂ Norm" : "Σ(Act²)", legacy_mode ? "Σ(Act²)" : "L₂ Norm",
"ZD", "ZD",
"CosSim"); "CosSim");
LOG_INF("============================================\n"); LOG_INF("============================================\n");
@ -1375,7 +1374,7 @@ static bool show_statistics(const common_params & params) {
const float l2_norm = (ll2n != lyr_l2_norm.end()) ? ll2n->second : 0.0f; const float l2_norm = (ll2n != lyr_l2_norm.end()) ? ll2n->second : 0.0f;
LOG_INF("%5d\t%11.2f\t%6.2f%%\t%10.4f\n", LOG_INF("%5d\t%11.2f\t%6.2f%%\t%10.4f\n",
layer, layer,
tensor_calc_mode == 1 ? l2_norm: lyr_sum, legacy_mode ? lyr_sum : l2_norm,
100.0f * lyr_zd, 100.0f * lyr_zd,
lyr_cs); lyr_cs);
} }