From 6773e4517c9a0053fcfc4e75a607f1532dc768b1 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Mon, 16 Jun 2025 07:52:27 -0700 Subject: [PATCH] Split Activations into Griffin/Attention to reduce memory usage for attention-only tests. PiperOrigin-RevId: 772025282 --- BUILD.bazel | 1 + gemma/activations.h | 190 +++++++++++++++++++++++++------------------- gemma/attention.cc | 23 +++--- gemma/attention.h | 11 +-- gemma/gemma.cc | 13 +-- gemma/griffin.cc | 25 +++--- gemma/vit.cc | 49 ++++++------ ops/ops_test.cc | 3 +- 8 files changed, 174 insertions(+), 141 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 3b894be..3c6ec5d 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -360,6 +360,7 @@ cc_test( deps = [ ":allocator", ":basics", + ":configs", ":gemma_lib", ":mat", ":ops", diff --git a/gemma/activations.h b/gemma/activations.h index 56c799b..6f14900 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -31,25 +31,46 @@ namespace gcpp { -// Returns the scale value to use for the query in the attention computation. -// Also called by ops_test. -static inline float ChooseQueryScale(const ModelConfig& config) { - if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads) - return 1.0f / sqrtf(static_cast(config.model_dim / - config.layer_configs[0].heads)); - // QueryScaleType::SqrtKeySize - return 1.0f / sqrtf(static_cast(config.layer_configs[0].qkv_dim)); -} +struct GriffinActivations { + GriffinActivations(const ModelConfig& config, size_t batch_size, + MatPadding pad) + : griffin_x("griffin_x", Extents2D(batch_size, config.model_dim), pad), + griffin_y("griffin_y", Extents2D(batch_size, config.model_dim), pad), + griffin_gate_x("griffin_gate_x", + Extents2D(batch_size, config.model_dim), pad), + griffin_multiplier("griffin_mul", + Extents2D(batch_size, config.model_dim), pad) {} -struct Activations { - Activations(const ModelConfig& config, size_t batch_size, size_t seq_len, - std::vector>& row_ptrs) - : weights_config(config), - layer_config(config.layer_configs[0]), - div_seq_len(static_cast(seq_len)), - is_griffin(config.model == Model::GRIFFIN_2B), + void SetBatchSize(size_t batch_size) { + griffin_x.OverrideRows(batch_size); + griffin_y.OverrideRows(batch_size); + griffin_gate_x.OverrideRows(batch_size); + griffin_multiplier.OverrideRows(batch_size); + } + + MatStorageT griffin_x; + MatStorageT griffin_y; + MatStorageT griffin_gate_x; + MatStorageT griffin_multiplier; +}; + +struct AttentionActivations { + // Returns the scale value to use for the query in the attention computation. + // Also called by ops_test. + static inline float ChooseQueryScale(const ModelConfig& config) { + if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads) + return 1.0f / sqrtf(static_cast(config.model_dim / + config.layer_configs[0].heads)); + // QueryScaleType::SqrtKeySize + return 1.0f / sqrtf(static_cast(config.layer_configs[0].qkv_dim)); + } + + AttentionActivations( + const ModelConfig& config, const LayerConfig& layer_config, + size_t batch_size, size_t seq_len, MatPadding pad, + std::vector>& row_ptrs) + : config(config), - x("x", Extents2D(batch_size, config.model_dim), pad_), // `vocab_size == 0` means it is for Vit part, VitAttention is still MHA // and does not use an external KV cache. q("q", @@ -57,36 +78,16 @@ struct Activations { config.vocab_size == 0 ? layer_config.heads * 3 * layer_config.qkv_dim : layer_config.heads * layer_config.qkv_dim), - pad_), - logits("logits", Extents2D(batch_size, config.vocab_size), pad_), + pad), pre_att_rms_out("pre_att_rms_out", - Extents2D(batch_size, config.model_dim), pad_), - att("att", Extents2D(batch_size, layer_config.heads * seq_len), pad_), + Extents2D(batch_size, config.model_dim), pad), + att("att", Extents2D(batch_size, layer_config.heads * seq_len), pad), att_out( "att_out", Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim), - pad_), - att_sums("att_sums", Extents2D(batch_size, config.model_dim), pad_), - - pre_ffw_rms_out("pre_ffw_rms_out", - Extents2D(batch_size, config.model_dim), pad_), - C1("C1", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_), - C2("C2", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_), - ffw_out("ffw_out", Extents2D(batch_size, config.model_dim), pad_), - - griffin_x("griffin_x", - is_griffin ? Extents2D(batch_size, config.model_dim) : none_, - pad_), - griffin_y("griffin_y", - is_griffin ? Extents2D(batch_size, config.model_dim) : none_, - pad_), - griffin_gate_x( - "griffin_gate_x", - is_griffin ? Extents2D(batch_size, config.model_dim) : none_, pad_), - griffin_multiplier( - "griffin_mul", - is_griffin ? Extents2D(batch_size, config.model_dim) : none_, pad_), + pad), + att_sums("att_sums", Extents2D(batch_size, config.model_dim), pad), inv_timescale( CreateInvTimescale(layer_config.qkv_dim, @@ -95,16 +96,73 @@ struct Activations { layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope, 1000000.0)), + div_seq_len(static_cast(seq_len)), query_scale(ChooseQueryScale(config)) { HWY_ASSERT(batch_size != 0); // For MatMul outputs, precompute their row pointers. // If we forget any MatMul outputs here, debug builds print a warning but // fill them in each MatMul call. - x.AllocateAndAttachRowPtrs(row_ptrs); q.AllocateAndAttachRowPtrs(row_ptrs); - logits.AllocateAndAttachRowPtrs(row_ptrs); att_sums.AllocateAndAttachRowPtrs(row_ptrs); + } + + void SetBatchSize(size_t batch_size) { + q.OverrideRows(batch_size); + + pre_att_rms_out.OverrideRows(batch_size); + att.OverrideRows(batch_size); + att_out.OverrideRows(batch_size); + att_sums.OverrideRows(batch_size); + } + + bool IsGlobalLayer(size_t layer_idx) const { + return config.attention_window_sizes[layer_idx] == div_seq_len.GetDivisor(); + } + + const ModelConfig& config; + + MatStorageT q; // query + + MatStorageT pre_att_rms_out; + MatStorageT att; // attention vector + MatStorageT att_out; // attention output + // Accumulation of attention outputs over heads + MatStorageT att_sums; + + // Rope + MatStorageT inv_timescale; + MatStorageT inv_timescale_global; + + hwy::Divisor div_seq_len; + float query_scale; +}; + +struct Activations { + Activations(const ModelConfig& config, size_t batch_size, size_t seq_len, + std::vector>& row_ptrs) + : layer_config(config.layer_configs[0]), + + x("x", Extents2D(batch_size, config.model_dim), pad_), + logits("logits", Extents2D(batch_size, config.vocab_size), pad_), + + pre_ffw_rms_out("pre_ffw_rms_out", + Extents2D(batch_size, config.model_dim), pad_), + C1("C1", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_), + C2("C2", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_), + ffw_out("ffw_out", Extents2D(batch_size, config.model_dim), pad_), + + attention(config, layer_config, batch_size, seq_len, pad_, row_ptrs) { + HWY_ASSERT(batch_size != 0); + if (config.model == Model::GRIFFIN_2B) { + griffin = std::make_unique(config, batch_size, pad_); + } + + // For MatMul outputs, precompute their row pointers. + // If we forget any MatMul outputs here, debug builds print a warning but + // fill them in each MatMul call. + x.AllocateAndAttachRowPtrs(row_ptrs); + logits.AllocateAndAttachRowPtrs(row_ptrs); C1.AllocateAndAttachRowPtrs(row_ptrs); C2.AllocateAndAttachRowPtrs(row_ptrs); ffw_out.AllocateAndAttachRowPtrs(row_ptrs); @@ -115,67 +173,35 @@ struct Activations { void SetBatchSize(size_t batch_size) { PROFILER_ZONE("SetBatchSize"); x.OverrideRows(batch_size); - q.OverrideRows(batch_size); logits.OverrideRows(batch_size); - pre_att_rms_out.OverrideRows(batch_size); - att.OverrideRows(batch_size); - att_out.OverrideRows(batch_size); - att_sums.OverrideRows(batch_size); - pre_ffw_rms_out.OverrideRows(batch_size); C1.OverrideRows(batch_size); C2.OverrideRows(batch_size); ffw_out.OverrideRows(batch_size); - if (is_griffin) { - griffin_x.OverrideRows(batch_size); - griffin_y.OverrideRows(batch_size); - griffin_gate_x.OverrideRows(batch_size); - griffin_multiplier.OverrideRows(batch_size); + attention.SetBatchSize(batch_size); + + if (griffin) { + griffin->SetBatchSize(batch_size); } } - bool IsGlobalLayer(size_t layer_idx) const { - return weights_config.attention_window_sizes[layer_idx] == - div_seq_len.GetDivisor(); - } - - const ModelConfig& weights_config; const LayerConfig& layer_config; - hwy::Divisor div_seq_len; - bool is_griffin; const Extents2D none_ = Extents2D(); const MatPadding pad_ = MatPadding::kOdd; MatStorageT x; // input - MatStorageT q; // query MatStorageT logits; - // Attention - MatStorageT pre_att_rms_out; - MatStorageT att; // attention vector - MatStorageT att_out; // attention output - // Accumulation of attention outputs over heads - MatStorageT att_sums; - // Gated FFW MatStorageT pre_ffw_rms_out; MatStorageT C1; // TODO: BF16 after Activation() supports it MatStorageT C2; MatStorageT ffw_out; - // Griffin - MatStorageT griffin_x; - MatStorageT griffin_y; - MatStorageT griffin_gate_x; - MatStorageT griffin_multiplier; - - // Rope - MatStorageT inv_timescale; - MatStorageT inv_timescale_global; - - float query_scale; + AttentionActivations attention; + std::unique_ptr griffin; }; } // namespace gcpp diff --git a/gemma/attention.cc b/gemma/attention.cc index 87793b8..6a6734a 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -20,7 +20,6 @@ #include "gemma/activations.h" #include "gemma/gemma.h" -#include "gemma/gemma_args.h" #include "gemma/weights.h" #include "util/threading.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -66,14 +65,14 @@ template static void PositionalEncodingQK(U* qk, const size_t qkv_dim, const size_t layer_idx, const LayerWeightsPtrs& layer, - const Activations& activations, + const AttentionActivations& activations, const size_t pos, const float mul = 1.0f) { const PostQKType& post_qk = layer.layer_config.post_qk; // qk is either q or k, so qkv_dim is the length we operate on. const float* inv_timescale = activations.inv_timescale.PackedScale1(); bool is_global_layer = activations.IsGlobalLayer(layer_idx); // TODO: add a config flag instead of hardcoding the model. - if (is_global_layer && IsVLM(activations.weights_config.model)) { + if (is_global_layer && IsVLM(activations.config.model)) { inv_timescale = activations.inv_timescale_global.PackedScale1(); } // PostQKType::Rope @@ -118,10 +117,10 @@ void SingleDotSoftmaxWeightedSum( const size_t pos, const size_t start_pos, const size_t last_pos, float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, const size_t layer_idx, const LayerWeightsPtrs& layer, - const Activations& activations, float* HWY_RESTRICT att, + const AttentionActivations& activations, float* HWY_RESTRICT att, float* HWY_RESTRICT att_out) { const size_t qkv_dim = layer.layer_config.qkv_dim; - const float att_cap = activations.weights_config.att_cap; + const float att_cap = activations.config.att_cap; const float query_scale = activations.query_scale; const size_t seq_len = static_cast(activations.div_seq_len.GetDivisor()); @@ -155,7 +154,7 @@ static HWY_INLINE size_t StartPos(size_t pos, const ModelConfig& config, void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, - Activations& activations, QBatch& qbatch, + AttentionActivations& activations, QBatch& qbatch, NestedPools& pools) { PROFILER_ZONE("Gen.Attention.DotSoftmax"); const hwy::Divisor div_qbatch(qbatch.Size()); @@ -190,7 +189,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, // the range of cache positions to attend to. const size_t pos = qbatch.Pos(qi) + batch_idx; const size_t start_pos = - StartPos(pos, activations.weights_config, layer_idx); + StartPos(pos, activations.config, layer_idx); size_t last_pos = pos; const size_t prefix_end = qbatch.PrefixEnd(qi); if (prefix_end > 0 && prefix_end - 1 > last_pos) { @@ -241,7 +240,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, // Fills activations.q and writes to KV cache. static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, - Activations& activations, + AttentionActivations& activations, const QBatch& qbatch, const int flags, MatMulEnv& env) { PROFILER_ZONE("Gen.Attention.QKV"); @@ -306,7 +305,8 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, // Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and // head_dim (`qkv_dim`) into output (`layer_out`). static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, - Activations& activations, MatMulEnv& env) { + AttentionActivations& activations, + MatMulEnv& env) { PROFILER_ZONE("Gen.Attention.SumHeads"); const LayerConfig& layer_config = layer.layer_config; // att_weights and att_out are concatenated heads, each of length @@ -324,8 +324,9 @@ static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, } void GemmaAttention(size_t num_tokens, const size_t layer_idx, - const LayerWeightsPtrs& layer, Activations& activations, - QBatch& qbatch, MatMulEnv& env, int flags) { + const LayerWeightsPtrs& layer, + AttentionActivations& activations, QBatch& qbatch, + MatMulEnv& env, int flags) { const LayerConfig& layer_config = layer.layer_config; HWY_DASSERT(!layer_config.IsMHA()); // No longer supported. HWY_DASSERT_M((layer_config.heads % layer_config.kv_heads) == 0, diff --git a/gemma/attention.h b/gemma/attention.h index 12809e9..589cdb1 100644 --- a/gemma/attention.h +++ b/gemma/attention.h @@ -32,17 +32,18 @@ namespace gcpp { const size_t pos, const size_t start_pos, const size_t last_pos, \ float* HWY_RESTRICT q, const MatPtrT& k, const MatPtrT& v, \ size_t layer_idx, const LayerWeightsPtrs& layer, \ - const Activations& activations, float* HWY_RESTRICT att, \ + const AttentionActivations& activations, float* HWY_RESTRICT att, \ float* HWY_RESTRICT att_out); \ \ void DotSoftmaxWeightedSum(const size_t num_tokens, size_t layer_idx, \ const LayerWeightsPtrs& layer, \ - Activations& activations, QBatch& qbatch, \ - NestedPools& pools); \ + AttentionActivations& activations, \ + QBatch& qbatch, NestedPools& pools); \ \ void GemmaAttention(size_t num_tokens, const size_t layer_idx, \ - const LayerWeightsPtrs& layer, Activations& activations, \ - QBatch& qbatch, MatMulEnv& env, int flags); \ + const LayerWeightsPtrs& layer, \ + AttentionActivations& activations, QBatch& qbatch, \ + MatMulEnv& env, int flags); \ /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 87e241f..d635dbb 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -65,14 +65,15 @@ void Attention(LayerAttentionType type, const size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, Activations& activations, QBatch& qbatch, MatMulEnv& env) { if (type == LayerAttentionType::kGemma) { - GemmaAttention(num_tokens, layer_idx, layer, activations, qbatch, env, + GemmaAttention(num_tokens, layer_idx, layer, activations.attention, qbatch, + env, /*flags=*/0); } else { HWY_DASSERT(type == LayerAttentionType::kGriffinRecurrentBlock); // KVCache conv1d_cache and rglru_cache have one row per *Griffin* layer, // so map `layer` to the Griffin layer index. const size_t griffin_layer = - activations.weights_config.NumLayersOfTypeBefore(type, layer_idx); + activations.attention.config.NumLayersOfTypeBefore(type, layer_idx); GriffinRecurrent(num_tokens, griffin_layer, &layer, activations, qbatch, env); } @@ -86,15 +87,15 @@ static HWY_NOINLINE void TransformerLayer(const size_t num_tokens, const LayerConfig& layer_config = layer.layer_config; RMSNormBatched(activations.x, layer.pre_attention_norm_scale, - activations.pre_att_rms_out); + activations.attention.pre_att_rms_out); Attention(layer_config.type, num_tokens, layer_idx, layer, activations, qbatch, env); PostNorm(layer_config.post_norm, layer.post_attention_norm_scale, - activations.att_sums); + activations.attention.att_sums); - ResidualConnection(activations.att_sums, activations.x, layer, + ResidualConnection(activations.attention.att_sums, activations.x, layer, /*is_attention=*/true); RMSNormBatched(activations.x, layer.pre_ffw_norm_scale, @@ -470,7 +471,7 @@ static void GenerateT(const ModelConfig& config, HWY_ASSERT(qbatch.KV(qi).SeqLen() == seq_len); } HWY_ASSERT(prefill_tokens < seq_len); - activations.div_seq_len = hwy::Divisor(static_cast(seq_len)); + HWY_ASSERT(activations.attention.div_seq_len.GetDivisor() == seq_len); // Lacks a constructor to bulk-set, hence initialized by Prefill* which have // qi loops anyway. diff --git a/gemma/griffin.cc b/gemma/griffin.cc index f7b02e2..e1afff5 100644 --- a/gemma/griffin.cc +++ b/gemma/griffin.cc @@ -60,20 +60,21 @@ void GriffinRecurrent(size_t num_tokens, size_t griffin_layer, const size_t num_interleaved = num_tokens * qbatch.Size(); const hwy::Divisor div_qbatch(static_cast(qbatch.Size())); + GriffinActivations& griffin = *activations.griffin; // X / Y linear layers. // TODO: MatMul - HWY_DASSERT(activations.griffin_y.Rows() == activations.griffin_x.Rows()); - HWY_DASSERT(num_interleaved == activations.griffin_y.Rows()); + HWY_DASSERT(griffin.griffin_y.Rows() == griffin.griffin_x.Rows()); + HWY_DASSERT(num_interleaved == griffin.griffin_y.Rows()); CallUpcastedSame( &layer_weights->griffin.linear_x_w, &layer_weights->griffin.linear_y_w, [&](const auto* wx, const auto* wy) { for (size_t r = 0; r < num_interleaved; ++r) { - float* HWY_RESTRICT y = activations.griffin_y.Row(r); - float* HWY_RESTRICT x = activations.griffin_x.Row(r); + float* HWY_RESTRICT y = griffin.griffin_y.Row(r); + float* HWY_RESTRICT x = griffin.griffin_x.Row(r); TwoMatVecAdd( *wx, *wy, 0, model_dim, model_dim, - activations.pre_att_rms_out.Row(r), + activations.attention.pre_att_rms_out.Row(r), /*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(), /*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(), /*out0=*/x, /*out1=*/y, pool); @@ -87,7 +88,7 @@ void GriffinRecurrent(size_t num_tokens, size_t griffin_layer, const size_t qi = div_qbatch.Remainder(interleaved_idx); const size_t batch_idx = div_qbatch.Divide(interleaved_idx); const size_t pos = qbatch.Pos(qi) + batch_idx; - float* HWY_RESTRICT x = activations.griffin_x.Row(qi); + float* HWY_RESTRICT x = griffin.griffin_x.Row(qi); // cache[i] = input at time t-i. float* HWY_RESTRICT cache[kMaxConv1DWidth]; @@ -124,10 +125,10 @@ void GriffinRecurrent(size_t num_tokens, size_t griffin_layer, const size_t batch_idx = div_qbatch.Divide(interleaved_idx); const size_t pos = qbatch.Pos(qi) + batch_idx; - float* HWY_RESTRICT x = activations.griffin_x.Row(qi); - float* HWY_RESTRICT y = activations.griffin_y.Row(qi); - float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(qi); - float* HWY_RESTRICT a = activations.griffin_multiplier.Row(qi); + float* HWY_RESTRICT x = griffin.griffin_x.Row(qi); + float* HWY_RESTRICT y = griffin.griffin_y.Row(qi); + float* HWY_RESTRICT gate_x = griffin.griffin_gate_x.Row(qi); + float* HWY_RESTRICT a = griffin.griffin_multiplier.Row(qi); float* HWY_RESTRICT rnn_state = qbatch.KV(qi).rglru_cache.Row(griffin_layer); @@ -175,9 +176,9 @@ void GriffinRecurrent(size_t num_tokens, size_t griffin_layer, } // interleaved_idx // Final linear layer. - CallMatMul(activations.griffin_x, layer_weights->griffin.linear_out_w, + CallMatMul(griffin.griffin_x, layer_weights->griffin.linear_out_w, layer_weights->griffin.linear_out_biases.PackedScale1(), env, - activations.att_sums); + activations.attention.att_sums); } // GriffinRecurrent // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/gemma/vit.cc b/gemma/vit.cc index fe17c3f..ddbd963 100644 --- a/gemma/vit.cc +++ b/gemma/vit.cc @@ -56,10 +56,10 @@ class VitAttention { // Computes Q, K, V for all heads, stored in activations_.q. HWY_NOINLINE void ComputeQKV() { PROFILER_ZONE("Gen.VitAttention.QKV"); - auto& qkv = activations_.q; + auto& qkv = activations_.attention.q; HWY_ASSERT(qkv.Rows() == num_tokens_); HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim); - CallMatMul(activations_.pre_att_rms_out, layer_.vit.qkv_einsum_w, + CallMatMul(activations_.attention.pre_att_rms_out, layer_.vit.qkv_einsum_w, layer_.vit.qkv_einsum_b.PackedScale1(), env_, qkv); } @@ -69,7 +69,7 @@ class VitAttention { const size_t heads = layer_config_.heads; HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA"); const size_t seq_len = - static_cast(activations_.div_seq_len.GetDivisor()); + static_cast(activations_.attention.div_seq_len.GetDivisor()); const float query_scale = 1.0f / sqrtf(static_cast(qkv_dim)); PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); @@ -82,12 +82,13 @@ class VitAttention { MatPadding::kPacked); // Initialize att_out to zero prior to head loop. - ZeroInit(activations_.att_out); + ZeroInit(activations_.attention.att_out); for (size_t head = 0; head < heads; ++head) { pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { const size_t token = task; - float* HWY_RESTRICT q = activations_.q.Row(token) + head * 3 * qkv_dim; + float* HWY_RESTRICT q = + activations_.attention.q.Row(token) + head * 3 * qkv_dim; // TODO: shift to MatMul with A.scale once MatMul is confirmed working MulByConst(query_scale, q, qkv_dim); hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float)); @@ -95,8 +96,8 @@ class VitAttention { pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { const size_t seq_idx = task; - float* HWY_RESTRICT k = - activations_.q.Row(seq_idx) + head * 3 * qkv_dim + qkv_dim; + float* HWY_RESTRICT k = activations_.attention.q.Row(seq_idx) + + head * 3 * qkv_dim + qkv_dim; hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float)); }); @@ -111,10 +112,10 @@ class VitAttention { pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { size_t token = task; float* HWY_RESTRICT att_out = - activations_.att_out.Row(token) + head * qkv_dim; + activations_.attention.att_out.Row(token) + head * qkv_dim; for (size_t i = 0; i < seq_len; ++i) { - float* HWY_RESTRICT v = - activations_.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim; + float* HWY_RESTRICT v = activations_.attention.q.Row(i) + + head * 3 * qkv_dim + 2 * qkv_dim; MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim); } }); @@ -126,7 +127,7 @@ class VitAttention { const size_t heads = layer_config_.heads; HWY_ASSERT_M(heads == layer_config_.kv_heads, "Vit expects MHA"); const size_t seq_len = - static_cast(activations_.div_seq_len.GetDivisor()); + static_cast(activations_.attention.div_seq_len.GetDivisor()); const float query_scale = 1.0f / sqrtf(static_cast(qkv_dim)); PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); @@ -137,24 +138,24 @@ class VitAttention { const size_t token = task / layer_config_.heads; // Compute Q.K scores, which are "logits" stored in head_att. float* HWY_RESTRICT q = - activations_.q.Row(token) + head * 3 * qkv_dim; + activations_.attention.q.Row(token) + head * 3 * qkv_dim; MulByConst(query_scale, q, qkv_dim); float* HWY_RESTRICT head_att = - activations_.att.Row(token) + head * seq_len; + activations_.attention.att.Row(token) + head * seq_len; for (size_t i = 0; i < seq_len; ++i) { - float* HWY_RESTRICT k = - activations_.q.Row(i) + head * 3 * qkv_dim + qkv_dim; + float* HWY_RESTRICT k = activations_.attention.q.Row(i) + + head * 3 * qkv_dim + qkv_dim; head_att[i] = Dot(q, k, qkv_dim); // score = q.k } // SoftMax yields "probabilities" in head_att. Softmax(head_att, seq_len); // Compute weighted sum of v into att_out. float* HWY_RESTRICT att_out = - activations_.att_out.Row(token) + head * qkv_dim; + activations_.attention.att_out.Row(token) + head * qkv_dim; hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); for (size_t i = 0; i < seq_len; ++i) { - float* HWY_RESTRICT v = - activations_.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim; + float* HWY_RESTRICT v = activations_.attention.q.Row(i) + + head * 3 * qkv_dim + 2 * qkv_dim; MulByConstAndAdd(head_att[i], v, att_out, qkv_dim); } }); @@ -168,8 +169,8 @@ class VitAttention { // att_weights and att_out are concatenated heads, each of length // qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] // matmul output is the sum over heads. - CallMatMul(activations_.att_out, layer_.vit.attn_out_w, bias, env_, - activations_.att_sums); + CallMatMul(activations_.attention.att_out, layer_.vit.attn_out_w, bias, + env_, activations_.attention.att_sums); } public: @@ -184,7 +185,7 @@ class VitAttention { HWY_INLINE void operator()() { ComputeQKV(); - if (activations_.weights_config.wrapping == PromptWrapping::GEMMA_VLM) { + if (activations_.attention.config.wrapping == PromptWrapping::GEMMA_VLM) { DotSoftmaxWeightedSumMatrix(); } else { DotSoftmaxWeightedSum(); @@ -233,7 +234,7 @@ void FFWVit(const LayerWeightsPtrs& layer, Activations& activations, void VitTransformerLayer(size_t num_tokens, const size_t layer_idx, const LayerWeightsPtrs& layer, Activations& activations, MatMulEnv& env) { - const size_t model_dim = activations.weights_config.model_dim; + const size_t model_dim = activations.attention.config.model_dim; auto type = layer.layer_config.type; HWY_DASSERT(type == LayerAttentionType::kVit); (void)type; @@ -246,14 +247,14 @@ void VitTransformerLayer(size_t num_tokens, const size_t layer_idx, // y = nn.LayerNorm()(x) // y ~ pre_att_rms_out LayerNormBatched(x, layer.vit.layer_norm_0_scale, layer.vit.layer_norm_0_bias, - activations.pre_att_rms_out); + activations.attention.pre_att_rms_out); // y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y) // y ~ att_sums VitAttention(num_tokens, layer_idx, activations, layer, env)(); // x = out["+sa"] = x + y - AddFromBatched(activations.att_sums, x); + AddFromBatched(activations.attention.att_sums, x); // y = nn.LayerNorm()(x) // y ~ pre_ffw_rms_out diff --git a/ops/ops_test.cc b/ops/ops_test.cc index 67890ec..a0ff314 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -32,6 +32,7 @@ #include #include "gemma/activations.h" // ChooseQueryScale +#include "gemma/configs.h" #include "util/allocator.h" #include "util/basics.h" // BF16 #include "util/mat.h" // MatStorageT @@ -400,7 +401,7 @@ void TestRopeAndMulBy() { x.Row(0)[i] = random_float(); } - const float qmul = ChooseQueryScale(config); + const float qmul = AttentionActivations::ChooseQueryScale(config); const float kmul = 1.0; MatStorageT qexpected("qexpected", dim_qkv);