Integrate matmul into FFW: 4.3x prefill speedup

```
before, bf16:
27.2929 prefill tokens / sec
17.2114 tokens / sec

after, bf16
116.496 prefill tokens / sec
17.5391 tokens / sec
```

PiperOrigin-RevId: 643328437
This commit is contained in:
Jan Wassenberg 2024-06-14 06:32:00 -07:00 committed by Copybara-Service
parent 198326a682
commit 29c0c574e6
4 changed files with 136 additions and 67 deletions

View File

@ -136,7 +136,7 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig
static constexpr int kSeqLen = 32;
static constexpr int kVocabSize = 16;
static constexpr int kVocabSize = 64;
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
FixedLayerConfig<3>(LayerAttentionType::kGemma);
static constexpr int kLayers = kLayerConfig.size();
@ -146,8 +146,8 @@ struct ConfigGemmaTiny : public ConfigNoSSM {
NumLayersOfTypeBefore(kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
kLayers);
static constexpr int kModelDim = 64;
static constexpr int kFFHiddenDim = 128;
static constexpr int kModelDim = 128;
static constexpr int kFFHiddenDim = 256;
static constexpr int kHeads = 4;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 16; // query size == key size == value size

View File

@ -88,6 +88,11 @@ struct Activations {
att_post2; // accumulation of attention outputs over heads
std::array<hwy::bfloat16_t, kBatchSize * kModelDim> bf_pre_ffw_rms_out;
std::array<float, kBatchSize * TConfig::kFFHiddenDim * 2> ffw_hidden;
// For FFW MatMul.
std::array<float, kBatchSize * TConfig::kFFHiddenDim> C1;
std::array<float, kBatchSize * TConfig::kFFHiddenDim> C2;
// bf_ version can't be used until GeluMulToBF16 issue in FFW() is resolved.
// std::array<hwy::bfloat16_t, kBatchSize * 2 * TConfig::kFFHiddenDim>
// bf_ffw_hidden;
@ -508,41 +513,70 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
float* HWY_RESTRICT even_odd = activations.even_odd.data();
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
// TODO: MatMul does not yet support adding another matrix to the result.
if constexpr (!TConfig::kFFBiases) {
PROFILER_ZONE("Gen.FFW.GatedGELU");
const hwy::bfloat16_t* HWY_RESTRICT vec =
activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim;
float* HWY_RESTRICT out = activations.ffw_hidden.data() + hidden_offset;
float* HWY_RESTRICT out_mul = out + kFFHiddenDim;
// Same matrix, first and second half of rows. Could fuse into one MatVec.
MatVecT</*kAdd=*/TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec,
TConfig::kFFBiases ?
layer_weights->ffw_gating_biases.data() + kFFHiddenDim : nullptr,
even_odd, out_mul, pool);
// Gate, will go through the nonlinearity.
MatVecT</*kAdd=*/TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
layer_weights->gating_einsum_w, 0, vec,
layer_weights->ffw_gating_biases.data(), even_odd, out, pool);
// MatMul expects col-major B, which is what we have: kModelDim consecutive
// elements in memory, repeated kFFHiddenDim times.
const auto b1 = layer_weights->gating_einsum_w.data();
constexpr size_t kColsA = kModelDim;
constexpr size_t kColsB = kFFHiddenDim;
const auto b2 = b1 + kColsA * kColsB;
auto A = activations.bf_pre_ffw_rms_out.data();
// Will go through GELU.
MatMul_4x4_Batch<kColsA, kColsB>(num_tokens, A, b1, activations.C1.data(),
pool);
// What to multiply by.
MatMul_4x4_Batch<kColsA, kColsB>(num_tokens, A, b2, activations.C2.data(),
pool);
// Gelu and multiply by gate.
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
hn::Transform1(DF(), out, kFFHiddenDim, out_mul,
[](DF df, VF v, VF mul)
HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); });
}
hn::Transform1(DF(), activations.C1.data(), kFFHiddenDim * num_tokens,
activations.C2.data(), [](DF df, VF v, VF mul) HWY_ATTR {
return hn::Mul(mul, Gelu(df, v));
});
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
PROFILER_ZONE("Gen.FFW\\GatedGELU");
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
MatVecT</*kAdd=*/TConfig::kFFBiases, kModelDim, kFFHiddenDim>(
layer_weights->linear_w, 0,
activations.ffw_hidden.data() + hidden_offset,
layer_weights->ffw_output_biases.data(), even_odd,
activations.ffw_out.data() + batch_idx * kModelDim, pool);
MatMul_4x4_Batch<kFFHiddenDim, kModelDim>(num_tokens, activations.C1.data(),
layer_weights->linear_w.data(),
activations.ffw_out.data(), pool);
} else {
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t hidden_offset = batch_idx * kFFHiddenDim * 2;
const hwy::bfloat16_t* HWY_RESTRICT vec =
activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim;
float* HWY_RESTRICT out = activations.ffw_hidden.data() + hidden_offset;
float* HWY_RESTRICT out_mul = out + kFFHiddenDim;
PROFILER_ZONE("Gen.FFW.GatedGELU");
// Same matrix, first and second half of rows. Could fuse into one MatVec.
MatVecT<TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec,
TConfig::kFFBiases
? layer_weights->ffw_gating_biases.data() + kFFHiddenDim
: nullptr,
even_odd, out_mul, pool);
// Gate, will go through the nonlinearity.
MatVecT<TConfig::kFFBiases, kFFHiddenDim, kModelDim>(
layer_weights->gating_einsum_w, 0, vec,
layer_weights->ffw_gating_biases.data(), even_odd, out, pool);
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
using VF = hn::Vec<DF>;
hn::Transform1(DF(), out, kFFHiddenDim, out_mul,
[](DF df, VF v, VF mul)
HWY_ATTR { return hn::Mul(mul, Gelu(df, v)); });
MatVecT</*kAdd=*/TConfig::kFFBiases, kModelDim, kFFHiddenDim>(
layer_weights->linear_w, 0,
activations.ffw_hidden.data() + hidden_offset,
layer_weights->ffw_output_biases.data(), even_odd,
activations.ffw_out.data() + batch_idx * kModelDim, pool);
}
}
}

