mirror of https://github.com/google/gemma.cpp.git
Add Adam optimizer.
Drive-by: Fix compilation errors and tests for backprop functions.
This commit is contained in:
parent
12707ade80
commit
c004799cdc
|
|
@ -50,6 +50,7 @@ set(SOURCES
|
|||
backprop/backward.h
|
||||
backprop/backward-inl.h
|
||||
backprop/backward_scalar.h
|
||||
backprop/common_scalar.h
|
||||
backprop/forward.cc
|
||||
backprop/forward.h
|
||||
backprop/forward-inl.h
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
#include <complex>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/common_scalar.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h" // EmbeddingScaling
|
||||
|
|
|
|||
|
|
@ -0,0 +1,117 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
|
||||
|
||||
#include <complex>
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
template<typename T, typename U>
|
||||
U DotT(const T* a, const U* b, size_t N) {
|
||||
U sum = {};
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
sum += a[i] * b[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<>
|
||||
std::complex<double> DotT(const float* a, const std::complex<double>* b,
|
||||
size_t N) {
|
||||
std::complex<double> sum = {};
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
sum += static_cast<double>(a[i]) * b[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void MulByConstT(T c, T* x, size_t N) {
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
x[i] *= c;
|
||||
}
|
||||
}
|
||||
|
||||
// out += c * x
|
||||
template<typename T>
|
||||
void MulByConstAndAddT(T c, const T* x, T* out, size_t N) {
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
out[i] += c * x[i];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, size_t N>
|
||||
void MulByConstAndAddT(T c, const std::array<T, N>& x, std::array<T, N>& out) {
|
||||
MulByConstAndAddT(c, x.data(), out.data(), N);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void AddFromT(const T* a, T* out, size_t N) {
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
out[i] += a[i];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T SquaredL2(const T* x, size_t N) {
|
||||
T sum = {};
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
sum += x[i] * x[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T Gelu(T x) {
|
||||
static const T kMul = 0.044715;
|
||||
static const T kSqrt2OverPi = 0.797884560804236;
|
||||
|
||||
const T x3 = x * x * x;
|
||||
const T arg = kSqrt2OverPi * (kMul * x3 + x);
|
||||
const T cdf = T(0.5) * (T(1.0) + std::tanh(arg));
|
||||
return x * cdf;
|
||||
}
|
||||
|
||||
template<typename T, typename U>
|
||||
void Rope(T* x, U base, size_t N, int i) {
|
||||
const size_t N2 = N / 2;
|
||||
for (size_t dim = 0; dim < N2; ++dim) {
|
||||
const T freq_exponents = T(2 * dim) / T(N);
|
||||
const T timescale = std::pow(base, freq_exponents);
|
||||
const T theta = T(i) / timescale;
|
||||
const T cos_val = std::cos(theta);
|
||||
const T sin_val = std::sin(theta);
|
||||
const T x0 = x[dim];
|
||||
const T x1 = x[dim + N2];
|
||||
x[dim] = x0 * cos_val - x1 * sin_val;
|
||||
x[dim + N2] = x0 * sin_val + x1 * cos_val;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void Rope(T* x, size_t N, int i) {
|
||||
Rope(x, T(10000.0), N, i);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void Rope(std::complex<T>* x, size_t N, int i) {
|
||||
Rope(x, T(10000.0), N, i);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
|
||||
|
|
@ -23,6 +23,7 @@
|
|||
#include <complex>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/common_scalar.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h" // EmbeddingScaling
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@
|
|||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "backprop/backward.h"
|
||||
#include "backprop/forward.h"
|
||||
#include "backprop/optimizer.h"
|
||||
|
|
@ -24,6 +23,7 @@
|
|||
#include "gemma/activations.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -32,9 +32,9 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
std::mt19937 gen(42);
|
||||
|
||||
Model model_type = Model::GEMMA_TINY;
|
||||
ByteStorageT grad = CallFunctorForModel<AllocateWeights>(model_type, pool);
|
||||
ByteStorageT grad_m = CallFunctorForModel<AllocateWeights>(model_type, pool);
|
||||
ByteStorageT grad_v = CallFunctorForModel<AllocateWeights>(model_type, pool);
|
||||
ByteStorageT grad = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
|
||||
ByteStorageT grad_m = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
|
||||
ByteStorageT grad_v = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
|
||||
ByteStorageT forward = CallFunctorForModel<AllocateForwardPass>(model_type);
|
||||
ByteStorageT backward = CallFunctorForModel<AllocateForwardPass>(model_type);
|
||||
KVCache kv_cache = KVCache::Create(model_type);
|
||||
|
|
@ -57,7 +57,7 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
stream_token, accept_token, ReverseSequenceSampler::kEndToken,
|
||||
};
|
||||
TimingInfo timing_info;
|
||||
model.Generate(runtime, prompt, 0, kv_cache, timing_info);
|
||||
gemma.Generate(runtime, prompt, 0, kv_cache, timing_info);
|
||||
return reply;
|
||||
};
|
||||
|
||||
|
|
@ -73,15 +73,18 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
return ok;
|
||||
};
|
||||
|
||||
CallFunctorForModel<RandInitWeights>(model_type, gemma.Weights(), pool, gen);
|
||||
CallFunctorForModel<ZeroInitWeights>(model_type, grad_m, pool);
|
||||
CallFunctorForModel<ZeroInitWeights>(model_type, grad_v, pool);
|
||||
RandInitWeights(model_type, gemma.Weights(), pool, gen);
|
||||
CallFunctorForModel<ZeroInitWeightsF>(model_type, grad_m, pool);
|
||||
CallFunctorForModel<ZeroInitWeightsF>(model_type, grad_v, pool);
|
||||
|
||||
printf("Initial weights:\n");
|
||||
LogWeightStats(model_type, gemma.Weights());
|
||||
|
||||
constexpr size_t kBatchSize = 8;
|
||||
float learning_rate = 0.0005f;
|
||||
const float alpha = 0.001f;
|
||||
const float beta1 = 0.9f;
|
||||
const float beta2 = 0.999f;
|
||||
const float epsilon = 1e-8f;
|
||||
|
||||
ReverseSequenceSampler training_task({
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1});
|
||||
|
|
@ -90,7 +93,7 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
size_t num_ok;
|
||||
for (; steps < 1000000; ++steps) {
|
||||
std::mt19937 sgen(42);
|
||||
CallFunctorForModel<ZeroInitWeights>(model_type, grad, pool);
|
||||
CallFunctorForModel<ZeroInitWeightsF>(model_type, grad, pool);
|
||||
float total_loss = 0.0f;
|
||||
num_ok = 0;
|
||||
for (size_t i = 0; i < kBatchSize; ++i) {
|
||||
|
|
@ -103,8 +106,8 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
}
|
||||
total_loss /= kBatchSize;
|
||||
|
||||
const float scale = -learning_rate / kBatchSize;
|
||||
UpdateWeights(model_type, grad, scale, gemma.Weights(), pool);
|
||||
AdamUpdate(model_type, grad, alpha, beta1, beta2, epsilon, steps + 1,
|
||||
gemma.Weights(), grad_m, grad_v, pool);
|
||||
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
|
||||
steps, total_loss, num_ok, kBatchSize);
|
||||
if (steps % 100 == 0) {
|
||||
|
|
@ -119,7 +122,7 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
printf("Num steps: %zu\n", steps);
|
||||
printf("Final weights:\n");
|
||||
LogWeightStats(model_type, gemma.Weights());
|
||||
EXPECT_LT(steps, 3000);
|
||||
EXPECT_LT(steps, 200);
|
||||
EXPECT_EQ(num_ok, kBatchSize);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@
|
|||
namespace gcpp {
|
||||
|
||||
namespace {
|
||||
|
||||
class WeightInitializer {
|
||||
public:
|
||||
WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {}
|
||||
|
|
@ -41,8 +42,8 @@ class WeightInitializer {
|
|||
};
|
||||
|
||||
template <typename TConfig>
|
||||
struct RandInitWeights {
|
||||
void operator()(ByteStorageT& weights_u8, hwy::ThreadPool& pool,
|
||||
struct RandInitWeightsT {
|
||||
void operator()(const ByteStorageT& weights_u8, hwy::ThreadPool& pool,
|
||||
std::mt19937& gen) const {
|
||||
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||
// TODO(szabadka) Use the same weight initialization method as in the python
|
||||
|
|
@ -52,39 +53,71 @@ struct RandInitWeights {
|
|||
}
|
||||
};
|
||||
|
||||
class WeightUpdater {
|
||||
class AdamUpdater {
|
||||
public:
|
||||
explicit WeightUpdater(float lr) : lr_(lr) {}
|
||||
explicit AdamUpdater(float alpha, float beta1, float beta2, float epsilon,
|
||||
size_t t)
|
||||
: alpha_(alpha), beta1_(beta1), beta2_(beta2), cbeta1_(1.0f - beta1),
|
||||
cbeta2_(1.0f - beta2), norm1_(1.0 / (1.0 - std::pow(beta1, t))),
|
||||
norm2_(1.0 / (1.0 - std::pow(beta2, t))), epsilon_(epsilon) {}
|
||||
|
||||
template <size_t kCapacity>
|
||||
void operator()(const char* name, const std::array<float, kCapacity>& grad,
|
||||
std::array<float, kCapacity>& weights) {
|
||||
std::array<float, kCapacity>& weights,
|
||||
std::array<float, kCapacity>& grad_m,
|
||||
std::array<float, kCapacity>& grad_v) {
|
||||
for (size_t i = 0; i < kCapacity; ++i) {
|
||||
weights[i] += lr_ * grad[i];
|
||||
grad_m[i] *= beta1_;
|
||||
grad_m[i] += cbeta1_ * grad[i];
|
||||
grad_v[i] *= beta2_;
|
||||
grad_v[i] += cbeta2_ * grad[i] * grad[i];
|
||||
const float mhat = grad_m[i] * norm1_;
|
||||
const float vhat = grad_v[i] * norm2_;
|
||||
weights[i] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
float lr_;
|
||||
float alpha_;
|
||||
float beta1_;
|
||||
float beta2_;
|
||||
float cbeta1_;
|
||||
float cbeta2_;
|
||||
float norm1_;
|
||||
float norm2_;
|
||||
float epsilon_;
|
||||
};
|
||||
|
||||
template <typename TConfig>
|
||||
struct UpdateWeightsT {
|
||||
void operator()(const ByteStorageT& grad_u8, float scale,
|
||||
ByteStorageT& weights_u8, hwy::ThreadPool& pool) const {
|
||||
struct AdamUpdateT {
|
||||
void operator()(const ByteStorageT& grad_u8, float alpha, float beta1,
|
||||
float beta2, float epsilon, size_t t,
|
||||
const ByteStorageT& weights_u8, const ByteStorageT& grad_m_u8,
|
||||
const ByteStorageT& grad_v_u8, hwy::ThreadPool& pool) const {
|
||||
const auto& grad =
|
||||
*reinterpret_cast<const WeightsF<TConfig>*>(grad_u8.get());
|
||||
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||
WeightUpdater updater(scale);
|
||||
ForEachTensor2<float, TConfig>(updater, grad, weights);
|
||||
auto& grad_m = *reinterpret_cast<WeightsF<TConfig>*>(grad_m_u8.get());
|
||||
auto& grad_v = *reinterpret_cast<WeightsF<TConfig>*>(grad_v_u8.get());
|
||||
AdamUpdater updater(alpha, beta1, beta2, epsilon, t);
|
||||
ForEachTensor4<float, TConfig>(updater, grad, weights, grad_m, grad_v);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void UpdateWeights(Model model, const ByteStorageT& grad, float scale,
|
||||
ByteStorageT& weights, hwy::ThreadPool& pool) {
|
||||
CallFunctorForModel<UpdateWeightsT>(model, grad, scale, weights, pool);
|
||||
void RandInitWeights(Model model, const ByteStorageT& weights,
|
||||
hwy::ThreadPool& pool,
|
||||
std::mt19937& gen) {
|
||||
CallFunctorForModel<RandInitWeightsT>(model, weights, pool, gen);
|
||||
}
|
||||
|
||||
void AdamUpdate(Model model, const ByteStorageT& grad, float alpha, float beta1,
|
||||
float beta2, float epsilon, size_t t,
|
||||
const ByteStorageT& weights, const ByteStorageT& grad_m,
|
||||
const ByteStorageT& grad_v, hwy::ThreadPool& pool) {
|
||||
CallFunctorForModel<AdamUpdateT>(model, grad, alpha, beta1, beta2, epsilon, t,
|
||||
weights, grad_m, grad_v, pool);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -16,13 +16,21 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
|
||||
|
||||
#include <random>
|
||||
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void UpdateWeights(Model model, const ByteStorageT& grad, float scale,
|
||||
ByteStorageT& weights, hwy::ThreadPool& pool);
|
||||
void RandInitWeights(Model model, const ByteStorageT& weights,
|
||||
hwy::ThreadPool& pool, std::mt19937& gen);
|
||||
|
||||
void AdamUpdate(Model model, const ByteStorageT& grad, float alpha, float beta1,
|
||||
float beta2, float epsilon, size_t t,
|
||||
const ByteStorageT& weights, const ByteStorageT& grad_m,
|
||||
const ByteStorageT& grad_v, hwy::ThreadPool& pool);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
|
||||
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/prompt.h"
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ class ActivationsWrapper {
|
|||
|
||||
public:
|
||||
ActivationsWrapper()
|
||||
: data_(WrappedT::Allocate()),
|
||||
: data_(AllocateSizeof<WrappedT>()),
|
||||
activations_(*reinterpret_cast<WrappedT*>(data_.get())) {}
|
||||
|
||||
const WrappedT& get() const { return activations_; }
|
||||
|
|
|
|||
|
|
@ -125,6 +125,13 @@ GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
|
|||
Sqrt(static_cast<float>(TConfig::kModelDim))));
|
||||
}
|
||||
|
||||
static HWY_INLINE GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling(
|
||||
size_t model_dim) {
|
||||
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
||||
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
|
||||
Sqrt(static_cast<float>(model_dim))));
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
|
||||
|
|
|
|||
|
|
@ -154,6 +154,7 @@ KVCache KVCache::Create(Model type) {
|
|||
|
||||
class GemmaTokenizer::Impl {
|
||||
public:
|
||||
Impl() = default;
|
||||
explicit Impl(const Path& tokenizer_path) {
|
||||
PROFILER_ZONE("Startup.tokenizer");
|
||||
spp_ = std::make_unique<sentencepiece::SentencePieceProcessor>();
|
||||
|
|
@ -164,24 +165,24 @@ class GemmaTokenizer::Impl {
|
|||
|
||||
bool Encode(const std::string& input,
|
||||
std::vector<std::string>* pieces) const {
|
||||
return spp_->Encode(input, pieces).ok();
|
||||
return spp_ && spp_->Encode(input, pieces).ok();
|
||||
}
|
||||
|
||||
bool Encode(const std::string& input, std::vector<int>* pieces) const {
|
||||
if constexpr (kShowTokenization) {
|
||||
bool is_ok = spp_->Encode(input, pieces).ok();
|
||||
bool is_ok = spp_ && spp_->Encode(input, pieces).ok();
|
||||
for (int i = 0; i < static_cast<int>(pieces->size()); i++) {
|
||||
fprintf(stderr, "%3d: %d\n", i, (*pieces)[i]);
|
||||
}
|
||||
return is_ok;
|
||||
} else {
|
||||
return spp_->Encode(input, pieces).ok();
|
||||
return spp_ && spp_->Encode(input, pieces).ok();
|
||||
}
|
||||
}
|
||||
|
||||
// Given a sequence of ids, decodes it into a detokenized output.
|
||||
bool Decode(const std::vector<int>& ids, std::string* detokenized) const {
|
||||
return spp_->Decode(ids, detokenized).ok();
|
||||
return spp_ && spp_->Decode(ids, detokenized).ok();
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -192,6 +193,10 @@ GemmaTokenizer::GemmaTokenizer(const Path& tokenizer_path) {
|
|||
impl_ = std::make_unique<Impl>(tokenizer_path);
|
||||
}
|
||||
|
||||
GemmaTokenizer::GemmaTokenizer() {
|
||||
impl_ = std::make_unique<Impl>();
|
||||
}
|
||||
|
||||
GemmaTokenizer::~GemmaTokenizer() = default;
|
||||
GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default;
|
||||
GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default;
|
||||
|
|
@ -942,7 +947,7 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
|||
Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type,
|
||||
hwy::ThreadPool& pool)
|
||||
: pool_(pool), tokenizer_(std::move(tokenizer)), model_type_(model_type) {
|
||||
weights_u8_ = CallFunctorForModel<AllocateWeights>(model_type, pool);
|
||||
weights_u8_ = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
|
||||
prefill_u8_ = CallFunctorForModel<AllocatePrefill>(model_type);
|
||||
decode_u8_ = CallFunctorForModel<AllocateDecode>(model_type);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ constexpr int EOS_ID = 1;
|
|||
|
||||
class GemmaTokenizer {
|
||||
public:
|
||||
GemmaTokenizer() = default; // for second Gemma ctor.
|
||||
GemmaTokenizer();
|
||||
explicit GemmaTokenizer(const Path& tokenizer_path);
|
||||
|
||||
// must come after definition of Impl
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ struct LoadRawWeightsT {
|
|||
checkpoint.path.c_str());
|
||||
}
|
||||
|
||||
ByteStorageT weights_u8 = AllocateWeights<TConfig>()(pool);
|
||||
ByteStorageT weights_u8 = AllocateWeightsF<TConfig>()(pool);
|
||||
auto* weights = reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||
|
||||
size_t scale_pos = 0;
|
||||
|
|
|
|||
|
|
@ -242,21 +242,29 @@ using WeightsT = hwy::If<kWeightsAreCompressed, CompressedWeights<TConfig>,
|
|||
WeightsF<TConfig>>;
|
||||
|
||||
// Call via CallFunctorForModel.
|
||||
template <typename TConfig>
|
||||
template <typename T, typename TConfig>
|
||||
struct AllocateWeights {
|
||||
ByteStorageT operator()(hwy::ThreadPool& pool) const {
|
||||
using TWeights = WeightsF<TConfig>;
|
||||
using TWeights = Weights<T, TConfig>;
|
||||
ByteStorageT weights_u8 = AllocateSizeof<TWeights>();
|
||||
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
|
||||
new (&weights->layer_ptrs) LayerPointers<float, TConfig>(pool);
|
||||
new (&weights->layer_ptrs) LayerPointers<T, TConfig>(pool);
|
||||
return weights_u8;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TConfig>
|
||||
struct AllocateWeightsF {
|
||||
ByteStorageT operator()(hwy::ThreadPool& pool) const {
|
||||
return AllocateWeights<float, TConfig>()(pool);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename TConfig>
|
||||
struct ZeroInitWeights {
|
||||
void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const {
|
||||
WeightsF<TConfig>& w = *reinterpret_cast<WeightsF<TConfig>*>(weights.get());
|
||||
Weights<T, TConfig>& w =
|
||||
*reinterpret_cast<Weights<T, TConfig>*>(weights.get());
|
||||
hwy::ZeroBytes(&w.embedder_input_embedding,
|
||||
sizeof(w.embedder_input_embedding));
|
||||
hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale));
|
||||
|
|
@ -267,8 +275,16 @@ struct ZeroInitWeights {
|
|||
};
|
||||
|
||||
template <typename TConfig>
|
||||
struct ZeroInitWeightsF {
|
||||
void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const {
|
||||
ZeroInitWeights<float, TConfig>()(weights, pool);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename TConfig>
|
||||
struct CopyWeights {
|
||||
void operator()(WeightsF<TConfig>& dst, const WeightsF<TConfig>& src) const {
|
||||
void operator()(Weights<T, TConfig>& dst,
|
||||
const Weights<T, TConfig>& src) const {
|
||||
hwy::CopyBytes(&src.embedder_input_embedding, &dst.embedder_input_embedding,
|
||||
sizeof(src.embedder_input_embedding));
|
||||
hwy::CopyBytes(&src.final_norm_scale, &dst.final_norm_scale,
|
||||
|
|
@ -298,14 +314,14 @@ class WeightsWrapper {
|
|||
public:
|
||||
WeightsWrapper()
|
||||
: pool_(0),
|
||||
data_(AllocateWeights<TConfig>(pool_)),
|
||||
data_(AllocateWeights<T, TConfig>()(pool_)),
|
||||
weights_(reinterpret_cast<Weights<T, TConfig>*>(data_.get())) {}
|
||||
|
||||
const Weights<T, TConfig>& get() const { return *weights_; }
|
||||
Weights<T, TConfig>& get() { return *weights_; }
|
||||
void clear() { ZeroInitWeights<TConfig>()(get()); }
|
||||
void clear() { ZeroInitWeights<T, TConfig>()(data_, pool_); }
|
||||
void copy(const WeightsWrapper<T, TConfig>& other) {
|
||||
CopyWeights<TConfig>()(get(), other.get());
|
||||
CopyWeights<T, TConfig>()(get(), other.get());
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -397,7 +413,7 @@ void ForEachTensor(const WeightsF<TConfig>* weights,
|
|||
#define GEMMA_CALL_TOP_FUNC3(name, member) \
|
||||
func(name, weights1.member, weights2.member, weights3.member)
|
||||
#define GEMMA_CALL_TOP_FUNC4(name, member) \
|
||||
func(name, weights1.member, weights2.memeber, \
|
||||
func(name, weights1.member, weights2.member, \
|
||||
weights3.member, weights4.member)
|
||||
|
||||
#define GEMMA_CALL_LAYER_FUNC1(name, member) \
|
||||
|
|
@ -414,7 +430,7 @@ void ForEachTensor(const WeightsF<TConfig>* weights,
|
|||
|
||||
#define GEMMA_CALL_LAYER_FUNC4(name, member) \
|
||||
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
|
||||
func(name_buf, layer1.member, layer2.member, layer4.member)
|
||||
func(name_buf, layer1.member, layer2.member, layer3.member, layer4.member)
|
||||
|
||||
#define GEMMA_CALL_ALL_LAYER_FUNC(N) \
|
||||
if (type == LayerAttentionType::kGemma) { \
|
||||
|
|
@ -491,6 +507,25 @@ void ForEachTensor2(Func& func, const Weights<T, TConfig>& weights1,
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T, typename TConfig, class Func>
|
||||
void ForEachTensor4(Func& func, const Weights<T, TConfig>& weights1,
|
||||
Weights<T, TConfig>& weights2,
|
||||
Weights<T, TConfig>& weights3,
|
||||
Weights<T, TConfig>& weights4) {
|
||||
GEMMA_CALL_TOP_FUNC4("embedding", embedder_input_embedding);
|
||||
GEMMA_CALL_TOP_FUNC4("final_norm", final_norm_scale);
|
||||
char name_buf[16];
|
||||
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||
auto type = TConfig::kLayerConfig[layer_idx];
|
||||
const size_t idx = static_cast<size_t>(layer_idx);
|
||||
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
|
||||
LayerF<TConfig>& layer2 = *weights2.GetLayer(idx);
|
||||
LayerF<TConfig>& layer3 = *weights3.GetLayer(idx);
|
||||
LayerF<TConfig>& layer4 = *weights4.GetLayer(idx);
|
||||
GEMMA_CALL_ALL_LAYER_FUNC(4)
|
||||
}
|
||||
}
|
||||
|
||||
#undef GEMMA_CALL_TOP_FUNC1
|
||||
#undef GEMMA_CALL_TOP_FUNC2
|
||||
#undef GEMMA_CALL_TOP_FUNC3
|
||||
|
|
|
|||
Loading…
Reference in New Issue