mirror of https://github.com/google/gemma.cpp.git
Add Adam optimizer.
Drive-by: Fix compilation errors and tests for backprop functions.
This commit is contained in:
parent
12707ade80
commit
c004799cdc
|
|
@ -50,6 +50,7 @@ set(SOURCES
|
||||||
backprop/backward.h
|
backprop/backward.h
|
||||||
backprop/backward-inl.h
|
backprop/backward-inl.h
|
||||||
backprop/backward_scalar.h
|
backprop/backward_scalar.h
|
||||||
|
backprop/common_scalar.h
|
||||||
backprop/forward.cc
|
backprop/forward.cc
|
||||||
backprop/forward.h
|
backprop/forward.h
|
||||||
backprop/forward-inl.h
|
backprop/forward-inl.h
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@
|
||||||
#include <complex>
|
#include <complex>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "backprop/common_scalar.h"
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
#include "gemma/activations.h"
|
#include "gemma/activations.h"
|
||||||
#include "gemma/common.h" // EmbeddingScaling
|
#include "gemma/common.h" // EmbeddingScaling
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,117 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
|
||||||
|
|
||||||
|
#include <complex>
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
template<typename T, typename U>
|
||||||
|
U DotT(const T* a, const U* b, size_t N) {
|
||||||
|
U sum = {};
|
||||||
|
for (size_t i = 0; i < N; ++i) {
|
||||||
|
sum += a[i] * b[i];
|
||||||
|
}
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
std::complex<double> DotT(const float* a, const std::complex<double>* b,
|
||||||
|
size_t N) {
|
||||||
|
std::complex<double> sum = {};
|
||||||
|
for (size_t i = 0; i < N; ++i) {
|
||||||
|
sum += static_cast<double>(a[i]) * b[i];
|
||||||
|
}
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void MulByConstT(T c, T* x, size_t N) {
|
||||||
|
for (size_t i = 0; i < N; ++i) {
|
||||||
|
x[i] *= c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// out += c * x
|
||||||
|
template<typename T>
|
||||||
|
void MulByConstAndAddT(T c, const T* x, T* out, size_t N) {
|
||||||
|
for (size_t i = 0; i < N; ++i) {
|
||||||
|
out[i] += c * x[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T, size_t N>
|
||||||
|
void MulByConstAndAddT(T c, const std::array<T, N>& x, std::array<T, N>& out) {
|
||||||
|
MulByConstAndAddT(c, x.data(), out.data(), N);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void AddFromT(const T* a, T* out, size_t N) {
|
||||||
|
for (size_t i = 0; i < N; ++i) {
|
||||||
|
out[i] += a[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
T SquaredL2(const T* x, size_t N) {
|
||||||
|
T sum = {};
|
||||||
|
for (size_t i = 0; i < N; ++i) {
|
||||||
|
sum += x[i] * x[i];
|
||||||
|
}
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
T Gelu(T x) {
|
||||||
|
static const T kMul = 0.044715;
|
||||||
|
static const T kSqrt2OverPi = 0.797884560804236;
|
||||||
|
|
||||||
|
const T x3 = x * x * x;
|
||||||
|
const T arg = kSqrt2OverPi * (kMul * x3 + x);
|
||||||
|
const T cdf = T(0.5) * (T(1.0) + std::tanh(arg));
|
||||||
|
return x * cdf;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T, typename U>
|
||||||
|
void Rope(T* x, U base, size_t N, int i) {
|
||||||
|
const size_t N2 = N / 2;
|
||||||
|
for (size_t dim = 0; dim < N2; ++dim) {
|
||||||
|
const T freq_exponents = T(2 * dim) / T(N);
|
||||||
|
const T timescale = std::pow(base, freq_exponents);
|
||||||
|
const T theta = T(i) / timescale;
|
||||||
|
const T cos_val = std::cos(theta);
|
||||||
|
const T sin_val = std::sin(theta);
|
||||||
|
const T x0 = x[dim];
|
||||||
|
const T x1 = x[dim + N2];
|
||||||
|
x[dim] = x0 * cos_val - x1 * sin_val;
|
||||||
|
x[dim + N2] = x0 * sin_val + x1 * cos_val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void Rope(T* x, size_t N, int i) {
|
||||||
|
Rope(x, T(10000.0), N, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void Rope(std::complex<T>* x, size_t N, int i) {
|
||||||
|
Rope(x, T(10000.0), N, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
|
||||||
|
|
@ -23,6 +23,7 @@
|
||||||
#include <complex>
|
#include <complex>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "backprop/common_scalar.h"
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
#include "gemma/activations.h"
|
#include "gemma/activations.h"
|
||||||
#include "gemma/common.h" // EmbeddingScaling
|
#include "gemma/common.h" // EmbeddingScaling
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "gtest/gtest.h"
|
|
||||||
#include "backprop/backward.h"
|
#include "backprop/backward.h"
|
||||||
#include "backprop/forward.h"
|
#include "backprop/forward.h"
|
||||||
#include "backprop/optimizer.h"
|
#include "backprop/optimizer.h"
|
||||||
|
|
@ -24,6 +23,7 @@
|
||||||
#include "gemma/activations.h"
|
#include "gemma/activations.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -32,9 +32,9 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
|
|
||||||
Model model_type = Model::GEMMA_TINY;
|
Model model_type = Model::GEMMA_TINY;
|
||||||
ByteStorageT grad = CallFunctorForModel<AllocateWeights>(model_type, pool);
|
ByteStorageT grad = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
|
||||||
ByteStorageT grad_m = CallFunctorForModel<AllocateWeights>(model_type, pool);
|
ByteStorageT grad_m = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
|
||||||
ByteStorageT grad_v = CallFunctorForModel<AllocateWeights>(model_type, pool);
|
ByteStorageT grad_v = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
|
||||||
ByteStorageT forward = CallFunctorForModel<AllocateForwardPass>(model_type);
|
ByteStorageT forward = CallFunctorForModel<AllocateForwardPass>(model_type);
|
||||||
ByteStorageT backward = CallFunctorForModel<AllocateForwardPass>(model_type);
|
ByteStorageT backward = CallFunctorForModel<AllocateForwardPass>(model_type);
|
||||||
KVCache kv_cache = KVCache::Create(model_type);
|
KVCache kv_cache = KVCache::Create(model_type);
|
||||||
|
|
@ -57,7 +57,7 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
stream_token, accept_token, ReverseSequenceSampler::kEndToken,
|
stream_token, accept_token, ReverseSequenceSampler::kEndToken,
|
||||||
};
|
};
|
||||||
TimingInfo timing_info;
|
TimingInfo timing_info;
|
||||||
model.Generate(runtime, prompt, 0, kv_cache, timing_info);
|
gemma.Generate(runtime, prompt, 0, kv_cache, timing_info);
|
||||||
return reply;
|
return reply;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -73,15 +73,18 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
return ok;
|
return ok;
|
||||||
};
|
};
|
||||||
|
|
||||||
CallFunctorForModel<RandInitWeights>(model_type, gemma.Weights(), pool, gen);
|
RandInitWeights(model_type, gemma.Weights(), pool, gen);
|
||||||
CallFunctorForModel<ZeroInitWeights>(model_type, grad_m, pool);
|
CallFunctorForModel<ZeroInitWeightsF>(model_type, grad_m, pool);
|
||||||
CallFunctorForModel<ZeroInitWeights>(model_type, grad_v, pool);
|
CallFunctorForModel<ZeroInitWeightsF>(model_type, grad_v, pool);
|
||||||
|
|
||||||
printf("Initial weights:\n");
|
printf("Initial weights:\n");
|
||||||
LogWeightStats(model_type, gemma.Weights());
|
LogWeightStats(model_type, gemma.Weights());
|
||||||
|
|
||||||
constexpr size_t kBatchSize = 8;
|
constexpr size_t kBatchSize = 8;
|
||||||
float learning_rate = 0.0005f;
|
const float alpha = 0.001f;
|
||||||
|
const float beta1 = 0.9f;
|
||||||
|
const float beta2 = 0.999f;
|
||||||
|
const float epsilon = 1e-8f;
|
||||||
|
|
||||||
ReverseSequenceSampler training_task({
|
ReverseSequenceSampler training_task({
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1});
|
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1});
|
||||||
|
|
@ -90,7 +93,7 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
size_t num_ok;
|
size_t num_ok;
|
||||||
for (; steps < 1000000; ++steps) {
|
for (; steps < 1000000; ++steps) {
|
||||||
std::mt19937 sgen(42);
|
std::mt19937 sgen(42);
|
||||||
CallFunctorForModel<ZeroInitWeights>(model_type, grad, pool);
|
CallFunctorForModel<ZeroInitWeightsF>(model_type, grad, pool);
|
||||||
float total_loss = 0.0f;
|
float total_loss = 0.0f;
|
||||||
num_ok = 0;
|
num_ok = 0;
|
||||||
for (size_t i = 0; i < kBatchSize; ++i) {
|
for (size_t i = 0; i < kBatchSize; ++i) {
|
||||||
|
|
@ -103,8 +106,8 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
}
|
}
|
||||||
total_loss /= kBatchSize;
|
total_loss /= kBatchSize;
|
||||||
|
|
||||||
const float scale = -learning_rate / kBatchSize;
|
AdamUpdate(model_type, grad, alpha, beta1, beta2, epsilon, steps + 1,
|
||||||
UpdateWeights(model_type, grad, scale, gemma.Weights(), pool);
|
gemma.Weights(), grad_m, grad_v, pool);
|
||||||
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
|
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
|
||||||
steps, total_loss, num_ok, kBatchSize);
|
steps, total_loss, num_ok, kBatchSize);
|
||||||
if (steps % 100 == 0) {
|
if (steps % 100 == 0) {
|
||||||
|
|
@ -119,7 +122,7 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
printf("Num steps: %zu\n", steps);
|
printf("Num steps: %zu\n", steps);
|
||||||
printf("Final weights:\n");
|
printf("Final weights:\n");
|
||||||
LogWeightStats(model_type, gemma.Weights());
|
LogWeightStats(model_type, gemma.Weights());
|
||||||
EXPECT_LT(steps, 3000);
|
EXPECT_LT(steps, 200);
|
||||||
EXPECT_EQ(num_ok, kBatchSize);
|
EXPECT_EQ(num_ok, kBatchSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class WeightInitializer {
|
class WeightInitializer {
|
||||||
public:
|
public:
|
||||||
WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {}
|
WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {}
|
||||||
|
|
@ -41,8 +42,8 @@ class WeightInitializer {
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename TConfig>
|
template <typename TConfig>
|
||||||
struct RandInitWeights {
|
struct RandInitWeightsT {
|
||||||
void operator()(ByteStorageT& weights_u8, hwy::ThreadPool& pool,
|
void operator()(const ByteStorageT& weights_u8, hwy::ThreadPool& pool,
|
||||||
std::mt19937& gen) const {
|
std::mt19937& gen) const {
|
||||||
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||||
// TODO(szabadka) Use the same weight initialization method as in the python
|
// TODO(szabadka) Use the same weight initialization method as in the python
|
||||||
|
|
@ -52,39 +53,71 @@ struct RandInitWeights {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class WeightUpdater {
|
class AdamUpdater {
|
||||||
public:
|
public:
|
||||||
explicit WeightUpdater(float lr) : lr_(lr) {}
|
explicit AdamUpdater(float alpha, float beta1, float beta2, float epsilon,
|
||||||
|
size_t t)
|
||||||
|
: alpha_(alpha), beta1_(beta1), beta2_(beta2), cbeta1_(1.0f - beta1),
|
||||||
|
cbeta2_(1.0f - beta2), norm1_(1.0 / (1.0 - std::pow(beta1, t))),
|
||||||
|
norm2_(1.0 / (1.0 - std::pow(beta2, t))), epsilon_(epsilon) {}
|
||||||
|
|
||||||
template <size_t kCapacity>
|
template <size_t kCapacity>
|
||||||
void operator()(const char* name, const std::array<float, kCapacity>& grad,
|
void operator()(const char* name, const std::array<float, kCapacity>& grad,
|
||||||
std::array<float, kCapacity>& weights) {
|
std::array<float, kCapacity>& weights,
|
||||||
|
std::array<float, kCapacity>& grad_m,
|
||||||
|
std::array<float, kCapacity>& grad_v) {
|
||||||
for (size_t i = 0; i < kCapacity; ++i) {
|
for (size_t i = 0; i < kCapacity; ++i) {
|
||||||
weights[i] += lr_ * grad[i];
|
grad_m[i] *= beta1_;
|
||||||
|
grad_m[i] += cbeta1_ * grad[i];
|
||||||
|
grad_v[i] *= beta2_;
|
||||||
|
grad_v[i] += cbeta2_ * grad[i] * grad[i];
|
||||||
|
const float mhat = grad_m[i] * norm1_;
|
||||||
|
const float vhat = grad_v[i] * norm2_;
|
||||||
|
weights[i] -= alpha_ * mhat / (std::sqrt(vhat) + epsilon_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
float lr_;
|
float alpha_;
|
||||||
|
float beta1_;
|
||||||
|
float beta2_;
|
||||||
|
float cbeta1_;
|
||||||
|
float cbeta2_;
|
||||||
|
float norm1_;
|
||||||
|
float norm2_;
|
||||||
|
float epsilon_;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename TConfig>
|
template <typename TConfig>
|
||||||
struct UpdateWeightsT {
|
struct AdamUpdateT {
|
||||||
void operator()(const ByteStorageT& grad_u8, float scale,
|
void operator()(const ByteStorageT& grad_u8, float alpha, float beta1,
|
||||||
ByteStorageT& weights_u8, hwy::ThreadPool& pool) const {
|
float beta2, float epsilon, size_t t,
|
||||||
|
const ByteStorageT& weights_u8, const ByteStorageT& grad_m_u8,
|
||||||
|
const ByteStorageT& grad_v_u8, hwy::ThreadPool& pool) const {
|
||||||
const auto& grad =
|
const auto& grad =
|
||||||
*reinterpret_cast<const WeightsF<TConfig>*>(grad_u8.get());
|
*reinterpret_cast<const WeightsF<TConfig>*>(grad_u8.get());
|
||||||
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||||
WeightUpdater updater(scale);
|
auto& grad_m = *reinterpret_cast<WeightsF<TConfig>*>(grad_m_u8.get());
|
||||||
ForEachTensor2<float, TConfig>(updater, grad, weights);
|
auto& grad_v = *reinterpret_cast<WeightsF<TConfig>*>(grad_v_u8.get());
|
||||||
|
AdamUpdater updater(alpha, beta1, beta2, epsilon, t);
|
||||||
|
ForEachTensor4<float, TConfig>(updater, grad, weights, grad_m, grad_v);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void UpdateWeights(Model model, const ByteStorageT& grad, float scale,
|
void RandInitWeights(Model model, const ByteStorageT& weights,
|
||||||
ByteStorageT& weights, hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool,
|
||||||
CallFunctorForModel<UpdateWeightsT>(model, grad, scale, weights, pool);
|
std::mt19937& gen) {
|
||||||
|
CallFunctorForModel<RandInitWeightsT>(model, weights, pool, gen);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AdamUpdate(Model model, const ByteStorageT& grad, float alpha, float beta1,
|
||||||
|
float beta2, float epsilon, size_t t,
|
||||||
|
const ByteStorageT& weights, const ByteStorageT& grad_m,
|
||||||
|
const ByteStorageT& grad_v, hwy::ThreadPool& pool) {
|
||||||
|
CallFunctorForModel<AdamUpdateT>(model, grad, alpha, beta1, beta2, epsilon, t,
|
||||||
|
weights, grad_m, grad_v, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -16,13 +16,21 @@
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
|
||||||
|
|
||||||
|
#include <random>
|
||||||
|
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
|
#include "gemma/weights.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
void UpdateWeights(Model model, const ByteStorageT& grad, float scale,
|
void RandInitWeights(Model model, const ByteStorageT& weights,
|
||||||
ByteStorageT& weights, hwy::ThreadPool& pool);
|
hwy::ThreadPool& pool, std::mt19937& gen);
|
||||||
|
|
||||||
|
void AdamUpdate(Model model, const ByteStorageT& grad, float alpha, float beta1,
|
||||||
|
float beta2, float epsilon, size_t t,
|
||||||
|
const ByteStorageT& weights, const ByteStorageT& grad_m,
|
||||||
|
const ByteStorageT& grad_v, hwy::ThreadPool& pool);
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_SAMPLER_H_
|
||||||
|
|
||||||
|
#include <random>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,7 @@ class ActivationsWrapper {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
ActivationsWrapper()
|
ActivationsWrapper()
|
||||||
: data_(WrappedT::Allocate()),
|
: data_(AllocateSizeof<WrappedT>()),
|
||||||
activations_(*reinterpret_cast<WrappedT*>(data_.get())) {}
|
activations_(*reinterpret_cast<WrappedT*>(data_.get())) {}
|
||||||
|
|
||||||
const WrappedT& get() const { return activations_; }
|
const WrappedT& get() const { return activations_; }
|
||||||
|
|
|
||||||
|
|
@ -125,6 +125,13 @@ GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
|
||||||
Sqrt(static_cast<float>(TConfig::kModelDim))));
|
Sqrt(static_cast<float>(TConfig::kModelDim))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static HWY_INLINE GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling(
|
||||||
|
size_t model_dim) {
|
||||||
|
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
||||||
|
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
|
||||||
|
Sqrt(static_cast<float>(model_dim))));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
|
||||||
|
|
|
||||||
|
|
@ -154,6 +154,7 @@ KVCache KVCache::Create(Model type) {
|
||||||
|
|
||||||
class GemmaTokenizer::Impl {
|
class GemmaTokenizer::Impl {
|
||||||
public:
|
public:
|
||||||
|
Impl() = default;
|
||||||
explicit Impl(const Path& tokenizer_path) {
|
explicit Impl(const Path& tokenizer_path) {
|
||||||
PROFILER_ZONE("Startup.tokenizer");
|
PROFILER_ZONE("Startup.tokenizer");
|
||||||
spp_ = std::make_unique<sentencepiece::SentencePieceProcessor>();
|
spp_ = std::make_unique<sentencepiece::SentencePieceProcessor>();
|
||||||
|
|
@ -164,24 +165,24 @@ class GemmaTokenizer::Impl {
|
||||||
|
|
||||||
bool Encode(const std::string& input,
|
bool Encode(const std::string& input,
|
||||||
std::vector<std::string>* pieces) const {
|
std::vector<std::string>* pieces) const {
|
||||||
return spp_->Encode(input, pieces).ok();
|
return spp_ && spp_->Encode(input, pieces).ok();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Encode(const std::string& input, std::vector<int>* pieces) const {
|
bool Encode(const std::string& input, std::vector<int>* pieces) const {
|
||||||
if constexpr (kShowTokenization) {
|
if constexpr (kShowTokenization) {
|
||||||
bool is_ok = spp_->Encode(input, pieces).ok();
|
bool is_ok = spp_ && spp_->Encode(input, pieces).ok();
|
||||||
for (int i = 0; i < static_cast<int>(pieces->size()); i++) {
|
for (int i = 0; i < static_cast<int>(pieces->size()); i++) {
|
||||||
fprintf(stderr, "%3d: %d\n", i, (*pieces)[i]);
|
fprintf(stderr, "%3d: %d\n", i, (*pieces)[i]);
|
||||||
}
|
}
|
||||||
return is_ok;
|
return is_ok;
|
||||||
} else {
|
} else {
|
||||||
return spp_->Encode(input, pieces).ok();
|
return spp_ && spp_->Encode(input, pieces).ok();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Given a sequence of ids, decodes it into a detokenized output.
|
// Given a sequence of ids, decodes it into a detokenized output.
|
||||||
bool Decode(const std::vector<int>& ids, std::string* detokenized) const {
|
bool Decode(const std::vector<int>& ids, std::string* detokenized) const {
|
||||||
return spp_->Decode(ids, detokenized).ok();
|
return spp_ && spp_->Decode(ids, detokenized).ok();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
@ -192,6 +193,10 @@ GemmaTokenizer::GemmaTokenizer(const Path& tokenizer_path) {
|
||||||
impl_ = std::make_unique<Impl>(tokenizer_path);
|
impl_ = std::make_unique<Impl>(tokenizer_path);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GemmaTokenizer::GemmaTokenizer() {
|
||||||
|
impl_ = std::make_unique<Impl>();
|
||||||
|
}
|
||||||
|
|
||||||
GemmaTokenizer::~GemmaTokenizer() = default;
|
GemmaTokenizer::~GemmaTokenizer() = default;
|
||||||
GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default;
|
GemmaTokenizer::GemmaTokenizer(GemmaTokenizer&& other) = default;
|
||||||
GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default;
|
GemmaTokenizer& GemmaTokenizer::operator=(GemmaTokenizer&& other) = default;
|
||||||
|
|
@ -942,7 +947,7 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||||
Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type,
|
Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type,
|
||||||
hwy::ThreadPool& pool)
|
hwy::ThreadPool& pool)
|
||||||
: pool_(pool), tokenizer_(std::move(tokenizer)), model_type_(model_type) {
|
: pool_(pool), tokenizer_(std::move(tokenizer)), model_type_(model_type) {
|
||||||
weights_u8_ = CallFunctorForModel<AllocateWeights>(model_type, pool);
|
weights_u8_ = CallFunctorForModel<AllocateWeightsF>(model_type, pool);
|
||||||
prefill_u8_ = CallFunctorForModel<AllocatePrefill>(model_type);
|
prefill_u8_ = CallFunctorForModel<AllocatePrefill>(model_type);
|
||||||
decode_u8_ = CallFunctorForModel<AllocateDecode>(model_type);
|
decode_u8_ = CallFunctorForModel<AllocateDecode>(model_type);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ constexpr int EOS_ID = 1;
|
||||||
|
|
||||||
class GemmaTokenizer {
|
class GemmaTokenizer {
|
||||||
public:
|
public:
|
||||||
GemmaTokenizer() = default; // for second Gemma ctor.
|
GemmaTokenizer();
|
||||||
explicit GemmaTokenizer(const Path& tokenizer_path);
|
explicit GemmaTokenizer(const Path& tokenizer_path);
|
||||||
|
|
||||||
// must come after definition of Impl
|
// must come after definition of Impl
|
||||||
|
|
|
||||||
|
|
@ -73,7 +73,7 @@ struct LoadRawWeightsT {
|
||||||
checkpoint.path.c_str());
|
checkpoint.path.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
ByteStorageT weights_u8 = AllocateWeights<TConfig>()(pool);
|
ByteStorageT weights_u8 = AllocateWeightsF<TConfig>()(pool);
|
||||||
auto* weights = reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
auto* weights = reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||||
|
|
||||||
size_t scale_pos = 0;
|
size_t scale_pos = 0;
|
||||||
|
|
|
||||||
|
|
@ -242,21 +242,29 @@ using WeightsT = hwy::If<kWeightsAreCompressed, CompressedWeights<TConfig>,
|
||||||
WeightsF<TConfig>>;
|
WeightsF<TConfig>>;
|
||||||
|
|
||||||
// Call via CallFunctorForModel.
|
// Call via CallFunctorForModel.
|
||||||
template <typename TConfig>
|
template <typename T, typename TConfig>
|
||||||
struct AllocateWeights {
|
struct AllocateWeights {
|
||||||
ByteStorageT operator()(hwy::ThreadPool& pool) const {
|
ByteStorageT operator()(hwy::ThreadPool& pool) const {
|
||||||
using TWeights = WeightsF<TConfig>;
|
using TWeights = Weights<T, TConfig>;
|
||||||
ByteStorageT weights_u8 = AllocateSizeof<TWeights>();
|
ByteStorageT weights_u8 = AllocateSizeof<TWeights>();
|
||||||
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
|
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
|
||||||
new (&weights->layer_ptrs) LayerPointers<float, TConfig>(pool);
|
new (&weights->layer_ptrs) LayerPointers<T, TConfig>(pool);
|
||||||
return weights_u8;
|
return weights_u8;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename TConfig>
|
template <typename TConfig>
|
||||||
|
struct AllocateWeightsF {
|
||||||
|
ByteStorageT operator()(hwy::ThreadPool& pool) const {
|
||||||
|
return AllocateWeights<float, TConfig>()(pool);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename TConfig>
|
||||||
struct ZeroInitWeights {
|
struct ZeroInitWeights {
|
||||||
void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const {
|
void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const {
|
||||||
WeightsF<TConfig>& w = *reinterpret_cast<WeightsF<TConfig>*>(weights.get());
|
Weights<T, TConfig>& w =
|
||||||
|
*reinterpret_cast<Weights<T, TConfig>*>(weights.get());
|
||||||
hwy::ZeroBytes(&w.embedder_input_embedding,
|
hwy::ZeroBytes(&w.embedder_input_embedding,
|
||||||
sizeof(w.embedder_input_embedding));
|
sizeof(w.embedder_input_embedding));
|
||||||
hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale));
|
hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale));
|
||||||
|
|
@ -267,8 +275,16 @@ struct ZeroInitWeights {
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename TConfig>
|
template <typename TConfig>
|
||||||
|
struct ZeroInitWeightsF {
|
||||||
|
void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const {
|
||||||
|
ZeroInitWeights<float, TConfig>()(weights, pool);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T, typename TConfig>
|
||||||
struct CopyWeights {
|
struct CopyWeights {
|
||||||
void operator()(WeightsF<TConfig>& dst, const WeightsF<TConfig>& src) const {
|
void operator()(Weights<T, TConfig>& dst,
|
||||||
|
const Weights<T, TConfig>& src) const {
|
||||||
hwy::CopyBytes(&src.embedder_input_embedding, &dst.embedder_input_embedding,
|
hwy::CopyBytes(&src.embedder_input_embedding, &dst.embedder_input_embedding,
|
||||||
sizeof(src.embedder_input_embedding));
|
sizeof(src.embedder_input_embedding));
|
||||||
hwy::CopyBytes(&src.final_norm_scale, &dst.final_norm_scale,
|
hwy::CopyBytes(&src.final_norm_scale, &dst.final_norm_scale,
|
||||||
|
|
@ -298,14 +314,14 @@ class WeightsWrapper {
|
||||||
public:
|
public:
|
||||||
WeightsWrapper()
|
WeightsWrapper()
|
||||||
: pool_(0),
|
: pool_(0),
|
||||||
data_(AllocateWeights<TConfig>(pool_)),
|
data_(AllocateWeights<T, TConfig>()(pool_)),
|
||||||
weights_(reinterpret_cast<Weights<T, TConfig>*>(data_.get())) {}
|
weights_(reinterpret_cast<Weights<T, TConfig>*>(data_.get())) {}
|
||||||
|
|
||||||
const Weights<T, TConfig>& get() const { return *weights_; }
|
const Weights<T, TConfig>& get() const { return *weights_; }
|
||||||
Weights<T, TConfig>& get() { return *weights_; }
|
Weights<T, TConfig>& get() { return *weights_; }
|
||||||
void clear() { ZeroInitWeights<TConfig>()(get()); }
|
void clear() { ZeroInitWeights<T, TConfig>()(data_, pool_); }
|
||||||
void copy(const WeightsWrapper<T, TConfig>& other) {
|
void copy(const WeightsWrapper<T, TConfig>& other) {
|
||||||
CopyWeights<TConfig>()(get(), other.get());
|
CopyWeights<T, TConfig>()(get(), other.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
@ -397,7 +413,7 @@ void ForEachTensor(const WeightsF<TConfig>* weights,
|
||||||
#define GEMMA_CALL_TOP_FUNC3(name, member) \
|
#define GEMMA_CALL_TOP_FUNC3(name, member) \
|
||||||
func(name, weights1.member, weights2.member, weights3.member)
|
func(name, weights1.member, weights2.member, weights3.member)
|
||||||
#define GEMMA_CALL_TOP_FUNC4(name, member) \
|
#define GEMMA_CALL_TOP_FUNC4(name, member) \
|
||||||
func(name, weights1.member, weights2.memeber, \
|
func(name, weights1.member, weights2.member, \
|
||||||
weights3.member, weights4.member)
|
weights3.member, weights4.member)
|
||||||
|
|
||||||
#define GEMMA_CALL_LAYER_FUNC1(name, member) \
|
#define GEMMA_CALL_LAYER_FUNC1(name, member) \
|
||||||
|
|
@ -414,7 +430,7 @@ void ForEachTensor(const WeightsF<TConfig>* weights,
|
||||||
|
|
||||||
#define GEMMA_CALL_LAYER_FUNC4(name, member) \
|
#define GEMMA_CALL_LAYER_FUNC4(name, member) \
|
||||||
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
|
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
|
||||||
func(name_buf, layer1.member, layer2.member, layer4.member)
|
func(name_buf, layer1.member, layer2.member, layer3.member, layer4.member)
|
||||||
|
|
||||||
#define GEMMA_CALL_ALL_LAYER_FUNC(N) \
|
#define GEMMA_CALL_ALL_LAYER_FUNC(N) \
|
||||||
if (type == LayerAttentionType::kGemma) { \
|
if (type == LayerAttentionType::kGemma) { \
|
||||||
|
|
@ -491,6 +507,25 @@ void ForEachTensor2(Func& func, const Weights<T, TConfig>& weights1,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, typename TConfig, class Func>
|
||||||
|
void ForEachTensor4(Func& func, const Weights<T, TConfig>& weights1,
|
||||||
|
Weights<T, TConfig>& weights2,
|
||||||
|
Weights<T, TConfig>& weights3,
|
||||||
|
Weights<T, TConfig>& weights4) {
|
||||||
|
GEMMA_CALL_TOP_FUNC4("embedding", embedder_input_embedding);
|
||||||
|
GEMMA_CALL_TOP_FUNC4("final_norm", final_norm_scale);
|
||||||
|
char name_buf[16];
|
||||||
|
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||||
|
auto type = TConfig::kLayerConfig[layer_idx];
|
||||||
|
const size_t idx = static_cast<size_t>(layer_idx);
|
||||||
|
const LayerF<TConfig>& layer1 = *weights1.GetLayer(idx);
|
||||||
|
LayerF<TConfig>& layer2 = *weights2.GetLayer(idx);
|
||||||
|
LayerF<TConfig>& layer3 = *weights3.GetLayer(idx);
|
||||||
|
LayerF<TConfig>& layer4 = *weights4.GetLayer(idx);
|
||||||
|
GEMMA_CALL_ALL_LAYER_FUNC(4)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#undef GEMMA_CALL_TOP_FUNC1
|
#undef GEMMA_CALL_TOP_FUNC1
|
||||||
#undef GEMMA_CALL_TOP_FUNC2
|
#undef GEMMA_CALL_TOP_FUNC2
|
||||||
#undef GEMMA_CALL_TOP_FUNC3
|
#undef GEMMA_CALL_TOP_FUNC3
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue