diff --git a/CMakeLists.txt b/CMakeLists.txt index 6c4d12f..72409f6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -50,6 +50,7 @@ set(SOURCES backprop/backward.h backprop/backward-inl.h backprop/backward_scalar.h + backprop/common_scalar.h backprop/forward.cc backprop/forward.h backprop/forward-inl.h diff --git a/backprop/backward_scalar.h b/backprop/backward_scalar.h index fe5ca46..445c2ad 100644 --- a/backprop/backward_scalar.h +++ b/backprop/backward_scalar.h @@ -23,6 +23,7 @@ #include #include +#include "backprop/common_scalar.h" #include "backprop/prompt.h" #include "gemma/activations.h" #include "gemma/common.h" // EmbeddingScaling diff --git a/backprop/common_scalar.h b/backprop/common_scalar.h new file mode 100644 index 0000000..49c4e1c --- /dev/null +++ b/backprop/common_scalar.h @@ -0,0 +1,117 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_ + +#include + +namespace gcpp { + +template +U DotT(const T* a, const U* b, size_t N) { + U sum = {}; + for (size_t i = 0; i < N; ++i) { + sum += a[i] * b[i]; + } + return sum; +} + +template<> +std::complex DotT(const float* a, const std::complex* b, + size_t N) { + std::complex sum = {}; + for (size_t i = 0; i < N; ++i) { + sum += static_cast(a[i]) * b[i]; + } + return sum; +} + +template +void MulByConstT(T c, T* x, size_t N) { + for (size_t i = 0; i < N; ++i) { + x[i] *= c; + } +} + +// out += c * x +template +void MulByConstAndAddT(T c, const T* x, T* out, size_t N) { + for (size_t i = 0; i < N; ++i) { + out[i] += c * x[i]; + } +} + +template +void MulByConstAndAddT(T c, const std::array& x, std::array& out) { + MulByConstAndAddT(c, x.data(), out.data(), N); +} + +template +void AddFromT(const T* a, T* out, size_t N) { + for (size_t i = 0; i < N; ++i) { + out[i] += a[i]; + } +} + +template +T SquaredL2(const T* x, size_t N) { + T sum = {}; + for (size_t i = 0; i < N; ++i) { + sum += x[i] * x[i]; + } + return sum; +} + +template +T Gelu(T x) { + static const T kMul = 0.044715; + static const T kSqrt2OverPi = 0.797884560804236; + + const T x3 = x * x * x; + const T arg = kSqrt2OverPi * (kMul * x3 + x); + const T cdf = T(0.5) * (T(1.0) + std::tanh(arg)); + return x * cdf; +} + +template +void Rope(T* x, U base, size_t N, int i) { + const size_t N2 = N / 2; + for (size_t dim = 0; dim < N2; ++dim) { + const T freq_exponents = T(2 * dim) / T(N); + const T timescale = std::pow(base, freq_exponents); + const T theta = T(i) / timescale; + const T cos_val = std::cos(theta); + const T sin_val = std::sin(theta); + const T x0 = x[dim]; + const T x1 = x[dim + N2]; + x[dim] = x0 * cos_val - x1 * sin_val; + x[dim + N2] = x0 * sin_val + x1 * cos_val; + } +} + +template +void Rope(T* x, size_t N, int i) { + Rope(x, T(10000.0), N, i); +} + +template +void Rope(std::complex* x, size_t N, int i) { + Rope(x, T(10000.0), N, i); +} + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_ diff --git a/backprop/forward_scalar.h b/backprop/forward_scalar.h index 44182a8..475f7c3 100644 --- a/backprop/forward_scalar.h +++ b/backprop/forward_scalar.h @@ -23,6 +23,7 @@ #include #include +#include "backprop/common_scalar.h" #include "backprop/prompt.h" #include "gemma/activations.h" #include "gemma/common.h" // EmbeddingScaling diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 32d4b93..0e22adf 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -16,7 +16,6 @@ #include #include -#include "gtest/gtest.h" #include "backprop/backward.h" #include "backprop/forward.h" #include "backprop/optimizer.h" @@ -24,6 +23,7 @@ #include "gemma/activations.h" #include "gemma/gemma.h" #include "gemma/weights.h" +#include "gtest/gtest.h" namespace gcpp { @@ -32,9 +32,9 @@ TEST(OptimizeTest, GradientDescent) { std::mt19937 gen(42); Model model_type = Model::GEMMA_TINY; - ByteStorageT grad = CallFunctorForModel(model_type, pool); - ByteStorageT grad_m = CallFunctorForModel(model_type, pool); - ByteStorageT grad_v = CallFunctorForModel(model_type, pool); + ByteStorageT grad = CallFunctorForModel(model_type, pool); + ByteStorageT grad_m = CallFunctorForModel(model_type, pool); + ByteStorageT grad_v = CallFunctorForModel(model_type, pool); ByteStorageT forward = CallFunctorForModel(model_type); ByteStorageT backward = CallFunctorForModel(model_type); KVCache kv_cache = KVCache::Create(model_type); @@ -57,7 +57,7 @@ TEST(OptimizeTest, GradientDescent) { stream_token, accept_token, ReverseSequenceSampler::kEndToken, }; TimingInfo timing_info; - model.Generate(runtime, prompt, 0, kv_cache, timing_info); + gemma.Generate(runtime, prompt, 0, kv_cache, timing_info); return reply; }; @@ -73,15 +73,18 @@ TEST(OptimizeTest, GradientDescent) { return ok; }; - CallFunctorForModel(model_type, gemma.Weights(), pool, gen); - CallFunctorForModel(model_type, grad_m, pool); - CallFunctorForModel(model_type, grad_v, pool); + RandInitWeights(model_type, gemma.Weights(), pool, gen); + CallFunctorForModel(model_type, grad_m, pool); + CallFunctorForModel(model_type, grad_v, pool); printf("Initial weights:\n"); LogWeightStats(model_type, gemma.Weights()); constexpr size_t kBatchSize = 8; - float learning_rate = 0.0005f; + const float alpha = 0.001f; + const float beta1 = 0.9f; + const float beta2 = 0.999f; + const float epsilon = 1e-8f; ReverseSequenceSampler training_task({ 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1}); @@ -90,7 +93,7 @@ TEST(OptimizeTest, GradientDescent) { size_t num_ok; for (; steps < 1000000; ++steps) { std::mt19937 sgen(42); - CallFunctorForModel(model_type, grad, pool); + CallFunctorForModel(model_type, grad, pool); float total_loss = 0.0f; num_ok = 0; for (size_t i = 0; i < kBatchSize; ++i) { @@ -103,8 +106,8 @@ TEST(OptimizeTest, GradientDescent) { } total_loss /= kBatchSize; - const float scale = -learning_rate / kBatchSize; - UpdateWeights(model_type, grad, scale, gemma.Weights(), pool); + AdamUpdate(model_type, grad, alpha, beta1, beta2, epsilon, steps + 1, + gemma.Weights(), grad_m, grad_v, pool); printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n", steps, total_loss, num_ok, kBatchSize); if (steps % 100 == 0) { @@ -119,7 +122,7 @@ TEST(OptimizeTest, GradientDescent) { printf("Num steps: %zu\n", steps); printf("Final weights:\n"); LogWeightStats(model_type, gemma.Weights()); - EXPECT_LT(steps, 3000); + EXPECT_LT(steps, 200); EXPECT_EQ(num_ok, kBatchSize); } diff --git a/backprop/optimizer.cc b/backprop/optimizer.cc index 79659b7..741070b 100644 --- a/backprop/optimizer.cc +++ b/backprop/optimizer.cc @@ -25,6 +25,7 @@ namespace gcpp { namespace { + class WeightInitializer { public: WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {} @@ -41,8 +42,8 @@ class WeightInitializer { }; template -struct RandInitWeights { - void operator()(ByteStorageT& weights_u8, hwy::ThreadPool& pool, +struct RandInitWeightsT { + void operator()(const ByteStorageT& weights_u8, hwy::ThreadPool& pool, std::mt19937& gen) const { auto& weights = *reinterpret_cast*>(weights_u8.get()); // TODO(szabadka) Use the same weight initialization method as in the python @@ -52,39 +53,71 @@ struct RandInitWeights { } }; -class WeightUpdater { +class AdamUpdater { public: - explicit WeightUpdater(float lr) : lr_(lr) {} + explicit AdamUpdater(float alpha, float beta1, float beta2, float epsilon, + size_t t) + : alpha_(alpha), beta1_(beta1), beta2_(beta2), cbeta1_(1.0f - beta1), + cbeta2_(1.0f - beta2), norm1_(1.0 / (1.0 - std::pow(beta1, t))), + norm2_(1.0 / (1.0 - std::pow(beta2, t))), epsilon_(epsilon) {} template void operator()(const char* name, const std::array& grad, - std::array& weights) { + std::array& weights, + std::array& grad_m, + std::array& grad_v) { for (size_t i = 0; i < kCapacity; ++i) { - weights[i] += lr_ * grad[i]; + grad_m[i] *= beta1_; + grad_m[i] += cbeta1_ * grad[i]; + grad_v[i] *= beta2_; + grad_v[i] += cbeta2_ * grad[i] * grad[i]; + const float mhat = grad_m[i] * norm1_; + const float vhat = grad_v[i] * norm2_; + weights[i] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_); } } private: - float lr_; + float alpha_; + float beta1_; + float beta2_; + float cbeta1_; + float cbeta2_; + float norm1_; + float norm2_; + float epsilon_; }; template -struct UpdateWeightsT { - void operator()(const ByteStorageT& grad_u8, float scale, - ByteStorageT& weights_u8, hwy::ThreadPool& pool) const { +struct AdamUpdateT { + void operator()(const ByteStorageT& grad_u8, float alpha, float beta1, + float beta2, float epsilon, size_t t, + const ByteStorageT& weights_u8, const ByteStorageT& grad_m_u8, + const ByteStorageT& grad_v_u8, hwy::ThreadPool& pool) const { const auto& grad = *reinterpret_cast*>(grad_u8.get()); auto& weights = *reinterpret_cast*>(weights_u8.get()); - WeightUpdater updater(scale); - ForEachTensor2(updater, grad, weights); + auto& grad_m = *reinterpret_cast*>(grad_m_u8.get()); + auto& grad_v = *reinterpret_cast*>(grad_v_u8.get()); + AdamUpdater updater(alpha, beta1, beta2, epsilon, t); + ForEachTensor4(updater, grad, weights, grad_m, grad_v); } }; } // namespace -void UpdateWeights(Model model, const ByteStorageT& grad, float scale, - ByteStorageT& weights, hwy::ThreadPool& pool) { - CallFunctorForModel(model, grad, scale, weights, pool); +void RandInitWeights(Model model, const ByteStorageT& weights, + hwy::ThreadPool& pool, + std::mt19937& gen) { + CallFunctorForModel(model, weights, pool, gen); +} + +void AdamUpdate(Model model, const ByteStorageT& grad, float alpha, float beta1, + float beta2, float epsilon, size_t t, + const ByteStorageT& weights, const ByteStorageT& grad_m, + const ByteStorageT& grad_v, hwy::ThreadPool& pool) { + CallFunctorForModel(model, grad, alpha, beta1, beta2, epsilon, t, + weights, grad_m, grad_v, pool); } } // namespace gcpp diff --git a/backprop/optimizer.h b/backprop/optimizer.h index 157352b..90fa2c7 100644 --- a/backprop/optimizer.h +++ b/backprop/optimizer.h @@ -16,13 +16,21 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ +#include + #include "gemma/common.h" +#include "gemma/weights.h" #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { -void UpdateWeights(Model model, const ByteStorageT& grad, float scale, - ByteStorageT& weights, hwy::ThreadPool& pool); +void RandInitWeights(Model model, const ByteStorageT& weights, + hwy::ThreadPool& pool, std::mt19937& gen); + +void AdamUpdate(Model model, const ByteStorageT& grad, float alpha, float beta1, + float beta2, float epsilon, size_t t, + const ByteStorageT& weights, const ByteStorageT& grad_m, + const ByteStorageT& grad_v, hwy::ThreadPool& pool); } // namespace gcpp diff --git a/backprop/sampler.h b/backprop/sampler.h index 257b993..d65be21 100644 --- a/backprop/sampler.h +++ b/backprop/sampler.h @@ -16,6 +16,7 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_ +#include #include #include "backprop/prompt.h" diff --git a/gemma/activations.h b/gemma/activations.h index 6894d67..6d2bc22 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -73,7 +73,7 @@ class ActivationsWrapper { public: ActivationsWrapper() - : data_(WrappedT::Allocate()), + : data_(AllocateSizeof()), activations_(*reinterpret_cast(data_.get())) {} const WrappedT& get() const { return activations_; } diff --git a/gemma/common.h b/gemma/common.h index e277095..e3c2650 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -125,6 +125,13 @@ GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() { Sqrt(static_cast(TConfig::kModelDim)))); } +static HWY_INLINE GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling( + size_t model_dim) { + // Round to bf16 to match Gemma's Embedder, which casts before mul. + return hwy::ConvertScalarTo(hwy::ConvertScalarTo( + Sqrt(static_cast(model_dim)))); +} + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 3135b51..23d7922 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -154,6 +154,7 @@ KVCache KVCache::Create(Model type) { class GemmaTokenizer::Impl { public: + Impl() = default; explicit Impl(const Path& tokenizer_path) { PROFILER_ZONE("Startup.tokenizer"); spp_ = std::make_unique(); @@ -164,24 +165,24 @@ class GemmaTokenizer::Impl { bool Encode(const std::string& input, std::vector* pieces) const { - return spp_->Encode(input, pieces).ok(); + return spp_ && spp_->Encode(input, pieces).ok(); } bool Encode(const std::string& input, std::vector* pieces) const { if constexpr (kShowTokenization) { - bool is_ok = spp_->Encode(input, pieces).ok(); + bool is_ok = spp_ && spp_->Encode(input, pieces).ok(); for (int i = 0; i < static_cast(pieces->size()); i++) { fprintf(stderr, "%3d: %d\n", i, (*pieces)[i]); } return is_ok; } else { - return spp_->Encode(input, pieces).ok(); + return spp_ && spp_->Encode(input, pieces).ok(); } } // Given a sequence of ids, decodes it into a detokenized output. bool Decode(const std::vector& ids, std::string* detokenized) const { - return spp_->Decode(ids, detokenized).ok(); + return spp_ && spp_->Decode(ids, detokenized).ok(); } private: @@ -192,6 +193,10 @@ GemmaTokenizer::GemmaTokenizer(const Path& tokenizer_path) { impl_ = std::make_unique(tokenizer_path); } +GemmaTokenizer::GemmaTokenizer() { + impl_ = std::make_unique(); +} + GemmaTokenizer::~GemmaTokenizer() = default; GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default; GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default; @@ -942,7 +947,7 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, hwy::ThreadPool& pool) : pool_(pool), tokenizer_(std::move(tokenizer)), model_type_(model_type) { - weights_u8_ = CallFunctorForModel(model_type, pool); + weights_u8_ = CallFunctorForModel(model_type, pool); prefill_u8_ = CallFunctorForModel(model_type); decode_u8_ = CallFunctorForModel(model_type); } diff --git a/gemma/gemma.h b/gemma/gemma.h index 2e7673a..e9fcb2f 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -48,7 +48,7 @@ constexpr int EOS_ID = 1; class GemmaTokenizer { public: - GemmaTokenizer() = default; // for second Gemma ctor. + GemmaTokenizer(); explicit GemmaTokenizer(const Path& tokenizer_path); // must come after definition of Impl diff --git a/gemma/weights.cc b/gemma/weights.cc index cfb2b86..d1e5fd1 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -73,7 +73,7 @@ struct LoadRawWeightsT { checkpoint.path.c_str()); } - ByteStorageT weights_u8 = AllocateWeights()(pool); + ByteStorageT weights_u8 = AllocateWeightsF()(pool); auto* weights = reinterpret_cast*>(weights_u8.get()); size_t scale_pos = 0; diff --git a/gemma/weights.h b/gemma/weights.h index d6ebd62..678bcc6 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -242,21 +242,29 @@ using WeightsT = hwy::If, WeightsF>; // Call via CallFunctorForModel. -template +template struct AllocateWeights { ByteStorageT operator()(hwy::ThreadPool& pool) const { - using TWeights = WeightsF; + using TWeights = Weights; ByteStorageT weights_u8 = AllocateSizeof(); TWeights* weights = reinterpret_cast(weights_u8.get()); - new (&weights->layer_ptrs) LayerPointers(pool); + new (&weights->layer_ptrs) LayerPointers(pool); return weights_u8; } }; template +struct AllocateWeightsF { + ByteStorageT operator()(hwy::ThreadPool& pool) const { + return AllocateWeights()(pool); + } +}; + +template struct ZeroInitWeights { void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const { - WeightsF& w = *reinterpret_cast*>(weights.get()); + Weights& w = + *reinterpret_cast*>(weights.get()); hwy::ZeroBytes(&w.embedder_input_embedding, sizeof(w.embedder_input_embedding)); hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale)); @@ -267,8 +275,16 @@ struct ZeroInitWeights { }; template +struct ZeroInitWeightsF { + void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const { + ZeroInitWeights()(weights, pool); + } +}; + +template struct CopyWeights { - void operator()(WeightsF& dst, const WeightsF& src) const { +void operator()(Weights& dst, + const Weights& src) const { hwy::CopyBytes(&src.embedder_input_embedding, &dst.embedder_input_embedding, sizeof(src.embedder_input_embedding)); hwy::CopyBytes(&src.final_norm_scale, &dst.final_norm_scale, @@ -298,14 +314,14 @@ class WeightsWrapper { public: WeightsWrapper() : pool_(0), - data_(AllocateWeights(pool_)), + data_(AllocateWeights()(pool_)), weights_(reinterpret_cast*>(data_.get())) {} const Weights& get() const { return *weights_; } Weights& get() { return *weights_; } - void clear() { ZeroInitWeights()(get()); } + void clear() { ZeroInitWeights()(data_, pool_); } void copy(const WeightsWrapper& other) { - CopyWeights()(get(), other.get()); + CopyWeights()(get(), other.get()); } private: @@ -397,7 +413,7 @@ void ForEachTensor(const WeightsF* weights, #define GEMMA_CALL_TOP_FUNC3(name, member) \ func(name, weights1.member, weights2.member, weights3.member) #define GEMMA_CALL_TOP_FUNC4(name, member) \ - func(name, weights1.member, weights2.memeber, \ + func(name, weights1.member, weights2.member, \ weights3.member, weights4.member) #define GEMMA_CALL_LAYER_FUNC1(name, member) \ @@ -414,7 +430,7 @@ void ForEachTensor(const WeightsF* weights, #define GEMMA_CALL_LAYER_FUNC4(name, member) \ snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ - func(name_buf, layer1.member, layer2.member, layer4.member) + func(name_buf, layer1.member, layer2.member, layer3.member, layer4.member) #define GEMMA_CALL_ALL_LAYER_FUNC(N) \ if (type == LayerAttentionType::kGemma) { \ @@ -491,6 +507,25 @@ void ForEachTensor2(Func& func, const Weights& weights1, } } +template +void ForEachTensor4(Func& func, const Weights& weights1, + Weights& weights2, + Weights& weights3, + Weights& weights4) { + GEMMA_CALL_TOP_FUNC4("embedding", embedder_input_embedding); + GEMMA_CALL_TOP_FUNC4("final_norm", final_norm_scale); + char name_buf[16]; + for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) { + auto type = TConfig::kLayerConfig[layer_idx]; + const size_t idx = static_cast(layer_idx); + const LayerF& layer1 = *weights1.GetLayer(idx); + LayerF& layer2 = *weights2.GetLayer(idx); + LayerF& layer3 = *weights3.GetLayer(idx); + LayerF& layer4 = *weights4.GetLayer(idx); + GEMMA_CALL_ALL_LAYER_FUNC(4) + } +} + #undef GEMMA_CALL_TOP_FUNC1 #undef GEMMA_CALL_TOP_FUNC2 #undef GEMMA_CALL_TOP_FUNC3