From 1bd64ec350029672030161d7a7e52ada0ea08ec2 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 26 Sep 2024 02:42:09 -0700 Subject: [PATCH] 1.6x speedup of MatMulSlow using compensated Dot PiperOrigin-RevId: 679063289 --- ops/matmul_test.cc | 56 ++++++++++++++++++++++------------------------ 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index dedeb8c..6e6c674 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -38,6 +38,7 @@ #include "hwy/highway.h" // After highway.h #include "compression/compress-inl.h" +#include "ops/dot-inl.h" #include "ops/matmul-inl.h" #include "hwy/tests/test_util-inl.h" @@ -51,7 +52,7 @@ using FloatPtr = hwy::AlignedFreeUniquePtr; template >> -MatPtr GenerateMatHeap(size_t offset, hwy::ThreadPool& pool) { +MatPtr GenerateMat(size_t offset, hwy::ThreadPool& pool) { gcpp::CompressWorkingSet ws; FloatPtr content = hwy::AllocateAligned(kNum); HWY_ASSERT(content); @@ -72,7 +73,7 @@ MatPtr GenerateMatHeap(size_t offset, hwy::ThreadPool& pool) { template >> -MatPtr GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) { +MatPtr GenerateTransposedMat(size_t offset, hwy::ThreadPool& pool) { gcpp::CompressWorkingSet ws; FloatPtr content = hwy::AllocateAligned(kNum); const float scale = SfpStream::kMax / (kNum + offset); @@ -93,7 +94,7 @@ MatPtr GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) { template >> -MatPtr GenerateZeroMatHeap(hwy::ThreadPool& pool) { +MatPtr GenerateZeroMat(hwy::ThreadPool& pool) { gcpp::CompressWorkingSet ws; FloatPtr content = hwy::AllocateAligned(kNum); HWY_ASSERT(content); @@ -158,19 +159,20 @@ void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b, } } } + // Largely unoptimized; reordered innermost loops nets ~5-10X speedup. -template +template HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, const MatTA* HWY_RESTRICT a, - const MatTB* HWY_RESTRICT b, const float scale, + const MatTB* HWY_RESTRICT b_trans, const float scale, const float* add, float* HWY_RESTRICT out) { + const hn::ScalableTag df; + const PackedSpan b_span = + MakeSpan(b_trans, cols_a_rows_b * cols_bc); for (size_t i = 0; i < rows_ac; ++i) { - for (size_t k = 0; k < cols_a_rows_b; ++k) { - for (size_t j = 0; j < cols_bc; ++j) { - const float a1 = hwy::ConvertScalarTo(a[i * cols_a_rows_b + k]); - const float b1 = hwy::ConvertScalarTo(b[k * cols_bc + j]); - out[i * cols_bc + j] += scale * a1 * b1; - } + for (size_t j = 0; j < cols_bc; ++j) { + out[i * cols_bc + j] = scale * Dot(df, b_span, j * cols_a_rows_b, + a + i * cols_a_rows_b, cols_a_rows_b); } if (add != nullptr) { for (size_t j = 0; j < cols_bc; ++j) { @@ -180,21 +182,20 @@ HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, } } -// The above overload can handle combinations of f32 and bf16, but this one -// is required for MatTB = {SFP, NUQ}. -template +// The above overload can handle A=f32 and any B; handle A=bf16 via Decompress. +template HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, const MatTA* HWY_RESTRICT a, - const MatTB* HWY_RESTRICT b_compr, const float scale, + const MatTB* HWY_RESTRICT b_trans, const float scale, const float* add, float* HWY_RESTRICT out) { - const size_t num_b = cols_a_rows_b * cols_bc; - FloatPtr b = hwy::AllocateAligned(num_b); - HWY_ASSERT(b); + const size_t num_a = cols_a_rows_b * rows_ac; + FloatPtr a_raw = hwy::AllocateAligned(num_a); + HWY_ASSERT(a_raw); const hn::ScalableTag df; - DecompressAndZeroPad(df, MakeSpan(b_compr, num_b), 0, b.get(), num_b); - MatMulSlow(rows_ac, cols_a_rows_b, cols_bc, a, b.get(), scale, add, out); + DecompressAndZeroPad(df, MakeSpan(a, num_a), 0, a_raw.get(), num_a); + MatMulSlow(rows_ac, cols_a_rows_b, cols_bc, a_raw.get(), b_trans, scale, add, + out); } - void PrintSpeed(const char* algo, size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, double elapsed) { const size_t num_b = cols_a_rows_b * cols_bc; @@ -213,26 +214,23 @@ void TestMatMul(MatMulEnv& env) { TypeName()); std::unique_ptr> a = - GenerateMatHeap(0, pool); + GenerateMat(0, pool); std::unique_ptr> b_trans = - GenerateTransposeMatHeap(0, pool); + GenerateTransposedMat(0, pool); FloatPtr c = hwy::AllocateAligned(kRowsAC * kColsBC); HWY_ASSERT(c); const float scale = a->scale() * b_trans->scale(); std::unique_ptr> add; if (kAdd) { - add = GenerateMatHeap(0, pool); + add = GenerateMat(0, pool); add->set_scale(1.0f); } - std::unique_ptr> b = - GenerateMatHeap(0, pool); - HWY_ASSERT_EQ(scale, a->scale() * b->scale()); std::unique_ptr> c_slow = - GenerateZeroMatHeap(pool); + GenerateZeroMat(pool); const double start_slow = hwy::platform::Now(); - MatMulSlow(kRowsAC, kColsARowsB, kColsBC, a->data(), b->data(), scale, + MatMulSlow(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(), scale, kAdd ? add->data() : nullptr, c_slow->data()); if (want_bench) { PrintSpeed("MatMulSlow", kRowsAC, kColsARowsB, kColsBC,