Integrate matmul into FFW: 4.3x prefill speedup

```
before, bf16:
27.2929 prefill tokens / sec
17.2114 tokens / sec

after, bf16
116.496 prefill tokens / sec
17.5391 tokens / sec
```

PiperOrigin-RevId: 643328437
This commit is contained in:
Jan Wassenberg 2024-06-14 06:32:00 -07:00 committed by Copybara-Service
parent 198326a682
commit 29c0c574e6
4 changed files with 136 additions and 67 deletions

View File

@ -136,7 +136,7 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = 32; static constexpr int kSeqLen = 32;
static constexpr int kVocabSize = 16; static constexpr int kVocabSize = 64;
static constexpr std::array<LayerAttentionType, 3> kLayerConfig = static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
FixedLayerConfig<3>(LayerAttentionType::kGemma); FixedLayerConfig<3>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size(); static constexpr int kLayers = kLayerConfig.size();
@ -146,8 +146,8 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
NumLayersOfTypeBefore(kLayerConfig, NumLayersOfTypeBefore(kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock, LayerAttentionType::kGriffinRecurrentBlock,
kLayers); kLayers);
static constexpr int kModelDim = 64; static constexpr int kModelDim = 128;
static constexpr int kFFHiddenDim = 128; static constexpr int kFFHiddenDim = 256;
static constexpr int kHeads = 4; static constexpr int kHeads = 4;
static constexpr int kKVHeads = 1; static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 16; // query size == key size == value size static constexpr int kQKVDim = 16; // query size == key size == value size

View File

