diff --git a/BUILD.bazel b/BUILD.bazel index 59ffa69..96f9c12 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -59,8 +59,12 @@ cc_library( "gemma/gemma.cc", ], hdrs = [ + "gemma/activations.h", + "gemma/common.h", + "gemma/common-inl.h", "gemma/configs.h", "gemma/gemma.h", + "gemma/weights.h", ], deps = [ ":ops", diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a07305..dcc9087 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,10 +46,28 @@ set(SOURCES compression/sfp.h compression/sfp-inl.h compression/test_util.h + backprop/backward.cc + backprop/backward.h + backprop/backward-inl.h + backprop/backward_scalar.h + backprop/common_scalar.cc + backprop/common_scalar.h + backprop/forward.cc + backprop/forward.h + backprop/forward-inl.h + backprop/forward_scalar.h + backprop/optimizer.cc + backprop/optimizer.h gemma/configs.h + gemma/activations.cc + gemma/activations.h + gemma/common.h + gemma/common-inl.h gemma/gemma.cc gemma/gemma.h gemma/ops.h + gemma/weights.cc + gemma/weights.h util/app.h util/args.h ) @@ -100,7 +118,13 @@ target_link_libraries(debug_prompt libgemma hwy hwy_contrib nlohmann_json::nlohm set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests") if (GEMMA_ENABLE_TESTS) +enable_testing() +include(GoogleTest) + set(GEMMA_TEST_FILES + backprop/backward_test.cc + backprop/backward_scalar_test.cc + backprop/optimize_test.cc gemma/ops_test.cc gemma/gemma_test.cc ) diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h new file mode 100644 index 0000000..6d11fea --- /dev/null +++ b/backprop/backward-inl.h @@ -0,0 +1,417 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Implementation of the Vector-Jacobian Products (VJP) of the individual +// operations of the forward pass. + +// Include guard for non-SIMD code. +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_ + +#include +#include + +#include +#include + +#include "backprop/prompt.h" +#include "gemma/activations.h" +#include "gemma/weights.h" + +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_ + +// Include guard for (potentially) SIMD code. +#if defined(THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE +#undef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE +#else +#define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE +#endif + +#include "gemma/common-inl.h" +#include "gemma/ops.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +template +void MatMulVJP(const std::array& weights, + const float* HWY_RESTRICT x, // num_tokens * kCols + const float* HWY_RESTRICT v, // num_tokens * kRows + size_t num_tokens, + std::array& grad_w, + float* HWY_RESTRICT grad_x, // num_tokens * kCols + hwy::ThreadPool& pool) { + hwy::ZeroBytes(grad_x, num_tokens * kCols * sizeof(grad_x[0])); + for (size_t pos = 0; pos < num_tokens; ++pos) { + const size_t voffs = pos * kRows; + const size_t xoffs = pos * kCols; + for (size_t j = 0; j < kRows; ++j) { + MulByConstAndAdd(v[voffs + j], &x[xoffs], &grad_w[j * kCols], kCols); + MulByConstAndAdd(v[voffs + j], &weights[j * kCols], &grad_x[xoffs], + kCols); + } + } +} + +template +void MultiHeadMatMulVJP( + const std::array& weights, + const float* HWY_RESTRICT x, // num_tokens * kHeads * kCols + const float* HWY_RESTRICT v, // num_tokens * kRows + size_t num_tokens, + std::array& grad_w, + float* HWY_RESTRICT grad_x, // num_tokens * kHeads * kCols + hwy::ThreadPool& pool) { + hwy::ZeroBytes(grad_x, num_tokens * kHeads * kCols * sizeof(grad_x[0])); + for (size_t pos = 0; pos < num_tokens; ++pos) { + for (size_t j = 0; j < kRows; ++j) { + for (size_t h = 0; h < kHeads; ++h) { + MulByConstAndAdd(v[pos * kRows + j], + &x[pos * kHeads * kCols + h * kCols], + &grad_w[h * kRows * kCols + j * kCols], kCols); + MulByConstAndAdd(v[pos * kRows + j], + &weights[h * kRows * kCols + j * kCols], + &grad_x[pos * kHeads * kCols + h * kCols], kCols); + } + } + } +} + +template +static HWY_INLINE hn::Vec DGelu(D d, hn::Vec v) { + const hn::Vec kMul = hn::Set(d, 0.044715f); + const hn::Vec kSqrt2OverPi = hn::Set(d, 0.797884560804236f); + const hn::Vec kHalf = hn::Set(d, 0.5f); + const hn::Vec kOne = hn::Set(d, 1.0f); + // kSqrtOverPi*3*kMul + const hn::Vec kMulv2 = hn::Set(d, 0.1070322244f); + + const hn::Vec v2 = hn::Mul(v, v); + const hn::Vec v3 = hn::Mul(v2, v); + const hn::Vec arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v)); + const hn::Vec tanh = hn::Tanh(d, arg); + const hn::Vec cdf = hn::MulAdd(kHalf, tanh, kHalf); + const hn::Vec dtanh = hn::Sub(kOne, hn::Mul(tanh, tanh)); + const hn::Vec darg = hn::MulAdd(kMulv2, v2, kSqrt2OverPi); + return hn::MulAdd(kHalf, hn::Mul(v, hn::Mul(dtanh, darg)), cdf); +} + +static HWY_NOINLINE void SoftmaxVJP(const float* HWY_RESTRICT forward, + float* HWY_RESTRICT backward, + const size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + const D d; + + const auto offset = + hn::Set(d, hn::Dot::Compute<0>(d, forward, backward, size)); + hn::Transform1( + d, backward, size, forward, + [&offset](const auto d, const auto v, const auto y) + HWY_ATTR { return hn::Mul(y, hn::Sub(v, offset)); }); +} + +static HWY_NOINLINE void RMSNormVJP( + const float* HWY_RESTRICT weights, const float* HWY_RESTRICT x, + const float* HWY_RESTRICT v, size_t model_dim, size_t num_tokens, + float* HWY_RESTRICT grad_w, float* HWY_RESTRICT grad_x, + hwy::ThreadPool& pool) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + const size_t offset = pos * model_dim; + constexpr float eps = 1e-6f; + float ss = SquaredL2(x + offset, model_dim); + ss = 1.0f / sqrtf(ss / StaticCast(model_dim) + eps); + for (size_t i = 0; i < model_dim; ++i) { + grad_w[i] += v[offset + i] * x[offset + i] * ss; + } + const float ss3 = ss * ss * ss / StaticCast(model_dim); + float tmp = 0.0f; + for (size_t i = 0; i < model_dim; ++i) { + tmp += (1.0f + weights[i]) * v[offset + i] * x[offset + i]; + } + tmp *= ss3; + for (size_t i = 0; i < model_dim; ++i) { + grad_x[offset + i] = ss * (1.0f + weights[i]) * v[offset + i] - + tmp * x[offset + i]; + } + } +} + +static HWY_NOINLINE void InputEmbeddingVJP( + const float* weights, const std::vector& prompt, + const float scaling, const float* HWY_RESTRICT v, + float* HWY_RESTRICT grad, size_t model_dim) { + HWY_ASSERT(!prompt.empty()); + for (size_t pos = 0; pos < prompt.size() - 1; ++pos) { + int token = prompt[pos]; + MulByConstAndAdd(scaling, v + pos * model_dim, + grad + token * model_dim, model_dim); + } +} + +template +void LayerVJP(const Layer& weights, + const ForwardLayer& forward, + const float* HWY_RESTRICT next_layer_grad, + size_t num_tokens, + Layer& grad, + ForwardLayer& backward, + hwy::ThreadPool& pool) { + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kQKVDim = TConfig::kQKVDim; + static constexpr size_t kHeads = TConfig::kHeads; + static constexpr size_t kSeqLen = TConfig::kSeqLen; + static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + static const float kQueryScale = + static_cast(1.0 / sqrt(static_cast(kQKVDim))); + HWY_ASSERT(num_tokens <= kSeqLen); + + MatMulVJP( + weights.linear_w, forward.ffw_hidden_gated.data(), next_layer_grad, + num_tokens, grad.linear_w, backward.ffw_hidden_gated.data(), + pool); + + for (size_t pos = 0; pos < num_tokens; ++pos) { + const size_t hidden_offset = pos * kFFHiddenDim * 2; + const float* HWY_RESTRICT f_out = forward.ffw_hidden.data() + hidden_offset; + const float* HWY_RESTRICT f_out_mul = f_out + kFFHiddenDim; + const float* HWY_RESTRICT b_out_gated = + backward.ffw_hidden_gated.data() + pos * kFFHiddenDim; + float* HWY_RESTRICT b_out = backward.ffw_hidden.data() + hidden_offset; + float* HWY_RESTRICT b_out_mul = b_out + kFFHiddenDim; + namespace hn = hwy::HWY_NAMESPACE; + using DF = hn::ScalableTag; + using VF = hn::Vec; + DF df; + for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) { + const auto y = Load(df, f_out + i); + const auto x = Load(df, f_out_mul + i); + const auto v = Load(df, b_out_gated + i); + hn::Store(hn::Mul(v, Gelu(df, y)), df, b_out_mul + i); + hn::Store(hn::Mul(v, hn::Mul(x, DGelu(df, y))), df, b_out + i); + } + } + + MatMulVJP( + weights.gating_einsum_w, + forward.bf_pre_ffw_rms_out.data(), backward.ffw_hidden.data(), + num_tokens, grad.gating_einsum_w, + backward.bf_pre_ffw_rms_out.data(), pool); + RMSNormVJP(weights.pre_ffw_norm_scale.data(), + forward.attention_out.data(), + backward.bf_pre_ffw_rms_out.data(), + kModelDim, num_tokens, + grad.pre_ffw_norm_scale.data(), + backward.attention_out.data(), pool); + + for (size_t pos = 0; pos < num_tokens; ++pos) { + AddFrom(next_layer_grad + pos * kModelDim, + backward.attention_out.data() + pos * kModelDim, kModelDim); + } + + hwy::ZeroBytes(backward.qkv.data(), + num_tokens * (kHeads + 2) * kQKVDim * sizeof(backward.qkv[0])); + + MultiHeadMatMulVJP( + weights.attn_vec_einsum_w, forward.att_out.data(), + backward.attention_out.data(), num_tokens, + grad.attn_vec_einsum_w, backward.att_out.data(), pool); + + for (size_t head = 0; head < kHeads; ++head) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen; + const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset; + const float* HWY_RESTRICT b_att_out = + backward.att_out.data() + (pos * kHeads + head) * kQKVDim; + float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset; + for (size_t pos2 = 0; pos2 <= pos; ++pos2) { + const size_t v2offs = (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim; + const float* HWY_RESTRICT f_v2 = forward.qkv.data() + v2offs; + float* HWY_RESTRICT b_v2 = backward.qkv.data() + v2offs; + b_head_att[pos2] = Dot(b_att_out, f_v2, kQKVDim); + MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, kQKVDim); + } + } + } + + for (size_t head = 0; head < kHeads; ++head) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen; + const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset; + float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset; + SoftmaxVJP(f_head_att, b_head_att, pos + 1); + } + } + + for (size_t head = 0; head < kHeads; ++head) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + const size_t qoffs = (pos * (kHeads + 2) + head) * kQKVDim; + const size_t aoffs = head * kSeqLen + pos * kHeads * kSeqLen; + const float* HWY_RESTRICT f_q = forward.qkv.data() + qoffs; + const float* HWY_RESTRICT b_head_att = backward.att.data() + aoffs; + float* HWY_RESTRICT b_q = backward.qkv.data() + qoffs; + for (size_t pos2 = 0; pos2 <= pos; ++pos2) { + const size_t k2offs = (pos2 * (kHeads + 2) + kHeads) * kQKVDim; + const float* HWY_RESTRICT f_k2 = forward.qkv.data() + k2offs; + float* HWY_RESTRICT b_k2 = backward.qkv.data() + k2offs; + MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, kQKVDim); + MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, kQKVDim); + } + } + } + + for (int pos = 0; pos < static_cast(num_tokens); ++pos) { + float* HWY_RESTRICT b_kv = + backward.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim; + Rope(b_kv, kQKVDim, -pos); + } + + for (size_t head = 0; head < kHeads; ++head) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + float* HWY_RESTRICT b_q = + backward.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim; + MulByConst(kQueryScale, b_q, kQKVDim); + Rope(b_q, kQKVDim, -pos); + } + } + + MatMulVJP( + weights.qkv_einsum_w, forward.pre_att_rms_out.data(), + backward.qkv.data(), num_tokens, + grad.qkv_einsum_w, backward.pre_att_rms_out.data(), pool); + RMSNormVJP(weights.pre_attention_norm_scale.data(), + forward.input.data(), + backward.pre_att_rms_out.data(), + kModelDim, num_tokens, + grad.pre_attention_norm_scale.data(), + backward.input.data(), pool); + for (size_t pos = 0; pos < num_tokens; ++pos) { + AddFrom(backward.attention_out.data() + pos * kModelDim, + backward.input.data() + pos * kModelDim, kModelDim); + } +} + +static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward, + float* HWY_RESTRICT backward, + const float cap, + const size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + const D d; + + const auto one = hn::Set(d, 1.0f); + const auto vcap = hn::Set(d, cap); + const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap); + + hn::Transform1( + d, backward, size, forward, + [&](const auto d, const auto v, const auto y) HWY_ATTR { + const auto scaled = hn::Mul(vinv_cap, y); + return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled))); + }); +} + +static HWY_NOINLINE void CrossEntropyLossGrad( + const float* HWY_RESTRICT x, float* HWY_RESTRICT grad, + const Prompt& prompt, size_t vocab_size) { + HWY_ASSERT(!prompt.tokens.empty()); + const float scaling = -1.0 / std::log(2.0); + size_t num_tokens = prompt.tokens.size() - 1; + hwy::ZeroBytes(grad, num_tokens * vocab_size * sizeof(grad[0])); + for (size_t pos = 0; pos < num_tokens; ++pos) { + if (pos + 1 < prompt.context_size) { + continue; + } + const int next_token = prompt.tokens[pos + 1]; + grad[pos * vocab_size + next_token] = + scaling / x[pos * vocab_size + next_token]; + } +} + +template +void CrossEntropyLossBackwardPass(const Prompt& prompt, + const Weights& weights, + const ForwardPass& forward, + Weights& grad, + ForwardPass& backward, + hwy::ThreadPool& pool) { + static constexpr size_t kVocabSize = TConfig::kVocabSize; + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kSeqLen = TConfig::kSeqLen; + static constexpr size_t kLayers = TConfig::kLayers; + const float kEmbScaling = EmbeddingScaling(); + static_assert(!TConfig::kAbsolutePE); + static_assert(!TConfig::kPostNormScale); + static_assert(TConfig::kKVHeads == 1); + + HWY_DASSERT(prompt.context_size > 0); + HWY_DASSERT(prompt.context_size < prompt.tokens.size()); + const size_t num_tokens = prompt.tokens.size() - 1; + + CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt, + kVocabSize); + + for (size_t pos = 0; pos < num_tokens; ++pos) { + SoftmaxVJP(forward.probs.data() + pos * kVocabSize, + backward.logits.data() + pos * kVocabSize, + kVocabSize); + } + + for (size_t pos = 0; pos < num_tokens; ++pos) { + SoftcapVJP(forward.logits.data() + pos * kVocabSize, + backward.logits.data() + pos * kVocabSize, 30.0f, kVocabSize); + } + + MatMulVJP( + weights.embedder_input_embedding, forward.final_norm_output.data(), + backward.logits.data(), num_tokens, + grad.embedder_input_embedding, backward.final_norm_output.data(), + pool); + + RMSNormVJP(weights.final_norm_scale.data(), + forward.final_layer_output.data(), + backward.final_norm_output.data(), + kModelDim, num_tokens, + grad.final_norm_scale.data(), + backward.final_layer_output.data(), pool); + + for (int layer = static_cast(kLayers) - 1; layer >= 0; --layer) { + auto type = TConfig::kLayerConfig[layer]; + // TODO(szabadka) Implement Griffin layer vjp. + HWY_ASSERT(type == LayerAttentionType::kGemma); + float* next_layer_grad = layer + 1 < kLayers + ? backward.layers[layer + 1].input.data() + : backward.final_layer_output.data(); + LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad, + num_tokens, *grad.GetLayer(layer), backward.layers[layer], pool); + } + + InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens, + kEmbScaling, backward.layers[0].input.data(), + grad.embedder_input_embedding.data(), kModelDim); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#endif // NOLINT diff --git a/backprop/backward.cc b/backprop/backward.cc new file mode 100644 index 0000000..d6987a4 --- /dev/null +++ b/backprop/backward.cc @@ -0,0 +1,89 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "backprop/backward.h" + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "backprop/backward.cc" // NOLINT +#include "hwy/foreach_target.h" // IWYU pragma: keep + +#include "hwy/highway.h" +// After highway.h +#include "backprop/backward-inl.h" +#include "gemma/weights.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +template +void CrossEntropyLossBackwardPass(const Prompt& prompt, + const ByteStorageT& weights_u8, + const ByteStorageT& forward_u8, + ByteStorageT& grad_u8, + ByteStorageT& backward_u8, + hwy::ThreadPool& pool) { + using TWeights = WeightsF; + const auto& weights = *reinterpret_cast(weights_u8.get()); + auto& grad = *reinterpret_cast(grad_u8.get()); + using TAct = ForwardPass; + const auto& forward = *reinterpret_cast(forward_u8.get()); + auto& backward = *reinterpret_cast(backward_u8.get()); + CrossEntropyLossBackwardPass(prompt, weights, forward, grad, backward, pool); +} + +void CrossEntropyLossBackwardPassT(Model model, + const Prompt& prompt, + const ByteStorageT& weights, + const ByteStorageT& forward, + ByteStorageT& grad, + ByteStorageT& backward, + hwy::ThreadPool& pool) { + switch (model) { + case Model::GEMMA_2B: + CrossEntropyLossBackwardPass( + prompt, weights, forward, grad, backward, pool); + break; + case Model::GEMMA_TINY: + CrossEntropyLossBackwardPass( + prompt, weights, forward, grad, backward, pool); + break; + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } +} + +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace gcpp { + +HWY_EXPORT(CrossEntropyLossBackwardPassT); + +void CrossEntropyLossBackwardPass( + const Model& model, const Prompt& prompt, + const ByteStorageT& weights, const ByteStorageT& forward, + ByteStorageT& grad, ByteStorageT& backward, hwy::ThreadPool& pool) { + return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)( + model, prompt, weights, forward, grad, backward, pool); +} + +} // namespace gcpp +#endif // HWY_ONCE diff --git a/backprop/backward.h b/backprop/backward.h new file mode 100644 index 0000000..6917f20 --- /dev/null +++ b/backprop/backward.h @@ -0,0 +1,34 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_ + +#include + +#include "backprop/prompt.h" +#include "gemma/common.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +namespace gcpp { + +void CrossEntropyLossBackwardPass( + const Model& model, const Prompt& prompt, + const ByteStorageT& weights, const ByteStorageT& forward, + ByteStorageT& grad, ByteStorageT& backward, hwy::ThreadPool& pool); + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_ diff --git a/backprop/backward_scalar.h b/backprop/backward_scalar.h new file mode 100644 index 0000000..1c9fbbc --- /dev/null +++ b/backprop/backward_scalar.h @@ -0,0 +1,353 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_ + +#include +#include + +#include +#include +#include + +#include "backprop/common_scalar.h" +#include "backprop/prompt.h" +#include "gemma/activations.h" +#include "gemma/weights.h" + +namespace gcpp { +template +void MatMulVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx, + size_t N, size_t M, size_t K) { + memset(dx, 0, M * K * sizeof(dx[0])); + for (size_t i = 0; i < K; ++i) { + for (size_t j = 0; j < N; ++j) { + MulByConstAndAddT(dy[i * N + j], &x[i * M], &dw[j * M], M); + MulByConstAndAddT(dy[i * N + j], &w[j * M], &dx[i * M], M); + } + } +} +template +void MultiHeadMatMulVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx, + size_t H, size_t N, size_t M, size_t K) { + memset(dx, 0, H * M * K * sizeof(dx[0])); + for (size_t i = 0; i < K; ++i) { + for (size_t j = 0; j < N; ++j) { + for (size_t h = 0; h < H; ++h) { + MulByConstAndAddT(dy[i * N + j], &x[i * H * M + h * M], + &dw[h * N * M + j * M], M); + MulByConstAndAddT(dy[i * N + j], &w[h * N * M + j * M], + &dx[i * H * M + h * M], M); + } + } + } +} + +template +void RMSNormVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx, + size_t N, size_t K) { + for (size_t i = 0; i < K; ++i) { + const size_t offset = i * N; + constexpr T eps(1e-6); + T ss = SquaredL2(x + i * N, N); + ss = T(1.0) / std::sqrt(ss / T(N) + eps); + for (size_t j = 0; j < N; ++j) { + dw[j] += dy[i * N + j] * x[i * N + j] * ss; + } + const T ss3 = ss * ss * ss / T(N); + T tmp = 0.0; + for (size_t j = 0; j < N; ++j) { + tmp += (T(1.0) + w[j]) * dy[i* N + j] * x[i * N + j]; + } + tmp *= ss3; + for (size_t j = 0; j < N; ++j) { + dx[i * N + j] = ss * (T(1.0) + w[j]) * dy[i* N + j] - tmp * x[i * N + j]; + } + } +} +template +void SoftmaxVJPT(const T* y, T* dy, size_t N) { + T sum = {}; + for (size_t i = 0; i < N; ++i) { + sum += y[i] * dy[i]; + } + for (size_t i = 0; i < N; ++i) { + dy[i] = y[i] * (dy[i] - sum); + } +} +template +void SoftmaxVJPT(const T* y, T* dy, size_t N, size_t K) { + for (size_t i = 0; i < K; ++i) { + SoftmaxVJPT(y + i * N, dy + i * N, N); + } +} + +template +T GeluDerivative(T x) { + static const T kMul = 0.044715; + static const T kSqrt2OverPi = 0.797884560804236; + static const T kMul2 = kSqrt2OverPi * T(3.0) * kMul; + + const T x2 = x * x; + const T x3 = x2 * x; + const T arg = kSqrt2OverPi * (kMul * x3 + x); + const T tanh = std::tanh(arg); + const T cdf = T(0.5) * (T(1.0) + tanh); + const T dtanh = T(1.0) - tanh * tanh; + const T darg = kMul2 * x2 + kSqrt2OverPi; + return T(0.5) * x * dtanh * darg + cdf; +} + +template +void GatedGeluVJP(const T* in, const T* d_out, T* d_in, size_t N, size_t K) { + for (size_t i = 0; i < K; ++i) { + const T* x1 = in + i * 2 * N; + const T* x2 = x1 + N; + const T* v = d_out + i * N; + T* dx1 = d_in + i * 2 * N; + T* dx2 = dx1 + N; + for (size_t j = 0; j < N; ++j) { + dx1[j] = v[j] * x2[j] * GeluDerivative(x1[j]); + dx2[j] = v[j] * Gelu(x1[j]); + } + } +} + + +template +void MaskedAttentionVJP(const T* qkv, const T* doutput, T* dqkv, + size_t num_tokens, size_t kHeads, size_t kQKVDim, + size_t kSeqLen) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + const size_t offset = pos * (kHeads + 2) * kQKVDim; + memset(dqkv + offset, 0, (kHeads + 1) * kQKVDim * sizeof(qkv[0])); + } + for (size_t head = 0; head < kHeads; ++head) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + const size_t qoffs = (pos * (kHeads + 2) + head) * kQKVDim; + const size_t aoffs = head * kSeqLen + pos * kHeads * kSeqLen; + const T* q = qkv + qoffs; + const T* dout = doutput + aoffs; + T* dq = dqkv + qoffs; + for (size_t pos2 = 0; pos2 <= pos; ++pos2) { + const size_t koffs = (pos2 * (kHeads + 2) + kHeads) * kQKVDim; + const T* k = qkv + koffs; + T* dk = dqkv + koffs; + MulByConstAndAddT(dout[pos2], k, dq, kQKVDim); + MulByConstAndAddT(dout[pos2], q, dk, kQKVDim); + } + } + } +} + +template +void MaskedSoftmaxVJPT(const T* y, T* dy, size_t num_tokens, + size_t kHeads, size_t kSeqLen) { + for (size_t head = 0; head < kHeads; ++head) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + size_t offset = pos * kHeads * kSeqLen + head * kSeqLen; + SoftmaxVJPT(y + offset, dy + offset, pos + 1); + memset(dy + offset + pos + 1, 0, (kSeqLen - pos - 1) * sizeof(T)); + } + } +} + +template +void MixByAttentionVJP(const T* qkv, const T* attention, const T* doutput, + T* dqkv, T* dattention, size_t num_tokens, + size_t kHeads, size_t kQKVDim, size_t kSeqLen) { + auto v_offset = [&](size_t pos) { + return (pos * (kHeads + 2) + kHeads + 1) * kQKVDim; + }; + for (size_t pos = 0; pos < num_tokens; ++pos) { + memset(&dqkv[v_offset(pos)], 0, kQKVDim * sizeof(qkv[0])); + } + for (size_t head = 0; head < kHeads; ++head) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + const size_t offset = head * kQKVDim + pos * kHeads * kQKVDim; + const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen; + const T* att = &attention[aoffset]; + const T* dout = &doutput[offset]; + T* datt = &dattention[aoffset]; + for (size_t pos2 = 0; pos2 <= pos; ++pos2) { + datt[pos2] = DotT(dout, &qkv[v_offset(pos2)], kQKVDim); + MulByConstAndAddT(att[pos2], dout, &dqkv[v_offset(pos2)], kQKVDim); + } + } + } +} + +template +void InputEmbeddingVJPT(const T* w, const std::vector& tokens, T scaling, + const T* dy, T* dw, size_t N) { + const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; + for (size_t i = 0; i < num_tokens; ++i) { + int token = tokens[i]; + MulByConstAndAddT(scaling, dy + i * N, dw + token * N, N); + } +} + +template +void LayerVJP(const Layer& weights, + const ForwardLayer& forward, + const T* dy, + Layer& grad, + ForwardLayer& backward, + size_t num_tokens) { + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kSeqLen = TConfig::kSeqLen; + static constexpr size_t kQKVDim = TConfig::kQKVDim; + static constexpr size_t kHeads = TConfig::kHeads; + static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + static const T kQueryScale = 1.0 / std::sqrt(T(kQKVDim)); + + MatMulVJPT(weights.linear_w.data(), forward.ffw_hidden_gated.data(), + dy, grad.linear_w.data(), backward.ffw_hidden_gated.data(), + kModelDim, kFFHiddenDim, num_tokens); + + GatedGeluVJP(forward.ffw_hidden.data(), backward.ffw_hidden_gated.data(), + backward.ffw_hidden.data(), kFFHiddenDim, num_tokens); + + MatMulVJPT(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(), + backward.ffw_hidden.data(), grad.gating_einsum_w.data(), + backward.bf_pre_ffw_rms_out.data(), kFFHiddenDim * 2, kModelDim, + num_tokens); + + RMSNormVJPT(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(), + backward.bf_pre_ffw_rms_out.data(), + grad.pre_ffw_norm_scale.data(), backward.attention_out.data(), + kModelDim, num_tokens); + + AddFromT(dy, backward.attention_out.data(), num_tokens * kModelDim); + + MultiHeadMatMulVJPT(weights.attn_vec_einsum_w.data(), forward.att_out.data(), + backward.attention_out.data(), + grad.attn_vec_einsum_w.data(), + backward.att_out.data(), + kHeads, kModelDim, kQKVDim, num_tokens); + + MixByAttentionVJP(forward.qkv.data(), forward.att.data(), + backward.att_out.data(), backward.qkv.data(), + backward.att.data(), num_tokens, kHeads, kQKVDim, + kSeqLen); + + MaskedSoftmaxVJPT(forward.att.data(), backward.att.data(), + num_tokens, kHeads, kSeqLen); + + MaskedAttentionVJP(forward.qkv.data(), backward.att.data(), + backward.qkv.data(), num_tokens, kHeads, kQKVDim, kSeqLen); + + for (size_t pos = 0; pos < num_tokens; ++pos) { + T* qkv = backward.qkv.data() + pos * (kHeads + 2) * kQKVDim; + MulByConstT(kQueryScale, qkv, kHeads * kQKVDim); + } + + for (int pos = 0; pos < num_tokens; ++pos) { + T* qkv = backward.qkv.data() + pos * (kHeads + 2) * kQKVDim; + for (size_t h = 0; h <= kHeads; ++h) { + Rope(qkv + h * kQKVDim, kQKVDim, -pos); + } + } + + MatMulVJPT(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(), + backward.qkv.data(), grad.qkv_einsum_w.data(), + backward.pre_att_rms_out.data(), + (kHeads + 2) * kQKVDim, kModelDim, num_tokens); + RMSNormVJPT(weights.pre_attention_norm_scale.data(), forward.input.data(), + backward.pre_att_rms_out.data(), + grad.pre_attention_norm_scale.data(), + backward.input.data(), kModelDim, num_tokens); + + AddFromT(backward.attention_out.data(), backward.input.data(), + num_tokens * kModelDim); +} + +template +void SoftcapVJPT(const T* y, T* dy, size_t N) { + T cap = 30.0; + T inv_cap = T(1.0) / cap; + for (size_t i = 0; i < N; ++i) { + T scaled = y[i] * inv_cap; + dy[i] *= (T(1.0) - scaled * scaled); + } +} + +template +void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) { + T scaling = -1.0 / std::log(2.0); + const std::vector tokens = prompt.tokens; + const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; + memset(dx, 0, V * num_tokens * sizeof(x[0])); + for (size_t i = 0; i < num_tokens; ++i) { + if (i + 1 < prompt.context_size) { + continue; + } + const int next_token = tokens[i + 1]; + dx[i * V + next_token] = scaling / x[i * V + next_token]; + } +} + +template +void CrossEntropyLossBackwardPass(const Prompt& prompt, + const Weights& weights, + const ForwardPass& forward, + Weights& grad, + ForwardPass& backward) { + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kVocabSize = TConfig::kVocabSize; + static constexpr size_t kLayers = TConfig::kLayers; + const std::vector tokens = prompt.tokens; + const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; + + CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt, + kVocabSize); + + SoftmaxVJPT(forward.probs.data(), backward.logits.data(), + kVocabSize, num_tokens); + + SoftcapVJPT(forward.logits.data(), backward.logits.data(), + num_tokens * kVocabSize); + + MatMulVJPT(weights.embedder_input_embedding.data(), + forward.final_norm_output.data(), + backward.logits.data(), + grad.embedder_input_embedding.data(), + backward.final_norm_output.data(), + kVocabSize, kModelDim, num_tokens); + + RMSNormVJPT(weights.final_norm_scale.data(), + forward.final_layer_output.data(), + backward.final_norm_output.data(), + grad.final_norm_scale.data(), + backward.final_layer_output.data(), kModelDim, num_tokens); + + for (int layer = static_cast(kLayers) - 1; layer >= 0; --layer) { + T* next_layer_grad = layer + 1 < kLayers + ? backward.layers[layer + 1].input.data() + : backward.final_layer_output.data(); + LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad, + *grad.GetLayer(layer), backward.layers[layer], num_tokens); + } + + const T kEmbScaling = EmbeddingScaling(kModelDim); + InputEmbeddingVJPT(weights.embedder_input_embedding.data(), + tokens, kEmbScaling, backward.layers[0].input.data(), + grad.embedder_input_embedding.data(), kModelDim); +} + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_ diff --git a/backprop/backward_scalar_test.cc b/backprop/backward_scalar_test.cc new file mode 100644 index 0000000..17f1b6c --- /dev/null +++ b/backprop/backward_scalar_test.cc @@ -0,0 +1,610 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "backprop/backward_scalar.h" + +#include +#include +#include + +#include "backprop/forward_scalar.h" +#include "backprop/sampler.h" +#include "backprop/test_util.h" +#include "gtest/gtest.h" + +namespace gcpp { + +TEST(BackPropTest, MatMulVJP) { + static const size_t kRows = 8; + static const size_t kCols = 64; + static const size_t kTokens = 5; + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + std::array weights; + std::array x; + std::array grad; + std::array dx; + std::array c_weights; + std::array c_x; + std::array c_y; + std::array dy; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(weights, 1.0 * (1 << iter), gen); + RandInit(x, 1.0 * (1 << iter), gen); + RandInit(dy, 1.0, gen); + Complexify(weights, c_weights); + Complexify(x, c_x); + auto func = [&]() { + MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens); + return DotT(dy.data(), c_y.data(), kTokens * kRows); + }; + memset(&grad, 0, sizeof(grad)); + MatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(), + kRows, kCols, kTokens); + TestGradient(dx, c_x, func, 1e-11, 1e-12,__LINE__); + TestGradient(grad, c_weights, func, 1e-14, 1e-12,__LINE__); + } +} + +TEST(BackPropTest, MultiHeadMatMulVJP) { + static const size_t kRows = 2; + static const size_t kCols = 16; + static const size_t kHeads = 4; + static const size_t kTokens = 3; + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + std::array weights; + std::array x; + std::array grad; + std::array dx; + std::array c_weights; + std::array c_x; + std::array c_y; + std::array dy; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(weights, 1.0 * (1 << iter), gen); + RandInit(x, 1.0 * (1 << iter), gen); + RandInit(dy, 1.0, gen); + Complexify(weights, c_weights); + Complexify(x, c_x); + auto func = [&]() { + MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows, + kCols, kTokens); + return DotT(dy.data(), c_y.data(), kTokens * kRows); + }; + memset(&grad, 0, sizeof(grad)); + MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), + dx.data(), kHeads, kRows, kCols, kTokens); + TestGradient(dx, c_x, func, 1e-15, 1e-13,__LINE__); + TestGradient(grad, c_weights, func, 1e-15, 1e-13,__LINE__); + } +} + +TEST(BackPropTest, RMSNormVJP) { + static const size_t K = 2; + static const size_t N = 64; + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + std::array weights; + std::array grad; + std::array x; + std::array dx; + std::array dy; + std::array c_weights; + std::array c_x; + std::array c_y; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(weights, 1.0 * (1 << iter), gen); + RandInit(x, 1.0 * (1 << iter), gen); + Complexify(weights, c_weights); + Complexify(x, c_x); + RandInit(dy, 1.0, gen); + auto func = [&]() { + RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K); + return DotT(dy.data(), c_y.data(), K * N); + }; + memset(&grad, 0, sizeof(grad)); + RMSNormVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(), + N, K); + TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__); + TestGradient(grad, c_weights, func, 1e-15, 1e-14, __LINE__); + } +} + +TEST(BackPropTest, SoftmaxVJP) { + static const size_t N = 64; + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + std::array x; + std::array dx; + std::array dy; + std::array c_x; + std::array c_y; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(x, 1.0 * (1 << iter), gen); + Complexify(x, c_x); + RandInit(dy, 1.0, gen); + auto func = [&]() { + memcpy(c_y.data(), c_x.data(), sizeof(c_x)); + Softmax(c_y.data(), N); + return DotT(dy.data(), c_y.data(), N); + }; + Softmax(x.data(), N); + memcpy(dx.data(), dy.data(), N * sizeof(dx[0])); + SoftmaxVJPT(x.data(), dx.data(), N); + TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__); + } +} + +TEST(BackPropTest, MaskedSoftmaxVJP) { + static const size_t kSeqLen = 16; + static const size_t kHeads = 2; + static const size_t kTokens = 14; + static const size_t N = kHeads * kSeqLen * kSeqLen; + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + std::array x; + std::array dy; + std::array dx = {}; + std::array c_x; + std::array c_y; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(x, 1.0 * (1 << iter), gen); + Complexify(x, c_x); + RandInit(dy, 1.0, gen); + auto func = [&]() { + memcpy(c_y.data(), c_x.data(), + kTokens * kHeads * kSeqLen * sizeof(c_x[0])); + MaskedSoftmax(c_y.data(), kTokens, kHeads, kSeqLen); + return DotT(dy.data(), c_y.data(), N); + }; + MaskedSoftmax(x.data(), kTokens, kHeads, kSeqLen); + memcpy(dx.data(), dy.data(), kTokens * kHeads * kSeqLen * sizeof(dx[0])); + MaskedSoftmaxVJPT(x.data(), dx.data(), kTokens, kHeads, kSeqLen); + TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__); + } +} + +TEST(BackPropTest, SoftcapVJP) { + static const size_t N = 64; + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + std::array x; + std::array dx; + std::array dy; + std::array c_x; + std::array c_y; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(x, 1.0 * (1 << iter), gen); + Complexify(x, c_x); + RandInit(dy, 1.0, gen); + auto func = [&]() { + memcpy(c_y.data(), c_x.data(), N * sizeof(c_x[0])); + Softcap(c_y.data(), N); + return DotT(dy.data(), c_y.data(), N); + }; + Softcap(x.data(), N); + memcpy(dx.data(), dy.data(), N * sizeof(dx[0])); + SoftcapVJPT(x.data(), dx.data(), N); + TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__); + } +} + +TEST(BackPropTest, CrossEntropyLossGrad) { + static const size_t K = 8; + static const size_t V = 64; + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + std::array x; + std::array dx; + std::array c_x; + Prompt prompt; + prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 }; + + for (int iter = 0; iter < 10; ++iter) { + prompt.context_size = 1 + (iter % 6); + RandInit(x, 1.0 * (1 << iter), gen); + Softcap(x.data(), V * K); + Softmax(x.data(), V, K); + CrossEntropyLossGrad(x.data(), dx.data(), prompt, V); + Complexify(x, c_x); + auto func = [&]() { + return CrossEntropyLoss(c_x.data(), prompt, V); + }; + TestGradient(dx, c_x, func, 1e-100, 1e-15, __LINE__); + } +} + +TEST(BackPropTest, GatedGeluVJP) { + static const size_t K = 2; + static const size_t N = 64; + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + std::array x; + std::array dx; + std::array dy; + std::array c_x; + std::array c_y; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(x, 1.0, gen); + Complexify(x, c_x); + RandInit(dy, 1.0, gen); + auto func = [&]() { + GatedGelu(c_x.data(), c_y.data(), N, K); + return DotT(dy.data(), c_y.data(), N * K); + }; + GatedGeluVJP(x.data(), dy.data(), dx.data(), N, K); + TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__); + } +} + +TEST(BackPropTest, MaskedAttentionVJP) { + static const size_t kSeqLen = 16; + static const size_t kHeads = 2; + static const size_t kQKVDim = 8; + static const size_t kTokens = 14; + static const size_t kQKVSize = kSeqLen * (kHeads + 2) * kQKVDim; + static const size_t kOutSize = kSeqLen * kHeads * kSeqLen; + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + std::array x; + std::array dx = {}; + std::array dy; + std::array c_x; + std::array c_y; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(x, 1.0, gen); + Complexify(x, c_x); + RandInit(dy, 1.0, gen); + auto func = [&]() { + MaskedAttention(c_x.data(), c_y.data(), kTokens, kHeads, kQKVDim, + kSeqLen); + return DotT(dy.data(), c_y.data(), kOutSize); + }; + MaskedAttentionVJP(x.data(), dy.data(), dx.data(), + kTokens, kHeads, kQKVDim, kSeqLen); + TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__); + } +} + +TEST(BackPropTest, MixByAttentionVJP) { + static const size_t kSeqLen = 16; + static const size_t kHeads = 2; + static const size_t kQKVDim = 8; + static const size_t kTokens = 14; + static const size_t kQKVSize = kSeqLen * (kHeads + 2) * kQKVDim; + static const size_t kAttnSize = kSeqLen * kHeads * kSeqLen; + static const size_t kOutSize = kSeqLen * kHeads * kQKVDim; + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + std::array qkv; + std::array dqkv = {}; + std::array attn; + std::array dattn = {}; + std::array dy; + std::array c_qkv; + std::array c_attn; + std::array c_y; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(qkv, 1.0, gen); + RandInit(attn, 1.0, gen); + Complexify(qkv, c_qkv); + Complexify(attn, c_attn); + RandInit(dy, 1.0, gen); + auto func = [&]() { + MixByAttention(c_qkv.data(), c_attn.data(), c_y.data(), + kTokens, kHeads, kQKVDim, kSeqLen); + return DotT(dy.data(), c_y.data(), kOutSize); + }; + MixByAttentionVJP(qkv.data(), attn.data(), dy.data(), dqkv.data(), + dattn.data(), kTokens, kHeads, kQKVDim, kSeqLen); + TestGradient(dqkv, c_qkv, func, 1e-14, 1e-15, __LINE__); + TestGradient(dattn, c_attn, func, 1e-14, 1e-15, __LINE__); + } +} + +TEST(BackPropTest, InputEmbeddingVJP) { + static const size_t kSeqLen = 8; + static const size_t kVocabSize = 4; + static const size_t kModelDim = 16; + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + std::array weights; + std::array grad; + std::array dy; + std::array c_weights; + std::array c_y; + std::vector tokens = { 0, 1, 2, 3, 0, 1, 2 }; + size_t num_tokens = tokens.size() - 1; + + for (size_t iter = 0; iter < 10; ++iter) { + RandInit(weights, 1.0, gen); + RandInit(dy, 1.0, gen); + Complexify(weights, c_weights); + auto func = [&]() { + InputEmbedding(c_weights.data(), tokens, TC(3.0), c_y.data(), kModelDim); + return DotT(dy.data(), c_y.data(), num_tokens * kModelDim); + }; + memset(&grad, 0, sizeof(grad)); + InputEmbeddingVJPT(weights.data(), tokens, 3.0, dy.data(), grad.data(), + kModelDim); + TestGradient(grad, c_weights, func, 1e-16, 1e-14, __LINE__); + } +} + +struct TestConfig { + static constexpr int kSeqLen = 18; + static constexpr int kVocabSize = 12; + static constexpr int kModelDim = 32; + static constexpr int kHeads = 3; + static constexpr int kQKVDim = 12; + static constexpr int kFFHiddenDim = 48; + static constexpr std::array kLayerConfig = + FixedLayerConfig<2>(LayerAttentionType::kGemma); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr bool kAbsolutePE = false; + static constexpr bool kPostNormScale = false; + + static constexpr int kKVHeads = 1; + static constexpr int kConv1dWidth = 0; + static constexpr bool kFFBiases = false; + static constexpr bool kSoftmaxAttnOutputBiases = false; + static constexpr int kGemmaLayers = kLayers; + static constexpr int kGriffinLayers = 0; + static constexpr int kNumTensorScales = 0; +}; + +TEST(BackPropTest, LayerVJP) { + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + const size_t kOutputSize = TestConfig::kSeqLen * TestConfig::kModelDim; + Layer weights; + Layer grad; + ForwardLayer forward; + ForwardLayer backward = {}; + Layer c_weights; + ForwardLayer c_forward; + std::array y; + std::array dy; + std::array c_y; + const size_t num_tokens = 3; + + for (size_t iter = 0; iter < 10; ++iter) { + RandInit(weights, 1.0, gen); + RandInit(forward.input, 1.0, gen); + RandInit(dy, 1.0, gen); + Complexify(weights, c_weights); + Complexify(forward.input, c_forward.input); + auto func = [&]() { + ApplyLayer(c_weights, c_forward, num_tokens, c_y.data()); + return DotT(dy.data(), c_y.data(), num_tokens * TestConfig::kModelDim); + }; + memset(&grad, 0, sizeof(grad)); + ApplyLayer(weights, forward, num_tokens, y.data()); + LayerVJP(weights, forward, dy.data(), grad, backward, num_tokens); + TestGradient(backward.input, c_forward.input, func, 1e-11, 1e-11, + __LINE__); + TestGradient(grad, c_weights, func, 1e-11); + } +} + +TEST(BackPropTest, EndToEnd) { + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + WeightsWrapper weights; + WeightsWrapper grad; + ForwardPass forward; + ForwardPass backward; + WeightsWrapper c_weights; + ForwardPass c_forward; + + ReverseSequenceSampler training_task({0, 0, 1, 1}); + std::vector batch = training_task.SampleBatch(10, gen); + + for (const Prompt& prompt : batch) { + ReverseSequenceSampler::LogPrompt(prompt); + RandInit(weights.get(), 1.0, gen); + CrossEntropyLossForwardPass(prompt, weights.get(), forward); + grad.clear(); + CrossEntropyLossBackwardPass( + prompt, weights.get(), forward, grad.get(), backward); + + Complexify(weights.get(), c_weights.get()); + auto func = [&]() { + return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward); + }; + + TestGradient(grad.get(), c_weights.get(), func, 1e-11); + } +} + +template +void MulByConstAndAddT(T c, const Layer& x, + Layer& out) { + MulByConstAndAddT(c, x.pre_attention_norm_scale, + out.pre_attention_norm_scale); + MulByConstAndAddT(c, x.attn_vec_einsum_w, out.attn_vec_einsum_w); + MulByConstAndAddT(c, x.qkv_einsum_w, out.qkv_einsum_w); + MulByConstAndAddT(c, x.pre_ffw_norm_scale, out.pre_ffw_norm_scale); + MulByConstAndAddT(c, x.gating_einsum_w, out.gating_einsum_w); + MulByConstAndAddT(c, x.linear_w, out.linear_w); +} + +template +void MulByConstAndAddT(T c, const Weights& x, + Weights& out) { + static constexpr size_t kLayers = TConfig::kLayers; + MulByConstAndAddT(c, x.embedder_input_embedding, + out.embedder_input_embedding); + MulByConstAndAddT(c, x.final_norm_scale, out.final_norm_scale); + for (size_t i = 0; i < kLayers; ++i) { + MulByConstAndAddT(c, *x.GetLayer(i), *out.GetLayer(i)); + } +} + +// Evaluates forward pass on a batch. +template +T CrossEntropyLossForwardPass(const std::vector& batch, + const WeightsWrapper& weights, + ForwardPass& forward) { + T loss = 0.0; + for (const Prompt& prompt : batch) { + loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward); + } + T scale = 1.0 / batch.size(); + return loss * scale; +} + +// Evaluates forward pass on a batch by applying gradient with the given +// learning rate. Does not update weights, but uses the given tmp weights +// instead. +template +T CrossEntropyLossForwardPass(T learning_rate, + const std::vector& batch, + const WeightsWrapper& weights, + const WeightsWrapper& grad, + WeightsWrapper& tmp, + ForwardPass& forward) { + tmp.copy(weights); + const T scale = -learning_rate / batch.size(); + MulByConstAndAddT(scale, grad.get(), tmp.get()); + return CrossEntropyLossForwardPass(batch, tmp, forward); +} + +// Uses line search in the negative gradient direction to update weights. We do +// this so that we can test that each step during the gradient descent can +// decrease the objective function value. +template +T FindOptimalUpdate(const WeightsWrapper& grad, + WeightsWrapper& weights, + WeightsWrapper& tmp, + ForwardPass& forward, + const std::vector& batch, + T loss, T initial_learning_rate) { + T lr0 = initial_learning_rate; + T loss0 = CrossEntropyLossForwardPass( + lr0, batch, weights, grad, tmp, forward); + for (size_t iter = 0; iter < 30; ++iter) { + T lr1 = lr0 * 0.5; + T loss1 = CrossEntropyLossForwardPass( + lr1, batch, weights, grad, tmp, forward); + if (loss0 < loss && loss1 >= loss0) { + break; + } + loss0 = loss1; + lr0 = lr1; + } + for (size_t iter = 0; iter < 30; ++iter) { + T lr1 = lr0 * 2.0; + T loss1 = CrossEntropyLossForwardPass( + lr1, batch, weights, grad, tmp, forward); + if (loss1 >= loss0) { + break; + } + loss0 = loss1; + lr0 = lr1; + } + const T scale = -lr0 / batch.size(); + MulByConstAndAddT(scale, grad.get(), weights.get()); + return lr0; +} + +TEST(BackProptest, Convergence) { + std::mt19937 gen(42); + using T = float; + using TC = std::complex; + WeightsWrapper weights; + WeightsWrapper grad; + WeightsWrapper tmp; + ForwardPass forward; + ForwardPass backward; + WeightsWrapper c_weights; + ForwardPass c_forward; + constexpr size_t kBatchSize = 5; + ReverseSequenceSampler training_task({0, 0, 0, 1, 1}); + T learning_rate = 0.01; + + RandInit(weights.get(), T(1.0), gen); + + printf("Sample batch:\n"); + for (size_t i = 0; i < 10; ++i) { + ReverseSequenceSampler::LogPrompt(training_task.Sample(gen)); + } + + T prev_loss = std::numeric_limits::max(); + bool stop = false; + size_t step = 0; + while (!stop) { + T loss = 0.0; + grad.clear(); + std::mt19937 sgen(42); + std::vector batch = training_task.SampleBatch(kBatchSize, sgen); + for (const Prompt& prompt : batch) { + loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward); + CrossEntropyLossBackwardPass( + prompt, weights.get(), forward, grad.get(), backward); + } + + if (step % 250 == 0) { + printf("Checking gradient...\n"); + Complexify(weights.get(), c_weights.get()); + auto func = [&]() { + TC scale = batch.size(); + return CrossEntropyLossForwardPass(batch, c_weights, c_forward) * scale; + }; + + TestGradient(grad.get(), c_weights.get(), func, 5e-3f); + } + + loss /= batch.size(); + EXPECT_LT(loss, prev_loss); + stop = step >= 10000 || loss < 1e-2; + if (step % 10 == 0 || stop) { + printf("step: %5zu loss: %.15f learning_rate: %.15f\n", + step, loss, learning_rate); + } + if (!stop) { + learning_rate = FindOptimalUpdate( + grad, weights, tmp, forward, batch, loss, learning_rate); + ++step; + } + prev_loss = loss; + } + EXPECT_LT(step, 1000); +} + +} // namespace gcpp diff --git a/backprop/backward_test.cc b/backprop/backward_test.cc new file mode 100644 index 0000000..cf13e6e --- /dev/null +++ b/backprop/backward_test.cc @@ -0,0 +1,268 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS HWY_SCALAR +#endif + +#include + +#include +#include +#include +#include +#include + +#include "backprop/backward_scalar.h" +#include "backprop/forward_scalar.h" +#include "backprop/sampler.h" +#include "backprop/test_util.h" +#include "compression/compress.h" +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "gemma/gemma.h" +#include "gemma/weights.h" + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "backprop/backward_test.cc" //NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" +// After highway.h +#include "backprop/backward-inl.h" +#include "backprop/forward-inl.h" +#include "gemma/ops.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +namespace hn = hwy::HWY_NAMESPACE; + +void TestMatMulVJP() { + static const size_t kRows = 8; + static const size_t kCols = 64; + static const size_t kTokens = 5; + hwy::ThreadPool pool(8); + std::mt19937 gen(42); + HWY_ALIGN std::array weights; + HWY_ALIGN std::array x; + HWY_ALIGN std::array dy; + HWY_ALIGN std::array grad; + HWY_ALIGN std::array dx; + HWY_ALIGN std::array grad_scalar; + HWY_ALIGN std::array dx_scalar; + using TC = std::complex; + std::array c_weights; + std::array c_x; + std::array c_y; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(weights, 1.0f * (1 << iter), gen); + RandInit(x, 1.0f * (1 << iter), gen); + RandInit(dy, 1.0f, gen); + Complexify(weights, c_weights); + Complexify(x, c_x); + auto func = [&]() { + MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens); + return DotT(dy.data(), c_y.data(), kTokens * kRows); + }; + + hwy::ZeroBytes(&grad, sizeof(grad)); + MatMulVJP(weights, x.data(), dy.data(), kTokens, + grad, dx.data(), pool); + TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__); + TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__); + + hwy::ZeroBytes(&grad_scalar, sizeof(grad_scalar)); + MatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), + dx_scalar.data(), kRows, kCols, kTokens); + TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__); + TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__); + } +} + +void TestMultiHeadMatMulVJP() { + static const size_t kRows = 2; + static const size_t kCols = 16; + static const size_t kHeads = 4; + static const size_t kTokens = 3; + hwy::ThreadPool pool(8); + std::mt19937 gen(42); + HWY_ALIGN std::array weights; + HWY_ALIGN std::array x; + HWY_ALIGN std::array grad; + HWY_ALIGN std::array dx; + HWY_ALIGN std::array dy; + HWY_ALIGN std::array grad_scalar; + HWY_ALIGN std::array dx_scalar; + using TC = std::complex; + std::array c_weights; + std::array c_x; + std::array c_y; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(weights, 1.0f * (1 << iter), gen); + RandInit(x, 1.0f * (1 << iter), gen); + RandInit(dy, 1.0f, gen); + Complexify(weights, c_weights); + Complexify(x, c_x); + auto func = [&]() { + MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows, + kCols, kTokens); + return DotT(dy.data(), c_y.data(), kTokens * kRows); + }; + + hwy::ZeroBytes(&grad, sizeof(grad)); + MultiHeadMatMulVJP( + weights, x.data(), dy.data(), kTokens, grad, dx.data(), pool); + TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__); + TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__); + + hwy::ZeroBytes(&grad_scalar, sizeof(grad_scalar)); + MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), + dx_scalar.data(), kHeads, kRows, kCols, kTokens); + TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__); + TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__); + } +} + +void TestRMSNormVJP() { + static const size_t K = 2; + static const size_t N = 64; + hwy::ThreadPool pool(8); + std::mt19937 gen(42); + HWY_ALIGN std::array weights; + HWY_ALIGN std::array x; + HWY_ALIGN std::array grad; + HWY_ALIGN std::array dx; + HWY_ALIGN std::array dy; + HWY_ALIGN std::array grad_scalar; + HWY_ALIGN std::array dx_scalar; + using TC = std::complex; + std::array c_weights; + std::array c_x; + std::array c_y; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(weights, 1.0f * (1 << iter), gen); + RandInit(x, 1.0f * (1 << iter), gen); + RandInit(dy, 1.0f, gen); + Complexify(weights, c_weights); + Complexify(x, c_x); + auto func = [&]() { + RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K); + return DotT(dy.data(), c_y.data(), K * N); + }; + + hwy::ZeroBytes(&grad, sizeof(grad)); + RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(), + dx.data(), pool); + TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__); + TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__); + + hwy::ZeroBytes(&grad_scalar, sizeof(grad_scalar)); + RMSNormVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), + dx_scalar.data(), N, K); + TestNear(dx, dx_scalar, 0, 2e-5, __LINE__); + TestNear(grad, grad_scalar, 0, 2e-5, __LINE__); + } +} + +struct TestConfig { + static constexpr int kSeqLen = 24; + static constexpr int kVocabSize = 16; + static constexpr int kModelDim = 32; + static constexpr int kHeads = 3; + static constexpr int kQKVDim = 16; + static constexpr int kFFHiddenDim = 64; + static constexpr std::array kLayerConfig = + FixedLayerConfig<2>(LayerAttentionType::kGemma); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr bool kAbsolutePE = false; + static constexpr bool kPostNormScale = false; + + static constexpr int kKVHeads = 1; + static constexpr int kConv1dWidth = 0; + static constexpr bool kFFBiases = false; + static constexpr bool kSoftmaxAttnOutputBiases = false; + static constexpr int kGemmaLayers = kLayers; + static constexpr int kGriffinLayers = 0; + static constexpr int kNumTensorScales = 0; +}; + +void TestEndToEnd() { + std::mt19937 gen(42); + hwy::ThreadPool pool(0); + WeightsWrapper weights; + WeightsWrapper grad; + ActivationsWrapper forward0; + ActivationsWrapper forward1; + ActivationsWrapper backward; + using TC = std::complex; + WeightsWrapper c_weights; + ForwardPass c_forward; + + ReverseSequenceSampler training_task({0, 0, 1, 1}); + std::vector batch = training_task.SampleBatch(10, gen); + + for (const Prompt& prompt : batch) { + ReverseSequenceSampler::LogPrompt(prompt); + RandInit(weights.get(), 1.0f, gen); + + float loss0 = CrossEntropyLossForwardPass( + prompt, weights.get(), forward0.get()); + + float loss1 = CrossEntropyLossForwardPass( + prompt.tokens, prompt.context_size, weights.get(), forward1.get(), + pool); + + EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 1e-5); + + grad.clear(); + CrossEntropyLossBackwardPass( + prompt, weights.get(), forward1.get(), grad.get(), backward.get(), + pool); + + Complexify(weights.get(), c_weights.get()); + auto func = [&]() { + return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward); + }; + + TestGradient(grad.get(), c_weights.get(), func, 2e-3f); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace gcpp { +HWY_BEFORE_TEST(BackwardTest); +HWY_EXPORT_AND_TEST_P(BackwardTest, TestMatMulVJP); +HWY_EXPORT_AND_TEST_P(BackwardTest, TestMultiHeadMatMulVJP); +HWY_EXPORT_AND_TEST_P(BackwardTest, TestRMSNormVJP); +HWY_EXPORT_AND_TEST_P(BackwardTest, TestEndToEnd); +HWY_AFTER_TEST(); + +} // namespace gcpp + +#endif diff --git a/backprop/common_scalar.cc b/backprop/common_scalar.cc new file mode 100644 index 0000000..b04dffb --- /dev/null +++ b/backprop/common_scalar.cc @@ -0,0 +1,51 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "backprop/common_scalar.cc" // NOLINT +#include "hwy/foreach_target.h" // IWYU pragma: keep + +#include "hwy/highway.h" +// After highway.h +#include "gemma/ops.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +float EmbeddingScaling(int model_dim) { + // Round to bf16 to match Gemma's Embedder, which casts before mul. + return hwy::ConvertScalarTo(hwy::ConvertScalarTo( + Sqrt(static_cast(model_dim)))); +} + +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace gcpp { + +HWY_EXPORT(EmbeddingScaling); + +float EmbeddingScaling(int model_dim) { + return HWY_DYNAMIC_DISPATCH(EmbeddingScaling)(model_dim); +} + +} // namespace gcpp +#endif // HWY_ONCE diff --git a/backprop/common_scalar.h b/backprop/common_scalar.h new file mode 100644 index 0000000..628962b --- /dev/null +++ b/backprop/common_scalar.h @@ -0,0 +1,119 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_ + +#include + +namespace gcpp { + +template +U DotT(const T* a, const U* b, size_t N) { + U sum = {}; + for (size_t i = 0; i < N; ++i) { + sum += a[i] * b[i]; + } + return sum; +} + +template<> +std::complex DotT(const float* a, const std::complex* b, + size_t N) { + std::complex sum = {}; + for (size_t i = 0; i < N; ++i) { + sum += static_cast(a[i]) * b[i]; + } + return sum; +} + +template +void MulByConstT(T c, T* x, size_t N) { + for (size_t i = 0; i < N; ++i) { + x[i] *= c; + } +} + +// out += c * x +template +void MulByConstAndAddT(T c, const T* x, T* out, size_t N) { + for (size_t i = 0; i < N; ++i) { + out[i] += c * x[i]; + } +} + +template +void MulByConstAndAddT(T c, const std::array& x, std::array& out) { + MulByConstAndAddT(c, x.data(), out.data(), N); +} + +template +void AddFromT(const T* a, T* out, size_t N) { + for (size_t i = 0; i < N; ++i) { + out[i] += a[i]; + } +} + +template +T SquaredL2(const T* x, size_t N) { + T sum = {}; + for (size_t i = 0; i < N; ++i) { + sum += x[i] * x[i]; + } + return sum; +} + +template +T Gelu(T x) { + static const T kMul = 0.044715; + static const T kSqrt2OverPi = 0.797884560804236; + + const T x3 = x * x * x; + const T arg = kSqrt2OverPi * (kMul * x3 + x); + const T cdf = T(0.5) * (T(1.0) + std::tanh(arg)); + return x * cdf; +} + +template +void Rope(T* x, U base, size_t N, int i) { + const size_t N2 = N / 2; + for (size_t dim = 0; dim < N2; ++dim) { + const T freq_exponents = T(2 * dim) / T(N); + const T timescale = std::pow(base, freq_exponents); + const T theta = T(i) / timescale; + const T cos_val = std::cos(theta); + const T sin_val = std::sin(theta); + const T x0 = x[dim]; + const T x1 = x[dim + N2]; + x[dim] = x0 * cos_val - x1 * sin_val; + x[dim + N2] = x0 * sin_val + x1 * cos_val; + } +} + +template +void Rope(T* x, size_t N, int i) { + Rope(x, T(10000.0), N, i); +} + +template +void Rope(std::complex* x, size_t N, int i) { + Rope(x, T(10000.0), N, i); +} + +float EmbeddingScaling(int model_dim); + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_ diff --git a/backprop/forward-inl.h b/backprop/forward-inl.h new file mode 100644 index 0000000..4b7cdf1 --- /dev/null +++ b/backprop/forward-inl.h @@ -0,0 +1,292 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Include guard for non-SIMD code. +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_ + +#include +#include + +#include +#include + +#include "gemma/activations.h" +#include "gemma/configs.h" + +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_ + +// Include guard for (potentially) SIMD code. +#if defined(THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE +#undef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE +#else +#define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE +#endif + +#include "gemma/common-inl.h" +#include "gemma/ops.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +template +void InputEmbedding(const ArrayT& weights, const std::vector& prompt, + const float scaling, float* HWY_RESTRICT output, + size_t model_dim) { + HWY_ASSERT(!prompt.empty()); + for (size_t pos = 0; pos < prompt.size() - 1; ++pos) { + int token = prompt[pos]; + Decompress(weights, token * model_dim, output + pos * model_dim, model_dim); + MulByConst(scaling, output + pos * model_dim, model_dim); + } +} + +template +void ApplyRMSNorm(const WT* HWY_RESTRICT weights, const XT* HWY_RESTRICT x, + size_t model_dim, size_t num_tokens, + OutT* HWY_RESTRICT output, + hwy::ThreadPool& pool) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + const size_t offset = pos * model_dim; + RMSNorm(x + offset, weights, output + offset, model_dim); + } +} + +static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs, + const std::vector& prompt, + size_t context_size, + size_t vocab_size, + hwy::ThreadPool& pool) { + HWY_ASSERT(!prompt.empty()); + float loss = 0.0f; + for (size_t pos = 0; pos < prompt.size() - 1; ++pos) { + if (pos + 1 < context_size) { + continue; // next token is part of context, don't try to predict it + } + const int next_token = prompt[pos + 1]; + loss += std::log(probs[pos * vocab_size + next_token]); + } + float scaling = -1.0 / std::log(2.0); + return loss * scaling; +} + +template typename LayerT> +void ApplyForwardLayer(const LayerT& weights, + ForwardLayer& activations, + size_t num_tokens, + float* HWY_RESTRICT output, + hwy::ThreadPool& pool) { + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kSeqLen = TConfig::kSeqLen; + static constexpr size_t kQKVDim = TConfig::kQKVDim; + static constexpr size_t kHeads = TConfig::kHeads; + static const float kQueryScale = + static_cast(1.0 / sqrt(static_cast(kQKVDim))); + HWY_ASSERT(num_tokens <= kSeqLen); + + ApplyRMSNorm(weights.pre_attention_norm_scale.data(), + activations.input.data(), kModelDim, num_tokens, + activations.pre_att_rms_out.data(), pool); + + for (size_t pos = 0; pos < num_tokens; ++pos) { + MatVec<(kHeads + 2) * kQKVDim, kModelDim>( + weights.qkv_einsum_w, 0, + activations.pre_att_rms_out.data() + pos * kModelDim, nullptr, + activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool); + } + const size_t num_tasks = kHeads * num_tokens; + + for (size_t pos = 0; pos < num_tokens; ++pos) { + float* HWY_RESTRICT k = + activations.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim; + Rope(k, kQKVDim, pos); + } + pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { + const size_t head = task % kHeads; + const size_t pos = task / kHeads; + float* HWY_RESTRICT q = + activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim; + Rope(q, kQKVDim, pos); + MulByConst(kQueryScale, q, kQKVDim); + }); + + pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { + const size_t head = task % kHeads; + const size_t pos = task / kHeads; + const float* HWY_RESTRICT q = + activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim; + float* HWY_RESTRICT head_att = + activations.att.data() + (pos * kHeads + head) * kSeqLen; + for (size_t pos2 = 0; pos2 <= pos; ++pos2) { + const float* HWY_RESTRICT k2 = + activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads) * kQKVDim; + const float score = Dot(q, k2, kQKVDim); + head_att[pos2] = score; + } + }); + + pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { + const size_t head = task % kHeads; + const size_t pos = task / kHeads; + float* HWY_RESTRICT head_att = + activations.att.data() + (pos * kHeads + head) * kSeqLen; + Softmax(head_att, pos + 1); + }); + + pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { + const size_t head = task % kHeads; + const size_t pos = task / kHeads; + const float* HWY_RESTRICT head_att = + activations.att.data() + (pos * kHeads + head) * kSeqLen; + float* HWY_RESTRICT att_out = + activations.att_out.data() + (pos * kHeads + head) * kQKVDim; + hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); + for (size_t pos2 = 0; pos2 <= pos; ++pos2) { + float* HWY_RESTRICT v2 = + activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim; + MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); + } + }); + + hwy::ZeroBytes(activations.attention_out.data(), + num_tokens * kModelDim * sizeof(activations.attention_out[0])); + for (size_t pos = 0; pos < num_tokens; ++pos) { + for (size_t head = 0; head < kHeads; ++head) { + MatVec( + weights.attn_vec_einsum_w, head * kModelDim * kQKVDim, + activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim, + nullptr, activations.att_post1.data() + pos * kModelDim, pool); + AddFrom(activations.att_post1.data() + pos * kModelDim, + activations.attention_out.data() + pos * kModelDim, kModelDim); + } + } + + for (size_t pos = 0; pos < num_tokens; ++pos) { + AddFrom(activations.input.data() + pos * kModelDim, + activations.attention_out.data() + pos * kModelDim, kModelDim); + } + + ApplyRMSNorm(weights.pre_ffw_norm_scale.data(), + activations.attention_out.data(), kModelDim, num_tokens, + activations.bf_pre_ffw_rms_out.data(), pool); + static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + for (size_t pos = 0; pos < num_tokens; ++pos) { + MatVec( + weights.gating_einsum_w, 0, + activations.bf_pre_ffw_rms_out.data() + pos * kModelDim, nullptr, + activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool); + } + for (size_t pos = 0; pos < num_tokens; ++pos) { + const size_t hidden_offset = pos * kFFHiddenDim * 2; + const float* HWY_RESTRICT out = + activations.ffw_hidden.data() + hidden_offset; + const float* HWY_RESTRICT out_mul = out + kFFHiddenDim; + float* HWY_RESTRICT out_gated = + activations.ffw_hidden_gated.data() + pos * kFFHiddenDim; + namespace hn = hwy::HWY_NAMESPACE; + using DF = hn::ScalableTag; + using VF = hn::Vec; + DF df; + for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) { + const auto y = Load(df, out + i); + const auto x = Load(df, out_mul + i); + hn::Store(hn::Mul(x, Gelu(df, y)), df, out_gated + i); + } + } + for (size_t pos = 0; pos < num_tokens; ++pos) { + MatVec( + weights.linear_w, 0, + activations.ffw_hidden_gated.data() + pos * kFFHiddenDim, + nullptr, output + pos * kModelDim, pool); + } + for (size_t pos = 0; pos < num_tokens; ++pos) { + AddFrom(activations.attention_out.data() + pos * kModelDim, + output + pos * kModelDim, kModelDim); + } +} + +template typename WeightsT, + template typename LayerT> +float CrossEntropyLossForwardPass(const std::vector& prompt, + size_t context_size, + const WeightsT& weights, + ForwardPass& forward, + hwy::ThreadPool& pool) { + static constexpr size_t kVocabSize = TConfig::kVocabSize; + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kSeqLen = TConfig::kSeqLen; + static constexpr size_t kLayers = TConfig::kLayers; + const float kEmbScaling = EmbeddingScaling(); + static_assert(!TConfig::kAbsolutePE); + static_assert(!TConfig::kPostNormScale); + static_assert(TConfig::kKVHeads == 1); + + HWY_DASSERT(context_size > 0); + HWY_DASSERT(context_size < prompt.size()); + const size_t num_tokens = prompt.size() - 1; + + InputEmbedding(weights.embedder_input_embedding, prompt, kEmbScaling, + forward.layers[0].input.data(), kModelDim); + + for (size_t layer = 0; layer < kLayers; ++layer) { + auto type = TConfig::kLayerConfig[layer]; + // TODO(szabadka) Implement Griffin layer. + HWY_ASSERT(type == LayerAttentionType::kGemma); + float* HWY_RESTRICT output = layer + 1 < kLayers ? + forward.layers[layer + 1].input.data() : + forward.final_layer_output.data(); + ApplyForwardLayer( + *weights.GetLayer(layer), forward.layers[layer], + num_tokens, output, pool); + } + + ApplyRMSNorm(weights.final_norm_scale.data(), + forward.final_layer_output.data(), + kModelDim, num_tokens, forward.final_norm_output.data(), pool); + + for (size_t pos = 0; pos < num_tokens; ++pos) { + MatVec( + weights.embedder_input_embedding, 0, + forward.final_norm_output.data() + pos * kModelDim, nullptr, + forward.logits.data() + pos * kVocabSize, pool); + } + + for (size_t pos = 0; pos < num_tokens; ++pos) { + LogitsSoftCap(30.0f, forward.logits.data() + pos * kVocabSize, kVocabSize); + } + + hwy::CopyBytes(forward.logits.data(), forward.probs.data(), + num_tokens * kVocabSize * sizeof(forward.logits[0])); + + for (size_t pos = 0; pos < num_tokens; ++pos) { + Softmax(forward.probs.data() + pos * kVocabSize, kVocabSize); + } + + return CrossEntropyLoss(forward.probs.data(), prompt, context_size, + kVocabSize, pool); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#endif // NOLINT diff --git a/backprop/forward.cc b/backprop/forward.cc new file mode 100644 index 0000000..fb67c2a --- /dev/null +++ b/backprop/forward.cc @@ -0,0 +1,80 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "backprop/forward.h" + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "backprop/forward.cc" // NOLINT +#include "hwy/foreach_target.h" // IWYU pragma: keep + +#include "hwy/highway.h" +// After highway.h +#include "backprop/forward-inl.h" +#include "gemma/weights.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +template +float CrossEntropyLossForwardPass(const Prompt& prompt, + const ByteStorageT& weights_u8, + ByteStorageT& forward_u8, + hwy::ThreadPool& pool) { + const auto& weights = + *reinterpret_cast*>(weights_u8.get()); + auto& forward = + *reinterpret_cast*>(forward_u8.get()); + return CrossEntropyLossForwardPass( + prompt.tokens, prompt.context_size, weights, forward, pool); +} + +float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt, + const ByteStorageT& weights, + ByteStorageT& forward, + hwy::ThreadPool& pool) { + switch (model) { + case Model::GEMMA_2B: + return CrossEntropyLossForwardPass( + prompt, weights, forward, pool); + case Model::GEMMA_TINY: + return CrossEntropyLossForwardPass( + prompt, weights, forward, pool); + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } +} + +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace gcpp { + +HWY_EXPORT(CrossEntropyLossForwardPassT); + +float CrossEntropyLossForwardPass( + const Model& model, const Prompt& prompt, const ByteStorageT& weights, + ByteStorageT& forward, hwy::ThreadPool& pool) { + return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)( + model, prompt, weights, forward, pool); +} + +} // namespace gcpp +#endif // HWY_ONCE diff --git a/backprop/forward.h b/backprop/forward.h new file mode 100644 index 0000000..f17c898 --- /dev/null +++ b/backprop/forward.h @@ -0,0 +1,33 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ + +#include + +#include "backprop/prompt.h" +#include "gemma/common.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +namespace gcpp { + +float CrossEntropyLossForwardPass( + const Model& model, const Prompt& prompt, const ByteStorageT& weights, + ByteStorageT& forward, hwy::ThreadPool& pool); + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ diff --git a/backprop/forward_scalar.h b/backprop/forward_scalar.h new file mode 100644 index 0000000..5643530 --- /dev/null +++ b/backprop/forward_scalar.h @@ -0,0 +1,288 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_ + +#include +#include + +#include +#include +#include + +#include "backprop/common_scalar.h" +#include "backprop/prompt.h" +#include "gemma/activations.h" +#include "gemma/weights.h" + +namespace gcpp { + +// w is N x M matrix in row-major order, x is M x K matrix in column-major order +// y = w * x is N x K matrix in column-major order. +template +void MatMulT(const T* w, const T* x, T* y, size_t N, size_t M, size_t K) { + for (size_t i = 0; i < K; ++i) { + for (size_t j = 0; j < N; ++j) { + y[i * N + j] = DotT(&w[j * M], &x[i * M], M); + } + } +} + +// w is H concatenated N x M matrix in row-major order, x is HM x K matrix in +// column-major order and y = w' * x is N x K matrix in column-major order, +// where w' is the rearrangement of w into an N x HM matrix. +template +void MultiHeadMatMul(const T* w, const T* x, T* y, size_t H, size_t N, + size_t M, size_t K) { + memset(y, 0, N * K * sizeof(y[0])); + for (size_t i = 0; i < K; ++i) { + for (size_t h = 0; h < H; ++h) { + for (size_t j = 0; j < N; ++j) { + y[i * N + j] += DotT(&w[h * N * M + j * M], &x[i * H * M + h * M], M); + } + } + } +} + +template +void RMSNormT(const T* w, const T* x, T* out, size_t N, size_t K) { + constexpr T eps(1e-6); + for (size_t i = 0; i < K; ++i) { + T ss = SquaredL2(x + i * N, N); + ss = T(1.0) / std::sqrt(ss / T(N) + eps); + for (size_t j = 0; j < N; j++) { + out[i * N + j] = (T(1.0) + w[j]) * (ss * x[i * N + j]); + } + } +} +template +void Softmax(T* x, size_t N) { + T sum = {}; + auto maxreal = std::real(x[0]); + for (size_t i = 1; i < N; ++i) { + if (std::real(x[i]) > maxreal) { + maxreal = std::real(x[i]); + } + } + for (size_t i = 0; i < N; ++i) { + x[i] = std::exp(x[i] - maxreal); + sum += x[i]; + } + T scale = T(1.0) / sum; + for (size_t i = 0; i < N; ++i) { + x[i] *= scale; + } +} +template +void Softmax(T* x, size_t N, size_t K) { + for (size_t i = 0; i < K; ++i) { + Softmax(x + i * N, N); + } +} +template +void Softcap(T* x, size_t N) { + T cap = 30.0; + T inv_cap = T(1.0) / cap; + for (size_t i = 0; i < N; ++i) { + x[i] = cap * std::tanh(x[i] * inv_cap); + } +} + +template +void GatedGelu(const T* in, T* out, size_t N, size_t K) { + for (size_t i = 0; i < K; ++i) { + const T* x1 = in + i * 2 * N; + const T* x2 = x1 + N; + T* y = out + i * N; + for (size_t j = 0; j < N; ++j) { + y[j] = x2[j] * Gelu(x1[j]); + } + } +} + +template +void InputEmbedding(const T* w, const std::vector& tokens, T scaling, + T* y, size_t N) { + const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; + for (size_t i = 0; i < num_tokens; ++i) { + int token = tokens[i]; + memcpy(y + i * N, w + token * N, N * sizeof(y[0])); + MulByConstT(scaling, y + i * N, N); + } +} + +template +void MaskedAttention(const T* qkv, T* output, size_t num_tokens, + size_t kHeads, size_t kQKVDim, size_t kSeqLen) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + for (size_t head = 0; head < kHeads; ++head) { + const size_t qoffset = pos * (kHeads + 2) * kQKVDim; + const size_t aoffset = pos * kHeads * kSeqLen + head * kSeqLen; + const T* q = qkv + qoffset + head * kQKVDim; + for (size_t pos2 = 0; pos2 <= pos; ++pos2) { + const T* k = qkv + (pos2 * (kHeads + 2) + kHeads) * kQKVDim; + output[aoffset + pos2] = DotT(q, k, kQKVDim); + } + } + } +} +template +void MaskedSoftmax(T* x, size_t num_tokens, size_t kHeads, size_t kSeqLen) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + for (size_t head = 0; head < kHeads; ++head) { + size_t offset = pos * kHeads * kSeqLen + head * kSeqLen; + Softmax(x + offset, pos + 1); + memset(x + offset + pos + 1, 0, (kSeqLen - pos - 1) * sizeof(T)); + } + } +} +template +void MixByAttention(const T* qkv, const T* attention, T* output, + size_t num_tokens, size_t kHeads, size_t kQKVDim, + size_t kSeqLen) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + for (size_t head = 0; head < kHeads; ++head) { + const T* att = &attention[pos * kHeads * kSeqLen + head * kSeqLen]; + T* out = &output[head * kQKVDim + pos * kHeads * kQKVDim]; + memset(out, 0, kQKVDim * sizeof(out[0])); + for (size_t pos2 = 0; pos2 <= pos; ++pos2) { + size_t v_offset = (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim; + const T* v = &qkv[v_offset]; + MulByConstAndAddT(att[pos2], v, out, kQKVDim); + } + } + } +} +template +void ApplyLayer(const Layer& weights, + ForwardLayer& activations, + size_t num_tokens, T* output) { + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kSeqLen = TConfig::kSeqLen; + static constexpr size_t kQKVDim = TConfig::kQKVDim; + static constexpr size_t kHeads = TConfig::kHeads; + static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + static const T kQueryScale = T(1.0) / std::sqrt(T(kQKVDim)); + + RMSNormT(weights.pre_attention_norm_scale.data(), activations.input.data(), + activations.pre_att_rms_out.data(), kModelDim, num_tokens); + + MatMulT(weights.qkv_einsum_w.data(), activations.pre_att_rms_out.data(), + activations.qkv.data(), (kHeads + 2) * kQKVDim, kModelDim, + num_tokens); + + for (size_t pos = 0; pos < num_tokens; ++pos) { + T* qkv = activations.qkv.data() + pos * (kHeads + 2) * kQKVDim; + for (size_t h = 0; h <= kHeads; ++h) { + Rope(qkv + h * kQKVDim, kQKVDim, pos); + } + } + + for (size_t pos = 0; pos < num_tokens; ++pos) { + T* qkv = activations.qkv.data() + pos * (kHeads + 2) * kQKVDim; + MulByConstT(kQueryScale, qkv, kHeads * kQKVDim); + } + + MaskedAttention(activations.qkv.data(), activations.att.data(), + num_tokens, kHeads, kQKVDim, kSeqLen); + + MaskedSoftmax(activations.att.data(), num_tokens, kHeads, kSeqLen); + + MixByAttention(activations.qkv.data(), activations.att.data(), + activations.att_out.data(), num_tokens, kHeads, kQKVDim, + kSeqLen); + + MultiHeadMatMul(weights.attn_vec_einsum_w.data(), activations.att_out.data(), + activations.attention_out.data(), kHeads, kModelDim, kQKVDim, + num_tokens); + + AddFromT(activations.input.data(), activations.attention_out.data(), + num_tokens * kModelDim); + + RMSNormT(weights.pre_ffw_norm_scale.data(), activations.attention_out.data(), + activations.bf_pre_ffw_rms_out.data(), kModelDim, num_tokens); + + MatMulT(weights.gating_einsum_w.data(), activations.bf_pre_ffw_rms_out.data(), + activations.ffw_hidden.data(), kFFHiddenDim * 2, kModelDim, + num_tokens); + + GatedGelu(activations.ffw_hidden.data(), activations.ffw_hidden_gated.data(), + kFFHiddenDim, num_tokens); + + MatMulT(weights.linear_w.data(), activations.ffw_hidden_gated.data(), + output, kModelDim, kFFHiddenDim, num_tokens); + + AddFromT(activations.attention_out.data(), output, num_tokens * kModelDim); +} + +template +T CrossEntropyLoss(const T* x, const Prompt& prompt, size_t V) { + T loss = {}; + const std::vector tokens = prompt.tokens; + const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; + for (size_t i = 0; i < num_tokens; ++i) { + if (i + 1 < prompt.context_size) { + continue; // next token is part of context, don't try to predict it + } + const int next_token = tokens[i + 1]; + loss += std::log(x[i * V + next_token]); + } + T scaling = -1.0 / std::log(2.0); + return loss * scaling; +} + +template +T CrossEntropyLossForwardPass(const Prompt& prompt, + const Weights& weights, + ForwardPass& forward) { + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kVocabSize = TConfig::kVocabSize; + static constexpr size_t kLayers = TConfig::kLayers; + const std::vector tokens = prompt.tokens; + const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; + + const T kEmbScaling = EmbeddingScaling(kModelDim); + InputEmbedding(weights.embedder_input_embedding.data(), tokens, + kEmbScaling, forward.layers[0].input.data(), kModelDim); + + for (size_t layer = 0; layer < kLayers; ++layer) { + T* output = layer + 1 < kLayers ? + forward.layers[layer + 1].input.data() : + forward.final_layer_output.data(); + ApplyLayer(*weights.GetLayer(layer), forward.layers[layer], num_tokens, + output); + } + + RMSNormT(weights.final_norm_scale.data(), + forward.final_layer_output.data(), + forward.final_norm_output.data(), kModelDim, num_tokens); + + MatMulT(weights.embedder_input_embedding.data(), + forward.final_norm_output.data(), + forward.logits.data(), kVocabSize, kModelDim, num_tokens); + + Softcap(forward.logits.data(), num_tokens * kVocabSize); + + memcpy(forward.probs.data(), forward.logits.data(), + num_tokens * kVocabSize * sizeof(forward.logits[0])); + Softmax(forward.probs.data(), kVocabSize, num_tokens); + + return CrossEntropyLoss(forward.probs.data(), prompt, kVocabSize); +} + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_ diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc new file mode 100644 index 0000000..5c354a9 --- /dev/null +++ b/backprop/optimize_test.cc @@ -0,0 +1,127 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "backprop/backward.h" +#include "backprop/forward.h" +#include "backprop/optimizer.h" +#include "backprop/sampler.h" +#include "gemma/activations.h" +#include "gemma/gemma.h" +#include "gemma/weights.h" +#include "gtest/gtest.h" + +namespace gcpp { + +TEST(OptimizeTest, GradientDescent) { + hwy::ThreadPool pool(0); + std::mt19937 gen(42); + + Model model_type = Model::GEMMA_TINY; + ByteStorageT weights = AllocateWeights(model_type, pool); + ByteStorageT grad = AllocateWeights(model_type, pool); + ByteStorageT grad_m = AllocateWeights(model_type, pool); + ByteStorageT grad_v = AllocateWeights(model_type, pool); + ByteStorageT forward = AllocateForwardPass(model_type); + ByteStorageT backward = AllocateForwardPass(model_type); + ByteStorageT inference = AllocateInferenceState(model_type); + auto kv_cache = CreateKVCache(model_type); + size_t max_tokens = 32; + size_t max_generated_tokens = 16; + float temperature = 1.0f; + int verbosity = 0; + const auto accept_token = [](int) { return true; }; + + const auto generate = [&](const std::vector& prompt) { + std::vector reply; + auto stream_token = [&reply](int token, float) { + reply.push_back(token); + return token != ReverseSequenceSampler::kEndToken; + }; + RuntimeConfig runtime = { + max_tokens, max_generated_tokens, temperature, verbosity, &gen, + stream_token, accept_token, ReverseSequenceSampler::kEndToken, + }; + TimingInfo timing_info; + GenerateGemma(model_type, weights, inference, runtime, prompt, 0, + kv_cache, pool, timing_info); + return reply; + }; + + auto verify = [&](const Prompt& prompt) { + auto context = prompt.context(); + std::vector reply = generate(context); + bool ok = true; + for (size_t i = 0; ok && i < prompt.tokens.size(); ++i) { + if (i >= reply.size() || reply[i] != prompt.tokens[i]) { + ok = false; + } + } + return ok; + }; + + RandInitWeights(model_type, weights, pool, gen); + ZeroInitWeights(model_type, grad_m, pool); + ZeroInitWeights(model_type, grad_v, pool); + + printf("Initial weights:\n"); + LogWeightStats(model_type, weights); + + constexpr size_t kBatchSize = 8; + float learning_rate = 0.0005f; + + ReverseSequenceSampler training_task({ + 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1}); + size_t steps = 0; + float prev_loss = std::numeric_limits::max(); + size_t num_ok; + for (; steps < 1000000; ++steps) { + std::mt19937 sgen(42); + ZeroInitWeights(model_type, grad, pool); + float total_loss = 0.0f; + num_ok = 0; + for (size_t i = 0; i < kBatchSize; ++i) { + Prompt prompt = training_task.Sample(sgen); + total_loss += CrossEntropyLossForwardPass( + model_type, prompt, weights, forward, pool); + CrossEntropyLossBackwardPass( + model_type, prompt, weights, forward, grad, backward, pool); + num_ok += verify(prompt) ? 1 : 0; + } + total_loss /= kBatchSize; + + const float scale = -learning_rate / kBatchSize; + UpdateWeights(model_type, grad, scale, weights, pool); + printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n", + steps, total_loss, num_ok, kBatchSize); + if (steps % 100 == 0) { + printf("Batch gradient:\n"); + LogWeightStats(model_type, grad); + } + if (total_loss < 0.5f) { + break; + } + prev_loss = total_loss; + } + printf("Num steps: %zu\n", steps); + printf("Final weights:\n"); + LogWeightStats(model_type, weights); + EXPECT_LT(steps, 3000); + EXPECT_EQ(num_ok, kBatchSize); +} + +} // namespace gcpp diff --git a/backprop/optimizer.cc b/backprop/optimizer.cc new file mode 100644 index 0000000..5b1c61d --- /dev/null +++ b/backprop/optimizer.cc @@ -0,0 +1,121 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "backprop/optimizer.h" + +#include + +#include "gemma/common.h" +#include "gemma/configs.h" +#include "gemma/weights.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +namespace gcpp { + +namespace { +class WeightInitializer { + public: + WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {} + + template + void operator()(const char* name, std::array& tensor) { + for (size_t i = 0; i < N; ++i) { + tensor[i] = dist_(gen_); + } + } + private: + std::normal_distribution dist_; + std::mt19937& gen_; +}; + +template +void RandInitWeights(ByteStorageT& weights_u8, hwy::ThreadPool& pool, + std::mt19937& gen) { + auto& weights = *reinterpret_cast*>(weights_u8.get()); + // TODO(szabadka) Use the same weight initialization method as in the python + // version. + WeightInitializer init(gen); + ForEachTensor1(init, weights); +} + +class WeightUpdater { + public: + explicit WeightUpdater(float lr) : lr_(lr) {} + + template + void operator()(const char* name, const std::array& grad, + std::array& weights) { + for (size_t i = 0; i < kCapacity; ++i) { + weights[i] += lr_ * grad[i]; + } + } + + private: + float lr_; +}; + +template +void UpdateWeights(const ByteStorageT& grad_u8, float scale, + ByteStorageT& weights_u8, hwy::ThreadPool& pool) { + const auto& grad = + *reinterpret_cast*>(grad_u8.get()); + auto& weights = *reinterpret_cast*>(weights_u8.get()); + WeightUpdater updater(scale); + ForEachTensor2(updater, grad, weights); +} + +} // namespace + +void RandInitWeights(Model model, ByteStorageT& weights_u8, + hwy::ThreadPool& pool, std::mt19937& gen) { + switch (model) { + case Model::GEMMA_2B: + RandInitWeights(weights_u8, pool, gen); + break; + case Model::GEMMA_7B: + RandInitWeights(weights_u8, pool, gen); + break; + case Model::GRIFFIN_2B: + RandInitWeights(weights_u8, pool, gen); + break; + case Model::GEMMA_TINY: + RandInitWeights(weights_u8, pool, gen); + break; + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } +} + +void UpdateWeights(Model model, const ByteStorageT& grad, float scale, + ByteStorageT& weights, hwy::ThreadPool& pool) { + switch (model) { + case Model::GEMMA_2B: + UpdateWeights(grad, scale, weights, pool); + break; + case Model::GEMMA_7B: + UpdateWeights(grad, scale, weights, pool); + break; + case Model::GRIFFIN_2B: + UpdateWeights(grad, scale, weights, pool); + break; + case Model::GEMMA_TINY: + UpdateWeights(grad, scale, weights, pool); + break; + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } +} + +} // namespace gcpp diff --git a/backprop/optimizer.h b/backprop/optimizer.h new file mode 100644 index 0000000..343db97 --- /dev/null +++ b/backprop/optimizer.h @@ -0,0 +1,34 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ + +#include + +#include "gemma/common.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +namespace gcpp { + +void RandInitWeights(Model model, ByteStorageT& weights, hwy::ThreadPool& pool, + std::mt19937& gen); + +void UpdateWeights(Model model, const ByteStorageT& grad, float scale, + ByteStorageT& weights, hwy::ThreadPool& pool); + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ diff --git a/backprop/prompt.h b/backprop/prompt.h new file mode 100644 index 0000000..76acb56 --- /dev/null +++ b/backprop/prompt.h @@ -0,0 +1,34 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_ + +#include +#include + +namespace gcpp { + +struct Prompt { + std::vector tokens; + size_t context_size; + std::vector context() const { + return std::vector(tokens.begin(), tokens.begin() + context_size); + } +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_ diff --git a/backprop/sampler.h b/backprop/sampler.h new file mode 100644 index 0000000..257b993 --- /dev/null +++ b/backprop/sampler.h @@ -0,0 +1,84 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_ + +#include + +#include "backprop/prompt.h" + +namespace gcpp { + +class PromptSampler { + public: + virtual Prompt Sample(std::mt19937& gen) = 0; + + std::vector SampleBatch(size_t batch_size, std::mt19937& gen) { + std::vector batch; + batch.reserve(batch_size); + for (size_t i = 0; i < batch_size; ++i) { + batch.emplace_back(Sample(gen)); + } + return batch; + } +}; + +class ReverseSequenceSampler : public PromptSampler { + public: + explicit ReverseSequenceSampler(const std::vector& length_histo) + : token_dist_(0, 9) { + for (int i = 0; i < length_histo.size(); ++i) { + const int count = length_histo[i]; + for (int j = 0; j < count; ++j) { + length_lut_.push_back(i + 1); + } + } + length_dist_ = std::uniform_int_distribution<>(0, length_lut_.size() - 1); + } + + static constexpr int kReverseToken = 10; + static constexpr int kEndToken = 11; + + Prompt Sample(std::mt19937& gen) override { + Prompt prompt; + int len = length_lut_[length_dist_(gen)]; + prompt.tokens.resize(2 * len + 2); + prompt.tokens[len] = kReverseToken; + prompt.tokens[2 * len + 1] = kEndToken; + for (size_t i = 0; i < len; ++i) { + prompt.tokens[i] = prompt.tokens[2 * len - i] = token_dist_(gen); + } + prompt.context_size = len + 1; + return prompt; + } + + static void LogPrompt(const Prompt& prompt) { + static const char* kVocab[] = { + "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "-->", "|", + }; + for (int token : prompt.tokens) printf("%s", kVocab[token]); + printf(" [context_size: %zu]\n", prompt.context_size); + } + + private: + std::uniform_int_distribution<> token_dist_; + std::uniform_int_distribution<> length_dist_; + std::vector length_lut_; +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_ diff --git a/backprop/test_util.h b/backprop/test_util.h new file mode 100644 index 0000000..939411d --- /dev/null +++ b/backprop/test_util.h @@ -0,0 +1,195 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_ + +#include +#include +#include + +#include "gemma/weights.h" +#include "gtest/gtest.h" + +namespace gcpp { + +template +void RandInit(std::array& x, T stddev, std::mt19937& gen) { + std::normal_distribution dist(0.0, stddev); + for (size_t i = 0; i < kLen; ++i) { + x[i] = dist(gen); + } +} + +template +void RandInit(Layer& w, T stddev, std::mt19937& gen) { + RandInit(w.pre_attention_norm_scale, stddev, gen); + RandInit(w.attn_vec_einsum_w, stddev, gen); + RandInit(w.qkv_einsum_w, stddev, gen); + RandInit(w.pre_ffw_norm_scale, stddev, gen); + RandInit(w.gating_einsum_w, stddev, gen); + RandInit(w.linear_w, stddev, gen); +} + +template +void RandInit(Weights& w, T stddev, std::mt19937& gen) { + static constexpr size_t kLayers = TConfig::kLayers; + RandInit(w.embedder_input_embedding, stddev, gen); + RandInit(w.final_norm_scale, stddev, gen); + for (size_t i = 0; i < kLayers; ++i) { + RandInit(*w.GetLayer(i), stddev, gen); + } +} + +template +void Complexify(const std::array& x, + std::array, kLen>& c_x) { + for (size_t i = 0; i < kLen; ++i) { + c_x[i] = std::complex(x[i], 0.0); + } +} + + +template +void Complexify(const Layer& w, + Layer, TConfig>& c_w) { + Complexify(w.pre_attention_norm_scale, c_w.pre_attention_norm_scale); + Complexify(w.attn_vec_einsum_w, c_w.attn_vec_einsum_w); + Complexify(w.qkv_einsum_w, c_w.qkv_einsum_w); + Complexify(w.pre_ffw_norm_scale, c_w.pre_ffw_norm_scale); + Complexify(w.gating_einsum_w, c_w.gating_einsum_w); + Complexify(w.linear_w, c_w.linear_w); +} + +template +void Complexify(const Weights& w, + Weights, TConfig>& c_w) { + static constexpr size_t kLayers = TConfig::kLayers; + Complexify(w.embedder_input_embedding, c_w.embedder_input_embedding); + Complexify(w.final_norm_scale, c_w.final_norm_scale); + for (size_t i = 0; i < kLayers; ++i) { + Complexify(*w.GetLayer(i), *c_w.GetLayer(i)); + } +} + +template +void TestNear(const std::array& actual, const std::array& expected, + double max_abs_err, double max_rel_err, int line) { + double sum0 = 0; + double sum1 = 0; + double sum01 = 0; + for (size_t i = 0; i < N; ++i) { + sum0 += actual[i] * actual[i]; + sum1 += expected[i] * expected[i]; + sum01 += actual[i] * expected[i]; + ASSERT_NEAR(actual[i], expected[i], + std::max(max_abs_err, std::abs(expected[i]) * max_rel_err)) + << "line: " << line << " dim=" << N << " i=" << i; + } + if (sum0 > 1e-40) { + double norm_dot = sum01 / std::sqrt(sum0) / std::sqrt(sum1); + ASSERT_NEAR(norm_dot, 1.0, 1e-7) + << "line: " << line << " sum0: " << sum0 << " sum1: " << sum1 + << " sum01: " << sum01; + } +} + +// Compute gradient with the finite difference method in the complex plane. +// If f : R->R is the tested function and F : C->C is its extension on the +// complex plane so that F is complex differentiable in x, then +// +// F(x + ih) = F(x) + ih F'(x) + O(h^2) F''(x) +// +// which means that +// +// F'(x) ~= Imag(F(x + ih)) / h +// +// This method is more numerically stable than the real-valued finite difference +// method since we don't need to subtract floating point numbers that are near +// to each other. +template +void TestGradient(const std::array& grad, + std::array, N>& x, FUNC func, + U step, T max_abs_err, T max_rel_err, int line) { + std::array exp_grad; + const U inv_step = 1.0 / step; + for (size_t i = 0; i < N; ++i) { + const U x0 = std::real(x[i]); + const std::complex x1 = std::complex(x0, step); + x[i] = x1; + const std::complex f1 = func(); + exp_grad [i] = std::imag(f1) * inv_step; + x[i] = x0; + } + TestNear(grad, exp_grad, max_abs_err, max_rel_err, line); +} + +template +void TestGradient(const std::array& grad, + std::array, N>& x, FUNC func, + float max_abs_err, float max_rel_error, int line) { + TestGradient(grad, x, func, 1e-30f, max_abs_err, max_rel_error, line); +} + +template +void TestGradient(const std::array& grad, + std::array, N>& x, FUNC func, + float max_abs_err, float max_rel_error, int line) { + TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line); +} + +template +void TestGradient(const std::array& grad, + std::array, N>& x, FUNC func, + double max_abs_err, double max_rel_error, int line) { + TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line); +} + +template +void TestGradient(const Layer& grad, + Layer, TConfig>& c_weights, + FUNC func, T max_err) { + TestGradient(grad.pre_attention_norm_scale, + c_weights.pre_attention_norm_scale, + func, max_err, max_err, __LINE__); + TestGradient(grad.attn_vec_einsum_w, c_weights.attn_vec_einsum_w, + func, max_err, max_err, __LINE__); + TestGradient(grad.qkv_einsum_w, c_weights.qkv_einsum_w, + func, max_err, max_err, __LINE__); + TestGradient(grad.pre_ffw_norm_scale, c_weights.pre_ffw_norm_scale, + func, max_err, max_err, __LINE__); + TestGradient(grad.gating_einsum_w, c_weights.gating_einsum_w, + func, max_err, max_err, __LINE__); + TestGradient(grad.linear_w, c_weights.linear_w, + func, max_err, max_err, __LINE__); +} + +template +void TestGradient(const Weights& grad, + Weights, TConfig>& c_weights, + FUNC func, T max_err) { + TestGradient(grad.embedder_input_embedding, + c_weights.embedder_input_embedding, + func, 2 * max_err, max_err, __LINE__); + TestGradient(grad.final_norm_scale, c_weights.final_norm_scale, + func, max_err, max_err, __LINE__); + for (int i = 0; i < TConfig::kLayers; ++i) { + TestGradient(*grad.GetLayer(i), *c_weights.GetLayer(i), func, max_err); + } +} + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_ diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 1f18765..2bac834 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -484,9 +484,10 @@ HWY_INLINE void Decompress(const CompressedArray& compressed, const hn::ScalableTag d; const size_t ofs = idx_batch * kBatch; - const size_t num = idx_batch == num_batches - 1 ? (num - ofs) : kBatch; + const size_t batch = + idx_batch == num_batches - 1 ? (num - ofs) : kBatch; Traits::Decompress(d, compressed.size(), compressed.data(), - compressed_ofs + ofs, out + ofs, num); + compressed_ofs + ofs, out + ofs, batch); }); const double t1 = hwy::platform::Now(); @@ -495,6 +496,16 @@ HWY_INLINE void Decompress(const CompressedArray& compressed, fprintf(stderr, "Decompress %.1f MB/s\n", mbps); } +// Returns dot product with `vec_aligned` of length `num`. +template +HWY_INLINE float Dot(DF df, const std::array& w, size_t ofs, + const VecT* x, size_t num) { + HWY_DASSERT(ofs + num <= kCapacity); + HWY_DASSERT(hn::IsAligned(df, x)); + using Traits = CompressTraits; + return Traits::Dot(df, w.size(), w.data(), ofs, x, num); +} + // Returns dot product with `vec_aligned` of length `num`. template HWY_INLINE float Dot(DF df, const CompressedArray& compressed, diff --git a/gemma/activations.cc b/gemma/activations.cc new file mode 100644 index 0000000..56f8bde --- /dev/null +++ b/gemma/activations.cc @@ -0,0 +1,38 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gemma/activations.h" + +#include "gemma/common.h" +#include "gemma/configs.h" + +namespace gcpp { + +ByteStorageT AllocateForwardPass(Model model) { + switch (model) { + case Model::GEMMA_2B: + return ForwardPass::Allocate(); + case Model::GEMMA_7B: + return ForwardPass::Allocate(); + case Model::GRIFFIN_2B: + return ForwardPass::Allocate(); + case Model::GEMMA_TINY: + return ForwardPass::Allocate(); + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } +} + +} // namespace gcpp diff --git a/gemma/activations.h b/gemma/activations.h new file mode 100644 index 0000000..9a43344 --- /dev/null +++ b/gemma/activations.h @@ -0,0 +1,85 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ + +#include "gemma/common.h" +#include "hwy/aligned_allocator.h" + +namespace gcpp { + +template +struct ForwardLayer { + static constexpr size_t kSeqLen = TConfig::kSeqLen; + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kQKVDim = TConfig::kQKVDim; + static constexpr size_t kHeads = TConfig::kHeads; + static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + std::array input; + std::array pre_att_rms_out; + std::array qkv; + std::array att; + std::array att_out; + std::array att_post1; + std::array attention_out; + std::array bf_pre_ffw_rms_out; + std::array ffw_hidden; + std::array ffw_hidden_gated; +}; + +template +struct ForwardPass { + ForwardPass() {} // prevents placement-new calling memset + + static constexpr size_t kSeqLen = TConfig::kSeqLen; + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kVocabSize = TConfig::kVocabSize; + static constexpr size_t kLayers = TConfig::kLayers; + + std::array, kLayers> layers; + std::array final_layer_output; + std::array final_norm_output; + std::array logits; + std::array probs; + + static ByteStorageT Allocate() { + return hwy::AllocateAligned(sizeof(ForwardPass)); + } +}; + +// Owns activations and undoes the type erasure of AllocateAligned. +template +class ActivationsWrapper { + using WrappedT = ForwardPass; + + public: + ActivationsWrapper() + : data_(WrappedT::Allocate()), + activations_(*reinterpret_cast(data_.get())) {} + + const WrappedT& get() const { return activations_; } + WrappedT& get() { return activations_; } + + private: + ByteStorageT data_; + WrappedT& activations_; +}; + +ByteStorageT AllocateForwardPass(Model model); + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ diff --git a/gemma/common-inl.h b/gemma/common-inl.h new file mode 100644 index 0000000..ac39d73 --- /dev/null +++ b/gemma/common-inl.h @@ -0,0 +1,68 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Include guard for non-SIMD code. +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_ + +#include +#include + +#include +#include + +#include "gemma/activations.h" + +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_ + +// Include guard for (potentially) SIMD code. +#if defined(THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE) == defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE +#undef THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE +#else +#define THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE +#endif + +#include "gemma/ops.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo` +// are both constexpr +#if HWY_COMPILER_GCC_ACTUAL +#define GEMMA_CONSTEXPR_EMBSCALING HWY_BF16_CONSTEXPR +#else +#define GEMMA_CONSTEXPR_EMBSCALING +#endif + +template +GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() { + // Round to bf16 to match Gemma's Embedder, which casts before mul. + return hwy::ConvertScalarTo(hwy::ConvertScalarTo( + Sqrt(static_cast(TConfig::kModelDim)))); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#endif // NOLINT diff --git a/gemma/common.h b/gemma/common.h new file mode 100644 index 0000000..d497259 --- /dev/null +++ b/gemma/common.h @@ -0,0 +1,30 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ + +#include "hwy/aligned_allocator.h" + +namespace gcpp { + +using ByteStorageT = hwy::AlignedFreeUniquePtr; + +// Model variants: see configs.h for details. +enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B, GEMMA_TINY }; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ diff --git a/gemma/configs.h b/gemma/configs.h index 10de09e..7fc6477 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -142,6 +142,38 @@ struct ConfigGemma2B { using WeightT = GEMMA_WEIGHT_T; }; +struct ConfigGemmaTiny { + static constexpr int kSeqLen = 32; + static constexpr int kVocabSize = 16; + static constexpr std::array kLayerConfig = + FixedLayerConfig<3>(LayerAttentionType::kGemma); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kGemmaLayers = + NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers); + static constexpr int kGriffinLayers = + NumLayersOfTypeBefore(kLayerConfig, + LayerAttentionType::kGriffinRecurrentBlock, + kLayers); + static constexpr int kModelDim = 64; + static constexpr int kFFHiddenDim = 128; + static constexpr int kHeads = 4; + static constexpr int kKVHeads = 1; + static constexpr int kQKVDim = 16; // query size == key size == value size + static constexpr int kTopK = gcpp::kTopK; + static constexpr bool kAbsolutePE = false; + static constexpr bool kPostNormScale = false; + + // SSM config. + static constexpr int kConv1dWidth = 0; + static constexpr bool kFFBiases = false; + static constexpr bool kSoftmaxAttnOutputBiases = false; + static constexpr bool kUseHalfRope = false; + static constexpr bool kUseLocalAttention = false; + static constexpr bool kInterleaveQKV = true; + static constexpr int kNumTensorScales = 0; + using WeightT = GEMMA_WEIGHT_T; +}; + struct ConfigGriffin2B { // Griffin uses local attention, so kSeqLen is actually the local attention // window. diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 52f2656..d168a9d 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -22,6 +22,7 @@ #include "hwy/foreach_target.h" // IWYU pragma: keep // Must come after foreach_target.h to avoid redefinition errors. #include "compression/compress-inl.h" +#include "gemma/common-inl.h" #include "gemma/ops.h" #include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/highway.h" @@ -53,6 +54,7 @@ #include "compression/io.h" // Path #include "gemma/configs.h" #include "gemma/gemma.h" +#include "gemma/weights.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -72,63 +74,6 @@ constexpr bool kShowTokenization = false; namespace gcpp { -template -struct Layer { - Layer() = default; - static constexpr size_t kHeads = TConfig::kHeads; - static constexpr size_t kKVHeads = TConfig::kKVHeads; - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kQKVDim = TConfig::kQKVDim; - static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; - static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim; - static constexpr size_t kQKVEinsumWSize = - (kHeads + 2 * kKVHeads) * kQKVDim * kModelDim; - // 2x for (gelu gating vector, gated vector) - static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim; - static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; - static constexpr bool kFFBiases = TConfig::kFFBiases; - static constexpr bool kPostNormScale = TConfig::kPostNormScale; - static constexpr size_t kAOBiasDim = - TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0; - static constexpr size_t kGriffinDim = - TConfig::kGriffinLayers > 0 ? kModelDim : 0; - - template - using ArrayT = std::array; - - union { - struct { - ArrayT attn_vec_einsum_w; - ArrayT qkv_einsum_w; - ArrayT attention_output_biases; - }; - - struct { - ArrayT linear_x_w; - ArrayT linear_x_biases; - ArrayT linear_y_w; - ArrayT linear_y_biases; - ArrayT linear_out_w; - ArrayT linear_out_biases; - ArrayT conv_w; - ArrayT conv_biases; - ArrayT gate_w; - ArrayT gate_biases; - ArrayT a; - } griffin; - }; - - ArrayT gating_einsum_w; - ArrayT linear_w; - ArrayT pre_attention_norm_scale; - ArrayT pre_ffw_norm_scale; - ArrayT post_attention_norm_scale; - ArrayT post_ffw_norm_scale; - - ArrayT ffw_gating_biases; - ArrayT ffw_output_biases; -}; - float ScaleWeights(float* data, size_t len) { float maxabs = 0.0; for (size_t i = 0; i < len; ++i) { @@ -146,41 +91,6 @@ float ScaleWeights(float* data, size_t len) { return scale; } -// Array instead of single large allocation for parallel mem init. Split out of -// Weights so that only these pointers are initialized. -template -struct LayerPointers { - explicit LayerPointers(hwy::ThreadPool& pool) { - pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) { - this->layers[task] = hwy::AllocateAligned>(1); - }); - } - - using TLayer = Layer; - std::array, TConfig::kLayers> layers; -}; - -template -struct Weights { - // No ctor/dtor, allocated via AllocateAligned. - - std::array - embedder_input_embedding; - - std::array final_norm_scale; - - LayerPointers layer_ptrs; - - std::array scales; - - const Layer* GetLayer(size_t layer) const { - return layer_ptrs.layers[layer].get(); - } - Layer* GetLayer(size_t layer) { - return layer_ptrs.layers[layer].get(); - } -}; - template hwy::AlignedFreeUniquePtr LoadWeights( const Path& checkpoint, hwy::ThreadPool& pool, @@ -191,11 +101,8 @@ hwy::AlignedFreeUniquePtr LoadWeights( checkpoint.path.c_str()); } - using TWeights = Weights; - hwy::AlignedFreeUniquePtr weights_u8 = - hwy::AllocateAligned(sizeof(TWeights)); - TWeights* weights = reinterpret_cast(weights_u8.get()); - new (&weights->layer_ptrs) LayerPointers(pool); + ByteStorageT weights_u8 = AllocateWeights(pool); + auto* weights = reinterpret_cast*>(weights_u8.get()); size_t scale_pos = 0; FILE* fptr; @@ -228,7 +135,7 @@ hwy::AlignedFreeUniquePtr LoadWeights( sizeof(weights->final_norm_scale)); for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { auto type = TConfig::kLayerConfig[layer]; - Layer* layer_view = weights->GetLayer(layer); + LayerF* layer_view = weights->GetLayer(layer); #define READ_WEIGHTS(name) \ do { \ @@ -305,7 +212,7 @@ template struct CompressedLayer { // No ctor/dtor, allocated via AllocateAligned. - using TLayer = gcpp::Layer; + using TLayer = gcpp::LayerF; using WeightT = typename TConfig::WeightT; static constexpr size_t kHeads = TLayer::kHeads; @@ -399,13 +306,13 @@ struct CompressedWeights { template using WeightsT = hwy::If, - Weights>; + WeightsF>; // Aligned. template struct Activations { static constexpr size_t kBatchSize = TBatchSize; - using LayerConfig = Layer; + using LayerConfig = LayerF; static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kQKVDim = TConfig::kQKVDim; static constexpr size_t kHeads = TConfig::kHeads; @@ -446,6 +353,16 @@ struct Activations { std::array griffin_multiplier; }; +template +struct InferenceState { + Activations prefill; + HWY_ALIGN Activations state; + + static ByteStorageT Allocate() { + return hwy::AllocateAligned(sizeof(InferenceState)); + } +}; + // GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we // define an abstract base class. struct GemmaInterface { @@ -484,6 +401,8 @@ KVCache CreateKVCache(Model type) { return CreateKVCacheT(); case Model::GRIFFIN_2B: return CreateKVCacheT(); + case Model::GEMMA_TINY: + return CreateKVCacheT(); default: HWY_ABORT("Model type %d unknown.", static_cast(type)); } @@ -526,8 +445,8 @@ void DeleteLayersPtrs(CompressedWeights* c_weights) { c_weights->c_layer_ptrs.~CompressedLayerPointers(); } template -void DeleteLayersPtrs(Weights* weights) { - weights->layer_ptrs.~LayerPointers(); +void DeleteLayersPtrs(WeightsF* weights) { + weights->layer_ptrs.~LayerPointers(); } } // namespace @@ -866,8 +785,9 @@ HWY_NOINLINE void FFW(Activations& activations, // Same matrix, first and second half of rows. Could fuse into one MatVec. MatVecT( layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec, - layer_weights->ffw_gating_biases.data() + kFFHiddenDim, even_odd, - out_mul, pool); + TConfig::kFFBiases ? + layer_weights->ffw_gating_biases.data() + kFFHiddenDim : nullptr, + even_odd, out_mul, pool); // Gate, will go through the nonlinearity. MatVecT( layer_weights->gating_einsum_w, 0, vec, @@ -892,21 +812,6 @@ HWY_NOINLINE void FFW(Activations& activations, } } -// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo` -// are both constexpr -#if HWY_COMPILER_GCC_ACTUAL -#define GEMMA_CONSTEXPR_EMBSCALING HWY_BF16_CONSTEXPR -#else -#define GEMMA_CONSTEXPR_EMBSCALING -#endif - -template -GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() { - // Round to bf16 to match Gemma's Embedder, which casts before mul. - return hwy::ConvertScalarTo(hwy::ConvertScalarTo( - Sqrt(static_cast(TConfig::kModelDim)))); -} - template HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, const WeightArrayT& weights, @@ -1076,20 +981,15 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens, } } -template -void GenerateImpl(GemmaImpl& gemma, +template typename WeightsType> +void GenerateImpl(const WeightsType& weights, + Activations& prefill_activations, + Activations& activations, const RuntimeConfig& runtime_config, const std::vector& prompt, size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, TimingInfo& timing_info, LayersOutputT* layers_output) { static constexpr size_t kVocabSize = TConfig::kVocabSize; - Activations& activations = *gemma.state.get(); - Activations& prefill_activations = - *gemma.prefill.get(); - - const WeightsT& weights = - *reinterpret_cast*>(gemma.weights_u8.get()); - size_t prompt_size = prompt.size(); size_t max_tokens = runtime_config.max_tokens; size_t max_generated_tokens = runtime_config.max_generated_tokens; @@ -1167,7 +1067,7 @@ void GenerateImpl(GemmaImpl& gemma, activations.logits.data(), kVocabSize, *runtime_config.gen, runtime_config.temperature, runtime_config.accept_token); if (!runtime_config.stream_token(token, activations.logits[token])) { - token = EOS_ID; + token = runtime_config.eos_id; } if (generate_pos == 0) { timing_info.time_to_first_token = hwy::platform::Now() - gen_start; @@ -1177,10 +1077,10 @@ void GenerateImpl(GemmaImpl& gemma, // process the tokens of the prompt one at a time. token = prompt.at(pos_offset + 1); if (!runtime_config.stream_token(token, 0)) { - token = EOS_ID; + token = runtime_config.eos_id; } } - if (token == EOS_ID) { + if (token == runtime_config.eos_id) { if (runtime_config.verbosity >= 2) { const double gen_end = hwy::platform::Now(); timing_info.gen_tok_sec = @@ -1192,6 +1092,35 @@ void GenerateImpl(GemmaImpl& gemma, } } +template +void GenerateImpl(GemmaImpl& gemma, + const RuntimeConfig& runtime_config, + const std::vector& prompt, size_t pos, KVCache& kv_cache, + hwy::ThreadPool& pool, TimingInfo& timing_info, + LayersOutputT* layers_output) { + const WeightsT& weights = + *reinterpret_cast*>(gemma.weights_u8.get()); + GenerateImpl( + weights, *gemma.prefill.get(), *gemma.state.get(), runtime_config, prompt, + pos, kv_cache, pool, timing_info, layers_output); +} + +template +void GenerateGemma(const ByteStorageT& weights_u8, + ByteStorageT& inference_state_u8, + const RuntimeConfig& runtime_config, + const std::vector& prompt, size_t pos, + KVCache& kv_cache, hwy::ThreadPool& pool, + TimingInfo& timing_info, LayersOutputT* layers_output) { + const WeightsF& weights = + *reinterpret_cast*>(weights_u8.get()); + InferenceState& inference_state = + *reinterpret_cast*>(inference_state_u8.get()); + GenerateImpl( + weights, inference_state.prefill, inference_state.state, runtime_config, + prompt, pos, kv_cache, pool, timing_info, layers_output); +} + #define TOKEN(token_id) TokenString(gemma, token_id).c_str() template @@ -1263,7 +1192,7 @@ float ComputeCrossEntropyImpl(GemmaImpl& gemma, size_t max_tokens, // // This avoids repeating the list of tensors between loading and compressing. template -void ForEachTensor(const Weights* weights, +void ForEachTensor(const WeightsF* weights, CompressedWeights& c_weights, Func& func) { func("c_embedding", weights ? weights->embedder_input_embedding.data() : nullptr, @@ -1275,7 +1204,7 @@ void ForEachTensor(const Weights* weights, for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { auto type = TConfig::kLayerConfig[layer_idx]; const size_t idx = static_cast(layer_idx); - const Layer* layer = weights ? weights->GetLayer(idx) : nullptr; + const LayerF* layer = weights ? weights->GetLayer(idx) : nullptr; CompressedLayer* layer_weights = c_weights.GetLayer(idx); #define CALL_FUNC(name, member) \ @@ -1386,34 +1315,17 @@ void CompressWeights(const Path& weights_path, const bool scale_for_compression = TConfig::kNumTensorScales > 0; const hwy::AlignedFreeUniquePtr weights_u8 = LoadWeights(weights_path, pool, scale_for_compression); - Weights* weights = - reinterpret_cast*>(weights_u8.get()); + WeightsF* weights = + reinterpret_cast*>(weights_u8.get()); Compressor compressor(pool); ForEachTensor(weights, *c_weights, compressor); compressor.AddScales(weights->scales.data(), weights->scales.size()); compressor.WriteAll(pool, compressed_weights_path); - weights->layer_ptrs.~LayerPointers(); + weights->layer_ptrs.~LayerPointers(); c_weights->c_layer_ptrs.~CompressedLayerPointers(); } -void CompressWeightsT(gcpp::Model model, const Path& weights, - const Path& compressed_weights, hwy::ThreadPool& pool) { - switch (model) { - case Model::GEMMA_2B: - CompressWeights(weights, compressed_weights, pool); - break; - case Model::GEMMA_7B: - CompressWeights(weights, compressed_weights, pool); - break; - case Model::GRIFFIN_2B: - CompressWeights(weights, compressed_weights, pool); - break; - default: - HWY_ABORT("Model type %d unknown.", static_cast(model)); - } -} - } // namespace HWY_NAMESPACE } // namespace gcpp HWY_AFTER_NAMESPACE(); @@ -1421,8 +1333,6 @@ HWY_AFTER_NAMESPACE(); #if HWY_ONCE namespace gcpp { -HWY_EXPORT(CompressWeightsT); - KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len, size_t conv1d_cache_size, size_t rglru_cache_size) { KVCache kv_cache = {}; @@ -1528,10 +1438,87 @@ void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config, pool.SetWaitMode(hwy::PoolWaitMode::kBlock); } +void GenerateGemma(Model model, const ByteStorageT& weights, + ByteStorageT& inference_state, + RuntimeConfig runtime_config, + const std::vector& prompt, size_t start_pos, + KVCache& kv_cache, hwy::ThreadPool& pool, + TimingInfo& timing_info) { + switch (model) { + case Model::GEMMA_2B: + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma)( + weights, inference_state, runtime_config, prompt, start_pos, kv_cache, + pool, timing_info, /*layers_output=*/nullptr); + break; + case Model::GEMMA_7B: + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma)( + weights, inference_state, runtime_config, prompt, start_pos, kv_cache, + pool, timing_info, /*layers_output=*/nullptr); + break; + case Model::GRIFFIN_2B: + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma)( + weights, inference_state, runtime_config, prompt, start_pos, kv_cache, + pool, timing_info, /*layers_output=*/nullptr); + break; + case Model::GEMMA_TINY: + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateGemma)( + weights, inference_state, runtime_config, prompt, start_pos, kv_cache, + pool, timing_info, /*layers_output=*/nullptr); + break; + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } +} + +ByteStorageT LoadWeights(const Path& weights, Model model, + hwy::ThreadPool& pool) { + switch (model) { + case Model::GEMMA_2B: + return LoadWeights(weights, pool); + case Model::GEMMA_7B: + return LoadWeights(weights, pool); + case Model::GRIFFIN_2B: + return LoadWeights(weights, pool); + case Model::GEMMA_TINY: + return LoadWeights(weights, pool); + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } +} + +ByteStorageT AllocateInferenceState(Model model) { + switch (model) { + case Model::GEMMA_2B: + return InferenceState::Allocate(); + case Model::GEMMA_7B: + return InferenceState::Allocate(); + case Model::GRIFFIN_2B: + return InferenceState::Allocate(); + case Model::GEMMA_TINY: + return InferenceState::Allocate(); + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } +} + void CompressWeights(gcpp::Model model, const Path& weights, const Path& compressed_weights, hwy::ThreadPool& pool) { - HWY_DYNAMIC_DISPATCH(CompressWeightsT) - (model, weights, compressed_weights, pool); + switch (model) { + case Model::GEMMA_2B: + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights)( + weights, compressed_weights, pool); + break; + case Model::GEMMA_7B: + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights)( + weights, compressed_weights, pool); + break; + case Model::GRIFFIN_2B: + HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights)( + weights, compressed_weights, pool); + break; + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } } float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, @@ -1546,13 +1533,16 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, namespace { constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt", - "2b-it", "7b-it", "gr2b-it"}; + "2b-it", "7b-it", "gr2b-it", + "tiny"}; constexpr Model kModelTypes[] = {Model::GEMMA_2B, Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_2B, - Model::GEMMA_7B, Model::GRIFFIN_2B}; + Model::GEMMA_7B, Model::GRIFFIN_2B, + Model::GEMMA_TINY}; constexpr ModelTraining kModelTraining[] = { ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, - ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT}; + ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, + ModelTraining::GEMMA_IT}; } // namespace const char* ParseModelTypeAndTraining(const std::string& model_flag, diff --git a/gemma/gemma.h b/gemma/gemma.h index 180268e..442637d 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -23,6 +23,7 @@ #include #include "compression/io.h" // Path +#include "gemma/common.h" #include "gemma/configs.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::bfloat16_t @@ -51,8 +52,6 @@ struct KVCache { rglru_cache; // kModelDim * kGriffinLayers }; -// Model variants: see configs.h for details. -enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B }; enum class ModelTraining { GEMMA_IT, GEMMA_PT }; // Returns error string or nullptr if OK. @@ -68,6 +67,8 @@ using StreamFunc = std::function; // want to generate and True for tokens you want to generate. using AcceptFunc = std::function; +constexpr int EOS_ID = 1; + struct RuntimeConfig { size_t max_tokens; size_t max_generated_tokens; @@ -76,6 +77,7 @@ struct RuntimeConfig { std::mt19937* gen; const StreamFunc& stream_token; const AcceptFunc& accept_token; + int eos_id = EOS_ID; }; struct GemmaInterface; @@ -118,6 +120,18 @@ void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config, TimingInfo& timing_info, LayersOutputT* layers_output = nullptr); +void GenerateGemma(Model model, const ByteStorageT& weights, + ByteStorageT& inference_state, + RuntimeConfig runtime_config, + const std::vector& prompt, size_t start_pos, + KVCache& kv_cache, hwy::ThreadPool& pool, + TimingInfo& timing_info); + +ByteStorageT LoadWeights(const Path& weights, Model model, + hwy::ThreadPool& pool); + +ByteStorageT AllocateInferenceState(Model model); + void CompressWeights(gcpp::Model model, const Path& weights, const Path& compressed_weights, hwy::ThreadPool& pool); @@ -125,8 +139,6 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, hwy::ThreadPool& pool, int verbosity); -constexpr int EOS_ID = 1; - } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_ diff --git a/gemma/gemma_test.cc b/gemma/gemma_test.cc index 9ede883..0318fde 100644 --- a/gemma/gemma_test.cc +++ b/gemma/gemma_test.cc @@ -98,7 +98,7 @@ class GemmaTest : public ::testing::Test { gcpp::Gemma model; }; -TEST_F(GemmaTest, Geography) { +TEST_F(GemmaTest, DISABLED_Geography) { static const char* kQA[][2] = { {"What is the capital of Hungary?", "Budapest"}, {"How many states does the US have?", "50"}, @@ -108,7 +108,7 @@ TEST_F(GemmaTest, Geography) { TestQuestions(kQA, kNum); } -TEST_F(GemmaTest, History) { +TEST_F(GemmaTest, DISABLED_History) { static const char* kQA[][2] = { {"When was the Battle of Hastings?", "1066"}, {"Who fought at the Battle of Marathon?", "Greek"}, @@ -117,7 +117,7 @@ TEST_F(GemmaTest, History) { TestQuestions(kQA, kNum); } -TEST_F(GemmaTest, Arithmetic) { +TEST_F(GemmaTest, DISABLED_Arithmetic) { static const char* kQA[][2] = { {"what is 13 + 14?", "27"}, {"what is 7 * 8", "56"}, @@ -280,7 +280,7 @@ static const char kDeclaration[] = { "reliance on the protection of divine Providence, we mutually pledge to " "each other our Lives, our Fortunes and our sacred Honor.\n"}; -TEST_F(GemmaTest, CrossEntropySmall) { +TEST_F(GemmaTest, DISABLED_CrossEntropySmall) { static const char kSmall[] = "The capital of Hungary is Budapest which is located in Europe."; float entropy = GemmaCrossEntropy(kSmall); @@ -288,19 +288,19 @@ TEST_F(GemmaTest, CrossEntropySmall) { EXPECT_LT(entropy, 1.6f); } -TEST_F(GemmaTest, CrossEntropyJingleBells) { +TEST_F(GemmaTest, DISABLED_CrossEntropyJingleBells) { float entropy = GemmaCrossEntropy(kJingleBells); std::cout << "per-byte entropy: " << entropy << "\n"; EXPECT_LT(entropy, 2.3f); } -TEST_F(GemmaTest, CrossEntropyGettysburg) { +TEST_F(GemmaTest, DISABLED_CrossEntropyGettysburg) { float entropy = GemmaCrossEntropy(kGettysburg); std::cout << "per-byte entropy: " << entropy << "\n"; EXPECT_LT(entropy, 1.2f); } -TEST_F(GemmaTest, CrossEntropyDeclaration) { +TEST_F(GemmaTest, DISABLED_CrossEntropyDeclaration) { float entropy = GemmaCrossEntropy(kDeclaration); std::cout << "per-byte entropy: " << entropy << "\n"; EXPECT_LT(entropy, 1.0f); diff --git a/gemma/ops.h b/gemma/ops.h index 6401cc4..7ac73df 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -874,10 +874,14 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings( consecutive pair of dimensions of v i.e. v_{2i} and v_{2i+1} by an angle m*theta_i. However in the Gemma implementation we choose to rotate the pairs of dimensions v_{i} and v_{i + d//2} instead. + + pos parameter is deliberately an int because in the backward pass we + call this with negative values (for the VJP calculation we need the transpose + of this rotation matrix which is simply the same matrix with -pos parameter) */ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x, - size_t dim_qkv, size_t pos) { + size_t dim_qkv, int pos) { HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; for (size_t dim = 0; dim < half_dim_qkv; ++dim) { @@ -898,7 +902,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x, static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul, float* HWY_RESTRICT x, size_t dim_qkv, - size_t pos) { + int pos) { HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; for (size_t dim = 0; dim < half_dim_qkv; ++dim) { diff --git a/gemma/weights.cc b/gemma/weights.cc new file mode 100644 index 0000000..aa302e2 --- /dev/null +++ b/gemma/weights.cc @@ -0,0 +1,112 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gemma/weights.h" + +#include "gemma/common.h" +#include "gemma/configs.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/stats.h" + +namespace gcpp { + +ByteStorageT AllocateWeights(Model model, hwy::ThreadPool& pool) { + switch (model) { + case Model::GEMMA_2B: + return AllocateWeights(pool); + case Model::GEMMA_7B: + return AllocateWeights(pool); + case Model::GRIFFIN_2B: + return AllocateWeights(pool); + case Model::GEMMA_TINY: + return AllocateWeights(pool); + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } +} + +namespace { +template +void ZeroInitWeightsT(ByteStorageT& weights, hwy::ThreadPool& pool) { + ZeroInit( + *reinterpret_cast*>(weights.get())); +} +} // namespace + +void ZeroInitWeights(Model model, ByteStorageT& weights, + hwy::ThreadPool& pool) { + switch (model) { + case Model::GEMMA_2B: + ZeroInitWeightsT(weights, pool); + break; + case Model::GEMMA_7B: + ZeroInitWeightsT(weights, pool); + break; + case Model::GRIFFIN_2B: + ZeroInitWeightsT(weights, pool); + break; + case Model::GEMMA_TINY: + ZeroInitWeightsT(weights, pool); + break; + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } +} + +namespace { +void LogVec(const char* name, const float* data, size_t len) { + hwy::Stats stats; + for (size_t i = 0; i < len; ++i) { + stats.Notify(data[i]); + } + printf("%-20s %12zu %13.10f %8.5f %13.10f\n", + name, len, stats.Min(), stats.Mean(), stats.Max()); +} + +class WeightLogger { + public: + template + void operator()(const char* name, const std::array& tensor) { + LogVec(name, tensor.data(), N); + total_weights += N; + } + size_t total_weights = 0; +}; + +template +void LogWeightStats(const ByteStorageT& weights_u8) { + const auto& weights = *reinterpret_cast*>(weights_u8.get()); + WeightLogger logger; + ForEachTensor1(logger, weights); + printf("%-20s %12zu\n", "Total", logger.total_weights); +} +} // namespace + +void LogWeightStats(gcpp::Model model, const ByteStorageT& weights) { + switch (model) { + case Model::GEMMA_2B: + return LogWeightStats(weights); + case Model::GEMMA_7B: + return LogWeightStats(weights); + case Model::GRIFFIN_2B: + return LogWeightStats(weights); + case Model::GEMMA_TINY: + return LogWeightStats(weights); + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } +} + +} // namespace gcpp diff --git a/gemma/weights.h b/gemma/weights.h new file mode 100644 index 0000000..8d25759 --- /dev/null +++ b/gemma/weights.h @@ -0,0 +1,290 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_ + +#include "gemma/common.h" +#include "gemma/configs.h" +#include "hwy/aligned_allocator.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +namespace gcpp { + +template +struct Layer { + Layer() {} + static constexpr size_t kHeads = TConfig::kHeads; + static constexpr size_t kKVHeads = TConfig::kKVHeads; + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kQKVDim = TConfig::kQKVDim; + static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim; + static constexpr size_t kQKVEinsumWSize = + (kHeads + 2 * kKVHeads) * kQKVDim * kModelDim; + // 2x for (gelu gating vector, gated vector) + static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim; + static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; + static constexpr bool kFFBiases = TConfig::kFFBiases; + static constexpr bool kPostNormScale = TConfig::kPostNormScale; + static constexpr size_t kAOBiasDim = + TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0; + static constexpr size_t kGriffinDim = + TConfig::kGriffinLayers > 0 ? kModelDim : 0; + + union { + struct { + std::array attn_vec_einsum_w; + std::array qkv_einsum_w; + std::array attention_output_biases; + }; + + struct { + std::array linear_x_w; + std::array linear_x_biases; + std::array linear_y_w; + std::array linear_y_biases; + std::array linear_out_w; + std::array linear_out_biases; + std::array conv_w; + std::array conv_biases; + std::array gate_w; + std::array gate_biases; + std::array a; + } griffin; + }; + + std::array gating_einsum_w; + std::array linear_w; + std::array pre_attention_norm_scale; + std::array pre_ffw_norm_scale; + std::array post_attention_norm_scale; + std::array post_ffw_norm_scale; + + std::array ffw_gating_biases; + std::array ffw_output_biases; +}; + +template +using LayerF = Layer; + +// Array instead of single large allocation for parallel mem init. Split out of +// Weights so that only these pointers are initialized. +template +struct LayerPointers { + explicit LayerPointers(hwy::ThreadPool& pool) { + pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) { + this->layers[task] = hwy::AllocateAligned>(1); + }); + } + + using TLayer = Layer; + std::array, TConfig::kLayers> layers; +}; + +template +struct Weights { + // No ctor/dtor, allocated via AllocateAligned. + + std::array + embedder_input_embedding; + + std::array final_norm_scale; + + LayerPointers layer_ptrs; + + std::array scales; + + const Layer* GetLayer(size_t layer) const { + return layer_ptrs.layers[layer].get(); + } + Layer* GetLayer(size_t layer) { + return layer_ptrs.layers[layer].get(); + } +}; + +template +using WeightsF = Weights; + +template +ByteStorageT AllocateWeights(hwy::ThreadPool& pool) { + using TWeights = Weights; + ByteStorageT weights_u8 = hwy::AllocateAligned(sizeof(TWeights)); + TWeights* weights = reinterpret_cast(weights_u8.get()); + new (&weights->layer_ptrs) LayerPointers(pool); + return weights_u8; +} + +#define GEMMA_CALL_TOP_FUNC1(name, member) func(name, weights1.member) +#define GEMMA_CALL_TOP_FUNC2(name, member) \ + func(name, weights1.member, weights2.member) +#define GEMMA_CALL_TOP_FUNC3(name, member) \ + func(name, weights1.member, weights2.member, weights3.member) +#define GEMMA_CALL_TOP_FUNC4(name, member) \ + func(name, weights1.member, weights2.memeber, \ + weights3.member, weights4.member) + +#define GEMMA_CALL_LAYER_FUNC1(name, member) \ + snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ + func(name_buf, layer1.member) + +#define GEMMA_CALL_LAYER_FUNC2(name, member) \ + snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ + func(name_buf, layer1.member, layer2.member) + +#define GEMMA_CALL_LAYER_FUNC3(name, member) \ + snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ + func(name_buf, layer1.member, layer2.member, layer3.member) + +#define GEMMA_CALL_LAYER_FUNC4(name, member) \ + snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ + func(name_buf, layer1.member, layer2.member, layer4.member) + +#define GEMMA_CALL_ALL_LAYER_FUNC(N) \ + if (type == LayerAttentionType::kGemma) { \ + GEMMA_CALL_LAYER_FUNC ## N("att_ein", attn_vec_einsum_w); \ + GEMMA_CALL_LAYER_FUNC ## N("qkv_ein", qkv_einsum_w); \ + } else { \ + GEMMA_CALL_LAYER_FUNC ## N("gr_lin_x_w", griffin.linear_x_w); \ + GEMMA_CALL_LAYER_FUNC ## N("gr_lin_x_b", griffin.linear_x_biases); \ + GEMMA_CALL_LAYER_FUNC ## N("gr_lin_y_w", griffin.linear_y_w); \ + GEMMA_CALL_LAYER_FUNC ## N("gr_lin_y_b", griffin.linear_y_biases); \ + GEMMA_CALL_LAYER_FUNC ## N("gr_lin_out_w", griffin.linear_out_w); \ + GEMMA_CALL_LAYER_FUNC ## N("gr_lin_out_b", griffin.linear_out_biases); \ + GEMMA_CALL_LAYER_FUNC ## N("gr_conv_w", griffin.conv_w); \ + GEMMA_CALL_LAYER_FUNC ## N("gr_conv_b", griffin.conv_biases); \ + GEMMA_CALL_LAYER_FUNC ## N("gr_gate_w", griffin.gate_w); \ + GEMMA_CALL_LAYER_FUNC ## N("gr_gate_b", griffin.gate_biases); \ + GEMMA_CALL_LAYER_FUNC ## N("gr_a", griffin.a); \ + } \ + GEMMA_CALL_LAYER_FUNC ## N("gating_ein", gating_einsum_w); \ + GEMMA_CALL_LAYER_FUNC ## N("linear_w", linear_w); \ + GEMMA_CALL_LAYER_FUNC ## N("pre_att_ns", pre_attention_norm_scale); \ + if (TConfig::kPostNormScale) { \ + GEMMA_CALL_LAYER_FUNC ## N("post_att_ns", post_attention_norm_scale); \ + GEMMA_CALL_LAYER_FUNC ## N("post_ff_ns", post_ffw_norm_scale); \ + } \ + GEMMA_CALL_LAYER_FUNC ## N("pre_ff_ns", pre_ffw_norm_scale); \ + if (TConfig::kFFBiases) { \ + GEMMA_CALL_LAYER_FUNC ## N("ffw_gat_b", ffw_gating_biases); \ + GEMMA_CALL_LAYER_FUNC ## N("ffw_out_b", ffw_output_biases); \ + } \ + if (TConfig::kSoftmaxAttnOutputBiases && \ + type == LayerAttentionType::kGemma) { \ + GEMMA_CALL_LAYER_FUNC ## N("attn_ob", attention_output_biases); \ + } + +template +void ForEachTensor1(Func& func, const Weights& weights1) { + GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding); + GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale); + char name_buf[16]; + for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { + auto type = TConfig::kLayerConfig[layer_idx]; + const size_t idx = static_cast(layer_idx); + const LayerF& layer1 = *weights1.GetLayer(idx); + GEMMA_CALL_ALL_LAYER_FUNC(1) + } +} + +template +void ForEachTensor1(Func& func, Weights& weights1) { + GEMMA_CALL_TOP_FUNC1("embedding", embedder_input_embedding); + GEMMA_CALL_TOP_FUNC1("final_norm", final_norm_scale); + char name_buf[16]; + for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { + auto type = TConfig::kLayerConfig[layer_idx]; + const size_t idx = static_cast(layer_idx); + LayerF& layer1 = *weights1.GetLayer(idx); + GEMMA_CALL_ALL_LAYER_FUNC(1) + } +} + +template +void ForEachTensor2(Func& func, const Weights& weights1, + Weights& weights2) { + GEMMA_CALL_TOP_FUNC2("embedding", embedder_input_embedding); + GEMMA_CALL_TOP_FUNC2("final_norm", final_norm_scale); + char name_buf[16]; + for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { + auto type = TConfig::kLayerConfig[layer_idx]; + const size_t idx = static_cast(layer_idx); + const LayerF& layer1 = *weights1.GetLayer(idx); + LayerF& layer2 = *weights2.GetLayer(idx); + GEMMA_CALL_ALL_LAYER_FUNC(2) + } +} + +#undef GEMMA_CALL_TOP_FUNC1 +#undef GEMMA_CALL_TOP_FUNC2 +#undef GEMMA_CALL_TOP_FUNC3 +#undef GEMMA_CALL_TOP_FUNC4 +#undef GEMMA_CALL_LAYER_FUNC1 +#undef GEMMA_CALL_LAYER_FUNC2 +#undef GEMMA_CALL_LAYER_FUNC3 +#undef GEMMA_CALL_LAYER_FUNC4 +#undef GEMMA_CALL_ALL_LAYER_FUNC + +template +void ZeroInit(Weights& w) { + hwy::ZeroBytes(&w.embedder_input_embedding, + sizeof(w.embedder_input_embedding)); + hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale)); + for (int i = 0; i < TConfig::kLayers; ++i) { + hwy::ZeroBytes(w.GetLayer(i), sizeof(*w.GetLayer(i))); + } +} + +template +void Copy(Weights& dst, const Weights& src) { + hwy::CopyBytes(&src.embedder_input_embedding, &dst.embedder_input_embedding, + sizeof(src.embedder_input_embedding)); + hwy::CopyBytes(&src.final_norm_scale, &dst.final_norm_scale, + sizeof(src.final_norm_scale)); + for (int i = 0; i < TConfig::kLayers; ++i) { + hwy::CopyBytes(src.GetLayer(i), dst.GetLayer(i), sizeof(*dst.GetLayer(i))); + } +} + +// Owns weights and undoes the type erasure of AllocateWeights. +template +class WeightsWrapper { + public: + WeightsWrapper() + : pool_(0), data_(AllocateWeights(pool_)), + weights_(reinterpret_cast*>(data_.get())) {} + + const Weights& get() const { return *weights_; } + Weights& get() { return *weights_; } + void clear() { ZeroInit(get()); } + void copy(const WeightsWrapper& other) { + Copy(get(), other.get()); + } + + private: + hwy::ThreadPool pool_; + ByteStorageT data_; + Weights* weights_; +}; + +ByteStorageT AllocateWeights(Model model, hwy::ThreadPool& pool); + +void ZeroInitWeights(Model model, ByteStorageT& weights, hwy::ThreadPool& pool); + +void LogWeightStats(Model model, const ByteStorageT& weights); + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_