Refactor compute_tensor_averages()

This commit is contained in:
Ed Addario 2026-01-11 17:36:10 +00:00
parent 6d82fa825a
commit d488bbb7c7
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 15 additions and 29 deletions

View File

@ -131,44 +131,30 @@ static void process_tensor_name(const std::string & input, std::string & layer,
static std::vector<float> compute_tensor_averages(const Stats & tstats) {
if (tstats.counts.empty()) { return {}; }
const size_t n_mat = tstats.counts.size();
const size_t len = !tstats.activations.empty() ? tstats.activations.size() : tstats.values.size();
if (len == 0 || n_mat == 0 || len % n_mat != 0) { return {}; }
const size_t row = len / n_mat;
std::vector<float> vec;
vec.reserve(len);
vec.resize(len);
bool has_valid = false;
if (tstats.activations.empty()) {
// Mean of squares (legacy: only values are available)
for (size_t m = 0; m < n_mat; ++m) {
const float c = (float) tstats.counts[m];
const size_t off = m * row;
if (c <= 0.0f) {
for (size_t j = 0; j < row; ++j) { vec.push_back(0.0f); }
continue;
}
const bool use_activations = !tstats.activations.empty();
has_valid = true;
for (size_t j = 0; j < row; ++j) {
vec.push_back(tstats.values[off + j] / c);
}
}
} else {
// Mean (new format: activations + values)
for (size_t m = 0; m < n_mat; ++m) {
const float c = (float) tstats.counts[m];
const size_t off = m * row;
if (c <= 0.0f) {
for (size_t j = 0; j < row; ++j) { vec.push_back(0.0f); }
continue;
}
for (size_t m = 0; m < n_mat; ++m) {
const auto c = (float) tstats.counts[m];
const size_t off = m * row;
has_valid = true;
for (size_t j = 0; j < row; ++j) {
vec.push_back(tstats.activations[off + j] / c);
}
}
if (c <= 0.0f) { continue; }
has_valid = true;
const float scale = 1.0f / c;
const float * src = use_activations ? &tstats.activations[off] : &tstats.values[off];
float * dst = & vec[off];
for (size_t j = 0; j < row; ++j) { dst[j] = src[j] * scale; }
}
if (!has_valid) { return {}; }