From 3e920885950bd98b71330ba8b0a74983bbaa8ad4 Mon Sep 17 00:00:00 2001 From: Andrey Vlasov Date: Thu, 11 Jul 2024 05:13:03 -0700 Subject: [PATCH] Remove allocation from GEMM_4x4_Tile when decoding compressed weights by implementing SfpCodec::Dec2F and ComressTraits::Decompress2 for all supported types. It also allows to remove one of the specializations of GEMM_4x4_Tile, handling compressed MatB with one function. As before even when MatA is bf16 it is using 32-bit registers for computations. Measurements for a 2b-it sfp-encoded model on a AMD Ryzen Threadripper PRO 3945WX 12-Cores: baseline: ``` 32.6254 prefill tokens / sec 8.91429 tokens / sec 115 milliseconds time to first token ``` this change: ``` 54.3045 prefill tokens / sec 16.8191 tokens / sec 56 milliseconds time to first token ``` PiperOrigin-RevId: 651369694 --- BUILD.bazel | 2 + compression/compress-inl.h | 39 +++++- compression/sfp-inl.h | 12 ++ gemma/ops.h | 279 +++++++++++++------------------------ 4 files changed, 142 insertions(+), 190 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 467bcf1..12abd4c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -25,6 +25,7 @@ cc_library( hdrs = ["gemma/ops.h"], deps = [ "//compression:compress", + "//compression:sfp", "@hwy//:algo", "@hwy//:dot", "@hwy//:hwy", @@ -49,6 +50,7 @@ cc_test( "//compression:compress", "@hwy//:hwy", "@hwy//:hwy_test_util", + "@hwy//:nanobenchmark", "@hwy//:thread_pool", ], ) diff --git a/compression/compress-inl.h b/compression/compress-inl.h index b3e943d..1693cf0 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -79,6 +79,15 @@ struct CompressTraits { } } + template + static HWY_INLINE void Decompress2(DF df, const MatT* HWY_RESTRICT in, + size_t in_ofs, hn::Vec& f0, + hn::Vec& f1) { + const size_t N = hn::Lanes(df); + f0 = hn::LoadU(df, in + in_ofs); + f1 = hn::LoadU(df, in + in_ofs + N); + } + template static HWY_INLINE void Decompress(DF df, size_t /*in_capacity*/, const MatT* HWY_RESTRICT in, size_t in_ofs, @@ -88,8 +97,8 @@ struct CompressTraits { HWY_DASSERT(num >= 2 * N && num % (2 * N) == 0); for (size_t i = 0; i < num; i += 2 * N) { - const VF in0 = hn::LoadU(df, in + in_ofs + i); - const VF in1 = hn::LoadU(df, in + in_ofs + i + N); + VF in0, in1; + Decompress2(df, in, in_ofs + i, in0, in1); hn::StoreU(in0, df, out + i); hn::StoreU(in1, df, out + i + N); } @@ -174,6 +183,17 @@ struct CompressTraits { } } + template + static HWY_INLINE void Decompress2(DF df, const MatT* HWY_RESTRICT in, + size_t in_ofs, hn::Vec& f0, + hn::Vec& f1) { + const hn::Repartition dbf; + using VBF = hn::Vec; + const VBF in16 = hn::LoadU(dbf, in + in_ofs); + f0 = hn::PromoteLowerTo(df, in16); + f1 = hn::PromoteUpperTo(df, in16); + } + template static HWY_INLINE void Decompress(DF df, size_t /*in_capacity*/, const MatT* HWY_RESTRICT in, size_t in_ofs, @@ -186,9 +206,8 @@ struct CompressTraits { size_t i = 0; if (num >= N16) { for (i = 0; i <= num - N16; i += N16) { - const VBF in16 = hn::LoadU(dbf, in + in_ofs + i); - const VF in0 = hn::PromoteLowerTo(df, in16); - const VF in1 = hn::PromoteUpperTo(df, in16); + VF in0, in1; + Decompress2(df, in, in_ofs + i, in0, in1); hn::StoreU(in0, df, out + i); hn::StoreU(in1, df, out + i + N16 / 2); } @@ -296,6 +315,16 @@ struct CompressTraits { } } + template + static HWY_INLINE void Decompress2(DF df, const MatT* HWY_RESTRICT in, + size_t in_ofs, hn::Vec& f0, + hn::Vec& f1) { + const hn::Twice> d8; + using V8 = hn::Vec; + const V8 packed = hn::LoadU(d8, &in->byte + in_ofs); + SfpCodec::Dec2F(df, packed, f0, f1); + } + template static HWY_INLINE void Decompress(D d, size_t /*in_capacity*/, const MatT* HWY_RESTRICT in, size_t in_ofs, diff --git a/compression/sfp-inl.h b/compression/sfp-inl.h index 86505c7..3be36ec 100644 --- a/compression/sfp-inl.h +++ b/compression/sfp-inl.h @@ -535,6 +535,18 @@ class SfpCodec { } } + template >>> + static HWY_INLINE void Dec2F(DF df, V8 packed, hn::Vec& f0, + hn::Vec& f1) { + const hn::Rebind dbf; + using VBF = hn::Vec; + VBF bf0, bf1; + Dec2B(dbf, packed, bf0, bf1); + f0 = hn::PromoteTo(df, bf0); + f1 = hn::PromoteTo(df, bf1); + } + private: // Wrappers to avoid code duplication across float/bf16 input types and // the main loop/remainder. diff --git a/gemma/ops.h b/gemma/ops.h index cdb7541..8a31f70 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -272,9 +272,9 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; if constexpr (kAdd) { const AddT* tile_add = add + row_b_col_c; - StoreHorizontalSumsAdd( - d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, - c30, c31, c32, c33, tile_add, tile_c, stride_c); + StoreHorizontalSumsAdd(d, c00, c01, c02, c03, c10, c11, c12, c13, + c20, c21, c22, c23, c30, c31, c32, c33, + tile_add, tile_c, stride_c); } else { StoreHorizontalSums( d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, @@ -433,10 +433,10 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; if constexpr (kAdd) { - const AddT* tile_add = add + row_b_col_c; - StoreHorizontalSumsAdd( - df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, - c30, c31, c32, c33, tile_add, tile_c, stride_c); + const AddT* dd = add + row_b_col_c; + StoreHorizontalSumsAdd(df, c00, c01, c02, c03, c10, c11, c12, c13, + c20, c21, c22, c23, c30, c31, c32, c33, dd, + tile_c, stride_c); } else { StoreHorizontalSums( df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, @@ -444,120 +444,31 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, } } -// Same as above, but with mixed Mat types: (f32, compressed). -template -HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, - const MatTB* HWY_RESTRICT B, - float* HWY_RESTRICT C, - const AddT* add, - const size_t idx_tile, - const size_t xtiles, const size_t stride_a, - const size_t stride_b, const size_t stride_c) { - constexpr size_t kRegRows = 4; - constexpr size_t kRegCols = 4; - static_assert(kNumRows <= kRegRows); - - // Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B. - const size_t row_a = idx_tile / xtiles * kRegRows; - const size_t row_b_col_c = idx_tile % xtiles * kRegCols; - - const hn::ScalableTag d; - const size_t N = Lanes(d); - using V = hn::Vec; - - V c00 = hn::Zero(d); - V c01 = hn::Zero(d); - V c02 = hn::Zero(d); - V c03 = hn::Zero(d); - - V c10 = hn::Zero(d); - V c11 = hn::Zero(d); - V c12 = hn::Zero(d); - V c13 = hn::Zero(d); - - V c20 = hn::Zero(d); - V c21 = hn::Zero(d); - V c22 = hn::Zero(d); - V c23 = hn::Zero(d); - - V c30 = hn::Zero(d); - V c31 = hn::Zero(d); - V c32 = hn::Zero(d); - V c33 = hn::Zero(d); - - const MatTA* HWY_RESTRICT tile_a = A + stride_a * row_a; - - hwy::AlignedFreeUniquePtr tile_b_unique_ptr = - hwy::AllocateAligned(kRegRows * kColsA_RowsB); - CompressTraits::Decompress( - d, - /*in_capacity=*/0, B, stride_b * row_b_col_c, tile_b_unique_ptr.get(), - kRegRows * kColsA_RowsB); - const float* HWY_RESTRICT tile_b = tile_b_unique_ptr.get(); - - // Loop over columns of A and columns of the transposed B, in steps of N. - // Accumulates into the c## vectors. - HWY_UNROLL(1) - for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) { - const V b0 = hn::LoadU(d, tile_b + stride_b * 0 + col_ab); - const V b1 = hn::LoadU(d, tile_b + stride_b * 1 + col_ab); - const V b2 = hn::LoadU(d, tile_b + stride_b * 2 + col_ab); - const V b3 = hn::LoadU(d, tile_b + stride_b * 3 + col_ab); - - const V a0 = hn::LoadU(d, tile_a + stride_a * 0 + col_ab); - c00 = hn::MulAdd(a0, b0, c00); - c01 = hn::MulAdd(a0, b1, c01); - c02 = hn::MulAdd(a0, b2, c02); - c03 = hn::MulAdd(a0, b3, c03); - if constexpr (kNumRows == 1) continue; - - const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab); - c10 = hn::MulAdd(a1, b0, c10); - c11 = hn::MulAdd(a1, b1, c11); - c12 = hn::MulAdd(a1, b2, c12); - c13 = hn::MulAdd(a1, b3, c13); - if constexpr (kNumRows == 2) continue; - - const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab); - c20 = hn::MulAdd(a2, b0, c20); - c21 = hn::MulAdd(a2, b1, c21); - c22 = hn::MulAdd(a2, b2, c22); - c23 = hn::MulAdd(a2, b3, c23); - if constexpr (kNumRows == 3) continue; - - const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab); - c30 = hn::MulAdd(a3, b0, c30); - c31 = hn::MulAdd(a3, b1, c31); - c32 = hn::MulAdd(a3, b2, c32); - c33 = hn::MulAdd(a3, b3, c33); - } - - float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; - if constexpr (kAdd) { - const AddT* tile_add = add + row_b_col_c; - StoreHorizontalSumsAdd( - d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, - c30, c31, c32, c33, tile_add, tile_c, stride_c); - } else { - StoreHorizontalSums( - d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, - c30, c31, c32, c33, tile_c, stride_c); - } +template +HWY_INLINE void UpdateTileRow(const VF& a0, const VF& a1, const VF& b00, + const VF& b01, const VF& b10, const VF& b11, + const VF& b20, const VF& b21, const VF& b30, + const VF& b31, VF& c0, VF& c1, VF& c2, VF& c3) { + c0 = MulAdd(a0, b00, c0); + c0 = MulAdd(a1, b01, c0); + c1 = MulAdd(a0, b10, c1); + c1 = MulAdd(a1, b11, c1); + c2 = MulAdd(a0, b20, c2); + c2 = MulAdd(a1, b21, c2); + c3 = MulAdd(a0, b30, c3); + c3 = MulAdd(a1, b31, c3); } -// Same as above, but with mixed Mat types: (bf16, compressed). +// Same as above, for when there exists CompressTraits::Decompress2 and +// MatTB is compressed. template + typename MatTB, HWY_IF_T_SIZE(MatTB, 1), typename AddT> HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B, - float* HWY_RESTRICT C, - const AddT* add, - const size_t idx_tile, - const size_t xtiles, const size_t stride_a, - const size_t stride_b, const size_t stride_c) { + float* HWY_RESTRICT C, const AddT* add, + const size_t idx_tile, const size_t xtiles, + const size_t stride_a, const size_t stride_b, + const size_t stride_c) { constexpr size_t kRegRows = 4; constexpr size_t kRegCols = 4; static_assert(kNumRows <= kRegRows); @@ -567,13 +478,8 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const size_t row_b_col_c = idx_tile % xtiles * kRegCols; const hn::ScalableTag d32; - const size_t N = Lanes(d32); + const size_t N = hn::Lanes(d32); using V = hn::Vec; - // TODO: Using half-vectors for now, it might be faster to - // PromoteLower/UpperTo, and more so to PromoteEven/OddTo if we have packed B - // accordingly. - const hn::Rebind d16; - HWY_DASSERT(Lanes(d16) == Lanes(d32)); V c00 = hn::Zero(d32); V c01 = hn::Zero(d32); @@ -595,67 +501,70 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, V c32 = hn::Zero(d32); V c33 = hn::Zero(d32); - const MatTA* HWY_RESTRICT tile_a = A + stride_a * row_a; + const size_t tile_a_ofs = stride_a * row_a; + const size_t tile_b_ofs = stride_b * row_b_col_c; - hwy::AlignedFreeUniquePtr tile_b_unique_ptr = - hwy::AllocateAligned(kRegRows * kColsA_RowsB); - CompressTraits::Decompress( - d32, - /*in_capacity=*/0, B, stride_b * row_b_col_c, tile_b_unique_ptr.get(), - kRegRows * kColsA_RowsB); - const float* HWY_RESTRICT tile_b = tile_b_unique_ptr.get(); - - // Loop over columns of A and columns of the transposed B, in steps of N. + // Loop over columns of A and columns of the transposed B, in steps of 2*N + // (since we are decoding consecutive bytes at each iteration). // Accumulates into the c## vectors. - HWY_UNROLL(1) - for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) { - const V b0 = hn::LoadU(d32, tile_b + stride_b * 0 + col_ab); - const V b1 = hn::LoadU(d32, tile_b + stride_b * 1 + col_ab); - const V b2 = hn::LoadU(d32, tile_b + stride_b * 2 + col_ab); - const V b3 = hn::LoadU(d32, tile_b + stride_b * 3 + col_ab); + size_t col_ab = 0; + + HWY_UNROLL(1) + for (; col_ab <= kColsA_RowsB - 2 * N; col_ab += 2 * N) { + V b00, b01; + CompressTraits::Decompress2( + d32, B, tile_b_ofs + stride_b * 0 + col_ab, b00, b01); + V b10, b11; + CompressTraits::Decompress2( + d32, B, tile_b_ofs + stride_b * 1 + col_ab, b10, b11); + V b20, b21; + CompressTraits::Decompress2( + d32, B, tile_b_ofs + stride_b * 2 + col_ab, b20, b21); + V b30, b31; + CompressTraits::Decompress2( + d32, B, tile_b_ofs + stride_b * 3 + col_ab, b30, b31); + + V a00, a01; + CompressTraits::Decompress2( + d32, A, tile_a_ofs + stride_a * 0 + col_ab, a00, a01); + UpdateTileRow(a00, a01, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01, + c02, c03); - const V a0 = - hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 0 + col_ab)); - c00 = hn::MulAdd(a0, b0, c00); - c01 = hn::MulAdd(a0, b1, c01); - c02 = hn::MulAdd(a0, b2, c02); - c03 = hn::MulAdd(a0, b3, c03); if constexpr (kNumRows == 1) continue; - const V a1 = - hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab)); - c10 = hn::MulAdd(a1, b0, c10); - c11 = hn::MulAdd(a1, b1, c11); - c12 = hn::MulAdd(a1, b2, c12); - c13 = hn::MulAdd(a1, b3, c13); + V a10, a11; + CompressTraits::Decompress2( + d32, A, tile_a_ofs + stride_a * 1 + col_ab, a10, a11); + UpdateTileRow(a10, a11, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11, + c12, c13); + if constexpr (kNumRows == 2) continue; - const V a2 = - hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab)); - c20 = hn::MulAdd(a2, b0, c20); - c21 = hn::MulAdd(a2, b1, c21); - c22 = hn::MulAdd(a2, b2, c22); - c23 = hn::MulAdd(a2, b3, c23); + V a20, a21; + CompressTraits::Decompress2( + d32, A, tile_a_ofs + stride_a * 2 + col_ab, a20, a21); + UpdateTileRow(a20, a21, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21, + c22, c23); + if constexpr (kNumRows == 3) continue; - const V a3 = - hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab)); - c30 = hn::MulAdd(a3, b0, c30); - c31 = hn::MulAdd(a3, b1, c31); - c32 = hn::MulAdd(a3, b2, c32); - c33 = hn::MulAdd(a3, b3, c33); + V a30, a31; + CompressTraits::Decompress2( + d32, A, tile_a_ofs + stride_a * 3 + col_ab, a30, a31); + UpdateTileRow(a30, a31, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31, + c32, c33); } float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; if constexpr (kAdd) { - const AddT* tile_add = add + row_b_col_c; - StoreHorizontalSumsAdd( - d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, - c30, c31, c32, c33, tile_add, tile_c, stride_c); + const AddT* dd = add + row_b_col_c; + StoreHorizontalSumsAdd(d32, c00, c01, c02, c03, c10, c11, c12, + c13, c20, c21, c22, c23, c30, c31, c32, + c33, dd, tile_c, stride_c); } else { - StoreHorizontalSums( - d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, - c30, c31, c32, c33, tile_c, stride_c); + StoreHorizontalSums(d32, c00, c01, c02, c03, c10, c11, c12, c13, + c20, c21, c22, c23, c30, c31, c32, c33, + tile_c, stride_c); } } @@ -755,10 +664,10 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; if constexpr (kAdd) { - const AddT* tile_add = add + row_b_col_c; - StoreHorizontalSumsAdd( - d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, - c30, c31, c32, c33, tile_add, tile_c, stride_c); + const AddT* dd = add + row_b_col_c; + StoreHorizontalSumsAdd(d32, c00, c01, c02, c03, c10, c11, c12, + c13, c20, c21, c22, c23, c30, c31, c32, + c33, dd, tile_c, stride_c); } else { StoreHorizontalSums( d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, @@ -861,10 +770,10 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; if constexpr (kAdd) { - const AddT* tile_add = add + row_b_col_c; - StoreHorizontalSumsAdd( - d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, - c30, c31, c32, c33, tile_add, tile_c, stride_c); + const AddT* dd = add + row_b_col_c; + StoreHorizontalSumsAdd(d32, c00, c01, c02, c03, c10, c11, c12, + c13, c20, c21, c22, c23, c30, c31, c32, + c33, dd, tile_c, stride_c); } else { StoreHorizontalSums( d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, @@ -906,20 +815,20 @@ HWY_NOINLINE void MatMul_4x4_Batch_Add( HWY_ASSERT(num_rows > 0); switch (num_rows) { case 1: - GEMM_4x4_Tile<1, kColsA_RowsB, kAdd>( - A, B, C, add, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC); + GEMM_4x4_Tile<1, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX, + kStrideA, kStrideB, kStrideC); break; case 2: - GEMM_4x4_Tile<2, kColsA_RowsB, kAdd>( - A, B, C, add, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC); + GEMM_4x4_Tile<2, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX, + kStrideA, kStrideB, kStrideC); break; case 3: - GEMM_4x4_Tile<3, kColsA_RowsB, kAdd>( - A, B, C, add, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC); + GEMM_4x4_Tile<3, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX, + kStrideA, kStrideB, kStrideC); break; default: - GEMM_4x4_Tile<4, kColsA_RowsB, kAdd>( - A, B, C, add, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC); + GEMM_4x4_Tile<4, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX, + kStrideA, kStrideB, kStrideC); } }); }