1.6x speedup of MatMulSlow using compensated Dot

PiperOrigin-RevId: 679063289
This commit is contained in:
Jan Wassenberg 2024-09-26 02:42:09 -07:00 committed by Copybara-Service
parent 606427022c
commit 1bd64ec350
1 changed files with 27 additions and 29 deletions

View File

@ -38,6 +38,7 @@
#include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
#include "ops/dot-inl.h"
#include "ops/matmul-inl.h"
#include "hwy/tests/test_util-inl.h"
@ -51,7 +52,7 @@ using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
template <typename MatT, size_t kRows, size_t kCols,
size_t kNum = kRows * kCols,
class MatPtr = std::unique_ptr<CompressedArray<MatT, kNum>>>
MatPtr GenerateMatHeap(size_t offset, hwy::ThreadPool& pool) {
MatPtr GenerateMat(size_t offset, hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
FloatPtr content = hwy::AllocateAligned<float>(kNum);
HWY_ASSERT(content);
@ -72,7 +73,7 @@ MatPtr GenerateMatHeap(size_t offset, hwy::ThreadPool& pool) {
template <typename MatT, size_t kRows, size_t kCols,
size_t kNum = kRows * kCols,
class MatPtr = std::unique_ptr<CompressedArray<MatT, kNum>>>
MatPtr GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) {
MatPtr GenerateTransposedMat(size_t offset, hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
FloatPtr content = hwy::AllocateAligned<float>(kNum);
const float scale = SfpStream::kMax / (kNum + offset);
@ -93,7 +94,7 @@ MatPtr GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) {
template <typename MatT, size_t kRows, size_t kCols,
size_t kNum = kRows * kCols,
class MatPtr = std::unique_ptr<CompressedArray<MatT, kNum>>>
MatPtr GenerateZeroMatHeap(hwy::ThreadPool& pool) {
MatPtr GenerateZeroMat(hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
FloatPtr content = hwy::AllocateAligned<float>(kNum);
HWY_ASSERT(content);
@ -158,19 +159,20 @@ void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b,
}
}
}
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup.
template <typename MatTA, typename MatTB, HWY_IF_T_SIZE_GT(MatTB, 1)>
template <typename MatTA, typename MatTB, HWY_IF_NOT_BF16(MatTA)>
HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc,
const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b, const float scale,
const MatTB* HWY_RESTRICT b_trans, const float scale,
const float* add, float* HWY_RESTRICT out) {
const hn::ScalableTag<float> df;
const PackedSpan<const MatTB> b_span =
MakeSpan(b_trans, cols_a_rows_b * cols_bc);
for (size_t i = 0; i < rows_ac; ++i) {
for (size_t k = 0; k < cols_a_rows_b; ++k) {
for (size_t j = 0; j < cols_bc; ++j) {
const float a1 = hwy::ConvertScalarTo<float>(a[i * cols_a_rows_b + k]);
const float b1 = hwy::ConvertScalarTo<float>(b[k * cols_bc + j]);
out[i * cols_bc + j] += scale * a1 * b1;
}
out[i * cols_bc + j] = scale * Dot(df, b_span, j * cols_a_rows_b,
a + i * cols_a_rows_b, cols_a_rows_b);
}
if (add != nullptr) {
for (size_t j = 0; j < cols_bc; ++j) {
@ -180,21 +182,20 @@ HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc,
}
}
// The above overload can handle combinations of f32 and bf16, but this one
// is required for MatTB = {SFP, NUQ}.
template <typename MatTA, typename MatTB, HWY_IF_T_SIZE(MatTB, 1)>
// The above overload can handle A=f32 and any B; handle A=bf16 via Decompress.
template <typename MatTA, typename MatTB, HWY_IF_BF16(MatTA)>
HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc,
const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b_compr, const float scale,
const MatTB* HWY_RESTRICT b_trans, const float scale,
const float* add, float* HWY_RESTRICT out) {
const size_t num_b = cols_a_rows_b * cols_bc;
FloatPtr b = hwy::AllocateAligned<float>(num_b);
HWY_ASSERT(b);
const size_t num_a = cols_a_rows_b * rows_ac;
FloatPtr a_raw = hwy::AllocateAligned<float>(num_a);
HWY_ASSERT(a_raw);
const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, MakeSpan(b_compr, num_b), 0, b.get(), num_b);
MatMulSlow(rows_ac, cols_a_rows_b, cols_bc, a, b.get(), scale, add, out);
DecompressAndZeroPad(df, MakeSpan(a, num_a), 0, a_raw.get(), num_a);
MatMulSlow(rows_ac, cols_a_rows_b, cols_bc, a_raw.get(), b_trans, scale, add,
out);
}
void PrintSpeed(const char* algo, size_t rows_ac, size_t cols_a_rows_b,
size_t cols_bc, double elapsed) {
const size_t num_b = cols_a_rows_b * cols_bc;
@ -213,26 +214,23 @@ void TestMatMul(MatMulEnv& env) {
TypeName<MatTB>());
std::unique_ptr<CompressedArray<MatTA, kRowsAC * kColsARowsB>> a =
GenerateMatHeap<MatTA, kRowsAC, kColsARowsB>(0, pool);
GenerateMat<MatTA, kRowsAC, kColsARowsB>(0, pool);
std::unique_ptr<CompressedArray<MatTB, kColsARowsB * kColsBC>> b_trans =
GenerateTransposeMatHeap<MatTB, kColsARowsB, kColsBC>(0, pool);
GenerateTransposedMat<MatTB, kColsARowsB, kColsBC>(0, pool);
FloatPtr c = hwy::AllocateAligned<float>(kRowsAC * kColsBC);
HWY_ASSERT(c);
const float scale = a->scale() * b_trans->scale();
std::unique_ptr<CompressedArray<float, kColsBC>> add;
if (kAdd) {
add = GenerateMatHeap<float, 1, kColsBC>(0, pool);
add = GenerateMat<float, 1, kColsBC>(0, pool);
add->set_scale(1.0f);
}
std::unique_ptr<CompressedArray<MatTB, kColsARowsB * kColsBC>> b =
GenerateMatHeap<MatTB, kColsARowsB, kColsBC>(0, pool);
HWY_ASSERT_EQ(scale, a->scale() * b->scale());
std::unique_ptr<CompressedArray<float, kRowsAC * kColsBC>> c_slow =
GenerateZeroMatHeap<float, kRowsAC, kColsBC>(pool);
GenerateZeroMat<float, kRowsAC, kColsBC>(pool);
const double start_slow = hwy::platform::Now();
MatMulSlow(kRowsAC, kColsARowsB, kColsBC, a->data(), b->data(), scale,
MatMulSlow(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(), scale,
kAdd ? add->data() : nullptr, c_slow->data());
if (want_bench) {
PrintSpeed("MatMulSlow", kRowsAC, kColsARowsB, kColsBC,