From 992a2cbbc0e726cb2b4e30026d17abd0d7b742f6 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 17 Jul 2024 04:37:40 -0700 Subject: [PATCH] De-templatize Activations, add RowVectorBatch class Also remove most kBatchSize args. PiperOrigin-RevId: 653185525 --- gemma/activations.h | 165 ++++++++++++++------- gemma/gemma-inl.h | 346 +++++++++++++++++++------------------------- gemma/gemma.cc | 65 +++++---- gemma/gemma.h | 9 +- gemma/ops.h | 13 +- 5 files changed, 305 insertions(+), 293 deletions(-) diff --git a/gemma/activations.h b/gemma/activations.h index 4b88bd4..9e3cc4e 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -18,77 +18,130 @@ #include -#include - -#include "gemma/common.h" // AllocateSizeof -#include "hwy/base.h" // hwy::bfloat16_t +#include "gemma/common.h" // kMaxThreads - TODO: remove +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" // HWY_DASSERT namespace gcpp { -// Must be aligned. -template +// Owns dynamically-allocated aligned memory for a batch of row vectors. +// This can be seen as a (batch_size x len) matrix. +template +class RowVectorBatch { + public: + // Default ctor for Activations ctor. + RowVectorBatch() : batch_size_(0), len_(0) {} + // Main ctor, called from Activations::Allocate. + RowVectorBatch(size_t batch_size, size_t len) + : batch_size_(batch_size), len_(len) { + mem_ = hwy::AllocateAligned(batch_size * len); + } + + // Move-only + RowVectorBatch(RowVectorBatch&) noexcept = delete; + RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete; + RowVectorBatch(RowVectorBatch&&) noexcept = default; + RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default; + + size_t BatchSize() const { return batch_size_; } + size_t Len() const { return len_; } + + // Returns the given row vector of length `Len()`. + T* Batch(size_t batch_idx) { + HWY_DASSERT(batch_idx < batch_size_); + return mem_.get() + batch_idx * len_; + } + + // For MatMul or other operations that process the entire batch at once. + T* All() { return mem_.get(); } + size_t NumBytes() const { return batch_size_ * len_ * sizeof(T); } + + private: + hwy::AlignedFreeUniquePtr mem_; + size_t batch_size_; // rows in the matrix + size_t len_; // columns in the matrix = vector length +}; + struct Activations { - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kQKVDim = TConfig::kQKVDim; - static constexpr size_t kHeads = TConfig::kHeads; - static constexpr size_t kKVHeads = TConfig::kKVHeads; - static constexpr bool kIsMHA = kHeads == kKVHeads; // Multi-Head Attention - // Stride between subsequent queries. Each of Q, K, V are of length kQKVDim, - // but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V. - static constexpr size_t kQStride = kQKVDim * (kIsMHA ? 3 : 1); + RowVectorBatch x; // input + RowVectorBatch q; // query, also KV if MHA. + RowVectorBatch logits; - std::array x; // input - std::array pre_att_rms_out; - std::array q; // query vector - std::array - att; // attention vector - std::array att_out; // attention output - std::array - att_post1; // attention output after linear transformation, per head - std::array - att_post2; // accumulation of attention outputs over heads - std::array bf_pre_ffw_rms_out; - std::array ffw_hidden; + // Attention + RowVectorBatch pre_att_rms_out; + RowVectorBatch att; // attention vector + RowVectorBatch att_out; // attention output + // After linear transformation, shared by all heads + RowVectorBatch att_post1; + // Accumulation of attention outputs over heads + RowVectorBatch att_post2; - // For FFW MatMul. - std::array C1; - std::array C2; + // Gated FFW + RowVectorBatch bf_pre_ffw_rms_out; + RowVectorBatch C1; + RowVectorBatch C2; + RowVectorBatch ffw_out; - std::array ffw_out; - std::array logits; + // Griffin + RowVectorBatch griffin_x; + RowVectorBatch griffin_y; + RowVectorBatch griffin_gate_x; + RowVectorBatch griffin_multiplier; // For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into // per-thread storage. - // TODO: only used for MatVec, remove once that is gone. - std::array even_odd; + // TODO: remove once MatVec is gone. + RowVectorBatch even_odd; - // Griffin layer internal activations - static constexpr size_t kGriffinDim = - TConfig::kGriffinLayers > 0 ? kModelDim : 0; - std::array griffin_x; - std::array griffin_y; - std::array griffin_gate_x; - std::array griffin_multiplier; -}; + // Multi-Head Attention? + template + static constexpr bool IsMHA() { + return TConfig::kHeads == TConfig::kKVHeads; + } -template -struct AllocateState { - void operator()(ByteStorageT& prefill, ByteStorageT& decode) const { - // When batching queries, the prefill batch size is reduced by a factor - // of kBatchedQueryBatchSize - prefill = - AllocateSizeof>(); - decode = AllocateSizeof< - Activations>(); + // Stride between subsequent queries. Each of Q, K, V are of length kQKVDim, + // but for MHA we store them as Q,K,V, Q,K,V, .. instead of Q..Q, K..K, V..V. + template + static constexpr size_t QStride() { + return TConfig::kQKVDim * (IsMHA() ? 3 : 1); + } + + template + void Allocate(size_t batch_size) { + constexpr size_t kModelDim = TConfig::kModelDim; + constexpr size_t kQKVDim = TConfig::kQKVDim; + constexpr size_t kHeads = TConfig::kHeads; + constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + constexpr size_t kVocabSize = TConfig::kVocabSize; + constexpr size_t kSeqLen = TConfig::kSeqLen; + constexpr size_t kGriffinLayers = TConfig::kGriffinLayers; + + x = RowVectorBatch(batch_size, kModelDim); + q = RowVectorBatch(batch_size, kHeads * QStride()); + logits = RowVectorBatch(batch_size, kVocabSize); + + pre_att_rms_out = RowVectorBatch(batch_size, kModelDim); + att = RowVectorBatch(batch_size, kHeads * kSeqLen); + att_out = RowVectorBatch(batch_size, kHeads * kQKVDim); + att_post1 = RowVectorBatch(1, kModelDim); + att_post2 = RowVectorBatch(batch_size, kModelDim); + + bf_pre_ffw_rms_out = RowVectorBatch(batch_size, kModelDim); + C1 = RowVectorBatch(batch_size, kFFHiddenDim); + C2 = RowVectorBatch(batch_size, kFFHiddenDim); + ffw_out = RowVectorBatch(batch_size, kModelDim); + + if (kGriffinLayers > 0) { + griffin_x = RowVectorBatch(batch_size, kModelDim); + griffin_y = RowVectorBatch(batch_size, kModelDim); + griffin_gate_x = RowVectorBatch(batch_size, kModelDim); + griffin_multiplier = RowVectorBatch(batch_size, kModelDim); + } + + even_odd = RowVectorBatch(1, kModelDim * kMaxThreads); } }; -template -Activations& GetActivations(const ByteStorageT& state_u8) { - return *reinterpret_cast*>(state_u8.get()); -} - } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 3426a63..cfc6edf 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -30,6 +30,7 @@ #include #include +#include #include #include "gemma/activations.h" @@ -59,33 +60,27 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -template +template HWY_NOINLINE void GriffinRecurrent( size_t batch_start, size_t num_tokens, size_t num_queries, size_t layer, - Activations& activations, - const CompressedLayer* layer_weights, + Activations& activations, const CompressedLayer* layer_weights, const std::vector& kv_caches, hwy::ThreadPool& pool) { PROFILER_ZONE("Gen.Griffin"); - static_assert(kQueryBatchSize == 1, - "Griffin does not support batched queries."); HWY_ASSERT(num_queries == 1); // TODO: add batch query support for Griffin. KVCache& kv_cache = *kv_caches[0]; namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; - HWY_ASSERT(num_tokens <= kBatchSize); - static constexpr size_t kModelDim = - gcpp::Activations::kModelDim; + static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; static constexpr size_t kHeads = TConfig::kHeads; // X / Y linear layers. for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - const size_t batch_offset = batch_idx * kModelDim; - float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset; - float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; + float* HWY_RESTRICT y = activations.griffin_y.Batch(batch_idx); + float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx); TwoMatVecAdd( layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0, - activations.pre_att_rms_out.data() + batch_offset, + activations.pre_att_rms_out.Batch(batch_idx), /*add0=*/layer_weights->griffin.linear_x_biases.data(), /*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x, /*out1=*/y, pool); @@ -94,9 +89,8 @@ HWY_NOINLINE void GriffinRecurrent( // Conv1D. for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - const size_t batch_offset = batch_idx * kModelDim; const size_t pos = batch_start + batch_idx; - float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; + float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx); HWY_FULL(float) df; HWY_DASSERT(kModelDim % hn::Lanes(df) == 0); const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1); @@ -130,14 +124,11 @@ HWY_NOINLINE void GriffinRecurrent( // RGLRU for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - const size_t batch_offset = batch_idx * kModelDim; const size_t pos = batch_start + batch_idx; - float* HWY_RESTRICT y = activations.griffin_y.data() + batch_offset; - float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; - float* HWY_RESTRICT gate_x = - activations.griffin_gate_x.data() + batch_offset; - float* HWY_RESTRICT a = - activations.griffin_multiplier.data() + batch_offset; + float* HWY_RESTRICT y = activations.griffin_y.Batch(batch_idx); + float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx); + float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Batch(batch_idx); + float* HWY_RESTRICT a = activations.griffin_multiplier.Batch(batch_idx); float* HWY_RESTRICT rnn_state = kv_cache.rglru_cache.get() + layer * kModelDim; @@ -185,13 +176,12 @@ HWY_NOINLINE void GriffinRecurrent( // Final linear layer. for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - const size_t batch_offset = batch_idx * kModelDim; - float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; - float* out_ptr = activations.att_post2.data() + batch_idx * kModelDim; + float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx); + float* out_ptr = activations.att_post2.Batch(batch_idx); MatVecAdd( layer_weights->griffin.linear_out_w, 0, x, layer_weights->griffin.linear_out_biases.data(), - activations.even_odd.data(), out_ptr, pool); + activations.even_odd.All(), out_ptr, pool); } } @@ -202,30 +192,26 @@ HWY_NOINLINE void PostQK(T* HWY_RESTRICT t, size_t pos, size_t layer) { Rope(t, TConfig::kUseHalfRope ? kQKVDim / 2 : kQKVDim, pos); } -template -HWY_NOINLINE void Attention( - size_t batch_and_query_start, size_t num_tokens, size_t num_queries, - size_t layer, - Activations& activations, - const CompressedLayer* layer_weights, - const std::vector& kv_caches, - hwy::ThreadPool& pool) { +template +HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens, + size_t num_queries, size_t layer, + Activations& activations, + const CompressedLayer* layer_weights, + const std::vector& kv_caches, + hwy::ThreadPool& pool) { PROFILER_ZONE("Gen.Attention"); - HWY_DASSERT(num_tokens <= kBatchSize); - HWY_DASSERT(num_queries <= kQueryBatchSize); HWY_DASSERT(batch_and_query_start % num_queries == 0); - using TActivations = Activations; - constexpr size_t kQKVDim = TActivations::kQKVDim; - constexpr size_t kQStride = TActivations::kQStride; + constexpr size_t kQKVDim = TConfig::kQKVDim; + constexpr size_t kQStride = Activations::QStride(); constexpr size_t kCachePosSize = CachePosSize()(); constexpr size_t kCacheLayerSize = CacheLayerSize()(); - constexpr size_t kModelDim = TActivations::kModelDim; + constexpr size_t kModelDim = TConfig::kModelDim; constexpr size_t kHeads = TConfig::kHeads; constexpr size_t kKVHeads = TConfig::kKVHeads; constexpr size_t kSeqLen = TConfig::kSeqLen; GEMMA_CONSTEXPR_SQRT float kQueryScale = ChooseQueryScale(); // Multi-Head Attention a.k.a. "use_qkv_einsum". - constexpr bool kIsMHA = TActivations::kIsMHA; + constexpr bool kIsMHA = Activations::IsMHA(); static_assert(!kIsMHA || TConfig::kInterleaveQKV); // MHA => interleaved const size_t batch_start = batch_and_query_start / num_queries; const size_t num_tokens_and_queries = num_tokens * num_queries; @@ -237,15 +223,14 @@ HWY_NOINLINE void Attention( // Compute Q only or QKV (if MHA). // If MHA, this also computes KV, which we copy to the KV cache below. MatMul_4x4_Batch( - num_tokens_and_queries, activations.pre_att_rms_out.data(), - layer_weights->qkv_einsum_w.data(), activations.q.data(), pool); + num_tokens_and_queries, activations.pre_att_rms_out.All(), + layer_weights->qkv_einsum_w.data(), activations.q.All(), pool); // Compute KV if not MHA. if constexpr (!kIsMHA) { for (size_t batch_and_query_idx = 0; batch_and_query_idx < num_tokens_and_queries; ++batch_and_query_idx) { - const float* x = - activations.pre_att_rms_out.data() + batch_and_query_idx * kModelDim; + const float* x = activations.pre_att_rms_out.Batch(batch_and_query_idx); const size_t query_idx = batch_and_query_idx % num_queries; const size_t batch_idx = batch_and_query_idx / num_queries; KVCache& kv_cache = *kv_caches[query_idx]; @@ -258,7 +243,7 @@ HWY_NOINLINE void Attention( // TODO: requires MatMul support for offsets. MatVec( layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x, - activations.even_odd.data(), kv, pool); + activations.even_odd.All(), kv, pool); } } @@ -279,8 +264,7 @@ HWY_NOINLINE void Attention( if constexpr (kIsMHA) { // For MHA, copy KV into the KV cache from scratch space (see above). const float* HWY_RESTRICT q = - activations.q.data() + (batch_and_query_idx * kHeads - + head) * kQStride; + activations.q.Batch(batch_and_query_idx) + head * kQStride; // Skip past the Q part of `q`, and copy KV to `kv`. memcpy(kv, q + kQKVDim, 2 * kQKVDim * sizeof(float)); } @@ -300,7 +284,7 @@ HWY_NOINLINE void Attention( const size_t head_offset = (head / kGroupHeads) * kQKVDim * 2; KVCache& kv_cache = *kv_caches[query_idx]; float* HWY_RESTRICT q = - activations.q.data() + (batch_and_query_idx * kHeads + head) * kQStride; + activations.q.Batch(batch_and_query_idx) + head * kQStride; // Apply rope and scaling to Q. const size_t pos = batch_start + batch_idx; @@ -309,11 +293,7 @@ HWY_NOINLINE void Attention( // Compute Q.K scores, yielding "logits" (or scores) in head_att. float* HWY_RESTRICT head_att = - activations.att.data() + head * kSeqLen - + batch_and_query_idx * kHeads * kSeqLen; - - - // Compute Q dot K scores + activations.att.Batch(batch_and_query_idx) + head * kSeqLen; const size_t start_pos = pos - std::min(TConfig::kAttentionWindowSizes[layer] - 1, pos); for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { @@ -335,8 +315,8 @@ HWY_NOINLINE void Attention( // Summation of v (kv_cache) weighted by probs (head_att) // into "encoded" (att_out). Compare gemma/modules.py: // encoded = jnp.einsum('BTNS,BSNH->BTNH', probs, value_proj) - float* HWY_RESTRICT att_out = activations.att_out.data() + head * kQKVDim + - batch_and_query_idx * kHeads * kQKVDim; + float* HWY_RESTRICT att_out = + activations.att_out.Batch(batch_and_query_idx) + head * kQKVDim; hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); for (size_t pos2 = start_pos; pos2 <= pos; ++pos2) { const size_t cache_pos = pos2 % (kSeqLen + kPrefillBatchSize); @@ -355,26 +335,24 @@ HWY_NOINLINE void Attention( // TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after // rearranging the weights. float* HWY_RESTRICT att_out = - activations.att_out.data() + batch_and_query_idx * kHeads * kQKVDim; + activations.att_out.Batch(batch_and_query_idx); float* HWY_RESTRICT layer_out = - activations.att_post2.data() + batch_and_query_idx * kModelDim; + activations.att_post2.Batch(batch_and_query_idx); // Head 0 (and potentially biases) -> layer_out. // attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim]. MatVecT( layer_weights->attn_vec_einsum_w, 0, att_out, layer_weights->attention_output_biases.data(), - activations.even_odd.data(), layer_out, pool); + activations.even_odd.All(), layer_out, pool); // Head 1 and following are added to layer_out. for (size_t head = 1; head < kHeads; ++head) { - // TODO(patrickms): Check this calculation - float* HWY_RESTRICT head_out = - activations.att_post1.data() + - head * kBatchSize * kQueryBatchSize * kModelDim; + // NOTE: this is a single kModelDim temp output. If parallelized or using + // MatMul, add per-thread storage. + float* HWY_RESTRICT head_out = activations.att_post1.All(); // TODO: requires MatMul support for offsets. MatVec( layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, - att_out + head * kQKVDim, - activations.even_odd.data(), head_out, pool); + att_out + head * kQKVDim, activations.even_odd.All(), head_out, pool); AddFrom(head_out, layer_out, kModelDim); } } @@ -392,13 +370,11 @@ HWY_NOINLINE void Activation(T* HWY_RESTRICT c1, T* HWY_RESTRICT c2, }); } -template -HWY_NOINLINE void FFW(Activations& activations, - size_t num_tokens, +template +HWY_NOINLINE void FFW(Activations& activations, size_t num_tokens, const CompressedLayer* layer_weights, hwy::ThreadPool& pool) { PROFILER_ZONE("Gen.FFW"); - HWY_DASSERT(num_tokens <= kBatchSize); constexpr size_t kModelDim = TConfig::kModelDim; constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; @@ -406,11 +382,12 @@ HWY_NOINLINE void FFW(Activations& activations, // elements in memory, repeated kFFHiddenDim times. constexpr size_t kColsA = kModelDim; constexpr size_t kColsB = kFFHiddenDim; - const auto A = activations.bf_pre_ffw_rms_out.data(); + HWY_DASSERT(num_tokens <= activations.bf_pre_ffw_rms_out.BatchSize()); + const auto A = activations.bf_pre_ffw_rms_out.All(); const auto B1 = layer_weights->gating_einsum_w.data(); const auto B2 = B1 + kColsA * kColsB; - auto C1 = activations.C1.data(); - auto C2 = activations.C2.data(); + auto C1 = activations.C1.All(); + auto C2 = activations.C2.All(); constexpr bool kAddBias = TConfig::kFFBiases; const auto bias = layer_weights->ffw_gating_biases.data(); @@ -422,8 +399,7 @@ HWY_NOINLINE void FFW(Activations& activations, bias + kFFHiddenDim, pool); // Activation (Gelu) and multiply by gate. Store activations in C1. - Activation(activations.C1.data(), activations.C2.data(), - kFFHiddenDim * num_tokens); + Activation(C1, C2, kFFHiddenDim * num_tokens); // linear_w may have a scale value different from 1, apply that here. // We multiply all activations by the scale value to compensate for the @@ -434,123 +410,115 @@ HWY_NOINLINE void FFW(Activations& activations, // Hidden layer -> output layer. MatMul_4x4_Batch_Add( - num_tokens, C1, layer_weights->linear_w.data(), - activations.ffw_out.data(), layer_weights->ffw_output_biases.data(), - pool); + num_tokens, C1, layer_weights->linear_w.data(), activations.ffw_out.All(), + layer_weights->ffw_output_biases.data(), pool); } -template -HWY_NOINLINE void EmbedToken(int token, size_t token_idx, size_t pos, +// TODO: pass Activations.x instead of Activations. +template +HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos, const CompressedWeights& weights, - Activations& activations) { + Activations& activations) { constexpr size_t kModelDim = TConfig::kModelDim; GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling = EmbeddingScaling(); HWY_DASSERT(token >= 0); HWY_DASSERT(token < TConfig::kVocabSize); Decompress(weights.embedder_input_embedding, token * kModelDim, - activations.x.data() + token_idx * kModelDim, kModelDim); - MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim, - kModelDim); + activations.x.Batch(batch_idx), kModelDim); + MulByConst(kEmbScaling, activations.x.Batch(batch_idx), kModelDim); if constexpr (TConfig::kAbsolutePE) { - AddAbsolutePositionalEmbeddings( - activations.x.data() + token_idx * kModelDim, kModelDim, - pos + token_idx); + AddAbsolutePositionalEmbeddings(activations.x.Batch(batch_idx), kModelDim, + pos + batch_idx); }; } -template +template HWY_NOINLINE void ResidualConnection( size_t num_tokens_and_queries, T* HWY_RESTRICT other, T* HWY_RESTRICT x, const CompressedLayer* layer_weights, bool is_attention) { constexpr size_t kModelDim = TConfig::kModelDim; // ResidualType::Add - AddFromBatched(num_tokens_and_queries, other, x, - kModelDim); + AddFromBatched(num_tokens_and_queries, other, x, kModelDim); } -template +template HWY_NOINLINE void TransformerLayer( size_t num_tokens, size_t num_queries, size_t pos, size_t layer, - const CompressedLayer* layer_weights, - Activations& activations, + const CompressedLayer* layer_weights, Activations& activations, const std::vector& kv_caches, hwy::ThreadPool& pool) { constexpr size_t kModelDim = TConfig::kModelDim; const size_t num_tokens_and_queries = num_tokens * num_queries; auto type = TConfig::kLayerConfig[layer]; size_t layer_of_type = NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer); - RMSNormBatched( - num_tokens_and_queries, activations.x.data(), - layer_weights->pre_attention_norm_scale.data(), - activations.pre_att_rms_out.data(), kModelDim); + RMSNormBatched(num_tokens_and_queries, activations.x.All(), + layer_weights->pre_attention_norm_scale.data(), + activations.pre_att_rms_out.All(), kModelDim); if (type == LayerAttentionType::kGemma) { - Attention( - pos, num_tokens, num_queries, layer_of_type, activations, - layer_weights, kv_caches, pool); + Attention(pos, num_tokens, num_queries, layer_of_type, activations, + layer_weights, kv_caches, pool); } else { // This Griffin layers should never exist unless the model is a Griffin // model. This conditional prevents the compiler from generating code for // this branch when building a non-Griffin model, since we have static // asserts about the query batch size for Griffin layers. if constexpr (TConfig::kGriffinLayers > 0) { - GriffinRecurrent( - pos, num_tokens, num_queries, layer_of_type, activations, - layer_weights, kv_caches, pool); + static_assert(kQueryBatchSize == 1, + "Griffin does not support batched queries."); + GriffinRecurrent(pos, num_tokens, num_queries, layer_of_type, + activations, layer_weights, kv_caches, pool); } } if (TConfig::kPostNorm == PostNormType::Scale) { - RMSNormInplaceBatched( - num_tokens_and_queries, - layer_weights->post_attention_norm_scale.data(), - activations.att_post2.data(), kModelDim); + RMSNormInplaceBatched(num_tokens_and_queries, + layer_weights->post_attention_norm_scale.data(), + activations.att_post2.All(), kModelDim); } - ResidualConnection( - num_tokens_and_queries, activations.att_post2.data(), - activations.x.data(), layer_weights, /*is_attention*/ true); - RMSNormBatched( - num_tokens_and_queries, activations.x.data(), - layer_weights->pre_ffw_norm_scale.data(), - activations.bf_pre_ffw_rms_out.data(), kModelDim); - FFW( - activations, num_tokens_and_queries, layer_weights, pool); + ResidualConnection(num_tokens_and_queries, + activations.att_post2.All(), activations.x.All(), + layer_weights, /*is_attention=*/true); + RMSNormBatched(num_tokens_and_queries, activations.x.All(), + layer_weights->pre_ffw_norm_scale.data(), + activations.bf_pre_ffw_rms_out.All(), kModelDim); + FFW(activations, num_tokens_and_queries, layer_weights, pool); if (TConfig::kPostNorm == PostNormType::Scale) { - RMSNormInplaceBatched( - num_tokens_and_queries, layer_weights->post_ffw_norm_scale.data(), - activations.ffw_out.data(), kModelDim); + RMSNormInplaceBatched(num_tokens_and_queries, + layer_weights->post_ffw_norm_scale.data(), + activations.ffw_out.All(), kModelDim); } - ResidualConnection( - num_tokens_and_queries, activations.ffw_out.data(), activations.x.data(), - layer_weights, /*is_attention*/ false); + ResidualConnection(num_tokens_and_queries, activations.ffw_out.All(), + activations.x.All(), layer_weights, + /*is_attention=*/false); } template -HWY_NOINLINE void Prefill( - const int* tokens, size_t num_tokens, size_t num_queries, size_t pos, - const CompressedWeights& weights, - Activations& activations, - const std::vector& kv_caches, hwy::ThreadPool& pool) { +HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, + size_t num_queries, size_t pos, + const CompressedWeights& weights, + Activations& activations, + const std::vector& kv_caches, + hwy::ThreadPool& pool) { + PROFILER_ZONE("Gen.Prefill"); HWY_DASSERT(num_queries <= kQueryBatchSize); const size_t minibatch_size = std::min(num_tokens, kBatchSize); - PROFILER_ZONE("Gen.Prefill\\Att\\FFW"); - // TODO(patrickms): Try to hoist pool.Run out of the loop. + // TODO: hoist pool.Run out of the loop, change the unit of work to batches. for (size_t i = 0; i < num_tokens; i += minibatch_size) { const size_t offset = i * num_queries; const size_t current_token_count = std::min( minibatch_size, num_tokens - i); pool.Run(0, current_token_count * num_queries, - [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { - EmbedToken( - tokens[token_idx + offset], token_idx, pos + offset, - weights, activations); - }); + [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { + EmbedToken(tokens[token_idx + offset], token_idx, + pos + offset, weights, activations); + }); for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { const auto* layer_weights = weights.GetLayer(layer); - TransformerLayer( - current_token_count, num_queries, pos + offset , layer, layer_weights, + TransformerLayer( + current_token_count, num_queries, pos + offset, layer, layer_weights, activations, kv_caches, pool); } } @@ -558,15 +526,14 @@ HWY_NOINLINE void Prefill( // Compute the transformer for a batch of input tokens. During generation, // we usually have num_tokens == 1 (and also kBatchSize == 1). -template -HWY_NOINLINE void Transformer( - const int* tokens, size_t num_tokens, size_t num_queries, size_t pos, - const CompressedWeights& weights, - Activations& activations, - const std::vector& kv_caches, - hwy::ThreadPool& pool, - const LayersOutputFunc& layers_output) { - HWY_ASSERT(num_tokens <= kBatchSize); +template +HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, + size_t num_queries, size_t pos, + const CompressedWeights& weights, + Activations& activations, + const std::vector& kv_caches, + hwy::ThreadPool& pool, + const LayersOutputFunc& layers_output) { const size_t num_tokens_and_queries = num_tokens * num_queries; if (layers_output) { for (size_t token_idx = 0; token_idx < num_tokens_and_queries; @@ -577,34 +544,33 @@ HWY_NOINLINE void Transformer( } constexpr size_t kModelDim = TConfig::kModelDim; for (size_t token_idx = 0; token_idx < num_tokens_and_queries; ++token_idx) { - EmbedToken( - tokens[token_idx], token_idx, pos, weights, activations); + EmbedToken(tokens[token_idx], token_idx, pos, weights, + activations); } for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { const CompressedLayer* layer_weights = weights.GetLayer(layer); - TransformerLayer( - num_tokens, num_queries, pos, layer, layer_weights, - activations, kv_caches, pool); + TransformerLayer(num_tokens, num_queries, pos, + layer, layer_weights, + activations, kv_caches, pool); if (layers_output) { const std::string block_name = "blocks." + std::to_string(layer); for (size_t token_idx = 0; token_idx < num_tokens_and_queries; ++token_idx) { layers_output(pos + token_idx, block_name, - activations.x.data() + token_idx * kModelDim, kModelDim); + activations.x.Batch(token_idx), kModelDim); } } } - RMSNormInplaceBatched( - num_tokens * num_queries, weights.final_norm_scale.data(), - activations.x.data(), kModelDim); + RMSNormInplaceBatched(num_tokens_and_queries, weights.final_norm_scale.data(), + activations.x.All(), kModelDim); if (layers_output) { for (size_t token_idx = 0; token_idx < num_tokens_and_queries; ++token_idx) { layers_output(pos + token_idx, "final_norm", - activations.x.data() + token_idx * kModelDim, kModelDim); + activations.x.Batch(token_idx), kModelDim); } } } @@ -644,9 +610,8 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens, // Placeholder for internal test3, do not remove template -void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, - const ByteStorageT& decode_u8, - const RuntimeConfig& runtime_config, +void GenerateT(const ByteStorageT& weights_u8, Activations& prefill, + Activations& activations, const RuntimeConfig& runtime_config, const hwy::Span>& prompts, size_t pos, const size_t query_index_offset, const std::vector& kv_caches, hwy::ThreadPool& pool, @@ -659,10 +624,6 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, pos *= num_queries; // position in (num_queries) interleaved token sequence. const CompressedWeights& weights = *reinterpret_cast*>(weights_u8.get()); - auto& prefill_activations = - GetActivations(prefill_u8); - auto& activations = GetActivations(decode_u8); size_t min_prompt_size = (size_t)-1; size_t max_prompt_size = 0; @@ -735,8 +696,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, HWY_DASSERT(pos_offset + batch_size <= (min_prompt_size - 1) * num_queries); const int* batch_tokens = prompt.data() + pos_offset; Prefill( - batch_tokens, batch_size, num_queries, pos, weights, - prefill_activations, kv_caches, pool); + batch_tokens, batch_size, num_queries, pos, weights, prefill, kv_caches, + pool); for (size_t idx = 0; idx < batch_size; ++idx) { bool all_tokens_eos = true; for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { @@ -788,7 +749,7 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, for (size_t generate_pos = 0; generate_pos < max_tokens && generate_pos < max_generated_tokens; ++single_prompt_pos_offset, ++generate_pos) { - Transformer( + Transformer( gen_tokens.data(), kDecodeBatchSize, num_queries, pos, weights, activations, kv_caches, pool, runtime_config.layers_output); float token_logit = 0.0f; @@ -796,22 +757,20 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, // We keep it here for clarity so that the code is correct even if Prefill // is disabled. bool all_tokens_eos = true; - float* x = activations.x.data(); - float* logits = activations.logits.data(); - for (size_t i = 0; i < num_queries; ++i, ++pos, ++pos_offset, - x += TConfig::kModelDim, logits += kVocabSize) { + for (size_t i = 0; i < num_queries; ++i, ++pos, ++pos_offset) { + const float* HWY_RESTRICT x = activations.x.Batch(i); + float* HWY_RESTRICT logits = activations.logits.Batch(i); const size_t prompt_size = prompts[i].size(); const bool is_generating_phase = (single_prompt_pos_offset >= prompt_size - 1); if (is_generating_phase) { PROFILER_ZONE("Gen.Embedding"); // Compute logits from last layer activations. - MatVec( - weights.embedder_input_embedding, 0, x, activations.even_odd.data(), - logits, pool); + MatVec(weights.embedder_input_embedding, + 0, x, activations.even_odd.All(), + logits, pool); if constexpr (TConfig::kFinalCap > 0.0f) { - LogitsSoftCap(TConfig::kFinalCap, activations.logits.data(), - kVocabSize); + LogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize); } // Barrier: must have all logits so we can subtract max. Softmax(logits, kVocabSize); @@ -850,9 +809,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, } template -void GenerateSingleT(const ByteStorageT& weights_u8, - const ByteStorageT& prefill_u8, - const ByteStorageT& decode_u8, +void GenerateSingleT(const ByteStorageT& weights_u8, Activations& prefill, + Activations& activations, const RuntimeConfig& runtime_config, const std::vector& prompt, size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, @@ -865,19 +823,17 @@ void GenerateSingleT(const ByteStorageT& weights_u8, std::vector kv_caches = {&kv_cache}; const size_t query_index_offset = 0; GenerateT( - weights_u8, prefill_u8, decode_u8, runtime_config, prompts, pos, + weights_u8, prefill, activations, runtime_config, prompts, pos, query_index_offset, kv_caches, pool, timing_info); } template -void GenerateBatchT(const ByteStorageT& weights_u8, - const ByteStorageT& prefill_u8, - const ByteStorageT& decode_u8, +void GenerateBatchT(const ByteStorageT& weights_u8, Activations& prefill, + Activations& activations, const RuntimeConfig& runtime_config, - const hwy::Span>& prompts, - size_t pos, const std::vector& kv_caches, - hwy::ThreadPool& pool, - TimingInfo& timing_info) { + const hwy::Span>& prompts, size_t pos, + const std::vector& kv_caches, + hwy::ThreadPool& pool, TimingInfo& timing_info) { // Disable query batching for Griffin models. constexpr size_t kQueryBatchSize = (TConfig::kGriffinLayers > 0) ? 1 : kBatchedQueryBatchSize; @@ -885,9 +841,9 @@ void GenerateBatchT(const ByteStorageT& weights_u8, const size_t num_queries = std::min(prompts.size() - i, kQueryBatchSize); const hwy::Span> current_prompts( prompts.data() + i, num_queries); - GenerateT(weights_u8, prefill_u8, decode_u8, - runtime_config, current_prompts, - pos, i, kv_caches, pool, timing_info); + GenerateT(weights_u8, prefill, activations, + runtime_config, current_prompts, pos, i, + kv_caches, pool, timing_info); } } @@ -898,25 +854,23 @@ void GenerateBatchT(const ByteStorageT& weights_u8, // These are extern functions defined by instantiations/*.cc, which include this // 'header' after defining GEMMA_CONFIG, which is for function overloading. void GenerateSingle( // NOLINT(misc-definitions-in-headers) - GEMMA_CONFIG, const ByteStorageT& weights_u8, - const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8, - const RuntimeConfig& runtime_config, const std::vector& prompt, - size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, - TimingInfo& timing_info) { + GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& prefill, + Activations& activations, const RuntimeConfig& runtime_config, + const std::vector& prompt, size_t pos, KVCache& kv_cache, + hwy::ThreadPool& pool, TimingInfo& timing_info) { HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateSingleT) - (weights_u8, prefill_u8, decode_u8, runtime_config, prompt, pos, kv_cache, + (weights_u8, prefill, activations, runtime_config, prompt, pos, kv_cache, pool, timing_info); } void GenerateBatch( // NOLINT(misc-definitions-in-headers) - GEMMA_CONFIG, const ByteStorageT& weights_u8, - const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8, - const RuntimeConfig& runtime_config, + GEMMA_CONFIG, const ByteStorageT& weights_u8, Activations& prefill, + Activations& activations, const RuntimeConfig& runtime_config, const hwy::Span>& prompts, size_t pos, const std::vector& kv_caches, hwy::ThreadPool& pool, TimingInfo& timing_info) { HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateBatchT) - (weights_u8, prefill_u8, decode_u8, runtime_config, prompts, pos, kv_caches, + (weights_u8, prefill, activations, runtime_config, prompts, pos, kv_caches, pool, timing_info); } diff --git a/gemma/gemma.cc b/gemma/gemma.cc index e1aaf7b..6df0a15 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -36,12 +36,23 @@ namespace gcpp { +template +struct AllocateState { + void operator()(Activations& prefill, Activations& decode) const { + // When batching queries, the prefill batch size is reduced by a factor + // of kBatchedQueryBatchSize + prefill.Allocate(kMinAdjustedPrefillBatchSize * + kBatchedQueryBatchSize); + decode.Allocate(kDecodeBatchSize * kBatchedQueryBatchSize); + } +}; + Gemma::Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info, hwy::ThreadPool& pool) : pool_(pool), tokenizer_(tokenizer_path), info_(info) { weights_u8_ = LoadCompressedWeights(weights, info.model, info.weight, pool); - CallForModelAndWeight(info.model, info.weight, prefill_u8_, - decode_u8_); + CallForModelAndWeight(info.model, info.weight, prefill_, + decode_); } Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, @@ -50,8 +61,8 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, HWY_ASSERT(info.weight == Type::kF32); weights_u8_ = CallForModel(info.model, pool); - CallForModelAndWeight(info.model, info.weight, prefill_u8_, - decode_u8_); + CallForModelAndWeight(info.model, info.weight, prefill_, + decode_); } Gemma::~Gemma() { @@ -63,19 +74,17 @@ Gemma::~Gemma() { // we shard them across multiple translation units in instantiations/*.cc. // This declares the functions defined there. We use overloading because // explicit instantiations are still too slow to compile. -#define GEMMA_DECLARE(CONFIGT, TWEIGHT) \ - extern void GenerateSingle( \ - CONFIGT, const ByteStorageT& weights_u8, \ - const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8, \ - const RuntimeConfig& runtime_config, const std::vector& prompt, \ - size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, \ - TimingInfo& timing_info); \ - extern void GenerateBatch( \ - CONFIGT, const ByteStorageT& weights_u8, \ - const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8, \ - const RuntimeConfig& runtime_config, \ - const hwy::Span>& prompts, size_t pos, \ - const std::vector& kv_caches, hwy::ThreadPool& pool, \ +#define GEMMA_DECLARE(CONFIGT, TWEIGHT) \ + extern void GenerateSingle( \ + CONFIGT, const ByteStorageT& weights_u8, Activations& prefill, \ + Activations& decode, const RuntimeConfig& runtime_config, \ + const std::vector& prompt, size_t pos, KVCache& kv_cache, \ + hwy::ThreadPool& pool, TimingInfo& timing_info); \ + extern void GenerateBatch( \ + CONFIGT, const ByteStorageT& weights_u8, Activations& prefill, \ + Activations& decode, const RuntimeConfig& runtime_config, \ + const hwy::Span>& prompts, size_t pos, \ + const std::vector& kv_caches, hwy::ThreadPool& pool, \ TimingInfo& timing_info); GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE); @@ -83,25 +92,23 @@ GEMMA_FOREACH_CONFIG_AND_WEIGHT(GEMMA_DECLARE); // TODO: gather all ByteStorageT into a type-erased model struct? template struct GenerateSingleT { - void operator()(const ByteStorageT& weights_u8, - const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8, - const RuntimeConfig& runtime_config, + void operator()(const ByteStorageT& weights_u8, Activations& prefill, + Activations& decode, const RuntimeConfig& runtime_config, const std::vector& prompt, size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, TimingInfo& timing_info) const { - GenerateSingle(TConfig(), weights_u8, prefill_u8, decode_u8, runtime_config, + GenerateSingle(TConfig(), weights_u8, prefill, decode, runtime_config, prompt, pos, kv_cache, pool, timing_info); } }; template struct GenerateBatchT { - void operator()(const ByteStorageT& weights_u8, - const ByteStorageT& prefill_u8, const ByteStorageT& decode_u8, - const RuntimeConfig& runtime_config, + void operator()(const ByteStorageT& weights_u8, Activations& prefill, + Activations& decode, const RuntimeConfig& runtime_config, const hwy::Span>& prompts, size_t pos, const std::vector& kv_caches, hwy::ThreadPool& pool, TimingInfo& timing_info) const { - GenerateBatch(TConfig(), weights_u8, prefill_u8, decode_u8, runtime_config, + GenerateBatch(TConfig(), weights_u8, prefill, decode, runtime_config, prompts, pos, kv_caches, pool, timing_info); } }; @@ -112,8 +119,8 @@ void Gemma::Generate(const RuntimeConfig& runtime_config, pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); CallForModelAndWeight( - info_.model, info_.weight, weights_u8_, prefill_u8_, decode_u8_, - runtime_config, prompt, start_pos, kv_cache, pool_, timing_info); + info_.model, info_.weight, weights_u8_, prefill_, decode_, runtime_config, + prompt, start_pos, kv_cache, pool_, timing_info); pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); } @@ -126,8 +133,8 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); CallForModelAndWeight( - info_.model, info_.weight, weights_u8_, prefill_u8_, decode_u8_, - runtime_config, prompts, start_pos, kv_caches, pool_, timing_info); + info_.model, info_.weight, weights_u8_, prefill_, decode_, runtime_config, + prompts, start_pos, kv_caches, pool_, timing_info); pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); } diff --git a/gemma/gemma.h b/gemma/gemma.h index f0091e3..477bcf8 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -23,6 +23,7 @@ // IWYU pragma: begin_exports #include "compression/io.h" // Path +#include "gemma/activations.h" #include "gemma/common.h" #include "gemma/kv_cache.h" #include "gemma/tokenizer.h" @@ -95,8 +96,8 @@ class Gemma { const ModelInfo& Info() const { return info_; } const GemmaTokenizer& Tokenizer() const { return tokenizer_; } const ByteStorageT& Weights() const { return weights_u8_; } - const ByteStorageT& Prefill() const { return prefill_u8_; } - const ByteStorageT& Decode() const { return decode_u8_; } + const Activations& Prefill() const { return prefill_; } + const Activations& Decode() const { return decode_; } void Generate(const RuntimeConfig& runtime_config, const std::vector& prompt, size_t start_pos, @@ -114,8 +115,8 @@ class Gemma { // Type-erased so that this can be defined in the header, without requiring // forwarding functions. ByteStorageT weights_u8_; - ByteStorageT prefill_u8_; - ByteStorageT decode_u8_; + Activations prefill_; + Activations decode_; ModelInfo info_; }; diff --git a/gemma/ops.h b/gemma/ops.h index ef82f64..51dcc56 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -1140,29 +1140,26 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom( } // Simple loops unless/until batch sizes are large enough to parallelize. -template +template void RMSNormBatched(size_t num_tokens, const float* activations, const WeightT* weights, OutT* out, const size_t model_dim) { - HWY_DASSERT(num_tokens <= kBatchSize); for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { RMSNorm(activations + token_idx * model_dim, weights, out + token_idx * model_dim, model_dim); } } -template +// TODO: pass RowVectorBatch argument. +template void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights, InOutT* inout, const size_t model_dim) { - HWY_DASSERT(num_tokens <= kBatchSize); for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { RMSNormInplace(weights, inout + token_idx * model_dim, model_dim); } } -template -void AddFromBatched(size_t num_tokens, const float* other, float* x, - const size_t model_dim) { - HWY_DASSERT(num_tokens <= kBatchSize); +static HWY_INLINE void AddFromBatched(size_t num_tokens, const float* other, + float* x, const size_t model_dim) { for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { AddFrom(other + token_idx * model_dim, x + token_idx * model_dim, model_dim);