mirror of https://github.com/google/gemma.cpp.git
Fix compilation and tests for gcc
This commit is contained in:
parent
36e4d8bbfe
commit
7e639856da
|
|
@ -34,7 +34,7 @@
|
|||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||
#ifdef THIRD_PARTY_GEMMA_CPP_BACkWARD_TOGGLE
|
||||
#ifdef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
|
||||
#undef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
|
||||
#else
|
||||
#define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
|
||||
|
|
|
|||
|
|
@ -92,8 +92,8 @@ void TestMatMulVJP() {
|
|||
memset(&grad_scalar, 0, sizeof(grad_scalar));
|
||||
MatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
|
||||
dx_scalar.data(), kRows, kCols, kTokens);
|
||||
TestNear(dx, dx_scalar, 0, 0, __LINE__);
|
||||
TestNear(grad, grad_scalar, 0, 0, __LINE__);
|
||||
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__);
|
||||
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -137,8 +137,8 @@ void TestMultiHeadMatMulVJP() {
|
|||
memset(&grad_scalar, 0, sizeof(grad_scalar));
|
||||
MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
|
||||
dx_scalar.data(), kHeads, kRows, kCols, kTokens);
|
||||
TestNear(dx, dx_scalar, 0, 0, __LINE__);
|
||||
TestNear(grad, grad_scalar, 0, 0, __LINE__);
|
||||
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__);
|
||||
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue