Validate number of elements if in_sum is present
This commit is contained in:
parent
1f72bc157f
commit
630750fdef
|
|
@ -940,11 +940,11 @@ bool IMatrixCollector::load_imatrix(const char * file_name) {
|
||||||
|
|
||||||
for (const auto & sc : sums_counts_for) {
|
for (const auto & sc : sums_counts_for) {
|
||||||
const std::string & name = sc.first;
|
const std::string & name = sc.first;
|
||||||
const struct ggml_tensor * in_sum = std::get<2>(sc.second);
|
const struct ggml_tensor * in_sum = std::get<2>(sc.second);
|
||||||
const struct ggml_tensor * in_sum2 = std::get<0>(sc.second);
|
const struct ggml_tensor * in_sum2 = std::get<0>(sc.second);
|
||||||
const struct ggml_tensor * counts = std::get<1>(sc.second);
|
const struct ggml_tensor * counts = std::get<1>(sc.second);
|
||||||
|
|
||||||
if (!in_sum2 || !counts) {
|
if (!in_sum2 || !counts || (in_sum != nullptr && ggml_nelements(in_sum) != ggml_nelements(in_sum2))) {
|
||||||
LOG_ERR("%s: mismatched sums and counts for %s\n", __func__, name.c_str());
|
LOG_ERR("%s: mismatched sums and counts for %s\n", __func__, name.c_str());
|
||||||
gguf_free(ctx_gguf);
|
gguf_free(ctx_gguf);
|
||||||
ggml_free(ctx);
|
ggml_free(ctx);
|
||||||
|
|
@ -981,16 +981,12 @@ bool IMatrixCollector::load_imatrix(const char * file_name) {
|
||||||
|
|
||||||
// Recreate the state as expected by save_imatrix()
|
// Recreate the state as expected by save_imatrix()
|
||||||
for (int64_t j = 0; j < nval; j++) {
|
for (int64_t j = 0; j < nval; j++) {
|
||||||
|
if (in_sum != nullptr) { e.activations[j] += ((const float *) in_sum->data)[j]; }
|
||||||
e.values[j] += ((const float *) in_sum2->data)[j];
|
e.values[j] += ((const float *) in_sum2->data)[j];
|
||||||
}
|
}
|
||||||
for (int64_t j = 0; j < ncounts; j++) {
|
for (int64_t j = 0; j < ncounts; j++) {
|
||||||
e.counts[j] += std::lround(((const float *) counts->data)[j]);
|
e.counts[j] += std::lround(((const float *) counts->data)[j]);
|
||||||
}
|
}
|
||||||
if (in_sum != nullptr) {
|
|
||||||
for (int64_t j = 0; j < nval; j++) {
|
|
||||||
e.activations[j] += ((const float *) in_sum->data)[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: extract into its own method; this is also used by the legacy format
|
// TODO: extract into its own method; this is also used by the legacy format
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue