From f467670de7d1925d93a5c1ab4daa03371214a1de Mon Sep 17 00:00:00 2001 From: "The gemma.cpp Authors" Date: Wed, 12 Jun 2024 01:11:28 -0700 Subject: [PATCH] Implement float * SfpStream matmul by decompressing 4 * kColsA_RowsB -sized chunks of the second matrix. PiperOrigin-RevId: 642533996 --- gemma/ops.h | 100 ++++++++++++++++++++++++++++++++++++++++++++++ gemma/ops_test.cc | 13 +++--- 2 files changed, 107 insertions(+), 6 deletions(-) diff --git a/gemma/ops.h b/gemma/ops.h index 9089745..ff6ba58 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -27,6 +27,7 @@ #include // std::enable_if_t #include "compression/sfp.h" +#include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/detect_targets.h" @@ -344,6 +345,93 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, c23, c30, c31, c32, c33, tile_c, stride_c); } +// As above, for SfpStream. +template +HWY_INLINE void GEMM_4x4_Tile(const float* HWY_RESTRICT A, + const SfpStream* HWY_RESTRICT B, + float* HWY_RESTRICT C, const size_t idx_tile, + const size_t xtiles, const size_t stride_a, + const size_t stride_b, const size_t stride_c) { + constexpr size_t kRegRows = 4; + constexpr size_t kRegCols = 4; + + // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B. + const size_t row_a = idx_tile / xtiles * kRegRows; + const size_t row_b_col_c = idx_tile % xtiles * kRegCols; + + const hn::ScalableTag d; + const size_t N = Lanes(d); + using V = hn::Vec; + + V c00 = hn::Zero(d); + V c01 = hn::Zero(d); + V c02 = hn::Zero(d); + V c03 = hn::Zero(d); + + V c10 = hn::Zero(d); + V c11 = hn::Zero(d); + V c12 = hn::Zero(d); + V c13 = hn::Zero(d); + + V c20 = hn::Zero(d); + V c21 = hn::Zero(d); + V c22 = hn::Zero(d); + V c23 = hn::Zero(d); + + V c30 = hn::Zero(d); + V c31 = hn::Zero(d); + V c32 = hn::Zero(d); + V c33 = hn::Zero(d); + + const float* HWY_RESTRICT tile_a = A + stride_a * row_a; + + hwy::AlignedFreeUniquePtr tile_b_unique_ptr = + hwy::AllocateAligned(kRegRows * kColsA_RowsB); + CompressTraits::Decompress( + d, + /*in_capacity=*/0, B, stride_b * row_b_col_c, tile_b_unique_ptr.get(), + kRegRows * kColsA_RowsB); + const float* HWY_RESTRICT tile_b = tile_b_unique_ptr.get(); + + // Loop over columns of A and columns of the transposed B, in steps of N. + // Accumulates into the c## vectors. + HWY_UNROLL(1) + for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) { + const V b0 = hn::LoadU(d, tile_b + stride_b * 0 + col_ab); + const V b1 = hn::LoadU(d, tile_b + stride_b * 1 + col_ab); + const V b2 = hn::LoadU(d, tile_b + stride_b * 2 + col_ab); + const V b3 = hn::LoadU(d, tile_b + stride_b * 3 + col_ab); + + const V a0 = hn::LoadU(d, tile_a + stride_a * 0 + col_ab); + c00 = hn::MulAdd(a0, b0, c00); + c01 = hn::MulAdd(a0, b1, c01); + c02 = hn::MulAdd(a0, b2, c02); + c03 = hn::MulAdd(a0, b3, c03); + + const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab); + c10 = hn::MulAdd(a1, b0, c10); + c11 = hn::MulAdd(a1, b1, c11); + c12 = hn::MulAdd(a1, b2, c12); + c13 = hn::MulAdd(a1, b3, c13); + + const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab); + c20 = hn::MulAdd(a2, b0, c20); + c21 = hn::MulAdd(a2, b1, c21); + c22 = hn::MulAdd(a2, b2, c22); + c23 = hn::MulAdd(a2, b3, c23); + + const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab); + c30 = hn::MulAdd(a3, b0, c30); + c31 = hn::MulAdd(a3, b1, c31); + c32 = hn::MulAdd(a3, b2, c32); + c33 = hn::MulAdd(a3, b3, c33); + } + + float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; + StoreHorizontalSums(d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, + c23, c30, c31, c32, c33, tile_c, stride_c); +} + // Same as above, but with mixed Mat types. template @@ -518,6 +606,18 @@ HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a, } } +template +HWY_INLINE void MatMulSlow(const float* HWY_RESTRICT a, + const SfpStream* HWY_RESTRICT b_sfp_stream, + float* HWY_RESTRICT out) { + const hn::ScalableTag d; + hwy::AlignedFreeUniquePtr b = hwy::AllocateAligned(kK * kN); + CompressTraits::Decompress(d, + /*in_capacity=*/0, b_sfp_stream, 0, + b.get(), kK * kN); + MatMulSlow(a, b.get(), out); +} + HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned, const size_t size, float* HWY_RESTRICT out) { const hn::ScalableTag df; diff --git a/gemma/ops_test.cc b/gemma/ops_test.cc index abb7dae..7c007fd 100644 --- a/gemma/ops_test.cc +++ b/gemma/ops_test.cc @@ -293,8 +293,8 @@ struct TestSoftmax { for (size_t i = 0; i < count; ++i) { sum += x[i]; double rel = std::abs(x[i] - e[i]) / e[i]; - ASSERT_LT(rel, 1e-6) - << "Mismatch on coordinate " << i << " out of " << count; + ASSERT_LT(rel, 1e-6) << "Mismatch on coordinate " << i << " out of " + << count; } ASSERT_NEAR(sum, 1.0, 1e-6); } @@ -388,7 +388,7 @@ std::unique_ptr> GenerateMatHeap( new CompressedArray); hwy::AlignedFreeUniquePtr content = hwy::AllocateAligned(kOuter * kInner); - const float scale = 1.0f / kInner; + const float scale = 1.875f / (kInner * kOuter + offset); pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) { for (size_t j = 0; j < kInner; j++) { content[i * kInner + j] = @@ -411,7 +411,7 @@ GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) { new CompressedArray); hwy::AlignedFreeUniquePtr content = hwy::AllocateAligned(kOuter * kInner); - const float scale = 1.0f / kInner; + const float scale = 1.875f / (kInner * kOuter + offset); pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) { for (size_t j = 0; j < kInner; j++) { content[j * kOuter + i] = @@ -554,17 +554,17 @@ void TestAllTiledMatMul() { TestTiledMatMul<512, 512, 512, float>(); TestTiledMatMul<512, 512, 512, hwy::bfloat16_t>(); TestTiledMatMul<512, 512, 512, float, hwy::bfloat16_t>(); + TestTiledMatMul<512, 512, 512, float, SfpStream>(); // minimal non-square test TestTiledMatMul<4, 128, 4, float>(); TestTiledMatMul<4, 128, 4, hwy::bfloat16_t>(); TestTiledMatMul<4, 128, 4, float, hwy::bfloat16_t>(); + TestTiledMatMul<32, 128, 32, float, SfpStream>(); // large-scale test // TODO(philculliton): investigate rounding issues with large matrices TestTiledMatMul<512, 24576, 3072, float>(); - - // TODO(janwas): SFP } void TestMatVecAdd() { @@ -666,6 +666,7 @@ HWY_AFTER_NAMESPACE(); namespace gcpp { HWY_BEFORE_TEST(OpsTest); + HWY_EXPORT_AND_TEST_P(OpsTest, TestAllAddFrom); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulBy); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);