diff --git a/gemma/backward-inl.h b/gemma/backward-inl.h index 1f55ad9..53cba96 100644 --- a/gemma/backward-inl.h +++ b/gemma/backward-inl.h @@ -34,7 +34,7 @@ // Include guard for (potentially) SIMD code. #if defined(THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE) -#ifdef THIRD_PARTY_GEMMA_CPP_BACkWARD_TOGGLE +#ifdef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE #undef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE #else #define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE diff --git a/gemma/backward_test.cc b/gemma/backward_test.cc index 9b80936..50918a0 100644 --- a/gemma/backward_test.cc +++ b/gemma/backward_test.cc @@ -92,8 +92,8 @@ void TestMatMulVJP() { memset(&grad_scalar, 0, sizeof(grad_scalar)); MatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), dx_scalar.data(), kRows, kCols, kTokens); - TestNear(dx, dx_scalar, 0, 0, __LINE__); - TestNear(grad, grad_scalar, 0, 0, __LINE__); + TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__); + TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__); } } @@ -137,8 +137,8 @@ void TestMultiHeadMatMulVJP() { memset(&grad_scalar, 0, sizeof(grad_scalar)); MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), dx_scalar.data(), kHeads, kRows, kCols, kTokens); - TestNear(dx, dx_scalar, 0, 0, __LINE__); - TestNear(grad, grad_scalar, 0, 0, __LINE__); + TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__); + TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__); } }