// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Implementation of the Vector-Jacobian Products (VJP) of the individual // operations of the forward pass. // Include guard for non-SIMD code. #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_ #include #include #include #include "backprop/activations.h" #include "backprop/prompt.h" #include "gemma/common.h" // EmbeddingScaling #include "gemma/configs.h" // LayerConfig, ModelConfig #include "gemma/weights.h" #include "util/allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_ // Include guard for (potentially) SIMD code. #if defined(THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE) #ifdef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE #undef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE #else #define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE #endif #include "hwy/highway.h" // After highway.h #include "ops/matmul-inl.h" #include "ops/ops-inl.h" #include "hwy/contrib/dot/dot-inl.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; HWY_INLINE void MatMulVJP(const float* HWY_RESTRICT weights, // kRows * kCols, const float* HWY_RESTRICT x, // num_tokens * kCols const float* HWY_RESTRICT v, // num_tokens * kRows size_t cols, size_t rows, size_t num_tokens, float* HWY_RESTRICT grad_w, // kRows * kCols, float* HWY_RESTRICT grad_x, // num_tokens * kCols hwy::ThreadPool& pool) { hwy::ZeroBytes(grad_x, num_tokens * cols * sizeof(grad_x[0])); for (size_t pos = 0; pos < num_tokens; ++pos) { const size_t voffs = pos * rows; const size_t xoffs = pos * cols; for (size_t j = 0; j < rows; ++j) { MulByConstAndAdd(v[voffs + j], &x[xoffs], &grad_w[j * cols], cols); MulByConstAndAdd(v[voffs + j], &weights[j * cols], &grad_x[xoffs], cols); } } } HWY_INLINE void MultiHeadMatMulVJP( const float* HWY_RESTRICT weights, // heads * kRows * kCols const float* HWY_RESTRICT x, // num_tokens * heads * kCols const float* HWY_RESTRICT v, // num_tokens * kRows size_t heads, size_t cols, size_t rows, size_t num_tokens, float* HWY_RESTRICT grad_w, // heads * kRows * kCols float* HWY_RESTRICT grad_x, // num_tokens * heads * kCols hwy::ThreadPool& pool) { hwy::ZeroBytes(grad_x, num_tokens * heads * cols * sizeof(grad_x[0])); for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t j = 0; j < rows; ++j) { for (size_t h = 0; h < heads; ++h) { MulByConstAndAdd(v[pos * rows + j], &x[pos * heads * cols + h * cols], &grad_w[h * rows * cols + j * cols], cols); MulByConstAndAdd(v[pos * rows + j], &weights[h * rows * cols + j * cols], &grad_x[pos * heads * cols + h * cols], cols); } } } } template static HWY_INLINE hn::Vec DGelu(D d, hn::Vec v) { const hn::Vec kMul = hn::Set(d, 0.044715f); const hn::Vec kSqrt2OverPi = hn::Set(d, 0.797884560804236f); const hn::Vec kHalf = hn::Set(d, 0.5f); const hn::Vec kOne = hn::Set(d, 1.0f); // kSqrtOverPi*3*kMul const hn::Vec kMulv2 = hn::Set(d, 0.1070322244f); const hn::Vec v2 = hn::Mul(v, v); const hn::Vec v3 = hn::Mul(v2, v); const hn::Vec arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v)); const hn::Vec tanh = hn::Tanh(d, arg); const hn::Vec cdf = hn::MulAdd(kHalf, tanh, kHalf); const hn::Vec dtanh = hn::Sub(kOne, hn::Mul(tanh, tanh)); const hn::Vec darg = hn::MulAdd(kMulv2, v2, kSqrt2OverPi); return hn::MulAdd(kHalf, hn::Mul(v, hn::Mul(dtanh, darg)), cdf); } static HWY_NOINLINE void SoftmaxVJP(const float* HWY_RESTRICT forward, float* HWY_RESTRICT backward, const size_t size) { namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; const D d; const auto offset = hn::Set(d, hn::Dot::Compute<0>(d, forward, backward, size)); hn::Transform1( d, backward, size, forward, [&offset](const auto d, const auto v, const auto y) HWY_ATTR { return hn::Mul(y, hn::Sub(v, offset)); }); } static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormVJP( const float* HWY_RESTRICT weights, const float* HWY_RESTRICT x, const float* HWY_RESTRICT v, size_t model_dim, size_t num_tokens, float* HWY_RESTRICT grad_w, float* HWY_RESTRICT grad_x, hwy::ThreadPool& pool) { for (size_t pos = 0; pos < num_tokens; ++pos) { const size_t offset = pos * model_dim; const float ss = detail::RMSNormMul(x + offset, model_dim); for (size_t i = 0; i < model_dim; ++i) { grad_w[i] += v[offset + i] * x[offset + i] * ss; } const float ss3 = ss * ss * ss / StaticCast(model_dim); float tmp = 0.0f; for (size_t i = 0; i < model_dim; ++i) { tmp += (1.0f + weights[i]) * v[offset + i] * x[offset + i]; } tmp *= ss3; for (size_t i = 0; i < model_dim; ++i) { grad_x[offset + i] = ss * (1.0f + weights[i]) * v[offset + i] - tmp * x[offset + i]; } } } static HWY_NOINLINE HWY_MAYBE_UNUSED void InputEmbeddingVJP( const float* weights, const std::vector& prompt, const float scaling, const float* HWY_RESTRICT v, float* HWY_RESTRICT grad, size_t model_dim) { HWY_ASSERT(!prompt.empty()); for (size_t pos = 0; pos < prompt.size() - 1; ++pos) { int token = prompt[pos]; MulByConstAndAdd(scaling, v + pos * model_dim, grad + token * model_dim, model_dim); } } template void LayerVJP(const LayerWeightsPtrs& weights, const ForwardLayer& forward, const float* HWY_RESTRICT next_layer_grad, size_t num_tokens, LayerWeightsPtrs& grad, ForwardLayer& backward, const RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { const LayerConfig& config = weights.layer_config; const size_t model_dim = config.model_dim; const size_t qkv_dim = config.qkv_dim; const size_t heads = config.heads; const size_t seq_len = forward.input.Rows(); const size_t ff_hidden_dim = config.ff_hidden_dim; const float query_scale = static_cast(1.0 / sqrt(static_cast(qkv_dim))); HWY_ASSERT(num_tokens <= seq_len); MatMulVJP(weights.linear_w.Packed(), forward.ffw_hidden_gated.Packed(), next_layer_grad, ff_hidden_dim, model_dim, num_tokens, grad.linear_w.Packed(), backward.ffw_hidden_gated.Packed(), pool); for (size_t pos = 0; pos < num_tokens; ++pos) { const size_t hidden_offset = pos * ff_hidden_dim * 2; const float* HWY_RESTRICT f_out = forward.ffw_hidden.Packed() + hidden_offset; const float* HWY_RESTRICT f_out_mul = f_out + ff_hidden_dim; const float* HWY_RESTRICT b_out_gated = backward.ffw_hidden_gated.Packed() + pos * ff_hidden_dim; float* HWY_RESTRICT b_out = backward.ffw_hidden.Packed() + hidden_offset; float* HWY_RESTRICT b_out_mul = b_out + ff_hidden_dim; namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; DF df; for (size_t i = 0; i < ff_hidden_dim; i += Lanes(df)) { const auto y = Load(df, f_out + i); const auto x = Load(df, f_out_mul + i); const auto v = Load(df, b_out_gated + i); hn::Store(hn::Mul(v, Gelu(df, y)), df, b_out_mul + i); hn::Store(hn::Mul(v, hn::Mul(x, DGelu(df, y))), df, b_out + i); } } MatMulVJP(weights.gating_einsum_w.Packed(), forward.bf_pre_ffw_rms_out.Packed(), backward.ffw_hidden.Packed(), model_dim, ff_hidden_dim * 2, num_tokens, grad.gating_einsum_w.Packed(), backward.bf_pre_ffw_rms_out.Packed(), pool); RMSNormVJP( weights.pre_ffw_norm_scale.Packed(), forward.attention_out.Packed(), backward.bf_pre_ffw_rms_out.Packed(), model_dim, num_tokens, grad.pre_ffw_norm_scale.Packed(), backward.attention_out.Packed(), pool); for (size_t pos = 0; pos < num_tokens; ++pos) { AddFrom(next_layer_grad + pos * model_dim, backward.attention_out.Packed() + pos * model_dim, model_dim); } ZeroInit(backward.qkv); MultiHeadMatMulVJP( weights.attn_vec_einsum_w.Packed(), forward.att_out.Packed(), backward.attention_out.Packed(), heads, qkv_dim, model_dim, num_tokens, grad.attn_vec_einsum_w.Packed(), backward.att_out.Packed(), pool); for (size_t head = 0; head < heads; ++head) { for (size_t pos = 0; pos < num_tokens; ++pos) { const size_t aoffset = head * seq_len + pos * heads * seq_len; const float* HWY_RESTRICT f_head_att = forward.att.Packed() + aoffset; const float* HWY_RESTRICT b_att_out = backward.att_out.Packed() + (pos * heads + head) * qkv_dim; float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffset; for (size_t pos2 = 0; pos2 <= pos; ++pos2) { const size_t v2offs = (pos2 * (heads + 2) + heads + 1) * qkv_dim; const float* HWY_RESTRICT f_v2 = forward.qkv.Packed() + v2offs; float* HWY_RESTRICT b_v2 = backward.qkv.Packed() + v2offs; b_head_att[pos2] = Dot(b_att_out, f_v2, qkv_dim); MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, qkv_dim); } } } for (size_t head = 0; head < heads; ++head) { for (size_t pos = 0; pos < num_tokens; ++pos) { const size_t aoffset = head * seq_len + pos * heads * seq_len; const float* HWY_RESTRICT f_head_att = forward.att.Packed() + aoffset; float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffset; SoftmaxVJP(f_head_att, b_head_att, pos + 1); } } for (size_t head = 0; head < heads; ++head) { for (size_t pos = 0; pos < num_tokens; ++pos) { const size_t qoffs = (pos * (heads + 2) + head) * qkv_dim; const size_t aoffs = head * seq_len + pos * heads * seq_len; const float* HWY_RESTRICT f_q = forward.qkv.Packed() + qoffs; const float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffs; float* HWY_RESTRICT b_q = backward.qkv.Packed() + qoffs; for (size_t pos2 = 0; pos2 <= pos; ++pos2) { const size_t k2offs = (pos2 * (heads + 2) + heads) * qkv_dim; const float* HWY_RESTRICT f_k2 = forward.qkv.Packed() + k2offs; float* HWY_RESTRICT b_k2 = backward.qkv.Packed() + k2offs; MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, qkv_dim); MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, qkv_dim); } } } for (int pos = 0; pos < static_cast(num_tokens); ++pos) { float* HWY_RESTRICT b_kv = backward.qkv.Packed() + (pos * (heads + 2) + heads) * qkv_dim; Rope(b_kv, qkv_dim, inv_timescale.Const(), -pos); } for (size_t head = 0; head < heads; ++head) { for (size_t pos = 0; pos < num_tokens; ++pos) { float* HWY_RESTRICT b_q = backward.qkv.Packed() + (pos * (heads + 2) + head) * qkv_dim; MulByConst(query_scale, b_q, qkv_dim); Rope(b_q, qkv_dim, inv_timescale.Const(), -pos); } } MatMulVJP(weights.qkv_einsum_w.Packed(), forward.pre_att_rms_out.Packed(), backward.qkv.Packed(), model_dim, (heads + 2) * qkv_dim, num_tokens, grad.qkv_einsum_w.Packed(), backward.pre_att_rms_out.Packed(), pool); RMSNormVJP(weights.pre_attention_norm_scale.Packed(), forward.input.Packed(), backward.pre_att_rms_out.Packed(), model_dim, num_tokens, grad.pre_attention_norm_scale.Packed(), backward.input.Packed(), pool); for (size_t pos = 0; pos < num_tokens; ++pos) { AddFrom(backward.attention_out.Packed() + pos * model_dim, backward.input.Packed() + pos * model_dim, model_dim); } } static HWY_NOINLINE void SoftcapVJP(const float cap, const float* HWY_RESTRICT forward, float* HWY_RESTRICT backward, const size_t size) { namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; const D d; const auto one = hn::Set(d, 1.0f); const auto vcap = hn::Set(d, cap); const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap); hn::Transform1(d, backward, size, forward, [&](const auto d, const auto v, const auto y) HWY_ATTR { const auto scaled = hn::Mul(vinv_cap, y); // = tanh return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled))); }); } static HWY_NOINLINE void CrossEntropyLossGrad( const float* HWY_RESTRICT x, float* HWY_RESTRICT grad, const Prompt& prompt, size_t vocab_size) { HWY_ASSERT(!prompt.tokens.empty()); const float scaling = -1.0 / std::log(2.0); size_t num_tokens = prompt.tokens.size() - 1; hwy::ZeroBytes(grad, num_tokens * vocab_size * sizeof(grad[0])); for (size_t pos = 0; pos < num_tokens; ++pos) { if (pos + 1 < prompt.context_size) { continue; } const int next_token = prompt.tokens[pos + 1]; grad[pos * vocab_size + next_token] = scaling / x[pos * vocab_size + next_token]; } } template void CrossEntropyLossBackwardPassInl(const Prompt& prompt, const ModelWeightsPtrs& weights, const ForwardPass& forward, ModelWeightsPtrs& grad, ForwardPass& backward, RowVectorBatch& inv_timescale, hwy::ThreadPool& pool) { const ModelConfig& config = weights.weights_config; const size_t kVocabSize = config.vocab_size; const size_t model_dim = config.model_dim; const size_t kLayers = config.layer_configs.size(); const float kEmbScaling = EmbeddingScaling(model_dim); HWY_ASSERT(!config.absolute_pe); HWY_ASSERT(config.layer_configs[0].post_norm == PostNormType::None); HWY_ASSERT(config.layer_configs[0].kv_heads == 1); HWY_DASSERT(prompt.context_size > 0); HWY_DASSERT(prompt.context_size < prompt.tokens.size()); const size_t num_tokens = prompt.tokens.size() - 1; CrossEntropyLossGrad(forward.probs.Packed(), backward.logits.Packed(), prompt, kVocabSize); for (size_t pos = 0; pos < num_tokens; ++pos) { SoftmaxVJP(forward.probs.Packed() + pos * kVocabSize, backward.logits.Packed() + pos * kVocabSize, kVocabSize); } if (config.final_cap > 0.0f) { for (size_t pos = 0; pos < num_tokens; ++pos) { SoftcapVJP(config.final_cap, forward.logits.Packed() + pos * kVocabSize, backward.logits.Packed() + pos * kVocabSize, kVocabSize); } } MatMulVJP(weights.embedder_input_embedding.Packed(), forward.final_norm_output.Packed(), backward.logits.Packed(), model_dim, kVocabSize, num_tokens, grad.embedder_input_embedding.Packed(), backward.final_norm_output.Packed(), pool); RMSNormVJP(weights.final_norm_scale.Packed(), forward.final_layer_output.Packed(), backward.final_norm_output.Packed(), model_dim, num_tokens, grad.final_norm_scale.Packed(), backward.final_layer_output.Packed(), pool); for (int layer = static_cast(kLayers) - 1; layer >= 0; --layer) { auto layer_config = config.layer_configs[layer]; // TODO(szabadka) Implement Griffin layer vjp. HWY_ASSERT(layer_config.type == LayerAttentionType::kGemma); float* next_layer_grad = layer + 1 < kLayers ? backward.layers[layer + 1].input.Packed() : backward.final_layer_output.Packed(); LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad, num_tokens, *grad.GetLayer(layer), backward.layers[layer], inv_timescale, pool); } InputEmbeddingVJP(weights.embedder_input_embedding.Packed(), prompt.tokens, kEmbScaling, backward.layers[0].input.Packed(), grad.embedder_input_embedding.Packed(), model_dim); } // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp HWY_AFTER_NAMESPACE(); #endif // NOLINT