mirror of https://github.com/google/gemma.cpp.git
1.6x speedup of MatMulSlow using compensated Dot
PiperOrigin-RevId: 679063289
This commit is contained in:
parent
606427022c
commit
1bd64ec350
|
|
@ -38,6 +38,7 @@
|
|||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "ops/dot-inl.h"
|
||||
#include "ops/matmul-inl.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
|
||||
|
|
@ -51,7 +52,7 @@ using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
|
|||
template <typename MatT, size_t kRows, size_t kCols,
|
||||
size_t kNum = kRows * kCols,
|
||||
class MatPtr = std::unique_ptr<CompressedArray<MatT, kNum>>>
|
||||
MatPtr GenerateMatHeap(size_t offset, hwy::ThreadPool& pool) {
|
||||
MatPtr GenerateMat(size_t offset, hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
FloatPtr content = hwy::AllocateAligned<float>(kNum);
|
||||
HWY_ASSERT(content);
|
||||
|
|
@ -72,7 +73,7 @@ MatPtr GenerateMatHeap(size_t offset, hwy::ThreadPool& pool) {
|
|||
template <typename MatT, size_t kRows, size_t kCols,
|
||||
size_t kNum = kRows * kCols,
|
||||
class MatPtr = std::unique_ptr<CompressedArray<MatT, kNum>>>
|
||||
MatPtr GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) {
|
||||
MatPtr GenerateTransposedMat(size_t offset, hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
FloatPtr content = hwy::AllocateAligned<float>(kNum);
|
||||
const float scale = SfpStream::kMax / (kNum + offset);
|
||||
|
|
@ -93,7 +94,7 @@ MatPtr GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) {
|
|||
template <typename MatT, size_t kRows, size_t kCols,
|
||||
size_t kNum = kRows * kCols,
|
||||
class MatPtr = std::unique_ptr<CompressedArray<MatT, kNum>>>
|
||||
MatPtr GenerateZeroMatHeap(hwy::ThreadPool& pool) {
|
||||
MatPtr GenerateZeroMat(hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
FloatPtr content = hwy::AllocateAligned<float>(kNum);
|
||||
HWY_ASSERT(content);
|
||||
|
|
@ -158,19 +159,20 @@ void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b,
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup.
|
||||
template <typename MatTA, typename MatTB, HWY_IF_T_SIZE_GT(MatTB, 1)>
|
||||
template <typename MatTA, typename MatTB, HWY_IF_NOT_BF16(MatTA)>
|
||||
HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc,
|
||||
const MatTA* HWY_RESTRICT a,
|
||||
const MatTB* HWY_RESTRICT b, const float scale,
|
||||
const MatTB* HWY_RESTRICT b_trans, const float scale,
|
||||
const float* add, float* HWY_RESTRICT out) {
|
||||
const hn::ScalableTag<float> df;
|
||||
const PackedSpan<const MatTB> b_span =
|
||||
MakeSpan(b_trans, cols_a_rows_b * cols_bc);
|
||||
for (size_t i = 0; i < rows_ac; ++i) {
|
||||
for (size_t k = 0; k < cols_a_rows_b; ++k) {
|
||||
for (size_t j = 0; j < cols_bc; ++j) {
|
||||
const float a1 = hwy::ConvertScalarTo<float>(a[i * cols_a_rows_b + k]);
|
||||
const float b1 = hwy::ConvertScalarTo<float>(b[k * cols_bc + j]);
|
||||
out[i * cols_bc + j] += scale * a1 * b1;
|
||||
}
|
||||
for (size_t j = 0; j < cols_bc; ++j) {
|
||||
out[i * cols_bc + j] = scale * Dot(df, b_span, j * cols_a_rows_b,
|
||||
a + i * cols_a_rows_b, cols_a_rows_b);
|
||||
}
|
||||
if (add != nullptr) {
|
||||
for (size_t j = 0; j < cols_bc; ++j) {
|
||||
|
|
@ -180,21 +182,20 @@ HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc,
|
|||
}
|
||||
}
|
||||
|
||||
// The above overload can handle combinations of f32 and bf16, but this one
|
||||
// is required for MatTB = {SFP, NUQ}.
|
||||
template <typename MatTA, typename MatTB, HWY_IF_T_SIZE(MatTB, 1)>
|
||||
// The above overload can handle A=f32 and any B; handle A=bf16 via Decompress.
|
||||
template <typename MatTA, typename MatTB, HWY_IF_BF16(MatTA)>
|
||||
HWY_INLINE void MatMulSlow(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc,
|
||||
const MatTA* HWY_RESTRICT a,
|
||||
const MatTB* HWY_RESTRICT b_compr, const float scale,
|
||||
const MatTB* HWY_RESTRICT b_trans, const float scale,
|
||||
const float* add, float* HWY_RESTRICT out) {
|
||||
const size_t num_b = cols_a_rows_b * cols_bc;
|
||||
FloatPtr b = hwy::AllocateAligned<float>(num_b);
|
||||
HWY_ASSERT(b);
|
||||
const size_t num_a = cols_a_rows_b * rows_ac;
|
||||
FloatPtr a_raw = hwy::AllocateAligned<float>(num_a);
|
||||
HWY_ASSERT(a_raw);
|
||||
const hn::ScalableTag<float> df;
|
||||
DecompressAndZeroPad(df, MakeSpan(b_compr, num_b), 0, b.get(), num_b);
|
||||
MatMulSlow(rows_ac, cols_a_rows_b, cols_bc, a, b.get(), scale, add, out);
|
||||
DecompressAndZeroPad(df, MakeSpan(a, num_a), 0, a_raw.get(), num_a);
|
||||
MatMulSlow(rows_ac, cols_a_rows_b, cols_bc, a_raw.get(), b_trans, scale, add,
|
||||
out);
|
||||
}
|
||||
|
||||
void PrintSpeed(const char* algo, size_t rows_ac, size_t cols_a_rows_b,
|
||||
size_t cols_bc, double elapsed) {
|
||||
const size_t num_b = cols_a_rows_b * cols_bc;
|
||||
|
|
@ -213,26 +214,23 @@ void TestMatMul(MatMulEnv& env) {
|
|||
TypeName<MatTB>());
|
||||
|
||||
std::unique_ptr<CompressedArray<MatTA, kRowsAC * kColsARowsB>> a =
|
||||
GenerateMatHeap<MatTA, kRowsAC, kColsARowsB>(0, pool);
|
||||
GenerateMat<MatTA, kRowsAC, kColsARowsB>(0, pool);
|
||||
std::unique_ptr<CompressedArray<MatTB, kColsARowsB * kColsBC>> b_trans =
|
||||
GenerateTransposeMatHeap<MatTB, kColsARowsB, kColsBC>(0, pool);
|
||||
GenerateTransposedMat<MatTB, kColsARowsB, kColsBC>(0, pool);
|
||||
FloatPtr c = hwy::AllocateAligned<float>(kRowsAC * kColsBC);
|
||||
HWY_ASSERT(c);
|
||||
|
||||
const float scale = a->scale() * b_trans->scale();
|
||||
std::unique_ptr<CompressedArray<float, kColsBC>> add;
|
||||
if (kAdd) {
|
||||
add = GenerateMatHeap<float, 1, kColsBC>(0, pool);
|
||||
add = GenerateMat<float, 1, kColsBC>(0, pool);
|
||||
add->set_scale(1.0f);
|
||||
}
|
||||
|
||||
std::unique_ptr<CompressedArray<MatTB, kColsARowsB * kColsBC>> b =
|
||||
GenerateMatHeap<MatTB, kColsARowsB, kColsBC>(0, pool);
|
||||
HWY_ASSERT_EQ(scale, a->scale() * b->scale());
|
||||
std::unique_ptr<CompressedArray<float, kRowsAC * kColsBC>> c_slow =
|
||||
GenerateZeroMatHeap<float, kRowsAC, kColsBC>(pool);
|
||||
GenerateZeroMat<float, kRowsAC, kColsBC>(pool);
|
||||
const double start_slow = hwy::platform::Now();
|
||||
MatMulSlow(kRowsAC, kColsARowsB, kColsBC, a->data(), b->data(), scale,
|
||||
MatMulSlow(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(), scale,
|
||||
kAdd ? add->data() : nullptr, c_slow->data());
|
||||
if (want_bench) {
|
||||
PrintSpeed("MatMulSlow", kRowsAC, kColsARowsB, kColsBC,
|
||||
|
|
|
|||
Loading…
Reference in New Issue