Add Adam optimizer.

Drive-by: Fix compilation errors and tests for backprop functions.
This commit is contained in:
Zoltan Szabadka 2024-06-06 18:41:36 +00:00
parent 12707ade80
commit c004799cdc
14 changed files with 260 additions and 48 deletions

View File

@ -50,6 +50,7 @@ set(SOURCES
backprop/backward.h backprop/backward.h
backprop/backward-inl.h backprop/backward-inl.h
backprop/backward_scalar.h backprop/backward_scalar.h
backprop/common_scalar.h
backprop/forward.cc backprop/forward.cc
backprop/forward.h backprop/forward.h
backprop/forward-inl.h backprop/forward-inl.h

View File

@ -23,6 +23,7 @@
#include <complex> #include <complex>
#include <vector> #include <vector>
#include "backprop/common_scalar.h"
#include "backprop/prompt.h" #include "backprop/prompt.h"
#include "gemma/activations.h" #include "gemma/activations.h"
#include "gemma/common.h" // EmbeddingScaling #include "gemma/common.h" // EmbeddingScaling

117
backprop/common_scalar.h Normal file
View File

@ -0,0 +1,117 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
#include <complex>
namespace gcpp {
template<typename T, typename U>
U DotT(const T* a, const U* b, size_t N) {
U sum = {};
for (size_t i = 0; i < N; ++i) {
sum += a[i] * b[i];
}
return sum;
}
template<>
std::complex<double> DotT(const float* a, const std::complex<double>* b,
size_t N) {
std::complex<double> sum = {};
for (size_t i = 0; i < N; ++i) {
sum += static_cast<double>(a[i]) * b[i];
}
return sum;
}
template<typename T>
void MulByConstT(T c, T* x, size_t N) {
for (size_t i = 0; i < N; ++i) {
x[i] *= c;
}
}
// out += c * x
template<typename T>
void MulByConstAndAddT(T c, const T* x, T* out, size_t N) {
for (size_t i = 0; i < N; ++i) {
out[i] += c * x[i];
}
}
template<typename T, size_t N>
void MulByConstAndAddT(T c, const std::array<T, N>& x, std::array<T, N>& out) {
MulByConstAndAddT(c, x.data(), out.data(), N);
}
template<typename T>
void AddFromT(const T* a, T* out, size_t N) {
for (size_t i = 0; i < N; ++i) {
out[i] += a[i];
}
}
template<typename T>
T SquaredL2(const T* x, size_t N) {
T sum = {};
for (size_t i = 0; i < N; ++i) {
sum += x[i] * x[i];
}
return sum;
}
template<typename T>
T Gelu(T x) {
static const T kMul = 0.044715;
static const T kSqrt2OverPi = 0.797884560804236;
const T x3 = x * x * x;
const T arg = kSqrt2OverPi * (kMul * x3 + x);
const T cdf = T(0.5) * (T(1.0) + std::tanh(arg));
return x * cdf;
}
template<typename T, typename U>
void Rope(T* x, U base, size_t N, int i) {
const size_t N2 = N / 2;
for (size_t dim = 0; dim < N2; ++dim) {
const T freq_exponents = T(2 * dim) / T(N);
const T timescale = std::pow(base, freq_exponents);
const T theta = T(i) / timescale;
const T cos_val = std::cos(theta);
const T sin_val = std::sin(theta);
const T x0 = x[dim];
const T x1 = x[dim + N2];
x[dim] = x0 * cos_val - x1 * sin_val;
x[dim + N2] = x0 * sin_val + x1 * cos_val;
}
}
template<typename T>
void Rope(T* x, size_t N, int i) {
Rope(x, T(10000.0), N, i);
}
template<typename T>
void Rope(std::complex<T>* x, size_t N, int i) {
Rope(x, T(10000.0), N, i);
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_

View File

@ -23,6 +23,7 @@
#include <complex> #include <complex>
#include <vector> #include <vector>
#include "backprop/common_scalar.h"
#include "backprop/prompt.h" #include "backprop/prompt.h"
#include "gemma/activations.h" #include "gemma/activations.h"
#include "gemma/common.h" // EmbeddingScaling #include "gemma/common.h" // EmbeddingScaling

View File

@ -16,7 +16,6 @@
#include <iostream> #include <iostream>
#include <string> #include <string>
#include "gtest/gtest.h"
#include "backprop/backward.h" #include "backprop/backward.h"
#include "backprop/forward.h" #include "backprop/forward.h"
#include "backprop/optimizer.h" #include "backprop/optimizer.h"
@ -24,6 +23,7 @@
#include "gemma/activations.h" #include "gemma/activations.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "gtest/gtest.h"
namespace gcpp { namespace gcpp {
@ -32,9 +32,9 @@ TEST(OptimizeTest, GradientDescent) {
std::mt19937 gen(42); std::mt19937 gen(42);
Model model_type = Model::GEMMA_TINY; Model model_type = Model::GEMMA_TINY;
ByteStorageT grad = CallFunctorForModel<AllocateWeights>(model_type, pool); ByteStorageT grad = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
ByteStorageT grad_m = CallFunctorForModel<AllocateWeights>(model_type, pool); ByteStorageT grad_m = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
ByteStorageT grad_v = CallFunctorForModel<AllocateWeights>(model_type, pool); ByteStorageT grad_v = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
ByteStorageT forward = CallFunctorForModel<AllocateForwardPass>(model_type); ByteStorageT forward = CallFunctorForModel<AllocateForwardPass>(model_type);
ByteStorageT backward = CallFunctorForModel<AllocateForwardPass>(model_type); ByteStorageT backward = CallFunctorForModel<AllocateForwardPass>(model_type);
KVCache kv_cache = KVCache::Create(model_type); KVCache kv_cache = KVCache::Create(model_type);
@ -57,7 +57,7 @@ TEST(OptimizeTest, GradientDescent) {
stream_token, accept_token, ReverseSequenceSampler::kEndToken, stream_token, accept_token, ReverseSequenceSampler::kEndToken,
}; };
TimingInfo timing_info; TimingInfo timing_info;
model.Generate(runtime, prompt, 0, kv_cache, timing_info); gemma.Generate(runtime, prompt, 0, kv_cache, timing_info);
return reply; return reply;
}; };
@ -73,15 +73,18 @@ TEST(OptimizeTest, GradientDescent) {
return ok; return ok;
}; };
CallFunctorForModel<RandInitWeights>(model_type, gemma.Weights(), pool, gen); RandInitWeights(model_type, gemma.Weights(), pool, gen);
CallFunctorForModel<ZeroInitWeights>(model_type, grad_m, pool); CallFunctorForModel<ZeroInitWeightsF>(model_type, grad_m, pool);
CallFunctorForModel<ZeroInitWeights>(model_type, grad_v, pool); CallFunctorForModel<ZeroInitWeightsF>(model_type, grad_v, pool);
printf("Initial weights:\n"); printf("Initial weights:\n");
LogWeightStats(model_type, gemma.Weights()); LogWeightStats(model_type, gemma.Weights());
constexpr size_t kBatchSize = 8; constexpr size_t kBatchSize = 8;
float learning_rate = 0.0005f; const float alpha = 0.001f;
const float beta1 = 0.9f;
const float beta2 = 0.999f;
const float epsilon = 1e-8f;
ReverseSequenceSampler training_task({ ReverseSequenceSampler training_task({
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1}); 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1});
@ -90,7 +93,7 @@ TEST(OptimizeTest, GradientDescent) {
size_t num_ok; size_t num_ok;
for (; steps < 1000000; ++steps) { for (; steps < 1000000; ++steps) {
std::mt19937 sgen(42); std::mt19937 sgen(42);
CallFunctorForModel<ZeroInitWeights>(model_type, grad, pool); CallFunctorForModel<ZeroInitWeightsF>(model_type, grad, pool);
float total_loss = 0.0f; float total_loss = 0.0f;
num_ok = 0; num_ok = 0;
for (size_t i = 0; i < kBatchSize; ++i) { for (size_t i = 0; i < kBatchSize; ++i) {
@ -103,8 +106,8 @@ TEST(OptimizeTest, GradientDescent) {
} }
total_loss /= kBatchSize; total_loss /= kBatchSize;
const float scale = -learning_rate / kBatchSize; AdamUpdate(model_type, grad, alpha, beta1, beta2, epsilon, steps + 1,
UpdateWeights(model_type, grad, scale, gemma.Weights(), pool); gemma.Weights(), grad_m, grad_v, pool);
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n", printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
steps, total_loss, num_ok, kBatchSize); steps, total_loss, num_ok, kBatchSize);
if (steps % 100 == 0) { if (steps % 100 == 0) {
@ -119,7 +122,7 @@ TEST(OptimizeTest, GradientDescent) {
printf("Num steps: %zu\n", steps); printf("Num steps: %zu\n", steps);
printf("Final weights:\n"); printf("Final weights:\n");
LogWeightStats(model_type, gemma.Weights()); LogWeightStats(model_type, gemma.Weights());
EXPECT_LT(steps, 3000); EXPECT_LT(steps, 200);
EXPECT_EQ(num_ok, kBatchSize); EXPECT_EQ(num_ok, kBatchSize);
} }

View File

@ -25,6 +25,7 @@
namespace gcpp { namespace gcpp {
namespace { namespace {
class WeightInitializer { class WeightInitializer {
public: public:
WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {} WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {}
@ -41,8 +42,8 @@ class WeightInitializer {
}; };
template <typename TConfig> template <typename TConfig>
struct RandInitWeights { struct RandInitWeightsT {
void operator()(ByteStorageT& weights_u8, hwy::ThreadPool& pool, void operator()(const ByteStorageT& weights_u8, hwy::ThreadPool& pool,
std::mt19937& gen) const { std::mt19937& gen) const {
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get()); auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
// TODO(szabadka) Use the same weight initialization method as in the python // TODO(szabadka) Use the same weight initialization method as in the python
@ -52,39 +53,71 @@ struct RandInitWeights {
} }
}; };
class WeightUpdater { class AdamUpdater {
public: public:
explicit WeightUpdater(float lr) : lr_(lr) {} explicit AdamUpdater(float alpha, float beta1, float beta2, float epsilon,
size_t t)
: alpha_(alpha), beta1_(beta1), beta2_(beta2), cbeta1_(1.0f - beta1),
cbeta2_(1.0f - beta2), norm1_(1.0 / (1.0 - std::pow(beta1, t))),
norm2_(1.0 / (1.0 - std::pow(beta2, t))), epsilon_(epsilon) {}
template <size_t kCapacity> template <size_t kCapacity>
void operator()(const char* name, const std::array<float, kCapacity>& grad, void operator()(const char* name, const std::array<float, kCapacity>& grad,
std::array<float, kCapacity>& weights) { std::array<float, kCapacity>& weights,
std::array<float, kCapacity>& grad_m,
std::array<float, kCapacity>& grad_v) {
for (size_t i = 0; i < kCapacity; ++i) { for (size_t i = 0; i < kCapacity; ++i) {
weights[i] += lr_ * grad[i]; grad_m[i] *= beta1_;
grad_m[i] += cbeta1_ * grad[i];
grad_v[i] *= beta2_;
grad_v[i] += cbeta2_ * grad[i] * grad[i];
const float mhat = grad_m[i] * norm1_;
const float vhat = grad_v[i] * norm2_;
weights[i] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_);
} }
} }
private: private:
float lr_; float alpha_;
float beta1_;
float beta2_;
float cbeta1_;
float cbeta2_;
float norm1_;
float norm2_;
float epsilon_;
}; };
template <typename TConfig> template <typename TConfig>
struct UpdateWeightsT { struct AdamUpdateT {
void operator()(const ByteStorageT& grad_u8, float scale, void operator()(const ByteStorageT& grad_u8, float alpha, float beta1,
ByteStorageT& weights_u8, hwy::ThreadPool& pool) const { float beta2, float epsilon, size_t t,
const ByteStorageT& weights_u8, const ByteStorageT& grad_m_u8,
const ByteStorageT& grad_v_u8, hwy::ThreadPool& pool) const {
const auto& grad = const auto& grad =
*reinterpret_cast<const WeightsF<TConfig>*>(grad_u8.get()); *reinterpret_cast<const WeightsF<TConfig>*>(grad_u8.get());
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get()); auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
WeightUpdater updater(scale); auto& grad_m = *reinterpret_cast<WeightsF<TConfig>*>(grad_m_u8.get());
ForEachTensor2<float, TConfig>(updater, grad, weights); auto& grad_v = *reinterpret_cast<WeightsF<TConfig>*>(grad_v_u8.get());
AdamUpdater updater(alpha, beta1, beta2, epsilon, t);
ForEachTensor4<float, TConfig>(updater, grad, weights, grad_m, grad_v);
} }
}; };
} // namespace } // namespace
void UpdateWeights(Model model, const ByteStorageT& grad, float scale, void RandInitWeights(Model model, const ByteStorageT& weights,
ByteStorageT& weights, hwy::ThreadPool& pool) { hwy::ThreadPool& pool,
CallFunctorForModel<UpdateWeightsT>(model, grad, scale, weights, pool); std::mt19937& gen) {
CallFunctorForModel<RandInitWeightsT>(model, weights, pool, gen);
}
void AdamUpdate(Model model, const ByteStorageT& grad, float alpha, float beta1,
float beta2, float epsilon, size_t t,
const ByteStorageT& weights, const ByteStorageT& grad_m,
const ByteStorageT& grad_v, hwy::ThreadPool& pool) {
CallFunctorForModel<AdamUpdateT>(model, grad, alpha, beta1, beta2, epsilon, t,
weights, grad_m, grad_v, pool);
} }
} // namespace gcpp } // namespace gcpp

View File

@ -16,13 +16,21 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
#include <random>
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/weights.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp { namespace gcpp {
void UpdateWeights(Model model, const ByteStorageT& grad, float scale, void RandInitWeights(Model model, const ByteStorageT& weights,
ByteStorageT& weights, hwy::ThreadPool& pool); hwy::ThreadPool& pool, std::mt19937& gen);
void AdamUpdate(Model model, const ByteStorageT& grad, float alpha, float beta1,
float beta2, float epsilon, size_t t,
const ByteStorageT& weights, const ByteStorageT& grad_m,
const ByteStorageT& grad_v, hwy::ThreadPool& pool);
} // namespace gcpp } // namespace gcpp

View File

@ -16,6 +16,7 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
#include <random>
#include <vector> #include <vector>
#include "backprop/prompt.h" #include "backprop/prompt.h"

View File

@ -73,7 +73,7 @@ class ActivationsWrapper {
public: public:
ActivationsWrapper() ActivationsWrapper()
: data_(WrappedT::Allocate()), : data_(AllocateSizeof<WrappedT>()),
activations_(*reinterpret_cast<WrappedT*>(data_.get())) {} activations_(*reinterpret_cast<WrappedT*>(data_.get())) {}
const WrappedT& get() const { return activations_; } const WrappedT& get() const { return activations_; }

View File

@ -125,6 +125,13 @@ GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
Sqrt(static_cast<float>(TConfig::kModelDim)))); Sqrt(static_cast<float>(TConfig::kModelDim))));
} }
static HWY_INLINE GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling(
size_t model_dim) {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
Sqrt(static_cast<float>(model_dim))));
}
} // namespace gcpp } // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_

