// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ #include // sqrtf #include #include #include #include #include "gemma/configs.h" // ModelConfig #include "gemma/gemma_args.h" // AttentionImpl #include "gemma/kv_cache.h" #include "gemma/tensor_stats.h" #include "ops/ops.h" // CreateInvTimescale #include "util/basics.h" // BF16 #include "util/mat.h" // MatStorageT #include "util/threading_context.h" namespace gcpp { // Returns the scale value to use for the query in the attention computation. // Also called by ops_test. static inline float ChooseQueryScale(const ModelConfig& config) { const LayerConfig& layer_config = config.layer_configs[0]; if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads) return 1.0f / sqrtf(static_cast(config.model_dim / layer_config.heads)); // QueryScaleType::SqrtKeySize return 1.0f / sqrtf(static_cast(layer_config.qkv_dim)); } struct AttentionActivations { AttentionActivations( const ModelConfig& config, const LayerConfig& layer_config, size_t batch_size, size_t seq_len, const RuntimeConfig& runtime_config, const Allocator& allocator, std::vector>& row_ptrs) : // `vocab_size == 0` means it is for Vit part, VitAttention is still // MHA and does not use an external KV cache. q(MatFactory("q", batch_size, config.vocab_size == 0 ? layer_config.heads * 3 * layer_config.qkv_dim : layer_config.heads * layer_config.qkv_dim, allocator)), q_bf(MatFactory("q_bf", batch_size, config.vocab_size == 0 ? layer_config.heads * 3 * layer_config.qkv_dim : layer_config.heads * layer_config.qkv_dim, allocator)), q_T(MatFactory("q_T", layer_config.qkv_dim, config.vocab_size == 0 ? batch_size * layer_config.heads * 3 : batch_size * layer_config.heads, allocator)), vit_Q(MatFactory("Q2", batch_size, layer_config.qkv_dim, allocator)), vit_K(MatFactory("K2", seq_len, layer_config.qkv_dim, allocator)), vit_C(MatFactory("C2", batch_size, seq_len, allocator)), pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size, config.model_dim, allocator)), att(MatFactory("att", batch_size, layer_config.heads * seq_len, allocator)), att_out(MatFactory("att_out", batch_size, layer_config.heads * layer_config.qkv_dim, allocator)), softmax_max(MatFactory("softmax_max", batch_size, layer_config.heads, allocator)), softmax_d( MatFactory("softmax_d", batch_size, layer_config.heads, allocator)), att_sums( MatFactory("att_sums", batch_size, config.model_dim, allocator)), inv_timescale( CreateInvTimescale(allocator, layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope)), inv_timescale_global(CreateInvTimescale( allocator, layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope, 1000000.0)) { // Batch size can be 0 in experimental code so do not assert. if (batch_size == 0) { static std::atomic_flag warned = ATOMIC_FLAG_INIT; if (!warned.test_and_set()) { HWY_WARN("Creating mostly empty activations with a batch_size of 0."); } return; } // For MatMul outputs, precompute their row pointers. // If we forget any MatMul outputs here, debug builds print a warning but // fill them in each MatMul call. q.AllocateAndAttachRowPtrs(row_ptrs); q_bf.AllocateAndAttachRowPtrs(row_ptrs); q_T.AllocateAndAttachRowPtrs(row_ptrs); vit_C.AllocateAndAttachRowPtrs(row_ptrs); att_sums.AllocateAndAttachRowPtrs(row_ptrs); } void SetBatchSize(size_t batch_size) { q.OverrideRows(batch_size); q_bf.OverrideRows(batch_size); // q_T rows are always qkv_dim! vit_Q.OverrideRows(batch_size); // vit_K stays seq_len! vit_C.OverrideRows(batch_size); pre_att_rms_out.OverrideRows(batch_size); att.OverrideRows(batch_size); att_out.OverrideRows(batch_size); softmax_max.OverrideRows(batch_size); softmax_d.OverrideRows(batch_size); att_sums.OverrideRows(batch_size); // `inv_timescale*` are not batched. } MatStorageT q; // query MatStorageT q_bf; MatStorageT q_T; // Transposed to maximize attention speed. MatStorageT vit_Q; MatStorageT vit_K; MatStorageT vit_C; MatStorageT pre_att_rms_out; MatStorageT att; // attention vector MatStorageT att_out; // attention output MatStorageT softmax_max; // see OnlineSoftmaxState MatStorageT softmax_d; // see OnlineSoftmaxState // Accumulation of attention outputs over heads MatStorageT att_sums; // Rope MatStorageT inv_timescale; MatStorageT inv_timescale_global; }; // A non-owning view of AttentionActivations. struct AttentionActivationsPtrs { AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len) : config(config), div_seq_len(static_cast(seq_len)), div_heads(static_cast(config.layer_configs[0].heads)), query_scale(ChooseQueryScale(config)) {} AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len, const AttentionActivations& activations) : AttentionActivationsPtrs(config, seq_len) { q = activations.q; q_bf = activations.q_bf; q_T = activations.q_T; vit_Q = activations.vit_Q; vit_K = activations.vit_K; vit_C = activations.vit_C; pre_att_rms_out = activations.pre_att_rms_out; att = activations.att; att_out = activations.att_out; softmax_max = activations.softmax_max; softmax_d = activations.softmax_d; att_sums = activations.att_sums; inv_timescale = activations.inv_timescale; inv_timescale_global = activations.inv_timescale_global; } void SetBatchSize(size_t batch_size) { q.OverrideRows(batch_size); q_bf.OverrideRows(batch_size); // q_T rows are always qkv_dim! vit_Q.OverrideRows(batch_size); // vit_K stays seq_len! vit_C.OverrideRows(batch_size); pre_att_rms_out.OverrideRows(batch_size); att.OverrideRows(batch_size); att_out.OverrideRows(batch_size); softmax_max.OverrideRows(batch_size); softmax_d.OverrideRows(batch_size); att_sums.OverrideRows(batch_size); // `inv_timescale*` are not batched. } size_t SeqLen() const { return static_cast(div_seq_len.GetDivisor()); } const ModelConfig& config; // For the matrices below, the batch_size dimension is really qbatch.Size() * // token_batch_size, but in all known uses, one of those is 1. Specifically, // during PrefillTBatch, it is prompt length (up to some max batch size) // and otherwise it's qbatch.Size(). // Query matrix of size batch_size x (q_heads * qkv_dim). MatPtrT q; // Query matrix of size batch_size x (q_heads * qkv_dim). MatPtrT q_bf; // Transposed query matrix for faster Q*K^T. MatPtrT q_T; MatPtrT vit_Q; MatPtrT vit_K; MatPtrT vit_C; // Output of RMSNorm before attention, size batch_size x model_dim. MatPtrT pre_att_rms_out; // Attention scores computed from Q*K^T, size batch_size x (q_heads * // seq_len). MatPtrT att; // Attention output computed from att * V, size batch_size x (q_heads * // qkv_dim). MatPtrT att_out; // The maximum logit value encountered when computing att_out from att, // size batch_size x q_heads . See OnlineSoftmaxState for details. // WARNING: Only filled in for AttentionImpl::kOld. MatPtrT softmax_max; // The sum of scaled exponentials when computing att_out from att, // size batch_size x q_heads . See OnlineSoftmaxState for details. // WARNING: Only filled in for AttentionImpl::kOld. MatPtrT softmax_d; // Accumulation of attention outputs over heads, size batch_size x // model_dim. MatPtrT att_sums; // Inverse timescales for RoPE computation. MatPtrT inv_timescale; // Inverse timescales for global RoPE computation. MatPtrT inv_timescale_global; // Divisor for faster division by sequence length. hwy::Divisor div_seq_len; // Divisor for faster division by number of heads. hwy::Divisor div_heads; // Query scaling factor for attention computation. float query_scale; }; struct Activations { Activations(const RuntimeConfig& runtime_config, const ModelConfig& config, size_t batch_size, size_t seq_len, ThreadingContext& ctx, std::vector>& row_ptrs) : layer_config(config.layer_configs[0]), x(MatFactory("x", batch_size, config.model_dim, ctx.allocator)), x_bf(MatFactory("x_bf", batch_size, config.model_dim, ctx.allocator)), logits( MatFactory("logits", batch_size, config.vocab_size, ctx.allocator)), sampled(MatFactory("sampled", batch_size, 3, ctx.allocator)), pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size, config.model_dim, ctx.allocator)), C1(MatFactory("C1", batch_size, layer_config.ff_hidden_dim, ctx.allocator)), C2(MatFactory("C2", batch_size, layer_config.ff_hidden_dim, ctx.allocator)), ffw_out( MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)), max_workers(ctx.pools.MaxWorkers()), s_ffw_in(config.num_layers, max_workers), s_ffw_hidden(config.num_layers, max_workers), s_ffw_out(config.num_layers, max_workers), s_w_gating_einsum_w1(config.num_layers, max_workers), s_w_gating_einsum_w2(config.num_layers, max_workers), s_w_linear_w(config.num_layers, max_workers), attention_impl(runtime_config.attention_impl), attention_storage(config, layer_config, batch_size, seq_len, runtime_config, ctx.allocator, row_ptrs), attention(config, seq_len, attention_storage) { HWY_ASSERT(batch_size != 0); // For MatMul outputs, precompute their row pointers. // If we forget any MatMul outputs here, debug builds print a warning but // fill them in each MatMul call. x.AllocateAndAttachRowPtrs(row_ptrs); x_bf.AllocateAndAttachRowPtrs(row_ptrs); logits.AllocateAndAttachRowPtrs(row_ptrs); C1.AllocateAndAttachRowPtrs(row_ptrs); C2.AllocateAndAttachRowPtrs(row_ptrs); ffw_out.AllocateAndAttachRowPtrs(row_ptrs); // Note that BindC on any MatMul output considerably slows down Prefill. } ~Activations() { s_ffw_in.ReduceAndPrint("ffw_in"); s_ffw_hidden.ReduceAndPrint("ffw_hidden"); s_ffw_out.ReduceAndPrint("ffw_out"); } // Negligible CPU time. void SetBatchSize(size_t batch_size) { x.OverrideRows(batch_size); x_bf.OverrideRows(batch_size); logits.OverrideRows(batch_size); sampled.OverrideRows(batch_size); pre_ffw_rms_out.OverrideRows(batch_size); C1.OverrideRows(batch_size); C2.OverrideRows(batch_size); ffw_out.OverrideRows(batch_size); attention_storage.SetBatchSize(batch_size); // `AttentionActivationsPtrs` holds `MatPtrT` which also require updating; // their row override is not updated when the underlying storage changes. attention.SetBatchSize(batch_size); } const LayerConfig& layer_config; MatStorageT x; // input MatStorageT x_bf; // output of final RMSNorm, input to EmbeddingMatmul MatStorageT logits; // TODO: BF16 after Softmax supports that. MatStorageT sampled; // batch_size x 3 (padded) // Gated FFW MatStorageT pre_ffw_rms_out; MatStorageT C1; MatStorageT C2; MatStorageT ffw_out; const size_t max_workers; TensorStats s_ffw_in; TensorStats s_ffw_hidden; // after Activation+gating TensorStats s_ffw_out; TensorStats s_w_gating_einsum_w1; TensorStats s_w_gating_einsum_w2; TensorStats s_w_linear_w; AttentionImpl attention_impl; AttentionActivations attention_storage; AttentionActivationsPtrs attention; }; } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_