// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ #include // sqrtf #include #include #include #include #include "gemma/configs.h" // ModelConfig #include "gemma/gemma_args.h" // AttentionImpl #include "ops/ops.h" // CreateInvTimescale #include "util/basics.h" // BF16 #include "util/mat.h" // MatStorageT #include "util/threading_context.h" namespace gcpp { // Returns the scale value to use for the query in the attention computation. // Also called by ops_test. static inline float ChooseQueryScale(const ModelConfig& config) { const LayerConfig& layer_config = config.layer_configs[0]; if (config.query_scale == QueryScaleType::SqrtModelDimDivNumHeads) return 1.0f / sqrtf(static_cast(config.model_dim / layer_config.heads)); // QueryScaleType::SqrtKeySize return 1.0f / sqrtf(static_cast(layer_config.qkv_dim)); } struct AttentionActivations { AttentionActivations( const ModelConfig& config, const LayerConfig& layer_config, size_t batch_size, size_t seq_len, const Allocator& allocator, std::vector>& row_ptrs) : // `vocab_size == 0` means it is for Vit part, VitAttention is still // MHA and does not use an external KV cache. q(MatFactory("q", batch_size, config.vocab_size == 0 ? layer_config.heads * 3 * layer_config.qkv_dim : layer_config.heads * layer_config.qkv_dim, allocator)), q_bf(MatFactory("q_bf", batch_size, config.vocab_size == 0 ? layer_config.heads * 3 * layer_config.qkv_dim : layer_config.heads * layer_config.qkv_dim, allocator)), q_T(MatFactory("q_T", layer_config.qkv_dim, config.vocab_size == 0 ? batch_size * layer_config.heads * 3 : batch_size * layer_config.heads, allocator)), pre_att_rms_out(MatFactory("pre_att_rms_out", batch_size, config.model_dim, allocator)), att(MatFactory("att", batch_size, layer_config.heads * seq_len, allocator)), att_out(MatFactory("att_out", batch_size, layer_config.heads * layer_config.qkv_dim, allocator)), att_sums( MatFactory("att_sums", batch_size, config.model_dim, allocator)), inv_timescale( CreateInvTimescale(allocator, layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope)), inv_timescale_global(CreateInvTimescale( allocator, layer_config.qkv_dim, layer_config.post_qk == PostQKType::HalfRope, 1000000.0)) { // Batch size can be 0 in experimental code so do not assert. if (batch_size == 0) { static std::atomic_flag warned = ATOMIC_FLAG_INIT; if (!warned.test_and_set()) { HWY_WARN("Creating mostly empty activations with a batch_size of 0."); } return; } // For MatMul outputs, precompute their row pointers. // If we forget any MatMul outputs here, debug builds print a warning but // fill them in each MatMul call. q.AllocateAndAttachRowPtrs(row_ptrs); q_bf.AllocateAndAttachRowPtrs(row_ptrs); q_T.AllocateAndAttachRowPtrs(row_ptrs); att_sums.AllocateAndAttachRowPtrs(row_ptrs); } void SetBatchSize(size_t batch_size) { q.OverrideRows(batch_size); q_bf.OverrideRows(batch_size); // q_T rows are always qkv_dim! pre_att_rms_out.OverrideRows(batch_size); att.OverrideRows(batch_size); att_out.OverrideRows(batch_size); att_sums.OverrideRows(batch_size); // `inv_timescale*` are not batched. } MatStorageT q; // query MatStorageT q_bf; MatStorageT q_T; // Transposed to maximize attention speed. MatStorageT pre_att_rms_out; MatStorageT att; // attention vector MatStorageT att_out; // attention output // Accumulation of attention outputs over heads MatStorageT att_sums; // Rope MatStorageT inv_timescale; MatStorageT inv_timescale_global; }; // A non-owning view of AttentionActivations. struct AttentionActivationsPtrs { AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len) : config(config), div_seq_len(static_cast(seq_len)), div_heads(static_cast(config.layer_configs[0].heads)), query_scale(ChooseQueryScale(config)) {} AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len, const AttentionActivations& activations) : AttentionActivationsPtrs(config, seq_len) { q = activations.q; q_bf = activations.q_bf; q_T = activations.q_T; pre_att_rms_out = activations.pre_att_rms_out; att = activations.att; att_out = activations.att_out; att_sums = activations.att_sums; inv_timescale = activations.inv_timescale; inv_timescale_global = activations.inv_timescale_global; } void SetBatchSize(size_t batch_size) { q.OverrideRows(batch_size); q_bf.OverrideRows(batch_size); // q_T rows are always qkv_dim! pre_att_rms_out.OverrideRows(batch_size); att.OverrideRows(batch_size); att_out.OverrideRows(batch_size); att_sums.OverrideRows(batch_size); // `inv_timescale*` are not batched. } size_t SeqLen() const { return static_cast(div_seq_len.GetDivisor()); } const ModelConfig& config; MatPtrT q; MatPtrT q_bf; MatPtrT q_T; MatPtrT pre_att_rms_out; MatPtrT att; MatPtrT att_out; MatPtrT att_sums; MatPtrT inv_timescale; MatPtrT inv_timescale_global; hwy::Divisor div_seq_len; hwy::Divisor div_heads; float query_scale; }; struct Activations { Activations(const RuntimeConfig& runtime_config, const ModelConfig& config, size_t batch_size, size_t seq_len, ThreadingContext& ctx, std::vector>& row_ptrs) : layer_config(config.layer_configs[0]), x(MatFactory("x", batch_size, config.model_dim, ctx.allocator)), x_bf(MatFactory("x_bf", batch_size, config.model_dim, ctx.allocator)), logits( MatFactory("logits", batch_size, config.vocab_size, ctx.allocator)), sampled(MatFactory("sampled", batch_size, 3, ctx.allocator)), pre_ffw_rms_out(MatFactory("pre_ffw_rms_out", batch_size, config.model_dim, ctx.allocator)), C1(MatFactory("C1", batch_size, layer_config.ff_hidden_dim, ctx.allocator)), C2(MatFactory("C2", batch_size, layer_config.ff_hidden_dim, ctx.allocator)), ffw_out( MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)), attention_impl(runtime_config.attention_impl), attention_storage(config, layer_config, batch_size, seq_len, ctx.allocator, row_ptrs), attention(config, seq_len, attention_storage) { HWY_ASSERT(batch_size != 0); // For MatMul outputs, precompute their row pointers. // If we forget any MatMul outputs here, debug builds print a warning but // fill them in each MatMul call. x.AllocateAndAttachRowPtrs(row_ptrs); x_bf.AllocateAndAttachRowPtrs(row_ptrs); logits.AllocateAndAttachRowPtrs(row_ptrs); C1.AllocateAndAttachRowPtrs(row_ptrs); C2.AllocateAndAttachRowPtrs(row_ptrs); ffw_out.AllocateAndAttachRowPtrs(row_ptrs); // Note that BindC on any MatMul output considerably slows down Prefill. } // Negligible CPU time. void SetBatchSize(size_t batch_size) { x.OverrideRows(batch_size); x_bf.OverrideRows(batch_size); logits.OverrideRows(batch_size); sampled.OverrideRows(batch_size); pre_ffw_rms_out.OverrideRows(batch_size); C1.OverrideRows(batch_size); C2.OverrideRows(batch_size); ffw_out.OverrideRows(batch_size); attention_storage.SetBatchSize(batch_size); // `AttentionActivationsPtrs` holds `MatPtrT` which also require updating; // their row override is not updated when the underlying storage changes. attention.SetBatchSize(batch_size); } const LayerConfig& layer_config; MatStorageT x; // input MatStorageT x_bf; // output of final RMSNorm, input to EmbeddingMatmul MatStorageT logits; // TODO: BF16 after Softmax supports that. MatStorageT sampled; // batch_size x 3 (padded) // Gated FFW MatStorageT pre_ffw_rms_out; MatStorageT C1; MatStorageT C2; MatStorageT ffw_out; AttentionImpl attention_impl; AttentionActivations attention_storage; AttentionActivationsPtrs attention; }; } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_