From 630750fdef2b050fed05fa51670ceb454746be56 Mon Sep 17 00:00:00 2001 From: Ed Addario Date: Sun, 17 Aug 2025 09:42:18 +0100 Subject: [PATCH] Validate number of elements if in_sum is present --- tools/imatrix/imatrix.cpp | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tools/imatrix/imatrix.cpp b/tools/imatrix/imatrix.cpp index 52a15ebd82..1698146ffb 100644 --- a/tools/imatrix/imatrix.cpp +++ b/tools/imatrix/imatrix.cpp @@ -940,11 +940,11 @@ bool IMatrixCollector::load_imatrix(const char * file_name) { for (const auto & sc : sums_counts_for) { const std::string & name = sc.first; - const struct ggml_tensor * in_sum = std::get<2>(sc.second); + const struct ggml_tensor * in_sum = std::get<2>(sc.second); const struct ggml_tensor * in_sum2 = std::get<0>(sc.second); const struct ggml_tensor * counts = std::get<1>(sc.second); - if (!in_sum2 || !counts) { + if (!in_sum2 || !counts || (in_sum != nullptr && ggml_nelements(in_sum) != ggml_nelements(in_sum2))) { LOG_ERR("%s: mismatched sums and counts for %s\n", __func__, name.c_str()); gguf_free(ctx_gguf); ggml_free(ctx); @@ -981,16 +981,12 @@ bool IMatrixCollector::load_imatrix(const char * file_name) { // Recreate the state as expected by save_imatrix() for (int64_t j = 0; j < nval; j++) { + if (in_sum != nullptr) { e.activations[j] += ((const float *) in_sum->data)[j]; } e.values[j] += ((const float *) in_sum2->data)[j]; } for (int64_t j = 0; j < ncounts; j++) { e.counts[j] += std::lround(((const float *) counts->data)[j]); } - if (in_sum != nullptr) { - for (int64_t j = 0; j < nval; j++) { - e.activations[j] += ((const float *) in_sum->data)[j]; - } - } } // TODO: extract into its own method; this is also used by the legacy format