Read statistics_data from imatrix

This commit is contained in:
Ed Addario 2026-01-21 18:27:44 +00:00
parent 25d7ecc42a
commit 3ba6798d45
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 49 additions and 6 deletions

View File

@ -226,7 +226,8 @@ static int load_legacy_imatrix(const std::string & imatrix_file, std::vector<std
static int load_imatrix(const std::string & imatrix_file,
std::vector<std::string> & imatrix_datasets,
std::unordered_map<std::string, std::vector<float>> & values_data,
std::unordered_map<std::string, std::vector<float>> & activations_data) {
std::unordered_map<std::string, std::vector<float>> & activations_data,
std::unordered_map<std::string, std::vector<float>> & statistics_data) {
struct ggml_context * ctx = nullptr;
struct gguf_init_params meta_gguf_params = {
@ -261,9 +262,10 @@ static int load_imatrix(const std::string & imatrix_file,
const std::string sums_suffix{ ".in_sum" };
const std::string sums2_suffix{ ".in_sum2" };
const std::string counts_suffix{ ".counts" };
const std::string stats_suffix{ ".stats" };
// Using an ordered map to get a deterministic iteration order.
std::map<std::string, std::tuple<struct ggml_tensor *, struct ggml_tensor *, struct ggml_tensor *>> sums_counts_for;
std::map<std::string, std::tuple<struct ggml_tensor *, struct ggml_tensor *, struct ggml_tensor *, struct ggml_tensor *>> sums_counts_for;
for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
std::string name = cur->name;
@ -279,7 +281,10 @@ static int load_imatrix(const std::string & imatrix_file,
} else if (string_remove_suffix(name, counts_suffix)) {
// counts
std::get<2>(sums_counts_for[std::move(name)]) = cur;
} else {
} else if (string_remove_suffix(name, stats_suffix)) {
// stats
std::get<3>(sums_counts_for[std::move(name)]) = cur;
} else {
// ignore other tensors
}
}
@ -289,6 +294,7 @@ static int load_imatrix(const std::string & imatrix_file,
const struct ggml_tensor * sums = std::get<0>(sc.second);
const struct ggml_tensor * sums2 = std::get<1>(sc.second);
const struct ggml_tensor * counts = std::get<2>(sc.second);
const struct ggml_tensor * stats = std::get<3>(sc.second);
// check sums2 and counts are present, and that sums and sums2 have the same shape
if (!sums2 || !counts || (sums != nullptr && ggml_nelements(sums) != ggml_nelements(sums2))) {
@ -306,6 +312,20 @@ static int load_imatrix(const std::string & imatrix_file,
if (sums) {
activations.resize(ggml_nelements(sums));
}
if (stats) {
auto & statistics = statistics_data[name];
statistics.resize(ggml_nelements(stats));
if (stats->type == GGML_TYPE_F32) {
std::memcpy(statistics.data(), stats->data, ggml_nelements(stats) * sizeof(float));
} else {
fprintf(stderr, "%s: unsupported .stats type '%s' for '%s' - ignoring entry\n",
__func__, ggml_type_name(stats->type), name.c_str());
statistics.clear();
statistics_data.erase(name);
}
}
values.resize(ggml_nelements(sums2));
float max_count = 0.0f;
for (int64_t j = 0; j < ne1; ++j) {
@ -358,10 +378,11 @@ static int prepare_imatrix(const std::string & imatrix_file,
const std::vector<std::string> & included_weights,
const std::vector<std::string> & excluded_weights,
std::unordered_map<std::string, std::vector<float>> & values_data,
std::unordered_map<std::string, std::vector<float>> & activations_data) {
std::unordered_map<std::string, std::vector<float>> & activations_data,
std::unordered_map<std::string, std::vector<float>> & statistics_data) {
int m_last_call = -1;
if (!imatrix_file.empty()) {
m_last_call = load_imatrix(imatrix_file, imatrix_dataset, values_data, activations_data);
m_last_call = load_imatrix(imatrix_file, imatrix_dataset, values_data, activations_data, statistics_data);
}
if (values_data.empty()) {
return m_last_call;
@ -384,11 +405,20 @@ static int prepare_imatrix(const std::string & imatrix_file,
++at;
}
}
for (auto st = statistics_data.begin(); st != statistics_data.end();) {
auto pos = st->first.find(name);
if (pos != std::string::npos) {
st = activations_data.erase(st);
} else {
++st;
}
}
}
}
if (!included_weights.empty()) {
std::unordered_map<std::string, std::vector<float>> tmp_values;
std::unordered_map<std::string, std::vector<float>> tmp_activations;
std::unordered_map<std::string, std::vector<float>> tmp_statistics;
for (const auto & name : included_weights) {
for (auto & e : values_data) {
auto pos = e.first.find(name);
@ -402,9 +432,16 @@ static int prepare_imatrix(const std::string & imatrix_file,
tmp_activations.emplace(std::move(a));
}
}
for (auto & s : statistics_data) {
auto pos = s.first.find(name);
if (pos != std::string::npos) {
tmp_statistics.emplace(std::move(s));
}
}
}
values_data = std::move(tmp_values);
activations_data = std::move(tmp_activations);
statistics_data = std::move(tmp_statistics);
}
return m_last_call;
@ -611,6 +648,8 @@ int main(int argc, char ** argv) {
if (arg_idx == argc-1 || !parse_target_size(argv[++arg_idx], target_size)) {
usage(argv[0]);
}
} else if (strcmp(argv[arg_idx], "--use-wce") == 0) {
params.use_wce = true;
} else if (strcmp(argv[arg_idx], "--ignore-tensor-importance") == 0) {
params.ignore_tensor_importance = true;
} else if (strcmp(argv[arg_idx], "--save-state") == 0) {
@ -669,7 +708,8 @@ int main(int argc, char ** argv) {
std::vector<std::string> imatrix_datasets;
std::unordered_map<std::string, std::vector<float>> values_data;
std::unordered_map<std::string, std::vector<float>> activations_data;
int m_last_call = prepare_imatrix(imatrix_file, imatrix_datasets, included_weights, excluded_weights, values_data, activations_data);
std::unordered_map<std::string, std::vector<float>> statistics_data;
int m_last_call = prepare_imatrix(imatrix_file, imatrix_datasets, included_weights, excluded_weights, values_data, activations_data, statistics_data);
if (!values_data.empty()) {
params.imatrix = &values_data;
{
@ -709,6 +749,9 @@ int main(int argc, char ** argv) {
if (!activations_data.empty()) {
params.activations = &activations_data;
}
if (!statistics_data.empty()) {
params.statistics = &statistics_data;
}
if (!kv_overrides.empty()) {
kv_overrides.emplace_back();
kv_overrides.back().key[0] = 0;