Clamp CosSim to [-1, 1] to avoid float drift

This commit is contained in:
Ed Addario 2025-10-28 18:29:59 +00:00
parent af3b6aca22
commit c9a0874f35
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 12 additions and 7 deletions

View File

@ -258,12 +258,12 @@ static void compute_tensor_statistics(std::vector<tensor_statistics> & tstats) {
if (std::smatch match; std::regex_search(ts.tensor, match, pattern)) {
const int blk = std::stoi(match[1]);
if (blk <= 0) continue;
if (blk <= 0) { continue; }
std::string tname(ts.tensor);
tname.replace(match.position(1), match.length(1), std::to_string(blk-1));
auto prev = std::find_if(tstats.begin(), tstats.end(),
[tname](const tensor_statistics & t) { return t.tensor == tname; });
if (prev == tstats.end()) continue;
if (prev == tstats.end()) { continue; }
const auto curr_avg = compute_tensor_averages(ts.stats);
const auto prev_avg = compute_tensor_averages(prev->stats);
if (curr_avg.size() == prev_avg.size() && !curr_avg.empty()) {
@ -275,7 +275,12 @@ static void compute_tensor_statistics(std::vector<tensor_statistics> & tstats) {
vec1 += curr_avg[i] * curr_avg[i];
vec2 += prev_avg[i] * prev_avg[i];
}
if (vec1 > 0 && vec2 > 0) ts.cossim = dot_prod / (std::sqrt(vec1) * std::sqrt(vec2));
if (vec1 > 0 && vec2 > 0) {
float cs = dot_prod / (std::sqrt(vec1) * std::sqrt(vec2));
cs = std::min(cs, 1.0f);
cs = std::max(cs, -1.0f);
ts.cossim = cs;
}
}
}
}
@ -283,19 +288,19 @@ static void compute_tensor_statistics(std::vector<tensor_statistics> & tstats) {
// compute the L2 Norm (Euclidian Distance) between the same tensors in consecutive layers
for (auto & ts : tstats) {
ts.l2_norm = 0.0f;
if (ts.stats.activations.empty()) continue;
if (ts.stats.activations.empty()) { continue; }
if (std::smatch match; std::regex_search(ts.tensor, match, pattern)) {
const int blk = std::stoi(match[1]);
if (blk <= 0) continue;
if (blk <= 0) { continue; }
std::string tname(ts.tensor);
tname.replace(match.position(1), match.length(1), std::to_string(blk - 1));
auto prev = std::find_if(tstats.begin(), tstats.end(),
[tname](const tensor_statistics & t) { return t.tensor == tname; });
if (prev == tstats.end()) continue;
if (prev == tstats.end()) { continue; }
const auto cur_avg = compute_tensor_averages(ts.stats);
const auto prev_avg = compute_tensor_averages(prev->stats);
if (cur_avg.empty() || prev_avg.empty() || cur_avg.size() != prev_avg.size()) continue;
if (cur_avg.empty() || prev_avg.empty() || cur_avg.size() != prev_avg.size()) { continue; }
float dist = 0.0;
for (size_t i = 0; i < cur_avg.size(); ++i) {