Improve WCE to be magnitude-aware

This commit is contained in:
Ed Addario 2026-03-01 09:19:55 +00:00
parent a057d827ca
commit 06d3b50b03
No known key found for this signature in database
GPG Key ID: E7875815A3230993
1 changed files with 9 additions and 8 deletions

View File

@ -1091,17 +1091,18 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
}
}
// Cosine Distance
double cos_sim;
const double norm_prod = nx * ny;
// Concordance correlation coefficient (magnitude-Aware WCE)
double ccc;
const double norm_sum = nx + ny;
if (norm_prod <= EPSILON) { cos_sim = nx <= EPSILON && ny <= EPSILON ? 1.0 : 0.0; }
else { cos_sim = dot / std::sqrt(norm_prod); }
if (norm_sum <= EPSILON) { ccc = nx <= EPSILON && ny <= EPSILON ? 1.0 : 0.0; }
else { ccc = 2.0 * dot / norm_sum; }
if (cos_sim > 1.0) { cos_sim = 1.0; }
else if (cos_sim < -1.0) { cos_sim = -1.0; }
slice_sum += 1.0 - cos_sim;
if (ccc > 1.0) { ccc = 1.0; }
else if (ccc < -1.0) { ccc = -1.0; }
slice_sum += 1.0 - ccc;
off += (size_t) n_per_row;
}