View File

@ -154,6 +154,7 @@ KVCache KVCache::Create(Model type) {
class GemmaTokenizer::Impl { class GemmaTokenizer::Impl {
public: public:
Impl() = default;
explicit Impl(const Path& tokenizer_path) { explicit Impl(const Path& tokenizer_path) {
PROFILER_ZONE("Startup.tokenizer"); PROFILER_ZONE("Startup.tokenizer");
spp_ = std::make_unique<sentencepiece::SentencePieceProcessor>(); spp_ = std::make_unique<sentencepiece::SentencePieceProcessor>();
@ -164,24 +165,24 @@ class GemmaTokenizer::Impl {
bool Encode(const std::string& input, bool Encode(const std::string& input,
std::vector<std::string>* pieces) const { std::vector<std::string>* pieces) const {
return spp_->Encode(input, pieces).ok(); return spp_ && spp_->Encode(input, pieces).ok();
} }
bool Encode(const std::string& input, std::vector<int>* pieces) const { bool Encode(const std::string& input, std::vector<int>* pieces) const {
if constexpr (kShowTokenization) { if constexpr (kShowTokenization) {
bool is_ok = spp_->Encode(input, pieces).ok(); bool is_ok = spp_ && spp_->Encode(input, pieces).ok();
for (int i = 0; i < static_cast<int>(pieces->size()); i++) { for (int i = 0; i < static_cast<int>(pieces->size()); i++) {
fprintf(stderr, "%3d: %d\n", i, (*pieces)[i]); fprintf(stderr, "%3d: %d\n", i, (*pieces)[i]);
} }
return is_ok; return is_ok;
} else { } else {
return spp_->Encode(input, pieces).ok(); return spp_ && spp_->Encode(input, pieces).ok();
} }
} }
// Given a sequence of ids, decodes it into a detokenized output. // Given a sequence of ids, decodes it into a detokenized output.
bool Decode(const std::vector<int>& ids, std::string* detokenized) const { bool Decode(const std::vector<int>& ids, std::string* detokenized) const {
return spp_->Decode(ids, detokenized).ok(); return spp_ && spp_->Decode(ids, detokenized).ok();
} }
private: private:
@ -192,6 +193,10 @@ GemmaTokenizer::GemmaTokenizer(const Path& tokenizer_path) {
impl_ = std::make_unique<Impl>(tokenizer_path); impl_ = std::make_unique<Impl>(tokenizer_path);
} }
GemmaTokenizer::GemmaTokenizer() {
impl_ = std::make_unique<Impl>();
}
GemmaTokenizer::~GemmaTokenizer() = default; GemmaTokenizer::~GemmaTokenizer() = default;
GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default; GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default;
GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default; GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default;
@ -942,7 +947,7 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type,
hwy::ThreadPool& pool) hwy::ThreadPool& pool)
: pool_(pool), tokenizer_(std::move(tokenizer)), model_type_(model_type) { : pool_(pool), tokenizer_(std::move(tokenizer)), model_type_(model_type) {
weights_u8_ = CallFunctorForModel<AllocateWeights>(model_type, pool); weights_u8_ = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
prefill_u8_ = CallFunctorForModel<AllocatePrefill>(model_type); prefill_u8_ = CallFunctorForModel<AllocatePrefill>(model_type);
decode_u8_ = CallFunctorForModel<AllocateDecode>(model_type); decode_u8_ = CallFunctorForModel<AllocateDecode>(model_type);
} }

View File

@ -48,7 +48,7 @@ constexpr int EOS_ID = 1;
class GemmaTokenizer { class GemmaTokenizer {
public: public:
GemmaTokenizer() = default; // for second Gemma ctor. GemmaTokenizer();
explicit GemmaTokenizer(const Path& tokenizer_path); explicit GemmaTokenizer(const Path& tokenizer_path);
// must come after definition of Impl // must come after definition of Impl

View File

@ -73,7 +73,7 @@ struct LoadRawWeightsT {
checkpoint.path.c_str()); checkpoint.path.c_str());
} }
ByteStorageT weights_u8 = AllocateWeights<TConfig>()(pool); ByteStorageT weights_u8 = AllocateWeightsF<TConfig>()(pool);
auto* weights = reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get()); auto* weights = reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
size_t scale_pos = 0; size_t scale_pos = 0;

View File

@ -242,21 +242,29 @@ using WeightsT = hwy::If<kWeightsAreCompressed, CompressedWeights<TConfig>,
WeightsF<TConfig>>; WeightsF<TConfig>>;
// Call via CallFunctorForModel. // Call via CallFunctorForModel.
template <typename TConfig> template <typename T, typename TConfig>
struct AllocateWeights { struct AllocateWeights {
ByteStorageT operator()(hwy::ThreadPool& pool) const { ByteStorageT operator()(hwy::ThreadPool& pool) const {
using TWeights = WeightsF<TConfig>; using TWeights = Weights<T, TConfig>;
ByteStorageT weights_u8 = AllocateSizeof<TWeights>(); ByteStorageT weights_u8 = AllocateSizeof<TWeights>();
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get()); TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
new (&weights->layer_ptrs) LayerPointers<float, TConfig>(pool); new (&weights->layer_ptrs) LayerPointers<T, TConfig>(pool);
return weights_u8; return weights_u8;
} }
}; };
template <typename TConfig> template <typename TConfig>
struct AllocateWeightsF {
ByteStorageT operator()(hwy::ThreadPool& pool) const {
return AllocateWeights<float, TConfig>()(pool);
}
};
template <typename T, typename TConfig>
struct ZeroInitWeights { struct ZeroInitWeights {
void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const { void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const {
WeightsF<TConfig>& w = *reinterpret_cast<WeightsF<TConfig>*>(weights.get()); Weights<T, TConfig>& w =
*reinterpret_cast<Weights<T, TConfig>*>(weights.get());
hwy::ZeroBytes(&w.embedder_input_embedding, hwy::ZeroBytes(&w.embedder_input_embedding,
sizeof(w.embedder_input_embedding)); sizeof(w.embedder_input_embedding));
hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale)); hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale));
@ -267,8 +275,16 @@ struct ZeroInitWeights {
}; };
template <typename TConfig> template <typename TConfig>
struct ZeroInitWeightsF {
void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const {
ZeroInitWeights<float, TConfig>()(weights, pool);
}
};
template <typename T, typename TConfig>
struct CopyWeights { struct CopyWeights {
void operator()(WeightsF<TConfig>& dst, const WeightsF<TConfig>& src) const { void operator()(Weights<T, TConfig>& dst,
const Weights<T, TConfig>& src) const {
hwy::CopyBytes(&src.embedder_input_embedding, &dst.embedder_input_embedding, hwy::CopyBytes(&src.embedder_input_embedding, &dst.embedder_input_embedding,
sizeof(src.embedder_input_embedding)); sizeof(src.embedder_input_embedding));
hwy::CopyBytes(&src.final_norm_scale, &dst.final_norm_scale, hwy::CopyBytes(&src.final_norm_scale, &dst.final_norm_scale,
@ -298,14 +314,14 @@ class WeightsWrapper {
public: public:
WeightsWrapper() WeightsWrapper()
: pool_(0), : pool_(0),
data_(AllocateWeights<TConfig>(pool_)), data_(AllocateWeights<T, TConfig>()(pool_)),
weights_(reinterpret_cast<Weights<T, TConfig>*>(data_.get())) {} weights_(reinterpret_cast<Weights<T, TConfig>*>(data_.get())) {}
const Weights<T, TConfig>& get() const { return *weights_; } const Weights<T, TConfig>& get() const { return *weights_; }
Weights<T, TConfig>& get() { return *weights_; } Weights<T, TConfig>& get() { return *weights_; }
void clear() { ZeroInitWeights<TConfig>()(get()); } void clear() { ZeroInitWeights<T, TConfig>()(data_, pool_); }
void copy(const WeightsWrapper<T, TConfig>& other) { void copy(const WeightsWrapper<T, TConfig>& other) {
CopyWeights<TConfig>()(get(), other.get()); CopyWeights<T, TConfig>()(get(), other.get());
} }
private: private:
@ -397,7 +413,7 @@ void ForEachTensor(const WeightsF<TConfig>* weights,
#define GEMMA_CALL_TOP_FUNC3(name, member) \ #define GEMMA_CALL_TOP_FUNC3(name, member) \
func(name, weights1.member, weights2.member, weights3.member) func(name, weights1.member, weights2.member, weights3.member)
#define GEMMA_CALL_TOP_FUNC4(name, member) \ #define GEMMA_CALL_TOP_FUNC4(name, member) \
func(name, weights1.member, weights2.memeber, \ func(name, weights1.member, weights2.member, \
weights3.member, weights4.member) weights3.member, weights4.member)
#define GEMMA_CALL_LAYER_FUNC1(name, member) \ #define GEMMA_CALL_LAYER_FUNC1(name, member) \
@ -414,7 +430,7 @@ void ForEachTensor(const WeightsF<TConfig>* weights,
#define GEMMA_CALL_LAYER_FUNC4(name, member) \ #define GEMMA_CALL_LAYER_FUNC4(name, member) \
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \ snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
func(name_buf, layer1.member, layer2.member, layer4.member) func(name_buf, layer1.member, layer2.member, layer3.member, layer4.member)
#define GEMMA_CALL_ALL_LAYER_FUNC(N) \ #define GEMMA_CALL_ALL_LAYER_FUNC(N) \
if (type == LayerAttentionType::kGemma) { \ if (type == LayerAttentionType::kGemma) { \
@ -491,6 +507,25 @@ void ForEachTensor2(Func& func, const Weights<T, TConfig>& weights1,
} }
} }
template <typename T, typename TConfig, class Func>
void ForEachTensor4(Func& func, const Weights<T, TConfig>& weights1,
Weights<T, TConfig>& weights2,
Weights<T, TConfig>& weights3,
Weights<T, TConfig>& weights4) {
GEMMA_CALL_TOP_FUNC4("embedding", embedder_input_embedding);
GEMMA_CALL_TOP_FUNC4("final_norm", final_norm_scale);
char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx);
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
LayerF<TConfig>& layer2 = *weights2.GetLayer(idx);
LayerF<TConfig>& layer3 = *weights3.GetLayer(idx);
LayerF<TConfig>& layer4 = *weights4.GetLayer(idx);
GEMMA_CALL_ALL_LAYER_FUNC(4)
}
}
#undef GEMMA_CALL_TOP_FUNC1 #undef GEMMA_CALL_TOP_FUNC1
#undef GEMMA_CALL_TOP_FUNC2 #undef GEMMA_CALL_TOP_FUNC2
#undef GEMMA_CALL_TOP_FUNC3 #undef GEMMA_CALL_TOP_FUNC3