mirror of https://github.com/google/gemma.cpp.git
Record time measurements in MatMul tests.
PiperOrigin-RevId: 651060711
This commit is contained in:
parent
ee6e017a77
commit
960ff4b4ec
|
|
@ -511,8 +511,9 @@ void AssertClose(const hwy::AlignedFreeUniquePtr<float[]>& a,
|
|||
template <size_t kM, size_t kN, size_t kK, bool kAdd, typename MatTA,
|
||||
typename MatTB = MatTA>
|
||||
void TestTiledBatchMatMul() {
|
||||
fprintf(stderr, "TestTiledBatchMatMul %lu, %lu, %lu, add=%d ", kM, kN, kK,
|
||||
kAdd);
|
||||
fprintf(stderr,
|
||||
"TestTiledBatchMatMul %lu, %lu, %lu, add=%d, MatTA=%s, MatTB=%s\n",
|
||||
kM, kN, kK, kAdd, typeid(MatTA).name(), typeid(MatTB).name());
|
||||
hwy::ThreadPool pool(3);
|
||||
std::unique_ptr<CompressedArray<MatTA, kM * kN>> a =
|
||||
GenerateMatHeap<MatTA, kM, kN>(0, pool);
|
||||
|
|
@ -523,18 +524,25 @@ void TestTiledBatchMatMul() {
|
|||
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow =
|
||||
GenerateZeroMatHeap<float, kM, kK>(pool);
|
||||
|
||||
const double start_slow = hwy::platform::Now();
|
||||
MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(),
|
||||
kAdd ? add->data() : nullptr, c_slow->data());
|
||||
const double slow_matmul_seconds = hwy::platform::Now() - start_slow;
|
||||
fprintf(stderr, "MatMulSlowBatch took %f seconds.\n", slow_matmul_seconds);
|
||||
|
||||
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
|
||||
std::unique_ptr<CompressedArray<MatTB, kN * kK>> b_trans =
|
||||
GenerateTransposeMatHeap<MatTB, kN, kK>(0, pool);
|
||||
|
||||
const double start_tiled = hwy::platform::Now();
|
||||
if (kAdd) {
|
||||
MatMul_4x4_Batch_Add<kN, kK, kAdd>(kM, a->data(), b_trans->data(), c.get(),
|
||||
add->data(), pool);
|
||||
} else {
|
||||
MatMul_4x4_Batch<kN, kK>(kM, a->data(), b_trans->data(), c.get(), pool);
|
||||
}
|
||||
const double tiled_matmul_seconds = hwy::platform::Now() - start_tiled;
|
||||
fprintf(stderr, "MatMul_4x4_Batch took %f seconds.\n", tiled_matmul_seconds);
|
||||
|
||||
AssertClose(c_slow->data(), c.get(), kM * kK);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue