Internal change

PiperOrigin-RevId: 874097322
This commit is contained in:
Jan Wassenberg 2026-02-23 08:55:04 -08:00 committed by Copybara-Service
parent 7dc98902d3
commit 463a3682be
1 changed files with 34 additions and 23 deletions

View File

@ -18,6 +18,7 @@
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
#include <stddef.h>
#include <stdio.h>
#include <vector>
@ -219,6 +220,8 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
// magnitude, but also to f32 accumulation of rows in A and B.
const double norm = MaxRowAbsSum(a_batch) * MaxRowAbsSum(b_trans_batch);
const float max_abs = MaxAbs(a_batch) * MaxAbs(b_trans_batch);
HWY_ASSERT(hn::AllTrue(df, hn::IsFinite(hn::Set(df, norm))));
HWY_ASSERT(hn::AllTrue(df, hn::IsFinite(hn::Set(df, max_abs))));
const double eps_bf16 = hwy::ConvertScalarTo<double>(hwy::Epsilon<BF16>());
const double eps_f32 = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>());
// Dot() uses double-precision summation.
@ -232,10 +235,7 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
if (tolerance > 500.0) {
HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs);
}
const double rel_tolerance =
1.0 + hwy::ConvertScalarTo<double>(hwy::Epsilon<TC>());
double max_rel = 0.0;
double worst_l1 = 0.0;
size_t worst_r = 0;
size_t worst_c = 0;
double worst_actual = 0.0;
@ -247,34 +247,45 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
for (size_t c = 0; c < B.Rows(); c++) {
const double expected_value = static_cast<double>(expected_row[c]);
const double actual_value = static_cast<double>(actual_row[c]);
const bool in_range = expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance;
const double l1 = hwy::ScalarAbs(expected_value - actual_value);
if (l1 > HWY_MAX(tolerance, tolerance * hwy::ScalarAbs(expected_value))) {
fprintf(stderr, "%zu,%zu\n", r, c);
++num_outside;
}
if (!in_range) {
const double max = HWY_MAX(expected_value, actual_value);
const double min = HWY_MIN(expected_value, actual_value);
const double rel = max / HWY_MAX(min, 1E-6);
if (rel > max_rel) {
worst_expected = expected_value;
worst_actual = actual_value;
worst_r = r;
worst_c = c;
max_rel = rel;
++num_outside;
}
if (l1 > worst_l1) {
worst_l1 = l1;
worst_expected = expected_value;
worst_actual = actual_value;
worst_r = r;
worst_c = c;
}
}
}
if (max_rel > rel_tolerance) {
if (num_outside > 0) {
const size_t r_begin = worst_r >= 1 ? worst_r - 1 : 0;
const size_t r_end = HWY_MIN(r_begin + 3, A.Rows());
const size_t c_begin = worst_c >= 3 ? worst_c - 3 : 0;
const size_t c_end = HWY_MIN(c_begin + 7, B.Rows());
fprintf(stderr,
"%zu outside. Printing rows [%zu, %zu) and columns [%zu, %zu)\n",
num_outside, r_begin, r_end, c_begin, c_end);
for (size_t r = r_begin; r < r_end; r++) {
const float* expected_row = c_slow_batch.Row(r);
const float* actual_row = c_batch.Row(r);
for (size_t c = c_begin; c < c_end; c++) {
fprintf(stderr, "%6.3f=%6.3f ", expected_row[c], actual_row[c]);
}
fprintf(stderr, "\n");
}
hwy::Abort(__FILE__, line,
"(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
"tolerance %f rel %E max_rel %E num_outside %zu\n",
"tolerance %f worst_l1 %E\n",
worst_r, worst_c, worst_expected, worst_actual, norm, max_abs,
tolerance, max_rel, rel_tolerance, num_outside);
tolerance, worst_l1);
}
HWY_ASSERT(hn::AllFalse(
df, hn::IsEitherNaN(hn::Set(df, norm), hn::Set(df, max_abs))));
}
// NOLINTNEXTLINE(google-readability-namespace-comments)