Improve WCE to be magnitude-aware
This commit is contained in:
parent
a057d827ca
commit
06d3b50b03
|
|
@ -1091,17 +1091,18 @@ static std::unordered_map<std::string, ggml_type> target_bpw_type(
|
|||
}
|
||||
}
|
||||
|
||||
// Cosine Distance
|
||||
double cos_sim;
|
||||
const double norm_prod = nx * ny;
|
||||
// Concordance correlation coefficient (magnitude-Aware WCE)
|
||||
double ccc;
|
||||
const double norm_sum = nx + ny;
|
||||
|
||||
if (norm_prod <= EPSILON) { cos_sim = nx <= EPSILON && ny <= EPSILON ? 1.0 : 0.0; }
|
||||
else { cos_sim = dot / std::sqrt(norm_prod); }
|
||||
if (norm_sum <= EPSILON) { ccc = nx <= EPSILON && ny <= EPSILON ? 1.0 : 0.0; }
|
||||
else { ccc = 2.0 * dot / norm_sum; }
|
||||
|
||||
if (cos_sim > 1.0) { cos_sim = 1.0; }
|
||||
else if (cos_sim < -1.0) { cos_sim = -1.0; }
|
||||
|
||||
slice_sum += 1.0 - cos_sim;
|
||||
if (ccc > 1.0) { ccc = 1.0; }
|
||||
else if (ccc < -1.0) { ccc = -1.0; }
|
||||
|
||||
slice_sum += 1.0 - ccc;
|
||||
off += (size_t) n_per_row;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue