mirror of https://github.com/google/gemma.cpp.git
parent
7dc98902d3
commit
463a3682be
|
|
@ -18,6 +18,7 @@
|
|||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -219,6 +220,8 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
// magnitude, but also to f32 accumulation of rows in A and B.
|
||||
const double norm = MaxRowAbsSum(a_batch) * MaxRowAbsSum(b_trans_batch);
|
||||
const float max_abs = MaxAbs(a_batch) * MaxAbs(b_trans_batch);
|
||||
HWY_ASSERT(hn::AllTrue(df, hn::IsFinite(hn::Set(df, norm))));
|
||||
HWY_ASSERT(hn::AllTrue(df, hn::IsFinite(hn::Set(df, max_abs))));
|
||||
const double eps_bf16 = hwy::ConvertScalarTo<double>(hwy::Epsilon<BF16>());
|
||||
const double eps_f32 = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>());
|
||||
// Dot() uses double-precision summation.
|
||||
|
|
@ -232,10 +235,7 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
if (tolerance > 500.0) {
|
||||
HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs);
|
||||
}
|
||||
const double rel_tolerance =
|
||||
1.0 + hwy::ConvertScalarTo<double>(hwy::Epsilon<TC>());
|
||||
|
||||
double max_rel = 0.0;
|
||||
double worst_l1 = 0.0;
|
||||
size_t worst_r = 0;
|
||||
size_t worst_c = 0;
|
||||
double worst_actual = 0.0;
|
||||
|
|
@ -247,34 +247,45 @@ void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
|||
for (size_t c = 0; c < B.Rows(); c++) {
|
||||
const double expected_value = static_cast<double>(expected_row[c]);
|
||||
const double actual_value = static_cast<double>(actual_row[c]);
|
||||
const bool in_range = expected_value - tolerance <= actual_value &&
|
||||
actual_value <= expected_value + tolerance;
|
||||
const double l1 = hwy::ScalarAbs(expected_value - actual_value);
|
||||
if (l1 > HWY_MAX(tolerance, tolerance * hwy::ScalarAbs(expected_value))) {
|
||||
fprintf(stderr, "%zu,%zu\n", r, c);
|
||||
++num_outside;
|
||||
}
|
||||
|
||||
if (!in_range) {
|
||||
const double max = HWY_MAX(expected_value, actual_value);
|
||||
const double min = HWY_MIN(expected_value, actual_value);
|
||||
const double rel = max / HWY_MAX(min, 1E-6);
|
||||
if (rel > max_rel) {
|
||||
worst_expected = expected_value;
|
||||
worst_actual = actual_value;
|
||||
worst_r = r;
|
||||
worst_c = c;
|
||||
max_rel = rel;
|
||||
++num_outside;
|
||||
}
|
||||
if (l1 > worst_l1) {
|
||||
worst_l1 = l1;
|
||||
worst_expected = expected_value;
|
||||
worst_actual = actual_value;
|
||||
worst_r = r;
|
||||
worst_c = c;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (max_rel > rel_tolerance) {
|
||||
if (num_outside > 0) {
|
||||
const size_t r_begin = worst_r >= 1 ? worst_r - 1 : 0;
|
||||
const size_t r_end = HWY_MIN(r_begin + 3, A.Rows());
|
||||
const size_t c_begin = worst_c >= 3 ? worst_c - 3 : 0;
|
||||
const size_t c_end = HWY_MIN(c_begin + 7, B.Rows());
|
||||
fprintf(stderr,
|
||||
"%zu outside. Printing rows [%zu, %zu) and columns [%zu, %zu)\n",
|
||||
num_outside, r_begin, r_end, c_begin, c_end);
|
||||
for (size_t r = r_begin; r < r_end; r++) {
|
||||
const float* expected_row = c_slow_batch.Row(r);
|
||||
const float* actual_row = c_batch.Row(r);
|
||||
for (size_t c = c_begin; c < c_end; c++) {
|
||||
fprintf(stderr, "%6.3f=%6.3f ", expected_row[c], actual_row[c]);
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
hwy::Abort(__FILE__, line,
|
||||
"(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
|
||||
"tolerance %f rel %E max_rel %E num_outside %zu\n",
|
||||
"tolerance %f worst_l1 %E\n",
|
||||
worst_r, worst_c, worst_expected, worst_actual, norm, max_abs,
|
||||
tolerance, max_rel, rel_tolerance, num_outside);
|
||||
tolerance, worst_l1);
|
||||
}
|
||||
HWY_ASSERT(hn::AllFalse(
|
||||
df, hn::IsEitherNaN(hn::Set(df, norm), hn::Set(df, max_abs))));
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
Loading…
Reference in New Issue