Fix compilation and tests for gcc

This commit is contained in:
Zoltan Szabadka 2024-06-04 07:39:54 +00:00
parent 36e4d8bbfe
commit 7e639856da
2 changed files with 5 additions and 5 deletions

View File

@ -34,7 +34,7 @@
// Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_BACkWARD_TOGGLE
#ifdef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE

View File

@ -92,8 +92,8 @@ void TestMatMulVJP() {
memset(&grad_scalar, 0, sizeof(grad_scalar));
MatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
dx_scalar.data(), kRows, kCols, kTokens);
TestNear(dx, dx_scalar, 0, 0, __LINE__);
TestNear(grad, grad_scalar, 0, 0, __LINE__);
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__);
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__);
}
}
@ -137,8 +137,8 @@ void TestMultiHeadMatMulVJP() {
memset(&grad_scalar, 0, sizeof(grad_scalar));
MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
dx_scalar.data(), kHeads, kRows, kCols, kTokens);
TestNear(dx, dx_scalar, 0, 0, __LINE__);
TestNear(grad, grad_scalar, 0, 0, __LINE__);
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__);
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__);
}
}