1.6x speedup of MatMulSlow using compensated Dot

PiperOrigin-RevId: 679063289
This commit is contained in:
Jan Wassenberg 2024-09-26 02:42:09 -07:00 committed by Copybara-Service
parent 606427022c
commit 1bd64ec350
1 changed files with 27 additions and 29 deletions

View File

@ -38,6 +38,7 @@
#include "hwy/highway.h" #include "hwy/highway.h"
// After highway.h // After highway.h
#include "compression/compress-inl.h" #include "compression/compress-inl.h"
#include "ops/dot-inl.h"
#include "ops/matmul-inl.h" #include "ops/matmul-inl.h"
#include "hwy/tests/test_util-inl.h" #include "hwy/tests/test_util-inl.h"
@ -51,7 +52,7 @@ using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
template <typename MatT, size_t kRows, size_t kCols, template <typename MatT, size_t kRows, size_t kCols,
size_t kNum = kRows * kCols, size_t kNum = kRows * kCols,
class MatPtr = std::unique_ptr<CompressedArray<MatT, kNum>>> class MatPtr = std::unique_ptr<CompressedArray<MatT, kNum>>>
MatPtr GenerateMatHeap(size_t offset, hwy::ThreadPool& pool) { MatPtr GenerateMat(size_t offset, hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws; gcpp::CompressWorkingSet ws;
FloatPtr content = hwy::AllocateAligned<float>(kNum); FloatPtr content = hwy::AllocateAligned<float>(kNum);
HWY_ASSERT(content); HWY_ASSERT(content);
@ -72,7 +73,7 @@ MatPtr GenerateMatHeap(size_t offset, hwy::ThreadPool& pool) {
template <typename MatT, size_t kRows, size_t kCols, template <typename MatT, size_t kRows, size_t kCols,
size_t kNum = kRows * kCols, size_t kNum = kRows * kCols,
class MatPtr = std::unique_ptr<CompressedArray<MatT, kNum>>> class MatPtr = std::unique_ptr<CompressedArray<MatT, kNum>>>
MatPtr GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) { MatPtr GenerateTransposedMat(size_t offset, hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws; gcpp::CompressWorkingSet ws;
FloatPtr content = hwy::AllocateAligned<float>(kNum); FloatPtr content = hwy::AllocateAligned<float>(kNum);
const float scale = SfpStream::kMax / (kNum + offset); const float scale = SfpStream::kMax / (kNum + offset);
@ -93,7 +94,7 @@ MatPtr GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) {
template <typename MatT, size_t kRows, size_t kCols, template <typename MatT, size_t kRows, size_t kCols,
size_t kNum = kRows * kCols, size_t kNum = kRows * kCols,
class MatPtr = std::unique_ptr<CompressedArray<MatT, kNum>>> class MatPtr = std::unique_ptr<CompressedArray<MatT, kNum>>>
MatPtr GenerateZeroMatHeap(hwy::ThreadPool& pool) { MatPtr GenerateZeroMat(hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws; gcpp::CompressWorkingSet ws;
FloatPtr content = hwy::AllocateAligned<float>(kNum); FloatPtr content = hwy::AllocateAligned<float>(kNum);
HWY_ASSERT(content); HWY_ASSERT(content);
@ -158,19 +159,20 @@ void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b,
} }
} }
} }
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup. // Largely unoptimized; reordered innermost loops nets ~5-10X speedup.
template <typename MatTA, typename MatTB, HWY_IF_T_SIZE_GT(MatTB, 1)> template <typename MatTA, typename MatTB, HWY_IF_NOT_BF16(MatTA)>
HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc,
const MatTA* HWY_RESTRICT a, const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b, const float scale, const MatTB* HWY_RESTRICT b_trans, const float scale,
const float* add, float* HWY_RESTRICT out) { const float* add, float* HWY_RESTRICT out) {
const hn::ScalableTag<float> df;
const PackedSpan<const MatTB> b_span =
MakeSpan(b_trans, cols_a_rows_b * cols_bc);
for (size_t i = 0; i < rows_ac; ++i) { for (size_t i = 0; i < rows_ac; ++i) {
for (size_t k = 0; k < cols_a_rows_b; ++k) { for (size_t j = 0; j < cols_bc; ++j) {
for (size_t j = 0; j < cols_bc; ++j) { out[i * cols_bc + j] = scale * Dot(df, b_span, j * cols_a_rows_b,
const float a1 = hwy::ConvertScalarTo<float>(a[i * cols_a_rows_b + k]); a + i * cols_a_rows_b, cols_a_rows_b);
const float b1 = hwy::ConvertScalarTo<float>(b[k * cols_bc + j]);
out[i * cols_bc + j] += scale * a1 * b1;
}
} }
if (add != nullptr) { if (add != nullptr) {
for (size_t j = 0; j < cols_bc; ++j) { for (size_t j = 0; j < cols_bc; ++j) {
@ -180,21 +182,20 @@ HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc,
} }
} }
// The above overload can handle combinations of f32 and bf16, but this one // The above overload can handle A=f32 and any B; handle A=bf16 via Decompress.
// is required for MatTB = {SFP, NUQ}. template <typename MatTA, typename MatTB, HWY_IF_BF16(MatTA)>
template <typename MatTA, typename MatTB, HWY_IF_T_SIZE(MatTB, 1)>
HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc,
const MatTA* HWY_RESTRICT a, const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b_compr, const float scale, const MatTB* HWY_RESTRICT b_trans, const float scale,
const float* add, float* HWY_RESTRICT out) { const float* add, float* HWY_RESTRICT out) {
const size_t num_b = cols_a_rows_b * cols_bc; const size_t num_a = cols_a_rows_b * rows_ac;
FloatPtr b = hwy::AllocateAligned<float>(num_b); FloatPtr a_raw = hwy::AllocateAligned<float>(num_a);
HWY_ASSERT(b); HWY_ASSERT(a_raw);
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, MakeSpan(b_compr, num_b), 0, b.get(), num_b); DecompressAndZeroPad(df, MakeSpan(a, num_a), 0, a_raw.get(), num_a);
MatMulSlow(rows_ac, cols_a_rows_b, cols_bc, a, b.get(), scale, add, out); MatMulSlow(rows_ac, cols_a_rows_b, cols_bc, a_raw.get(), b_trans, scale, add,
out);
} }
void PrintSpeed(const char* algo, size_t rows_ac, size_t cols_a_rows_b, void PrintSpeed(const char* algo, size_t rows_ac, size_t cols_a_rows_b,
size_t cols_bc, double elapsed) { size_t cols_bc, double elapsed) {
const size_t num_b = cols_a_rows_b * cols_bc; const size_t num_b = cols_a_rows_b * cols_bc;
@ -213,26 +214,23 @@ void TestMatMul(MatMulEnv& env) {
TypeName<MatTB>()); TypeName<MatTB>());
std::unique_ptr<CompressedArray<MatTA, kRowsAC * kColsARowsB>> a = std::unique_ptr<CompressedArray<MatTA, kRowsAC * kColsARowsB>> a =
GenerateMatHeap<MatTA, kRowsAC, kColsARowsB>(0, pool); GenerateMat<MatTA, kRowsAC, kColsARowsB>(0, pool);
std::unique_ptr<CompressedArray<MatTB, kColsARowsB * kColsBC>> b_trans = std::unique_ptr<CompressedArray<MatTB, kColsARowsB * kColsBC>> b_trans =
GenerateTransposeMatHeap<MatTB, kColsARowsB, kColsBC>(0, pool); GenerateTransposedMat<MatTB, kColsARowsB, kColsBC>(0, pool);
FloatPtr c = hwy::AllocateAligned<float>(kRowsAC * kColsBC); FloatPtr c = hwy::AllocateAligned<float>(kRowsAC * kColsBC);
HWY_ASSERT(c); HWY_ASSERT(c);
const float scale = a->scale() * b_trans->scale(); const float scale = a->scale() * b_trans->scale();
std::unique_ptr<CompressedArray<float, kColsBC>> add; std::unique_ptr<CompressedArray<float, kColsBC>> add;
if (kAdd) { if (kAdd) {
add = GenerateMatHeap<float, 1, kColsBC>(0, pool); add = GenerateMat<float, 1, kColsBC>(0, pool);
add->set_scale(1.0f); add->set_scale(1.0f);
} }
std::unique_ptr<CompressedArray<MatTB, kColsARowsB * kColsBC>> b =
GenerateMatHeap<MatTB, kColsARowsB, kColsBC>(0, pool);
HWY_ASSERT_EQ(scale, a->scale() * b->scale());
std::unique_ptr<CompressedArray<float, kRowsAC * kColsBC>> c_slow = std::unique_ptr<CompressedArray<float, kRowsAC * kColsBC>> c_slow =
GenerateZeroMatHeap<float, kRowsAC, kColsBC>(pool); GenerateZeroMat<float, kRowsAC, kColsBC>(pool);
const double start_slow = hwy::platform::Now(); const double start_slow = hwy::platform::Now();
MatMulSlow(kRowsAC, kColsARowsB, kColsBC, a->data(), b->data(), scale, MatMulSlow(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(), scale,
kAdd ? add->data() : nullptr, c_slow->data()); kAdd ? add->data() : nullptr, c_slow->data());
if (want_bench) { if (want_bench) {
PrintSpeed("MatMulSlow", kRowsAC, kColsARowsB, kColsBC, PrintSpeed("MatMulSlow", kRowsAC, kColsARowsB, kColsBC,