mirror of https://github.com/google/gemma.cpp.git
Simplifications: remove GemmaInterface and GemmaImpl
Split common and weights into separate lib Remove common-inl (does not have to be SIMD code), activations.cc Centralize switch(Model) to avoid duplication Move CompressWeightsT to compress_weights.cc Move LoadWeights to weights.cc PiperOrigin-RevId: 640869202
This commit is contained in:
parent
5c3e5f7038
commit
57c2cd8b52
42
BUILD.bazel
42
BUILD.bazel
|
|
@ -53,6 +53,36 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "common",
|
||||
srcs = ["gemma/common.cc"],
|
||||
hdrs = [
|
||||
"gemma/common.h",
|
||||
"gemma/configs.h",
|
||||
],
|
||||
deps = [
|
||||
"//compression:compress",
|
||||
"//compression:io",
|
||||
"@hwy//:hwy", # base.h
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "weights",
|
||||
srcs = ["gemma/weights.cc"],
|
||||
hdrs = ["gemma/weights.h"],
|
||||
deps = [
|
||||
":common",
|
||||
"//compression:compress",
|
||||
"//compression:io",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:profiler",
|
||||
"@hwy//:stats",
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gemma_lib",
|
||||
srcs = [
|
||||
|
|
@ -60,14 +90,12 @@ cc_library(
|
|||
],
|
||||
hdrs = [
|
||||
"gemma/activations.h",
|
||||
"gemma/common.h",
|
||||
"gemma/common-inl.h",
|
||||
"gemma/configs.h",
|
||||
"gemma/gemma.h",
|
||||
"gemma/weights.h",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
":ops",
|
||||
":weights",
|
||||
"//compression:compress",
|
||||
"//compression:io",
|
||||
"@hwy//:hwy",
|
||||
|
|
@ -93,6 +121,7 @@ cc_library(
|
|||
hdrs = ["util/app.h"],
|
||||
deps = [
|
||||
":args",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
"//compression:io",
|
||||
"@hwy//:hwy",
|
||||
|
|
@ -126,6 +155,7 @@ cc_binary(
|
|||
deps = [
|
||||
":app",
|
||||
":args",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
# "//base",
|
||||
"//compression:compress",
|
||||
|
|
@ -141,7 +171,9 @@ cc_binary(
|
|||
srcs = ["gemma/compress_weights.cc"],
|
||||
deps = [
|
||||
":args",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":weights",
|
||||
# "//base",
|
||||
"//compression:compress",
|
||||
"@hwy//:hwy",
|
||||
|
|
@ -157,7 +189,9 @@ cc_binary(
|
|||
deps = [
|
||||
":app",
|
||||
":args",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
"//compression:io",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:nanobenchmark",
|
||||
"@hwy//:thread_pool",
|
||||
|
|
|
|||
|
|
@ -50,8 +50,6 @@ set(SOURCES
|
|||
backprop/backward.h
|
||||
backprop/backward-inl.h
|
||||
backprop/backward_scalar.h
|
||||
backprop/common_scalar.cc
|
||||
backprop/common_scalar.h
|
||||
backprop/forward.cc
|
||||
backprop/forward.h
|
||||
backprop/forward-inl.h
|
||||
|
|
@ -59,10 +57,9 @@ set(SOURCES
|
|||
backprop/optimizer.cc
|
||||
backprop/optimizer.h
|
||||
gemma/configs.h
|
||||
gemma/activations.cc
|
||||
gemma/activations.h
|
||||
gemma/common.cc
|
||||
gemma/common.h
|
||||
gemma/common-inl.h
|
||||
gemma/gemma.cc
|
||||
gemma/gemma.h
|
||||
gemma/ops.h
|
||||
|
|
|
|||
|
|
@ -151,9 +151,9 @@ file (usually `tokenizer.spm`), call `Encode()` to go from string prompts to
|
|||
token id vectors, or `Decode()` to go from token id vector outputs from the
|
||||
model back to strings.
|
||||
|
||||
### `GenerateGemma()` is the entrypoint for token generation
|
||||
### `model.Generate()` is the entrypoint for token generation
|
||||
|
||||
Calling into `GenerateGemma` with a tokenized prompt will 1) mutate the
|
||||
Calling into `model.Generate` with a tokenized prompt will 1) mutate the
|
||||
activation values in `model` and 2) invoke StreamFunc - a lambda callback for
|
||||
each generated token.
|
||||
|
||||
|
|
@ -170,11 +170,11 @@ no-op which is what `run.cc` does.
|
|||
|
||||
### `Transformer()` implements the inference (i.e. `forward()` method in PyTorch or Jax) computation of the neural network
|
||||
|
||||
For high-level applications, you might only call `GenerateGemma()` and never
|
||||
For high-level applications, you might only call `model.Generate()` and never
|
||||
interact directly with the neural network, but if you're doing something a bit
|
||||
more custom you can call transformer which performs a single inference
|
||||
operation on a single token and mutates the Activations and the KVCache through
|
||||
the neural network computation.
|
||||
more custom you can call transformer which performs a single inference operation
|
||||
on a single token and mutates the Activations and the KVCache through the neural
|
||||
network computation.
|
||||
|
||||
### For low level operations, defining new architectures, call `ops.h` functions directly
|
||||
|
||||
|
|
|
|||
|
|
@ -28,8 +28,8 @@
|
|||
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/weights.h"
|
||||
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
|
|
@ -43,7 +43,6 @@
|
|||
#define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
|
||||
#endif
|
||||
|
||||
#include "gemma/common-inl.h"
|
||||
#include "gemma/ops.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ void CrossEntropyLossBackwardPassT(Model model,
|
|||
ByteStorageT& grad,
|
||||
ByteStorageT& backward,
|
||||
hwy::ThreadPool& pool) {
|
||||
// TODO(janwas): use CallFunctorForModel
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
CrossEntropyLossBackwardPass<ConfigGemma2B>(
|
||||
|
|
|
|||
|
|
@ -23,9 +23,9 @@
|
|||
#include <complex>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/common_scalar.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h" // EmbeddingScaling
|
||||
#include "gemma/weights.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
|
|||
|
|
@ -19,10 +19,10 @@
|
|||
#include <complex>
|
||||
#include <random>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "backprop/forward_scalar.h"
|
||||
#include "backprop/sampler.h"
|
||||
#include "backprop/test_util.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
|
|||
|
|
@ -30,11 +30,11 @@
|
|||
#include "backprop/sampler.h"
|
||||
#include "backprop/test_util.h"
|
||||
#include "compression/compress.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/weights.h"
|
||||
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
|
|
|
|||
|
|
@ -1,51 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "backprop/common_scalar.cc" // NOLINT
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "gemma/ops.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
float EmbeddingScaling(int model_dim) {
|
||||
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
||||
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
|
||||
Sqrt(static_cast<float>(model_dim))));
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
|
||||
HWY_EXPORT(EmbeddingScaling);
|
||||
|
||||
float EmbeddingScaling(int model_dim) {
|
||||
return HWY_DYNAMIC_DISPATCH(EmbeddingScaling)(model_dim);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
@ -1,119 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
|
||||
|
||||
#include <complex>
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
template<typename T, typename U>
|
||||
U DotT(const T* a, const U* b, size_t N) {
|
||||
U sum = {};
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
sum += a[i] * b[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<>
|
||||
std::complex<double> DotT(const float* a, const std::complex<double>* b,
|
||||
size_t N) {
|
||||
std::complex<double> sum = {};
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
sum += static_cast<double>(a[i]) * b[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void MulByConstT(T c, T* x, size_t N) {
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
x[i] *= c;
|
||||
}
|
||||
}
|
||||
|
||||
// out += c * x
|
||||
template<typename T>
|
||||
void MulByConstAndAddT(T c, const T* x, T* out, size_t N) {
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
out[i] += c * x[i];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, size_t N>
|
||||
void MulByConstAndAddT(T c, const std::array<T, N>& x, std::array<T, N>& out) {
|
||||
MulByConstAndAddT(c, x.data(), out.data(), N);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void AddFromT(const T* a, T* out, size_t N) {
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
out[i] += a[i];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T SquaredL2(const T* x, size_t N) {
|
||||
T sum = {};
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
sum += x[i] * x[i];
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T Gelu(T x) {
|
||||
static const T kMul = 0.044715;
|
||||
static const T kSqrt2OverPi = 0.797884560804236;
|
||||
|
||||
const T x3 = x * x * x;
|
||||
const T arg = kSqrt2OverPi * (kMul * x3 + x);
|
||||
const T cdf = T(0.5) * (T(1.0) + std::tanh(arg));
|
||||
return x * cdf;
|
||||
}
|
||||
|
||||
template<typename T, typename U>
|
||||
void Rope(T* x, U base, size_t N, int i) {
|
||||
const size_t N2 = N / 2;
|
||||
for (size_t dim = 0; dim < N2; ++dim) {
|
||||
const T freq_exponents = T(2 * dim) / T(N);
|
||||
const T timescale = std::pow(base, freq_exponents);
|
||||
const T theta = T(i) / timescale;
|
||||
const T cos_val = std::cos(theta);
|
||||
const T sin_val = std::sin(theta);
|
||||
const T x0 = x[dim];
|
||||
const T x1 = x[dim + N2];
|
||||
x[dim] = x0 * cos_val - x1 * sin_val;
|
||||
x[dim + N2] = x0 * sin_val + x1 * cos_val;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void Rope(T* x, size_t N, int i) {
|
||||
Rope(x, T(10000.0), N, i);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void Rope(std::complex<T>* x, size_t N, int i) {
|
||||
Rope(x, T(10000.0), N, i);
|
||||
}
|
||||
|
||||
float EmbeddingScaling(int model_dim);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
|
||||
|
|
@ -24,8 +24,8 @@
|
|||
#include <cmath>
|
||||
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
|
|
@ -39,7 +39,6 @@
|
|||
#define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
|
||||
#endif
|
||||
|
||||
#include "gemma/common-inl.h"
|
||||
#include "gemma/ops.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt,
|
|||
const ByteStorageT& weights,
|
||||
ByteStorageT& forward,
|
||||
hwy::ThreadPool& pool) {
|
||||
// TODO(janwas): use CallFunctorForModel
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
return CrossEntropyLossForwardPass<ConfigGemma2B>(
|
||||
|
|
|
|||
|
|
@ -23,9 +23,9 @@
|
|||
#include <complex>
|
||||
#include <vector>
|
||||
|
||||
#include "backprop/common_scalar.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h" // EmbeddingScaling
|
||||
#include "gemma/weights.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "backprop/backward.h"
|
||||
#include "backprop/forward.h"
|
||||
#include "backprop/optimizer.h"
|
||||
|
|
@ -23,7 +24,6 @@
|
|||
#include "gemma/activations.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -32,20 +32,20 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
std::mt19937 gen(42);
|
||||
|
||||
Model model_type = Model::GEMMA_TINY;
|
||||
ByteStorageT weights = AllocateWeights(model_type, pool);
|
||||
ByteStorageT grad = AllocateWeights(model_type, pool);
|
||||
ByteStorageT grad_m = AllocateWeights(model_type, pool);
|
||||
ByteStorageT grad_v = AllocateWeights(model_type, pool);
|
||||
ByteStorageT forward = AllocateForwardPass(model_type);
|
||||
ByteStorageT backward = AllocateForwardPass(model_type);
|
||||
ByteStorageT inference = AllocateInferenceState(model_type);
|
||||
auto kv_cache = CreateKVCache(model_type);
|
||||
ByteStorageT grad = CallFunctorForModel<AllocateWeights>(model_type, pool);
|
||||
ByteStorageT grad_m = CallFunctorForModel<AllocateWeights>(model_type, pool);
|
||||
ByteStorageT grad_v = CallFunctorForModel<AllocateWeights>(model_type, pool);
|
||||
ByteStorageT forward = CallFunctorForModel<AllocateForwardPass>(model_type);
|
||||
ByteStorageT backward = CallFunctorForModel<AllocateForwardPass>(model_type);
|
||||
KVCache kv_cache = KVCache::Create(model_type);
|
||||
size_t max_tokens = 32;
|
||||
size_t max_generated_tokens = 16;
|
||||
float temperature = 1.0f;
|
||||
int verbosity = 0;
|
||||
const auto accept_token = [](int) { return true; };
|
||||
|
||||
Gemma gemma(GemmaTokenizer(), model_type, pool);
|
||||
|
||||
const auto generate = [&](const std::vector<int>& prompt) {
|
||||
std::vector<int> reply;
|
||||
auto stream_token = [&reply](int token, float) {
|
||||
|
|
@ -57,8 +57,7 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
stream_token, accept_token, ReverseSequenceSampler::kEndToken,
|
||||
};
|
||||
TimingInfo timing_info;
|
||||
GenerateGemma(model_type, weights, inference, runtime, prompt, 0,
|
||||
kv_cache, pool, timing_info);
|
||||
model.Generate(runtime, prompt, 0, kv_cache, timing_info);
|
||||
return reply;
|
||||
};
|
||||
|
||||
|
|
@ -74,12 +73,12 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
return ok;
|
||||
};
|
||||
|
||||
RandInitWeights(model_type, weights, pool, gen);
|
||||
ZeroInitWeights(model_type, grad_m, pool);
|
||||
ZeroInitWeights(model_type, grad_v, pool);
|
||||
CallFunctorForModel<RandInitWeights>(model_type, gemma.Weights(), pool, gen);
|
||||
CallFunctorForModel<ZeroInitWeights>(model_type, grad_m, pool);
|
||||
CallFunctorForModel<ZeroInitWeights>(model_type, grad_v, pool);
|
||||
|
||||
printf("Initial weights:\n");
|
||||
LogWeightStats(model_type, weights);
|
||||
LogWeightStats(model_type, gemma.Weights());
|
||||
|
||||
constexpr size_t kBatchSize = 8;
|
||||
float learning_rate = 0.0005f;
|
||||
|
|
@ -91,21 +90,21 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
size_t num_ok;
|
||||
for (; steps < 1000000; ++steps) {
|
||||
std::mt19937 sgen(42);
|
||||
ZeroInitWeights(model_type, grad, pool);
|
||||
CallFunctorForModel<ZeroInitWeights>(model_type, grad, pool);
|
||||
float total_loss = 0.0f;
|
||||
num_ok = 0;
|
||||
for (size_t i = 0; i < kBatchSize; ++i) {
|
||||
Prompt prompt = training_task.Sample(sgen);
|
||||
total_loss += CrossEntropyLossForwardPass(
|
||||
model_type, prompt, weights, forward, pool);
|
||||
CrossEntropyLossBackwardPass(
|
||||
model_type, prompt, weights, forward, grad, backward, pool);
|
||||
total_loss += CrossEntropyLossForwardPass(model_type, prompt,
|
||||
gemma.Weights(), forward, pool);
|
||||
CrossEntropyLossBackwardPass(model_type, prompt, gemma.Weights(), forward,
|
||||
grad, backward, pool);
|
||||
num_ok += verify(prompt) ? 1 : 0;
|
||||
}
|
||||
total_loss /= kBatchSize;
|
||||
|
||||
const float scale = -learning_rate / kBatchSize;
|
||||
UpdateWeights(model_type, grad, scale, weights, pool);
|
||||
UpdateWeights(model_type, grad, scale, gemma.Weights(), pool);
|
||||
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
|
||||
steps, total_loss, num_ok, kBatchSize);
|
||||
if (steps % 100 == 0) {
|
||||
|
|
@ -119,7 +118,7 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
}
|
||||
printf("Num steps: %zu\n", steps);
|
||||
printf("Final weights:\n");
|
||||
LogWeightStats(model_type, weights);
|
||||
LogWeightStats(model_type, gemma.Weights());
|
||||
EXPECT_LT(steps, 3000);
|
||||
EXPECT_EQ(num_ok, kBatchSize);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -41,14 +41,16 @@ class WeightInitializer {
|
|||
};
|
||||
|
||||
template <typename TConfig>
|
||||
void RandInitWeights(ByteStorageT& weights_u8, hwy::ThreadPool& pool,
|
||||
std::mt19937& gen) {
|
||||
struct RandInitWeights {
|
||||
void operator()(ByteStorageT& weights_u8, hwy::ThreadPool& pool,
|
||||
std::mt19937& gen) const {
|
||||
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||
// TODO(szabadka) Use the same weight initialization method as in the python
|
||||
// version.
|
||||
WeightInitializer init(gen);
|
||||
ForEachTensor1<float, TConfig>(init, weights);
|
||||
}
|
||||
};
|
||||
|
||||
class WeightUpdater {
|
||||
public:
|
||||
|
|
@ -67,55 +69,22 @@ class WeightUpdater {
|
|||
};
|
||||
|
||||
template <typename TConfig>
|
||||
void UpdateWeights(const ByteStorageT& grad_u8, float scale,
|
||||
ByteStorageT& weights_u8, hwy::ThreadPool& pool) {
|
||||
struct UpdateWeightsT {
|
||||
void operator()(const ByteStorageT& grad_u8, float scale,
|
||||
ByteStorageT& weights_u8, hwy::ThreadPool& pool) const {
|
||||
const auto& grad =
|
||||
*reinterpret_cast<const WeightsF<TConfig>*>(grad_u8.get());
|
||||
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||
WeightUpdater updater(scale);
|
||||
ForEachTensor2<float, TConfig>(updater, grad, weights);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void RandInitWeights(Model model, ByteStorageT& weights_u8,
|
||||
hwy::ThreadPool& pool, std::mt19937& gen) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
RandInitWeights<ConfigGemma2B>(weights_u8, pool, gen);
|
||||
break;
|
||||
case Model::GEMMA_7B:
|
||||
RandInitWeights<ConfigGemma7B>(weights_u8, pool, gen);
|
||||
break;
|
||||
case Model::GRIFFIN_2B:
|
||||
RandInitWeights<ConfigGriffin2B>(weights_u8, pool, gen);
|
||||
break;
|
||||
case Model::GEMMA_TINY:
|
||||
RandInitWeights<ConfigGemmaTiny>(weights_u8, pool, gen);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateWeights(Model model, const ByteStorageT& grad, float scale,
|
||||
ByteStorageT& weights, hwy::ThreadPool& pool) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
UpdateWeights<ConfigGemma2B>(grad, scale, weights, pool);
|
||||
break;
|
||||
case Model::GEMMA_7B:
|
||||
UpdateWeights<ConfigGemma7B>(grad, scale, weights, pool);
|
||||
break;
|
||||
case Model::GRIFFIN_2B:
|
||||
UpdateWeights<ConfigGriffin2B>(grad, scale, weights, pool);
|
||||
break;
|
||||
case Model::GEMMA_TINY:
|
||||
UpdateWeights<ConfigGemmaTiny>(grad, scale, weights, pool);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
CallFunctorForModel<UpdateWeightsT>(model, grad, scale, weights, pool);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -16,16 +16,11 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
|
||||
|
||||
#include <random>
|
||||
|
||||
#include "gemma/common.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void RandInitWeights(Model model, ByteStorageT& weights, hwy::ThreadPool& pool,
|
||||
std::mt19937& gen);
|
||||
|
||||
void UpdateWeights(Model model, const ByteStorageT& grad, float scale,
|
||||
ByteStorageT& weights, hwy::ThreadPool& pool);
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ std::pair<std::string, int> QueryModel(
|
|||
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input,
|
||||
gcpp::LayersOutputT* layers_output) {
|
||||
std::vector<int> prompt;
|
||||
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
|
||||
HWY_ASSERT(model.Tokenizer().Encode(input, &prompt));
|
||||
|
||||
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
||||
// if needed.
|
||||
|
|
@ -48,11 +48,10 @@ std::pair<std::string, int> QueryModel(
|
|||
std::mt19937 gen;
|
||||
gen.seed(42);
|
||||
|
||||
auto stream_token = [&res, &total_tokens, &app,
|
||||
tokenizer = model.Tokenizer()](int token, float) {
|
||||
auto stream_token = [&res, &total_tokens, &model](int token, float) {
|
||||
++total_tokens;
|
||||
std::string token_text;
|
||||
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text));
|
||||
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
res += token_text;
|
||||
return true;
|
||||
};
|
||||
|
|
@ -70,8 +69,8 @@ std::pair<std::string, int> QueryModel(
|
|||
.stream_token = stream_token,
|
||||
.accept_token = accept_token,
|
||||
};
|
||||
GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0, kv_cache, pool,
|
||||
timing_info, layers_output);
|
||||
model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache, timing_info,
|
||||
layers_output);
|
||||
return {res, total_tokens};
|
||||
}
|
||||
|
||||
|
|
@ -115,7 +114,7 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
|
||||
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
|
||||
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
|
||||
|
||||
const std::string& prompt = prompt_args.prompt;
|
||||
if (prompt.empty()) {
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ int main(int argc, char** argv) {
|
|||
|
||||
// Instantiate model and KV Cache
|
||||
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
|
||||
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
|
||||
size_t pos = 0; // KV Cache position
|
||||
|
||||
// Initialize random number generator
|
||||
|
|
|
|||
|
|
@ -1,38 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/activations.h"
|
||||
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
ByteStorageT AllocateForwardPass(Model model) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
return ForwardPass<float, ConfigGemma2B>::Allocate();
|
||||
case Model::GEMMA_7B:
|
||||
return ForwardPass<float, ConfigGemma7B>::Allocate();
|
||||
case Model::GRIFFIN_2B:
|
||||
return ForwardPass<float, ConfigGriffin2B>::Allocate();
|
||||
case Model::GEMMA_TINY:
|
||||
return ForwardPass<float, ConfigGemmaTiny>::Allocate();
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -16,8 +16,11 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
|
||||
|
||||
#include "gemma/common.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include <stddef.h>
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "gemma/common.h" // ByteStorageT
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -54,9 +57,12 @@ struct ForwardPass {
|
|||
std::array<T, kSeqLen * kModelDim> final_norm_output;
|
||||
std::array<T, kSeqLen * kVocabSize> logits;
|
||||
std::array<T, kSeqLen * kVocabSize> probs;
|
||||
};
|
||||
|
||||
static ByteStorageT Allocate() {
|
||||
return hwy::AllocateAligned<uint8_t>(sizeof(ForwardPass<T, TConfig>));
|
||||
template <typename TConfig>
|
||||
struct AllocateForwardPass {
|
||||
ByteStorageT operator()() const {
|
||||
return AllocateSizeof<ForwardPass<float, TConfig>>();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -78,8 +84,6 @@ class ActivationsWrapper {
|
|||
WrappedT& activations_;
|
||||
};
|
||||
|
||||
ByteStorageT AllocateForwardPass(Model model);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
#include <algorithm>
|
||||
#include <cstdlib> // EXIT_FAILURE
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
|
|
@ -8,6 +9,7 @@
|
|||
#include <utility> // std::pair
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/gemma.h"
|
||||
#include "util/app.h"
|
||||
#include "util/args.h"
|
||||
|
|
@ -61,7 +63,7 @@ std::pair<std::string, int> QueryModel(
|
|||
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
|
||||
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input) {
|
||||
std::vector<int> prompt;
|
||||
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
|
||||
HWY_ASSERT(model.Tokenizer().Encode(input, &prompt));
|
||||
|
||||
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
||||
// if needed.
|
||||
|
|
@ -73,11 +75,11 @@ std::pair<std::string, int> QueryModel(
|
|||
gen.seed(42);
|
||||
|
||||
const double time_start = hwy::platform::Now();
|
||||
auto stream_token = [&res, &total_tokens, &time_start, &app,
|
||||
tokenizer = model.Tokenizer()](int token, float) {
|
||||
auto stream_token = [&res, &total_tokens, &time_start, &app, &model](
|
||||
int token, float) {
|
||||
++total_tokens;
|
||||
std::string token_text;
|
||||
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text));
|
||||
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
res += token_text;
|
||||
if (app.verbosity >= 1 && total_tokens % 100 == 0) {
|
||||
LogSpeedStats(time_start, total_tokens);
|
||||
|
|
@ -98,8 +100,8 @@ std::pair<std::string, int> QueryModel(
|
|||
.stream_token = stream_token,
|
||||
.accept_token = accept_token,
|
||||
};
|
||||
GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0, kv_cache, pool,
|
||||
timing_info, /*layers_output=*/nullptr);
|
||||
model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache, timing_info,
|
||||
/*layers_output=*/nullptr);
|
||||
if (app.verbosity >= 1) {
|
||||
LogSpeedStats(time_start, total_tokens);
|
||||
}
|
||||
|
|
@ -190,7 +192,7 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
|
|||
size_t batch_tokens) {
|
||||
std::string input = ReadFile(text);
|
||||
std::vector<int> prompt;
|
||||
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
|
||||
HWY_ASSERT(model.Tokenizer().Encode(input, &prompt));
|
||||
prompt.resize(std::min<size_t>(args.max_tokens, prompt.size()));
|
||||
std::cout << "Number of input tokens: " << prompt.size() << "\n";
|
||||
const double time_start = hwy::platform::Now();
|
||||
|
|
@ -201,14 +203,13 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
|
|||
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
||||
std::vector<int> prompt_slice(prompt.begin() + pos,
|
||||
prompt.begin() + pos + num_tokens);
|
||||
auto kv_cache = CreateKVCache(model_type);
|
||||
float entropy =
|
||||
ComputeCrossEntropy(model, num_tokens, prompt_slice, kv_cache, pool,
|
||||
app.verbosity);
|
||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(model_type);
|
||||
float entropy = model.ComputeCrossEntropy(num_tokens, prompt_slice,
|
||||
kv_cache, app.verbosity);
|
||||
total_entropy += entropy;
|
||||
LogSpeedStats(time_start, pos + num_tokens);
|
||||
std::string text_slice;
|
||||
HWY_ASSERT(model.Tokenizer()->Decode(prompt_slice, &text_slice));
|
||||
HWY_ASSERT(model.Tokenizer().Decode(prompt_slice, &text_slice));
|
||||
total_input_len += text_slice.size();
|
||||
printf("Cross entropy per byte: %f [cumulative: %f]\n",
|
||||
entropy / text_slice.size(), total_entropy / total_input_len);
|
||||
|
|
@ -277,7 +278,7 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
|
||||
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
|
||||
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
|
||||
|
||||
if (!benchmark_args.goldens.path.empty()) {
|
||||
const std::string golden_path =
|
||||
|
|
|
|||
|
|
@ -1,68 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Include guard for non-SIMD code.
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
|
||||
#include "gemma/activations.h"
|
||||
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||
#ifdef THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE
|
||||
#undef THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE
|
||||
#else
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE
|
||||
#endif
|
||||
|
||||
#include "gemma/ops.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo`
|
||||
// are both constexpr
|
||||
#if HWY_COMPILER_GCC_ACTUAL
|
||||
#define GEMMA_CONSTEXPR_EMBSCALING HWY_BF16_CONSTEXPR
|
||||
#else
|
||||
#define GEMMA_CONSTEXPR_EMBSCALING
|
||||
#endif
|
||||
|
||||
template <typename TConfig>
|
||||
GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
|
||||
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
||||
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
|
||||
Sqrt(static_cast<float>(TConfig::kModelDim))));
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#endif // NOLINT
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/common.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <algorithm> // std::transform
|
||||
#include <cctype>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
namespace {
|
||||
constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt", "2b-it",
|
||||
"7b-it", "gr2b-it", "tiny"};
|
||||
constexpr Model kModelTypes[] = {
|
||||
Model::GEMMA_2B, Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_2B,
|
||||
Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_TINY};
|
||||
constexpr ModelTraining kModelTraining[] = {
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT,
|
||||
ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT,
|
||||
ModelTraining::GEMMA_IT};
|
||||
} // namespace
|
||||
|
||||
const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
||||
Model& model, ModelTraining& training) {
|
||||
constexpr size_t kNum = std::end(kModelFlags) - std::begin(kModelFlags);
|
||||
static char kErrorMessageBuffer[kNum * 8 + 1024] =
|
||||
"Invalid or missing model flag, need to specify one of ";
|
||||
for (size_t i = 0; i + 1 < kNum; i++) {
|
||||
strcat(kErrorMessageBuffer, kModelFlags[i]); // NOLINT
|
||||
strcat(kErrorMessageBuffer, ", "); // NOLINT
|
||||
}
|
||||
strcat(kErrorMessageBuffer, kModelFlags[kNum - 1]); // NOLINT
|
||||
strcat(kErrorMessageBuffer, "."); // NOLINT
|
||||
std::string model_type_lc = model_flag;
|
||||
std::transform(begin(model_type_lc), end(model_type_lc), begin(model_type_lc),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
for (size_t i = 0; i < kNum; i++) {
|
||||
if (kModelFlags[i] == model_type_lc) {
|
||||
model = kModelTypes[i];
|
||||
training = kModelTraining[i];
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return kErrorMessageBuffer;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
100
gemma/common.h
100
gemma/common.h
|
|
@ -16,15 +16,115 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
|
||||
|
||||
#include <math.h> // sqrtf
|
||||
#include <stdint.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "gemma/configs.h" // IWYU pragma: export
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // ConvertScalarTo
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
using ByteStorageT = hwy::AlignedFreeUniquePtr<uint8_t[]>;
|
||||
|
||||
template <typename T>
|
||||
ByteStorageT AllocateSizeof() {
|
||||
return hwy::AllocateAligned<uint8_t>(sizeof(T));
|
||||
}
|
||||
|
||||
// Model variants: see configs.h for details.
|
||||
enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B, GEMMA_TINY };
|
||||
|
||||
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
||||
|
||||
// Returns the return value of Func<T>().operator() called with `args`, where
|
||||
// `T` is selected based on `model`.
|
||||
//
|
||||
// This is used to implement type-erased functions such as
|
||||
// LoadCompressedWeights, which can be called from other .cc files, by calling a
|
||||
// functor LoadCompressedWeightsT, which has a template argument. `Func` must
|
||||
// be a functor because function templates cannot be passed as a template
|
||||
// template argument, and we prefer to avoid the overhead of std::function.
|
||||
//
|
||||
// This function avoids having to update all call sites when we extend `Model`.
|
||||
template <template <typename Config> class Func, typename... Args>
|
||||
decltype(auto) CallFunctorForModel(Model model, Args&&... args) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_TINY:
|
||||
return Func<ConfigGemmaTiny>()(std::forward<Args>(args)...);
|
||||
case Model::GEMMA_2B:
|
||||
return Func<ConfigGemma2B>()(std::forward<Args>(args)...);
|
||||
case Model::GEMMA_7B:
|
||||
return Func<ConfigGemma7B>()(std::forward<Args>(args)...);
|
||||
case Model::GRIFFIN_2B:
|
||||
return Func<ConfigGriffin2B>()(std::forward<Args>(args)...);
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
}
|
||||
|
||||
// Like CallFunctorForModel, but for SIMD function templates. This is a macro
|
||||
// because it boils down to N_SSE4::FUNC, which would not work if FUNC was a
|
||||
// normal function argument.
|
||||
#define GEMMA_EXPORT_AND_DISPATCH_MODEL(MODEL, FUNC, ARGS) \
|
||||
switch (MODEL) { \
|
||||
case Model::GEMMA_TINY: { \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemmaTiny>) \
|
||||
ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::GEMMA_2B: { \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma2B>) \
|
||||
ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::GEMMA_7B: { \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma7B>) \
|
||||
ARGS; \
|
||||
break; \
|
||||
} \
|
||||
case Model::GRIFFIN_2B: { \
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGriffin2B>) \
|
||||
ARGS; \
|
||||
break; \
|
||||
} \
|
||||
default: \
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
|
||||
}
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
// Thread-hostile.
|
||||
const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
||||
Model& model, ModelTraining& training);
|
||||
|
||||
// __builtin_sqrt is not constexpr as of Clang 17.
|
||||
#if HWY_COMPILER_GCC_ACTUAL
|
||||
#define GEMMA_CONSTEXPR_SQRT constexpr
|
||||
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) {
|
||||
return __builtin_sqrt(x);
|
||||
}
|
||||
#else
|
||||
#define GEMMA_CONSTEXPR_SQRT
|
||||
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
|
||||
#endif
|
||||
|
||||
// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo`
|
||||
// are both constexpr
|
||||
#if HWY_COMPILER_GCC_ACTUAL
|
||||
#define GEMMA_CONSTEXPR_EMBSCALING HWY_BF16_CONSTEXPR
|
||||
#else
|
||||
#define GEMMA_CONSTEXPR_EMBSCALING
|
||||
#endif
|
||||
|
||||
template <typename TConfig>
|
||||
GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
|
||||
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
||||
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
|
||||
Sqrt(static_cast<float>(TConfig::kModelDim))));
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
|
||||
|
|
|
|||
|
|
@ -15,11 +15,33 @@
|
|||
|
||||
// Command line tool to create compressed weights.
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"gemma/compress_weights.cc" // NOLINT
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
// Must come after foreach_target.h to avoid redefinition errors.
|
||||
#include "compression/compress-inl.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
#ifndef GEMMA_COMPRESS_WEIGHTS_ONCE
|
||||
#define GEMMA_COMPRESS_WEIGHTS_ONCE
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::clamp
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <thread> // NOLINT
|
||||
|
||||
#include "gemma/gemma.h" // Gemma
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/common.h" // Model
|
||||
#include "gemma/weights.h"
|
||||
#include "util/args.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -99,10 +121,57 @@ void ShowHelp(gcpp::Args& args) {
|
|||
std::cerr << "\n";
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // GEMMA_COMPRESS_WEIGHTS_ONCE
|
||||
|
||||
// SIMD code, compiled once per target.
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
template <class TConfig>
|
||||
void CompressWeights(const Path& weights_path,
|
||||
const Path& compressed_weights_path, Model model_type,
|
||||
hwy::ThreadPool& pool) {
|
||||
if (!weights_path.Exists()) {
|
||||
HWY_ABORT("The model weights file '%s' does not exist.",
|
||||
weights_path.path.c_str());
|
||||
}
|
||||
|
||||
// Allocate compressed weights.
|
||||
using CWeights = CompressedWeights<TConfig>;
|
||||
ByteStorageT c_weights_u8 = AllocateSizeof<CWeights>();
|
||||
CWeights* c_weights = reinterpret_cast<CWeights*>(c_weights_u8.get());
|
||||
new (&c_weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
|
||||
|
||||
// Get weights, compress, and store.
|
||||
const bool scale_for_compression = TConfig::kNumTensorScales > 0;
|
||||
const ByteStorageT weights_u8 = gcpp::LoadRawWeights(
|
||||
weights_path, model_type, pool, scale_for_compression);
|
||||
WeightsF<TConfig>* weights =
|
||||
reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||
Compressor compressor(pool);
|
||||
ForEachTensor<TConfig>(weights, *c_weights, compressor);
|
||||
compressor.AddScales(weights->scales.data(), weights->scales.size());
|
||||
compressor.WriteAll(pool, compressed_weights_path);
|
||||
|
||||
weights->layer_ptrs.~LayerPointers<float, TConfig>();
|
||||
c_weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
|
||||
void Run(Args& args) {
|
||||
hwy::ThreadPool pool(args.num_threads);
|
||||
gcpp::CompressWeights(args.ModelType(), args.weights, args.compressed_weights,
|
||||
pool);
|
||||
const Model model_type = args.ModelType();
|
||||
GEMMA_EXPORT_AND_DISPATCH_MODEL(
|
||||
model_type, CompressWeights,
|
||||
(args.weights, args.compressed_weights, model_type, pool));
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -111,12 +180,12 @@ int main(int argc, char** argv) {
|
|||
gcpp::Args args(argc, argv);
|
||||
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
ShowHelp(args);
|
||||
gcpp::ShowHelp(args);
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (const char* error = args.Validate()) {
|
||||
ShowHelp(args);
|
||||
gcpp::ShowHelp(args);
|
||||
HWY_ABORT("\nInvalid args: %s", error);
|
||||
}
|
||||
|
||||
|
|
@ -124,3 +193,5 @@ int main(int argc, char** argv) {
|
|||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
|
|
@ -13,11 +13,20 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Model configurations
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_
|
||||
|
||||
// Model configurations
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "compression/compress.h" // SfpStream
|
||||
#include "hwy/base.h" // hwy::bfloat16_t
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Allow changing pre-allocated kv cache size as a compiler flag
|
||||
#ifndef GEMMA_MAX_SEQLEN
|
||||
#define GEMMA_MAX_SEQLEN 4096
|
||||
|
|
@ -33,25 +42,20 @@
|
|||
#define GEMMA_MAX_THREADS 128
|
||||
#endif // !GEMMA_MAX_THREADS
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "compression/sfp.h"
|
||||
#include "hwy/base.h" // hwy::bfloat16_t
|
||||
|
||||
// Allowable types for GEMMA_WEIGHT_T (can be specified at compilation time):
|
||||
// float, hwy::bfloat16_t, SfpStream, NuqStream
|
||||
#ifndef GEMMA_WEIGHT_T
|
||||
#define GEMMA_WEIGHT_T SfpStream
|
||||
#endif // !GEMMA_WEIGHT_T
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
|
||||
static constexpr size_t kTopK = GEMMA_TOPK;
|
||||
static constexpr size_t kMaxThreads = GEMMA_MAX_THREADS;
|
||||
|
||||
using GemmaWeightT = GEMMA_WEIGHT_T;
|
||||
|
||||
using EmbedderInputT = hwy::bfloat16_t;
|
||||
|
||||
enum class LayerAttentionType {
|
||||
kGemma,
|
||||
kGriffinRecurrentBlock,
|
||||
|
|
|
|||
984
gemma/gemma.cc
984
gemma/gemma.cc
File diff suppressed because it is too large
Load Diff
133
gemma/gemma.h
133
gemma/gemma.h
|
|
@ -24,22 +24,12 @@
|
|||
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // hwy::bfloat16_t
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
using GemmaWeightT = GEMMA_WEIGHT_T;
|
||||
using EmbedderInputT = hwy::bfloat16_t;
|
||||
// Will be called for layers output with:
|
||||
// - position in the tokens sequence
|
||||
// - name of the data, p.ex. "tokens", "block.1", "final_norm"
|
||||
// - pointer to the data array
|
||||
// - size of the data array
|
||||
using LayersOutputT =
|
||||
std::function<void(int, const std::string&, const float*, size_t)>;
|
||||
constexpr size_t kPrefillBatchSize = 16;
|
||||
constexpr bool kSystemPrompt = false;
|
||||
|
||||
|
|
@ -50,14 +40,30 @@ struct KVCache {
|
|||
conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kGriffinLayers
|
||||
hwy::AlignedFreeUniquePtr<float[]>
|
||||
rglru_cache; // kModelDim * kGriffinLayers
|
||||
|
||||
static KVCache Create(Model type);
|
||||
};
|
||||
|
||||
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
|
||||
constexpr int EOS_ID = 1;
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
// Thread-hostile.
|
||||
const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
||||
Model& model, ModelTraining& training);
|
||||
class GemmaTokenizer {
|
||||
public:
|
||||
GemmaTokenizer() = default; // for second Gemma ctor.
|
||||
explicit GemmaTokenizer(const Path& tokenizer_path);
|
||||
|
||||
// must come after definition of Impl
|
||||
~GemmaTokenizer();
|
||||
GemmaTokenizer(GemmaTokenizer&& other);
|
||||
GemmaTokenizer& operator=(GemmaTokenizer&& other);
|
||||
|
||||
bool Encode(const std::string& input, std::vector<std::string>* pieces) const;
|
||||
bool Encode(const std::string& input, std::vector<int>* pieces) const;
|
||||
bool Decode(const std::vector<int>& ids, std::string* detokenized) const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
// StreamFunc is called with (token, probability). For prompt tokens,
|
||||
// probability is 0.0f. StreamFunc should return False to stop generation and
|
||||
|
|
@ -67,8 +73,6 @@ using StreamFunc = std::function<bool(int, float)>;
|
|||
// want to generate and True for tokens you want to generate.
|
||||
using AcceptFunc = std::function<bool(int)>;
|
||||
|
||||
constexpr int EOS_ID = 1;
|
||||
|
||||
struct RuntimeConfig {
|
||||
size_t max_tokens;
|
||||
size_t max_generated_tokens;
|
||||
|
|
@ -80,64 +84,65 @@ struct RuntimeConfig {
|
|||
int eos_id = EOS_ID;
|
||||
};
|
||||
|
||||
struct GemmaInterface;
|
||||
|
||||
class GemmaTokenizer {
|
||||
public:
|
||||
virtual ~GemmaTokenizer() = default;
|
||||
virtual bool Encode(const std::string& input,
|
||||
std::vector<std::string>* pieces) const = 0;
|
||||
virtual bool Encode(const std::string& input,
|
||||
std::vector<int>* pieces) const = 0;
|
||||
virtual bool Decode(const std::vector<int>& ids,
|
||||
std::string* detokenized) const = 0;
|
||||
};
|
||||
|
||||
struct Gemma {
|
||||
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||
hwy::ThreadPool& pool);
|
||||
~Gemma(); // must be defined after the GemmaInterface dtor is defined.
|
||||
const GemmaTokenizer* Tokenizer() const;
|
||||
std::unique_ptr<GemmaInterface> impl_;
|
||||
};
|
||||
|
||||
struct TimingInfo {
|
||||
double prefill_tok_sec = 0.0;
|
||||
double gen_tok_sec = 0.0;
|
||||
double time_to_first_token = 0;
|
||||
};
|
||||
|
||||
KVCache CreateKVCache(Model type); // convenient workaround for now
|
||||
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
|
||||
size_t conv1d_cache_size, size_t rglru_cache_size);
|
||||
// Will be called for layers output with:
|
||||
// - position in the tokens sequence
|
||||
// - name of the data, p.ex. "tokens", "block.1", "final_norm"
|
||||
// - pointer to the data array
|
||||
// - size of the data array
|
||||
using LayersOutputT =
|
||||
std::function<void(int, const std::string&, const float*, size_t)>;
|
||||
|
||||
// Bundle runtime parameters as RuntimeConfig
|
||||
// layers_output is optional; if set - it will be called with the activations
|
||||
// output after applying each layer.
|
||||
void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output = nullptr);
|
||||
|
||||
void GenerateGemma(Model model, const ByteStorageT& weights,
|
||||
ByteStorageT& inference_state,
|
||||
RuntimeConfig runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
TimingInfo& timing_info);
|
||||
|
||||
ByteStorageT LoadWeights(const Path& weights, Model model,
|
||||
class Gemma {
|
||||
public:
|
||||
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||
hwy::ThreadPool& pool);
|
||||
|
||||
ByteStorageT AllocateInferenceState(Model model);
|
||||
// Allocates weights, caller is responsible for filling them.
|
||||
Gemma(GemmaTokenizer&& tokenizer, Model model_type, hwy::ThreadPool& pool);
|
||||
~Gemma();
|
||||
|
||||
void CompressWeights(gcpp::Model model, const Path& weights,
|
||||
const Path& compressed_weights, hwy::ThreadPool& pool);
|
||||
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
||||
const ByteStorageT& Weights() const { return weights_u8_; }
|
||||
const ByteStorageT& Prefill() const { return prefill_u8_; }
|
||||
const ByteStorageT& Decode() const { return decode_u8_; }
|
||||
|
||||
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||
const std::vector<int>& prompt, KVCache& kv_cache,
|
||||
hwy::ThreadPool& pool, int verbosity);
|
||||
// layers_output is optional; if set - it will be called with the activations
|
||||
// output after applying each layer.
|
||||
void Generate(const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output = nullptr);
|
||||
|
||||
float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
|
||||
KVCache& kv_cache, int verbosity);
|
||||
|
||||
private:
|
||||
hwy::ThreadPool& pool_;
|
||||
|
||||
GemmaTokenizer tokenizer_;
|
||||
// Type-erased so that this can be defined in the header, without requiring
|
||||
// forwarding functions.
|
||||
ByteStorageT weights_u8_;
|
||||
ByteStorageT prefill_u8_;
|
||||
ByteStorageT decode_u8_;
|
||||
Model model_type_;
|
||||
};
|
||||
|
||||
// DEPRECATED, call Gemma::Generate directly.
|
||||
HWY_INLINE void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
|
||||
const std::vector<int>& prompt, size_t start_pos,
|
||||
KVCache& kv_cache, hwy::ThreadPool& /*pool*/,
|
||||
TimingInfo& timing_info,
|
||||
LayersOutputT* layers_output) {
|
||||
gemma.Generate(runtime_config, prompt, start_pos, kv_cache, timing_info,
|
||||
layers_output);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
|
|
@ -22,8 +22,9 @@
|
|||
#include <thread> // NOLINT
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/ops.h"
|
||||
#include "util/args.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
|
||||
|
|
@ -38,7 +39,7 @@ class GemmaTest : public ::testing::Test {
|
|||
pool(std::min<int>(20, (std::thread::hardware_concurrency() - 1) / 2)),
|
||||
model_type(gcpp::Model::GEMMA_2B),
|
||||
model(tokenizer, weights, model_type, pool) {
|
||||
kv_cache = CreateKVCache(model_type);
|
||||
KVCache kv_cache = KVCache::Create(model_type);
|
||||
}
|
||||
|
||||
std::string GemmaReply(const std::string& prompt_string) {
|
||||
|
|
@ -46,7 +47,7 @@ class GemmaTest : public ::testing::Test {
|
|||
gen.seed(42);
|
||||
|
||||
std::vector<int> prompt;
|
||||
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
|
||||
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt));
|
||||
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
||||
// if needed.
|
||||
prompt.insert(prompt.begin(), 2);
|
||||
|
|
@ -66,18 +67,18 @@ class GemmaTest : public ::testing::Test {
|
|||
.accept_token = [](int) { return true; },
|
||||
};
|
||||
gcpp::TimingInfo timing_info;
|
||||
gcpp::GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0,
|
||||
kv_cache, pool, timing_info, /*layers_output=*/nullptr);
|
||||
model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache,
|
||||
timing_info, /*layers_output=*/nullptr);
|
||||
std::string response_text;
|
||||
HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text));
|
||||
HWY_ASSERT(model.Tokenizer().Decode(response, &response_text));
|
||||
return response_text;
|
||||
}
|
||||
|
||||
float GemmaCrossEntropy(const std::string& prompt_string) {
|
||||
std::vector<int> prompt;
|
||||
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
|
||||
return gcpp::ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt,
|
||||
kv_cache, pool, /*verbosity=*/0) /
|
||||
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt));
|
||||
return model.ComputeCrossEntropy(/*max_tokens=*/3072, prompt, kv_cache,
|
||||
/*verbosity=*/0) /
|
||||
prompt_string.size();
|
||||
}
|
||||
|
||||
|
|
|
|||
19
gemma/ops.h
19
gemma/ops.h
|
|
@ -17,12 +17,12 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
|
||||
|
||||
#include <math.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <random>
|
||||
#include <type_traits> // std::enable_if_t
|
||||
|
||||
|
|
@ -31,21 +31,6 @@
|
|||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// __builtin_sqrt is not constexpr as of Clang 17.
|
||||
#if HWY_COMPILER_GCC_ACTUAL
|
||||
#define GEMMA_CONSTEXPR_SQRT constexpr
|
||||
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) {
|
||||
return __builtin_sqrt(x);
|
||||
}
|
||||
#else
|
||||
#define GEMMA_CONSTEXPR_SQRT
|
||||
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
|
||||
#endif
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
|
|
|
|||
14
gemma/run.cc
14
gemma/run.cc
|
|
@ -25,6 +25,7 @@
|
|||
|
||||
// Placeholder for internal header, do not modify.
|
||||
#include "compression/compress.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h" // Gemma
|
||||
#include "util/app.h"
|
||||
#include "util/args.h" // HasHelp
|
||||
|
|
@ -116,8 +117,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
|||
|
||||
// callback function invoked for each generated token.
|
||||
auto stream_token = [&abs_pos, ¤t_pos, &args, &gen, &prompt_size,
|
||||
tokenizer = model.Tokenizer(),
|
||||
verbosity](int token, float) {
|
||||
&model, verbosity](int token, float) {
|
||||
++abs_pos;
|
||||
++current_pos;
|
||||
// <= since position is incremented before
|
||||
|
|
@ -135,7 +135,8 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
|||
}
|
||||
} else {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text));
|
||||
HWY_ASSERT(
|
||||
model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
// +1 since position is incremented above
|
||||
if (current_pos == prompt_size + 1) {
|
||||
// first token of response
|
||||
|
|
@ -192,7 +193,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
|||
}
|
||||
}
|
||||
|
||||
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
|
||||
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt));
|
||||
|
||||
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
|
||||
// if needed.
|
||||
|
|
@ -221,8 +222,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
|||
.stream_token = stream_token,
|
||||
.accept_token = accept_token,
|
||||
};
|
||||
GenerateGemma(model, runtime_config, prompt, abs_pos, kv_cache, pool,
|
||||
timing_info);
|
||||
model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info);
|
||||
if (verbosity >= 2) {
|
||||
std::cout << current_pos << " tokens (" << abs_pos << " total tokens)"
|
||||
<< "\n"
|
||||
|
|
@ -251,7 +251,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
|
||||
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
|
||||
|
||||
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||
KVCache kv_cache = KVCache::Create(loader.ModelType());
|
||||
|
||||
if (const char* error = inference.Validate()) {
|
||||
ShowHelp(loader, inference, app);
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@
|
|||
|
||||
// Command line text interface to gemma.
|
||||
|
||||
#include <ctime>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
|
|
@ -63,7 +62,7 @@ void JsonGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
|
|||
std::set<int> accept_token_set{};
|
||||
for (const std::string& accept_token : accept_tokens) {
|
||||
std::vector<int> accept_token_ids;
|
||||
HWY_ASSERT(model.Tokenizer()->Encode(accept_token, &accept_token_ids));
|
||||
HWY_ASSERT(model.Tokenizer().Encode(accept_token, &accept_token_ids));
|
||||
accept_token_set.insert(accept_token_ids.begin(), accept_token_ids.end());
|
||||
}
|
||||
|
||||
|
|
@ -76,7 +75,7 @@ void JsonGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
|
|||
const std::string& prompt_string = sample["prompt"];
|
||||
std::vector<int> prompt;
|
||||
|
||||
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
|
||||
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt));
|
||||
prompt_size = prompt.size();
|
||||
|
||||
const std::string& correct_answer = accept_tokens[sample["input_label"]];
|
||||
|
|
@ -124,11 +123,10 @@ void JsonGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
|
|||
.stream_token = stream_token,
|
||||
.accept_token = accept_token,
|
||||
};
|
||||
GenerateGemma(model, runtime_config, prompt, abs_pos, kv_cache, pool,
|
||||
timing_info);
|
||||
model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info);
|
||||
|
||||
std::string output_string;
|
||||
HWY_ASSERT(model.Tokenizer()->Decode(predicted_token_ids, &output_string));
|
||||
HWY_ASSERT(model.Tokenizer().Decode(predicted_token_ids, &output_string));
|
||||
std::cout << "QuestionId: " << sample["i"] << "; "
|
||||
<< "Predicted Answer: " << output_string << "; "
|
||||
<< "Correct Answer: " << correct_answer << std::endl;
|
||||
|
|
@ -161,7 +159,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
|
||||
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
|
||||
|
||||
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
|
||||
|
||||
JsonGemma(model, kv_cache, pool, inference, app.verbosity, app.eot_line);
|
||||
}
|
||||
|
|
|
|||
261
gemma/weights.cc
261
gemma/weights.cc
|
|
@ -15,53 +15,230 @@
|
|||
|
||||
#include "gemma/weights.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/stats.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
ByteStorageT AllocateWeights(Model model, hwy::ThreadPool& pool) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
return AllocateWeights<float, ConfigGemma2B>(pool);
|
||||
case Model::GEMMA_7B:
|
||||
return AllocateWeights<float, ConfigGemma7B>(pool);
|
||||
case Model::GRIFFIN_2B:
|
||||
return AllocateWeights<float, ConfigGriffin2B>(pool);
|
||||
case Model::GEMMA_TINY:
|
||||
return AllocateWeights<float, ConfigGemmaTiny>(pool);
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
// Setting this to true disables fread() calls that read the model file.
|
||||
constexpr bool kDryRunFread = false;
|
||||
|
||||
namespace {
|
||||
float ScaleWeights(float* data, size_t len) {
|
||||
float maxabs = 0.0;
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
maxabs = std::max(maxabs, std::abs(data[i]));
|
||||
}
|
||||
const float kMaxRange = 1.875f;
|
||||
if (maxabs <= kMaxRange) {
|
||||
return 1.0f;
|
||||
}
|
||||
const float scale = maxabs / kMaxRange;
|
||||
const float inv_scale = 1.0f / scale;
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
data[i] *= inv_scale;
|
||||
}
|
||||
return scale;
|
||||
}
|
||||
|
||||
#define READ_WEIGHTS(name) \
|
||||
do { \
|
||||
do_fread(&(layer_view->name), layer, #name, sizeof(layer_view->name)); \
|
||||
} while (0)
|
||||
|
||||
#define SCALE_WEIGHTS(name) \
|
||||
do { \
|
||||
if (ok && !kDryRunFread && scale_for_compression) { \
|
||||
weights->scales[scale_pos++] = \
|
||||
ScaleWeights(layer_view->name.data(), layer_view->name.size()); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
template <typename TConfig>
|
||||
struct LoadRawWeightsT {
|
||||
ByteStorageT operator()(const Path& checkpoint, hwy::ThreadPool& pool,
|
||||
bool scale_for_compression) const {
|
||||
PROFILER_ZONE("Startup.LoadWeights");
|
||||
if (!checkpoint.Exists()) {
|
||||
HWY_ABORT("The model weights file '%s' does not exist.",
|
||||
checkpoint.path.c_str());
|
||||
}
|
||||
|
||||
ByteStorageT weights_u8 = AllocateWeights<TConfig>()(pool);
|
||||
auto* weights = reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||
|
||||
size_t scale_pos = 0;
|
||||
FILE* fptr;
|
||||
if constexpr (kDryRunFread) {
|
||||
fprintf(stderr, "Dry-Run, not reading model-file.\n");
|
||||
} else {
|
||||
fptr = fopen(checkpoint.path.c_str(), "rb");
|
||||
if (fptr == nullptr) {
|
||||
HWY_ABORT("Failed to open model file %s - does it exist?",
|
||||
checkpoint.path.c_str());
|
||||
}
|
||||
}
|
||||
bool ok = true;
|
||||
uint64_t total_size = 0;
|
||||
auto do_fread = [&](void* var, int layer, const char* name, size_t size) {
|
||||
if (layer == -1) {
|
||||
fprintf(stderr, "Loading Parameters (size %zu): %s\n", size, name);
|
||||
} else {
|
||||
fprintf(stderr, "Loading Parameters (layer=%d, size %zu): %s\n", layer,
|
||||
size, name);
|
||||
}
|
||||
if constexpr (!kDryRunFread) {
|
||||
ok &= 1 == fread(var, size, 1, fptr);
|
||||
total_size += size;
|
||||
}
|
||||
};
|
||||
do_fread(&(weights->embedder_input_embedding), -1,
|
||||
"embedder_input_embedding",
|
||||
sizeof(weights->embedder_input_embedding));
|
||||
do_fread(&(weights->final_norm_scale), -1, "final_norm_scale",
|
||||
sizeof(weights->final_norm_scale));
|
||||
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
|
||||
auto type = TConfig::kLayerConfig[layer];
|
||||
LayerF<TConfig>* layer_view = weights->GetLayer(layer);
|
||||
|
||||
// Make sure we don't have uninitialized memory.
|
||||
hwy::ZeroBytes(layer_view, sizeof(*layer_view));
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
READ_WEIGHTS(attn_vec_einsum_w);
|
||||
READ_WEIGHTS(qkv_einsum_w);
|
||||
SCALE_WEIGHTS(attn_vec_einsum_w);
|
||||
SCALE_WEIGHTS(qkv_einsum_w);
|
||||
} else {
|
||||
READ_WEIGHTS(griffin.linear_x_w);
|
||||
READ_WEIGHTS(griffin.linear_x_biases);
|
||||
READ_WEIGHTS(griffin.linear_y_w);
|
||||
READ_WEIGHTS(griffin.linear_y_biases);
|
||||
READ_WEIGHTS(griffin.linear_out_w);
|
||||
READ_WEIGHTS(griffin.linear_out_biases);
|
||||
READ_WEIGHTS(griffin.conv_w);
|
||||
READ_WEIGHTS(griffin.conv_biases);
|
||||
READ_WEIGHTS(griffin.gate_w);
|
||||
READ_WEIGHTS(griffin.gate_biases);
|
||||
READ_WEIGHTS(griffin.a);
|
||||
SCALE_WEIGHTS(griffin.linear_x_w);
|
||||
SCALE_WEIGHTS(griffin.linear_y_w);
|
||||
SCALE_WEIGHTS(griffin.linear_out_w);
|
||||
SCALE_WEIGHTS(griffin.gate_w);
|
||||
}
|
||||
READ_WEIGHTS(gating_einsum_w);
|
||||
READ_WEIGHTS(linear_w);
|
||||
SCALE_WEIGHTS(gating_einsum_w);
|
||||
SCALE_WEIGHTS(linear_w);
|
||||
READ_WEIGHTS(pre_attention_norm_scale);
|
||||
READ_WEIGHTS(pre_ffw_norm_scale);
|
||||
if (TConfig::kPostNormScale) {
|
||||
READ_WEIGHTS(post_attention_norm_scale);
|
||||
READ_WEIGHTS(post_ffw_norm_scale);
|
||||
}
|
||||
if (TConfig::kFFBiases) {
|
||||
READ_WEIGHTS(ffw_gating_biases);
|
||||
READ_WEIGHTS(ffw_output_biases);
|
||||
}
|
||||
if (TConfig::kSoftmaxAttnOutputBiases &&
|
||||
type == LayerAttentionType::kGemma) {
|
||||
READ_WEIGHTS(attention_output_biases);
|
||||
}
|
||||
}
|
||||
if (!ok) {
|
||||
HWY_ABORT(
|
||||
"Failed to read from %s - might be a directory, or too small? "
|
||||
"expected size: %d kB",
|
||||
checkpoint.path.c_str(), static_cast<uint32_t>(total_size >> 10));
|
||||
}
|
||||
if (!kDryRunFread) {
|
||||
HWY_ASSERT(0 == fclose(fptr));
|
||||
if (scale_for_compression) {
|
||||
HWY_ASSERT(scale_pos == TConfig::kNumTensorScales);
|
||||
}
|
||||
}
|
||||
return weights_u8;
|
||||
}
|
||||
};
|
||||
|
||||
#undef READ_WEIGHTS
|
||||
#undef SCALE_WEIGHTS
|
||||
} // namespace
|
||||
|
||||
ByteStorageT LoadRawWeights(const Path& weights, Model model,
|
||||
hwy::ThreadPool& pool, bool scale_for_compression) {
|
||||
return CallFunctorForModel<LoadRawWeightsT>(model, weights, pool,
|
||||
scale_for_compression);
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename TConfig>
|
||||
void ZeroInitWeightsT(ByteStorageT& weights, hwy::ThreadPool& pool) {
|
||||
ZeroInit<float, TConfig>(
|
||||
*reinterpret_cast<Weights<float, TConfig>*>(weights.get()));
|
||||
template <class TConfig>
|
||||
struct LoadCompressedWeightsT {
|
||||
ByteStorageT operator()(const Path& weights, hwy::ThreadPool& pool) const {
|
||||
PROFILER_ZONE("Startup.LoadCompressedWeights");
|
||||
if (!weights.Exists()) {
|
||||
HWY_ABORT("The model weights file '%s' does not exist.",
|
||||
weights.path.c_str());
|
||||
}
|
||||
|
||||
// Allocate compressed weights.
|
||||
using CWeights = CompressedWeights<TConfig>;
|
||||
ByteStorageT c_weights_u8 = AllocateSizeof<CWeights>();
|
||||
CWeights* c_weights = reinterpret_cast<CWeights*>(c_weights_u8.get());
|
||||
new (&c_weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
|
||||
|
||||
std::array<float, TConfig::kNumTensorScales> scales;
|
||||
CacheLoader loader(weights);
|
||||
ForEachTensor<TConfig>(nullptr, *c_weights, loader);
|
||||
loader.LoadScales(scales.data(), scales.size());
|
||||
if (!loader.ReadAll(pool)) {
|
||||
HWY_ABORT("Failed to load model weights.");
|
||||
}
|
||||
if (TConfig::kNumTensorScales > 0) {
|
||||
size_t scale_pos = 0;
|
||||
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||
auto type = TConfig::kLayerConfig[layer_idx];
|
||||
const size_t idx = static_cast<size_t>(layer_idx);
|
||||
CompressedLayer<TConfig>* layer_weights = c_weights->GetLayer(idx);
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
layer_weights->attn_vec_einsum_w.set_scale(scales[scale_pos++]);
|
||||
layer_weights->qkv_einsum_w.set_scale(scales[scale_pos++]);
|
||||
} else {
|
||||
layer_weights->griffin.linear_x_w.set_scale(scales[scale_pos++]);
|
||||
layer_weights->griffin.linear_y_w.set_scale(scales[scale_pos++]);
|
||||
layer_weights->griffin.linear_out_w.set_scale(scales[scale_pos++]);
|
||||
layer_weights->griffin.gate_w.set_scale(scales[scale_pos++]);
|
||||
}
|
||||
layer_weights->gating_einsum_w.set_scale(scales[scale_pos++]);
|
||||
layer_weights->linear_w.set_scale(scales[scale_pos++]);
|
||||
}
|
||||
HWY_ASSERT(scale_pos == TConfig::kNumTensorScales);
|
||||
}
|
||||
return c_weights_u8;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void ZeroInitWeights(Model model, ByteStorageT& weights,
|
||||
ByteStorageT LoadCompressedWeights(const Path& weights, Model model,
|
||||
hwy::ThreadPool& pool) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
ZeroInitWeightsT<ConfigGemma2B>(weights, pool);
|
||||
break;
|
||||
case Model::GEMMA_7B:
|
||||
ZeroInitWeightsT<ConfigGemma7B>(weights, pool);
|
||||
break;
|
||||
case Model::GRIFFIN_2B:
|
||||
ZeroInitWeightsT<ConfigGriffin2B>(weights, pool);
|
||||
break;
|
||||
case Model::GEMMA_TINY:
|
||||
ZeroInitWeightsT<ConfigGemmaTiny>(weights, pool);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
return CallFunctorForModel<LoadCompressedWeightsT>(model, weights, pool);
|
||||
}
|
||||
|
||||
ByteStorageT LoadWeights(const Path& weights, Model model,
|
||||
hwy::ThreadPool& pool) {
|
||||
if constexpr (kWeightsAreCompressed) {
|
||||
return LoadCompressedWeights(weights, model, pool);
|
||||
} else {
|
||||
return LoadRawWeights(weights, model, pool,
|
||||
/*scale_for_compression=*/false);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -86,27 +263,19 @@ class WeightLogger {
|
|||
};
|
||||
|
||||
template <typename TConfig>
|
||||
void LogWeightStats(const ByteStorageT& weights_u8) {
|
||||
const auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||
struct LogWeightStatsT {
|
||||
void operator()(const ByteStorageT& weights_u8) const {
|
||||
const auto& weights =
|
||||
*reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
|
||||
WeightLogger logger;
|
||||
ForEachTensor1<float, TConfig>(logger, weights);
|
||||
printf("%-20s %12zu\n", "Total", logger.total_weights);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void LogWeightStats(gcpp::Model model, const ByteStorageT& weights) {
|
||||
switch (model) {
|
||||
case Model::GEMMA_2B:
|
||||
return LogWeightStats<ConfigGemma2B>(weights);
|
||||
case Model::GEMMA_7B:
|
||||
return LogWeightStats<ConfigGemma7B>(weights);
|
||||
case Model::GRIFFIN_2B:
|
||||
return LogWeightStats<ConfigGriffin2B>(weights);
|
||||
case Model::GEMMA_TINY:
|
||||
return LogWeightStats<ConfigGemmaTiny>(weights);
|
||||
default:
|
||||
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
|
||||
}
|
||||
CallFunctorForModel<LogWeightStatsT>(model, weights);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
313
gemma/weights.h
313
gemma/weights.h
|
|
@ -16,13 +16,21 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Setting this to false will load and use uncompressed weights.
|
||||
constexpr bool kWeightsAreCompressed = true;
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Uncompressed
|
||||
|
||||
template <typename T, class TConfig>
|
||||
struct Layer {
|
||||
Layer() {}
|
||||
|
|
@ -118,14 +126,261 @@ struct Weights {
|
|||
template <class TConfig>
|
||||
using WeightsF = Weights<float, TConfig>;
|
||||
|
||||
template <typename T, typename TConfig>
|
||||
ByteStorageT AllocateWeights(hwy::ThreadPool& pool) {
|
||||
using TWeights = Weights<T, TConfig>;
|
||||
ByteStorageT weights_u8 = hwy::AllocateAligned<uint8_t>(sizeof(TWeights));
|
||||
// ----------------------------------------------------------------------------
|
||||
// Compressed
|
||||
|
||||
template <class TConfig>
|
||||
struct CompressedLayer {
|
||||
// No ctor/dtor, allocated via AllocateAligned.
|
||||
|
||||
using TLayer = gcpp::LayerF<TConfig>;
|
||||
using WeightT = typename TConfig::WeightT;
|
||||
|
||||
static constexpr size_t kHeads = TLayer::kHeads;
|
||||
static constexpr size_t kKVHeads = TLayer::kKVHeads;
|
||||
static constexpr size_t kModelDim = TLayer::kModelDim;
|
||||
static constexpr size_t kQKVDim = TLayer::kQKVDim;
|
||||
static constexpr size_t kFFHiddenDim = TLayer::kFFHiddenDim;
|
||||
static constexpr size_t kAttVecEinsumWSize = TLayer::kAttVecEinsumWSize;
|
||||
static constexpr size_t kQKVEinsumWSize = TLayer::kQKVEinsumWSize;
|
||||
static constexpr size_t kGatingEinsumWSize = TLayer::kGatingEinsumWSize;
|
||||
static constexpr size_t kConv1dWidth = TLayer::kConv1dWidth;
|
||||
static constexpr bool kFFBiases = TLayer::kFFBiases;
|
||||
static constexpr bool kPostNormScale = TConfig::kPostNormScale;
|
||||
static constexpr size_t kAOBiasDim = TLayer::kAOBiasDim;
|
||||
static constexpr size_t kGriffinDim = TLayer::kGriffinDim;
|
||||
|
||||
// Compressed Parameters
|
||||
|
||||
template <class T, size_t N>
|
||||
using ArrayT = CompressedArray<T, N>;
|
||||
|
||||
union {
|
||||
struct {
|
||||
ArrayT<WeightT, kAttVecEinsumWSize> attn_vec_einsum_w;
|
||||
ArrayT<WeightT, kQKVEinsumWSize> qkv_einsum_w;
|
||||
ArrayT<float, kAOBiasDim> attention_output_biases;
|
||||
};
|
||||
|
||||
struct {
|
||||
ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_x_w;
|
||||
ArrayT<float, kGriffinDim> linear_x_biases;
|
||||
ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_y_w;
|
||||
ArrayT<float, kGriffinDim> linear_y_biases;
|
||||
ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_out_w;
|
||||
ArrayT<float, kGriffinDim> linear_out_biases;
|
||||
ArrayT<float, TConfig::kConv1dWidth * kGriffinDim> conv_w;
|
||||
ArrayT<float, kGriffinDim> conv_biases;
|
||||
ArrayT<WeightT, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
|
||||
ArrayT<float, kGriffinDim * 2> gate_biases;
|
||||
ArrayT<float, kGriffinDim> a;
|
||||
} griffin;
|
||||
};
|
||||
|
||||
ArrayT<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w;
|
||||
ArrayT<WeightT, kModelDim * kFFHiddenDim> linear_w;
|
||||
// We don't yet have an RMSNorm that accepts all WeightT.
|
||||
ArrayT<hwy::bfloat16_t, kModelDim> pre_attention_norm_scale;
|
||||
ArrayT<hwy::bfloat16_t, kModelDim> pre_ffw_norm_scale;
|
||||
ArrayT<hwy::bfloat16_t, kPostNormScale ? kModelDim : 0>
|
||||
post_attention_norm_scale;
|
||||
ArrayT<hwy::bfloat16_t, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale;
|
||||
|
||||
ArrayT<float, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
|
||||
ArrayT<float, kFFBiases ? kModelDim : 0> ffw_output_biases;
|
||||
};
|
||||
|
||||
// Array instead of single large allocation for parallel mem init. Split out
|
||||
// of CompressedWeights so that only these pointers are initialized, not the
|
||||
// CompressedArray.
|
||||
template <class TConfig>
|
||||
struct CompressedLayerPointers {
|
||||
explicit CompressedLayerPointers(hwy::ThreadPool& pool) {
|
||||
pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) {
|
||||
this->c_layers[task] = hwy::AllocateAligned<CompressedLayer<TConfig>>(1);
|
||||
});
|
||||
}
|
||||
|
||||
using CLayer = CompressedLayer<TConfig>;
|
||||
std::array<hwy::AlignedFreeUniquePtr<CLayer[]>, TConfig::kLayers> c_layers;
|
||||
};
|
||||
|
||||
template <class TConfig>
|
||||
struct CompressedWeights {
|
||||
// No ctor/dtor, allocated via AllocateAligned.
|
||||
|
||||
CompressedArray<EmbedderInputT, TConfig::kVocabSize * TConfig::kModelDim>
|
||||
embedder_input_embedding;
|
||||
|
||||
CompressedArray<hwy::bfloat16_t, TConfig::kModelDim> final_norm_scale;
|
||||
|
||||
// Must be last so that the other arrays remain aligned.
|
||||
CompressedLayerPointers<TConfig> c_layer_ptrs;
|
||||
|
||||
const CompressedLayer<TConfig>* GetLayer(size_t layer) const {
|
||||
return c_layer_ptrs.c_layers[layer].get();
|
||||
}
|
||||
CompressedLayer<TConfig>* GetLayer(size_t layer) {
|
||||
return c_layer_ptrs.c_layers[layer].get();
|
||||
}
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Interface
|
||||
|
||||
template <class TConfig>
|
||||
using WeightsT = hwy::If<kWeightsAreCompressed, CompressedWeights<TConfig>,
|
||||
WeightsF<TConfig>>;
|
||||
|
||||
// Call via CallFunctorForModel.
|
||||
template <typename TConfig>
|
||||
struct AllocateWeights {
|
||||
ByteStorageT operator()(hwy::ThreadPool& pool) const {
|
||||
using TWeights = WeightsF<TConfig>;
|
||||
ByteStorageT weights_u8 = AllocateSizeof<TWeights>();
|
||||
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
|
||||
new (&weights->layer_ptrs) LayerPointers<T, TConfig>(pool);
|
||||
new (&weights->layer_ptrs) LayerPointers<float, TConfig>(pool);
|
||||
return weights_u8;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TConfig>
|
||||
struct ZeroInitWeights {
|
||||
void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const {
|
||||
WeightsF<TConfig>& w = *reinterpret_cast<WeightsF<TConfig>*>(weights.get());
|
||||
hwy::ZeroBytes(&w.embedder_input_embedding,
|
||||
sizeof(w.embedder_input_embedding));
|
||||
hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale));
|
||||
for (int i = 0; i < TConfig::kLayers; ++i) {
|
||||
hwy::ZeroBytes(w.GetLayer(i), sizeof(*w.GetLayer(i)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TConfig>
|
||||
struct CopyWeights {
|
||||
void operator()(WeightsF<TConfig>& dst, const WeightsF<TConfig>& src) const {
|
||||
hwy::CopyBytes(&src.embedder_input_embedding, &dst.embedder_input_embedding,
|
||||
sizeof(src.embedder_input_embedding));
|
||||
hwy::CopyBytes(&src.final_norm_scale, &dst.final_norm_scale,
|
||||
sizeof(src.final_norm_scale));
|
||||
for (int i = 0; i < TConfig::kLayers; ++i) {
|
||||
hwy::CopyBytes(src.GetLayer(i), dst.GetLayer(i),
|
||||
sizeof(*dst.GetLayer(i)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class TConfig>
|
||||
struct DeleteLayersPtrs {
|
||||
void operator()(ByteStorageT& weights_u8) const {
|
||||
auto* weights = reinterpret_cast<WeightsT<TConfig>*>(weights_u8.get());
|
||||
if constexpr (kWeightsAreCompressed) {
|
||||
weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
|
||||
} else {
|
||||
weights->layer_ptrs.~LayerPointers<float, TConfig>();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Owns weights and provides access to TConfig.
|
||||
template <typename T, typename TConfig>
|
||||
class WeightsWrapper {
|
||||
public:
|
||||
WeightsWrapper()
|
||||
: pool_(0),
|
||||
data_(AllocateWeights<TConfig>(pool_)),
|
||||
weights_(reinterpret_cast<Weights<T, TConfig>*>(data_.get())) {}
|
||||
|
||||
const Weights<T, TConfig>& get() const { return *weights_; }
|
||||
Weights<T, TConfig>& get() { return *weights_; }
|
||||
void clear() { ZeroInitWeights<TConfig>()(get()); }
|
||||
void copy(const WeightsWrapper<T, TConfig>& other) {
|
||||
CopyWeights<TConfig>()(get(), other.get());
|
||||
}
|
||||
|
||||
private:
|
||||
hwy::ThreadPool pool_;
|
||||
ByteStorageT data_;
|
||||
Weights<T, TConfig>* weights_;
|
||||
};
|
||||
|
||||
// For use by compress_weights.cc.
|
||||
ByteStorageT LoadRawWeights(const Path& weights, Model model,
|
||||
hwy::ThreadPool& pool, bool scale_for_compression);
|
||||
|
||||
// For gemma.cc; calls LoadRawWeights if !kWeightsAreCompressed.
|
||||
ByteStorageT LoadWeights(const Path& weights, Model model,
|
||||
hwy::ThreadPool& pool);
|
||||
|
||||
void LogWeightStats(Model model, const ByteStorageT& weights);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Iterators
|
||||
|
||||
#define GEMMA_CALL_FUNC(name, member) \
|
||||
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
|
||||
func(name_buf, layer ? layer->member.data() : nullptr, layer_weights->member)
|
||||
|
||||
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
|
||||
// if weights = null, which happens during the first call where we attempt to
|
||||
// load from cache.
|
||||
//
|
||||
// This avoids repeating the list of tensors between loading and compressing.
|
||||
template <class TConfig, class Func>
|
||||
void ForEachTensor(const WeightsF<TConfig>* weights,
|
||||
CompressedWeights<TConfig>& c_weights, Func& func) {
|
||||
func("c_embedding",
|
||||
weights ? weights->embedder_input_embedding.data() : nullptr,
|
||||
c_weights.embedder_input_embedding);
|
||||
func("c_final_norm", weights ? weights->final_norm_scale.data() : nullptr,
|
||||
c_weights.final_norm_scale);
|
||||
|
||||
char name_buf[16];
|
||||
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
|
||||
auto type = TConfig::kLayerConfig[layer_idx];
|
||||
const size_t idx = static_cast<size_t>(layer_idx);
|
||||
const LayerF<TConfig>* layer = weights ? weights->GetLayer(idx) : nullptr;
|
||||
CompressedLayer<TConfig>* layer_weights = c_weights.GetLayer(idx);
|
||||
|
||||
GEMMA_CALL_FUNC("pre_ff_ns", pre_ffw_norm_scale);
|
||||
GEMMA_CALL_FUNC("gating_ein", gating_einsum_w);
|
||||
GEMMA_CALL_FUNC("linear_w", linear_w);
|
||||
if (type == LayerAttentionType::kGemma) {
|
||||
GEMMA_CALL_FUNC("qkv_ein", qkv_einsum_w);
|
||||
GEMMA_CALL_FUNC("att_ein", attn_vec_einsum_w);
|
||||
} else {
|
||||
GEMMA_CALL_FUNC("gr_lin_x_w", griffin.linear_x_w);
|
||||
GEMMA_CALL_FUNC("gr_lin_x_b", griffin.linear_x_biases);
|
||||
GEMMA_CALL_FUNC("gr_lin_y_w", griffin.linear_y_w);
|
||||
GEMMA_CALL_FUNC("gr_lin_y_b", griffin.linear_y_biases);
|
||||
GEMMA_CALL_FUNC("gr_lin_out_w", griffin.linear_out_w);
|
||||
GEMMA_CALL_FUNC("gr_lin_out_b", griffin.linear_out_biases);
|
||||
GEMMA_CALL_FUNC("gr_conv_w", griffin.conv_w);
|
||||
GEMMA_CALL_FUNC("gr_conv_b", griffin.conv_biases);
|
||||
GEMMA_CALL_FUNC("gr_gate_w", griffin.gate_w);
|
||||
GEMMA_CALL_FUNC("gr_gate_b", griffin.gate_biases);
|
||||
GEMMA_CALL_FUNC("gr_a", griffin.a);
|
||||
}
|
||||
GEMMA_CALL_FUNC("pre_att_ns", pre_attention_norm_scale);
|
||||
if (TConfig::kPostNormScale) {
|
||||
GEMMA_CALL_FUNC("post_att_ns", post_attention_norm_scale);
|
||||
GEMMA_CALL_FUNC("post_ff_ns", post_ffw_norm_scale);
|
||||
}
|
||||
|
||||
if (TConfig::kFFBiases) {
|
||||
GEMMA_CALL_FUNC("ffw_gat_b", ffw_gating_biases);
|
||||
GEMMA_CALL_FUNC("ffw_out_b", ffw_output_biases);
|
||||
}
|
||||
|
||||
if (TConfig::kSoftmaxAttnOutputBiases &&
|
||||
type == LayerAttentionType::kGemma) {
|
||||
GEMMA_CALL_FUNC("attn_ob", attention_output_biases);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#undef GEMMA_CALL_FUNC
|
||||
|
||||
#define GEMMA_CALL_TOP_FUNC1(name, member) func(name, weights1.member)
|
||||
#define GEMMA_CALL_TOP_FUNC2(name, member) \
|
||||
|
|
@ -237,54 +492,6 @@ void ForEachTensor2(Func& func, const Weights<T, TConfig>& weights1,
|
|||
#undef GEMMA_CALL_LAYER_FUNC4
|
||||
#undef GEMMA_CALL_ALL_LAYER_FUNC
|
||||
|
||||
template<typename T, typename TConfig>
|
||||
void ZeroInit(Weights<T, TConfig>& w) {
|
||||
hwy::ZeroBytes(&w.embedder_input_embedding,
|
||||
sizeof(w.embedder_input_embedding));
|
||||
hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale));
|
||||
for (int i = 0; i < TConfig::kLayers; ++i) {
|
||||
hwy::ZeroBytes(w.GetLayer(i), sizeof(*w.GetLayer(i)));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename TConfig>
|
||||
void Copy(Weights<T, TConfig>& dst, const Weights<T, TConfig>& src) {
|
||||
hwy::CopyBytes(&src.embedder_input_embedding, &dst.embedder_input_embedding,
|
||||
sizeof(src.embedder_input_embedding));
|
||||
hwy::CopyBytes(&src.final_norm_scale, &dst.final_norm_scale,
|
||||
sizeof(src.final_norm_scale));
|
||||
for (int i = 0; i < TConfig::kLayers; ++i) {
|
||||
hwy::CopyBytes(src.GetLayer(i), dst.GetLayer(i), sizeof(*dst.GetLayer(i)));
|
||||
}
|
||||
}
|
||||
|
||||
// Owns weights and undoes the type erasure of AllocateWeights.
|
||||
template<typename T, typename TConfig>
|
||||
class WeightsWrapper {
|
||||
public:
|
||||
WeightsWrapper()
|
||||
: pool_(0), data_(AllocateWeights<T, TConfig>(pool_)),
|
||||
weights_(reinterpret_cast<Weights<T, TConfig>*>(data_.get())) {}
|
||||
|
||||
const Weights<T, TConfig>& get() const { return *weights_; }
|
||||
Weights<T, TConfig>& get() { return *weights_; }
|
||||
void clear() { ZeroInit(get()); }
|
||||
void copy(const WeightsWrapper<T, TConfig>& other) {
|
||||
Copy(get(), other.get());
|
||||
}
|
||||
|
||||
private:
|
||||
hwy::ThreadPool pool_;
|
||||
ByteStorageT data_;
|
||||
Weights<T, TConfig>* weights_;
|
||||
};
|
||||
|
||||
ByteStorageT AllocateWeights(Model model, hwy::ThreadPool& pool);
|
||||
|
||||
void ZeroInitWeights(Model model, ByteStorageT& weights, hwy::ThreadPool& pool);
|
||||
|
||||
void LogWeightStats(Model model, const ByteStorageT& weights);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
|
||||
|
|
|
|||
Loading…
Reference in New Issue