// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Include guard for non-SIMD code. #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_ #include #include #include #include #include "backprop/activations.h" #include "gemma/common.h" #include "gemma/configs.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_ // Include guard for (potentially) SIMD code. #if defined(THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE) #ifdef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE #undef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE #else #define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE #endif #include "gemma/ops.h" #include "hwy/highway.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { template void InputEmbedding(const ArrayT& weights, const std::vector& prompt, const float scaling, float* HWY_RESTRICT output, size_t model_dim) { HWY_ASSERT(!prompt.empty()); for (size_t pos = 0; pos < prompt.size() - 1; ++pos) { int token = prompt[pos]; Decompress(weights, token * model_dim, output + pos * model_dim, model_dim); MulByConst(scaling, output + pos * model_dim, model_dim); } } template void ApplyRMSNorm(const WT* HWY_RESTRICT weights, const XT* HWY_RESTRICT x, size_t model_dim, size_t num_tokens, OutT* HWY_RESTRICT output, hwy::ThreadPool& pool) { for (size_t pos = 0; pos < num_tokens; ++pos) { const size_t offset = pos * model_dim; RMSNorm(x + offset, weights, output + offset, model_dim); } } static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs, const std::vector& prompt, size_t context_size, size_t vocab_size, hwy::ThreadPool& pool) { HWY_ASSERT(!prompt.empty()); float loss = 0.0f; for (size_t pos = 0; pos < prompt.size() - 1; ++pos) { if (pos + 1 < context_size) { continue; // next token is part of context, don't try to predict it } const int next_token = prompt[pos + 1]; loss += std::log(probs[pos * vocab_size + next_token]); } float scaling = -1.0 / std::log(2.0); return loss * scaling; } template typename LayerT> void ApplyForwardLayer(const LayerT& weights, ForwardLayer& activations, size_t num_tokens, float* HWY_RESTRICT output, hwy::ThreadPool& pool) { static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kSeqLen = TConfig::kSeqLen; static constexpr size_t kQKVDim = TConfig::kQKVDim; static constexpr size_t kHeads = TConfig::kHeads; static const float kQueryScale = static_cast(1.0 / sqrt(static_cast(kQKVDim))); HWY_ASSERT(num_tokens <= kSeqLen); ApplyRMSNorm(weights.pre_attention_norm_scale.data(), activations.input.data(), kModelDim, num_tokens, activations.pre_att_rms_out.data(), pool); for (size_t pos = 0; pos < num_tokens; ++pos) { MatVec<(kHeads + 2) * kQKVDim, kModelDim>( weights.qkv_einsum_w, 0, activations.pre_att_rms_out.data() + pos * kModelDim, nullptr, activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool); } const size_t num_tasks = kHeads * num_tokens; for (size_t pos = 0; pos < num_tokens; ++pos) { float* HWY_RESTRICT k = activations.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim; Rope(k, kQKVDim, pos); } pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { const size_t head = task % kHeads; const size_t pos = task / kHeads; float* HWY_RESTRICT q = activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim; Rope(q, kQKVDim, pos); MulByConst(kQueryScale, q, kQKVDim); }); pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { const size_t head = task % kHeads; const size_t pos = task / kHeads; const float* HWY_RESTRICT q = activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim; float* HWY_RESTRICT head_att = activations.att.data() + (pos * kHeads + head) * kSeqLen; for (size_t pos2 = 0; pos2 <= pos; ++pos2) { const float* HWY_RESTRICT k2 = activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads) * kQKVDim; const float score = Dot(q, k2, kQKVDim); head_att[pos2] = score; } }); pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { const size_t head = task % kHeads; const size_t pos = task / kHeads; float* HWY_RESTRICT head_att = activations.att.data() + (pos * kHeads + head) * kSeqLen; Softmax(head_att, pos + 1); }); pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { const size_t head = task % kHeads; const size_t pos = task / kHeads; const float* HWY_RESTRICT head_att = activations.att.data() + (pos * kHeads + head) * kSeqLen; float* HWY_RESTRICT att_out = activations.att_out.data() + (pos * kHeads + head) * kQKVDim; hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); for (size_t pos2 = 0; pos2 <= pos; ++pos2) { float* HWY_RESTRICT v2 = activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim; MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); } }); hwy::ZeroBytes(activations.attention_out.data(), num_tokens * kModelDim * sizeof(activations.attention_out[0])); for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t head = 0; head < kHeads; ++head) { MatVec( weights.attn_vec_einsum_w, head * kModelDim * kQKVDim, activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim, nullptr, activations.att_post1.data() + pos * kModelDim, pool); AddFrom(activations.att_post1.data() + pos * kModelDim, activations.attention_out.data() + pos * kModelDim, kModelDim); } } for (size_t pos = 0; pos < num_tokens; ++pos) { AddFrom(activations.input.data() + pos * kModelDim, activations.attention_out.data() + pos * kModelDim, kModelDim); } ApplyRMSNorm(weights.pre_ffw_norm_scale.data(), activations.attention_out.data(), kModelDim, num_tokens, activations.bf_pre_ffw_rms_out.data(), pool); static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; for (size_t pos = 0; pos < num_tokens; ++pos) { MatVec( weights.gating_einsum_w, 0, activations.bf_pre_ffw_rms_out.data() + pos * kModelDim, nullptr, activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool); } for (size_t pos = 0; pos < num_tokens; ++pos) { const size_t hidden_offset = pos * kFFHiddenDim * 2; const float* HWY_RESTRICT out = activations.ffw_hidden.data() + hidden_offset; const float* HWY_RESTRICT out_mul = out + kFFHiddenDim; float* HWY_RESTRICT out_gated = activations.ffw_hidden_gated.data() + pos * kFFHiddenDim; namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; DF df; for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) { const auto y = hn::Load(df, out + i); const auto x = hn::Load(df, out_mul + i); hn::Store(hn::Mul(x, Gelu(df, y)), df, out_gated + i); } } for (size_t pos = 0; pos < num_tokens; ++pos) { MatVec( weights.linear_w, 0, activations.ffw_hidden_gated.data() + pos * kFFHiddenDim, nullptr, output + pos * kModelDim, pool); } for (size_t pos = 0; pos < num_tokens; ++pos) { AddFrom(activations.attention_out.data() + pos * kModelDim, output + pos * kModelDim, kModelDim); } } template typename WeightsT, template typename LayerT> float CrossEntropyLossForwardPass(const std::vector& prompt, size_t context_size, const WeightsT& weights, ForwardPass& forward, hwy::ThreadPool& pool) { static constexpr size_t kVocabSize = TConfig::kVocabSize; static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kLayers = TConfig::kLayers; const float kEmbScaling = EmbeddingScaling(); static_assert(!TConfig::kAbsolutePE); static_assert(!TConfig::kPostNormScale); static_assert(TConfig::kKVHeads == 1); HWY_DASSERT(context_size > 0); HWY_DASSERT(context_size < prompt.size()); const size_t num_tokens = prompt.size() - 1; InputEmbedding(weights.embedder_input_embedding, prompt, kEmbScaling, forward.layers[0].input.data(), kModelDim); for (size_t layer = 0; layer < kLayers; ++layer) { auto type = TConfig::kLayerConfig[layer]; // TODO(szabadka) Implement Griffin layer. HWY_ASSERT(type == LayerAttentionType::kGemma); float* HWY_RESTRICT output = layer + 1 < kLayers ? forward.layers[layer + 1].input.data() : forward.final_layer_output.data(); ApplyForwardLayer( *weights.GetLayer(layer), forward.layers[layer], num_tokens, output, pool); } ApplyRMSNorm(weights.final_norm_scale.data(), forward.final_layer_output.data(), kModelDim, num_tokens, forward.final_norm_output.data(), pool); for (size_t pos = 0; pos < num_tokens; ++pos) { MatVec( weights.embedder_input_embedding, 0, forward.final_norm_output.data() + pos * kModelDim, nullptr, forward.logits.data() + pos * kVocabSize, pool); } if constexpr (TConfig::kFinalCap > 0.0f) { for (size_t pos = 0; pos < num_tokens; ++pos) { LogitsSoftCap(TConfig::kFinalCap, forward.logits.data() + pos * kVocabSize, kVocabSize); } } hwy::CopyBytes(forward.logits.data(), forward.probs.data(), num_tokens * kVocabSize * sizeof(forward.logits[0])); for (size_t pos = 0; pos < num_tokens; ++pos) { Softmax(forward.probs.data() + pos * kVocabSize, kVocabSize); } return CrossEntropyLoss(forward.probs.data(), prompt, context_size, kVocabSize, pool); } // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp HWY_AFTER_NAMESPACE(); #endif // NOLINT