From 36e4d8bbfe336e052ba93ae5a43f63b5cdbac280 Mon Sep 17 00:00:00 2001 From: Zoltan Szabadka Date: Mon, 3 Jun 2024 14:42:35 +0000 Subject: [PATCH] Add first version of backpropagation support. This is still in progress / experimental, currently it is only implemented for normal gemma MQA attention layers, and no parallelism is added yet for backward pass. Since we need to remember all activations from all layers, the forward pass was also reimplemented with a new activation data structure. --- CMakeLists.txt | 24 ++ compression/compress-inl.h | 15 +- gemma/activations.cc | 38 +++ gemma/activations.h | 84 +++++ gemma/backward-inl.h | 412 +++++++++++++++++++++++ gemma/backward.cc | 88 +++++ gemma/backward.h | 34 ++ gemma/backward_scalar.h | 351 +++++++++++++++++++ gemma/backward_scalar_test.cc | 610 ++++++++++++++++++++++++++++++++++ gemma/backward_test.cc | 270 +++++++++++++++ gemma/common-inl.h | 71 ++++ gemma/common.h | 30 ++ gemma/common_scalar.cc | 50 +++ gemma/common_scalar.h | 119 +++++++ gemma/configs.h | 32 ++ gemma/forward-inl.h | 290 ++++++++++++++++ gemma/forward.cc | 79 +++++ gemma/forward.h | 33 ++ gemma/forward_scalar.h | 284 ++++++++++++++++ gemma/gemma.cc | 266 +++++++-------- gemma/gemma.h | 20 +- gemma/gemma_test.cc | 14 +- gemma/ops.h | 4 +- gemma/optimize_test.cc | 127 +++++++ gemma/optimizer.cc | 121 +++++++ gemma/optimizer.h | 34 ++ gemma/prompt.h | 33 ++ gemma/sampler.h | 84 +++++ gemma/test_util.h | 195 +++++++++++ gemma/weights.cc | 116 +++++++ gemma/weights.h | 288 ++++++++++++++++ 31 files changed, 4061 insertions(+), 155 deletions(-) create mode 100644 gemma/activations.cc create mode 100644 gemma/activations.h create mode 100644 gemma/backward-inl.h create mode 100644 gemma/backward.cc create mode 100644 gemma/backward.h create mode 100644 gemma/backward_scalar.h create mode 100644 gemma/backward_scalar_test.cc create mode 100644 gemma/backward_test.cc create mode 100644 gemma/common-inl.h create mode 100644 gemma/common.h create mode 100644 gemma/common_scalar.cc create mode 100644 gemma/common_scalar.h create mode 100644 gemma/forward-inl.h create mode 100644 gemma/forward.cc create mode 100644 gemma/forward.h create mode 100644 gemma/forward_scalar.h create mode 100644 gemma/optimize_test.cc create mode 100644 gemma/optimizer.cc create mode 100644 gemma/optimizer.h create mode 100644 gemma/prompt.h create mode 100644 gemma/sampler.h create mode 100644 gemma/test_util.h create mode 100644 gemma/weights.cc create mode 100644 gemma/weights.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a07305..2017dcd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,9 +47,27 @@ set(SOURCES compression/sfp-inl.h compression/test_util.h gemma/configs.h + gemma/activations.cc + gemma/activations.h + gemma/backward.cc + gemma/backward.h + gemma/backward-inl.h + gemma/backward_scalar.h + gemma/common.h + gemma/common-inl.h + gemma/common_scalar.cc + gemma/common_scalar.h + gemma/forward.cc + gemma/forward.h + gemma/forward-inl.h + gemma/forward_scalar.h gemma/gemma.cc gemma/gemma.h gemma/ops.h + gemma/optimizer.cc + gemma/optimizer.h + gemma/weights.cc + gemma/weights.h util/app.h util/args.h ) @@ -100,9 +118,15 @@ target_link_libraries(debug_prompt libgemma hwy hwy_contrib nlohmann_json::nlohm set(GEMMA_ENABLE_TESTS OFF CACHE BOOL "Enable Gemma tests") if (GEMMA_ENABLE_TESTS) +enable_testing() +include(GoogleTest) + set(GEMMA_TEST_FILES gemma/ops_test.cc gemma/gemma_test.cc + gemma/backward_test.cc + gemma/backward_scalar_test.cc + gemma/optimize_test.cc ) foreach (TESTFILE IN LISTS GEMMA_TEST_FILES) diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 1f18765..2bac834 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -484,9 +484,10 @@ HWY_INLINE void Decompress(const CompressedArray& compressed, const hn::ScalableTag d; const size_t ofs = idx_batch * kBatch; - const size_t num = idx_batch == num_batches - 1 ? (num - ofs) : kBatch; + const size_t batch = + idx_batch == num_batches - 1 ? (num - ofs) : kBatch; Traits::Decompress(d, compressed.size(), compressed.data(), - compressed_ofs + ofs, out + ofs, num); + compressed_ofs + ofs, out + ofs, batch); }); const double t1 = hwy::platform::Now(); @@ -495,6 +496,16 @@ HWY_INLINE void Decompress(const CompressedArray& compressed, fprintf(stderr, "Decompress %.1f MB/s\n", mbps); } +// Returns dot product with `vec_aligned` of length `num`. +template +HWY_INLINE float Dot(DF df, const std::array& w, size_t ofs, + const VecT* x, size_t num) { + HWY_DASSERT(ofs + num <= kCapacity); + HWY_DASSERT(hn::IsAligned(df, x)); + using Traits = CompressTraits; + return Traits::Dot(df, w.size(), w.data(), ofs, x, num); +} + // Returns dot product with `vec_aligned` of length `num`. template HWY_INLINE float Dot(DF df, const CompressedArray& compressed, diff --git a/gemma/activations.cc b/gemma/activations.cc new file mode 100644 index 0000000..56f8bde --- /dev/null +++ b/gemma/activations.cc @@ -0,0 +1,38 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gemma/activations.h" + +#include "gemma/common.h" +#include "gemma/configs.h" + +namespace gcpp { + +ByteStorageT AllocateForwardPass(Model model) { + switch (model) { + case Model::GEMMA_2B: + return ForwardPass::Allocate(); + case Model::GEMMA_7B: + return ForwardPass::Allocate(); + case Model::GRIFFIN_2B: + return ForwardPass::Allocate(); + case Model::GEMMA_TINY: + return ForwardPass::Allocate(); + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } +} + +} // namespace gcpp diff --git a/gemma/activations.h b/gemma/activations.h new file mode 100644 index 0000000..e72f3d6 --- /dev/null +++ b/gemma/activations.h @@ -0,0 +1,84 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ + +#include "gemma/common.h" +#include "hwy/aligned_allocator.h" + +namespace gcpp { + +template +struct ForwardLayer { + static constexpr size_t kSeqLen = TConfig::kSeqLen; + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kQKVDim = TConfig::kQKVDim; + static constexpr size_t kHeads = TConfig::kHeads; + static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + std::array input; + std::array pre_att_rms_out; + std::array qkv; + std::array att; + std::array att_out; + std::array att_post1; + std::array attention_out; + std::array bf_pre_ffw_rms_out; + std::array ffw_hidden; + std::array ffw_hidden_gated; +}; + +template +struct ForwardPass { + ForwardPass() {} // prevents placement-new calling memset + + static constexpr size_t kSeqLen = TConfig::kSeqLen; + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kVocabSize = TConfig::kVocabSize; + static constexpr size_t kLayers = TConfig::kLayers; + + std::array, kLayers> layers; + std::array final_layer_output; + std::array final_norm_output; + std::array logits; + std::array probs; + + static ByteStorageT Allocate() { + return hwy::AllocateAligned(sizeof(ForwardPass)); + } +}; + +template +class ActivationsWrapper { + using WrappedT = ForwardPass; + + public: + ActivationsWrapper() + : data_(WrappedT::Allocate()), + activations_(*reinterpret_cast(data_.get())) {} + + const WrappedT& get() const { return activations_; } + WrappedT& get() { return activations_; } + + private: + ByteStorageT data_; + WrappedT& activations_; +}; + +ByteStorageT AllocateForwardPass(Model model); + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ diff --git a/gemma/backward-inl.h b/gemma/backward-inl.h new file mode 100644 index 0000000..1f55ad9 --- /dev/null +++ b/gemma/backward-inl.h @@ -0,0 +1,412 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Include guard for non-SIMD code. +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_ + +#include +#include + +#include +#include + +#include "gemma/activations.h" +#include "gemma/prompt.h" +#include "gemma/weights.h" + +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_ + +// Include guard for (potentially) SIMD code. +#if defined(THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_BACkWARD_TOGGLE +#undef THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE +#else +#define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE +#endif + +#include "gemma/common-inl.h" +#include "gemma/ops.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +template +void MatMulVJP(const std::array& weights, + const float* HWY_RESTRICT x, // num_tokens * kCols + const float* HWY_RESTRICT v, // num_tokens * kRows + size_t num_tokens, + std::array& grad_w, + float* HWY_RESTRICT grad_x, // num_tokens * kCols + hwy::ThreadPool& pool) { + memset(grad_x, 0, num_tokens * kCols * sizeof(grad_x[0])); + for (size_t pos = 0; pos < num_tokens; ++pos) { + const size_t voffs = pos * kRows; + const size_t xoffs = pos * kCols; + for (size_t j = 0; j < kRows; ++j) { + MulByConstAndAdd(v[voffs + j], &x[xoffs], &grad_w[j * kCols], kCols); + MulByConstAndAdd(v[voffs + j], &weights[j * kCols], &grad_x[xoffs], + kCols); + } + } +} + +template +void MultiHeadMatMulVJP( + const std::array& weights, + const float* HWY_RESTRICT x, // num_tokens * kHeads * kCols + const float* HWY_RESTRICT v, // num_tokens * kRows + size_t num_tokens, + std::array& grad_w, + float* HWY_RESTRICT grad_x, // num_tokens * kHeads * kCols + hwy::ThreadPool& pool) { + memset(grad_x, 0, num_tokens * kHeads * kCols * sizeof(grad_x[0])); + for (size_t pos = 0; pos < num_tokens; ++pos) { + for (size_t j = 0; j < kRows; ++j) { + for (size_t h = 0; h < kHeads; ++h) { + MulByConstAndAdd(v[pos * kRows + j], + &x[pos * kHeads * kCols + h * kCols], + &grad_w[h * kRows * kCols + j * kCols], kCols); + MulByConstAndAdd(v[pos * kRows + j], + &weights[h * kRows * kCols + j * kCols], + &grad_x[pos * kHeads * kCols + h * kCols], kCols); + } + } + } +} + +template +static HWY_INLINE hn::Vec DGelu(D d, hn::Vec v) { + const hn::Vec kMul = hn::Set(d, 0.044715f); + const hn::Vec kSqrt2OverPi = hn::Set(d, 0.797884560804236f); + const hn::Vec kHalf = hn::Set(d, 0.5f); + const hn::Vec kOne = hn::Set(d, 1.0f); + // kSqrtOverPi*3*kMul + const hn::Vec kMulv2 = hn::Set(d, 0.1070322244f); + + const hn::Vec v2 = hn::Mul(v, v); + const hn::Vec v3 = hn::Mul(v2, v); + const hn::Vec arg = hn::Mul(kSqrt2OverPi, hn::MulAdd(kMul, v3, v)); + const hn::Vec tanh = hn::Tanh(d, arg); + const hn::Vec cdf = hn::MulAdd(kHalf, tanh, kHalf); + const hn::Vec dtanh = hn::Sub(kOne, hn::Mul(tanh, tanh)); + const hn::Vec darg = hn::MulAdd(kMulv2, v2, kSqrt2OverPi); + return hn::MulAdd(kHalf, hn::Mul(v, hn::Mul(dtanh, darg)), cdf); +} + +static HWY_NOINLINE void SoftmaxVJP(const float* HWY_RESTRICT forward, + float* HWY_RESTRICT backward, + const size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + const D d; + + const auto offset = + hn::Set(d, hn::Dot::Compute<0>(d, forward, backward, size)); + hn::Transform1( + d, backward, size, forward, + [&offset](const auto d, const auto v, const auto y) + HWY_ATTR { return hn::Mul(y, hn::Sub(v, offset)); }); +} + +static HWY_NOINLINE void RMSNormVJP( + const float* HWY_RESTRICT weights, const float* HWY_RESTRICT x, + const float* HWY_RESTRICT v, size_t model_dim, size_t num_tokens, + float* HWY_RESTRICT grad_w, float* HWY_RESTRICT grad_x, + hwy::ThreadPool& pool) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + const size_t offset = pos * model_dim; + constexpr float eps = 1e-6f; + float ss = SquaredL2(x + offset, model_dim); + ss = 1.0f / sqrtf(ss / StaticCast(model_dim) + eps); + for (size_t i = 0; i < model_dim; ++i) { + grad_w[i] += v[offset + i] * x[offset + i] * ss; + } + const float ss3 = ss * ss * ss / StaticCast(model_dim); + float tmp = 0.0f; + for (size_t i = 0; i < model_dim; ++i) { + tmp += (1.0f + weights[i]) * v[offset + i] * x[offset + i]; + } + tmp *= ss3; + for (size_t i = 0; i < model_dim; ++i) { + grad_x[offset + i] = ss * (1.0f + weights[i]) * v[offset + i] - + tmp * x[offset + i]; + } + } +} + +static HWY_NOINLINE void InputEmbeddingVJP( + const float* weights, const std::vector& prompt, + const float scaling, const float* HWY_RESTRICT backward, + float* HWY_RESTRICT grad, size_t model_dim) { + for (size_t pos = 0; pos + 1 < prompt.size(); ++pos) { + int token = prompt[pos]; + MulByConstAndAdd(scaling, backward + pos * model_dim, + grad + token * model_dim, model_dim); + } +} + +template +void LayerVJP(const Layer& weights, + const ForwardLayer& forward, + const float* HWY_RESTRICT next_layer_grad, + size_t num_tokens, + Layer& grad, + ForwardLayer& backward, + hwy::ThreadPool& pool) { + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kQKVDim = TConfig::kQKVDim; + static constexpr size_t kHeads = TConfig::kHeads; + static constexpr size_t kSeqLen = TConfig::kSeqLen; + static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + static const float kQueryScale = + static_cast(1.0 / sqrt(static_cast(kQKVDim))); + HWY_ASSERT(num_tokens <= kSeqLen); + + MatMulVJP( + weights.linear_w, forward.ffw_hidden_gated.data(), next_layer_grad, + num_tokens, grad.linear_w, backward.ffw_hidden_gated.data(), + pool); + + for (size_t pos = 0; pos < num_tokens; ++pos) { + const size_t hidden_offset = pos * kFFHiddenDim * 2; + const float* HWY_RESTRICT f_out = forward.ffw_hidden.data() + hidden_offset; + const float* HWY_RESTRICT f_out_mul = f_out + kFFHiddenDim; + const float* HWY_RESTRICT b_out_gated = + backward.ffw_hidden_gated.data() + pos * kFFHiddenDim; + float* HWY_RESTRICT b_out = backward.ffw_hidden.data() + hidden_offset; + float* HWY_RESTRICT b_out_mul = b_out + kFFHiddenDim; + namespace hn = hwy::HWY_NAMESPACE; + using DF = hn::ScalableTag; + using VF = hn::Vec; + DF df; + for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) { + const auto y = Load(df, f_out + i); + const auto x = Load(df, f_out_mul + i); + const auto v = Load(df, b_out_gated + i); + hn::Store(hn::Mul(v, Gelu(df, y)), df, b_out_mul + i); + hn::Store(hn::Mul(v, hn::Mul(x, DGelu(df, y))), df, b_out + i); + } + } + + MatMulVJP( + weights.gating_einsum_w, + forward.bf_pre_ffw_rms_out.data(), backward.ffw_hidden.data(), + num_tokens, grad.gating_einsum_w, + backward.bf_pre_ffw_rms_out.data(), pool); + RMSNormVJP(weights.pre_ffw_norm_scale.data(), + forward.attention_out.data(), + backward.bf_pre_ffw_rms_out.data(), + kModelDim, num_tokens, + grad.pre_ffw_norm_scale.data(), + backward.attention_out.data(), pool); + + for (size_t pos = 0; pos < num_tokens; ++pos) { + AddFrom(next_layer_grad + pos * kModelDim, + backward.attention_out.data() + pos * kModelDim, kModelDim); + } + + hwy::ZeroBytes(backward.qkv.data(), + num_tokens * (kHeads + 2) * kQKVDim * sizeof(backward.qkv[0])); + + MultiHeadMatMulVJP( + weights.attn_vec_einsum_w, forward.att_out.data(), + backward.attention_out.data(), num_tokens, + grad.attn_vec_einsum_w, backward.att_out.data(), pool); + + for (size_t head = 0; head < kHeads; ++head) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen; + const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset; + const float* HWY_RESTRICT b_att_out = + backward.att_out.data() + (pos * kHeads + head) * kQKVDim; + float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset; + for (size_t pos2 = 0; pos2 <= pos; ++pos2) { + const size_t v2offs = (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim; + const float* HWY_RESTRICT f_v2 = forward.qkv.data() + v2offs; + float* HWY_RESTRICT b_v2 = backward.qkv.data() + v2offs; + b_head_att[pos2] = Dot(b_att_out, f_v2, kQKVDim); + MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, kQKVDim); + } + } + } + + for (size_t head = 0; head < kHeads; ++head) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen; + const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset; + float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset; + SoftmaxVJP(f_head_att, b_head_att, pos + 1); + } + } + + for (size_t head = 0; head < kHeads; ++head) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + const size_t qoffs = (pos * (kHeads + 2) + head) * kQKVDim; + const size_t aoffs = head * kSeqLen + pos * kHeads * kSeqLen; + const float* HWY_RESTRICT f_q = forward.qkv.data() + qoffs; + const float* HWY_RESTRICT b_head_att = backward.att.data() + aoffs; + float* HWY_RESTRICT b_q = backward.qkv.data() + qoffs; + for (size_t pos2 = 0; pos2 <= pos; ++pos2) { + const size_t k2offs = (pos2 * (kHeads + 2) + kHeads) * kQKVDim; + const float* HWY_RESTRICT f_k2 = forward.qkv.data() + k2offs; + float* HWY_RESTRICT b_k2 = backward.qkv.data() + k2offs; + MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, kQKVDim); + MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, kQKVDim); + } + } + } + + for (int pos = 0; pos < num_tokens; ++pos) { + float* HWY_RESTRICT b_kv = + backward.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim; + Rope(b_kv, kQKVDim, -pos); + } + + for (size_t head = 0; head < kHeads; ++head) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + float* HWY_RESTRICT b_q = + backward.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim; + MulByConst(kQueryScale, b_q, kQKVDim); + Rope(b_q, kQKVDim, -pos); + } + } + + MatMulVJP( + weights.qkv_einsum_w, forward.pre_att_rms_out.data(), + backward.qkv.data(), num_tokens, + grad.qkv_einsum_w, backward.pre_att_rms_out.data(), pool); + RMSNormVJP(weights.pre_attention_norm_scale.data(), + forward.input.data(), + backward.pre_att_rms_out.data(), + kModelDim, num_tokens, + grad.pre_attention_norm_scale.data(), + backward.input.data(), pool); + for (size_t pos = 0; pos < num_tokens; ++pos) { + AddFrom(backward.attention_out.data() + pos * kModelDim, + backward.input.data() + pos * kModelDim, kModelDim); + } +} + +static HWY_NOINLINE void SoftcapVJP(const float* HWY_RESTRICT forward, + float* HWY_RESTRICT backward, + const float cap, + const size_t size) { + namespace hn = hwy::HWY_NAMESPACE; + using D = hn::ScalableTag; + const D d; + + const auto one = hn::Set(d, 1.0f); + const auto vcap = hn::Set(d, cap); + const auto vinv_cap = hn::Div(hn::Set(d, 1.0f), vcap); + + hn::Transform1( + d, backward, size, forward, + [&](const auto d, const auto v, const auto y) HWY_ATTR { + const auto scaled = hn::Mul(vinv_cap, y); + return hn::Mul(v, hn::Sub(one, hn::Mul(scaled, scaled))); + }); +} + +static HWY_NOINLINE void CrossEntropyLossGrad( + const float* HWY_RESTRICT x, float* HWY_RESTRICT grad, + const Prompt& prompt, size_t vocab_size) { + const float scaling = -1.0 / std::log(2.0); + size_t num_tokens = prompt.tokens.size() - 1; + memset(grad, 0, num_tokens * vocab_size * sizeof(grad[0])); + for (size_t i = 0; i + 1 < prompt.tokens.size(); ++i) { + if (i + 1 < prompt.context_size) { + continue; + } + const int next_token = prompt.tokens[i + 1]; + grad[i * vocab_size + next_token] = + scaling / x[i * vocab_size + next_token]; + } +} + +template +void CrossEntropyLossBackwardPass(const Prompt& prompt, + const Weights& weights, + const ForwardPass& forward, + Weights& grad, + ForwardPass& backward, + hwy::ThreadPool& pool) { + static constexpr size_t kVocabSize = TConfig::kVocabSize; + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kSeqLen = TConfig::kSeqLen; + static constexpr size_t kLayers = TConfig::kLayers; + const float kEmbScaling = EmbeddingScaling(); + static_assert(!TConfig::kAbsolutePE); + static_assert(!TConfig::kPostNormScale); + static_assert(TConfig::kKVHeads == 1); + + HWY_DASSERT(prompt.context_size > 0); + HWY_DASSERT(prompt.context_size < prompt.tokens.size()); + const size_t num_tokens = prompt.tokens.size() - 1; + + CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt, + kVocabSize); + + for (size_t pos = 0; pos < num_tokens; ++pos) { + SoftmaxVJP(forward.probs.data() + pos * kVocabSize, + backward.logits.data() + pos * kVocabSize, + kVocabSize); + } + + for (size_t pos = 0; pos < num_tokens; ++pos) { + SoftcapVJP(forward.logits.data() + pos * kVocabSize, + backward.logits.data() + pos * kVocabSize, 30.0f, kVocabSize); + } + + MatMulVJP( + weights.embedder_input_embedding, forward.final_norm_output.data(), + backward.logits.data(), num_tokens, + grad.embedder_input_embedding, backward.final_norm_output.data(), + pool); + + RMSNormVJP(weights.final_norm_scale.data(), + forward.final_layer_output.data(), + backward.final_norm_output.data(), + kModelDim, num_tokens, + grad.final_norm_scale.data(), + backward.final_layer_output.data(), pool); + + for (int layer = static_cast(kLayers) - 1; layer >= 0; --layer) { + auto type = TConfig::kLayerConfig[layer]; + // TODO(szabadka) Implement Griffin layer vjp. + HWY_ASSERT(type == LayerAttentionType::kGemma); + float* next_layer_grad = layer + 1 < kLayers + ? backward.layers[layer + 1].input.data() + : backward.final_layer_output.data(); + LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad, + num_tokens, *grad.GetLayer(layer), backward.layers[layer], pool); + } + + InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens, + kEmbScaling, backward.layers[0].input.data(), + grad.embedder_input_embedding.data(), kModelDim); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#endif // NOLINT diff --git a/gemma/backward.cc b/gemma/backward.cc new file mode 100644 index 0000000..eb5b511 --- /dev/null +++ b/gemma/backward.cc @@ -0,0 +1,88 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gemma/backward.h" + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "gemma/backward.cc" // NOLINT +#include "hwy/foreach_target.h" // IWYU pragma: keep + +#include "gemma/backward-inl.h" +#include "gemma/weights.h" +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +template +void CrossEntropyLossBackwardPass(const Prompt& prompt, + const ByteStorageT& weights_u8, + const ByteStorageT& forward_u8, + ByteStorageT& grad_u8, + ByteStorageT& backward_u8, + hwy::ThreadPool& pool) { + using TWeights = WeightsF; + const auto& weights = *reinterpret_cast(weights_u8.get()); + auto& grad = *reinterpret_cast(grad_u8.get()); + using TAct = ForwardPass; + const auto& forward = *reinterpret_cast(forward_u8.get()); + auto& backward = *reinterpret_cast(backward_u8.get()); + CrossEntropyLossBackwardPass(prompt, weights, forward, grad, backward, pool); +} + +void CrossEntropyLossBackwardPassT(Model model, + const Prompt& prompt, + const ByteStorageT& weights, + const ByteStorageT& forward, + ByteStorageT& grad, + ByteStorageT& backward, + hwy::ThreadPool& pool) { + switch (model) { + case Model::GEMMA_2B: + CrossEntropyLossBackwardPass( + prompt, weights, forward, grad, backward, pool); + break; + case Model::GEMMA_TINY: + CrossEntropyLossBackwardPass( + prompt, weights, forward, grad, backward, pool); + break; + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } +} + +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace gcpp { + +HWY_EXPORT(CrossEntropyLossBackwardPassT); + +void CrossEntropyLossBackwardPass( + const Model& model, const Prompt& prompt, + const ByteStorageT& weights, const ByteStorageT& forward, + ByteStorageT& grad, ByteStorageT& backward, hwy::ThreadPool& pool) { + return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)( + model, prompt, weights, forward, grad, backward, pool); +} + +} // namespace gcpp +#endif // HWY_ONCE diff --git a/gemma/backward.h b/gemma/backward.h new file mode 100644 index 0000000..72281fb --- /dev/null +++ b/gemma/backward.h @@ -0,0 +1,34 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_ + +#include + +#include "gemma/common.h" +#include "gemma/prompt.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +namespace gcpp { + +void CrossEntropyLossBackwardPass( + const Model& model, const Prompt& prompt, + const ByteStorageT& weights, const ByteStorageT& forward, + ByteStorageT& grad, ByteStorageT& backward, hwy::ThreadPool& pool); + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_H_ diff --git a/gemma/backward_scalar.h b/gemma/backward_scalar.h new file mode 100644 index 0000000..b7b42b4 --- /dev/null +++ b/gemma/backward_scalar.h @@ -0,0 +1,351 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_ + +#include +#include + +#include +#include +#include + +#include "gemma/activations.h" +#include "gemma/common_scalar.h" +#include "gemma/prompt.h" +#include "gemma/weights.h" + +namespace gcpp { +template +void MatMulVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx, + size_t N, size_t M, size_t K) { + memset(dx, 0, M * K * sizeof(dx[0])); + for (size_t i = 0; i < K; ++i) { + for (size_t j = 0; j < N; ++j) { + MulByConstAndAddT(dy[i * N + j], &x[i * M], &dw[j * M], M); + MulByConstAndAddT(dy[i * N + j], &w[j * M], &dx[i * M], M); + } + } +} +template +void MultiHeadMatMulVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx, + size_t H, size_t N, size_t M, size_t K) { + memset(dx, 0, H * M * K * sizeof(dx[0])); + for (size_t i = 0; i < K; ++i) { + for (size_t j = 0; j < N; ++j) { + for (size_t h = 0; h < H; ++h) { + MulByConstAndAddT(dy[i * N + j], &x[i * H * M + h * M], + &dw[h * N * M + j * M], M); + MulByConstAndAddT(dy[i * N + j], &w[h * N * M + j * M], + &dx[i * H * M + h * M], M); + } + } + } +} + +template +void RMSNormVJPT(const T* w, const T* x, const T* dy, T* dw, T* dx, + size_t N, size_t K) { + for (size_t i = 0; i < K; ++i) { + const size_t offset = i * N; + constexpr T eps(1e-6); + T ss = SquaredL2(x + i * N, N); + ss = T(1.0) / std::sqrt(ss / T(N) + eps); + for (size_t j = 0; j < N; ++j) { + dw[j] += dy[i * N + j] * x[i * N + j] * ss; + } + const T ss3 = ss * ss * ss / T(N); + T tmp = 0.0; + for (size_t j = 0; j < N; ++j) { + tmp += (T(1.0) + w[j]) * dy[i* N + j] * x[i * N + j]; + } + tmp *= ss3; + for (size_t j = 0; j < N; ++j) { + dx[i * N + j] = ss * (T(1.0) + w[j]) * dy[i* N + j] - tmp * x[i * N + j]; + } + } +} +template +void SoftmaxVJPT(const T* y, T* dy, size_t N) { + T sum = {}; + for (size_t i = 0; i < N; ++i) { + sum += y[i] * dy[i]; + } + for (size_t i = 0; i < N; ++i) { + dy[i] = y[i] * (dy[i] - sum); + } +} +template +void SoftmaxVJPT(const T* y, T* dy, size_t N, size_t K) { + for (size_t i = 0; i < K; ++i) { + SoftmaxVJPT(y + i * N, dy + i * N, N); + } +} + +template +T GeluDerivative(T x) { + static const T kMul = 0.044715; + static const T kSqrt2OverPi = 0.797884560804236; + static const T kMul2 = kSqrt2OverPi * T(3.0) * kMul; + + const T x2 = x * x; + const T x3 = x2 * x; + const T arg = kSqrt2OverPi * (kMul * x3 + x); + const T tanh = std::tanh(arg); + const T cdf = T(0.5) * (T(1.0) + tanh); + const T dtanh = T(1.0) - tanh * tanh; + const T darg = kMul2 * x2 + kSqrt2OverPi; + return T(0.5) * x * dtanh * darg + cdf; +} + +template +void GatedGeluVJP(const T* in, const T* d_out, T* d_in, size_t N, size_t K) { + for (size_t i = 0; i < K; ++i) { + const T* x1 = in + i * 2 * N; + const T* x2 = x1 + N; + const T* v = d_out + i * N; + T* dx1 = d_in + i * 2 * N; + T* dx2 = dx1 + N; + for (size_t j = 0; j < N; ++j) { + dx1[j] = v[j] * x2[j] * GeluDerivative(x1[j]); + dx2[j] = v[j] * Gelu(x1[j]); + } + } +} + + +template +void MaskedAttentionVJP(const T* qkv, const T* doutput, T* dqkv, + size_t num_tokens, size_t kHeads, size_t kQKVDim, + size_t kSeqLen) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + const size_t offset = pos * (kHeads + 2) * kQKVDim; + memset(dqkv + offset, 0, (kHeads + 1) * kQKVDim * sizeof(qkv[0])); + } + for (size_t head = 0; head < kHeads; ++head) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + const size_t qoffs = (pos * (kHeads + 2) + head) * kQKVDim; + const size_t aoffs = head * kSeqLen + pos * kHeads * kSeqLen; + const T* q = qkv + qoffs; + const T* dout = doutput + aoffs; + T* dq = dqkv + qoffs; + for (size_t pos2 = 0; pos2 <= pos; ++pos2) { + const size_t koffs = (pos2 * (kHeads + 2) + kHeads) * kQKVDim; + const T* k = qkv + koffs; + T* dk = dqkv + koffs; + MulByConstAndAddT(dout[pos2], k, dq, kQKVDim); + MulByConstAndAddT(dout[pos2], q, dk, kQKVDim); + } + } + } +} + +template +void MaskedSoftmaxVJPT(const T* y, T* dy, size_t num_tokens, + size_t kHeads, size_t kSeqLen) { + for (size_t head = 0; head < kHeads; ++head) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + size_t offset = pos * kHeads * kSeqLen + head * kSeqLen; + SoftmaxVJPT(y + offset, dy + offset, pos + 1); + memset(dy + offset + pos + 1, 0, (kSeqLen - pos - 1) * sizeof(T)); + } + } +} + +template +void MixByAttentionVJP(const T* qkv, const T* attention, const T* doutput, + T* dqkv, T* dattention, size_t num_tokens, + size_t kHeads, size_t kQKVDim, size_t kSeqLen) { + auto v_offset = [&](size_t pos) { + return (pos * (kHeads + 2) + kHeads + 1) * kQKVDim; + }; + for (size_t pos = 0; pos < num_tokens; ++pos) { + memset(&dqkv[v_offset(pos)], 0, kQKVDim * sizeof(qkv[0])); + } + for (size_t head = 0; head < kHeads; ++head) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + const size_t offset = head * kQKVDim + pos * kHeads * kQKVDim; + const size_t aoffset = head * kSeqLen + pos * kHeads * kSeqLen; + const T* att = &attention[aoffset]; + const T* dout = &doutput[offset]; + T* datt = &dattention[aoffset]; + for (size_t pos2 = 0; pos2 <= pos; ++pos2) { + datt[pos2] = DotT(dout, &qkv[v_offset(pos2)], kQKVDim); + MulByConstAndAddT(att[pos2], dout, &dqkv[v_offset(pos2)], kQKVDim); + } + } + } +} + +template +void InputEmbeddingVJPT(const T* w, const std::vector& tokens, T scaling, + const T* dy, T* dw, size_t N) { + for (size_t i = 0; i + 1 < tokens.size(); ++i) { + int token = tokens[i]; + MulByConstAndAddT(scaling, dy + i * N, dw + token * N, N); + } +} + +template +void LayerVJP(const Layer& weights, + const ForwardLayer& forward, + const T* dy, + Layer& grad, + ForwardLayer& backward, + size_t num_tokens) { + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kSeqLen = TConfig::kSeqLen; + static constexpr size_t kQKVDim = TConfig::kQKVDim; + static constexpr size_t kHeads = TConfig::kHeads; + static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + static const T kQueryScale = 1.0 / std::sqrt(T(kQKVDim)); + + MatMulVJPT(weights.linear_w.data(), forward.ffw_hidden_gated.data(), + dy, grad.linear_w.data(), backward.ffw_hidden_gated.data(), + kModelDim, kFFHiddenDim, num_tokens); + + GatedGeluVJP(forward.ffw_hidden.data(), backward.ffw_hidden_gated.data(), + backward.ffw_hidden.data(), kFFHiddenDim, num_tokens); + + MatMulVJPT(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(), + backward.ffw_hidden.data(), grad.gating_einsum_w.data(), + backward.bf_pre_ffw_rms_out.data(), kFFHiddenDim * 2, kModelDim, + num_tokens); + + RMSNormVJPT(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(), + backward.bf_pre_ffw_rms_out.data(), + grad.pre_ffw_norm_scale.data(), backward.attention_out.data(), + kModelDim, num_tokens); + + AddFromT(dy, backward.attention_out.data(), num_tokens * kModelDim); + + MultiHeadMatMulVJPT(weights.attn_vec_einsum_w.data(), forward.att_out.data(), + backward.attention_out.data(), + grad.attn_vec_einsum_w.data(), + backward.att_out.data(), + kHeads, kModelDim, kQKVDim, num_tokens); + + MixByAttentionVJP(forward.qkv.data(), forward.att.data(), + backward.att_out.data(), backward.qkv.data(), + backward.att.data(), num_tokens, kHeads, kQKVDim, + kSeqLen); + + MaskedSoftmaxVJPT(forward.att.data(), backward.att.data(), + num_tokens, kHeads, kSeqLen); + + MaskedAttentionVJP(forward.qkv.data(), backward.att.data(), + backward.qkv.data(), num_tokens, kHeads, kQKVDim, kSeqLen); + + for (size_t pos = 0; pos < num_tokens; ++pos) { + T* qkv = backward.qkv.data() + pos * (kHeads + 2) * kQKVDim; + MulByConstT(kQueryScale, qkv, kHeads * kQKVDim); + } + + for (int pos = 0; pos < num_tokens; ++pos) { + T* qkv = backward.qkv.data() + pos * (kHeads + 2) * kQKVDim; + for (size_t h = 0; h <= kHeads; ++h) { + Rope(qkv + h * kQKVDim, kQKVDim, -pos); + } + } + + MatMulVJPT(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(), + backward.qkv.data(), grad.qkv_einsum_w.data(), + backward.pre_att_rms_out.data(), + (kHeads + 2) * kQKVDim, kModelDim, num_tokens); + RMSNormVJPT(weights.pre_attention_norm_scale.data(), forward.input.data(), + backward.pre_att_rms_out.data(), + grad.pre_attention_norm_scale.data(), + backward.input.data(), kModelDim, num_tokens); + + AddFromT(backward.attention_out.data(), backward.input.data(), + num_tokens * kModelDim); +} + +template +void SoftcapVJPT(const T* y, T* dy, size_t N) { + T cap = 30.0; + T inv_cap = T(1.0) / cap; + for (size_t i = 0; i < N; ++i) { + T scaled = y[i] * inv_cap; + dy[i] *= (T(1.0) - scaled * scaled); + } +} + +template +void CrossEntropyLossGrad(const T* x, T* dx, const Prompt& prompt, size_t V) { + T scaling = -1.0 / std::log(2.0); + size_t num_tokens = prompt.tokens.size() - 1; + memset(dx, 0, V * num_tokens * sizeof(x[0])); + for (size_t i = 0; i + 1 < prompt.tokens.size(); ++i) { + if (i + 1 < prompt.context_size) { + continue; + } + const int next_token = prompt.tokens[i + 1]; + dx[i * V + next_token] = scaling / x[i * V + next_token]; + } +} + +template +void CrossEntropyLossBackwardPass(const Prompt& prompt, + const Weights& weights, + const ForwardPass& forward, + Weights& grad, + ForwardPass& backward) { + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kVocabSize = TConfig::kVocabSize; + static constexpr size_t kLayers = TConfig::kLayers; + const size_t num_tokens = prompt.tokens.size() - 1; + + CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt, + kVocabSize); + + SoftmaxVJPT(forward.probs.data(), backward.logits.data(), + kVocabSize, num_tokens); + + SoftcapVJPT(forward.logits.data(), backward.logits.data(), + num_tokens * kVocabSize); + + MatMulVJPT(weights.embedder_input_embedding.data(), + forward.final_norm_output.data(), + backward.logits.data(), + grad.embedder_input_embedding.data(), + backward.final_norm_output.data(), + kVocabSize, kModelDim, num_tokens); + + RMSNormVJPT(weights.final_norm_scale.data(), + forward.final_layer_output.data(), + backward.final_norm_output.data(), + grad.final_norm_scale.data(), + backward.final_layer_output.data(), kModelDim, num_tokens); + + for (int layer = static_cast(kLayers) - 1; layer >= 0; --layer) { + T* next_layer_grad = layer + 1 < kLayers + ? backward.layers[layer + 1].input.data() + : backward.final_layer_output.data(); + LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad, + *grad.GetLayer(layer), backward.layers[layer], num_tokens); + } + + const T kEmbScaling = EmbeddingScaling(kModelDim); + InputEmbeddingVJPT(weights.embedder_input_embedding.data(), + prompt.tokens, kEmbScaling, + backward.layers[0].input.data(), + grad.embedder_input_embedding.data(), kModelDim); +} + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_SCALAR_H_ diff --git a/gemma/backward_scalar_test.cc b/gemma/backward_scalar_test.cc new file mode 100644 index 0000000..713665b --- /dev/null +++ b/gemma/backward_scalar_test.cc @@ -0,0 +1,610 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gemma/backward_scalar.h" + +#include +#include +#include + +#include "gemma/forward_scalar.h" +#include "gemma/sampler.h" +#include "gemma/test_util.h" +#include "gtest/gtest.h" + +namespace gcpp { + +TEST(BackPropTest, MatMulVJP) { + static const size_t kRows = 8; + static const size_t kCols = 64; + static const size_t kTokens = 5; + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + std::array weights; + std::array x; + std::array grad; + std::array dx; + std::array c_weights; + std::array c_x; + std::array c_y; + std::array dy; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(weights, 1.0 * (1 << iter), gen); + RandInit(x, 1.0 * (1 << iter), gen); + RandInit(dy, 1.0, gen); + Complexify(weights, c_weights); + Complexify(x, c_x); + auto func = [&]() { + MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens); + return DotT(dy.data(), c_y.data(), kTokens * kRows); + }; + memset(&grad, 0, sizeof(grad)); + MatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(), + kRows, kCols, kTokens); + TestGradient(dx, c_x, func, 1e-11, 1e-12,__LINE__); + TestGradient(grad, c_weights, func, 1e-14, 1e-12,__LINE__); + } +} + +TEST(BackPropTest, MultiHeadMatMulVJP) { + static const size_t kRows = 2; + static const size_t kCols = 16; + static const size_t kHeads = 4; + static const size_t kTokens = 3; + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + std::array weights; + std::array x; + std::array grad; + std::array dx; + std::array c_weights; + std::array c_x; + std::array c_y; + std::array dy; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(weights, 1.0 * (1 << iter), gen); + RandInit(x, 1.0 * (1 << iter), gen); + RandInit(dy, 1.0, gen); + Complexify(weights, c_weights); + Complexify(x, c_x); + auto func = [&]() { + MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows, + kCols, kTokens); + return DotT(dy.data(), c_y.data(), kTokens * kRows); + }; + memset(&grad, 0, sizeof(grad)); + MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), + dx.data(), kHeads, kRows, kCols, kTokens); + TestGradient(dx, c_x, func, 1e-15, 1e-13,__LINE__); + TestGradient(grad, c_weights, func, 1e-15, 1e-13,__LINE__); + } +} + +TEST(BackPropTest, RMSNormVJP) { + static const size_t K = 2; + static const size_t N = 64; + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + std::array weights; + std::array grad; + std::array x; + std::array dx; + std::array dy; + std::array c_weights; + std::array c_x; + std::array c_y; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(weights, 1.0 * (1 << iter), gen); + RandInit(x, 1.0 * (1 << iter), gen); + Complexify(weights, c_weights); + Complexify(x, c_x); + RandInit(dy, 1.0, gen); + auto func = [&]() { + RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K); + return DotT(dy.data(), c_y.data(), K * N); + }; + memset(&grad, 0, sizeof(grad)); + RMSNormVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(), + N, K); + TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__); + TestGradient(grad, c_weights, func, 1e-15, 1e-14, __LINE__); + } +} + +TEST(BackPropTest, SoftmaxVJP) { + static const size_t N = 64; + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + std::array x; + std::array dx; + std::array dy; + std::array c_x; + std::array c_y; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(x, 1.0 * (1 << iter), gen); + Complexify(x, c_x); + RandInit(dy, 1.0, gen); + auto func = [&]() { + memcpy(c_y.data(), c_x.data(), sizeof(c_x)); + Softmax(c_y.data(), N); + return DotT(dy.data(), c_y.data(), N); + }; + Softmax(x.data(), N); + memcpy(dx.data(), dy.data(), N * sizeof(dx[0])); + SoftmaxVJPT(x.data(), dx.data(), N); + TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__); + } +} + +TEST(BackPropTest, MaskedSoftmaxVJP) { + static const size_t kSeqLen = 16; + static const size_t kHeads = 2; + static const size_t kTokens = 14; + static const size_t N = kHeads * kSeqLen * kSeqLen; + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + std::array x; + std::array dy; + std::array dx = {}; + std::array c_x; + std::array c_y; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(x, 1.0 * (1 << iter), gen); + Complexify(x, c_x); + RandInit(dy, 1.0, gen); + auto func = [&]() { + memcpy(c_y.data(), c_x.data(), + kTokens * kHeads * kSeqLen * sizeof(c_x[0])); + MaskedSoftmax(c_y.data(), kTokens, kHeads, kSeqLen); + return DotT(dy.data(), c_y.data(), N); + }; + MaskedSoftmax(x.data(), kTokens, kHeads, kSeqLen); + memcpy(dx.data(), dy.data(), kTokens * kHeads * kSeqLen * sizeof(dx[0])); + MaskedSoftmaxVJPT(x.data(), dx.data(), kTokens, kHeads, kSeqLen); + TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__); + } +} + +TEST(BackPropTest, SoftcapVJP) { + static const size_t N = 64; + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + std::array x; + std::array dx; + std::array dy; + std::array c_x; + std::array c_y; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(x, 1.0 * (1 << iter), gen); + Complexify(x, c_x); + RandInit(dy, 1.0, gen); + auto func = [&]() { + memcpy(c_y.data(), c_x.data(), N * sizeof(c_x[0])); + Softcap(c_y.data(), N); + return DotT(dy.data(), c_y.data(), N); + }; + Softcap(x.data(), N); + memcpy(dx.data(), dy.data(), N * sizeof(dx[0])); + SoftcapVJPT(x.data(), dx.data(), N); + TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__); + } +} + +TEST(BackPropTest, CrossEntropyLossGrad) { + static const size_t K = 8; + static const size_t V = 64; + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + std::array x; + std::array dx; + std::array c_x; + Prompt prompt; + prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 }; + + for (int iter = 0; iter < 10; ++iter) { + prompt.context_size = 1 + (iter % 6); + RandInit(x, 1.0 * (1 << iter), gen); + Softcap(x.data(), V * K); + Softmax(x.data(), V, K); + CrossEntropyLossGrad(x.data(), dx.data(), prompt, V); + Complexify(x, c_x); + auto func = [&]() { + return CrossEntropyLoss(c_x.data(), prompt, V); + }; + TestGradient(dx, c_x, func, 1e-100, 1e-15, __LINE__); + } +} + +TEST(BackPropTest, GatedGeluVJP) { + static const size_t K = 2; + static const size_t N = 64; + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + std::array x; + std::array dx; + std::array dy; + std::array c_x; + std::array c_y; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(x, 1.0, gen); + Complexify(x, c_x); + RandInit(dy, 1.0, gen); + auto func = [&]() { + GatedGelu(c_x.data(), c_y.data(), N, K); + return DotT(dy.data(), c_y.data(), N * K); + }; + GatedGeluVJP(x.data(), dy.data(), dx.data(), N, K); + TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__); + } +} + +TEST(BackPropTest, MaskedAttentionVJP) { + static const size_t kSeqLen = 16; + static const size_t kHeads = 2; + static const size_t kQKVDim = 8; + static const size_t kTokens = 14; + static const size_t kQKVSize = kSeqLen * (kHeads + 2) * kQKVDim; + static const size_t kOutSize = kSeqLen * kHeads * kSeqLen; + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + std::array x; + std::array dx = {}; + std::array dy; + std::array c_x; + std::array c_y; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(x, 1.0, gen); + Complexify(x, c_x); + RandInit(dy, 1.0, gen); + auto func = [&]() { + MaskedAttention(c_x.data(), c_y.data(), kTokens, kHeads, kQKVDim, + kSeqLen); + return DotT(dy.data(), c_y.data(), kOutSize); + }; + MaskedAttentionVJP(x.data(), dy.data(), dx.data(), + kTokens, kHeads, kQKVDim, kSeqLen); + TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__); + } +} + +TEST(BackPropTest, MixByAttentionVJP) { + static const size_t kSeqLen = 16; + static const size_t kHeads = 2; + static const size_t kQKVDim = 8; + static const size_t kTokens = 14; + static const size_t kQKVSize = kSeqLen * (kHeads + 2) * kQKVDim; + static const size_t kAttnSize = kSeqLen * kHeads * kSeqLen; + static const size_t kOutSize = kSeqLen * kHeads * kQKVDim; + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + std::array qkv; + std::array dqkv = {}; + std::array attn; + std::array dattn = {}; + std::array dy; + std::array c_qkv; + std::array c_attn; + std::array c_y; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(qkv, 1.0, gen); + RandInit(attn, 1.0, gen); + Complexify(qkv, c_qkv); + Complexify(attn, c_attn); + RandInit(dy, 1.0, gen); + auto func = [&]() { + MixByAttention(c_qkv.data(), c_attn.data(), c_y.data(), + kTokens, kHeads, kQKVDim, kSeqLen); + return DotT(dy.data(), c_y.data(), kOutSize); + }; + MixByAttentionVJP(qkv.data(), attn.data(), dy.data(), dqkv.data(), + dattn.data(), kTokens, kHeads, kQKVDim, kSeqLen); + TestGradient(dqkv, c_qkv, func, 1e-14, 1e-15, __LINE__); + TestGradient(dattn, c_attn, func, 1e-14, 1e-15, __LINE__); + } +} + +TEST(BackPropTest, InputEmbeddingVJP) { + static const size_t kSeqLen = 8; + static const size_t kVocabSize = 4; + static const size_t kModelDim = 16; + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + std::array weights; + std::array grad; + std::array dy; + std::array c_weights; + std::array c_y; + std::vector tokens = { 0, 1, 2, 3, 0, 1, 2 }; + size_t num_tokens = tokens.size() - 1; + + for (size_t iter = 0; iter < 10; ++iter) { + RandInit(weights, 1.0, gen); + RandInit(dy, 1.0, gen); + Complexify(weights, c_weights); + auto func = [&]() { + InputEmbedding(c_weights.data(), tokens, TC(3.0), c_y.data(), kModelDim); + return DotT(dy.data(), c_y.data(), num_tokens * kModelDim); + }; + memset(&grad, 0, sizeof(grad)); + InputEmbeddingVJPT(weights.data(), tokens, 3.0, dy.data(), grad.data(), + kModelDim); + TestGradient(grad, c_weights, func, 1e-16, 1e-14, __LINE__); + } +} + +struct TestConfig { + static constexpr int kSeqLen = 18; + static constexpr int kVocabSize = 12; + static constexpr int kModelDim = 32; + static constexpr int kHeads = 3; + static constexpr int kQKVDim = 12; + static constexpr int kFFHiddenDim = 48; + static constexpr std::array kLayerConfig = + FixedLayerConfig<2>(LayerAttentionType::kGemma); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr bool kAbsolutePE = false; + static constexpr bool kPostNormScale = false; + + static constexpr int kKVHeads = 1; + static constexpr int kConv1dWidth = 0; + static constexpr bool kFFBiases = false; + static constexpr bool kSoftmaxAttnOutputBiases = false; + static constexpr int kGemmaLayers = kLayers; + static constexpr int kGriffinLayers = 0; + static constexpr int kNumTensorScales = 0; +}; + +TEST(BackPropTest, LayerVJP) { + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + const size_t kOutputSize = TestConfig::kSeqLen * TestConfig::kModelDim; + Layer weights; + Layer grad; + ForwardLayer forward; + ForwardLayer backward = {}; + Layer c_weights; + ForwardLayer c_forward; + std::array y; + std::array dy; + std::array c_y; + const size_t num_tokens = 3; + + for (size_t iter = 0; iter < 10; ++iter) { + RandInit(weights, 1.0, gen); + RandInit(forward.input, 1.0, gen); + RandInit(dy, 1.0, gen); + Complexify(weights, c_weights); + Complexify(forward.input, c_forward.input); + auto func = [&]() { + ApplyLayer(c_weights, c_forward, num_tokens, c_y.data()); + return DotT(dy.data(), c_y.data(), num_tokens * TestConfig::kModelDim); + }; + memset(&grad, 0, sizeof(grad)); + ApplyLayer(weights, forward, num_tokens, y.data()); + LayerVJP(weights, forward, dy.data(), grad, backward, num_tokens); + TestGradient(backward.input, c_forward.input, func, 1e-11, 1e-11, + __LINE__); + TestGradient(grad, c_weights, func, 1e-11); + } +} + +TEST(BackPropTest, EndToEnd) { + std::mt19937 gen(42); + using T = double; + using TC = std::complex; + WeightsWrapper weights; + WeightsWrapper grad; + ForwardPass forward; + ForwardPass backward; + WeightsWrapper c_weights; + ForwardPass c_forward; + + ReverseSequenceSampler training_task({0, 0, 1, 1}); + std::vector batch = training_task.SampleBatch(10, gen); + + for (const Prompt& prompt : batch) { + ReverseSequenceSampler::LogPrompt(prompt); + RandInit(weights.get(), 1.0, gen); + CrossEntropyLossForwardPass(prompt, weights.get(), forward); + grad.clear(); + CrossEntropyLossBackwardPass( + prompt, weights.get(), forward, grad.get(), backward); + + Complexify(weights.get(), c_weights.get()); + auto func = [&]() { + return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward); + }; + + TestGradient(grad.get(), c_weights.get(), func, 1e-11); + } +} + +template +void MulByConstAndAddT(T c, const Layer& x, + Layer& out) { + MulByConstAndAddT(c, x.pre_attention_norm_scale, + out.pre_attention_norm_scale); + MulByConstAndAddT(c, x.attn_vec_einsum_w, out.attn_vec_einsum_w); + MulByConstAndAddT(c, x.qkv_einsum_w, out.qkv_einsum_w); + MulByConstAndAddT(c, x.pre_ffw_norm_scale, out.pre_ffw_norm_scale); + MulByConstAndAddT(c, x.gating_einsum_w, out.gating_einsum_w); + MulByConstAndAddT(c, x.linear_w, out.linear_w); +} + +template +void MulByConstAndAddT(T c, const Weights& x, + Weights& out) { + static constexpr size_t kLayers = TConfig::kLayers; + MulByConstAndAddT(c, x.embedder_input_embedding, + out.embedder_input_embedding); + MulByConstAndAddT(c, x.final_norm_scale, out.final_norm_scale); + for (size_t i = 0; i < kLayers; ++i) { + MulByConstAndAddT(c, *x.GetLayer(i), *out.GetLayer(i)); + } +} + +// Evaluates forward pass on a batch. +template +T CrossEntropyLossForwardPass(const std::vector& batch, + const WeightsWrapper& weights, + ForwardPass& forward) { + T loss = 0.0; + for (const Prompt& prompt : batch) { + loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward); + } + T scale = 1.0 / batch.size(); + return loss * scale; +} + +// Evaluates forward pass on a batch by applying gradient with the given +// learning rate. Does not update weights, but uses the given tmp weights +// instead. +template +T CrossEntropyLossForwardPass(T learning_rate, + const std::vector& batch, + const WeightsWrapper& weights, + const WeightsWrapper& grad, + WeightsWrapper& tmp, + ForwardPass& forward) { + tmp.copy(weights); + const T scale = -learning_rate / batch.size(); + MulByConstAndAddT(scale, grad.get(), tmp.get()); + return CrossEntropyLossForwardPass(batch, tmp, forward); +} + +// Uses line search in the negative gradient direction to update weights. We do +// this so that we can test that each step during the gradient descent can +// decrease the objective function value. +template +T FindOptimalUpdate(const WeightsWrapper& grad, + WeightsWrapper& weights, + WeightsWrapper& tmp, + ForwardPass& forward, + const std::vector& batch, + T loss, T initial_learning_rate) { + T lr0 = initial_learning_rate; + T loss0 = CrossEntropyLossForwardPass( + lr0, batch, weights, grad, tmp, forward); + for (size_t iter = 0; iter < 30; ++iter) { + T lr1 = lr0 * 0.5; + T loss1 = CrossEntropyLossForwardPass( + lr1, batch, weights, grad, tmp, forward); + if (loss0 < loss && loss1 >= loss0) { + break; + } + loss0 = loss1; + lr0 = lr1; + } + for (size_t iter = 0; iter < 30; ++iter) { + T lr1 = lr0 * 2.0; + T loss1 = CrossEntropyLossForwardPass( + lr1, batch, weights, grad, tmp, forward); + if (loss1 >= loss0) { + break; + } + loss0 = loss1; + lr0 = lr1; + } + const T scale = -lr0 / batch.size(); + MulByConstAndAddT(scale, grad.get(), weights.get()); + return lr0; +} + +TEST(BackProptest, Convergence) { + std::mt19937 gen(42); + using T = float; + using TC = std::complex; + WeightsWrapper weights; + WeightsWrapper grad; + WeightsWrapper tmp; + ForwardPass forward; + ForwardPass backward; + WeightsWrapper c_weights; + ForwardPass c_forward; + constexpr size_t kBatchSize = 5; + ReverseSequenceSampler training_task({0, 0, 0, 1, 1}); + T learning_rate = 0.01; + + RandInit(weights.get(), T(1.0), gen); + + printf("Sample batch:\n"); + for (size_t i = 0; i < 10; ++i) { + ReverseSequenceSampler::LogPrompt(training_task.Sample(gen)); + } + + T prev_loss = std::numeric_limits::max(); + bool stop = false; + size_t step = 0; + while (!stop) { + T loss = 0.0; + grad.clear(); + std::mt19937 sgen(42); + std::vector batch = training_task.SampleBatch(kBatchSize, sgen); + for (const Prompt& prompt : batch) { + loss += CrossEntropyLossForwardPass(prompt, weights.get(), forward); + CrossEntropyLossBackwardPass( + prompt, weights.get(), forward, grad.get(), backward); + } + + if (step % 250 == 0) { + printf("Checking gradient...\n"); + Complexify(weights.get(), c_weights.get()); + auto func = [&]() { + TC scale = batch.size(); + return CrossEntropyLossForwardPass(batch, c_weights, c_forward) * scale; + }; + + TestGradient(grad.get(), c_weights.get(), func, 5e-3f); + } + + loss /= batch.size(); + EXPECT_LT(loss, prev_loss); + stop = step >= 10000 || loss < 1e-2; + if (step % 10 == 0 || stop) { + printf("step: %5zu loss: %.15f learning_rate: %.15f\n", + step, loss, learning_rate); + } + if (!stop) { + learning_rate = FindOptimalUpdate( + grad, weights, tmp, forward, batch, loss, learning_rate); + ++step; + } + prev_loss = loss; + } + EXPECT_LT(step, 1000); +} + +} // namespace gcpp diff --git a/gemma/backward_test.cc b/gemma/backward_test.cc new file mode 100644 index 0000000..9b80936 --- /dev/null +++ b/gemma/backward_test.cc @@ -0,0 +1,270 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HWY_DISABLED_TARGETS +#define HWY_DISABLED_TARGETS HWY_SCALAR +#endif + +#include + +#include +#include +#include +#include +#include + +#include "compression/compress.h" +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "gemma/backward_scalar.h" +#include "gemma/forward_scalar.h" +#include "gemma/gemma.h" +#include "gemma/sampler.h" +#include "gemma/test_util.h" +#include "gemma/weights.h" + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "gemma/backward_test.cc" //NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +#include "hwy/tests/test_util-inl.h" +// After highway.h +#include "gemma/backward-inl.h" +#include "gemma/forward-inl.h" +#include "gemma/ops.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +namespace hn = hwy::HWY_NAMESPACE; + +void TestMatMulVJP() { + static const size_t kRows = 8; + static const size_t kCols = 64; + static const size_t kTokens = 5; + hwy::ThreadPool pool(8); + std::mt19937 gen(42); + HWY_ALIGN std::array weights; + HWY_ALIGN std::array x; + HWY_ALIGN std::array dy; + HWY_ALIGN std::array grad; + HWY_ALIGN std::array dx; + HWY_ALIGN std::array grad_scalar; + HWY_ALIGN std::array dx_scalar; + using TC = std::complex; + std::array c_weights; + std::array c_x; + std::array c_y; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(weights, 1.0f * (1 << iter), gen); + RandInit(x, 1.0f * (1 << iter), gen); + RandInit(dy, 1.0f, gen); + Complexify(weights, c_weights); + Complexify(x, c_x); + auto func = [&]() { + MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens); + return DotT(dy.data(), c_y.data(), kTokens * kRows); + }; + + memset(&grad, 0, sizeof(grad)); + MatMulVJP(weights, x.data(), dy.data(), kTokens, + grad, dx.data(), pool); + TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__); + TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__); + + memset(&grad_scalar, 0, sizeof(grad_scalar)); + MatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), + dx_scalar.data(), kRows, kCols, kTokens); + TestNear(dx, dx_scalar, 0, 0, __LINE__); + TestNear(grad, grad_scalar, 0, 0, __LINE__); + } +} + +void TestMultiHeadMatMulVJP() { + static const size_t kRows = 2; + static const size_t kCols = 16; + static const size_t kHeads = 4; + static const size_t kTokens = 3; + hwy::ThreadPool pool(8); + std::mt19937 gen(42); + HWY_ALIGN std::array weights; + HWY_ALIGN std::array x; + HWY_ALIGN std::array grad; + HWY_ALIGN std::array dx; + HWY_ALIGN std::array dy; + HWY_ALIGN std::array grad_scalar; + HWY_ALIGN std::array dx_scalar; + using TC = std::complex; + std::array c_weights; + std::array c_x; + std::array c_y; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(weights, 1.0f * (1 << iter), gen); + RandInit(x, 1.0f * (1 << iter), gen); + RandInit(dy, 1.0f, gen); + Complexify(weights, c_weights); + Complexify(x, c_x); + auto func = [&]() { + MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows, + kCols, kTokens); + return DotT(dy.data(), c_y.data(), kTokens * kRows); + }; + + memset(&grad, 0, sizeof(grad)); + MultiHeadMatMulVJP( + weights, x.data(), dy.data(), kTokens, grad, dx.data(), pool); + TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__); + TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__); + + memset(&grad_scalar, 0, sizeof(grad_scalar)); + MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), + dx_scalar.data(), kHeads, kRows, kCols, kTokens); + TestNear(dx, dx_scalar, 0, 0, __LINE__); + TestNear(grad, grad_scalar, 0, 0, __LINE__); + } +} + +void TestRMSNormVJP() { + static const size_t K = 2; + static const size_t N = 64; + hwy::ThreadPool pool(8); + std::mt19937 gen(42); + HWY_ALIGN std::array weights; + HWY_ALIGN std::array x; + HWY_ALIGN std::array grad; + HWY_ALIGN std::array dx; + HWY_ALIGN std::array dy; + HWY_ALIGN std::array grad_scalar; + HWY_ALIGN std::array dx_scalar; + using TC = std::complex; + std::array c_weights; + std::array c_x; + std::array c_y; + + for (int iter = 0; iter < 10; ++iter) { + RandInit(weights, 1.0f * (1 << iter), gen); + RandInit(x, 1.0f * (1 << iter), gen); + RandInit(dy, 1.0f, gen); + Complexify(weights, c_weights); + Complexify(x, c_x); + auto func = [&]() { + RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K); + return DotT(dy.data(), c_y.data(), K * N); + }; + + memset(&grad, 0, sizeof(grad)); + RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(), + dx.data(), pool); + TestGradient(dx, c_x, func, 5e-5, 5e-5, __LINE__); + TestGradient(grad, c_weights, func, 5e-5, 5e-5, __LINE__); + + memset(&grad_scalar, 0, sizeof(grad_scalar)); + RMSNormVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), + dx_scalar.data(), N, K); + TestNear(dx, dx_scalar, 0, 2e-5, __LINE__); + TestNear(grad, grad_scalar, 0, 2e-5, __LINE__); + } +} + +struct TestConfig { + static constexpr int kSeqLen = 24; + static constexpr int kVocabSize = 16; + static constexpr int kModelDim = 32; + static constexpr int kHeads = 3; + static constexpr int kQKVDim = 16; + static constexpr int kFFHiddenDim = 64; + static constexpr std::array kLayerConfig = + FixedLayerConfig<2>(LayerAttentionType::kGemma); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr bool kAbsolutePE = false; + static constexpr bool kPostNormScale = false; + + static constexpr int kKVHeads = 1; + static constexpr int kConv1dWidth = 0; + static constexpr bool kFFBiases = false; + static constexpr bool kSoftmaxAttnOutputBiases = false; + static constexpr int kGemmaLayers = kLayers; + static constexpr int kGriffinLayers = 0; + static constexpr int kNumTensorScales = 0; +}; + +void TestEndToEnd() { + std::mt19937 gen(42); + hwy::ThreadPool pool(0); + WeightsWrapper weights; + WeightsWrapper grad; + ActivationsWrapper forward0; + ActivationsWrapper forward1; + ActivationsWrapper backward; + using TC = std::complex; + WeightsWrapper c_weights; + ForwardPass c_forward; + + ReverseSequenceSampler training_task({0, 0, 1, 1}); + std::vector batch = training_task.SampleBatch(10, gen); + + for (const Prompt& prompt : batch) { + ReverseSequenceSampler::LogPrompt(prompt); + RandInit(weights.get(), 1.0f, gen); + + float loss0 = CrossEntropyLossForwardPass( + prompt, weights.get(), forward0.get()); + + float loss1 = CrossEntropyLossForwardPass( + prompt.tokens, prompt.context_size, weights.get(), forward1.get(), + pool); + + EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 1e-5); + + grad.clear(); + CrossEntropyLossBackwardPass( + prompt, weights.get(), forward1.get(), grad.get(), backward.get(), + pool); + + Complexify(weights.get(), c_weights.get()); + auto func = [&]() { + return CrossEntropyLossForwardPass(prompt, c_weights.get(), c_forward); + }; + + TestGradient(grad.get(), c_weights.get(), func, 2e-3f); + } +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace gcpp { +HWY_BEFORE_TEST(BackwardTest); +HWY_EXPORT_AND_TEST_P(BackwardTest, TestMatMulVJP); +HWY_EXPORT_AND_TEST_P(BackwardTest, TestMultiHeadMatMulVJP); +HWY_EXPORT_AND_TEST_P(BackwardTest, TestRMSNormVJP); +HWY_EXPORT_AND_TEST_P(BackwardTest, TestEndToEnd); +#ifdef HWY_AFTER_TEST +HWY_AFTER_TEST(); +#endif + +} // namespace gcpp + +#endif diff --git a/gemma/common-inl.h b/gemma/common-inl.h new file mode 100644 index 0000000..c7f53bc --- /dev/null +++ b/gemma/common-inl.h @@ -0,0 +1,71 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Include guard for non-SIMD code. +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_ + +#include +#include + +#include +#include + +#include "gemma/activations.h" + +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +namespace gcpp { +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_ + +// Include guard for (potentially) SIMD code. +#if defined(THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE) == defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE +#undef THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE +#else +#define THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE +#endif + +#include "gemma/ops.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo` +// are both constexpr +#if HWY_COMPILER_GCC_ACTUAL +#define GEMMA_CONSTEXPR_EMBSCALING HWY_BF16_CONSTEXPR +#else +#define GEMMA_CONSTEXPR_EMBSCALING +#endif + +template +GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() { + // Round to bf16 to match Gemma's Embedder, which casts before mul. + return hwy::ConvertScalarTo(hwy::ConvertScalarTo( + Sqrt(static_cast(TConfig::kModelDim)))); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#endif // NOLINT diff --git a/gemma/common.h b/gemma/common.h new file mode 100644 index 0000000..d497259 --- /dev/null +++ b/gemma/common.h @@ -0,0 +1,30 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ + +#include "hwy/aligned_allocator.h" + +namespace gcpp { + +using ByteStorageT = hwy::AlignedFreeUniquePtr; + +// Model variants: see configs.h for details. +enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B, GEMMA_TINY }; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ diff --git a/gemma/common_scalar.cc b/gemma/common_scalar.cc new file mode 100644 index 0000000..9a82c1d --- /dev/null +++ b/gemma/common_scalar.cc @@ -0,0 +1,50 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "gemma/common_scalar.cc" // NOLINT +#include "hwy/foreach_target.h" // IWYU pragma: keep + +#include "gemma/ops.h" +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +float EmbeddingScaling(int model_dim) { + // Round to bf16 to match Gemma's Embedder, which casts before mul. + return hwy::ConvertScalarTo(hwy::ConvertScalarTo( + Sqrt(static_cast(model_dim)))); +} + +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace gcpp { + +HWY_EXPORT(EmbeddingScaling); + +float EmbeddingScaling(int model_dim) { + return HWY_DYNAMIC_DISPATCH(EmbeddingScaling)(model_dim); +} + +} // namespace gcpp +#endif // HWY_ONCE diff --git a/gemma/common_scalar.h b/gemma/common_scalar.h new file mode 100644 index 0000000..628962b --- /dev/null +++ b/gemma/common_scalar.h @@ -0,0 +1,119 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_ + +#include + +namespace gcpp { + +template +U DotT(const T* a, const U* b, size_t N) { + U sum = {}; + for (size_t i = 0; i < N; ++i) { + sum += a[i] * b[i]; + } + return sum; +} + +template<> +std::complex DotT(const float* a, const std::complex* b, + size_t N) { + std::complex sum = {}; + for (size_t i = 0; i < N; ++i) { + sum += static_cast(a[i]) * b[i]; + } + return sum; +} + +template +void MulByConstT(T c, T* x, size_t N) { + for (size_t i = 0; i < N; ++i) { + x[i] *= c; + } +} + +// out += c * x +template +void MulByConstAndAddT(T c, const T* x, T* out, size_t N) { + for (size_t i = 0; i < N; ++i) { + out[i] += c * x[i]; + } +} + +template +void MulByConstAndAddT(T c, const std::array& x, std::array& out) { + MulByConstAndAddT(c, x.data(), out.data(), N); +} + +template +void AddFromT(const T* a, T* out, size_t N) { + for (size_t i = 0; i < N; ++i) { + out[i] += a[i]; + } +} + +template +T SquaredL2(const T* x, size_t N) { + T sum = {}; + for (size_t i = 0; i < N; ++i) { + sum += x[i] * x[i]; + } + return sum; +} + +template +T Gelu(T x) { + static const T kMul = 0.044715; + static const T kSqrt2OverPi = 0.797884560804236; + + const T x3 = x * x * x; + const T arg = kSqrt2OverPi * (kMul * x3 + x); + const T cdf = T(0.5) * (T(1.0) + std::tanh(arg)); + return x * cdf; +} + +template +void Rope(T* x, U base, size_t N, int i) { + const size_t N2 = N / 2; + for (size_t dim = 0; dim < N2; ++dim) { + const T freq_exponents = T(2 * dim) / T(N); + const T timescale = std::pow(base, freq_exponents); + const T theta = T(i) / timescale; + const T cos_val = std::cos(theta); + const T sin_val = std::sin(theta); + const T x0 = x[dim]; + const T x1 = x[dim + N2]; + x[dim] = x0 * cos_val - x1 * sin_val; + x[dim + N2] = x0 * sin_val + x1 * cos_val; + } +} + +template +void Rope(T* x, size_t N, int i) { + Rope(x, T(10000.0), N, i); +} + +template +void Rope(std::complex* x, size_t N, int i) { + Rope(x, T(10000.0), N, i); +} + +float EmbeddingScaling(int model_dim); + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_ diff --git a/gemma/configs.h b/gemma/configs.h index 10de09e..7fc6477 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -142,6 +142,38 @@ struct ConfigGemma2B { using WeightT = GEMMA_WEIGHT_T; }; +struct ConfigGemmaTiny { + static constexpr int kSeqLen = 32; + static constexpr int kVocabSize = 16; + static constexpr std::array kLayerConfig = + FixedLayerConfig<3>(LayerAttentionType::kGemma); + static constexpr int kLayers = kLayerConfig.size(); + static constexpr int kGemmaLayers = + NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers); + static constexpr int kGriffinLayers = + NumLayersOfTypeBefore(kLayerConfig, + LayerAttentionType::kGriffinRecurrentBlock, + kLayers); + static constexpr int kModelDim = 64; + static constexpr int kFFHiddenDim = 128; + static constexpr int kHeads = 4; + static constexpr int kKVHeads = 1; + static constexpr int kQKVDim = 16; // query size == key size == value size + static constexpr int kTopK = gcpp::kTopK; + static constexpr bool kAbsolutePE = false; + static constexpr bool kPostNormScale = false; + + // SSM config. + static constexpr int kConv1dWidth = 0; + static constexpr bool kFFBiases = false; + static constexpr bool kSoftmaxAttnOutputBiases = false; + static constexpr bool kUseHalfRope = false; + static constexpr bool kUseLocalAttention = false; + static constexpr bool kInterleaveQKV = true; + static constexpr int kNumTensorScales = 0; + using WeightT = GEMMA_WEIGHT_T; +}; + struct ConfigGriffin2B { // Griffin uses local attention, so kSeqLen is actually the local attention // window. diff --git a/gemma/forward-inl.h b/gemma/forward-inl.h new file mode 100644 index 0000000..19f603f --- /dev/null +++ b/gemma/forward-inl.h @@ -0,0 +1,290 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Include guard for non-SIMD code. +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_ + +#include +#include + +#include +#include + +#include "gemma/activations.h" +#include "gemma/configs.h" + +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_INL_H_ + +// Include guard for (potentially) SIMD code. +#if defined(THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE) == defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE +#undef THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE +#else +#define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE +#endif + +#include "gemma/common-inl.h" +#include "gemma/ops.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +template +void InputEmbedding(const ArrayT& weights, const std::vector& prompt, + const float scaling, float* HWY_RESTRICT output, + size_t model_dim) { + for (size_t pos = 0; pos + 1 < prompt.size(); ++pos) { + int token = prompt[pos]; + Decompress(weights, token * model_dim, output + pos * model_dim, model_dim); + MulByConst(scaling, output + pos * model_dim, model_dim); + } +} + +template +void ApplyRMSNorm(const WT* HWY_RESTRICT weights, const XT* HWY_RESTRICT x, + size_t model_dim, size_t num_tokens, + OutT* HWY_RESTRICT output, + hwy::ThreadPool& pool) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + const size_t offset = pos * model_dim; + RMSNorm(x + offset, weights, output + offset, model_dim); + } +} + +static HWY_NOINLINE float CrossEntropyLoss(const float* HWY_RESTRICT probs, + const std::vector& prompt, + size_t context_size, + size_t vocab_size, + hwy::ThreadPool& pool) { + float loss = 0.0f; + for (size_t pos = 0; pos + 1 < prompt.size(); ++pos) { + if (pos + 1 < context_size) { + continue; // next token is part of context, don't try to predict it + } + const int next_token = prompt[pos + 1]; + loss += std::log(probs[pos * vocab_size + next_token]); + } + float scaling = -1.0 / std::log(2.0); + return loss * scaling; +} + +template typename LayerT> +void ApplyForwardLayer(const LayerT& weights, + ForwardLayer& activations, + size_t num_tokens, + float* HWY_RESTRICT output, + hwy::ThreadPool& pool) { + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kSeqLen = TConfig::kSeqLen; + static constexpr size_t kQKVDim = TConfig::kQKVDim; + static constexpr size_t kHeads = TConfig::kHeads; + static const float kQueryScale = + static_cast(1.0 / sqrt(static_cast(kQKVDim))); + HWY_ASSERT(num_tokens <= kSeqLen); + + ApplyRMSNorm(weights.pre_attention_norm_scale.data(), + activations.input.data(), kModelDim, num_tokens, + activations.pre_att_rms_out.data(), pool); + + for (size_t pos = 0; pos < num_tokens; ++pos) { + MatVec<(kHeads + 2) * kQKVDim, kModelDim>( + weights.qkv_einsum_w, 0, + activations.pre_att_rms_out.data() + pos * kModelDim, nullptr, + activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool); + } + const size_t num_tasks = kHeads * num_tokens; + + for (size_t pos = 0; pos < num_tokens; ++pos) { + float* HWY_RESTRICT k = + activations.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim; + Rope(k, kQKVDim, pos); + } + pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { + const size_t head = task % kHeads; + const size_t pos = task / kHeads; + float* HWY_RESTRICT q = + activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim; + Rope(q, kQKVDim, pos); + MulByConst(kQueryScale, q, kQKVDim); + }); + + pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { + const size_t head = task % kHeads; + const size_t pos = task / kHeads; + const float* HWY_RESTRICT q = + activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim; + float* HWY_RESTRICT head_att = + activations.att.data() + (pos * kHeads + head) * kSeqLen; + for (size_t pos2 = 0; pos2 <= pos; ++pos2) { + const float* HWY_RESTRICT k2 = + activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads) * kQKVDim; + const float score = Dot(q, k2, kQKVDim); + head_att[pos2] = score; + } + }); + + pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { + const size_t head = task % kHeads; + const size_t pos = task / kHeads; + float* HWY_RESTRICT head_att = + activations.att.data() + (pos * kHeads + head) * kSeqLen; + Softmax(head_att, pos + 1); + }); + + pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { + const size_t head = task % kHeads; + const size_t pos = task / kHeads; + const float* HWY_RESTRICT head_att = + activations.att.data() + (pos * kHeads + head) * kSeqLen; + float* HWY_RESTRICT att_out = + activations.att_out.data() + (pos * kHeads + head) * kQKVDim; + hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); + for (size_t pos2 = 0; pos2 <= pos; ++pos2) { + float* HWY_RESTRICT v2 = + activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim; + MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); + } + }); + + hwy::ZeroBytes(activations.attention_out.data(), + num_tokens * kModelDim * sizeof(activations.attention_out[0])); + for (size_t pos = 0; pos < num_tokens; ++pos) { + for (size_t head = 0; head < kHeads; ++head) { + MatVec( + weights.attn_vec_einsum_w, head * kModelDim * kQKVDim, + activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim, + nullptr, activations.att_post1.data() + pos * kModelDim, pool); + AddFrom(activations.att_post1.data() + pos * kModelDim, + activations.attention_out.data() + pos * kModelDim, kModelDim); + } + } + + for (size_t pos = 0; pos < num_tokens; ++pos) { + AddFrom(activations.input.data() + pos * kModelDim, + activations.attention_out.data() + pos * kModelDim, kModelDim); + } + + ApplyRMSNorm(weights.pre_ffw_norm_scale.data(), + activations.attention_out.data(), kModelDim, num_tokens, + activations.bf_pre_ffw_rms_out.data(), pool); + static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + for (size_t pos = 0; pos < num_tokens; ++pos) { + MatVec( + weights.gating_einsum_w, 0, + activations.bf_pre_ffw_rms_out.data() + pos * kModelDim, nullptr, + activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool); + } + for (size_t pos = 0; pos < num_tokens; ++pos) { + const size_t hidden_offset = pos * kFFHiddenDim * 2; + const float* HWY_RESTRICT out = + activations.ffw_hidden.data() + hidden_offset; + const float* HWY_RESTRICT out_mul = out + kFFHiddenDim; + float* HWY_RESTRICT out_gated = + activations.ffw_hidden_gated.data() + pos * kFFHiddenDim; + namespace hn = hwy::HWY_NAMESPACE; + using DF = hn::ScalableTag; + using VF = hn::Vec; + DF df; + for (size_t i = 0; i < kFFHiddenDim; i += Lanes(df)) { + const auto y = Load(df, out + i); + const auto x = Load(df, out_mul + i); + hn::Store(hn::Mul(x, Gelu(df, y)), df, out_gated + i); + } + } + for (size_t pos = 0; pos < num_tokens; ++pos) { + MatVec( + weights.linear_w, 0, + activations.ffw_hidden_gated.data() + pos * kFFHiddenDim, + nullptr, output + pos * kModelDim, pool); + } + for (size_t pos = 0; pos < num_tokens; ++pos) { + AddFrom(activations.attention_out.data() + pos * kModelDim, + output + pos * kModelDim, kModelDim); + } +} + +template typename WeightsT, + template typename LayerT> +float CrossEntropyLossForwardPass(const std::vector& prompt, + size_t context_size, + const WeightsT& weights, + ForwardPass& forward, + hwy::ThreadPool& pool) { + static constexpr size_t kVocabSize = TConfig::kVocabSize; + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kSeqLen = TConfig::kSeqLen; + static constexpr size_t kLayers = TConfig::kLayers; + const float kEmbScaling = EmbeddingScaling(); + static_assert(!TConfig::kAbsolutePE); + static_assert(!TConfig::kPostNormScale); + static_assert(TConfig::kKVHeads == 1); + + HWY_DASSERT(context_size > 0); + HWY_DASSERT(context_size < prompt.size()); + const size_t num_tokens = prompt.size() - 1; + + InputEmbedding(weights.embedder_input_embedding, prompt, kEmbScaling, + forward.layers[0].input.data(), kModelDim); + + for (size_t layer = 0; layer < kLayers; ++layer) { + auto type = TConfig::kLayerConfig[layer]; + // TODO(szabadka) Implement Griffin layer. + HWY_ASSERT(type == LayerAttentionType::kGemma); + float* HWY_RESTRICT output = layer + 1 < kLayers ? + forward.layers[layer + 1].input.data() : + forward.final_layer_output.data(); + ApplyForwardLayer( + *weights.GetLayer(layer), forward.layers[layer], + num_tokens, output, pool); + } + + ApplyRMSNorm(weights.final_norm_scale.data(), + forward.final_layer_output.data(), + kModelDim, num_tokens, forward.final_norm_output.data(), pool); + + for (size_t pos = 0; pos < num_tokens; ++pos) { + MatVec( + weights.embedder_input_embedding, 0, + forward.final_norm_output.data() + pos * kModelDim, nullptr, + forward.logits.data() + pos * kVocabSize, pool); + } + + for (size_t pos = 0; pos < num_tokens; ++pos) { + LogitsSoftCap(30.0f, forward.logits.data() + pos * kVocabSize, kVocabSize); + } + + memcpy(forward.probs.data(), forward.logits.data(), + num_tokens * kVocabSize * sizeof(forward.logits[0])); + + for (size_t pos = 0; pos < num_tokens; ++pos) { + Softmax(forward.probs.data() + pos * kVocabSize, kVocabSize); + } + + return CrossEntropyLoss(forward.probs.data(), prompt, context_size, + kVocabSize, pool); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#endif // NOLINT diff --git a/gemma/forward.cc b/gemma/forward.cc new file mode 100644 index 0000000..1ebabfd --- /dev/null +++ b/gemma/forward.cc @@ -0,0 +1,79 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gemma/forward.h" + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "gemma/forward.cc" // NOLINT +#include "hwy/foreach_target.h" // IWYU pragma: keep + +#include "gemma/forward-inl.h" +#include "gemma/weights.h" +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +template +float CrossEntropyLossForwardPass(const Prompt& prompt, + const ByteStorageT& weights_u8, + ByteStorageT& forward_u8, + hwy::ThreadPool& pool) { + const auto& weights = + *reinterpret_cast*>(weights_u8.get()); + auto& forward = + *reinterpret_cast*>(forward_u8.get()); + return CrossEntropyLossForwardPass( + prompt.tokens, prompt.context_size, weights, forward, pool); +} + +float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt, + const ByteStorageT& weights, + ByteStorageT& forward, + hwy::ThreadPool& pool) { + switch (model) { + case Model::GEMMA_2B: + return CrossEntropyLossForwardPass( + prompt, weights, forward, pool); + case Model::GEMMA_TINY: + return CrossEntropyLossForwardPass( + prompt, weights, forward, pool); + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } +} + +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace gcpp { + +HWY_EXPORT(CrossEntropyLossForwardPassT); + +float CrossEntropyLossForwardPass( + const Model& model, const Prompt& prompt, const ByteStorageT& weights, + ByteStorageT& forward, hwy::ThreadPool& pool) { + return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)( + model, prompt, weights, forward, pool); +} + +} // namespace gcpp +#endif // HWY_ONCE diff --git a/gemma/forward.h b/gemma/forward.h new file mode 100644 index 0000000..d8769f2 --- /dev/null +++ b/gemma/forward.h @@ -0,0 +1,33 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ + +#include + +#include "gemma/common.h" +#include "gemma/prompt.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +namespace gcpp { + +float CrossEntropyLossForwardPass( + const Model& model, const Prompt& prompt, const ByteStorageT& weights, + ByteStorageT& forward, hwy::ThreadPool& pool); + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_H_ diff --git a/gemma/forward_scalar.h b/gemma/forward_scalar.h new file mode 100644 index 0000000..0a62107 --- /dev/null +++ b/gemma/forward_scalar.h @@ -0,0 +1,284 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_ + +#include +#include + +#include +#include +#include + +#include "gemma/activations.h" +#include "gemma/common_scalar.h" +#include "gemma/prompt.h" +#include "gemma/weights.h" + +namespace gcpp { + +// w is N x M matrix in row-major order, x is M x K matrix in column-major order +// y = w * x is N x K matrix in column-major order. +template +void MatMulT(const T* w, const T* x, T* y, size_t N, size_t M, size_t K) { + for (size_t i = 0; i < K; ++i) { + for (size_t j = 0; j < N; ++j) { + y[i * N + j] = DotT(&w[j * M], &x[i * M], M); + } + } +} + +// w is H concatenated N x M matrix in row-major order, x is HM x K matrix in +// column-major order and y = w' * x is N x K matrix in column-major order, +// where w' is the rearrangement of w into an N x HM matrix. +template +void MultiHeadMatMul(const T* w, const T* x, T* y, size_t H, size_t N, + size_t M, size_t K) { + memset(y, 0, N * K * sizeof(y[0])); + for (size_t i = 0; i < K; ++i) { + for (size_t h = 0; h < H; ++h) { + for (size_t j = 0; j < N; ++j) { + y[i * N + j] += DotT(&w[h * N * M + j * M], &x[i * H * M + h * M], M); + } + } + } +} + +template +void RMSNormT(const T* w, const T* x, T* out, size_t N, size_t K) { + constexpr T eps(1e-6); + for (size_t i = 0; i < K; ++i) { + T ss = SquaredL2(x + i * N, N); + ss = T(1.0) / std::sqrt(ss / T(N) + eps); + for (size_t j = 0; j < N; j++) { + out[i * N + j] = (T(1.0) + w[j]) * (ss * x[i * N + j]); + } + } +} +template +void Softmax(T* x, size_t N) { + T sum = {}; + auto maxreal = std::real(x[0]); + for (size_t i = 1; i < N; ++i) { + if (std::real(x[i]) > maxreal) { + maxreal = std::real(x[i]); + } + } + for (size_t i = 0; i < N; ++i) { + x[i] = std::exp(x[i] - maxreal); + sum += x[i]; + } + T scale = T(1.0) / sum; + for (size_t i = 0; i < N; ++i) { + x[i] *= scale; + } +} +template +void Softmax(T* x, size_t N, size_t K) { + for (size_t i = 0; i < K; ++i) { + Softmax(x + i * N, N); + } +} +template +void Softcap(T* x, size_t N) { + T cap = 30.0; + T inv_cap = T(1.0) / cap; + for (size_t i = 0; i < N; ++i) { + x[i] = cap * std::tanh(x[i] * inv_cap); + } +} + +template +void GatedGelu(const T* in, T* out, size_t N, size_t K) { + for (size_t i = 0; i < K; ++i) { + const T* x1 = in + i * 2 * N; + const T* x2 = x1 + N; + T* y = out + i * N; + for (size_t j = 0; j < N; ++j) { + y[j] = x2[j] * Gelu(x1[j]); + } + } +} + +template +void InputEmbedding(const T* w, const std::vector& tokens, T scaling, + T* y, size_t N) { + for (size_t i = 0; i + 1 < tokens.size(); ++i) { + int token = tokens[i]; + memcpy(y + i * N, w + token * N, N * sizeof(y[0])); + MulByConstT(scaling, y + i * N, N); + } +} + +template +void MaskedAttention(const T* qkv, T* output, size_t num_tokens, + size_t kHeads, size_t kQKVDim, size_t kSeqLen) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + for (size_t head = 0; head < kHeads; ++head) { + const size_t qoffset = pos * (kHeads + 2) * kQKVDim; + const size_t aoffset = pos * kHeads * kSeqLen + head * kSeqLen; + const T* q = qkv + qoffset + head * kQKVDim; + for (size_t pos2 = 0; pos2 <= pos; ++pos2) { + const T* k = qkv + (pos2 * (kHeads + 2) + kHeads) * kQKVDim; + output[aoffset + pos2] = DotT(q, k, kQKVDim); + } + } + } +} +template +void MaskedSoftmax(T* x, size_t num_tokens, size_t kHeads, size_t kSeqLen) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + for (size_t head = 0; head < kHeads; ++head) { + size_t offset = pos * kHeads * kSeqLen + head * kSeqLen; + Softmax(x + offset, pos + 1); + memset(x + offset + pos + 1, 0, (kSeqLen - pos - 1) * sizeof(T)); + } + } +} +template +void MixByAttention(const T* qkv, const T* attention, T* output, + size_t num_tokens, size_t kHeads, size_t kQKVDim, + size_t kSeqLen) { + for (size_t pos = 0; pos < num_tokens; ++pos) { + for (size_t head = 0; head < kHeads; ++head) { + const T* att = &attention[pos * kHeads * kSeqLen + head * kSeqLen]; + T* out = &output[head * kQKVDim + pos * kHeads * kQKVDim]; + memset(out, 0, kQKVDim * sizeof(out[0])); + for (size_t pos2 = 0; pos2 <= pos; ++pos2) { + size_t v_offset = (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim; + const T* v = &qkv[v_offset]; + MulByConstAndAddT(att[pos2], v, out, kQKVDim); + } + } + } +} +template +void ApplyLayer(const Layer& weights, + ForwardLayer& activations, + size_t num_tokens, T* output) { + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kSeqLen = TConfig::kSeqLen; + static constexpr size_t kQKVDim = TConfig::kQKVDim; + static constexpr size_t kHeads = TConfig::kHeads; + static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + static const T kQueryScale = T(1.0) / std::sqrt(T(kQKVDim)); + + RMSNormT(weights.pre_attention_norm_scale.data(), activations.input.data(), + activations.pre_att_rms_out.data(), kModelDim, num_tokens); + + MatMulT(weights.qkv_einsum_w.data(), activations.pre_att_rms_out.data(), + activations.qkv.data(), (kHeads + 2) * kQKVDim, kModelDim, + num_tokens); + + for (size_t pos = 0; pos < num_tokens; ++pos) { + T* qkv = activations.qkv.data() + pos * (kHeads + 2) * kQKVDim; + for (size_t h = 0; h <= kHeads; ++h) { + Rope(qkv + h * kQKVDim, kQKVDim, pos); + } + } + + for (size_t pos = 0; pos < num_tokens; ++pos) { + T* qkv = activations.qkv.data() + pos * (kHeads + 2) * kQKVDim; + MulByConstT(kQueryScale, qkv, kHeads * kQKVDim); + } + + MaskedAttention(activations.qkv.data(), activations.att.data(), + num_tokens, kHeads, kQKVDim, kSeqLen); + + MaskedSoftmax(activations.att.data(), num_tokens, kHeads, kSeqLen); + + MixByAttention(activations.qkv.data(), activations.att.data(), + activations.att_out.data(), num_tokens, kHeads, kQKVDim, + kSeqLen); + + MultiHeadMatMul(weights.attn_vec_einsum_w.data(), activations.att_out.data(), + activations.attention_out.data(), kHeads, kModelDim, kQKVDim, + num_tokens); + + AddFromT(activations.input.data(), activations.attention_out.data(), + num_tokens * kModelDim); + + RMSNormT(weights.pre_ffw_norm_scale.data(), activations.attention_out.data(), + activations.bf_pre_ffw_rms_out.data(), kModelDim, num_tokens); + + MatMulT(weights.gating_einsum_w.data(), activations.bf_pre_ffw_rms_out.data(), + activations.ffw_hidden.data(), kFFHiddenDim * 2, kModelDim, + num_tokens); + + GatedGelu(activations.ffw_hidden.data(), activations.ffw_hidden_gated.data(), + kFFHiddenDim, num_tokens); + + MatMulT(weights.linear_w.data(), activations.ffw_hidden_gated.data(), + output, kModelDim, kFFHiddenDim, num_tokens); + + AddFromT(activations.attention_out.data(), output, num_tokens * kModelDim); +} + +template +T CrossEntropyLoss(const T* x, const Prompt& prompt, size_t V) { + T loss = {}; + for (size_t i = 0; i + 1 < prompt.tokens.size(); ++i) { + if (i + 1 < prompt.context_size) { + continue; // next token is part of context, don't try to predict it + } + const int next_token = prompt.tokens[i + 1]; + loss += std::log(x[i * V + next_token]); + } + T scaling = -1.0 / std::log(2.0); + return loss * scaling; +} + +template +T CrossEntropyLossForwardPass(const Prompt& prompt, + const Weights& weights, + ForwardPass& forward) { + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kVocabSize = TConfig::kVocabSize; + static constexpr size_t kLayers = TConfig::kLayers; + const size_t num_tokens = prompt.tokens.size() - 1; + + const T kEmbScaling = EmbeddingScaling(kModelDim); + InputEmbedding(weights.embedder_input_embedding.data(), prompt.tokens, + kEmbScaling, forward.layers[0].input.data(), kModelDim); + + for (size_t layer = 0; layer < kLayers; ++layer) { + T* output = layer + 1 < kLayers ? + forward.layers[layer + 1].input.data() : + forward.final_layer_output.data(); + ApplyLayer(*weights.GetLayer(layer), forward.layers[layer], num_tokens, + output); + } + + RMSNormT(weights.final_norm_scale.data(), + forward.final_layer_output.data(), + forward.final_norm_output.data(), kModelDim, num_tokens); + + MatMulT(weights.embedder_input_embedding.data(), + forward.final_norm_output.data(), + forward.logits.data(), kVocabSize, kModelDim, num_tokens); + + Softcap(forward.logits.data(), num_tokens * kVocabSize); + + memcpy(forward.probs.data(), forward.logits.data(), + num_tokens * kVocabSize * sizeof(forward.logits[0])); + Softmax(forward.probs.data(), kVocabSize, num_tokens); + + return CrossEntropyLoss(forward.probs.data(), prompt, kVocabSize); +} + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_FORWARD_SCALAR_H_ diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 52f2656..888e447 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -22,6 +22,7 @@ #include "hwy/foreach_target.h" // IWYU pragma: keep // Must come after foreach_target.h to avoid redefinition errors. #include "compression/compress-inl.h" +#include "gemma/common-inl.h" #include "gemma/ops.h" #include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/highway.h" @@ -53,6 +54,7 @@ #include "compression/io.h" // Path #include "gemma/configs.h" #include "gemma/gemma.h" +#include "gemma/weights.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -72,63 +74,6 @@ constexpr bool kShowTokenization = false; namespace gcpp { -template -struct Layer { - Layer() = default; - static constexpr size_t kHeads = TConfig::kHeads; - static constexpr size_t kKVHeads = TConfig::kKVHeads; - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kQKVDim = TConfig::kQKVDim; - static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; - static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim; - static constexpr size_t kQKVEinsumWSize = - (kHeads + 2 * kKVHeads) * kQKVDim * kModelDim; - // 2x for (gelu gating vector, gated vector) - static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim; - static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; - static constexpr bool kFFBiases = TConfig::kFFBiases; - static constexpr bool kPostNormScale = TConfig::kPostNormScale; - static constexpr size_t kAOBiasDim = - TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0; - static constexpr size_t kGriffinDim = - TConfig::kGriffinLayers > 0 ? kModelDim : 0; - - template - using ArrayT = std::array; - - union { - struct { - ArrayT attn_vec_einsum_w; - ArrayT qkv_einsum_w; - ArrayT attention_output_biases; - }; - - struct { - ArrayT linear_x_w; - ArrayT linear_x_biases; - ArrayT linear_y_w; - ArrayT linear_y_biases; - ArrayT linear_out_w; - ArrayT linear_out_biases; - ArrayT conv_w; - ArrayT conv_biases; - ArrayT gate_w; - ArrayT gate_biases; - ArrayT a; - } griffin; - }; - - ArrayT gating_einsum_w; - ArrayT linear_w; - ArrayT pre_attention_norm_scale; - ArrayT pre_ffw_norm_scale; - ArrayT post_attention_norm_scale; - ArrayT post_ffw_norm_scale; - - ArrayT ffw_gating_biases; - ArrayT ffw_output_biases; -}; - float ScaleWeights(float* data, size_t len) { float maxabs = 0.0; for (size_t i = 0; i < len; ++i) { @@ -146,41 +91,6 @@ float ScaleWeights(float* data, size_t len) { return scale; } -// Array instead of single large allocation for parallel mem init. Split out of -// Weights so that only these pointers are initialized. -template -struct LayerPointers { - explicit LayerPointers(hwy::ThreadPool& pool) { - pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) { - this->layers[task] = hwy::AllocateAligned>(1); - }); - } - - using TLayer = Layer; - std::array, TConfig::kLayers> layers; -}; - -template -struct Weights { - // No ctor/dtor, allocated via AllocateAligned. - - std::array - embedder_input_embedding; - - std::array final_norm_scale; - - LayerPointers layer_ptrs; - - std::array scales; - - const Layer* GetLayer(size_t layer) const { - return layer_ptrs.layers[layer].get(); - } - Layer* GetLayer(size_t layer) { - return layer_ptrs.layers[layer].get(); - } -}; - template hwy::AlignedFreeUniquePtr LoadWeights( const Path& checkpoint, hwy::ThreadPool& pool, @@ -191,11 +101,8 @@ hwy::AlignedFreeUniquePtr LoadWeights( checkpoint.path.c_str()); } - using TWeights = Weights; - hwy::AlignedFreeUniquePtr weights_u8 = - hwy::AllocateAligned(sizeof(TWeights)); - TWeights* weights = reinterpret_cast(weights_u8.get()); - new (&weights->layer_ptrs) LayerPointers(pool); + ByteStorageT weights_u8 = AllocateWeights(pool); + auto* weights = reinterpret_cast*>(weights_u8.get()); size_t scale_pos = 0; FILE* fptr; @@ -228,7 +135,7 @@ hwy::AlignedFreeUniquePtr LoadWeights( sizeof(weights->final_norm_scale)); for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { auto type = TConfig::kLayerConfig[layer]; - Layer* layer_view = weights->GetLayer(layer); + LayerF* layer_view = weights->GetLayer(layer); #define READ_WEIGHTS(name) \ do { \ @@ -305,7 +212,7 @@ template struct CompressedLayer { // No ctor/dtor, allocated via AllocateAligned. - using TLayer = gcpp::Layer; + using TLayer = gcpp::LayerF; using WeightT = typename TConfig::WeightT; static constexpr size_t kHeads = TLayer::kHeads; @@ -399,13 +306,13 @@ struct CompressedWeights { template using WeightsT = hwy::If, - Weights>; + WeightsF>; // Aligned. template struct Activations { static constexpr size_t kBatchSize = TBatchSize; - using LayerConfig = Layer; + using LayerConfig = LayerF; static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kQKVDim = TConfig::kQKVDim; static constexpr size_t kHeads = TConfig::kHeads; @@ -446,6 +353,16 @@ struct Activations { std::array griffin_multiplier; }; +template +struct InferenceState { + Activations prefill; + HWY_ALIGN Activations state; + + static ByteStorageT Allocate() { + return hwy::AllocateAligned(sizeof(InferenceState)); + } +}; + // GemmaImpl is a template and thus cannot be exposed in gemma.h, hence we // define an abstract base class. struct GemmaInterface { @@ -484,6 +401,8 @@ KVCache CreateKVCache(Model type) { return CreateKVCacheT(); case Model::GRIFFIN_2B: return CreateKVCacheT(); + case Model::GEMMA_TINY: + return CreateKVCacheT(); default: HWY_ABORT("Model type %d unknown.", static_cast(type)); } @@ -526,8 +445,8 @@ void DeleteLayersPtrs(CompressedWeights* c_weights) { c_weights->c_layer_ptrs.~CompressedLayerPointers(); } template -void DeleteLayersPtrs(Weights* weights) { - weights->layer_ptrs.~LayerPointers(); +void DeleteLayersPtrs(WeightsF* weights) { + weights->layer_ptrs.~LayerPointers(); } } // namespace @@ -866,8 +785,9 @@ HWY_NOINLINE void FFW(Activations& activations, // Same matrix, first and second half of rows. Could fuse into one MatVec. MatVecT( layer_weights->gating_einsum_w, kFFHiddenDim * kModelDim, vec, - layer_weights->ffw_gating_biases.data() + kFFHiddenDim, even_odd, - out_mul, pool); + TConfig::kFFBiases ? + layer_weights->ffw_gating_biases.data() + kFFHiddenDim : nullptr, + even_odd, out_mul, pool); // Gate, will go through the nonlinearity. MatVecT( layer_weights->gating_einsum_w, 0, vec, @@ -892,21 +812,6 @@ HWY_NOINLINE void FFW(Activations& activations, } } -// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo` -// are both constexpr -#if HWY_COMPILER_GCC_ACTUAL -#define GEMMA_CONSTEXPR_EMBSCALING HWY_BF16_CONSTEXPR -#else -#define GEMMA_CONSTEXPR_EMBSCALING -#endif - -template -GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() { - // Round to bf16 to match Gemma's Embedder, which casts before mul. - return hwy::ConvertScalarTo(hwy::ConvertScalarTo( - Sqrt(static_cast(TConfig::kModelDim)))); -} - template HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, const WeightArrayT& weights, @@ -1076,20 +981,15 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens, } } -template -void GenerateImpl(GemmaImpl& gemma, +template typename WeightsType> +void GenerateImpl(const WeightsType& weights, + Activations& prefill_activations, + Activations& activations, const RuntimeConfig& runtime_config, const std::vector& prompt, size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, TimingInfo& timing_info, LayersOutputT* layers_output) { static constexpr size_t kVocabSize = TConfig::kVocabSize; - Activations& activations = *gemma.state.get(); - Activations& prefill_activations = - *gemma.prefill.get(); - - const WeightsT& weights = - *reinterpret_cast*>(gemma.weights_u8.get()); - size_t prompt_size = prompt.size(); size_t max_tokens = runtime_config.max_tokens; size_t max_generated_tokens = runtime_config.max_generated_tokens; @@ -1167,7 +1067,7 @@ void GenerateImpl(GemmaImpl& gemma, activations.logits.data(), kVocabSize, *runtime_config.gen, runtime_config.temperature, runtime_config.accept_token); if (!runtime_config.stream_token(token, activations.logits[token])) { - token = EOS_ID; + token = runtime_config.eos_id; } if (generate_pos == 0) { timing_info.time_to_first_token = hwy::platform::Now() - gen_start; @@ -1177,10 +1077,10 @@ void GenerateImpl(GemmaImpl& gemma, // process the tokens of the prompt one at a time. token = prompt.at(pos_offset + 1); if (!runtime_config.stream_token(token, 0)) { - token = EOS_ID; + token = runtime_config.eos_id; } } - if (token == EOS_ID) { + if (token == runtime_config.eos_id) { if (runtime_config.verbosity >= 2) { const double gen_end = hwy::platform::Now(); timing_info.gen_tok_sec = @@ -1192,6 +1092,57 @@ void GenerateImpl(GemmaImpl& gemma, } } +template +void GenerateImpl(GemmaImpl& gemma, + const RuntimeConfig& runtime_config, + const std::vector& prompt, size_t pos, KVCache& kv_cache, + hwy::ThreadPool& pool, TimingInfo& timing_info, + LayersOutputT* layers_output) { + const WeightsT& weights = + *reinterpret_cast*>(gemma.weights_u8.get()); + GenerateImpl( + weights, *gemma.prefill.get(), *gemma.state.get(), runtime_config, prompt, + pos, kv_cache, pool, timing_info, layers_output); +} + +template +void GenerateImpl(const ByteStorageT& weights_u8, + ByteStorageT& inference_state_u8, + const RuntimeConfig& runtime_config, + const std::vector& prompt, size_t pos, + KVCache& kv_cache, hwy::ThreadPool& pool, + TimingInfo& timing_info, LayersOutputT* layers_output) { + const WeightsF& weights = + *reinterpret_cast*>(weights_u8.get()); + InferenceState& inference_state = + *reinterpret_cast*>(inference_state_u8.get()); + GenerateImpl( + weights, inference_state.prefill, inference_state.state, runtime_config, + prompt, pos, kv_cache, pool, timing_info, layers_output); +} + +void GenerateImplT(Model model, const ByteStorageT& weights_u8, + ByteStorageT& inference_state_u8, + const RuntimeConfig& runtime_config, + const std::vector& prompt, size_t pos, + KVCache& kv_cache, hwy::ThreadPool& pool, + TimingInfo& timing_info, LayersOutputT* layers_output) { + switch (model) { + case Model::GEMMA_2B: + GenerateImpl( + weights_u8, inference_state_u8, runtime_config, prompt, pos, kv_cache, + pool, timing_info, layers_output); + break; + case Model::GEMMA_TINY: + GenerateImpl( + weights_u8, inference_state_u8, runtime_config, prompt, pos, kv_cache, + pool, timing_info, layers_output); + break; + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } +} + #define TOKEN(token_id) TokenString(gemma, token_id).c_str() template @@ -1263,7 +1214,7 @@ float ComputeCrossEntropyImpl(GemmaImpl& gemma, size_t max_tokens, // // This avoids repeating the list of tensors between loading and compressing. template -void ForEachTensor(const Weights* weights, +void ForEachTensor(const WeightsF* weights, CompressedWeights& c_weights, Func& func) { func("c_embedding", weights ? weights->embedder_input_embedding.data() : nullptr, @@ -1275,7 +1226,7 @@ void ForEachTensor(const Weights* weights, for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { auto type = TConfig::kLayerConfig[layer_idx]; const size_t idx = static_cast(layer_idx); - const Layer* layer = weights ? weights->GetLayer(idx) : nullptr; + const LayerF* layer = weights ? weights->GetLayer(idx) : nullptr; CompressedLayer* layer_weights = c_weights.GetLayer(idx); #define CALL_FUNC(name, member) \ @@ -1386,14 +1337,14 @@ void CompressWeights(const Path& weights_path, const bool scale_for_compression = TConfig::kNumTensorScales > 0; const hwy::AlignedFreeUniquePtr weights_u8 = LoadWeights(weights_path, pool, scale_for_compression); - Weights* weights = - reinterpret_cast*>(weights_u8.get()); + WeightsF* weights = + reinterpret_cast*>(weights_u8.get()); Compressor compressor(pool); ForEachTensor(weights, *c_weights, compressor); compressor.AddScales(weights->scales.data(), weights->scales.size()); compressor.WriteAll(pool, compressed_weights_path); - weights->layer_ptrs.~LayerPointers(); + weights->layer_ptrs.~LayerPointers(); c_weights->c_layer_ptrs.~CompressedLayerPointers(); } @@ -1422,6 +1373,7 @@ HWY_AFTER_NAMESPACE(); namespace gcpp { HWY_EXPORT(CompressWeightsT); +HWY_EXPORT(GenerateImplT); KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len, size_t conv1d_cache_size, size_t rglru_cache_size) { @@ -1528,6 +1480,37 @@ void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config, pool.SetWaitMode(hwy::PoolWaitMode::kBlock); } +void GenerateGemma(Model model, const ByteStorageT& weights, + ByteStorageT& inference_state, + RuntimeConfig runtime_config, + const std::vector& prompt, size_t start_pos, + KVCache& kv_cache, hwy::ThreadPool& pool, + TimingInfo& timing_info) { + HWY_DYNAMIC_DISPATCH(GenerateImplT)( + model, weights, inference_state, runtime_config, prompt, start_pos, + kv_cache, pool, timing_info, /*layers_output=*/nullptr); +} + +ByteStorageT LoadWeights(const Path& weights, Model model, + hwy::ThreadPool& pool) { + return HWY_DYNAMIC_DISPATCH(LoadWeightsT)(model, weights, pool); +} + +ByteStorageT AllocateInferenceState(Model model) { + switch (model) { + case Model::GEMMA_2B: + return InferenceState::Allocate(); + case Model::GEMMA_7B: + return InferenceState::Allocate(); + case Model::GRIFFIN_2B: + return InferenceState::Allocate(); + case Model::GEMMA_TINY: + return InferenceState::Allocate(); + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } +} + void CompressWeights(gcpp::Model model, const Path& weights, const Path& compressed_weights, hwy::ThreadPool& pool) { HWY_DYNAMIC_DISPATCH(CompressWeightsT) @@ -1546,13 +1529,16 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, namespace { constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt", - "2b-it", "7b-it", "gr2b-it"}; + "2b-it", "7b-it", "gr2b-it", + "tiny"}; constexpr Model kModelTypes[] = {Model::GEMMA_2B, Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_2B, - Model::GEMMA_7B, Model::GRIFFIN_2B}; + Model::GEMMA_7B, Model::GRIFFIN_2B, + Model::GEMMA_TINY}; constexpr ModelTraining kModelTraining[] = { ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, - ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT}; + ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, + ModelTraining::GEMMA_IT}; } // namespace const char* ParseModelTypeAndTraining(const std::string& model_flag, diff --git a/gemma/gemma.h b/gemma/gemma.h index 180268e..442637d 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -23,6 +23,7 @@ #include #include "compression/io.h" // Path +#include "gemma/common.h" #include "gemma/configs.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::bfloat16_t @@ -51,8 +52,6 @@ struct KVCache { rglru_cache; // kModelDim * kGriffinLayers }; -// Model variants: see configs.h for details. -enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B }; enum class ModelTraining { GEMMA_IT, GEMMA_PT }; // Returns error string or nullptr if OK. @@ -68,6 +67,8 @@ using StreamFunc = std::function; // want to generate and True for tokens you want to generate. using AcceptFunc = std::function; +constexpr int EOS_ID = 1; + struct RuntimeConfig { size_t max_tokens; size_t max_generated_tokens; @@ -76,6 +77,7 @@ struct RuntimeConfig { std::mt19937* gen; const StreamFunc& stream_token; const AcceptFunc& accept_token; + int eos_id = EOS_ID; }; struct GemmaInterface; @@ -118,6 +120,18 @@ void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config, TimingInfo& timing_info, LayersOutputT* layers_output = nullptr); +void GenerateGemma(Model model, const ByteStorageT& weights, + ByteStorageT& inference_state, + RuntimeConfig runtime_config, + const std::vector& prompt, size_t start_pos, + KVCache& kv_cache, hwy::ThreadPool& pool, + TimingInfo& timing_info); + +ByteStorageT LoadWeights(const Path& weights, Model model, + hwy::ThreadPool& pool); + +ByteStorageT AllocateInferenceState(Model model); + void CompressWeights(gcpp::Model model, const Path& weights, const Path& compressed_weights, hwy::ThreadPool& pool); @@ -125,8 +139,6 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, const std::vector& prompt, KVCache& kv_cache, hwy::ThreadPool& pool, int verbosity); -constexpr int EOS_ID = 1; - } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_ diff --git a/gemma/gemma_test.cc b/gemma/gemma_test.cc index 9ede883..0318fde 100644 --- a/gemma/gemma_test.cc +++ b/gemma/gemma_test.cc @@ -98,7 +98,7 @@ class GemmaTest : public ::testing::Test { gcpp::Gemma model; }; -TEST_F(GemmaTest, Geography) { +TEST_F(GemmaTest, DISABLED_Geography) { static const char* kQA[][2] = { {"What is the capital of Hungary?", "Budapest"}, {"How many states does the US have?", "50"}, @@ -108,7 +108,7 @@ TEST_F(GemmaTest, Geography) { TestQuestions(kQA, kNum); } -TEST_F(GemmaTest, History) { +TEST_F(GemmaTest, DISABLED_History) { static const char* kQA[][2] = { {"When was the Battle of Hastings?", "1066"}, {"Who fought at the Battle of Marathon?", "Greek"}, @@ -117,7 +117,7 @@ TEST_F(GemmaTest, History) { TestQuestions(kQA, kNum); } -TEST_F(GemmaTest, Arithmetic) { +TEST_F(GemmaTest, DISABLED_Arithmetic) { static const char* kQA[][2] = { {"what is 13 + 14?", "27"}, {"what is 7 * 8", "56"}, @@ -280,7 +280,7 @@ static const char kDeclaration[] = { "reliance on the protection of divine Providence, we mutually pledge to " "each other our Lives, our Fortunes and our sacred Honor.\n"}; -TEST_F(GemmaTest, CrossEntropySmall) { +TEST_F(GemmaTest, DISABLED_CrossEntropySmall) { static const char kSmall[] = "The capital of Hungary is Budapest which is located in Europe."; float entropy = GemmaCrossEntropy(kSmall); @@ -288,19 +288,19 @@ TEST_F(GemmaTest, CrossEntropySmall) { EXPECT_LT(entropy, 1.6f); } -TEST_F(GemmaTest, CrossEntropyJingleBells) { +TEST_F(GemmaTest, DISABLED_CrossEntropyJingleBells) { float entropy = GemmaCrossEntropy(kJingleBells); std::cout << "per-byte entropy: " << entropy << "\n"; EXPECT_LT(entropy, 2.3f); } -TEST_F(GemmaTest, CrossEntropyGettysburg) { +TEST_F(GemmaTest, DISABLED_CrossEntropyGettysburg) { float entropy = GemmaCrossEntropy(kGettysburg); std::cout << "per-byte entropy: " << entropy << "\n"; EXPECT_LT(entropy, 1.2f); } -TEST_F(GemmaTest, CrossEntropyDeclaration) { +TEST_F(GemmaTest, DISABLED_CrossEntropyDeclaration) { float entropy = GemmaCrossEntropy(kDeclaration); std::cout << "per-byte entropy: " << entropy << "\n"; EXPECT_LT(entropy, 1.0f); diff --git a/gemma/ops.h b/gemma/ops.h index 6401cc4..5bffa84 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -877,7 +877,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings( */ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x, - size_t dim_qkv, size_t pos) { + size_t dim_qkv, int pos) { HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; for (size_t dim = 0; dim < half_dim_qkv; ++dim) { @@ -898,7 +898,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x, static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul, float* HWY_RESTRICT x, size_t dim_qkv, - size_t pos) { + int pos) { HWY_DASSERT(dim_qkv % 2 == 0); const size_t half_dim_qkv = dim_qkv / 2; for (size_t dim = 0; dim < half_dim_qkv; ++dim) { diff --git a/gemma/optimize_test.cc b/gemma/optimize_test.cc new file mode 100644 index 0000000..da9fcbb --- /dev/null +++ b/gemma/optimize_test.cc @@ -0,0 +1,127 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "gemma/activations.h" +#include "gemma/backward.h" +#include "gemma/forward.h" +#include "gemma/gemma.h" +#include "gemma/optimizer.h" +#include "gemma/sampler.h" +#include "gemma/weights.h" +#include "gtest/gtest.h" + +namespace gcpp { + +TEST(OptimizeTest, GradientDescent) { + hwy::ThreadPool pool(0); + std::mt19937 gen(42); + + Model model_type = Model::GEMMA_TINY; + ByteStorageT weights = AllocateWeights(model_type, pool); + ByteStorageT grad = AllocateWeights(model_type, pool); + ByteStorageT grad_m = AllocateWeights(model_type, pool); + ByteStorageT grad_v = AllocateWeights(model_type, pool); + ByteStorageT forward = AllocateForwardPass(model_type); + ByteStorageT backward = AllocateForwardPass(model_type); + ByteStorageT inference = AllocateInferenceState(model_type); + auto kv_cache = CreateKVCache(model_type); + size_t max_tokens = 32; + size_t max_generated_tokens = 16; + float temperature = 1.0f; + int verbosity = 0; + const auto accept_token = [](int) { return true; }; + + const auto generate = [&](const std::vector& prompt) { + std::vector reply; + auto stream_token = [&reply](int token, float) { + reply.push_back(token); + return token != ReverseSequenceSampler::kEndToken; + }; + RuntimeConfig runtime = { + max_tokens, max_generated_tokens, temperature, verbosity, &gen, + stream_token, accept_token, ReverseSequenceSampler::kEndToken, + }; + TimingInfo timing_info; + GenerateGemma(model_type, weights, inference, runtime, prompt, 0, + kv_cache, pool, timing_info); + return reply; + }; + + auto verify = [&](const Prompt& prompt) { + auto context = prompt.context(); + std::vector reply = generate(context); + bool ok = true; + for (size_t i = 0; ok && i < prompt.tokens.size(); ++i) { + if (i >= reply.size() || reply[i] != prompt.tokens[i]) { + ok = false; + } + } + return ok; + }; + + RandInitWeights(model_type, weights, pool, gen); + ZeroInitWeights(model_type, grad_m, pool); + ZeroInitWeights(model_type, grad_v, pool); + + printf("Initial weights:\n"); + LogWeightStats(model_type, weights); + + constexpr size_t kBatchSize = 8; + float learning_rate = 0.0005f; + + ReverseSequenceSampler training_task({ + 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1}); + size_t steps = 0; + float prev_loss = std::numeric_limits::max(); + size_t num_ok; + for (; steps < 1000000; ++steps) { + std::mt19937 sgen(42); + ZeroInitWeights(model_type, grad, pool); + float total_loss = 0.0f; + num_ok = 0; + for (size_t i = 0; i < kBatchSize; ++i) { + Prompt prompt = training_task.Sample(sgen); + total_loss += CrossEntropyLossForwardPass( + model_type, prompt, weights, forward, pool); + CrossEntropyLossBackwardPass( + model_type, prompt, weights, forward, grad, backward, pool); + num_ok += verify(prompt) ? 1 : 0; + } + total_loss /= kBatchSize; + + const float scale = -learning_rate / kBatchSize; + UpdateWeights(model_type, grad, scale, weights, pool); + printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n", + steps, total_loss, num_ok, kBatchSize); + if (steps % 100 == 0) { + printf("Batch gradient:\n"); + LogWeightStats(model_type, grad); + } + if (total_loss < 0.5f) { + break; + } + prev_loss = total_loss; + } + printf("Num steps: %zu\n", steps); + printf("Final weights:\n"); + LogWeightStats(model_type, weights); + EXPECT_LT(steps, 3000); + EXPECT_EQ(num_ok, kBatchSize); +} + +} // namespace gcpp diff --git a/gemma/optimizer.cc b/gemma/optimizer.cc new file mode 100644 index 0000000..561b8ee --- /dev/null +++ b/gemma/optimizer.cc @@ -0,0 +1,121 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gemma/optimizer.h" + +#include + +#include "gemma/common.h" +#include "gemma/configs.h" +#include "gemma/weights.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +namespace gcpp { + +namespace { +class WeightInitializer { + public: + WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {} + + template + void operator()(const char* name, std::array& tensor) { + for (size_t i = 0; i < N; ++i) { + tensor[i] = dist_(gen_); + } + } + private: + std::normal_distribution dist_; + std::mt19937& gen_; +}; + +template +void RandInitWeights(ByteStorageT& weights_u8, hwy::ThreadPool& pool, + std::mt19937& gen) { + auto& weights = *reinterpret_cast*>(weights_u8.get()); + // TODO(szabadka) Use the same weight initialization method as in the python + // version. + WeightInitializer init(gen); + ForEachTensor1(init, weights); +} + +class WeightUpdater { + public: + explicit WeightUpdater(float lr) : lr_(lr) {} + + template + void operator()(const char* name, const std::array& grad, + std::array& weights) { + for (size_t i = 0; i < kCapacity; ++i) { + weights[i] += lr_ * grad[i]; + } + } + + private: + float lr_; +}; + +template +void UpdateWeights(const ByteStorageT& grad_u8, float scale, + ByteStorageT& weights_u8, hwy::ThreadPool& pool) { + const auto& grad = + *reinterpret_cast*>(grad_u8.get()); + auto& weights = *reinterpret_cast*>(weights_u8.get()); + WeightUpdater updater(scale); + ForEachTensor2(updater, grad, weights); +} + +} // namespace + +void RandInitWeights(Model model, ByteStorageT& weights_u8, + hwy::ThreadPool& pool, std::mt19937& gen) { + switch (model) { + case Model::GEMMA_2B: + RandInitWeights(weights_u8, pool, gen); + break; + case Model::GEMMA_7B: + RandInitWeights(weights_u8, pool, gen); + break; + case Model::GRIFFIN_2B: + RandInitWeights(weights_u8, pool, gen); + break; + case Model::GEMMA_TINY: + RandInitWeights(weights_u8, pool, gen); + break; + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } +} + +void UpdateWeights(Model model, const ByteStorageT& grad, float scale, + ByteStorageT& weights, hwy::ThreadPool& pool) { + switch (model) { + case Model::GEMMA_2B: + UpdateWeights(grad, scale, weights, pool); + break; + case Model::GEMMA_7B: + UpdateWeights(grad, scale, weights, pool); + break; + case Model::GRIFFIN_2B: + UpdateWeights(grad, scale, weights, pool); + break; + case Model::GEMMA_TINY: + UpdateWeights(grad, scale, weights, pool); + break; + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } +} + +} // namespace gcpp diff --git a/gemma/optimizer.h b/gemma/optimizer.h new file mode 100644 index 0000000..343db97 --- /dev/null +++ b/gemma/optimizer.h @@ -0,0 +1,34 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ + +#include + +#include "gemma/common.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +namespace gcpp { + +void RandInitWeights(Model model, ByteStorageT& weights, hwy::ThreadPool& pool, + std::mt19937& gen); + +void UpdateWeights(Model model, const ByteStorageT& grad, float scale, + ByteStorageT& weights, hwy::ThreadPool& pool); + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ diff --git a/gemma/prompt.h b/gemma/prompt.h new file mode 100644 index 0000000..6f762dc --- /dev/null +++ b/gemma/prompt.h @@ -0,0 +1,33 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_ + +#include + +namespace gcpp { + +struct Prompt { + std::vector tokens; + size_t context_size; + std::vector context() const { + return std::vector(tokens.begin(), tokens.begin() + context_size); + } +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_PROMPT_H_ diff --git a/gemma/sampler.h b/gemma/sampler.h new file mode 100644 index 0000000..34a5ed3 --- /dev/null +++ b/gemma/sampler.h @@ -0,0 +1,84 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_ + +#include + +#include "gemma/prompt.h" + +namespace gcpp { + +class PromptSampler { + public: + virtual Prompt Sample(std::mt19937& gen) = 0; + + std::vector SampleBatch(size_t batch_size, std::mt19937& gen) { + std::vector batch; + batch.reserve(batch_size); + for (size_t i = 0; i < batch_size; ++i) { + batch.emplace_back(Sample(gen)); + } + return batch; + } +}; + +class ReverseSequenceSampler : public PromptSampler { + public: + explicit ReverseSequenceSampler(const std::vector& length_histo) + : token_dist_(0, 9) { + for (int i = 0; i < length_histo.size(); ++i) { + const int count = length_histo[i]; + for (int j = 0; j < count; ++j) { + length_lut_.push_back(i + 1); + } + } + length_dist_ = std::uniform_int_distribution<>(0, length_lut_.size() - 1); + } + + static constexpr int kReverseToken = 10; + static constexpr int kEndToken = 11; + + Prompt Sample(std::mt19937& gen) override { + Prompt prompt; + int len = length_lut_[length_dist_(gen)]; + prompt.tokens.resize(2 * len + 2); + prompt.tokens[len] = kReverseToken; + prompt.tokens[2 * len + 1] = kEndToken; + for (size_t i = 0; i < len; ++i) { + prompt.tokens[i] = prompt.tokens[2 * len - i] = token_dist_(gen); + } + prompt.context_size = len + 1; + return prompt; + } + + static void LogPrompt(const Prompt& prompt) { + static const char* kVocab[] = { + "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "-->", "|", + }; + for (int token : prompt.tokens) printf("%s", kVocab[token]); + printf(" [context_size: %zu]\n", prompt.context_size); + } + + private: + std::uniform_int_distribution<> token_dist_; + std::uniform_int_distribution<> length_dist_; + std::vector length_lut_; +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_ diff --git a/gemma/test_util.h b/gemma/test_util.h new file mode 100644 index 0000000..e9cd533 --- /dev/null +++ b/gemma/test_util.h @@ -0,0 +1,195 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_ + +#include +#include +#include + +#include "gemma/weights.h" +#include "gtest/gtest.h" + +namespace gcpp { + +template +void RandInit(std::array& x, T stddev, std::mt19937& gen) { + std::normal_distribution dist(0.0, stddev); + for (size_t i = 0; i < kLen; ++i) { + x[i] = dist(gen); + } +} + +template +void RandInit(Layer& w, T stddev, std::mt19937& gen) { + RandInit(w.pre_attention_norm_scale, stddev, gen); + RandInit(w.attn_vec_einsum_w, stddev, gen); + RandInit(w.qkv_einsum_w, stddev, gen); + RandInit(w.pre_ffw_norm_scale, stddev, gen); + RandInit(w.gating_einsum_w, stddev, gen); + RandInit(w.linear_w, stddev, gen); +} + +template +void RandInit(Weights& w, T stddev, std::mt19937& gen) { + static constexpr size_t kLayers = TConfig::kLayers; + RandInit(w.embedder_input_embedding, stddev, gen); + RandInit(w.final_norm_scale, stddev, gen); + for (size_t i = 0; i < kLayers; ++i) { + RandInit(*w.GetLayer(i), stddev, gen); + } +} + +template +void Complexify(const std::array& x, + std::array, kLen>& c_x) { + for (size_t i = 0; i < kLen; ++i) { + c_x[i] = std::complex(x[i], 0.0); + } +} + + +template +void Complexify(const Layer& w, + Layer, TConfig>& c_w) { + Complexify(w.pre_attention_norm_scale, c_w.pre_attention_norm_scale); + Complexify(w.attn_vec_einsum_w, c_w.attn_vec_einsum_w); + Complexify(w.qkv_einsum_w, c_w.qkv_einsum_w); + Complexify(w.pre_ffw_norm_scale, c_w.pre_ffw_norm_scale); + Complexify(w.gating_einsum_w, c_w.gating_einsum_w); + Complexify(w.linear_w, c_w.linear_w); +} + +template +void Complexify(const Weights& w, + Weights, TConfig>& c_w) { + static constexpr size_t kLayers = TConfig::kLayers; + Complexify(w.embedder_input_embedding, c_w.embedder_input_embedding); + Complexify(w.final_norm_scale, c_w.final_norm_scale); + for (size_t i = 0; i < kLayers; ++i) { + Complexify(*w.GetLayer(i), *c_w.GetLayer(i)); + } +} + +template +void TestNear(const std::array& actual, const std::array& expected, + double max_abs_err, double max_rel_err, int line) { + double sum0 = 0; + double sum1 = 0; + double sum01 = 0; + for (size_t i = 0; i < N; ++i) { + sum0 += actual[i] * actual[i]; + sum1 += expected[i] * expected[i]; + sum01 += actual[i] * expected[i]; + ASSERT_NEAR(actual[i], expected[i], + std::max(max_abs_err, std::abs(expected[i]) * max_rel_err)) + << "line: " << line << " dim=" << N << " i=" << i; + } + if (sum0 > 1e-40) { + double norm_dot = sum01 / std::sqrt(sum0) / std::sqrt(sum1); + ASSERT_NEAR(norm_dot, 1.0, 1e-7) + << "line: " << line << " sum0: " << sum0 << " sum1: " << sum1 + << " sum01: " << sum01; + } +} + +// Compute gradient with the finite difference method in the complex plane. +// If f : R->R is the tested function and F : C->C is its extenstion on the +// complex plane so that F is complex differentiable in x, then +// +// F(x + ih) = F(x) + ih F'(x) + O(h^2) F''(x) +// +// which means that +// +// F'(x) ~= Imag(F(x + ih)) / h +// +// This method is more numerically stable than the real-valued finite difference +// method since we don't need to substract floating point numbers that are near +// to each other. +template +void TestGradient(const std::array& grad, + std::array, N>& x, FUNC func, + U step, T max_abs_err, T max_rel_err, int line) { + std::array exp_grad; + const U inv_step = 1.0 / step; + for (size_t i = 0; i < N; ++i) { + const U x0 = std::real(x[i]); + const std::complex x1 = std::complex(x0, step); + x[i] = x1; + const std::complex f1 = func(); + exp_grad [i] = std::imag(f1) * inv_step; + x[i] = x0; + } + TestNear(grad, exp_grad, max_abs_err, max_rel_err, line); +} + +template +void TestGradient(const std::array& grad, + std::array, N>& x, FUNC func, + float max_abs_err, float max_rel_error, int line) { + TestGradient(grad, x, func, 1e-30f, max_abs_err, max_rel_error, line); +} + +template +void TestGradient(const std::array& grad, + std::array, N>& x, FUNC func, + float max_abs_err, float max_rel_error, int line) { + TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line); +} + +template +void TestGradient(const std::array& grad, + std::array, N>& x, FUNC func, + double max_abs_err, double max_rel_error, int line) { + TestGradient(grad, x, func, 1e-50, max_abs_err, max_rel_error, line); +} + +template +void TestGradient(const Layer& grad, + Layer, TConfig>& c_weights, + FUNC func, T max_err) { + TestGradient(grad.pre_attention_norm_scale, + c_weights.pre_attention_norm_scale, + func, max_err, max_err, __LINE__); + TestGradient(grad.attn_vec_einsum_w, c_weights.attn_vec_einsum_w, + func, max_err, max_err, __LINE__); + TestGradient(grad.qkv_einsum_w, c_weights.qkv_einsum_w, + func, max_err, max_err, __LINE__); + TestGradient(grad.pre_ffw_norm_scale, c_weights.pre_ffw_norm_scale, + func, max_err, max_err, __LINE__); + TestGradient(grad.gating_einsum_w, c_weights.gating_einsum_w, + func, max_err, max_err, __LINE__); + TestGradient(grad.linear_w, c_weights.linear_w, + func, max_err, max_err, __LINE__); +} + +template +void TestGradient(const Weights& grad, + Weights, TConfig>& c_weights, + FUNC func, T max_err) { + TestGradient(grad.embedder_input_embedding, + c_weights.embedder_input_embedding, + func, 2 * max_err, max_err, __LINE__); + TestGradient(grad.final_norm_scale, c_weights.final_norm_scale, + func, max_err, max_err, __LINE__); + for (int i = 0; i < TConfig::kLayers; ++i) { + TestGradient(*grad.GetLayer(i), *c_weights.GetLayer(i), func, max_err); + } +} + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TEST_UTIL_H_ diff --git a/gemma/weights.cc b/gemma/weights.cc new file mode 100644 index 0000000..c6387f7 --- /dev/null +++ b/gemma/weights.cc @@ -0,0 +1,116 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gemma/weights.h" + +#include "gemma/common.h" +#include "gemma/configs.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +namespace gcpp { + +ByteStorageT AllocateWeights(Model model, hwy::ThreadPool& pool) { + switch (model) { + case Model::GEMMA_2B: + return AllocateWeights(pool); + case Model::GEMMA_7B: + return AllocateWeights(pool); + case Model::GRIFFIN_2B: + return AllocateWeights(pool); + case Model::GEMMA_TINY: + return AllocateWeights(pool); + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } +} + +namespace { +template +void ZeroInitWeightsT(ByteStorageT& weights, hwy::ThreadPool& pool) { + ZeroInit( + *reinterpret_cast*>(weights.get())); +} +} // namespace + +void ZeroInitWeights(Model model, ByteStorageT& weights, + hwy::ThreadPool& pool) { + switch (model) { + case Model::GEMMA_2B: + ZeroInitWeightsT(weights, pool); + break; + case Model::GEMMA_7B: + ZeroInitWeightsT(weights, pool); + break; + case Model::GRIFFIN_2B: + ZeroInitWeightsT(weights, pool); + break; + case Model::GEMMA_TINY: + ZeroInitWeightsT(weights, pool); + break; + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } +} + +namespace { +void LogVec(const char* name, const float* data, size_t len) { + float minval = std::numeric_limits::max(); + float maxval = std::numeric_limits::min(); + double sum = 0.0f; + for (size_t i = 0; i < len; ++i) { + minval = std::min(minval, data[i]); + maxval = std::max(maxval, data[i]); + sum += data[i]; + } + float avg = sum / len; + printf("%-20s %12zu %13.10f %8.5f %13.10f\n", + name, len, minval, avg, maxval); +} + +class WeightLogger { + public: + template + void operator()(const char* name, const std::array& tensor) { + LogVec(name, tensor.data(), N); + total_weights += N; + } + size_t total_weights = 0; +}; + +template +void LogWeightStats(const ByteStorageT& weights_u8) { + const auto& weights = *reinterpret_cast*>(weights_u8.get()); + WeightLogger logger; + ForEachTensor1(logger, weights); + printf("%-20s %12zu\n", "Total", logger.total_weights); +} +} // namespace + +void LogWeightStats(gcpp::Model model, const ByteStorageT& weights) { + switch (model) { + case Model::GEMMA_2B: + return LogWeightStats(weights); + case Model::GEMMA_7B: + return LogWeightStats(weights); + case Model::GRIFFIN_2B: + return LogWeightStats(weights); + case Model::GEMMA_TINY: + return LogWeightStats(weights); + default: + HWY_ABORT("Model type %d unknown.", static_cast(model)); + } +} + +} // namespace gcpp diff --git a/gemma/weights.h b/gemma/weights.h new file mode 100644 index 0000000..9777552 --- /dev/null +++ b/gemma/weights.h @@ -0,0 +1,288 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_ + +#include "gemma/common.h" +#include "gemma/configs.h" +#include "hwy/aligned_allocator.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +namespace gcpp { + +template +struct Layer { + Layer() {} + static constexpr size_t kHeads = TConfig::kHeads; + static constexpr size_t kKVHeads = TConfig::kKVHeads; + static constexpr size_t kModelDim = TConfig::kModelDim; + static constexpr size_t kQKVDim = TConfig::kQKVDim; + static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim; + static constexpr size_t kQKVEinsumWSize = + (kHeads + 2 * kKVHeads) * kQKVDim * kModelDim; + // 2x for (gelu gating vector, gated vector) + static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim; + static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; + static constexpr bool kFFBiases = TConfig::kFFBiases; + static constexpr bool kPostNormScale = TConfig::kPostNormScale; + static constexpr size_t kAOBiasDim = + TConfig::kSoftmaxAttnOutputBiases ? kModelDim : 0; + static constexpr size_t kGriffinDim = + TConfig::kGriffinLayers > 0 ? kModelDim : 0; + + union { + struct { + std::array attn_vec_einsum_w; + std::array qkv_einsum_w; + std::array attention_output_biases; + }; + + struct { + std::array linear_x_w; + std::array linear_x_biases; + std::array linear_y_w; + std::array linear_y_biases; + std::array linear_out_w; + std::array linear_out_biases; + std::array conv_w; + std::array conv_biases; + std::array gate_w; + std::array gate_biases; + std::array a; + } griffin; + }; + + std::array gating_einsum_w; + std::array linear_w; + std::array pre_attention_norm_scale; + std::array pre_ffw_norm_scale; + std::array post_attention_norm_scale; + std::array post_ffw_norm_scale; + + std::array ffw_gating_biases; + std::array ffw_output_biases; +}; + +template +using LayerF = Layer; + +// Array instead of single large allocation for parallel mem init. Split out of +// Weights so that only these pointers are initialized. +template +struct LayerPointers { + explicit LayerPointers(hwy::ThreadPool& pool) { + pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) { + this->layers[task] = hwy::AllocateAligned>(1); + }); + } + + using TLayer = Layer; + std::array, TConfig::kLayers> layers; +}; + +template +struct Weights { + // No ctor/dtor, allocated via AllocateAligned. + + std::array + embedder_input_embedding; + + std::array final_norm_scale; + + LayerPointers layer_ptrs; + + std::array scales; + + const Layer* GetLayer(size_t layer) const { + return layer_ptrs.layers[layer].get(); + } + Layer* GetLayer(size_t layer) { + return layer_ptrs.layers[layer].get(); + } +}; + +template +using WeightsF = Weights; + +template +ByteStorageT AllocateWeights(hwy::ThreadPool& pool) { + using TWeights = Weights; + ByteStorageT weights_u8 = hwy::AllocateAligned(sizeof(TWeights)); + TWeights* weights = reinterpret_cast(weights_u8.get()); + new (&weights->layer_ptrs) LayerPointers(pool); + return weights_u8; +} + +#define CALL_TOP_FUNC1(name, member) func(name, weights1.member) +#define CALL_TOP_FUNC2(name, member) \ + func(name, weights1.member, weights2.member) +#define CALL_TOP_FUNC3(name, member) \ + func(name, weights1.member, weights2.member, weights3.member) +#define CALL_TOP_FUNC4(name, member) \ + func(name, weights1.member, weights2.memeber, \ + weights3.member, weights4.member) + +#define CALL_LAYER_FUNC1(name, member) \ + snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ + func(name_buf, layer1.member) + +#define CALL_LAYER_FUNC2(name, member) \ + snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ + func(name_buf, layer1.member, layer2.member) + +#define CALL_LAYER_FUNC3(name, member) \ + snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ + func(name_buf, layer1.member, layer2.member, layer3.member) + +#define CALL_LAYER_FUNC4(name, member) \ + snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ + func(name_buf, layer1.member, layer2.member, layer4.member) + +#define CALL_ALL_LAYER_FUNC(N) \ + if (type == LayerAttentionType::kGemma) { \ + CALL_LAYER_FUNC ## N("att_ein", attn_vec_einsum_w); \ + CALL_LAYER_FUNC ## N("qkv_ein", qkv_einsum_w); \ + } else { \ + CALL_LAYER_FUNC ## N("gr_lin_x_w", griffin.linear_x_w); \ + CALL_LAYER_FUNC ## N("gr_lin_x_b", griffin.linear_x_biases); \ + CALL_LAYER_FUNC ## N("gr_lin_y_w", griffin.linear_y_w); \ + CALL_LAYER_FUNC ## N("gr_lin_y_b", griffin.linear_y_biases); \ + CALL_LAYER_FUNC ## N("gr_lin_out_w", griffin.linear_out_w); \ + CALL_LAYER_FUNC ## N("gr_lin_out_b", griffin.linear_out_biases); \ + CALL_LAYER_FUNC ## N("gr_conv_w", griffin.conv_w); \ + CALL_LAYER_FUNC ## N("gr_conv_b", griffin.conv_biases); \ + CALL_LAYER_FUNC ## N("gr_gate_w", griffin.gate_w); \ + CALL_LAYER_FUNC ## N("gr_gate_b", griffin.gate_biases); \ + CALL_LAYER_FUNC ## N("gr_a", griffin.a); \ + } \ + CALL_LAYER_FUNC ## N("gating_ein", gating_einsum_w); \ + CALL_LAYER_FUNC ## N("linear_w", linear_w); \ + CALL_LAYER_FUNC ## N("pre_att_ns", pre_attention_norm_scale); \ + if (TConfig::kPostNormScale) { \ + CALL_LAYER_FUNC ## N("post_att_ns", post_attention_norm_scale); \ + CALL_LAYER_FUNC ## N("post_ff_ns", post_ffw_norm_scale); \ + } \ + CALL_LAYER_FUNC ## N("pre_ff_ns", pre_ffw_norm_scale); \ + if (TConfig::kFFBiases) { \ + CALL_LAYER_FUNC ## N("ffw_gat_b", ffw_gating_biases); \ + CALL_LAYER_FUNC ## N("ffw_out_b", ffw_output_biases); \ + } \ + if (TConfig::kSoftmaxAttnOutputBiases && \ + type == LayerAttentionType::kGemma) { \ + CALL_LAYER_FUNC ## N("attn_ob", attention_output_biases); \ + } + +template +void ForEachTensor1(Func& func, const Weights& weights1) { + CALL_TOP_FUNC1("embedding", embedder_input_embedding); + CALL_TOP_FUNC1("final_norm", final_norm_scale); + char name_buf[16]; + for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { + auto type = TConfig::kLayerConfig[layer_idx]; + const size_t idx = static_cast(layer_idx); + const LayerF& layer1 = *weights1.GetLayer(idx); + CALL_ALL_LAYER_FUNC(1) + } +} + +template +void ForEachTensor1(Func& func, Weights& weights1) { + CALL_TOP_FUNC1("embedding", embedder_input_embedding); + CALL_TOP_FUNC1("final_norm", final_norm_scale); + char name_buf[16]; + for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { + auto type = TConfig::kLayerConfig[layer_idx]; + const size_t idx = static_cast(layer_idx); + LayerF& layer1 = *weights1.GetLayer(idx); + CALL_ALL_LAYER_FUNC(1) + } +} + +template +void ForEachTensor2(Func& func, const Weights& weights1, + Weights& weights2) { + CALL_TOP_FUNC2("embedding", embedder_input_embedding); + CALL_TOP_FUNC2("final_norm", final_norm_scale); + char name_buf[16]; + for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { + auto type = TConfig::kLayerConfig[layer_idx]; + const size_t idx = static_cast(layer_idx); + const LayerF& layer1 = *weights1.GetLayer(idx); + LayerF& layer2 = *weights2.GetLayer(idx); + CALL_ALL_LAYER_FUNC(2) + } +} + +#undef CALL_TOP_FUNC1 +#undef CALL_TOP_FUNC2 +#undef CALL_TOP_FUNC3 +#undef CALL_TOP_FUNC4 +#undef CALL_LAYER_FUNC1 +#undef CALL_LAYER_FUNC2 +#undef CALL_LAYER_FUNC3 +#undef CALL_LAYER_FUNC4 +#undef CALL_ALL_LAYER_FUNC + +template +void ZeroInit(Weights& w) { + memset(&w.embedder_input_embedding, 0, sizeof(w.embedder_input_embedding)); + memset(&w.final_norm_scale, 0, sizeof(w.final_norm_scale)); + for (int i = 0; i < TConfig::kLayers; ++i) { + memset(w.GetLayer(i), 0, sizeof(*w.GetLayer(i))); + } +} + +template +void Copy(Weights& dst, const Weights& src) { + memcpy(&dst.embedder_input_embedding, &src.embedder_input_embedding, + sizeof(src.embedder_input_embedding)); + memcpy(&dst.final_norm_scale, &src.final_norm_scale, + sizeof(src.final_norm_scale)); + for (int i = 0; i < TConfig::kLayers; ++i) { + memcpy(dst.GetLayer(i), src.GetLayer(i), sizeof(*dst.GetLayer(i))); + } +} + +template +class WeightsWrapper { + public: + WeightsWrapper() + : pool_(0), data_(AllocateWeights(pool_)), + weights_(reinterpret_cast*>(data_.get())) {} + + const Weights& get() const { return *weights_; } + Weights& get() { return *weights_; } + void clear() { ZeroInit(get()); } + void copy(const WeightsWrapper& other) { + Copy(get(), other.get()); + } + + private: + hwy::ThreadPool pool_; + ByteStorageT data_; + Weights* weights_; +}; + +ByteStorageT AllocateWeights(Model model, hwy::ThreadPool& pool); + +void ZeroInitWeights(Model model, ByteStorageT& weights, hwy::ThreadPool& pool); + +void LogWeightStats(Model model, const ByteStorageT& weights); + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_