Determine calculation mode

This commit is contained in:
Ed Addario 2025-08-02 16:36:12 +01:00
parent 78ddb475de
commit 9744a4a1c6
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 7 additions and 4 deletions

View File

@ -127,18 +127,19 @@ static void process_tensor_name(const std::string & input, std::string & layer,
} }
} }
static void compute_statistics(std::vector<tensor_statistics> & tstats, const std::string & name, const Stats & e) { static int compute_tensor_statistics(std::vector<tensor_statistics> & tstats, const std::string & name, const Stats & e) {
if (e.in_sum2.size() % e.counts.size() != 0) { if (e.in_sum2.size() % e.counts.size() != 0) {
LOG_ERR("%s: activation size mismatch for tensor %s (%zu vs %zu)\n", __func__, name.c_str(), e.counts.size(), e.in_sum2.size()); LOG_ERR("%s: activation size mismatch for tensor %s (%zu vs %zu)\n", __func__, name.c_str(), e.counts.size(), e.in_sum2.size());
return; return -1;;
} }
if (e.counts.empty()) { if (e.counts.empty()) {
LOG_ERR("%s: there are no activations for tensor %s. The imatrix may be suboptimal\n", __func__, name.c_str()); LOG_ERR("%s: there are no activations for tensor %s. The imatrix may be suboptimal\n", __func__, name.c_str());
return; return -1;
} }
const int n_mat = e.counts.size(); const int n_mat = e.counts.size();
const int row_size = e.in_sum2.size() / n_mat; const int row_size = e.in_sum2.size() / n_mat;
const int calc_mode = e.in_sum.empty() ? 2 : 1;
std::vector<float> activations; std::vector<float> activations;
@ -1104,13 +1105,15 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c
static bool show_statistics(const common_params & params) { static bool show_statistics(const common_params & params) {
std::vector<tensor_statistics> ts; std::vector<tensor_statistics> ts;
int tensor_calc_mode = 0;
if (params.in_files.empty() || params.in_files.size() > 1) { if (params.in_files.empty() || params.in_files.size() > 1) {
LOG_ERR("\nError: a single imatrix file is required to compute tensor statistics\n\n"); LOG_ERR("\nError: a single imatrix file is required to compute tensor statistics\n\n");
return false; return false;
} }
if (g_collector.load_imatrix(params.in_files[0].c_str())) { if (g_collector.load_imatrix(params.in_files[0].c_str())) {
for (const auto & [name, stats] :g_collector.get_mstats()) { for (const auto & [name, stats] :g_collector.get_mstats()) {
compute_statistics(ts, name, stats); tensor_calc_mode =compute_tensor_statistics(ts, name, stats);
} }
} else { } else {
LOG_ERR("\nError: %s is not a valid imatrix file\n\n", params.in_files[0].c_str()); LOG_ERR("\nError: %s is not a valid imatrix file\n\n", params.in_files[0].c_str());