From be765afce2abcf080e95e713e66b30b63bdcb950 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 11 Jul 2024 06:58:07 -0700 Subject: [PATCH] Simplify matmul: only 2 overloads Also add StoreHorizontalSumsMaybeAdd wrapper function, move MatMulSlowBatch into test. 1.02-1.06x speedup. PiperOrigin-RevId: 651394791 --- gemma/ops.h | 531 ++++++---------------------------------------- gemma/ops_test.cc | 37 ++++ 2 files changed, 103 insertions(+), 465 deletions(-) diff --git a/gemma/ops.h b/gemma/ops.h index 8a31f70..3415c4c 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -28,7 +28,6 @@ #include // std::enable_if_t #include "compression/sfp.h" -#include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/detect_targets.h" @@ -142,14 +141,14 @@ HWY_INLINE void StoreHorizontalSums(DF df, VF c00, VF c01, VF c02, VF c03, } // Completes the tile by summing across the vectors, and adds the biases. -template , typename AddT> +template > HWY_INLINE void StoreHorizontalSumsAdd(DF df, VF c00, VF c01, VF c02, VF c03, - VF c10, VF c11, VF c12, VF c13, // - VF c20, VF c21, VF c22, VF c23, // - VF c30, VF c31, VF c32, VF c33, - const AddT* add, - float* HWY_RESTRICT tile_c, - size_t stride_c) { + VF c10, VF c11, VF c12, VF c13, // + VF c20, VF c21, VF c22, VF c23, // + VF c30, VF c31, VF c32, VF c33, + const float* HWY_RESTRICT add, + float* HWY_RESTRICT tile_c, + size_t stride_c) { // We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles. // Each entry of C[r,c] is a dot product of A.row and B.col, which reside in // the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is @@ -182,103 +181,23 @@ HWY_INLINE void StoreHorizontalSumsAdd(DF df, VF c00, VF c01, VF c02, VF c03, tile_c[stride_c * 3 + 3] = hn::ReduceSum(df, c33) + addon3; } -// Accumulates a single kNumRowsx4 tile of A x B into C. B is transposed, so we -// can iterate over both A and B with consecutive vector loads. kNumRows<=4. -// Shared between parallelized and sequential (loop) callers. -template -HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, - const MatT* HWY_RESTRICT B, MatT* HWY_RESTRICT C, - const AddT* add, - const size_t idx_tile, const size_t xtiles, - const size_t stride_a, const size_t stride_b, - const size_t stride_c) { - // Tile size. 4x unrolling makes sense on many platforms because we can fit - // 4x4 accumulators and 8 temporaries in the 32 vectors; we have more than - // #FMA units * FMA latency (up to 2*5) independent computations in flight; - // threads write in units of 4*N elements, which is at least one cache line. - constexpr size_t kRegRows = 4; - constexpr size_t kRegCols = 4; - static_assert(kNumRows <= kRegRows); - - // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B. - const size_t row_a = idx_tile / xtiles * kRegRows; - const size_t row_b_col_c = idx_tile % xtiles * kRegCols; - - const hn::ScalableTag d; - const size_t N = Lanes(d); - using V = hn::Vec; - - V c00 = hn::Zero(d); - V c01 = hn::Zero(d); - V c02 = hn::Zero(d); - V c03 = hn::Zero(d); - - V c10 = hn::Zero(d); - V c11 = hn::Zero(d); - V c12 = hn::Zero(d); - V c13 = hn::Zero(d); - - V c20 = hn::Zero(d); - V c21 = hn::Zero(d); - V c22 = hn::Zero(d); - V c23 = hn::Zero(d); - - V c30 = hn::Zero(d); - V c31 = hn::Zero(d); - V c32 = hn::Zero(d); - V c33 = hn::Zero(d); - - const MatT* HWY_RESTRICT tile_a = A + stride_a * row_a; - const MatT* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c; - - // Loop over columns of A and columns of the transposed B, in steps of N. - // Accumulates into the c## vectors. - HWY_UNROLL(1) - for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) { - const V b0 = hn::LoadU(d, tile_b + stride_b * 0 + col_ab); - const V b1 = hn::LoadU(d, tile_b + stride_b * 1 + col_ab); - const V b2 = hn::LoadU(d, tile_b + stride_b * 2 + col_ab); - const V b3 = hn::LoadU(d, tile_b + stride_b * 3 + col_ab); - - const V a0 = hn::LoadU(d, tile_a + stride_a * 0 + col_ab); - c00 = hn::MulAdd(a0, b0, c00); - c01 = hn::MulAdd(a0, b1, c01); - c02 = hn::MulAdd(a0, b2, c02); - c03 = hn::MulAdd(a0, b3, c03); - if constexpr (kNumRows == 1) continue; - - const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab); - c10 = hn::MulAdd(a1, b0, c10); - c11 = hn::MulAdd(a1, b1, c11); - c12 = hn::MulAdd(a1, b2, c12); - c13 = hn::MulAdd(a1, b3, c13); - if constexpr (kNumRows == 2) continue; - - const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab); - c20 = hn::MulAdd(a2, b0, c20); - c21 = hn::MulAdd(a2, b1, c21); - c22 = hn::MulAdd(a2, b2, c22); - c23 = hn::MulAdd(a2, b3, c23); - if constexpr (kNumRows == 3) continue; - - const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab); - c30 = hn::MulAdd(a3, b0, c30); - c31 = hn::MulAdd(a3, b1, c31); - c32 = hn::MulAdd(a3, b2, c32); - c33 = hn::MulAdd(a3, b3, c33); - } - - float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; +// Wrapper around StoreHorizontalSums and StoreHorizontalSumsAdd to shorten call +// sites. If `!kAdd`, `add` is nullptr, so adding `add_offset` to it would be +// UB, hence we pass it as a separate argument. +template > +HWY_INLINE void StoreHorizontalSumsMaybeAdd( + DF df, VF c00, VF c01, VF c02, VF c03, VF c10, VF c11, VF c12, VF c13, + VF c20, VF c21, VF c22, VF c23, VF c30, VF c31, VF c32, VF c33, + const float* HWY_RESTRICT add, size_t add_offset, + float* HWY_RESTRICT tile_c, size_t stride_c) { if constexpr (kAdd) { - const AddT* tile_add = add + row_b_col_c; - StoreHorizontalSumsAdd(d, c00, c01, c02, c03, c10, c11, c12, c13, + StoreHorizontalSumsAdd(df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31, c32, c33, - tile_add, tile_c, stride_c); + add + add_offset, tile_c, stride_c); } else { - StoreHorizontalSums( - d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, - c30, c31, c32, c33, tile_c, stride_c); + StoreHorizontalSums(df, c00, c01, c02, c03, c10, c11, c12, c13, + c20, c21, c22, c23, c30, c31, c32, c33, + tile_c, stride_c); } } @@ -290,12 +209,13 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, #define GEMMA_NATIVE_BF16 0 #endif -// As above, for MatT=bf16 -template -HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, - const MatT* HWY_RESTRICT B, float* HWY_RESTRICT C, - const AddT* add, +#if GEMMA_NATIVE_BF16 + +// Specialization for f32 += bf16 * bf16 that avoids promoting to f32. +template +HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A, + const hwy::bfloat16_t* HWY_RESTRICT B, + float* HWY_RESTRICT C, const float* add, const size_t idx_tile, const size_t xtiles, const size_t stride_a, const size_t stride_b, const size_t stride_c) { @@ -309,17 +229,10 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, const hn::ScalableTag df; using VF = hn::Vec; -#if GEMMA_NATIVE_BF16 // ReorderWidenMulAccumulate does not use its sum1 arg and we can use full // bf16 vectors. - const hn::Repartition d; + const hn::Repartition d; VF unused_sum1 = hn::Zero(df); -#else - // Emulated: use half-vectors of bf16 because we cannot afford two sums for - // each c##. - const hn::Rebind d; - HWY_DASSERT(Lanes(d) == Lanes(df)); -#endif const size_t N = Lanes(d); @@ -343,14 +256,13 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, VF c32 = hn::Zero(df); VF c33 = hn::Zero(df); - const MatT* HWY_RESTRICT tile_a = A + stride_a * row_a; - const MatT* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c; + const hwy::bfloat16_t* HWY_RESTRICT tile_a = A + stride_a * row_a; + const hwy::bfloat16_t* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c; // Loop over columns of A and columns of the transposed B, in steps of N. // Accumulates into the c## vectors. HWY_UNROLL(1) for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) { -#if GEMMA_NATIVE_BF16 using V = hn::Vec; const V b0 = hn::LoadU(d, tile_b + stride_b * 0 + col_ab); const V b1 = hn::LoadU(d, tile_b + stride_b * 1 + col_ab); @@ -383,89 +295,44 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, c31 = hn::ReorderWidenMulAccumulate(df, a3, b1, c31, unused_sum1); c32 = hn::ReorderWidenMulAccumulate(df, a3, b2, c32, unused_sum1); c33 = hn::ReorderWidenMulAccumulate(df, a3, b3, c33, unused_sum1); -#else // Emulated: promote bf16 to f32 - const VF b0 = - hn::PromoteTo(df, hn::LoadU(d, tile_b + stride_b * 0 + col_ab)); - const VF b1 = - hn::PromoteTo(df, hn::LoadU(d, tile_b + stride_b * 1 + col_ab)); - const VF b2 = - hn::PromoteTo(df, hn::LoadU(d, tile_b + stride_b * 2 + col_ab)); - const VF b3 = - hn::PromoteTo(df, hn::LoadU(d, tile_b + stride_b * 3 + col_ab)); - - const VF a0 = - hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 0 + col_ab)); - c00 = hn::MulAdd(a0, b0, c00); - c01 = hn::MulAdd(a0, b1, c01); - c02 = hn::MulAdd(a0, b2, c02); - c03 = hn::MulAdd(a0, b3, c03); - if constexpr (kNumRows == 1) continue; - - const VF a1 = - hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 1 + col_ab)); - c10 = hn::MulAdd(a1, b0, c10); - c11 = hn::MulAdd(a1, b1, c11); - c12 = hn::MulAdd(a1, b2, c12); - c13 = hn::MulAdd(a1, b3, c13); - if constexpr (kNumRows == 2) continue; - - const VF a2 = - hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 2 + col_ab)); - c20 = hn::MulAdd(a2, b0, c20); - c21 = hn::MulAdd(a2, b1, c21); - c22 = hn::MulAdd(a2, b2, c22); - c23 = hn::MulAdd(a2, b3, c23); - if constexpr (kNumRows == 3) continue; - - const VF a3 = - hn::PromoteTo(df, hn::LoadU(d, tile_a + stride_a * 3 + col_ab)); - c30 = hn::MulAdd(a3, b0, c30); - c31 = hn::MulAdd(a3, b1, c31); - c32 = hn::MulAdd(a3, b2, c32); - c33 = hn::MulAdd(a3, b3, c33); -#endif // !GEMMA_NATIVE_BF16 } -#if GEMMA_NATIVE_BF16 // Ensure sum1 was indeed unused. HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df)))); -#endif float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; - if constexpr (kAdd) { - const AddT* dd = add + row_b_col_c; - StoreHorizontalSumsAdd(df, c00, c01, c02, c03, c10, c11, c12, c13, - c20, c21, c22, c23, c30, c31, c32, c33, dd, - tile_c, stride_c); - } else { - StoreHorizontalSums( - df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, - c30, c31, c32, c33, tile_c, stride_c); - } + StoreHorizontalSumsMaybeAdd( + df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31, + c32, c33, add, row_b_col_c, tile_c, stride_c); } +#endif // GEMMA_NATIVE_BF16 + +// The col_ab loop is unrolled 2x, so we have two consecutive a0/a1 and b00/b01 +// etc. Multiplies a[c] with b[r,c] and adds to c[r]. template HWY_INLINE void UpdateTileRow(const VF& a0, const VF& a1, const VF& b00, const VF& b01, const VF& b10, const VF& b11, const VF& b20, const VF& b21, const VF& b30, const VF& b31, VF& c0, VF& c1, VF& c2, VF& c3) { - c0 = MulAdd(a0, b00, c0); - c0 = MulAdd(a1, b01, c0); - c1 = MulAdd(a0, b10, c1); - c1 = MulAdd(a1, b11, c1); - c2 = MulAdd(a0, b20, c2); - c2 = MulAdd(a1, b21, c2); - c3 = MulAdd(a0, b30, c3); - c3 = MulAdd(a1, b31, c3); + c0 = hn::MulAdd(a0, b00, c0); + c1 = hn::MulAdd(a0, b10, c1); + c2 = hn::MulAdd(a0, b20, c2); + c3 = hn::MulAdd(a0, b30, c3); + c0 = hn::MulAdd(a1, b01, c0); + c1 = hn::MulAdd(a1, b11, c1); + c2 = hn::MulAdd(a1, b21, c2); + c3 = hn::MulAdd(a1, b31, c3); } -// Same as above, for when there exists CompressTraits::Decompress2 and -// MatTB is compressed. +// Accumulates a single kNumRowsx4 tile of A x B into C. B is transposed, so we +// can iterate over both A and B with consecutive vector loads. kNumRows<=4. +// General case: uses CompressTraits to load from A and B. template + typename MatTB> HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B, - float* HWY_RESTRICT C, const AddT* add, + float* HWY_RESTRICT C, const float* add, const size_t idx_tile, const size_t xtiles, const size_t stride_a, const size_t stride_b, const size_t stride_c) { @@ -473,6 +340,9 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, constexpr size_t kRegCols = 4; static_assert(kNumRows <= kRegRows); + using TraitsA = CompressTraits; + using TraitsB = CompressTraits; + // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B. const size_t row_a = idx_tile / xtiles * kRegRows; const size_t row_b_col_c = idx_tile % xtiles * kRegCols; @@ -512,273 +382,42 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, HWY_UNROLL(1) for (; col_ab <= kColsA_RowsB - 2 * N; col_ab += 2 * N) { V b00, b01; - CompressTraits::Decompress2( - d32, B, tile_b_ofs + stride_b * 0 + col_ab, b00, b01); + TraitsB::Decompress2(d32, B, tile_b_ofs + stride_b * 0 + col_ab, b00, b01); V b10, b11; - CompressTraits::Decompress2( - d32, B, tile_b_ofs + stride_b * 1 + col_ab, b10, b11); + TraitsB::Decompress2(d32, B, tile_b_ofs + stride_b * 1 + col_ab, b10, b11); V b20, b21; - CompressTraits::Decompress2( - d32, B, tile_b_ofs + stride_b * 2 + col_ab, b20, b21); + TraitsB::Decompress2(d32, B, tile_b_ofs + stride_b * 2 + col_ab, b20, b21); V b30, b31; - CompressTraits::Decompress2( - d32, B, tile_b_ofs + stride_b * 3 + col_ab, b30, b31); + TraitsB::Decompress2(d32, B, tile_b_ofs + stride_b * 3 + col_ab, b30, b31); V a00, a01; - CompressTraits::Decompress2( - d32, A, tile_a_ofs + stride_a * 0 + col_ab, a00, a01); + TraitsA::Decompress2(d32, A, tile_a_ofs + stride_a * 0 + col_ab, a00, a01); UpdateTileRow(a00, a01, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01, c02, c03); - if constexpr (kNumRows == 1) continue; V a10, a11; - CompressTraits::Decompress2( - d32, A, tile_a_ofs + stride_a * 1 + col_ab, a10, a11); + TraitsA::Decompress2(d32, A, tile_a_ofs + stride_a * 1 + col_ab, a10, a11); UpdateTileRow(a10, a11, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11, c12, c13); - if constexpr (kNumRows == 2) continue; V a20, a21; - CompressTraits::Decompress2( - d32, A, tile_a_ofs + stride_a * 2 + col_ab, a20, a21); + TraitsA::Decompress2(d32, A, tile_a_ofs + stride_a * 2 + col_ab, a20, a21); UpdateTileRow(a20, a21, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21, c22, c23); - if constexpr (kNumRows == 3) continue; V a30, a31; - CompressTraits::Decompress2( - d32, A, tile_a_ofs + stride_a * 3 + col_ab, a30, a31); + TraitsA::Decompress2(d32, A, tile_a_ofs + stride_a * 3 + col_ab, a30, a31); UpdateTileRow(a30, a31, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31, c32, c33); } float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; - if constexpr (kAdd) { - const AddT* dd = add + row_b_col_c; - StoreHorizontalSumsAdd(d32, c00, c01, c02, c03, c10, c11, c12, - c13, c20, c21, c22, c23, c30, c31, c32, - c33, dd, tile_c, stride_c); - } else { - StoreHorizontalSums(d32, c00, c01, c02, c03, c10, c11, c12, c13, - c20, c21, c22, c23, c30, c31, c32, c33, - tile_c, stride_c); - } -} - -// Same as above, but with mixed Mat types: (f32, bf16). -template -HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, - const MatTB* HWY_RESTRICT B, - float* HWY_RESTRICT C, const AddT* add, - const size_t idx_tile, - const size_t xtiles, const size_t stride_a, - const size_t stride_b, const size_t stride_c) { - constexpr size_t kRegRows = 4; - constexpr size_t kRegCols = 4; - static_assert(kNumRows <= kRegRows); - - // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B. - const size_t row_a = idx_tile / xtiles * kRegRows; - const size_t row_b_col_c = idx_tile % xtiles * kRegCols; - - const hn::ScalableTag d32; - using VF = hn::Vec; - - // TODO: Using half-vectors for now, it might be faster to - // PromoteLower/UpperTo, and more so to PromoteEven/OddTo if we have packed B - // accordingly. - const hn::Rebind d16; - HWY_DASSERT(Lanes(d16) == Lanes(d32)); - - const size_t N = Lanes(d16); - - VF c00 = hn::Zero(d32); - VF c01 = hn::Zero(d32); - VF c02 = hn::Zero(d32); - VF c03 = hn::Zero(d32); - - VF c10 = hn::Zero(d32); - VF c11 = hn::Zero(d32); - VF c12 = hn::Zero(d32); - VF c13 = hn::Zero(d32); - - VF c20 = hn::Zero(d32); - VF c21 = hn::Zero(d32); - VF c22 = hn::Zero(d32); - VF c23 = hn::Zero(d32); - - VF c30 = hn::Zero(d32); - VF c31 = hn::Zero(d32); - VF c32 = hn::Zero(d32); - VF c33 = hn::Zero(d32); - - const MatTA* HWY_RESTRICT tile_a = A + stride_a * row_a; - const MatTB* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c; - - // Loop over columns of A and columns of the transposed B, in steps of N. - // Accumulates into the c## vectors. - HWY_UNROLL(1) - for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) { - // Promote bf16 to f32 - const VF b0 = - hn::PromoteTo(d32, hn::LoadU(d16, tile_b + stride_b * 0 + col_ab)); - const VF b1 = - hn::PromoteTo(d32, hn::LoadU(d16, tile_b + stride_b * 1 + col_ab)); - const VF b2 = - hn::PromoteTo(d32, hn::LoadU(d16, tile_b + stride_b * 2 + col_ab)); - const VF b3 = - hn::PromoteTo(d32, hn::LoadU(d16, tile_b + stride_b * 3 + col_ab)); - - const VF a0 = hn::LoadU(d32, tile_a + stride_a * 0 + col_ab); - c00 = hn::MulAdd(a0, b0, c00); - c01 = hn::MulAdd(a0, b1, c01); - c02 = hn::MulAdd(a0, b2, c02); - c03 = hn::MulAdd(a0, b3, c03); - if constexpr (kNumRows == 1) continue; - - const VF a1 = hn::LoadU(d32, tile_a + stride_a * 1 + col_ab); - c10 = hn::MulAdd(a1, b0, c10); - c11 = hn::MulAdd(a1, b1, c11); - c12 = hn::MulAdd(a1, b2, c12); - c13 = hn::MulAdd(a1, b3, c13); - if constexpr (kNumRows == 2) continue; - - const VF a2 = hn::LoadU(d32, tile_a + stride_a * 2 + col_ab); - c20 = hn::MulAdd(a2, b0, c20); - c21 = hn::MulAdd(a2, b1, c21); - c22 = hn::MulAdd(a2, b2, c22); - c23 = hn::MulAdd(a2, b3, c23); - if constexpr (kNumRows == 3) continue; - - const VF a3 = hn::LoadU(d32, tile_a + stride_a * 3 + col_ab); - c30 = hn::MulAdd(a3, b0, c30); - c31 = hn::MulAdd(a3, b1, c31); - c32 = hn::MulAdd(a3, b2, c32); - c33 = hn::MulAdd(a3, b3, c33); - } - - float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; - if constexpr (kAdd) { - const AddT* dd = add + row_b_col_c; - StoreHorizontalSumsAdd(d32, c00, c01, c02, c03, c10, c11, c12, - c13, c20, c21, c22, c23, c30, c31, c32, - c33, dd, tile_c, stride_c); - } else { - StoreHorizontalSums( - d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, - c30, c31, c32, c33, tile_c, stride_c); - } -} - -// Same as above, but with mixed Mat types: (bf16, f32). -template -HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, - const MatTB* HWY_RESTRICT B, - float* HWY_RESTRICT C, const AddT* add, - const size_t idx_tile, - const size_t xtiles, const size_t stride_a, - const size_t stride_b, const size_t stride_c) { - constexpr size_t kRegRows = 4; - constexpr size_t kRegCols = 4; - static_assert(kNumRows <= kRegRows); - - // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B. - const size_t row_a = idx_tile / xtiles * kRegRows; - const size_t row_b_col_c = idx_tile % xtiles * kRegCols; - - const hn::ScalableTag d32; - using VF = hn::Vec; - - // TODO: Using half-vectors for now, it might be faster to - // PromoteLower/UpperTo, and more so to PromoteEven/OddTo if we have packed B - // accordingly. - const hn::Rebind d16; - HWY_DASSERT(Lanes(d16) == Lanes(d32)); - - const size_t N = Lanes(d16); - - VF c00 = hn::Zero(d32); - VF c01 = hn::Zero(d32); - VF c02 = hn::Zero(d32); - VF c03 = hn::Zero(d32); - - VF c10 = hn::Zero(d32); - VF c11 = hn::Zero(d32); - VF c12 = hn::Zero(d32); - VF c13 = hn::Zero(d32); - - VF c20 = hn::Zero(d32); - VF c21 = hn::Zero(d32); - VF c22 = hn::Zero(d32); - VF c23 = hn::Zero(d32); - - VF c30 = hn::Zero(d32); - VF c31 = hn::Zero(d32); - VF c32 = hn::Zero(d32); - VF c33 = hn::Zero(d32); - - const MatTA* HWY_RESTRICT tile_a = A + stride_a * row_a; - const MatTB* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c; - - // Loop over columns of A and columns of the transposed B, in steps of N. - // Accumulates into the c## vectors. - HWY_UNROLL(1) - for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) { - // Promote bf16 to f32 - const VF b0 = hn::LoadU(d32, tile_b + stride_b * 0 + col_ab); - const VF b1 = hn::LoadU(d32, tile_b + stride_b * 1 + col_ab); - const VF b2 = hn::LoadU(d32, tile_b + stride_b * 2 + col_ab); - const VF b3 = hn::LoadU(d32, tile_b + stride_b * 3 + col_ab); - - const VF a0 = - hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 0 + col_ab)); - c00 = hn::MulAdd(a0, b0, c00); - c01 = hn::MulAdd(a0, b1, c01); - c02 = hn::MulAdd(a0, b2, c02); - c03 = hn::MulAdd(a0, b3, c03); - if constexpr (kNumRows == 1) continue; - - const VF a1 = - hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab)); - c10 = hn::MulAdd(a1, b0, c10); - c11 = hn::MulAdd(a1, b1, c11); - c12 = hn::MulAdd(a1, b2, c12); - c13 = hn::MulAdd(a1, b3, c13); - if constexpr (kNumRows == 2) continue; - - const VF a2 = - hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab)); - c20 = hn::MulAdd(a2, b0, c20); - c21 = hn::MulAdd(a2, b1, c21); - c22 = hn::MulAdd(a2, b2, c22); - c23 = hn::MulAdd(a2, b3, c23); - if constexpr (kNumRows == 3) continue; - - const VF a3 = - hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab)); - c30 = hn::MulAdd(a3, b0, c30); - c31 = hn::MulAdd(a3, b1, c31); - c32 = hn::MulAdd(a3, b2, c32); - c33 = hn::MulAdd(a3, b3, c33); - } - - float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; - if constexpr (kAdd) { - const AddT* dd = add + row_b_col_c; - StoreHorizontalSumsAdd(d32, c00, c01, c02, c03, c10, c11, c12, - c13, c20, c21, c22, c23, c30, c31, c32, - c33, dd, tile_c, stride_c); - } else { - StoreHorizontalSums( - d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, - c30, c31, c32, c33, tile_c, stride_c); - } + StoreHorizontalSumsMaybeAdd( + d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31, + c32, c33, add, row_b_col_c, tile_c, stride_c); } // Tiled 4x4 GEMM. Typically batch_size is 1..512, kColsA_RowsB is 3k or 24k, @@ -842,45 +481,7 @@ HWY_NOINLINE void MatMul_4x4_Batch( batch_size, A, B, C, /*add=*/static_cast(nullptr), pool); } -// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on -// ops_test across instruction sets. -template -HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a, - const MatTB* HWY_RESTRICT b, - const float* add, - float* HWY_RESTRICT out) { - for (size_t i = 0; i < batch_size; ++i) { - for (size_t k = 0; k < kN; ++k) { - for (size_t j = 0; j < kK; ++j) { - const float a1 = hwy::ConvertScalarTo(a[i * kN + k]); - const float b1 = hwy::ConvertScalarTo(b[k * kK + j]); - out[i * kK + j] += a1 * b1; - } - } - if (add != nullptr) { - for (size_t j = 0; j < kK; ++j) { - out[i * kK + j] += add[j]; - } - } - } -} - -// The above overload can handle combinations of f32 and bf16, but this one -// is required for MatTB = {SFP, NUQ}. -template -HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a, - const MatTB* HWY_RESTRICT b_compr, - const float* add, - float* HWY_RESTRICT out) { - const hn::ScalableTag d; - hwy::AlignedFreeUniquePtr b = hwy::AllocateAligned(kK * kN); - CompressTraits::Decompress(d, /*in_capacity=*/0, b_compr, 0, b.get(), - kK * kN); - MatMulSlowBatch(batch_size, a, b.get(), add, out); -} - +//------------------------------------------------------------------------------ HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned, const size_t size, float* HWY_RESTRICT out) { const hn::ScalableTag df; diff --git a/gemma/ops_test.cc b/gemma/ops_test.cc index 2eddb7a..48429ca 100644 --- a/gemma/ops_test.cc +++ b/gemma/ops_test.cc @@ -508,6 +508,43 @@ void AssertClose(const hwy::AlignedFreeUniquePtr& a, } } +// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on +// ops_test across instruction sets. +template +HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a, + const MatTB* HWY_RESTRICT b, const float* add, + float* HWY_RESTRICT out) { + for (size_t i = 0; i < batch_size; ++i) { + for (size_t k = 0; k < kN; ++k) { + for (size_t j = 0; j < kK; ++j) { + const float a1 = hwy::ConvertScalarTo(a[i * kN + k]); + const float b1 = hwy::ConvertScalarTo(b[k * kK + j]); + out[i * kK + j] += a1 * b1; + } + } + if (add != nullptr) { + for (size_t j = 0; j < kK; ++j) { + out[i * kK + j] += add[j]; + } + } + } +} + +// The above overload can handle combinations of f32 and bf16, but this one +// is required for MatTB = {SFP, NUQ}. +template +HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a, + const MatTB* HWY_RESTRICT b_compr, + const float* add, float* HWY_RESTRICT out) { + const hn::ScalableTag d; + hwy::AlignedFreeUniquePtr b = hwy::AllocateAligned(kK * kN); + CompressTraits::Decompress(d, /*in_capacity=*/0, b_compr, 0, b.get(), + kK * kN); + MatMulSlowBatch(batch_size, a, b.get(), add, out); +} + template void TestTiledBatchMatMul() {