From 38eb452b94c162c981a306aed670a94df794712d Mon Sep 17 00:00:00 2001 From: Andrey Vlasov Date: Thu, 13 Jun 2024 02:06:41 -0700 Subject: [PATCH] Support mixed (bf16, sfp) tiled MatMul. Same sfp-decompress strategy as in (f32, sfp) tiled MatMul. PiperOrigin-RevId: 642901844 --- gemma/ops.h | 112 ++++++++++++++++++++++++++++++++++++++++++---- gemma/ops_test.cc | 2 + 2 files changed, 106 insertions(+), 8 deletions(-) diff --git a/gemma/ops.h b/gemma/ops.h index ca6fa8e..7b29b9f 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -345,9 +345,9 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, c23, c30, c31, c32, c33, tile_c, stride_c); } -// As above, for SfpStream. -template -HWY_INLINE void GEMM_4x4_Tile(const float* HWY_RESTRICT A, +// Same as above, but with mixed Mat types: (f32, sfp)). +template +HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const SfpStream* HWY_RESTRICT B, float* HWY_RESTRICT C, const size_t idx_tile, const size_t xtiles, const size_t stride_a, @@ -359,7 +359,7 @@ HWY_INLINE void GEMM_4x4_Tile(const float* HWY_RESTRICT A, const size_t row_a = idx_tile / xtiles * kRegRows; const size_t row_b_col_c = idx_tile % xtiles * kRegCols; - const hn::ScalableTag d; + const hn::ScalableTag d; const size_t N = Lanes(d); using V = hn::Vec; @@ -383,7 +383,7 @@ HWY_INLINE void GEMM_4x4_Tile(const float* HWY_RESTRICT A, V c32 = hn::Zero(d); V c33 = hn::Zero(d); - const float* HWY_RESTRICT tile_a = A + stride_a * row_a; + const MatTA* HWY_RESTRICT tile_a = A + stride_a * row_a; hwy::AlignedFreeUniquePtr tile_b_unique_ptr = hwy::AllocateAligned(kRegRows * kColsA_RowsB); @@ -432,7 +432,103 @@ HWY_INLINE void GEMM_4x4_Tile(const float* HWY_RESTRICT A, c23, c30, c31, c32, c33, tile_c, stride_c); } -// Same as above, but with mixed Mat types. +// Same as above, but with mixed Mat types: (bf16, sfp)). +template +HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, + const SfpStream* HWY_RESTRICT B, + float* HWY_RESTRICT C, const size_t idx_tile, + const size_t xtiles, const size_t stride_a, + const size_t stride_b, const size_t stride_c) { + constexpr size_t kRegRows = 4; + constexpr size_t kRegCols = 4; + + // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B. + const size_t row_a = idx_tile / xtiles * kRegRows; + const size_t row_b_col_c = idx_tile % xtiles * kRegCols; + + const hn::ScalableTag d32; + const size_t N = Lanes(d32); + using V = hn::Vec; + // TODO: Using half-vectors for now, it might be faster to + // PromoteLower/UpperTo, and more so to PromoteEven/OddTo if we have packed B + // accordingly. + const hn::Rebind d16; + HWY_DASSERT(Lanes(d16) == Lanes(d32)); + + V c00 = hn::Zero(d32); + V c01 = hn::Zero(d32); + V c02 = hn::Zero(d32); + V c03 = hn::Zero(d32); + + V c10 = hn::Zero(d32); + V c11 = hn::Zero(d32); + V c12 = hn::Zero(d32); + V c13 = hn::Zero(d32); + + V c20 = hn::Zero(d32); + V c21 = hn::Zero(d32); + V c22 = hn::Zero(d32); + V c23 = hn::Zero(d32); + + V c30 = hn::Zero(d32); + V c31 = hn::Zero(d32); + V c32 = hn::Zero(d32); + V c33 = hn::Zero(d32); + + const MatTA* HWY_RESTRICT tile_a = A + stride_a * row_a; + + hwy::AlignedFreeUniquePtr tile_b_unique_ptr = + hwy::AllocateAligned(kRegRows * kColsA_RowsB); + CompressTraits::Decompress( + d32, + /*in_capacity=*/0, B, stride_b * row_b_col_c, tile_b_unique_ptr.get(), + kRegRows * kColsA_RowsB); + const float* HWY_RESTRICT tile_b = tile_b_unique_ptr.get(); + + // Loop over columns of A and columns of the transposed B, in steps of N. + // Accumulates into the c## vectors. + HWY_UNROLL(1) + for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) { + const V b0 = hn::LoadU(d32, tile_b + stride_b * 0 + col_ab); + const V b1 = hn::LoadU(d32, tile_b + stride_b * 1 + col_ab); + const V b2 = hn::LoadU(d32, tile_b + stride_b * 2 + col_ab); + const V b3 = hn::LoadU(d32, tile_b + stride_b * 3 + col_ab); + + const V a0 = + hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 0 + col_ab)); + c00 = hn::MulAdd(a0, b0, c00); + c01 = hn::MulAdd(a0, b1, c01); + c02 = hn::MulAdd(a0, b2, c02); + c03 = hn::MulAdd(a0, b3, c03); + + const V a1 = + hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab)); + c10 = hn::MulAdd(a1, b0, c10); + c11 = hn::MulAdd(a1, b1, c11); + c12 = hn::MulAdd(a1, b2, c12); + c13 = hn::MulAdd(a1, b3, c13); + + const V a2 = + hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab)); + c20 = hn::MulAdd(a2, b0, c20); + c21 = hn::MulAdd(a2, b1, c21); + c22 = hn::MulAdd(a2, b2, c22); + c23 = hn::MulAdd(a2, b3, c23); + + const V a3 = + hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab)); + c30 = hn::MulAdd(a3, b0, c30); + c31 = hn::MulAdd(a3, b1, c31); + c32 = hn::MulAdd(a3, b2, c32); + c33 = hn::MulAdd(a3, b3, c33); + } + + float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; + StoreHorizontalSums(d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, + c22, c23, c30, c31, c32, c33, tile_c, stride_c); +} + +// Same as above, but with mixed Mat types: (f32, bf16). template HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, @@ -606,8 +702,8 @@ HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a, } } -template -HWY_INLINE void MatMulSlow(const float* HWY_RESTRICT a, +template +HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a, const SfpStream* HWY_RESTRICT b_sfp_stream, float* HWY_RESTRICT out) { const hn::ScalableTag d; diff --git a/gemma/ops_test.cc b/gemma/ops_test.cc index 7c007fd..aec0d7f 100644 --- a/gemma/ops_test.cc +++ b/gemma/ops_test.cc @@ -555,12 +555,14 @@ void TestAllTiledMatMul() { TestTiledMatMul<512, 512, 512, hwy::bfloat16_t>(); TestTiledMatMul<512, 512, 512, float, hwy::bfloat16_t>(); TestTiledMatMul<512, 512, 512, float, SfpStream>(); + TestTiledMatMul<512, 512, 512, hwy::bfloat16_t, SfpStream>(); // minimal non-square test TestTiledMatMul<4, 128, 4, float>(); TestTiledMatMul<4, 128, 4, hwy::bfloat16_t>(); TestTiledMatMul<4, 128, 4, float, hwy::bfloat16_t>(); TestTiledMatMul<32, 128, 32, float, SfpStream>(); + TestTiledMatMul<32, 128, 32, hwy::bfloat16_t, SfpStream>(); // large-scale test // TODO(philculliton): investigate rounding issues with large matrices