diff --git a/gemma/ops_test.cc b/gemma/ops_test.cc index 604190f..abb7dae 100644 --- a/gemma/ops_test.cc +++ b/gemma/ops_test.cc @@ -510,11 +510,14 @@ template void AssertClose(const MatT* HWY_RESTRICT expected, const MatT* HWY_RESTRICT actual, size_t num) { for (size_t idx = 0; idx < num; idx++) { - double expected_value = hwy::ConvertScalarTo(expected[idx]); - double actual_value = hwy::ConvertScalarTo(actual[idx]); + const double expected_value = hwy::ConvertScalarTo(expected[idx]); + const double actual_value = hwy::ConvertScalarTo(actual[idx]); + + const double magnitude = std::abs(expected_value); const double tolerance = - expected_value * 21 * 1.0 / (1ULL << hwy::MantissaBits()); + 64.0 * hwy::ConvertScalarTo(hwy::Epsilon()) * + HWY_MAX(magnitude, 1.0); if (!(expected_value - tolerance <= actual_value && actual_value <= expected_value + tolerance)) { @@ -559,7 +562,7 @@ void TestAllTiledMatMul() { // large-scale test // TODO(philculliton): investigate rounding issues with large matrices - // TestTiledMatMul<512, 24576, 3072, float>(); + TestTiledMatMul<512, 24576, 3072, float>(); // TODO(janwas): SFP }