@ -88,6 +88,11 @@ struct Activations {
att_post2; // accumulation of attention outputs over heads att_post2; // accumulation of attention outputs over heads
std::array<hwy::bfloat16_t, kBatchSize * kModelDim> bf_pre_ffw_rms_out; std::array<hwy::bfloat16_t, kBatchSize * kModelDim> bf_pre_ffw_rms_out;
std::array<float, kBatchSize * TConfig::kFFHiddenDim * 2> ffw_hidden; std::array<float, kBatchSize * TConfig::kFFHiddenDim * 2> ffw_hidden;
// For FFW MatMul.
std::array<float, kBatchSize * TConfig::kFFHiddenDim> C1;
std::array<float, kBatchSize * TConfig::kFFHiddenDim> C2;
// bf_ version can't be used until GeluMulToBF16 issue in FFW() is resolved. // bf_ version can't be used until GeluMulToBF16 issue in FFW() is resolved.
// std::array<hwy::bfloat16_t, kBatchSize * 2 * TConfig::kFFHiddenDim> // std::array<hwy::bfloat16_t, kBatchSize * 2 * TConfig::kFFHiddenDim>
// bf_ffw_hidden; // bf_ffw_hidden;
@ -508,41 +513,70 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
float* HWY_RESTRICT even_odd = activations.even_odd.data(); float* HWY_RESTRICT even_odd = activations.even_odd.data();
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { // TODO: MatMul does not yet support adding another matrix to the result.
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2; if constexpr (!TConfig::kFFBiases) {
PROFILER_ZONE("Gen.FFW.GatedGELU"); PROFILER_ZONE("Gen.FFW.GatedGELU");
const hwy::bfloat16_t* HWY_RESTRICT vec =
activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim;
float* HWY_RESTRICT out = activations.ffw_hidden.data() + hidden_offset;
float* HWY_RESTRICT out_mul = out + kFFHiddenDim;
// Same matrix, first and second half of rows. Could fuse into one MatVec. // MatMul expects col-major B, which is what we have: kModelDim consecutive
MatVecT</*kAdd=*/TConfig::kFFBiases, kFFHiddenDim, kModelDim>( // elements in memory, repeated kFFHiddenDim times.
layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec, const auto b1 = layer_weights->gating_einsum_w.data();
TConfig::kFFBiases ? constexpr size_t kColsA = kModelDim;
layer_weights->ffw_gating_biases.data() + kFFHiddenDim : nullptr, constexpr size_t kColsB = kFFHiddenDim;
even_odd, out_mul, pool); const auto b2 = b1 + kColsA * kColsB;
// Gate, will go through the nonlinearity. auto A = activations.bf_pre_ffw_rms_out.data();
MatVecT</*kAdd=*/TConfig::kFFBiases, kFFHiddenDim, kModelDim>( // Will go through GELU.
layer_weights->gating_einsum_w, 0, vec, MatMul_4x4_Batch<kColsA, kColsB>(num_tokens, A, b1, activations.C1.data(),
layer_weights->ffw_gating_biases.data(), even_odd, out, pool); pool);
// What to multiply by.
MatMul_4x4_Batch<kColsA, kColsB>(num_tokens, A, b2, activations.C2.data(),
pool);
// Gelu and multiply by gate.
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>; using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>; using VF = hn::Vec<DF>;
hn::Transform1(DF(), out, kFFHiddenDim, out_mul, hn::Transform1(DF(), activations.C1.data(), kFFHiddenDim * num_tokens,
[](DF df, VF v, VF mul) activations.C2.data(), [](DF df, VF v, VF mul) HWY_ATTR {
HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); }); return hn::Mul(mul, Gelu(df, v));
} });
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { MatMul_4x4_Batch<kFFHiddenDim, kModelDim>(num_tokens, activations.C1.data(),
PROFILER_ZONE("Gen.FFW\\GatedGELU"); layer_weights->linear_w.data(),
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2; activations.ffw_out.data(), pool);
MatVecT</*kAdd=*/TConfig::kFFBiases, kModelDim, kFFHiddenDim>( } else {
layer_weights->linear_w, 0, for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
activations.ffw_hidden.data() + hidden_offset, const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
layer_weights->ffw_output_biases.data(), even_odd, const hwy::bfloat16_t* HWY_RESTRICT vec =
activations.ffw_out.data() + batch_idx * kModelDim, pool); activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim;
float* HWY_RESTRICT out = activations.ffw_hidden.data() + hidden_offset;
float* HWY_RESTRICT out_mul = out + kFFHiddenDim;
PROFILER_ZONE("Gen.FFW.GatedGELU");
// Same matrix, first and second half of rows. Could fuse into one MatVec.
MatVecT<TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec,
TConfig::kFFBiases
? layer_weights->ffw_gating_biases.data() + kFFHiddenDim
: nullptr,
even_odd, out_mul, pool);
// Gate, will go through the nonlinearity.
MatVecT<TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
layer_weights->gating_einsum_w, 0, vec,
layer_weights->ffw_gating_biases.data(), even_odd, out, pool);
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
hn::Transform1(DF(), out, kFFHiddenDim, out_mul,
[](DF df, VF v, VF mul)
HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); });
MatVecT</*kAdd=*/TConfig::kFFBiases, kModelDim, kFFHiddenDim>(
layer_weights->linear_w, 0,
activations.ffw_hidden.data() + hidden_offset,
layer_weights->ffw_output_biases.data(), even_odd,
activations.ffw_out.data() + batch_idx * kModelDim, pool);
}
} }
} }

View File

