From 29c0c574e680ecec2594f19b4786d48a853527f6 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 14 Jun 2024 06:32:00 -0700 Subject: [PATCH] Integrate matmul into FFW: 4.3x prefill speedup ``` before, bf16: 27.2929 prefill tokens / sec 17.2114 tokens / sec after, bf16 116.496 prefill tokens / sec 17.5391 tokens / sec ``` PiperOrigin-RevId: 643328437 --- gemma/configs.h | 6 ++-- gemma/gemma.cc | 90 ++++++++++++++++++++++++++++++++--------------- gemma/ops.h | 85 ++++++++++++++++++++++++++++++++++++-------- gemma/ops_test.cc | 22 ------------ 4 files changed, 136 insertions(+), 67 deletions(-) diff --git a/gemma/configs.h b/gemma/configs.h index b59b450..47efed6 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -136,7 +136,7 @@ struct ConfigGemmaTiny : public ConfigNoSSM { using Weight = TWeight; // make accessible where we only have a TConfig static constexpr int kSeqLen = 32; - static constexpr int kVocabSize = 16; + static constexpr int kVocabSize = 64; static constexpr std::array kLayerConfig = FixedLayerConfig<3>(LayerAttentionType::kGemma); static constexpr int kLayers = kLayerConfig.size(); @@ -146,8 +146,8 @@ struct ConfigGemmaTiny : public ConfigNoSSM { NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGriffinRecurrentBlock, kLayers); - static constexpr int kModelDim = 64; - static constexpr int kFFHiddenDim = 128; + static constexpr int kModelDim = 128; + static constexpr int kFFHiddenDim = 256; static constexpr int kHeads = 4; static constexpr int kKVHeads = 1; static constexpr int kQKVDim = 16; // query size == key size == value size diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 5de541b..1716399 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -88,6 +88,11 @@ struct Activations { att_post2; // accumulation of attention outputs over heads std::array bf_pre_ffw_rms_out; std::array ffw_hidden; + + // For FFW MatMul. + std::array C1; + std::array C2; + // bf_ version can't be used until GeluMulToBF16 issue in FFW() is resolved. // std::array // bf_ffw_hidden; @@ -508,41 +513,70 @@ HWY_NOINLINE void FFW(Activations& activations, static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; float* HWY_RESTRICT even_odd = activations.even_odd.data(); - for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - const size_t hidden_offset = batch_idx * kFFHiddenDim * 2; + // TODO: MatMul does not yet support adding another matrix to the result. + if constexpr (!TConfig::kFFBiases) { PROFILER_ZONE("Gen.FFW.GatedGELU"); - const hwy::bfloat16_t* HWY_RESTRICT vec = - activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim; - float* HWY_RESTRICT out = activations.ffw_hidden.data() + hidden_offset; - float* HWY_RESTRICT out_mul = out + kFFHiddenDim; - // Same matrix, first and second half of rows. Could fuse into one MatVec. - MatVecT( - layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec, - TConfig::kFFBiases ? - layer_weights->ffw_gating_biases.data() + kFFHiddenDim : nullptr, - even_odd, out_mul, pool); - // Gate, will go through the nonlinearity. - MatVecT( - layer_weights->gating_einsum_w, 0, vec, - layer_weights->ffw_gating_biases.data(), even_odd, out, pool); + // MatMul expects col-major B, which is what we have: kModelDim consecutive + // elements in memory, repeated kFFHiddenDim times. + const auto b1 = layer_weights->gating_einsum_w.data(); + constexpr size_t kColsA = kModelDim; + constexpr size_t kColsB = kFFHiddenDim; + const auto b2 = b1 + kColsA * kColsB; + auto A = activations.bf_pre_ffw_rms_out.data(); + // Will go through GELU. + MatMul_4x4_Batch(num_tokens, A, b1, activations.C1.data(), + pool); + // What to multiply by. + MatMul_4x4_Batch(num_tokens, A, b2, activations.C2.data(), + pool); + // Gelu and multiply by gate. namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; - hn::Transform1(DF(), out, kFFHiddenDim, out_mul, - [](DF df, VF v, VF mul) - HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); }); - } + hn::Transform1(DF(), activations.C1.data(), kFFHiddenDim * num_tokens, + activations.C2.data(), [](DF df, VF v, VF mul) HWY_ATTR { + return hn::Mul(mul, Gelu(df, v)); + }); - for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - PROFILER_ZONE("Gen.FFW\\GatedGELU"); - const size_t hidden_offset = batch_idx * kFFHiddenDim * 2; - MatVecT( - layer_weights->linear_w, 0, - activations.ffw_hidden.data() + hidden_offset, - layer_weights->ffw_output_biases.data(), even_odd, - activations.ffw_out.data() + batch_idx * kModelDim, pool); + MatMul_4x4_Batch(num_tokens, activations.C1.data(), + layer_weights->linear_w.data(), + activations.ffw_out.data(), pool); + } else { + for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { + const size_t hidden_offset = batch_idx * kFFHiddenDim * 2; + const hwy::bfloat16_t* HWY_RESTRICT vec = + activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim; + float* HWY_RESTRICT out = activations.ffw_hidden.data() + hidden_offset; + float* HWY_RESTRICT out_mul = out + kFFHiddenDim; + + PROFILER_ZONE("Gen.FFW.GatedGELU"); + // Same matrix, first and second half of rows. Could fuse into one MatVec. + MatVecT( + layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec, + TConfig::kFFBiases + ? layer_weights->ffw_gating_biases.data() + kFFHiddenDim + : nullptr, + even_odd, out_mul, pool); + // Gate, will go through the nonlinearity. + MatVecT( + layer_weights->gating_einsum_w, 0, vec, + layer_weights->ffw_gating_biases.data(), even_odd, out, pool); + + namespace hn = hwy::HWY_NAMESPACE; + using DF = hn::ScalableTag; + using VF = hn::Vec; + hn::Transform1(DF(), out, kFFHiddenDim, out_mul, + [](DF df, VF v, VF mul) + HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); }); + + MatVecT( + layer_weights->linear_w, 0, + activations.ffw_hidden.data() + hidden_offset, + layer_weights->ffw_output_biases.data(), even_odd, + activations.ffw_out.data() + batch_idx * kModelDim, pool); + } } } diff --git a/gemma/ops.h b/gemma/ops.h index 93a9a4b..8bb64f2 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -23,6 +23,7 @@ #include #include +#include #include #include // std::enable_if_t @@ -70,6 +71,29 @@ StaticCast(From from) noexcept { return static_cast(from); } +// For testing. +template +void AssertClose(const MatT* HWY_RESTRICT expected, + const MatT* HWY_RESTRICT actual, size_t num) { + for (size_t idx = 0; idx < num; idx++) { + const double expected_value = hwy::ConvertScalarTo(expected[idx]); + const double actual_value = hwy::ConvertScalarTo(actual[idx]); + + const double magnitude = std::abs(expected_value); + + const double tolerance = + 256.0 * hwy::ConvertScalarTo(hwy::Epsilon()) * + HWY_MAX(magnitude, 1.0); + + if (!(expected_value - tolerance <= actual_value && + actual_value <= expected_value + tolerance)) { + fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f\n", idx, + expected_value, idx, actual_value); + HWY_ASSERT(0); + } + } +} + template HWY_INLINE constexpr size_t RowsPerStrip() { // Aim for 128 work items to reduce pool overhead. Must be at least one @@ -362,11 +386,11 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A, c23, c30, c31, c32, c33, tile_c, stride_c); } -// Same as above, but with mixed Mat types: (f32, sfp). +// Same as above, but with mixed Mat types: (f32, compressed). template + HWY_IF_F32(MatTA), typename MatTB, HWY_IF_T_SIZE(MatTB, 1)> HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, - const SfpStream* HWY_RESTRICT B, + const MatTB* HWY_RESTRICT B, float* HWY_RESTRICT C, const size_t idx_tile, const size_t xtiles, const size_t stride_a, const size_t stride_b, const size_t stride_c) { @@ -406,7 +430,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, hwy::AlignedFreeUniquePtr tile_b_unique_ptr = hwy::AllocateAligned(kRegRows * kColsA_RowsB); - CompressTraits::Decompress( + CompressTraits::Decompress( d, /*in_capacity=*/0, B, stride_b * row_b_col_c, tile_b_unique_ptr.get(), kRegRows * kColsA_RowsB); @@ -455,11 +479,11 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, c23, c30, c31, c32, c33, tile_c, stride_c); } -// Same as above, but with mixed Mat types: (bf16, sfp). +// Same as above, but with mixed Mat types: (bf16, compressed)). template + HWY_IF_BF16(MatTA), typename MatTB, HWY_IF_T_SIZE(MatTB, 1)> HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, - const SfpStream* HWY_RESTRICT B, + const MatTB* HWY_RESTRICT B, float* HWY_RESTRICT C, const size_t idx_tile, const size_t xtiles, const size_t stride_a, const size_t stride_b, const size_t stride_c) { @@ -504,7 +528,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, hwy::AlignedFreeUniquePtr tile_b_unique_ptr = hwy::AllocateAligned(kRegRows * kColsA_RowsB); - CompressTraits::Decompress( + CompressTraits::Decompress( d32, /*in_capacity=*/0, B, stride_b * row_b_col_c, tile_b_unique_ptr.get(), kRegRows * kColsA_RowsB); @@ -806,7 +830,37 @@ HWY_NOINLINE void MatMul_4x4_Batch( // Largely unoptimized; reordered innermost loops nets ~5-10X speedup on // ops_test across instruction sets. -template +template +HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a, + const MatTB* HWY_RESTRICT b, + float* HWY_RESTRICT out) { + for (size_t i = 0; i < kM; ++i) { + for (size_t k = 0; k < kN; ++k) { + for (size_t j = 0; j < kK; ++j) { + const float a1 = hwy::ConvertScalarTo(a[i * kN + k]); + const float b1 = hwy::ConvertScalarTo(b[k * kK + j]); + out[i * kK + j] += a1 * b1; + } + } + } +} + +template +HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a, + const SfpStream* HWY_RESTRICT b_sfp_stream, + float* HWY_RESTRICT out) { + const hn::ScalableTag d; + hwy::AlignedFreeUniquePtr b = hwy::AllocateAligned(kK * kN); + CompressTraits::Decompress(d, + /*in_capacity=*/0, b_sfp_stream, 0, + b.get(), kK * kN); + MatMulSlow(a, b.get(), out); +} + +// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on +// ops_test across instruction sets. +template HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a, const MatTB* HWY_RESTRICT b, float* HWY_RESTRICT out) { @@ -821,15 +875,18 @@ HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a, } } -template +// The above overload can handle combinations of f32 and bf16, but this one +// is required for MatTB = {SFP, NUQ}. +template HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a, - const SfpStream* HWY_RESTRICT b_sfp_stream, + const MatTB* HWY_RESTRICT b_compr, float* HWY_RESTRICT out) { const hn::ScalableTag d; hwy::AlignedFreeUniquePtr b = hwy::AllocateAligned(kK * kN); - CompressTraits::Decompress(d, - /*in_capacity=*/0, b_sfp_stream, 0, - b.get(), kK * kN); + CompressTraits::Decompress(d, + /*in_capacity=*/0, b_compr, 0, b.get(), + kK * kN); MatMulSlowBatch(batch_size, a, b.get(), out); } diff --git a/gemma/ops_test.cc b/gemma/ops_test.cc index c9efde1..fe48833 100644 --- a/gemma/ops_test.cc +++ b/gemma/ops_test.cc @@ -506,28 +506,6 @@ void AssertClose(const hwy::AlignedFreeUniquePtr& a, } } -template -void AssertClose(const MatT* HWY_RESTRICT expected, - const MatT* HWY_RESTRICT actual, size_t num) { - for (size_t idx = 0; idx < num; idx++) { - const double expected_value = hwy::ConvertScalarTo(expected[idx]); - const double actual_value = hwy::ConvertScalarTo(actual[idx]); - - const double magnitude = std::abs(expected_value); - - const double tolerance = - 64.0 * hwy::ConvertScalarTo(hwy::Epsilon()) * - HWY_MAX(magnitude, 1.0); - - if (!(expected_value - tolerance <= actual_value && - actual_value <= expected_value + tolerance)) { - fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f, tolerance: %f\n", - idx, expected_value, idx, actual_value, tolerance); - HWY_ASSERT(0); - } - } -} - template void TestTiledBatchMatMul() {