diff --git a/gemma/ops_test.cc b/gemma/ops_test.cc index 76bb9bc..604190f 100644 --- a/gemma/ops_test.cc +++ b/gemma/ops_test.cc @@ -414,8 +414,8 @@ GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) { const float scale = 1.0f / kInner; pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) { for (size_t j = 0; j < kInner; j++) { - content[i * kInner + j] = - static_cast((j * kInner + i + offset) * scale); + content[j * kOuter + i] = + static_cast((i * kInner + j + offset) * scale); } }); @@ -525,13 +525,10 @@ void AssertClose(const MatT* HWY_RESTRICT expected, } } -template +template void TestTiledMatMul() { hwy::ThreadPool pool(3); - constexpr size_t kM = 512; // 384 - constexpr size_t kN = 512; // * 5; // 6; // 768 - constexpr size_t kK = 512; // * 5; // 640 - std::unique_ptr> a = GenerateMatHeap(0, pool); std::unique_ptr> b = @@ -550,9 +547,20 @@ void TestTiledMatMul() { } void TestAllTiledMatMul() { - TestTiledMatMul(); - TestTiledMatMul(); - TestTiledMatMul(); + // medium-sized square test + TestTiledMatMul<512, 512, 512, float>(); + TestTiledMatMul<512, 512, 512, hwy::bfloat16_t>(); + TestTiledMatMul<512, 512, 512, float, hwy::bfloat16_t>(); + + // minimal non-square test + TestTiledMatMul<4, 128, 4, float>(); + TestTiledMatMul<4, 128, 4, hwy::bfloat16_t>(); + TestTiledMatMul<4, 128, 4, float, hwy::bfloat16_t>(); + + // large-scale test + // TODO(philculliton): investigate rounding issues with large matrices + // TestTiledMatMul<512, 24576, 3072, float>(); + // TODO(janwas): SFP }