From c616abe6284a0d3cf52934088724abdbc4252476 Mon Sep 17 00:00:00 2001 From: Phil Culliton Date: Wed, 29 May 2024 13:01:58 -0700 Subject: [PATCH] Unrolled / tiled 4x4 MatMul PiperOrigin-RevId: 638384686 --- gemma/ops.h | 183 ++++++++++++++++++++++++++++++++++++++++++++-- gemma/ops_test.cc | 80 ++++++++++++++++++-- 2 files changed, 250 insertions(+), 13 deletions(-) diff --git a/gemma/ops.h b/gemma/ops.h index 90b0d13..1ea800b 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -22,6 +22,7 @@ #include #include +#include #include #include // std::enable_if_t @@ -93,11 +94,179 @@ HWY_INLINE constexpr size_t RowsPerStrip() { return kRowsPerStrip; } +// Processes a single 4x4 tile of A x B. Shared between static and dynamic +// versions. +template +HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, + const MatT* HWY_RESTRICT B, MatT* HWY_RESTRICT C, + size_t tile_num, const int xtiles, const int lda, + const int ldb, const int ldc) { + constexpr int RM = 4; + constexpr int RN = 4; + + // Calculate chunk start coords. + int ii = tile_num / xtiles * RM; + int jj = tile_num % xtiles * RN; + + const hn::ScalableTag d; + const size_t N = Lanes(d); + using V = hn::Vec; + + V c00 = hn::Zero(d); + V c01 = hn::Zero(d); + V c02 = hn::Zero(d); + V c03 = hn::Zero(d); + + V c10 = hn::Zero(d); + V c11 = hn::Zero(d); + V c12 = hn::Zero(d); + V c13 = hn::Zero(d); + + V c20 = hn::Zero(d); + V c21 = hn::Zero(d); + V c22 = hn::Zero(d); + V c23 = hn::Zero(d); + + V c30 = hn::Zero(d); + V c31 = hn::Zero(d); + V c32 = hn::Zero(d); + V c33 = hn::Zero(d); + + // Steps down the rows of A and B, and across width (kN) in steps of + // N (Lanes()). Accumulates into the cache vectors. hn::ReduceSum() is + // called on each of the cache vectors to sum the partial sums into C. + for (size_t l = 0; l < kColsA; l += N) { + V k0 = hn::LoadU(d, B + ldb * (jj + 0) + l); + V k1 = hn::LoadU(d, B + ldb * (jj + 1) + l); + V k2 = hn::LoadU(d, B + ldb * (jj + 2) + l); + V k3 = hn::LoadU(d, B + ldb * (jj + 3) + l); + + V a0 = hn::LoadU(d, A + lda * (ii + 0) + l); + c00 = hn::MulAdd(a0, k0, c00); + c01 = hn::MulAdd(a0, k1, c01); + c02 = hn::MulAdd(a0, k2, c02); + c03 = hn::MulAdd(a0, k3, c03); + + V a1 = hn::LoadU(d, A + lda * (ii + 1) + l); + c10 = hn::MulAdd(a1, k0, c10); + c11 = hn::MulAdd(a1, k1, c11); + c12 = hn::MulAdd(a1, k2, c12); + c13 = hn::MulAdd(a1, k3, c13); + + V a2 = hn::LoadU(d, A + lda * (ii + 2) + l); + c20 = hn::MulAdd(a2, k0, c20); + c21 = hn::MulAdd(a2, k1, c21); + c22 = hn::MulAdd(a2, k2, c22); + c23 = hn::MulAdd(a2, k3, c23); + + V a3 = hn::LoadU(d, A + lda * (ii + 3) + l); + c30 = hn::MulAdd(a3, k0, c30); + c31 = hn::MulAdd(a3, k1, c31); + c32 = hn::MulAdd(a3, k2, c32); + c33 = hn::MulAdd(a3, k3, c33); + } + + C[ldc * (ii + 0) + (jj + 0)] = hn::ReduceSum(d, c00); + C[ldc * (ii + 0) + (jj + 1)] = hn::ReduceSum(d, c01); + C[ldc * (ii + 0) + (jj + 2)] = hn::ReduceSum(d, c02); + C[ldc * (ii + 0) + (jj + 3)] = hn::ReduceSum(d, c03); + + C[ldc * (ii + 1) + (jj + 0)] = hn::ReduceSum(d, c10); + C[ldc * (ii + 1) + (jj + 1)] = hn::ReduceSum(d, c11); + C[ldc * (ii + 1) + (jj + 2)] = hn::ReduceSum(d, c12); + C[ldc * (ii + 1) + (jj + 3)] = hn::ReduceSum(d, c13); + + C[ldc * (ii + 2) + (jj + 0)] = hn::ReduceSum(d, c20); + C[ldc * (ii + 2) + (jj + 1)] = hn::ReduceSum(d, c21); + C[ldc * (ii + 2) + (jj + 2)] = hn::ReduceSum(d, c22); + C[ldc * (ii + 2) + (jj + 3)] = hn::ReduceSum(d, c23); + + C[ldc * (ii + 3) + (jj + 0)] = hn::ReduceSum(d, c30); + C[ldc * (ii + 3) + (jj + 1)] = hn::ReduceSum(d, c31); + C[ldc * (ii + 3) + (jj + 2)] = hn::ReduceSum(d, c32); + C[ldc * (ii + 3) + (jj + 3)] = hn::ReduceSum(d, c33); +} + +// Tiled 4x4 GEMM. Covers primary M =4..512, k = 3k/24k, n = 24k/3k use case. +// This version uses tiling suitable for static scheduling. +// Note: expects transposed / shuffled B. +template +void GEMM_4x4_Static(const MatT* HWY_RESTRICT A, const MatT* HWY_RESTRICT B, + MatT* HWY_RESTRICT C) { + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); // column step size + constexpr int RM = 4; // tile height + constexpr int RN = 4; // tile width + + HWY_ASSERT(kM % RM == 0); + HWY_ASSERT(kColsA % N == 0); + HWY_ASSERT(kColsA % RN == 0); + + int lda = kColsA; + int ldb = kColsA; // n instead of k because we're transposing + int ldc = kK; + + int ytiles = (kM) / RM; + int xtiles = (kK) / RN; // k instead of n because we're transposing + int tiles = xtiles * ytiles; + + for (int job = 0; job < tiles; ++job) { + GEMM_4x4_Tile(A, B, C, job, xtiles, lda, ldb, ldc); + } +} + +// Tiled 4x4 GEMM. Covers primary M =4..512, k = 3k/24k, n = 24k/3k use case. +// This version uses tiling and pooled threads. +// Note: expects transposed / shuffled B. +template +HWY_NOINLINE void MatMul_4x4_Impl(const MatT* HWY_RESTRICT A, + const MatT* HWY_RESTRICT B, + MatT* HWY_RESTRICT C, hwy::ThreadPool& pool) { + // Process 4x4 chunks of C in parallel. Each pool thread handles a single A x + // B tile. Note that C is being addressed directly without a buffer, and that + // the cache vectors (c00, c01, etc.) are being summed directly into C. There + // may be additional stability / speed gains to be made by using a buffer. + const hn::ScalableTag d; + const size_t N = Lanes(d); + + const int lda = kColsA; + const int ldb = kColsA; // n instead of k because we're transposing + const int ldc = kK; + + // 4x4 + const int RM = 4; + const int RN = 4; + + const int ytiles = (kM) / RM; + const int xtiles = (kK) / RN; // k instead of n because we're transposing + const int tiles = xtiles * ytiles; + + // 4x4 case requires kM % 4 == 0, kN % N == 0, kK % 4 == 0 + HWY_ASSERT(kM % RM == 0); + HWY_ASSERT(kColsA % N == 0); + HWY_ASSERT(kColsA % RN == 0); + HWY_ASSERT(kK % RN == 0); + HWY_ASSERT(kColsA >= N); + + // Handles a single 4x4 chunk, which is completed and then written into C. + pool.Run(0, tiles, [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR { + GEMM_4x4_Tile(A, B, C, chunk, xtiles, lda, ldb, ldc); + }); +} + +// Requires m % 4 == 0, n % Lanes() == 0, k % 4 == 0 +template +HWY_INLINE void MatMul_4x4(const MatT* HWY_RESTRICT a, + const MatT* HWY_RESTRICT b, MatT* HWY_RESTRICT out, + hwy::ThreadPool& pool) { + return MatMul_4x4_Impl(a, b, out, pool); +} + // Largely unoptimized; reordered innermost loops nets ~5-10X speedup on // ops_test across instruction sets. -template -HWY_INLINE void MatMul(const float* HWY_RESTRICT a, const float* HWY_RESTRICT b, - float* HWY_RESTRICT out) { +template +HWY_INLINE void MatMul(const MatT* HWY_RESTRICT a, const MatT* HWY_RESTRICT b, + MatT* HWY_RESTRICT out) { int i, j, k; for (i = 0; i < kM; ++i) { for (k = 0; k < kN; ++k) { @@ -167,8 +336,8 @@ HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0, namespace detail { // For each i = [0, num_rows), compute partial (length `num_cols`) dot product -// of row i with `vec_aligned` and add into `out[i]`. The upper-left coordinate -// of the tile is r0, c0. +// of row i with `vec_aligned` and add into `out[i]`. The upper-left +// coordinate of the tile is r0, c0. template HWY_INLINE void AccumulatePartialDotProducts( DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0, @@ -208,8 +377,8 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat, // Adds together partial dot products for all tiles with the same r0 (a // horizontal strip of the entire matrix); the result is the full dot product -// for rows r in [r0, r0 + num_rows) + optionally the add vector, which we store -// into in out[r - r0]. +// for rows r in [r0, r0 + num_rows) + optionally the add vector, which we +// store into in out[r - r0]. template HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat, diff --git a/gemma/ops_test.cc b/gemma/ops_test.cc index 5d59f63..0348687 100644 --- a/gemma/ops_test.cc +++ b/gemma/ops_test.cc @@ -350,9 +350,44 @@ CompressedArray GenerateMat(size_t offset) { const float scale = 1.0f / kInner; for (size_t i = 0; i < kOuter; i++) { for (size_t j = 0; j < kInner; j++) { - content[i * kInner + j] = static_cast((i + j + offset) * scale); + content[i * kInner + j] = + static_cast((i * kInner + j + offset) * scale); } } + + // for (size_t i = 0; i < kOuter; i++) { + // for (size_t j = 0; j < kInner; j++) { + // fprintf(stderr, "content[%lu] = %f\n", i * kInner + j, + // content[i * kInner + j]); + // } + // } + + Compress(content, ws, mat, pool); + mat.set_scale(1.0f); + return mat; +} + +template +CompressedArray GenerateTransposeMat(size_t offset) { + hwy::ThreadPool pool(0); + gcpp::CompressWorkingSet ws; + CompressedArray mat; + std::array content; + const float scale = 1.0f / kInner; + for (size_t i = 0; i < kOuter; i++) { + for (size_t j = 0; j < kInner; j++) { + content[j * kOuter + i] = + static_cast((i * kInner + j + offset) * scale); + } + } + + // for (size_t i = 0; i < kOuter; i++) { + // for (size_t j = 0; j < kInner; j++) { + // fprintf(stderr, "content[%lu] = %f (transpose)\n", i * kInner + j, + // content[i * kInner + j]); + // } + // } + Compress(content, ws, mat, pool); mat.set_scale(1.0f); return mat; @@ -403,6 +438,7 @@ hwy::AlignedFreeUniquePtr SimpleMatMul( } } } + return out; } @@ -445,7 +481,7 @@ void AssertClose(const hwy::AlignedFreeUniquePtr& expected, double actual_value = hwy::ConvertScalarTo(actual[idx]); const double tolerance = - expected_value * 20 * 1.0 / (1ULL << hwy::MantissaBits()); + expected_value * 21 * 1.0 / (1ULL << hwy::MantissaBits()); if (!(expected_value - tolerance <= actual_value && actual_value <= expected_value + tolerance)) { @@ -456,11 +492,11 @@ void AssertClose(const hwy::AlignedFreeUniquePtr& expected, } } -void TestMatMul() { +void TestTiledMatMul() { hwy::ThreadPool pool(0); - constexpr size_t kM = 128 * 3; // 384 - constexpr size_t kK = 128 * 5; // 640 - constexpr size_t kN = 128 * 6; // 768 + constexpr size_t kM = 512; // 384 + constexpr size_t kN = 512; // * 5; // 6; // 768 + constexpr size_t kK = 512; // * 5; // 640 CompressedArray a1 = GenerateMat(0); CompressedArray b1 = GenerateMat(0); @@ -478,6 +514,37 @@ void TestMatMul() { hwy::AlignedFreeUniquePtr c = hwy::AllocateAligned(kM * kK); Decompress(compressed_c, 0, c.get(), kM * kK); + CompressedArray b1_trans = GenerateTransposeMat(0); + hwy::AlignedFreeUniquePtr b_trans = + hwy::AllocateAligned(kN * kK); + Decompress(b1_trans, 0, b_trans.get(), kN * kK); + MatMul_4x4(a.get(), b_trans.get(), c.get(), pool); + + AssertClose(expected_out1, c, kM * kK); +} + +void TestMatMul() { + constexpr size_t kM = 512; // 384 + constexpr size_t kN = 512; // * 5; // 6; // 768 + constexpr size_t kK = 512; // * 5; // 640 + + CompressedArray a1 = GenerateMat(0); + CompressedArray b1 = GenerateMat(0); + + hwy::AlignedFreeUniquePtr a = hwy::AllocateAligned(kM * kN); + Decompress(a1, 0, a.get(), kM * kN); + + hwy::AlignedFreeUniquePtr b = hwy::AllocateAligned(kN * kK); + Decompress(b1, 0, b.get(), kN * kK); + + hwy::AlignedFreeUniquePtr expected_out1 = + SimpleMatMul(a, b); + + CompressedArray compressed_c = GenerateZeroMat(0); + hwy::AlignedFreeUniquePtr c = hwy::AllocateAligned(kM * kK); + Decompress(compressed_c, 0, c.get(), kM * kK); + + Decompress(b1, 0, b.get(), kN * kK); MatMul(a.get(), b.get(), c.get()); AssertClose(expected_out1, c, kM * kK); @@ -583,6 +650,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); +HWY_EXPORT_AND_TEST_P(OpsTest, TestTiledMatMul); HWY_EXPORT_AND_TEST_P(OpsTest, TestMatMul); HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);