diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 9fbc908c16..48c6fe7c15 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -1091,17 +1091,18 @@ static std::unordered_map target_bpw_type( } } - // Cosine Distance - double cos_sim; - const double norm_prod = nx * ny; + // Concordance correlation coefficient (magnitude-Aware WCE) + double ccc; + const double norm_sum = nx + ny; - if (norm_prod <= EPSILON) { cos_sim = nx <= EPSILON && ny <= EPSILON ? 1.0 : 0.0; } - else { cos_sim = dot / std::sqrt(norm_prod); } + if (norm_sum <= EPSILON) { ccc = nx <= EPSILON && ny <= EPSILON ? 1.0 : 0.0; } + else { ccc = 2.0 * dot / norm_sum; } - if (cos_sim > 1.0) { cos_sim = 1.0; } - else if (cos_sim < -1.0) { cos_sim = -1.0; } - slice_sum += 1.0 - cos_sim; + if (ccc > 1.0) { ccc = 1.0; } + else if (ccc < -1.0) { ccc = -1.0; } + + slice_sum += 1.0 - ccc; off += (size_t) n_per_row; }