@ -23,6 +23,7 @@
#include <stdio.h> #include <stdio.h>
#include <array> #include <array>
#include <cmath>
#include <random> #include <random>
#include <type_traits> // std::enable_if_t #include <type_traits> // std::enable_if_t
@ -70,6 +71,29 @@ StaticCast(From from) noexcept {
return static_cast<To>(from); return static_cast<To>(from);
} }
// For testing.
template <typename MatT>
void AssertClose(const MatT* HWY_RESTRICT expected,
const MatT* HWY_RESTRICT actual, size_t num) {
for (size_t idx = 0; idx < num; idx++) {
const double expected_value = hwy::ConvertScalarTo<double>(expected[idx]);
const double actual_value = hwy::ConvertScalarTo<double>(actual[idx]);
const double magnitude = std::abs(expected_value);
const double tolerance =
256.0 * hwy::ConvertScalarTo<double>(hwy::Epsilon<MatT>()) *
HWY_MAX(magnitude, 1.0);
if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) {
fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f\n", idx,
expected_value, idx, actual_value);
HWY_ASSERT(0);
}
}
}
template <size_t kOuter> template <size_t kOuter>
HWY_INLINE constexpr size_t RowsPerStrip() { HWY_INLINE constexpr size_t RowsPerStrip() {
// Aim for 128 work items to reduce pool overhead. Must be at least one // Aim for 128 work items to reduce pool overhead. Must be at least one
@ -362,11 +386,11 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
c23, c30, c31, c32, c33, tile_c, stride_c); c23, c30, c31, c32, c33, tile_c, stride_c);
} }
// Same as above, but with mixed Mat types: (f32, sfp). // Same as above, but with mixed Mat types: (f32, compressed).
template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA, template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
HWY_IF_F32(MatTA)> HWY_IF_F32(MatTA), typename MatTB, HWY_IF_T_SIZE(MatTB, 1)>
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
const SfpStream* HWY_RESTRICT B, const MatTB* HWY_RESTRICT B,
float* HWY_RESTRICT C, const size_t idx_tile, float* HWY_RESTRICT C, const size_t idx_tile,
const size_t xtiles, const size_t stride_a, const size_t xtiles, const size_t stride_a,
const size_t stride_b, const size_t stride_c) { const size_t stride_b, const size_t stride_c) {
@ -406,7 +430,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
hwy::AlignedFreeUniquePtr<float[]> tile_b_unique_ptr = hwy::AlignedFreeUniquePtr<float[]> tile_b_unique_ptr =
hwy::AllocateAligned<float>(kRegRows * kColsA_RowsB); hwy::AllocateAligned<float>(kRegRows * kColsA_RowsB);
CompressTraits<SfpStream>::Decompress( CompressTraits<MatTB>::Decompress(
d, d,
/*in_capacity=*/0, B, stride_b * row_b_col_c, tile_b_unique_ptr.get(), /*in_capacity=*/0, B, stride_b * row_b_col_c, tile_b_unique_ptr.get(),
kRegRows * kColsA_RowsB); kRegRows * kColsA_RowsB);
@ -455,11 +479,11 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
c23, c30, c31, c32, c33, tile_c, stride_c); c23, c30, c31, c32, c33, tile_c, stride_c);
} }
// Same as above, but with mixed Mat types: (bf16, sfp). // Same as above, but with mixed Mat types: (bf16, compressed)).
template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA, template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
HWY_IF_BF16(MatTA)> HWY_IF_BF16(MatTA), typename MatTB, HWY_IF_T_SIZE(MatTB, 1)>
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
const SfpStream* HWY_RESTRICT B, const MatTB* HWY_RESTRICT B,
float* HWY_RESTRICT C, const size_t idx_tile, float* HWY_RESTRICT C, const size_t idx_tile,
const size_t xtiles, const size_t stride_a, const size_t xtiles, const size_t stride_a,
const size_t stride_b, const size_t stride_c) { const size_t stride_b, const size_t stride_c) {
@ -504,7 +528,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
hwy::AlignedFreeUniquePtr<float[]> tile_b_unique_ptr = hwy::AlignedFreeUniquePtr<float[]> tile_b_unique_ptr =
hwy::AllocateAligned<float>(kRegRows * kColsA_RowsB); hwy::AllocateAligned<float>(kRegRows * kColsA_RowsB);
CompressTraits<SfpStream>::Decompress( CompressTraits<MatTB>::Decompress(
d32, d32,
/*in_capacity=*/0, B, stride_b * row_b_col_c, tile_b_unique_ptr.get(), /*in_capacity=*/0, B, stride_b * row_b_col_c, tile_b_unique_ptr.get(),
kRegRows * kColsA_RowsB); kRegRows * kColsA_RowsB);
@ -806,7 +830,37 @@ HWY_NOINLINE void MatMul_4x4_Batch(
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on // Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
// ops_test across instruction sets. // ops_test across instruction sets.
template <size_t kN, size_t kK, typename MatTA, typename MatTB> template <size_t kM, size_t kN, size_t kK, typename MatTA, typename MatTB>
HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b,
float* HWY_RESTRICT out) {
for (size_t i = 0; i < kM; ++i) {
for (size_t k = 0; k < kN; ++k) {
for (size_t j = 0; j < kK; ++j) {
const float a1 = hwy::ConvertScalarTo<float>(a[i * kN + k]);
const float b1 = hwy::ConvertScalarTo<float>(b[k * kK + j]);
out[i * kK + j] += a1 * b1;
}
}
}
}
template <size_t kM, size_t kN, size_t kK, typename MatTA>
HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a,
const SfpStream* HWY_RESTRICT b_sfp_stream,
float* HWY_RESTRICT out) {
const hn::ScalableTag<float> d;
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN);
CompressTraits<SfpStream>::Decompress(d,
/*in_capacity=*/0, b_sfp_stream, 0,
b.get(), kK * kN);
MatMulSlow<kM, kN, kK>(a, b.get(), out);
}
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
// ops_test across instruction sets.
template <size_t kN, size_t kK, typename MatTA, typename MatTB,
HWY_IF_T_SIZE_GT(MatTB, 1)>
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a, HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b, const MatTB* HWY_RESTRICT b,
float* HWY_RESTRICT out) { float* HWY_RESTRICT out) {
@ -821,15 +875,18 @@ HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
} }
} }
template <size_t kN, size_t kK, typename MatTA> // The above overload can handle combinations of f32 and bf16, but this one
// is required for MatTB = {SFP, NUQ}.
template <size_t kN, size_t kK, typename MatTA, typename MatTB,
HWY_IF_T_SIZE(MatTB, 1)>
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a, HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
const SfpStream* HWY_RESTRICT b_sfp_stream, const MatTB* HWY_RESTRICT b_compr,
float* HWY_RESTRICT out) { float* HWY_RESTRICT out) {
const hn::ScalableTag<float> d; const hn::ScalableTag<float> d;
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN); hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN);
CompressTraits<SfpStream>::Decompress(d, CompressTraits<MatTB>::Decompress(d,
/*in_capacity=*/0, b_sfp_stream, 0, /*in_capacity=*/0, b_compr, 0, b.get(),
b.get(), kK * kN); kK * kN);
MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), out); MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), out);
} }

View File

@ -506,28 +506,6 @@ void AssertClose(const hwy::AlignedFreeUniquePtr<float[]>& a,
} }
} }
template <typename MatT>
void AssertClose(const MatT* HWY_RESTRICT expected,
const MatT* HWY_RESTRICT actual, size_t num) {
for (size_t idx = 0; idx < num; idx++) {
const double expected_value = hwy::ConvertScalarTo<double>(expected[idx]);
const double actual_value = hwy::ConvertScalarTo<double>(actual[idx]);
const double magnitude = std::abs(expected_value);
const double tolerance =
64.0 * hwy::ConvertScalarTo<double>(hwy::Epsilon<MatT>()) *
HWY_MAX(magnitude, 1.0);
if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) {
fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f, tolerance: %f\n",
idx, expected_value, idx, actual_value, tolerance);
HWY_ASSERT(0);
}
}
}
template <size_t kM, size_t kN, size_t kK, typename MatTA, template <size_t kM, size_t kN, size_t kK, typename MatTA,
typename MatTB = MatTA> typename MatTB = MatTA>
void TestTiledBatchMatMul() { void TestTiledBatchMatMul() {