From ff34370aac66aaef2cb80700a6fd620539970c45 Mon Sep 17 00:00:00 2001 From: Daniel Keysers Date: Tue, 16 Jul 2024 09:40:38 -0700 Subject: [PATCH] Simplify FFW by using MatMul_4x4_Batch_Add. Affects only the griffin model, where prefill TPS improves by about 70%. PiperOrigin-RevId: 652878176 --- compression/compress.h | 2 + gemma/activations.h | 1 + gemma/gemma-inl.h | 86 +++++++++++++++++------------------------- 3 files changed, 37 insertions(+), 52 deletions(-) diff --git a/compression/compress.h b/compression/compress.h index 7df9b73..fc17f2b 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -76,6 +76,8 @@ class CompressedArray { public: using value_type = MatT; + // Note that whenever you access data(), you have to consider a scale() that + // may be different from 1.0f. MatT* data() { return data_.data(); } const MatT* data() const { return data_.data(); } diff --git a/gemma/activations.h b/gemma/activations.h index ffaa726..4b88bd4 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -59,6 +59,7 @@ struct Activations { // For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into // per-thread storage. + // TODO: only used for MatVec, remove once that is gone. std::array even_odd; // Griffin layer internal activations diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index a3812da..3426a63 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -68,11 +68,11 @@ HWY_NOINLINE void GriffinRecurrent( PROFILER_ZONE("Gen.Griffin"); static_assert(kQueryBatchSize == 1, "Griffin does not support batched queries."); - HWY_DASSERT(num_queries == 1); // TODO: add batch query support for Griffin. + HWY_ASSERT(num_queries == 1); // TODO: add batch query support for Griffin. KVCache& kv_cache = *kv_caches[0]; namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; - HWY_DASSERT(num_tokens <= kBatchSize); + HWY_ASSERT(num_tokens <= kBatchSize); static constexpr size_t kModelDim = gcpp::Activations::kModelDim; static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; @@ -397,64 +397,46 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_tokens, const CompressedLayer* layer_weights, hwy::ThreadPool& pool) { + PROFILER_ZONE("Gen.FFW"); HWY_DASSERT(num_tokens <= kBatchSize); constexpr size_t kModelDim = TConfig::kModelDim; constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; - float* HWY_RESTRICT even_odd = activations.even_odd.data(); - // TODO: MatMul does not yet support adding another matrix to the result. - if constexpr (!TConfig::kFFBiases) { - PROFILER_ZONE("Gen.FFW.GatedGELU"); + // MatMul expects col-major B, which is what we have: kModelDim consecutive + // elements in memory, repeated kFFHiddenDim times. + constexpr size_t kColsA = kModelDim; + constexpr size_t kColsB = kFFHiddenDim; + const auto A = activations.bf_pre_ffw_rms_out.data(); + const auto B1 = layer_weights->gating_einsum_w.data(); + const auto B2 = B1 + kColsA * kColsB; + auto C1 = activations.C1.data(); + auto C2 = activations.C2.data(); + constexpr bool kAddBias = TConfig::kFFBiases; + const auto bias = layer_weights->ffw_gating_biases.data(); - // MatMul expects col-major B, which is what we have: kModelDim consecutive - // elements in memory, repeated kFFHiddenDim times. - const auto b1 = layer_weights->gating_einsum_w.data(); - constexpr size_t kColsA = kModelDim; - constexpr size_t kColsB = kFFHiddenDim; - const auto b2 = b1 + kColsA * kColsB; - auto A = activations.bf_pre_ffw_rms_out.data(); - // Will go through GELU. - MatMul_4x4_Batch(num_tokens, A, b1, activations.C1.data(), - pool); - // What to multiply by. - MatMul_4x4_Batch(num_tokens, A, b2, activations.C2.data(), - pool); + // Will go through GELU. + MatMul_4x4_Batch_Add(num_tokens, A, B1, C1, + bias, pool); + // What to multiply by. + MatMul_4x4_Batch_Add(num_tokens, A, B2, C2, + bias + kFFHiddenDim, pool); - // Activation (Gelu) and multiply by gate. - Activation(activations.C1.data(), activations.C2.data(), - kFFHiddenDim * num_tokens); + // Activation (Gelu) and multiply by gate. Store activations in C1. + Activation(activations.C1.data(), activations.C2.data(), + kFFHiddenDim * num_tokens); - MatMul_4x4_Batch(num_tokens, activations.C1.data(), - layer_weights->linear_w.data(), - activations.ffw_out.data(), pool); - } else { // TConfig::kFFBiases == true - for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - const size_t hidden_offset = batch_idx * kFFHiddenDim * 2; - const hwy::bfloat16_t* HWY_RESTRICT vec = - activations.bf_pre_ffw_rms_out.data() + batch_idx * kModelDim; - float* HWY_RESTRICT out = activations.ffw_hidden.data() + hidden_offset; - float* HWY_RESTRICT out_mul = out + kFFHiddenDim; - - PROFILER_ZONE("Gen.FFW.GatedGELU"); - // Same matrix, first and second half of rows. Could fuse into one MatVec. - MatVecT( - layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec, - layer_weights->ffw_gating_biases.data() + kFFHiddenDim, even_odd, - out_mul, pool); - // Gate, will go through the nonlinearity. - MatVecT( - layer_weights->gating_einsum_w, 0, vec, - layer_weights->ffw_gating_biases.data(), even_odd, out, pool); - - Activation(out, out_mul, kFFHiddenDim); - - MatVecT( - layer_weights->linear_w, 0, - activations.ffw_hidden.data() + hidden_offset, - layer_weights->ffw_output_biases.data(), even_odd, - activations.ffw_out.data() + batch_idx * kModelDim, pool); - } + // linear_w may have a scale value different from 1, apply that here. + // We multiply all activations by the scale value to compensate for the + // missing scale value in the weights. + if (layer_weights->linear_w.scale() != 1.0f) { + MulByConst(layer_weights->linear_w.scale(), C1, kFFHiddenDim * num_tokens); } + + // Hidden layer -> output layer. + MatMul_4x4_Batch_Add( + num_tokens, C1, layer_weights->linear_w.data(), + activations.ffw_out.data(), layer_weights->ffw_output_biases.data(), + pool); } template