mirror of https://github.com/google/gemma.cpp.git
Update AssertClose for large matrices and add large matrix test
PiperOrigin-RevId: 642277221
This commit is contained in:
parent
8ec8eef524
commit
b6565e3bf6
|
|
@ -510,11 +510,14 @@ template <typename MatT>
|
|||
void AssertClose(const MatT* HWY_RESTRICT expected,
|
||||
const MatT* HWY_RESTRICT actual, size_t num) {
|
||||
for (size_t idx = 0; idx < num; idx++) {
|
||||
double expected_value = hwy::ConvertScalarTo<double>(expected[idx]);
|
||||
double actual_value = hwy::ConvertScalarTo<double>(actual[idx]);
|
||||
const double expected_value = hwy::ConvertScalarTo<double>(expected[idx]);
|
||||
const double actual_value = hwy::ConvertScalarTo<double>(actual[idx]);
|
||||
|
||||
const double magnitude = std::abs(expected_value);
|
||||
|
||||
const double tolerance =
|
||||
expected_value * 21 * 1.0 / (1ULL << hwy::MantissaBits<MatT>());
|
||||
64.0 * hwy::ConvertScalarTo<double>(hwy::Epsilon<MatT>()) *
|
||||
HWY_MAX(magnitude, 1.0);
|
||||
|
||||
if (!(expected_value - tolerance <= actual_value &&
|
||||
actual_value <= expected_value + tolerance)) {
|
||||
|
|
@ -559,7 +562,7 @@ void TestAllTiledMatMul() {
|
|||
|
||||
// large-scale test
|
||||
// TODO(philculliton): investigate rounding issues with large matrices
|
||||
// TestTiledMatMul<512, 24576, 3072, float>();
|
||||
TestTiledMatMul<512, 24576, 3072, float>();
|
||||
|
||||
// TODO(janwas): SFP
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue