// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Transformer components shared between vit.cc and attention.cc. #include #include #include "gemma/activations.h" #include "gemma/configs.h" #include "gemma/tensor_stats.h" #include "gemma/weights.h" #include "ops/matmul.h" #include "util/mat.h" #include "util/threading.h" #include "util/zones.h" #include "hwy/profiler.h" // Include guard (still compiled once per target) #if defined(THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_) == \ defined(HWY_TARGET_TOGGLE) #ifdef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_ #undef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_ #else #define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_ #endif #include "hwy/highway.h" // After highway.h #include "ops/ops-inl.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { // For use by Vit even if !GEMMA_FUSED_FFN. template void Activation(ActivationType activation, T1* HWY_RESTRICT c1, const T2* HWY_RESTRICT c2, const size_t count, ThreadingContext& ctx, const size_t worker) { GCPP_ZONE(ctx, worker, Zones::kGenActivation); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; // ActivationType::Gelu if (c2 == nullptr) { // No multiplier, just Gelu. Gelu(c1, count); return; }; // Has multiplier, Gelu(c1) * c2. Decompress1AndCompressInplace(DF(), c1, count, c2, /*p1_ofs=*/0, [](DF df, VF v1, VF v2) HWY_ATTR -> VF { return hn::Mul(v2, Gelu(df, v1)); }); } // No C2 multiplier - used by Vit. template void ActivationBatched( ActivationType activation, Mat& c1, ThreadingContext& ctx, size_t cluster_idx = 0, Parallelism parallelism = Parallelism::kFlat) { using T = typename Mat::T; ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, Callers::kActivationBatched, [&](uint64_t task, size_t worker) { // Cast to correct type so type deduction works. Activation(activation, c1.Row(task), static_cast(nullptr), c1.Cols(), ctx, worker); }); } #if GEMMA_FUSED_FFN // Called during `TwoMatMul`. static inline void Activation(ActivationType activation, const RowPtrsBF C1, const IndexRange range_r, const IndexRange range_c, const StridedViewBF C2, ThreadingContext& ctx, const size_t worker) { GCPP_ZONE(ctx, worker, Zones::kGenActivationFused); const size_t cols = range_c.Num(); HWY_DASSERT(C2.Cols() == cols); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; using VF = hn::Vec; // ActivationType::Gelu // Gated: Gelu(c1) * c2. for (size_t ir = 0; ir < range_r.Num(); ++ir) { Decompress1AndCompressInplace( DF(), C1.Row(range_r.begin() + ir) + range_c.begin(), cols, C2.Row(ir), /*p1_ofs*/ 0, [](DF df, VF v1, VF v2) HWY_ATTR -> VF { return hn::Mul(v2, Gelu(df, v1)); }); } } #endif // GEMMA_FUSED_FFN // Only used if !GEMMA_FUSED_FFN, but define anyway so that we can check // using if constexpr rather than #if, which interferes with code folding. template HWY_NOINLINE void ActivationBatched( ActivationType activation, Mat1& c1, const Mat2* c2, ThreadingContext& ctx, size_t cluster_idx = 0, Parallelism parallelism = Parallelism::kFlat) { HWY_DASSERT(c1.SameShape(*c2)); if (c2 && c2->HasPtr()) { ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, Callers::kActivationBatched, [&](uint64_t task, size_t worker) { Activation(activation, c1.Row(task), c2->Row(task), c1.Cols(), ctx, worker); }); } else { // No multiplier ParallelFor(parallelism, c1.Rows(), ctx, cluster_idx, Callers::kActivationBatched, [&](uint64_t task, size_t worker) { Activation(activation, c1.Row(task), static_cast(nullptr), c1.Cols(), ctx, worker); }); } } template HWY_NOINLINE void ResidualConnection(const MatPtrT& other, MatPtrT& HWY_RESTRICT x, const LayerWeights& layer, bool is_attention, ThreadingContext& ctx) { // ResidualType::Add AddFromBatched(other, x, ctx); } template void PostNorm(PostNormType post_norm, const MatPtr& weights, MatPtrT& inout, ThreadingContext& ctx) { HWY_DASSERT(weights.Rows() == 1); if (post_norm == PostNormType::Scale) { RMSNormInplaceBatched(weights, inout, ctx); } } static inline void FFWNoVit(const LayerWeightsPtrs& layer, Activations& activations, MatMulEnv& env) { GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenFFW); const LayerConfig& layer_config = layer.layer_config; HWY_DASSERT(!layer_config.ff_biases); // Only used in Vit. activations.s_ffw_in.Notify(layer.layer_idx, activations.pre_ffw_rms_out, env.ctx); #if GEMMA_FUSED_FFN const auto fused = [&](RowPtrsBF C1, IndexRange range_r, IndexRange range_c, StridedViewBF C2, size_t worker) { Activation(layer_config.activation, C1, range_r, range_c, C2, env.ctx, worker); }; MMOptions options; options.SetFunc(fused); CallTwoMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, layer.gating_einsum_w2, env, activations.C1, options); #else // Compute the hidden layer activations. CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w1, nullptr, env, activations.C1); CallMatMul(activations.pre_ffw_rms_out, layer.gating_einsum_w2, nullptr, env, activations.C2); // Activation (Gelu) and maybe multiply by gate. Store activations in act. ActivationBatched(layer_config.activation, activations.C1, &activations.C2, env.ctx); #endif activations.s_ffw_hidden.Notify(layer.layer_idx, activations.C1, env.ctx); // Hidden layer -> output layer. CallMatMul(activations.C1, layer.linear_w, nullptr, env, activations.ffw_out); activations.s_ffw_out.Notify(layer.layer_idx, activations.ffw_out, env.ctx); } // Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and // head_dim (`qkv_dim`) into output (`layer_out`). static HWY_INLINE void SumHeads(const LayerWeightsPtrs& layer, AttentionActivationsPtrs& activations, MatMulEnv& env) { GCPP_ZONE(env.ctx, hwy::Profiler::GlobalIdx(), Zones::kGenAttentionSumHeads); const LayerConfig& layer_config = layer.layer_config; (void)layer_config; // For HWY_DASSERT // att_weights and att_out are concatenated heads, each of length // layer_config.qkv_dim. Thus the [num_interleaved, // layer_config.model_dim] matmul output is the sum over heads. Compare // gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', // encoded) HWY_DASSERT(layer_config.model_dim != 0 && layer_config.heads != 0 && layer_config.qkv_dim != 0); CallMatMul(activations.att_out, layer.att_weights, /*add=*/nullptr, env, activations.att_sums); } // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp HWY_AFTER_NAMESPACE(); #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_