Record time measurements in MatMul tests.

PiperOrigin-RevId: 651060711
This commit is contained in:
Andrey Vlasov 2024-07-10 10:04:06 -07:00 committed by Copybara-Service
parent ee6e017a77
commit 960ff4b4ec
1 changed files with 10 additions and 2 deletions

View File

@ -511,8 +511,9 @@ void AssertClose(const hwy::AlignedFreeUniquePtr<float[]>& a,
template <size_t kM, size_t kN, size_t kK, bool kAdd, typename MatTA,
typename MatTB = MatTA>
void TestTiledBatchMatMul() {
fprintf(stderr, "TestTiledBatchMatMul %lu, %lu, %lu, add=%d ", kM, kN, kK,
kAdd);
fprintf(stderr,
"TestTiledBatchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n",
kM, kN, kK, kAdd, typeid(MatTA).name(), typeid(MatTB).name());
hwy::ThreadPool pool(3);
std::unique_ptr<CompressedArray<MatTA, kM * kN>> a =
GenerateMatHeap<MatTA, kM, kN>(0, pool);
@ -523,18 +524,25 @@ void TestTiledBatchMatMul() {
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow =
GenerateZeroMatHeap<float, kM, kK>(pool);
const double start_slow = hwy::platform::Now();
MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(),
kAdd ? add->data() : nullptr, c_slow->data());
const double slow_matmul_seconds = hwy::platform::Now() - start_slow;
fprintf(stderr, "MatMulSlowBatch took %f seconds.\n", slow_matmul_seconds);
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b_trans =
GenerateTransposeMatHeap<MatTB, kN, kK>(0, pool);
const double start_tiled = hwy::platform::Now();
if (kAdd) {
MatMul_4x4_Batch_Add<kN, kK, kAdd>(kM, a->data(), b_trans->data(), c.get(),
add->data(), pool);
} else {
MatMul_4x4_Batch<kN, kK>(kM, a->data(), b_trans->data(), c.get(), pool);
}
const double tiled_matmul_seconds = hwy::platform::Now() - start_tiled;
fprintf(stderr, "MatMul_4x4_Batch took %f seconds.\n", tiled_matmul_seconds);
AssertClose(c_slow->data(), c.get(), kM * kK);
}