Validate number of elements if in_sum is present

This commit is contained in:
Ed Addario 2025-08-17 09:42:18 +01:00
parent 1f72bc157f
commit 630750fdef
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 3 additions and 7 deletions

View File

@ -940,11 +940,11 @@ bool IMatrixCollector::load_imatrix(const char * file_name) {
for (const auto & sc : sums_counts_for) {
const std::string & name = sc.first;
const struct ggml_tensor * in_sum = std::get<2>(sc.second);
const struct ggml_tensor * in_sum = std::get<2>(sc.second);
const struct ggml_tensor * in_sum2 = std::get<0>(sc.second);
const struct ggml_tensor * counts = std::get<1>(sc.second);
if (!in_sum2 || !counts) {
if (!in_sum2 || !counts || (in_sum != nullptr && ggml_nelements(in_sum) != ggml_nelements(in_sum2))) {
LOG_ERR("%s: mismatched sums and counts for %s\n", __func__, name.c_str());
gguf_free(ctx_gguf);
ggml_free(ctx);
@ -981,16 +981,12 @@ bool IMatrixCollector::load_imatrix(const char * file_name) {
// Recreate the state as expected by save_imatrix()
for (int64_t j = 0; j < nval; j++) {
if (in_sum != nullptr) { e.activations[j] += ((const float *) in_sum->data)[j]; }
e.values[j] += ((const float *) in_sum2->data)[j];
}
for (int64_t j = 0; j < ncounts; j++) {
e.counts[j] += std::lround(((const float *) counts->data)[j]);
}
if (in_sum != nullptr) {
for (int64_t j = 0; j < nval; j++) {
e.activations[j] += ((const float *) in_sum->data)[j];
}
}
}
// TODO: extract into its own method; this is also used by the legacy format