diff --git a/gemma/ops_test.cc b/gemma/ops_test.cc index 5641efa..2eddb7a 100644 --- a/gemma/ops_test.cc +++ b/gemma/ops_test.cc @@ -511,8 +511,9 @@ void AssertClose(const hwy::AlignedFreeUniquePtr& a, template void TestTiledBatchMatMul() { - fprintf(stderr, "TestTiledBatchMatMul %lu, %lu, %lu, add=%d ", kM, kN, kK, - kAdd); + fprintf(stderr, + "TestTiledBatchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n", + kM, kN, kK, kAdd, typeid(MatTA).name(), typeid(MatTB).name()); hwy::ThreadPool pool(3); std::unique_ptr> a = GenerateMatHeap(0, pool); @@ -523,18 +524,25 @@ void TestTiledBatchMatMul() { std::unique_ptr> c_slow = GenerateZeroMatHeap(pool); + const double start_slow = hwy::platform::Now(); MatMulSlowBatch(kM, a->data(), b->data(), kAdd ? add->data() : nullptr, c_slow->data()); + const double slow_matmul_seconds = hwy::platform::Now() - start_slow; + fprintf(stderr, "MatMulSlowBatch took %f seconds.\n", slow_matmul_seconds); hwy::AlignedFreeUniquePtr c = hwy::AllocateAligned(kM * kK); std::unique_ptr> b_trans = GenerateTransposeMatHeap(0, pool); + + const double start_tiled = hwy::platform::Now(); if (kAdd) { MatMul_4x4_Batch_Add(kM, a->data(), b_trans->data(), c.get(), add->data(), pool); } else { MatMul_4x4_Batch(kM, a->data(), b_trans->data(), c.get(), pool); } + const double tiled_matmul_seconds = hwy::platform::Now() - start_tiled; + fprintf(stderr, "MatMul_4x4_Batch took %f seconds.\n", tiled_matmul_seconds); AssertClose(c_slow->data(), c.get(), kM * kK); }