View File

@ -23,6 +23,7 @@
#include <stdio.h>
#include <array>
#include <cmath>
#include <random>
#include <type_traits> // std::enable_if_t
@ -70,6 +71,29 @@ StaticCast(From from) noexcept {
return static_cast<To>(from);
}
// For testing.
template <typename MatT>
void AssertClose(const MatT* HWY_RESTRICT expected,
const MatT* HWY_RESTRICT actual, size_t num) {
for (size_t idx = 0; idx < num; idx++) {
const double expected_value = hwy::ConvertScalarTo<double>(expected[idx]);
const double actual_value = hwy::ConvertScalarTo<double>(actual[idx]);
const double magnitude = std::abs(expected_value);
const double tolerance =
256.0 * hwy::ConvertScalarTo<double>(hwy::Epsilon<MatT>()) *
HWY_MAX(magnitude, 1.0);
if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) {
fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f\n", idx,
expected_value, idx, actual_value);
HWY_ASSERT(0);
}
}
}
template <size_t kOuter>
HWY_INLINE constexpr size_t RowsPerStrip() {
// Aim for 128 work items to reduce pool overhead. Must be at least one
@ -362,11 +386,11 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
c23, c30, c31, c32, c33, tile_c, stride_c);
}
// Same as above, but with mixed Mat types: (f32, sfp).
// Same as above, but with mixed Mat types: (f32, compressed).
template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
HWY_IF_F32(MatTA)>
HWY_IF_F32(MatTA), typename MatTB, HWY_IF_T_SIZE(MatTB, 1)>
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
const SfpStream* HWY_RESTRICT B,
const MatTB* HWY_RESTRICT B,
float* HWY_RESTRICT C, const size_t idx_tile,
const size_t xtiles, const size_t stride_a,
const size_t stride_b, const size_t stride_c) {
@ -406,7 +430,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
hwy::AlignedFreeUniquePtr<float[]> tile_b_unique_ptr =
hwy::AllocateAligned<float>(kRegRows * kColsA_RowsB);
CompressTraits<SfpStream>::Decompress(
CompressTraits<MatTB>::Decompress(
d,
/*in_capacity=*/0, B, stride_b * row_b_col_c, tile_b_unique_ptr.get(),
kRegRows * kColsA_RowsB);
@ -455,11 +479,11 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
c23, c30, c31, c32, c33, tile_c, stride_c);
}
// Same as above, but with mixed Mat types: (bf16, sfp).
// Same as above, but with mixed Mat types: (bf16, compressed)).
template <size_t kNumRows, size_t kColsA_RowsB, typename MatTA,
HWY_IF_BF16(MatTA)>
HWY_IF_BF16(MatTA), typename MatTB, HWY_IF_T_SIZE(MatTB, 1)>
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
const SfpStream* HWY_RESTRICT B,
const MatTB* HWY_RESTRICT B,
float* HWY_RESTRICT C, const size_t idx_tile,
const size_t xtiles, const size_t stride_a,
const size_t stride_b, const size_t stride_c) {
@ -504,7 +528,7 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
hwy::AlignedFreeUniquePtr<float[]> tile_b_unique_ptr =
hwy::AllocateAligned<float>(kRegRows * kColsA_RowsB);
CompressTraits<SfpStream>::Decompress(
CompressTraits<MatTB>::Decompress(
d32,
/*in_capacity=*/0, B, stride_b * row_b_col_c, tile_b_unique_ptr.get(),
kRegRows * kColsA_RowsB);
@ -806,7 +830,37 @@ HWY_NOINLINE void MatMul_4x4_Batch(
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
// ops_test across instruction sets.
template <size_t kN, size_t kK, typename MatTA, typename MatTB>
template <size_t kM, size_t kN, size_t kK, typename MatTA, typename MatTB>
HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b,
float* HWY_RESTRICT out) {
for (size_t i = 0; i < kM; ++i) {
for (size_t k = 0; k < kN; ++k) {
for (size_t j = 0; j < kK; ++j) {
const float a1 = hwy::ConvertScalarTo<float>(a[i * kN + k]);
const float b1 = hwy::ConvertScalarTo<float>(b[k * kK + j]);
out[i * kK + j] += a1 * b1;
}
}
}
}
template <size_t kM, size_t kN, size_t kK, typename MatTA>
HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a,
const SfpStream* HWY_RESTRICT b_sfp_stream,
float* HWY_RESTRICT out) {
const hn::ScalableTag<float> d;
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN);
CompressTraits<SfpStream>::Decompress(d,
/*in_capacity=*/0, b_sfp_stream, 0,
b.get(), kK * kN);
MatMulSlow<kM, kN, kK>(a, b.get(), out);
}
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
// ops_test across instruction sets.
template <size_t kN, size_t kK, typename MatTA, typename MatTB,
HWY_IF_T_SIZE_GT(MatTB, 1)>
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b,
float* HWY_RESTRICT out) {
@ -821,15 +875,18 @@ HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
}
}
template <size_t kN, size_t kK, typename MatTA>
// The above overload can handle combinations of f32 and bf16, but this one
// is required for MatTB = {SFP, NUQ}.
template <size_t kN, size_t kK, typename MatTA, typename MatTB,
HWY_IF_T_SIZE(MatTB, 1)>
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
const SfpStream* HWY_RESTRICT b_sfp_stream,
const MatTB* HWY_RESTRICT b_compr,
float* HWY_RESTRICT out) {
const hn::ScalableTag<float> d;
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN);
CompressTraits<SfpStream>::Decompress(d,
/*in_capacity=*/0, b_sfp_stream, 0,
b.get(), kK * kN);
CompressTraits<MatTB>::Decompress(d,
/*in_capacity=*/0, b_compr, 0, b.get(),
kK * kN);
MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), out);
}

View File

@ -506,28 +506,6 @@ void AssertClose(const hwy::AlignedFreeUniquePtr<float[]>& a,
}
}
template <typename MatT>
void AssertClose(const MatT* HWY_RESTRICT expected,
const MatT* HWY_RESTRICT actual, size_t num) {
for (size_t idx = 0; idx < num; idx++) {
const double expected_value = hwy::ConvertScalarTo<double>(expected[idx]);
const double actual_value = hwy::ConvertScalarTo<double>(actual[idx]);
const double magnitude = std::abs(expected_value);
const double tolerance =
64.0 * hwy::ConvertScalarTo<double>(hwy::Epsilon<MatT>()) *
HWY_MAX(magnitude, 1.0);
if (!(expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance)) {
fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f, tolerance: %f\n",
idx, expected_value, idx, actual_value, tolerance);
HWY_ASSERT(0);
}
}
}
template <size_t kM, size_t kN, size_t kK, typename MatTA,
typename MatTB = MatTA>
void TestTiledBatchMatMul() {