From 2ebbe4076f331b01035d91916e92f4e8776dd9bb Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 9 Aug 2024 01:22:46 -0700 Subject: [PATCH] 1.03-1.08x decode speedup: precompute Rope theta, fuse Split attention into functions, move into class. Fuse Rope and MulBy, allow non-in-place version to avoid copy from q to KV. Sink if() into MaybeLogitsSoftCap. PiperOrigin-RevId: 661168418 --- BUILD.bazel | 1 + backprop/backward-inl.h | 26 +- backprop/backward.cc | 27 +- backprop/backward.h | 11 +- backprop/backward_test.cc | 9 +- backprop/forward-inl.h | 22 +- backprop/forward.cc | 26 +- backprop/forward.h | 9 +- backprop/optimize_test.cc | 10 +- gemma/activations.h | 23 ++ gemma/gemma-inl.h | 509 +++++++++++++++++++++++--------------- ops/ops-inl.h | 48 ++-- 12 files changed, 443 insertions(+), 278 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 6620122..77f199b 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -455,6 +455,7 @@ cc_test( ":common", ":gemma_lib", ":ops", + ":prompt", ":sampler", "@googletest//:gtest_main", "//compression:weights_raw", diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h index 62e2d13..2b10cc4 100644 --- a/backprop/backward-inl.h +++ b/backprop/backward-inl.h @@ -27,6 +27,7 @@ #include "backprop/activations.h" #include "backprop/prompt.h" +#include "gemma/activations.h" // CreateInvTimescale #include "gemma/common.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -166,13 +167,12 @@ static HWY_NOINLINE void InputEmbeddingVJP( } } -template typename LayerT> +template typename LayerT> void LayerVJP(const LayerT& weights, const ForwardLayer& forward, - const float* HWY_RESTRICT next_layer_grad, - size_t num_tokens, - LayerT& grad, - ForwardLayer& backward, + const float* HWY_RESTRICT next_layer_grad, size_t num_tokens, + LayerT& grad, ForwardLayer& backward, + const RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kQKVDim = TConfig::kQKVDim; @@ -279,7 +279,7 @@ void LayerVJP(const LayerT& weights, for (int pos = 0; pos < static_cast(num_tokens); ++pos) { float* HWY_RESTRICT b_kv = backward.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim; - Rope(b_kv, kQKVDim, -pos); + Rope(b_kv, kQKVDim, inv_timescale.Const(), -pos); } for (size_t head = 0; head < kHeads; ++head) { @@ -287,7 +287,7 @@ void LayerVJP(const LayerT& weights, float* HWY_RESTRICT b_q = backward.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim; MulByConst(kQueryScale, b_q, kQKVDim); - Rope(b_q, kQKVDim, -pos); + Rope(b_q, kQKVDim, inv_timescale.Const(), -pos); } } @@ -342,13 +342,14 @@ static HWY_NOINLINE void CrossEntropyLossGrad( } } -template typename WeightsT, - template typename LayerT> +template typename WeightsT, + template typename LayerT> void CrossEntropyLossBackwardPass(const Prompt& prompt, const WeightsT& weights, const ForwardPass& forward, WeightsT& grad, ForwardPass& backward, + RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { static constexpr size_t kVocabSize = TConfig::kVocabSize; static constexpr size_t kModelDim = TConfig::kModelDim; @@ -398,9 +399,10 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt, float* next_layer_grad = layer + 1 < kLayers ? backward.layers[layer + 1].input.data() : backward.final_layer_output.data(); - LayerVJP( - *weights.GetLayer(layer), forward.layers[layer], next_layer_grad, - num_tokens, *grad.GetLayer(layer), backward.layers[layer], pool); + LayerVJP(*weights.GetLayer(layer), forward.layers[layer], + next_layer_grad, num_tokens, + *grad.GetLayer(layer), backward.layers[layer], + inv_timescale, pool); } InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens, diff --git a/backprop/backward.cc b/backprop/backward.cc index 89bbef3..27bbd0e 100644 --- a/backprop/backward.cc +++ b/backprop/backward.cc @@ -29,6 +29,7 @@ #include "hwy/highway.h" // After highway.h #include "backprop/backward-inl.h" +#include "gemma/activations.h" #include "gemma/weights.h" HWY_BEFORE_NAMESPACE(); @@ -41,6 +42,7 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt, const ByteStorageT& forward_u8, ByteStorageT& grad_u8, ByteStorageT& backward_u8, + RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { using TWeights = CompressedWeights; const auto& weights = *reinterpret_cast(weights_u8.get()); @@ -49,25 +51,24 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt, const auto& forward = *reinterpret_cast(forward_u8.get()); auto& backward = *reinterpret_cast(backward_u8.get()); CrossEntropyLossBackwardPass( - prompt, weights, forward, grad, backward, pool); + prompt, weights, forward, grad, backward, inv_timescale, pool); } -void CrossEntropyLossBackwardPassT(Model model, - const Prompt& prompt, +void CrossEntropyLossBackwardPassT(Model model, const Prompt& prompt, const ByteStorageT& weights, const ByteStorageT& forward, - ByteStorageT& grad, - ByteStorageT& backward, + ByteStorageT& grad, ByteStorageT& backward, + RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { // TODO(janwas): use CallFunctorForModel switch (model) { case Model::GEMMA_2B: CrossEntropyLossBackwardPass>( - prompt, weights, forward, grad, backward, pool); + prompt, weights, forward, grad, backward, inv_timescale, pool); break; case Model::GEMMA_TINY: CrossEntropyLossBackwardPass>( - prompt, weights, forward, grad, backward, pool); + prompt, weights, forward, grad, backward, inv_timescale, pool); break; default: HWY_ABORT("Model type %d unknown.", static_cast(model)); @@ -83,12 +84,14 @@ namespace gcpp { HWY_EXPORT(CrossEntropyLossBackwardPassT); -void CrossEntropyLossBackwardPass( - const Model& model, const Prompt& prompt, - const ByteStorageT& weights, const ByteStorageT& forward, - ByteStorageT& grad, ByteStorageT& backward, hwy::ThreadPool& pool) { +void CrossEntropyLossBackwardPass(const Model& model, const Prompt& prompt, + const ByteStorageT& weights, + const ByteStorageT& forward, + ByteStorageT& grad, ByteStorageT& backward, + RowVectorBatch& inv_timescale, + hwy::ThreadPool& pool) { return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)( - model, prompt, weights, forward, grad, backward, pool); + model, prompt, weights, forward, grad, backward, inv_timescale, pool); } } // namespace gcpp diff --git a/backprop/backward.h b/backprop/backward.h index aac2122..0ac218a 100644 --- a/backprop/backward.h +++ b/backprop/backward.h @@ -17,15 +17,18 @@ #define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_ #include "backprop/prompt.h" +#include "gemma/activations.h" #include "gemma/common.h" #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { -void CrossEntropyLossBackwardPass( - const Model& model, const Prompt& prompt, - const ByteStorageT& weights, const ByteStorageT& forward, - ByteStorageT& grad, ByteStorageT& backward, hwy::ThreadPool& pool); +void CrossEntropyLossBackwardPass(const Model& model, const Prompt& prompt, + const ByteStorageT& weights, + const ByteStorageT& forward, + ByteStorageT& grad, ByteStorageT& backward, + RowVectorBatch& inv_timescale, + hwy::ThreadPool& pool); } // namespace gcpp diff --git a/backprop/backward_test.cc b/backprop/backward_test.cc index bf8cf5f..b6c0780 100644 --- a/backprop/backward_test.cc +++ b/backprop/backward_test.cc @@ -21,6 +21,7 @@ #include #include +#include // std::abs #include #include @@ -28,9 +29,11 @@ #include "backprop/backward_scalar.h" #include "backprop/common_scalar.h" #include "backprop/forward_scalar.h" +#include "backprop/prompt.h" #include "backprop/sampler.h" #include "backprop/test_util.h" #include "compression/weights_raw.h" +#include "gemma/activations.h" #include "gemma/configs.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -214,6 +217,8 @@ void TestEndToEnd() { ReverseSequenceSampler training_task({0, 0, 1, 1}); std::vector batch = training_task.SampleBatch(3, gen); + RowVectorBatch inv_timescale = + Activations::CreateInvTimescale(); for (const Prompt& prompt : batch) { ReverseSequenceSampler::LogPrompt(prompt); RandInit(weights.get(), 1.0f, gen); @@ -223,14 +228,14 @@ void TestEndToEnd() { float loss1 = CrossEntropyLossForwardPass( prompt.tokens, prompt.context_size, weights.get(), forward1.get(), - pool); + inv_timescale, pool); EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5); grad.clear(); CrossEntropyLossBackwardPass( prompt, weights.get(), forward1.get(), grad.get(), backward.get(), - pool); + inv_timescale, pool); Complexify(weights.get(), c_weights.get()); auto func = [&]() { diff --git a/backprop/forward-inl.h b/backprop/forward-inl.h index b3ffd92..7dec634 100644 --- a/backprop/forward-inl.h +++ b/backprop/forward-inl.h @@ -24,6 +24,7 @@ #include #include "backprop/activations.h" +#include "gemma/activations.h" #include "gemma/common.h" #include "gemma/configs.h" #include "hwy/base.h" @@ -88,11 +89,11 @@ static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs, return loss * scaling; } -template typename LayerT> +template typename LayerT> void ApplyForwardLayer(const LayerT& weights, ForwardLayer& activations, - size_t num_tokens, - float* HWY_RESTRICT output, + size_t num_tokens, float* HWY_RESTRICT output, + const RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kSeqLen = TConfig::kSeqLen; @@ -117,14 +118,14 @@ void ApplyForwardLayer(const LayerT& weights, for (size_t pos = 0; pos < num_tokens; ++pos) { float* HWY_RESTRICT k = activations.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim; - Rope(k, kQKVDim, pos); + Rope(k, kQKVDim, inv_timescale.Const(), pos); } pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { const size_t head = task % kHeads; const size_t pos = task / kHeads; float* HWY_RESTRICT q = activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim; - Rope(q, kQKVDim, pos); + Rope(q, kQKVDim, inv_timescale.Const(), pos); MulByConst(kQueryScale, q, kQKVDim); }); @@ -222,12 +223,13 @@ void ApplyForwardLayer(const LayerT& weights, } } -template typename WeightsT, - template typename LayerT> +template typename WeightsT, + template typename LayerT> float CrossEntropyLossForwardPass(const std::vector& prompt, size_t context_size, const WeightsT& weights, ForwardPass& forward, + const RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { static constexpr size_t kVocabSize = TConfig::kVocabSize; static constexpr size_t kModelDim = TConfig::kModelDim; @@ -251,9 +253,9 @@ float CrossEntropyLossForwardPass(const std::vector& prompt, float* HWY_RESTRICT output = layer + 1 < kLayers ? forward.layers[layer + 1].input.data() : forward.final_layer_output.data(); - ApplyForwardLayer( - *weights.GetLayer(layer), forward.layers[layer], - num_tokens, output, pool); + ApplyForwardLayer(*weights.GetLayer(layer), + forward.layers[layer], num_tokens, + output, inv_timescale, pool); } ApplyRMSNorm(weights.final_norm_scale.data(), diff --git a/backprop/forward.cc b/backprop/forward.cc index 1357276..29721d2 100644 --- a/backprop/forward.cc +++ b/backprop/forward.cc @@ -17,6 +17,7 @@ #include "backprop/activations.h" #include "backprop/prompt.h" +#include "gemma/activations.h" #include "gemma/common.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -39,28 +40,31 @@ template float CrossEntropyLossForwardPass(const Prompt& prompt, const ByteStorageT& weights_u8, ByteStorageT& forward_u8, + RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { const auto& weights = *reinterpret_cast*>(weights_u8.get()); auto& forward = *reinterpret_cast*>(forward_u8.get()); - return - CrossEntropyLossForwardPass( - prompt.tokens, prompt.context_size, weights, forward, pool); + return CrossEntropyLossForwardPass( + prompt.tokens, prompt.context_size, weights, forward, inv_timescale, + pool); } float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt, const ByteStorageT& weights, ByteStorageT& forward, + RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { // TODO(janwas): use CallFunctorForModel switch (model) { case Model::GEMMA_2B: - return CrossEntropyLossForwardPass>(prompt, weights, - forward, pool); + return CrossEntropyLossForwardPass>( + prompt, weights, forward, inv_timescale, pool); case Model::GEMMA_TINY: return CrossEntropyLossForwardPass>( - prompt, weights, forward, pool); + prompt, weights, forward, inv_timescale, pool); default: HWY_ABORT("Model type %d unknown.", static_cast(model)); } @@ -75,11 +79,13 @@ namespace gcpp { HWY_EXPORT(CrossEntropyLossForwardPassT); -float CrossEntropyLossForwardPass( - const Model& model, const Prompt& prompt, const ByteStorageT& weights, - ByteStorageT& forward, hwy::ThreadPool& pool) { +float CrossEntropyLossForwardPass(const Model& model, const Prompt& prompt, + const ByteStorageT& weights, + ByteStorageT& forward, + RowVectorBatch& inv_timescale, + hwy::ThreadPool& pool) { return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)( - model, prompt, weights, forward, pool); + model, prompt, weights, forward, inv_timescale, pool); } } // namespace gcpp diff --git a/backprop/forward.h b/backprop/forward.h index 4950f37..92ca371 100644 --- a/backprop/forward.h +++ b/backprop/forward.h @@ -17,14 +17,17 @@ #define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ #include "backprop/prompt.h" +#include "gemma/activations.h" #include "gemma/common.h" #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { -float CrossEntropyLossForwardPass( - const Model& model, const Prompt& prompt, const ByteStorageT& weights, - ByteStorageT& forward, hwy::ThreadPool& pool); +float CrossEntropyLossForwardPass(const Model& model, const Prompt& prompt, + const ByteStorageT& weights, + ByteStorageT& forward, + RowVectorBatch& inv_timescale, + hwy::ThreadPool& pool); } // namespace gcpp diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 6a4f030..23174a0 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -26,6 +26,7 @@ #include "backprop/optimizer.h" #include "backprop/prompt.h" #include "backprop/sampler.h" +#include "gemma/activations.h" #include "gemma/common.h" #include "gemma/gemma.h" #include "gemma/weights.h" @@ -56,6 +57,9 @@ TEST(OptimizeTest, GradientDescent) { CallForModelAndWeight(info.model, info.weight); KVCache kv_cache = KVCache::Create(info.model, /*prefill_tbatch_size=*/16); + RowVectorBatch inv_timescale = + Activations::CreateInvTimescale>(); + Gemma gemma(GemmaTokenizer(), info, pools); const auto generate = [&](const std::vector& prompt) { @@ -118,10 +122,10 @@ TEST(OptimizeTest, GradientDescent) { num_ok = 0; for (size_t i = 0; i < kBatchSize; ++i) { Prompt prompt = training_task.Sample(sgen); - total_loss += CrossEntropyLossForwardPass(info.model, prompt, - gemma.Weights(), forward, pool); + total_loss += CrossEntropyLossForwardPass( + info.model, prompt, gemma.Weights(), forward, inv_timescale, pool); CrossEntropyLossBackwardPass(info.model, prompt, gemma.Weights(), forward, - grad, backward, pool); + grad, backward, inv_timescale, pool); num_ok += verify(prompt) ? 1 : 0; } total_loss /= kBatchSize; diff --git a/gemma/activations.h b/gemma/activations.h index 9e3cc4e..72a0ab8 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -18,6 +18,8 @@ #include +#include + #include "gemma/common.h" // kMaxThreads - TODO: remove #include "hwy/aligned_allocator.h" #include "hwy/base.h" // HWY_DASSERT @@ -54,6 +56,7 @@ class RowVectorBatch { // For MatMul or other operations that process the entire batch at once. T* All() { return mem_.get(); } + const T* Const() const { return mem_.get(); } size_t NumBytes() const { return batch_size_ * len_ * sizeof(T); } private: @@ -88,6 +91,9 @@ struct Activations { RowVectorBatch griffin_gate_x; RowVectorBatch griffin_multiplier; + // Rope + RowVectorBatch inv_timescale; + // For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into // per-thread storage. // TODO: remove once MatVec is gone. @@ -106,6 +112,21 @@ struct Activations { return TConfig::kQKVDim * (IsMHA() ? 3 : 1); } + template + static RowVectorBatch CreateInvTimescale() { + constexpr size_t kQKVDim = TConfig::kQKVDim; + const size_t rope_dim = TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim; + RowVectorBatch inv_timescale(1, rope_dim / 2); + for (size_t dim = 0; dim < rope_dim / 2; ++dim) { + const float freq_exponents = + static_cast(2 * dim) / static_cast(rope_dim); + // Replacing with expf(ln(1E4) * freq_exponents) changes results + // noticeably. + inv_timescale.Batch(0)[dim] = 1.0f / std::pow(10000.0f, freq_exponents); + } + return inv_timescale; + } + template void Allocate(size_t batch_size) { constexpr size_t kModelDim = TConfig::kModelDim; @@ -138,6 +159,8 @@ struct Activations { griffin_multiplier = RowVectorBatch(batch_size, kModelDim); } + inv_timescale = CreateInvTimescale(); + even_odd = RowVectorBatch(1, kModelDim * kMaxThreads); } }; diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 6c31356..10d7350 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -196,216 +196,317 @@ HWY_NOINLINE void GriffinRecurrent( } } -template -HWY_NOINLINE void PostQK(T* HWY_RESTRICT inout, size_t pos, size_t layer) { - constexpr size_t kQKVDim = TConfig::kQKVDim; - // PostQKType::Rope - Rope(inout, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); -} - +// Wrapper class; holds arguments in member variables to shorten call sites. template -HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, - size_t num_queries, size_t layer, - Activations& activations, - const CompressedLayer* layer_weights, - const KVCaches& kv_caches, - hwy::ThreadPool& pool) { - PROFILER_ZONE("Gen.Attention"); - HWY_DASSERT(interleaved_start % num_queries == 0); - constexpr size_t kQKVDim = TConfig::kQKVDim; - constexpr size_t kQStride = Activations::QStride(); - constexpr size_t kCachePosSize = CachePosSize()(); - constexpr size_t kCacheLayerSize = CacheLayerSize()(); - constexpr size_t kModelDim = TConfig::kModelDim; - constexpr size_t kHeads = TConfig::kHeads; - constexpr size_t kKVHeads = TConfig::kKVHeads; - constexpr size_t kSeqLen = TConfig::kSeqLen; - GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale(); +class GemmaAttention { + static constexpr size_t kCacheLayerSize = CacheLayerSize()(); + static constexpr size_t kCachePosSize = CachePosSize()(); + static constexpr size_t kHeads = TConfig::kHeads; + static constexpr size_t kKVHeads = TConfig::kKVHeads; + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kQKVDim = TConfig::kQKVDim; + static constexpr size_t kQStride = Activations::QStride(); + static constexpr size_t kSeqLen = TConfig::kSeqLen; + static constexpr bool kIsMHA = Activations::IsMHA(); - HWY_ASSERT(num_queries <= kv_caches.size()); - const hwy::Divisor div_seq_len(static_cast(kv_caches[0].seq_len)); + // The attention window usually starts at 0 unless unless `pos` is larger than + // the attention window size, then it is `pos` - window_size + 1. + static HWY_INLINE size_t StartPos(size_t pos, size_t layer) { + const size_t att_window_size = TConfig::kAttentionWindowSizes[layer]; + return pos - std::min(att_window_size - 1, pos); + } - // Multi-Head Attention a.k.a. "use_qkv_einsum". - constexpr bool kIsMHA = Activations::IsMHA(); - static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved - const size_t batch_start = interleaved_start / num_queries; - const size_t num_interleaved = num_tokens * num_queries; - - // For the computation of Q, K, and V, it is useful to remember that - // qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim] - // and kQStride = kQKVDim * (kIsMHA ? 3 : 1); - // - // Compute Q only or QKV (if MHA). - // If MHA, this also computes KV, which we copy to the KV cache below. - MatMul_4x4( - num_interleaved, MakeMat(activations.pre_att_rms_out.All(), kModelDim), - MakeMat(layer_weights->qkv_einsum_w.data(), kModelDim), - layer_weights->qkv_einsum_w.scale(), /*add=*/nullptr, - MakeMat(activations.q.All(), kHeads * kQStride), pool); - - // Compute KV if not MHA. - if constexpr (!kIsMHA) { - // Single query and no wraparound means we can use a matmul and write - // directly into the KV cache with a stride of kCachePosSize. - if (num_queries == 1 && - batch_start + num_tokens <= div_seq_len.GetDivisor()) { - const size_t kv_ofs = - batch_start * kCachePosSize + layer * kCacheLayerSize; - // KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v). - float* HWY_RESTRICT kv = kv_caches[0].kv_cache.get() + kv_ofs; - MatMul_4x4( - num_tokens, MakeMat(activations.pre_att_rms_out.All(), kModelDim), - MakeMat(layer_weights->qkv_einsum_w.data(), kModelDim, kModelDim, - kHeads * kQKVDim * kModelDim), - layer_weights->qkv_einsum_w.scale(), /*add=*/nullptr, - MakeMat(kv, kKVHeads * 2 * kQKVDim, kCachePosSize), pool); + template + HWY_INLINE void PositionalEncodingQK(const T* qk, size_t pos, size_t layer, + const float mul, T* qk_out) { + const float* inv_timescale = activations_.inv_timescale.Const(); + // PostQKType::Rope + (void)layer; + if (TConfig::kUseHalfRope) { + hwy::CopyBytes(qk, qk_out, kQKVDim * sizeof(*qk)); + Rope(qk_out, kQKVDim / 2, inv_timescale, pos); + MulByConst(mul, qk_out, kQKVDim); } else { - for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; - ++interleaved_idx) { - const float* x = activations.pre_att_rms_out.Batch(interleaved_idx); - const size_t query_idx = interleaved_idx % num_queries; - const size_t batch_idx = interleaved_idx / num_queries; - KVCache& kv_cache = kv_caches[query_idx]; - const size_t cache_pos = div_seq_len.Remainder(batch_start + batch_idx); - const size_t kv_offset = - cache_pos * kCachePosSize + layer * kCacheLayerSize; - float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; + RopeAndMulBy(mul, qk, kQKVDim, inv_timescale, pos, qk_out); + } + } + + // Fills activations.q and computes KV. For kIsMHA, a single MatMul suffices + // and we later copy KV from q to KVCache. Otherwise, a second MatMul writes + // KV directly to KVCache. + HWY_NOINLINE void ComputeQKV(const size_t batch_start, + const size_t num_interleaved) { + PROFILER_ZONE("Gen.Attention.QKV"); + // For the computation of Q, K, and V, it is useful to remember that + // qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim] + // and kQStride = kQKVDim * (kIsMHA ? 3 : 1); + + const auto pre_att_rms_out = + MakeMat(activations_.pre_att_rms_out.All(), kModelDim); + MatMul_4x4( + num_interleaved, pre_att_rms_out, + MakeMat(layer_weights_.qkv_einsum_w.data(), kModelDim), + layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr, + MakeMat(activations_.q.All(), kHeads * kQStride), pool_); + + if constexpr (kIsMHA) { + static_assert(TConfig::kInterleaveQKV, "MHA implies interleaved"); + // Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already. + } else { + // Single query and no wraparound means we can use a matmul and write + // directly into the KV cache with a stride of kCachePosSize. + if (num_queries_ == 1 && + batch_start + num_tokens_ <= div_seq_len_.GetDivisor()) { + const size_t kv_ofs = + batch_start * kCachePosSize + layer_ * kCacheLayerSize; // KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v). - MatVec( - layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x, - activations.even_odd.All(), kv, pool); + float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs; + MatMul_4x4( + num_tokens_, pre_att_rms_out, + MakeMat(layer_weights_.qkv_einsum_w.data(), kModelDim, kModelDim, + kHeads * kQKVDim * kModelDim), + layer_weights_.qkv_einsum_w.scale(), /*add=*/nullptr, + MakeMat(kv, kKVHeads * 2 * kQKVDim, kCachePosSize), pool_); + } else { + // Proceed row by row because there will be wraparound. + for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; + ++interleaved_idx) { + const float* x = activations_.pre_att_rms_out.Batch(interleaved_idx); + const size_t query_idx = interleaved_idx % num_queries_; + const size_t batch_idx = interleaved_idx / num_queries_; + KVCache& kv_cache = kv_caches_[query_idx]; + const size_t cache_pos = + div_seq_len_.Remainder(batch_start + batch_idx); + const size_t kv_offset = + cache_pos * kCachePosSize + layer_ * kCacheLayerSize; + float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; + // KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v). + MatVec( + layer_weights_.qkv_einsum_w, kHeads * kQKVDim * kModelDim, x, + activations_.even_odd.All(), kv, pool_); + } + } + } + + // Apply positional encodings for K (and copy KV to cache if MHA). + pool_.Run( + 0, kKVHeads * num_interleaved, + [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + const size_t head = task % kKVHeads; + const size_t interleaved_idx = task / kKVHeads; + const size_t query_idx = interleaved_idx % num_queries_; + const size_t batch_idx = interleaved_idx / num_queries_; + const size_t pos = batch_start + batch_idx; + const size_t cache_pos = div_seq_len_.Remainder(pos); + const size_t kv_offset = cache_pos * kCachePosSize + + layer_ * kCacheLayerSize + + head * kQKVDim * 2; + KVCache& kv_cache = kv_caches_[query_idx]; + float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; + const float* HWY_RESTRICT mha_kv = + activations_.q.Batch(interleaved_idx) + head * kQStride + kQKVDim; + + // Copy from `q` if MHA, or apply in-place. + PositionalEncodingQK(kIsMHA ? mha_kv : kv, pos, layer_, 1.0f, kv); + + // If MHA, also copy V into KVCache. + if (kIsMHA) { + hwy::CopyBytes(mha_kv + kQKVDim, kv + kQKVDim, + kQKVDim * sizeof(*kv)); + } + }); + } + + // Computes Q.K scores, which are "logits" (or scores) stored to head_att. + HWY_INLINE void QDotK(const size_t start_pos, const size_t pos, + const size_t head_offset, const float* HWY_RESTRICT q, + const KVCache& kv_cache, float* HWY_RESTRICT head_att) { + if (HWY_LIKELY(pos <= kSeqLen)) { + // Slightly faster: no wraparound. + for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { + const size_t kv_offset = + pos2 * kCachePosSize + layer_ * kCacheLayerSize + head_offset; + const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset]; + const float score = Dot(q, k, kQKVDim); + head_att[pos2] = score; + } + } else { + for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { + const size_t cache_pos = div_seq_len_.Remainder(pos2); + const size_t kv_offset = + cache_pos * kCachePosSize + layer_ * kCacheLayerSize + head_offset; + const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset]; + const float score = Dot(q, k, kQKVDim); + head_att[pos2 % kSeqLen] = score; } } } - // Apply positional encodings for K (and copy KV to cache if MHA). - pool.Run( - 0, kKVHeads * num_interleaved, - [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - const size_t head = task % kKVHeads; - const size_t interleaved_idx = task / kKVHeads; - const size_t query_idx = interleaved_idx % num_queries; - const size_t batch_idx = interleaved_idx / num_queries; - const size_t pos = batch_start + batch_idx; - const size_t cache_pos = div_seq_len.Remainder(pos); - const size_t kv_offset = cache_pos * kCachePosSize + - layer * kCacheLayerSize + head * kQKVDim * 2; - KVCache& kv_cache = kv_caches[query_idx]; - float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; - if constexpr (kIsMHA) { - // For MHA, copy KV into the KV cache from scratch space (see above). - const float* HWY_RESTRICT q = - activations.q.Batch(interleaved_idx) + head * kQStride; - // Skip past the Q part of `q`, and copy KV to `kv`. - hwy::CopyBytes(q + kQKVDim, kv, 2 * kQKVDim * sizeof(float)); - } - PostQK(kv, pos, layer); - }); + // Accumulates the sum of v (from `kv_cache`) * probability (`head_att`) into + // `att_out`. Equivalent in gemma/modules.py: + // encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) + static HWY_INLINE void WeightedSumV(const size_t start_pos, const size_t pos, + const float* HWY_RESTRICT head_att, + const size_t layer, + const size_t head_offset, + const hwy::Divisor& div_seq_len, + const KVCache& kv_cache, + float* HWY_RESTRICT att_out) { + hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); - // A "head group" in the context of GQA refers to a collection of query heads - // that share the same key and value heads. - static_assert((kHeads % kKVHeads) == 0, - "query heads must be a multiple of key-value heads"); - constexpr size_t kHeadGroups = kHeads / kKVHeads; - // For each head (token, query), compute Q.K, softmax, and weighted V. - pool.Run( - 0, kHeads * num_interleaved, - [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - const size_t head = task % kHeads; - const size_t interleaved_idx = task / kHeads; - const size_t query_idx = interleaved_idx % num_queries; - const size_t batch_idx = interleaved_idx / num_queries; - const size_t head_offset = (head / kHeadGroups) * kQKVDim * 2; - KVCache& kv_cache = kv_caches[query_idx]; - float* HWY_RESTRICT q = - activations.q.Batch(interleaved_idx) + head * kQStride; - - // Apply rope and scaling to Q. - const size_t pos = batch_start + batch_idx; - PostQK(q, pos, layer); - MulByConst(kQueryScale, q, kQKVDim); - - // Compute Q.K scores, yielding "logits" (or scores) in head_att. - float* HWY_RESTRICT head_att = - activations.att.Batch(interleaved_idx) + head * kSeqLen; - // Usually start_pos is 0, unless pos is larger than the attention - // window size, then it is pos - window_size + 1. - const size_t start_pos = - pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos); - for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { - const size_t cache_pos = div_seq_len.Remainder(pos2); - const size_t kv_offset = - cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset; - const float* HWY_RESTRICT k = &kv_cache.kv_cache[kv_offset]; - const float score = Dot(q, k, kQKVDim); - head_att[pos2 % kSeqLen] = score; - } - - // SoftMax. May be preceded by SoftCap. Yields "probabilities" in - // head_att. - const size_t head_att_len = std::min(pos + 1, kSeqLen); - if constexpr (TConfig::kAttCap > 0.0f) { - LogitsSoftCap(TConfig::kAttCap, head_att, head_att_len); - } - Softmax(head_att, head_att_len); - - // Summation of v (kv_cache) weighted by probs (head_att) - // into "encoded" (att_out). Compare gemma/modules.py: - // encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) - float* HWY_RESTRICT att_out = - activations.att_out.Batch(interleaved_idx) + head * kQKVDim; - hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); - for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { - const size_t cache_pos = div_seq_len.Remainder(pos2); - const size_t kv_offset = - cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset; - float* HWY_RESTRICT v = - kv_cache.kv_cache.get() + kv_offset + kQKVDim; - MulByConstAndAdd(head_att[pos2 % kSeqLen], v, att_out, kQKVDim); - } - }); - - // Sum encoded (att_out) over num_heads and head_dim (kQKVDim) - // into output (layer_out). Compare gemma/modules.py: - // attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded) - for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; - ++interleaved_idx) { - // TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after - // rearranging the weights. - float* HWY_RESTRICT att_out = activations.att_out.Batch(interleaved_idx); - float* HWY_RESTRICT layer_out = - activations.att_post2.Batch(interleaved_idx); - // Head 0 (and potentially biases) -> layer_out. - // attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim]. - constexpr bool kAdd = TConfig::kSoftmaxAttnOutputBiases; - const float* bias = - kAdd ? layer_weights->attention_output_biases.data_scale1() : nullptr; - MatVecT( - layer_weights->attn_vec_einsum_w, 0, att_out, bias, - activations.even_odd.All(), layer_out, pool); - // Head 1 and following are added to layer_out. - for (size_t head = 1; head < kHeads; ++head) { - // NOTE: this is a single kModelDim temp output. If parallelized or using - // MatMul, add per-thread storage. - float* HWY_RESTRICT head_out = activations.att_post1.All(); - // TODO: requires MatMul support for offsets. - MatVec( - layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, - att_out + head * kQKVDim, activations.even_odd.All(), head_out, pool); - AddFrom(head_out, layer_out, kModelDim); + if (HWY_LIKELY(pos <= kSeqLen)) { + // Slightly faster: no wraparound. + for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { + const size_t kv_offset = + pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; + const float* HWY_RESTRICT v = + kv_cache.kv_cache.get() + kv_offset + kQKVDim; + MulByConstAndAdd(head_att[pos2], v, att_out, kQKVDim); + } + } else { + for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { + const size_t cache_pos = div_seq_len.Remainder(pos2); + const size_t kv_offset = + cache_pos * kCachePosSize + layer * kCacheLayerSize + head_offset; + const float* HWY_RESTRICT v = + kv_cache.kv_cache.get() + kv_offset + kQKVDim; + MulByConstAndAdd(head_att[pos2 % kSeqLen], v, att_out, kQKVDim); + } } } -} + + HWY_NOINLINE void DotSoftmaxWeightedSum(const size_t batch_start, + const size_t num_interleaved) { + PROFILER_ZONE("Gen.Attention.DotSoftmax"); + GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale(); + + // A "head group" in the context of GQA refers to a collection of query + // heads that share the same key and value heads. + static_assert((kHeads % kKVHeads) == 0, + "query heads must be a multiple of key-value heads"); + constexpr size_t kHeadGroups = kHeads / kKVHeads; + + // For each head (token, query), compute Q.K, softmax, and weighted V. + pool_.Run(0, kHeads * num_interleaved, + [&](uint64_t task, size_t /*thread*/) HWY_ATTR { + const size_t head = task % kHeads; + const size_t interleaved_idx = task / kHeads; + const size_t query_idx = interleaved_idx % num_queries_; + const size_t batch_idx = interleaved_idx / num_queries_; + const size_t head_offset = (head / kHeadGroups) * kQKVDim * 2; + KVCache& kv_cache = kv_caches_[query_idx]; + float* HWY_RESTRICT q = + activations_.q.Batch(interleaved_idx) + head * kQStride; + + // Apply rope and scaling to Q. + const size_t pos = batch_start + batch_idx; + PositionalEncodingQK(q, pos, layer_, kQueryScale, q); + + const size_t start_pos = StartPos(pos, layer_); + + float* HWY_RESTRICT head_att = + activations_.att.Batch(interleaved_idx) + head * kSeqLen; + QDotK(start_pos, pos, head_offset, q, kv_cache, head_att); + // SoftMax with optional SoftCap yields "probabilities" in + // head_att. + const size_t head_att_len = std::min(pos + 1, kSeqLen); + MaybeLogitsSoftCap(TConfig::kAttCap, head_att, head_att_len); + Softmax(head_att, head_att_len); + + float* HWY_RESTRICT att_out = + activations_.att_out.Batch(interleaved_idx) + + head * kQKVDim; + WeightedSumV(start_pos, pos, head_att, layer_, head_offset, + div_seq_len_, kv_cache, att_out); + }); + } + + // Sums encoded (`att_out`) over num_heads and head_dim (kQKVDim) into output + // (`layer_out`). Compare gemma/modules.py: + // attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded) + HWY_NOINLINE void SumHeads(const size_t num_interleaved) { + PROFILER_ZONE("Gen.Attention.SumHeads"); + for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; + ++interleaved_idx) { + // TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after + // rearranging the weights. + float* HWY_RESTRICT att_out = activations_.att_out.Batch(interleaved_idx); + float* HWY_RESTRICT layer_out = + activations_.att_post2.Batch(interleaved_idx); + // Head 0 (and potentially biases) -> layer_out. + // attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim]. + constexpr bool kAdd = TConfig::kSoftmaxAttnOutputBiases; + const float* bias = + kAdd ? layer_weights_.attention_output_biases.data_scale1() : nullptr; + MatVecT( + layer_weights_.attn_vec_einsum_w, 0, att_out, bias, + activations_.even_odd.All(), layer_out, pool_); + // Head 1 and following are added to layer_out. + for (size_t head = 1; head < kHeads; ++head) { + // NOTE: this is a single kModelDim temp output. If parallelized or + // using MatMul, add per-thread storage. + float* HWY_RESTRICT head_out = activations_.att_post1.All(); + // TODO: requires MatMul support for offsets. + MatVec( + layer_weights_.attn_vec_einsum_w, head * kModelDim * kQKVDim, + att_out + head * kQKVDim, activations_.even_odd.All(), head_out, + pool_); + AddFrom(head_out, layer_out, kModelDim); + } + } + } + + public: + GemmaAttention(size_t interleaved_start, size_t num_tokens, + size_t num_queries, size_t layer, Activations& activations, + const CompressedLayer* layer_weights, + const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, + hwy::ThreadPool& pool) + : interleaved_start_(interleaved_start), + num_tokens_(num_tokens), + num_queries_(num_queries), + layer_(layer), + activations_(activations), + layer_weights_(*layer_weights), + div_seq_len_(div_seq_len), + kv_caches_(kv_caches), + pool_(pool) { + HWY_DASSERT(interleaved_start_ % num_queries_ == 0); + HWY_DASSERT(num_queries_ <= kv_caches_.size()); + } + + HWY_INLINE void operator()() { + const size_t batch_start = interleaved_start_ / num_queries_; + const size_t num_interleaved = num_tokens_ * num_queries_; + + ComputeQKV(batch_start, num_interleaved); + DotSoftmaxWeightedSum(batch_start, num_interleaved); + SumHeads(num_interleaved); + } + + private: + const size_t interleaved_start_; + const size_t num_tokens_; + const size_t num_queries_; + const size_t layer_; + Activations& activations_; + const CompressedLayer& layer_weights_; + const hwy::Divisor& div_seq_len_; + const KVCaches& kv_caches_; + hwy::ThreadPool& pool_; +}; template HWY_NOINLINE void Attention(LayerAttentionType type, size_t interleaved_start, size_t num_tokens, size_t num_queries, size_t layer, Activations& activations, const CompressedLayer* layer_weights, + const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, hwy::ThreadPool& pool) { if (type == LayerAttentionType::kGemma) { GemmaAttention(interleaved_start, num_tokens, num_queries, layer, - activations, layer_weights, kv_caches, pool); + activations, layer_weights, div_seq_len, kv_caches, + pool)(); } else { // Only reached if the model is Griffin. `if constexpr` prevents generating // this code for non-Griffin models. @@ -421,6 +522,7 @@ HWY_NOINLINE void Attention(LayerAttentionType type, size_t interleaved_start, template HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2, size_t count) { + PROFILER_ZONE("Gen.Activation"); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; @@ -516,7 +618,8 @@ template HWY_NOINLINE void TransformerLayer( size_t num_tokens, size_t num_queries, size_t pos, size_t layer, const CompressedLayer* layer_weights, Activations& activations, - const KVCaches& kv_caches, hwy::ThreadPool& pool) { + const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, + hwy::ThreadPool& pool) { constexpr size_t kModelDim = TConfig::kModelDim; const size_t num_interleaved = num_tokens * num_queries; auto type = TConfig::kLayerConfig[layer]; @@ -528,7 +631,7 @@ HWY_NOINLINE void TransformerLayer( activations.pre_att_rms_out.All(), kModelDim); Attention(type, pos, num_tokens, num_queries, layer_of_type, - activations, layer_weights, kv_caches, pool); + activations, layer_weights, div_seq_len, kv_caches, pool); PostNorm(num_interleaved, layer_weights->post_attention_norm_scale, activations.att_post2.All()); @@ -606,6 +709,7 @@ class PrefillState { const size_t query_idx_start, const CompressedWeights& weights, const RuntimeConfig& runtime_config, + const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, PerClusterPools& pools) { PROFILER_ZONE("Gen.Prefill"); const size_t num_queries = prompts.size(); @@ -638,9 +742,10 @@ class PrefillState { // Transformer with one batch of tokens from a single query. for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { const auto* layer_weights = weights.GetLayer(layer); - TransformerLayer( - tbatch_size, kPrefillQueries, pos + tbatch_start, layer, - layer_weights, activations, prefill_kv_caches, inner_pool); + TransformerLayer(tbatch_size, kPrefillQueries, + pos + tbatch_start, layer, + layer_weights, activations, div_seq_len, + prefill_kv_caches, inner_pool); } // NOTE: we unconditionally call StreamToken, even if EOS. @@ -664,6 +769,7 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, size_t num_queries, size_t pos, const CompressedWeights& weights, Activations& activations, + const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, hwy::ThreadPool& pool, const LayersOutputFunc& layers_output) { const size_t num_interleaved = num_tokens * num_queries; @@ -684,7 +790,8 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { const CompressedLayer* layer_weights = weights.GetLayer(layer); TransformerLayer(num_tokens, num_queries, pos, layer, - layer_weights, activations, kv_caches, pool); + layer_weights, activations, div_seq_len, + kv_caches, pool); if (layers_output) { for (size_t token_idx = 0; token_idx < num_interleaved; ++token_idx) { @@ -822,6 +929,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096. HWY_ASSERT(num_queries <= activations.x.BatchSize()); HWY_ASSERT(kv_caches.size() == num_queries); + const hwy::Divisor div_seq_len(static_cast(kv_caches[0].seq_len)); size_t min_prompt_size, max_prompt_size; const std::vector prompt = InterleaveQueries( @@ -857,7 +965,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, pools); prefill_start = hwy::platform::Now(); prefill.Prefill(prompts, prefill_per_query, pos, query_idx_start, - weights, runtime_config, kv_caches, pools); + weights, runtime_config, div_seq_len, kv_caches, + pools); timing_info.NotifyPrefill(prefill_per_query * num_queries, prefill_start); } @@ -881,8 +990,8 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, ++gen_per_query) { // Decode: generate one token for each query. Transformer(gen_tokens.data(), /*num_tokens=*/1, num_queries, - interleaved_pos, weights, activations, kv_caches, pool, - runtime_config.layers_output); + interleaved_pos, weights, activations, div_seq_len, + kv_caches, pool, runtime_config.layers_output); interleaved_pos += num_queries; bool all_queries_eos = true; @@ -895,9 +1004,7 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations, MakeMat(activations.logits.All(), kVocabSize), pool); for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); - if constexpr (TConfig::kFinalCap > 0.0f) { - LogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize); - } + MaybeLogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize); Softmax(logits, kVocabSize); const int token = sample_token(logits, kVocabSize); timing_info.NotifyGenerated(prefill_start, gen_start); diff --git a/ops/ops-inl.h b/ops/ops-inl.h index eb5c15e..9a8e7c0 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -381,16 +381,16 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings( of this rotation matrix which is simply the same matrix with -pos parameter) */ -static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x, - size_t dim_qkv, int pos) { +// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate. +// This overload is called from backprop/ and if kUseHalfRope. +static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope( + float* HWY_RESTRICT x, size_t dim_qkv, + const float* HWY_RESTRICT inv_timescale, int pos) { + PROFILER_FUNC; HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; for (size_t dim = 0; dim < half_dim_qkv; ++dim) { - const float freq_exponents = - StaticCast(2 * dim) / StaticCast(dim_qkv); - // Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably. - const float timescale = powf(10000.0f, freq_exponents); - const float theta = StaticCast(pos) / timescale; + const float theta = StaticCast(pos) * inv_timescale[dim]; const float cos_val = cosf(theta); const float sin_val = sinf(theta); const float x0 = x[dim]; @@ -400,24 +400,23 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x, } } -static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul, - float* HWY_RESTRICT x, - size_t dim_qkv, - int pos) { +// TODO(janwas): vectorize +// `inv_timescale[dim_qkv / 2]` is precomputed in Activations::Allocate. +static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy( + const float mul, const float* HWY_RESTRICT x, size_t dim_qkv, + const float* HWY_RESTRICT inv_timescale, int pos, + float* HWY_RESTRICT x_out) { + PROFILER_FUNC; HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; for (size_t dim = 0; dim < half_dim_qkv; ++dim) { - const float freq_exponents = - StaticCast(2 * dim) / StaticCast(dim_qkv); - // Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably. - const float timescale = powf(10000.0f, freq_exponents); - const float theta = StaticCast(pos) / timescale; + const float theta = StaticCast(pos) * inv_timescale[dim]; const float cos_val = cosf(theta); const float sin_val = sinf(theta); const float x0 = x[dim]; const float x1 = x[dim + half_dim_qkv]; - x[dim] = mul * (x0 * cos_val - x1 * sin_val); - x[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val); + x_out[dim] = mul * (x0 * cos_val - x1 * sin_val); + x_out[dim + half_dim_qkv] = mul * (x0 * sin_val + x1 * cos_val); } } @@ -577,12 +576,19 @@ static HWY_NOINLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, }); } -static HWY_INLINE HWY_MAYBE_UNUSED void LogitsSoftCap(const float cap, - float* HWY_RESTRICT x, - const size_t size) { +static HWY_INLINE void LogitsSoftCap(const float cap, float* HWY_RESTRICT x, + const size_t size) { LogitsSoftCap(cap, x, size, size); } +// Calls LogitsSoftCap if cap != 0.0f. +static HWY_INLINE HWY_MAYBE_UNUSED void MaybeLogitsSoftCap( + const float cap, float* HWY_RESTRICT x, const size_t size) { + if (cap != 0.0f) { + LogitsSoftCap(cap, x, size, size); + } +} + static HWY_NOINLINE HWY_MAYBE_UNUSED size_t SampleArgmax(const float* probabilities, size_t vocab_size) { size_t max_index = 0;