Validate number of elements if in_sum is present
This commit is contained in:
parent
1f72bc157f
commit
630750fdef
|
|
@ -940,11 +940,11 @@ bool IMatrixCollector::load_imatrix(const char * file_name) {
|
|||
|
||||
for (const auto & sc : sums_counts_for) {
|
||||
const std::string & name = sc.first;
|
||||
const struct ggml_tensor * in_sum = std::get<2>(sc.second);
|
||||
const struct ggml_tensor * in_sum = std::get<2>(sc.second);
|
||||
const struct ggml_tensor * in_sum2 = std::get<0>(sc.second);
|
||||
const struct ggml_tensor * counts = std::get<1>(sc.second);
|
||||
|
||||
if (!in_sum2 || !counts) {
|
||||
if (!in_sum2 || !counts || (in_sum != nullptr && ggml_nelements(in_sum) != ggml_nelements(in_sum2))) {
|
||||
LOG_ERR("%s: mismatched sums and counts for %s\n", __func__, name.c_str());
|
||||
gguf_free(ctx_gguf);
|
||||
ggml_free(ctx);
|
||||
|
|
@ -981,16 +981,12 @@ bool IMatrixCollector::load_imatrix(const char * file_name) {
|
|||
|
||||
// Recreate the state as expected by save_imatrix()
|
||||
for (int64_t j = 0; j < nval; j++) {
|
||||
if (in_sum != nullptr) { e.activations[j] += ((const float *) in_sum->data)[j]; }
|
||||
e.values[j] += ((const float *) in_sum2->data)[j];
|
||||
}
|
||||
for (int64_t j = 0; j < ncounts; j++) {
|
||||
e.counts[j] += std::lround(((const float *) counts->data)[j]);
|
||||
}
|
||||
if (in_sum != nullptr) {
|
||||
for (int64_t j = 0; j < nval; j++) {
|
||||
e.activations[j] += ((const float *) in_sum->data)[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: extract into its own method; this is also used by the legacy format
|
||||
|
|
|
|||
Loading…
Reference in New Issue