Simplifications: remove GemmaInterface and GemmaImpl

Split common and weights into separate lib
Remove common-inl (does not have to be SIMD code), activations.cc
Centralize switch(Model) to avoid duplication
Move CompressWeightsT to compress_weights.cc
Move LoadWeights to weights.cc

PiperOrigin-RevId: 640869202
This commit is contained in:
Jan Wassenberg 2024-06-06 05:53:54 -07:00 committed by Copybara-Service
parent 5c3e5f7038
commit 57c2cd8b52
34 changed files with 1164 additions and 1425 deletions

View File

@ -53,6 +53,36 @@ cc_test(
],
)
cc_library(
name = "common",
srcs = ["gemma/common.cc"],
hdrs = [
"gemma/common.h",
"gemma/configs.h",
],
deps = [
"//compression:compress",
"//compression:io",
"@hwy//:hwy", # base.h
"@hwy//:thread_pool",
],
)
cc_library(
name = "weights",
srcs = ["gemma/weights.cc"],
hdrs = ["gemma/weights.h"],
deps = [
":common",
"//compression:compress",
"//compression:io",
"@hwy//:hwy",
"@hwy//:profiler",
"@hwy//:stats",
"@hwy//:thread_pool",
],
)
cc_library(
name = "gemma_lib",
srcs = [
@ -60,14 +90,12 @@ cc_library(
],
hdrs = [
"gemma/activations.h",
"gemma/common.h",
"gemma/common-inl.h",
"gemma/configs.h",
"gemma/gemma.h",
"gemma/weights.h",
],
deps = [
":common",
":ops",
":weights",
"//compression:compress",
"//compression:io",
"@hwy//:hwy",
@ -93,6 +121,7 @@ cc_library(
hdrs = ["util/app.h"],
deps = [
":args",
":common",
":gemma_lib",
"//compression:io",
"@hwy//:hwy",
@ -126,6 +155,7 @@ cc_binary(
deps = [
":app",
":args",
":common",
":gemma_lib",
# "//base",
"//compression:compress",
@ -141,7 +171,9 @@ cc_binary(
srcs = ["gemma/compress_weights.cc"],
deps = [
":args",
":common",
":gemma_lib",
":weights",
# "//base",
"//compression:compress",
"@hwy//:hwy",
@ -157,7 +189,9 @@ cc_binary(
deps = [
":app",
":args",
":common",
":gemma_lib",
"//compression:io",
"@hwy//:hwy",
"@hwy//:nanobenchmark",
"@hwy//:thread_pool",

View File

@ -50,8 +50,6 @@ set(SOURCES
backprop/backward.h
backprop/backward-inl.h
backprop/backward_scalar.h
backprop/common_scalar.cc
backprop/common_scalar.h
backprop/forward.cc
backprop/forward.h
backprop/forward-inl.h
@ -59,10 +57,9 @@ set(SOURCES
backprop/optimizer.cc
backprop/optimizer.h
gemma/configs.h
gemma/activations.cc
gemma/activations.h
gemma/common.cc
gemma/common.h
gemma/common-inl.h
gemma/gemma.cc
gemma/gemma.h
gemma/ops.h

View File

@ -151,9 +151,9 @@ file (usually `tokenizer.spm`), call `Encode()` to go from string prompts to
token id vectors, or `Decode()` to go from token id vector outputs from the
model back to strings.
### `GenerateGemma()` is the entrypoint for token generation
### `model.Generate()` is the entrypoint for token generation
Calling into `GenerateGemma` with a tokenized prompt will 1) mutate the
Calling into `model.Generate` with a tokenized prompt will 1) mutate the
activation values in `model` and 2) invoke StreamFunc - a lambda callback for
each generated token.
@ -170,11 +170,11 @@ no-op which is what `run.cc` does.
### `Transformer()` implements the inference (i.e. `forward()` method in PyTorch or Jax) computation of the neural network
For high-level applications, you might only call `GenerateGemma()` and never
For high-level applications, you might only call `model.Generate()` and never
interact directly with the neural network, but if you're doing something a bit
more custom you can call transformer which performs a single inference
operation on a single token and mutates the Activations and the KVCache through
the neural network computation.
more custom you can call transformer which performs a single inference operation
on a single token and mutates the Activations and the KVCache through the neural
network computation.
### For low level operations, defining new architectures, call `ops.h` functions directly

View File

@ -28,8 +28,8 @@
#include "backprop/prompt.h"
#include "gemma/activations.h"
#include "gemma/common.h"
#include "gemma/weights.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -43,7 +43,6 @@
#define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE
#endif
#include "gemma/common-inl.h"
#include "gemma/ops.h"
HWY_BEFORE_NAMESPACE();

View File

@ -54,6 +54,7 @@ void CrossEntropyLossBackwardPassT(Model model,
ByteStorageT& grad,
ByteStorageT& backward,
hwy::ThreadPool& pool) {
// TODO(janwas): use CallFunctorForModel
switch (model) {
case Model::GEMMA_2B:
CrossEntropyLossBackwardPass<ConfigGemma2B>(

View File

@ -23,9 +23,9 @@
#include <complex>
#include <vector>
#include "backprop/common_scalar.h"
#include "backprop/prompt.h"
#include "gemma/activations.h"
#include "gemma/common.h" // EmbeddingScaling
#include "gemma/weights.h"
namespace gcpp {

View File

@ -19,10 +19,10 @@
#include <complex>
#include <random>
#include "gtest/gtest.h"
#include "backprop/forward_scalar.h"
#include "backprop/sampler.h"
#include "backprop/test_util.h"
#include "gtest/gtest.h"
namespace gcpp {

View File

@ -30,11 +30,11 @@
#include "backprop/sampler.h"
#include "backprop/test_util.h"
#include "compression/compress.h"
#include "gemma/gemma.h"
#include "gemma/weights.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "gemma/gemma.h"
#include "gemma/weights.h"
// clang-format off
#undef HWY_TARGET_INCLUDE

View File

@ -1,51 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "backprop/common_scalar.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "gemma/ops.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
float EmbeddingScaling(int model_dim) {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
Sqrt(static_cast<float>(model_dim))));
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(EmbeddingScaling);
float EmbeddingScaling(int model_dim) {
return HWY_DYNAMIC_DISPATCH(EmbeddingScaling)(model_dim);
}
} // namespace gcpp
#endif // HWY_ONCE

View File

@ -1,119 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_
#include <complex>
namespace gcpp {
template<typename T, typename U>
U DotT(const T* a, const U* b, size_t N) {
U sum = {};
for (size_t i = 0; i < N; ++i) {
sum += a[i] * b[i];
}
return sum;
}
template<>
std::complex<double> DotT(const float* a, const std::complex<double>* b,
size_t N) {
std::complex<double> sum = {};
for (size_t i = 0; i < N; ++i) {
sum += static_cast<double>(a[i]) * b[i];
}
return sum;
}
template<typename T>
void MulByConstT(T c, T* x, size_t N) {
for (size_t i = 0; i < N; ++i) {
x[i] *= c;
}
}
// out += c * x
template<typename T>
void MulByConstAndAddT(T c, const T* x, T* out, size_t N) {
for (size_t i = 0; i < N; ++i) {
out[i] += c * x[i];
}
}
template<typename T, size_t N>
void MulByConstAndAddT(T c, const std::array<T, N>& x, std::array<T, N>& out) {
MulByConstAndAddT(c, x.data(), out.data(), N);
}
template<typename T>
void AddFromT(const T* a, T* out, size_t N) {
for (size_t i = 0; i < N; ++i) {
out[i] += a[i];
}
}
template<typename T>
T SquaredL2(const T* x, size_t N) {
T sum = {};
for (size_t i = 0; i < N; ++i) {
sum += x[i] * x[i];
}
return sum;
}
template<typename T>
T Gelu(T x) {
static const T kMul = 0.044715;
static const T kSqrt2OverPi = 0.797884560804236;
const T x3 = x * x * x;
const T arg = kSqrt2OverPi * (kMul * x3 + x);
const T cdf = T(0.5) * (T(1.0) + std::tanh(arg));
return x * cdf;
}
template<typename T, typename U>
void Rope(T* x, U base, size_t N, int i) {
const size_t N2 = N / 2;
for (size_t dim = 0; dim < N2; ++dim) {
const T freq_exponents = T(2 * dim) / T(N);
const T timescale = std::pow(base, freq_exponents);
const T theta = T(i) / timescale;
const T cos_val = std::cos(theta);
const T sin_val = std::sin(theta);
const T x0 = x[dim];
const T x1 = x[dim + N2];
x[dim] = x0 * cos_val - x1 * sin_val;
x[dim + N2] = x0 * sin_val + x1 * cos_val;
}
}
template<typename T>
void Rope(T* x, size_t N, int i) {
Rope(x, T(10000.0), N, i);
}
template<typename T>
void Rope(std::complex<T>* x, size_t N, int i) {
Rope(x, T(10000.0), N, i);
}
float EmbeddingScaling(int model_dim);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_

View File

@ -24,8 +24,8 @@
#include <cmath>
#include "gemma/activations.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -39,7 +39,6 @@
#define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE
#endif
#include "gemma/common-inl.h"
#include "gemma/ops.h"
HWY_BEFORE_NAMESPACE();

View File

@ -48,6 +48,7 @@ float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt,
const ByteStorageT& weights,
ByteStorageT& forward,
hwy::ThreadPool& pool) {
// TODO(janwas): use CallFunctorForModel
switch (model) {
case Model::GEMMA_2B:
return CrossEntropyLossForwardPass<ConfigGemma2B>(

View File

@ -23,9 +23,9 @@
#include <complex>
#include <vector>
#include "backprop/common_scalar.h"
#include "backprop/prompt.h"
#include "gemma/activations.h"
#include "gemma/common.h" // EmbeddingScaling
#include "gemma/weights.h"
namespace gcpp {

View File

@ -16,6 +16,7 @@
#include <iostream>
#include <string>
#include "gtest/gtest.h"
#include "backprop/backward.h"
#include "backprop/forward.h"
#include "backprop/optimizer.h"
@ -23,7 +24,6 @@
#include "gemma/activations.h"
#include "gemma/gemma.h"
#include "gemma/weights.h"
#include "gtest/gtest.h"
namespace gcpp {
@ -32,20 +32,20 @@ TEST(OptimizeTest, GradientDescent) {
std::mt19937 gen(42);
Model model_type = Model::GEMMA_TINY;
ByteStorageT weights = AllocateWeights(model_type, pool);
ByteStorageT grad = AllocateWeights(model_type, pool);
ByteStorageT grad_m = AllocateWeights(model_type, pool);
ByteStorageT grad_v = AllocateWeights(model_type, pool);
ByteStorageT forward = AllocateForwardPass(model_type);
ByteStorageT backward = AllocateForwardPass(model_type);
ByteStorageT inference = AllocateInferenceState(model_type);
auto kv_cache = CreateKVCache(model_type);
ByteStorageT grad = CallFunctorForModel<AllocateWeights>(model_type, pool);
ByteStorageT grad_m = CallFunctorForModel<AllocateWeights>(model_type, pool);
ByteStorageT grad_v = CallFunctorForModel<AllocateWeights>(model_type, pool);
ByteStorageT forward = CallFunctorForModel<AllocateForwardPass>(model_type);
ByteStorageT backward = CallFunctorForModel<AllocateForwardPass>(model_type);
KVCache kv_cache = KVCache::Create(model_type);
size_t max_tokens = 32;
size_t max_generated_tokens = 16;
float temperature = 1.0f;
int verbosity = 0;
const auto accept_token = [](int) { return true; };
Gemma gemma(GemmaTokenizer(), model_type, pool);
const auto generate = [&](const std::vector<int>& prompt) {
std::vector<int> reply;
auto stream_token = [&reply](int token, float) {
@ -57,8 +57,7 @@ TEST(OptimizeTest, GradientDescent) {
stream_token, accept_token, ReverseSequenceSampler::kEndToken,
};
TimingInfo timing_info;
GenerateGemma(model_type, weights, inference, runtime, prompt, 0,
kv_cache, pool, timing_info);
model.Generate(runtime, prompt, 0, kv_cache, timing_info);
return reply;
};
@ -74,12 +73,12 @@ TEST(OptimizeTest, GradientDescent) {
return ok;
};
RandInitWeights(model_type, weights, pool, gen);
ZeroInitWeights(model_type, grad_m, pool);
ZeroInitWeights(model_type, grad_v, pool);
CallFunctorForModel<RandInitWeights>(model_type, gemma.Weights(), pool, gen);
CallFunctorForModel<ZeroInitWeights>(model_type, grad_m, pool);
CallFunctorForModel<ZeroInitWeights>(model_type, grad_v, pool);
printf("Initial weights:\n");
LogWeightStats(model_type, weights);
LogWeightStats(model_type, gemma.Weights());
constexpr size_t kBatchSize = 8;
float learning_rate = 0.0005f;
@ -91,21 +90,21 @@ TEST(OptimizeTest, GradientDescent) {
size_t num_ok;
for (; steps < 1000000; ++steps) {
std::mt19937 sgen(42);
ZeroInitWeights(model_type, grad, pool);
CallFunctorForModel<ZeroInitWeights>(model_type, grad, pool);
float total_loss = 0.0f;
num_ok = 0;
for (size_t i = 0; i < kBatchSize; ++i) {
Prompt prompt = training_task.Sample(sgen);
total_loss += CrossEntropyLossForwardPass(
model_type, prompt, weights, forward, pool);
CrossEntropyLossBackwardPass(
model_type, prompt, weights, forward, grad, backward, pool);
total_loss += CrossEntropyLossForwardPass(model_type, prompt,
gemma.Weights(), forward, pool);
CrossEntropyLossBackwardPass(model_type, prompt, gemma.Weights(), forward,
grad, backward, pool);
num_ok += verify(prompt) ? 1 : 0;
}
total_loss /= kBatchSize;
const float scale = -learning_rate / kBatchSize;
UpdateWeights(model_type, grad, scale, weights, pool);
UpdateWeights(model_type, grad, scale, gemma.Weights(), pool);
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
steps, total_loss, num_ok, kBatchSize);
if (steps % 100 == 0) {
@ -119,7 +118,7 @@ TEST(OptimizeTest, GradientDescent) {
}
printf("Num steps: %zu\n", steps);
printf("Final weights:\n");
LogWeightStats(model_type, weights);
LogWeightStats(model_type, gemma.Weights());
EXPECT_LT(steps, 3000);
EXPECT_EQ(num_ok, kBatchSize);
}

View File

@ -41,14 +41,16 @@ class WeightInitializer {
};
template <typename TConfig>
void RandInitWeights(ByteStorageT& weights_u8, hwy::ThreadPool& pool,
std::mt19937& gen) {
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
// TODO(szabadka) Use the same weight initialization method as in the python
// version.
WeightInitializer init(gen);
ForEachTensor1<float, TConfig>(init, weights);
}
struct RandInitWeights {
void operator()(ByteStorageT& weights_u8, hwy::ThreadPool& pool,
std::mt19937& gen) const {
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
// TODO(szabadka) Use the same weight initialization method as in the python
// version.
WeightInitializer init(gen);
ForEachTensor1<float, TConfig>(init, weights);
}
};
class WeightUpdater {
public:
@ -67,55 +69,22 @@ class WeightUpdater {
};
template <typename TConfig>
void UpdateWeights(const ByteStorageT& grad_u8, float scale,
ByteStorageT& weights_u8, hwy::ThreadPool& pool) {
const auto& grad =
*reinterpret_cast<const WeightsF<TConfig>*>(grad_u8.get());
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
WeightUpdater updater(scale);
ForEachTensor2<float, TConfig>(updater, grad, weights);
}
struct UpdateWeightsT {
void operator()(const ByteStorageT& grad_u8, float scale,
ByteStorageT& weights_u8, hwy::ThreadPool& pool) const {
const auto& grad =
*reinterpret_cast<const WeightsF<TConfig>*>(grad_u8.get());
auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
WeightUpdater updater(scale);
ForEachTensor2<float, TConfig>(updater, grad, weights);
}
};
} // namespace
void RandInitWeights(Model model, ByteStorageT& weights_u8,
hwy::ThreadPool& pool, std::mt19937& gen) {
switch (model) {
case Model::GEMMA_2B:
RandInitWeights<ConfigGemma2B>(weights_u8, pool, gen);
break;
case Model::GEMMA_7B:
RandInitWeights<ConfigGemma7B>(weights_u8, pool, gen);
break;
case Model::GRIFFIN_2B:
RandInitWeights<ConfigGriffin2B>(weights_u8, pool, gen);
break;
case Model::GEMMA_TINY:
RandInitWeights<ConfigGemmaTiny>(weights_u8, pool, gen);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
void UpdateWeights(Model model, const ByteStorageT& grad, float scale,
ByteStorageT& weights, hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
UpdateWeights<ConfigGemma2B>(grad, scale, weights, pool);
break;
case Model::GEMMA_7B:
UpdateWeights<ConfigGemma7B>(grad, scale, weights, pool);
break;
case Model::GRIFFIN_2B:
UpdateWeights<ConfigGriffin2B>(grad, scale, weights, pool);
break;
case Model::GEMMA_TINY:
UpdateWeights<ConfigGemmaTiny>(grad, scale, weights, pool);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
CallFunctorForModel<UpdateWeightsT>(model, grad, scale, weights, pool);
}
} // namespace gcpp

View File

@ -16,16 +16,11 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_
#include <random>
#include "gemma/common.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
void RandInitWeights(Model model, ByteStorageT& weights, hwy::ThreadPool& pool,
std::mt19937& gen);
void UpdateWeights(Model model, const ByteStorageT& grad, float scale,
ByteStorageT& weights, hwy::ThreadPool& pool);

View File

@ -37,7 +37,7 @@ std::pair<std::string, int> QueryModel(
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input,
gcpp::LayersOutputT* layers_output) {
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
HWY_ASSERT(model.Tokenizer().Encode(input, &prompt));
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
@ -48,11 +48,10 @@ std::pair<std::string, int> QueryModel(
std::mt19937 gen;
gen.seed(42);
auto stream_token = [&res, &total_tokens, &app,
tokenizer = model.Tokenizer()](int token, float) {
auto stream_token = [&res, &total_tokens, &model](int token, float) {
++total_tokens;
std::string token_text;
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text));
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
res += token_text;
return true;
};
@ -70,8 +69,8 @@ std::pair<std::string, int> QueryModel(
.stream_token = stream_token,
.accept_token = accept_token,
};
GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0, kv_cache, pool,
timing_info, layers_output);
model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache, timing_info,
layers_output);
return {res, total_tokens};
}
@ -115,7 +114,7 @@ int main(int argc, char** argv) {
}
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
auto kv_cache = CreateKVCache(loader.ModelType());
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
const std::string& prompt = prompt_args.prompt;
if (prompt.empty()) {

View File

@ -40,7 +40,7 @@ int main(int argc, char** argv) {
// Instantiate model and KV Cache
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
auto kv_cache = CreateKVCache(loader.ModelType());
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
size_t pos = 0; // KV Cache position
// Initialize random number generator

View File

@ -1,38 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/activations.h"
#include "gemma/common.h"
#include "gemma/configs.h"
namespace gcpp {
ByteStorageT AllocateForwardPass(Model model) {
switch (model) {
case Model::GEMMA_2B:
return ForwardPass<float, ConfigGemma2B>::Allocate();
case Model::GEMMA_7B:
return ForwardPass<float, ConfigGemma7B>::Allocate();
case Model::GRIFFIN_2B:
return ForwardPass<float, ConfigGriffin2B>::Allocate();
case Model::GEMMA_TINY:
return ForwardPass<float, ConfigGemmaTiny>::Allocate();
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
} // namespace gcpp

View File

@ -16,8 +16,11 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_
#include "gemma/common.h"
#include "hwy/aligned_allocator.h"
#include <stddef.h>
#include <array>
#include "gemma/common.h" // ByteStorageT
namespace gcpp {
@ -54,9 +57,12 @@ struct ForwardPass {
std::array<T, kSeqLen * kModelDim> final_norm_output;
std::array<T, kSeqLen * kVocabSize> logits;
std::array<T, kSeqLen * kVocabSize> probs;
};
static ByteStorageT Allocate() {
return hwy::AllocateAligned<uint8_t>(sizeof(ForwardPass<T, TConfig>));
template <typename TConfig>
struct AllocateForwardPass {
ByteStorageT operator()() const {
return AllocateSizeof<ForwardPass<float, TConfig>>();
}
};
@ -78,8 +84,6 @@ class ActivationsWrapper {
WrappedT& activations_;
};
ByteStorageT AllocateForwardPass(Model model);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_

View File

@ -1,4 +1,5 @@
#include <algorithm>
#include <cstdlib> // EXIT_FAILURE
#include <fstream>
#include <iostream>
#include <ostream>
@ -8,6 +9,7 @@
#include <utility> // std::pair
#include <vector>
#include "compression/io.h" // Path
#include "gemma/gemma.h"
#include "util/app.h"
#include "util/args.h"
@ -61,7 +63,7 @@ std::pair<std::string, int> QueryModel(
gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app,
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input) {
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
HWY_ASSERT(model.Tokenizer().Encode(input, &prompt));
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
@ -73,11 +75,11 @@ std::pair<std::string, int> QueryModel(
gen.seed(42);
const double time_start = hwy::platform::Now();
auto stream_token = [&res, &total_tokens, &time_start, &app,
tokenizer = model.Tokenizer()](int token, float) {
auto stream_token = [&res, &total_tokens, &time_start, &app, &model](
int token, float) {
++total_tokens;
std::string token_text;
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text));
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
res += token_text;
if (app.verbosity >= 1 && total_tokens % 100 == 0) {
LogSpeedStats(time_start, total_tokens);
@ -98,8 +100,8 @@ std::pair<std::string, int> QueryModel(
.stream_token = stream_token,
.accept_token = accept_token,
};
GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0, kv_cache, pool,
timing_info, /*layers_output=*/nullptr);
model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache, timing_info,
/*layers_output=*/nullptr);
if (app.verbosity >= 1) {
LogSpeedStats(time_start, total_tokens);
}
@ -190,7 +192,7 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
size_t batch_tokens) {
std::string input = ReadFile(text);
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt));
HWY_ASSERT(model.Tokenizer().Encode(input, &prompt));
prompt.resize(std::min<size_t>(args.max_tokens, prompt.size()));
std::cout << "Number of input tokens: " << prompt.size() << "\n";
const double time_start = hwy::platform::Now();
@ -201,14 +203,13 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type,
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
std::vector<int> prompt_slice(prompt.begin() + pos,
prompt.begin() + pos + num_tokens);
auto kv_cache = CreateKVCache(model_type);
float entropy =
ComputeCrossEntropy(model, num_tokens, prompt_slice, kv_cache, pool,
app.verbosity);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(model_type);
float entropy = model.ComputeCrossEntropy(num_tokens, prompt_slice,
kv_cache, app.verbosity);
total_entropy += entropy;
LogSpeedStats(time_start, pos + num_tokens);
std::string text_slice;
HWY_ASSERT(model.Tokenizer()->Decode(prompt_slice, &text_slice));
HWY_ASSERT(model.Tokenizer().Decode(prompt_slice, &text_slice));
total_input_len += text_slice.size();
printf("Cross entropy per byte: %f [cumulative: %f]\n",
entropy / text_slice.size(), total_entropy / total_input_len);
@ -277,7 +278,7 @@ int main(int argc, char** argv) {
}
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
auto kv_cache = CreateKVCache(loader.ModelType());
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
if (!benchmark_args.goldens.path.empty()) {
const std::string golden_path =

View File

@ -1,68 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Include guard for non-SIMD code.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_
#include <stddef.h>
#include <stdint.h>
#include <array>
#include <cmath>
#include "gemma/activations.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_
// Include guard for (potentially) SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE
#endif
#include "gemma/ops.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo`
// are both constexpr
#if HWY_COMPILER_GCC_ACTUAL
#define GEMMA_CONSTEXPR_EMBSCALING HWY_BF16_CONSTEXPR
#else
#define GEMMA_CONSTEXPR_EMBSCALING
#endif
template <typename TConfig>
GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
Sqrt(static_cast<float>(TConfig::kModelDim))));
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // NOLINT

67
gemma/common.cc Normal file
View File

@ -0,0 +1,67 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/common.h"
#include <stddef.h>
#include <string.h>
#include <algorithm> // std::transform
#include <cctype>
#include <string>
#include <vector>
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
namespace {
constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt", "2b-it",
"7b-it", "gr2b-it", "tiny"};
constexpr Model kModelTypes[] = {
Model::GEMMA_2B, Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_2B,
Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_TINY};
constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT,
ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT,
ModelTraining::GEMMA_IT};
} // namespace
const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training) {
constexpr size_t kNum = std::end(kModelFlags) - std::begin(kModelFlags);
static char kErrorMessageBuffer[kNum * 8 + 1024] =
"Invalid or missing model flag, need to specify one of ";
for (size_t i = 0; i + 1 < kNum; i++) {
strcat(kErrorMessageBuffer, kModelFlags[i]); // NOLINT
strcat(kErrorMessageBuffer, ", "); // NOLINT
}
strcat(kErrorMessageBuffer, kModelFlags[kNum - 1]); // NOLINT
strcat(kErrorMessageBuffer, "."); // NOLINT
std::string model_type_lc = model_flag;
std::transform(begin(model_type_lc), end(model_type_lc), begin(model_type_lc),
[](unsigned char c) { return std::tolower(c); });
for (size_t i = 0; i < kNum; i++) {
if (kModelFlags[i] == model_type_lc) {
model = kModelTypes[i];
training = kModelTraining[i];
return nullptr;
}
}
return kErrorMessageBuffer;
}
} // namespace gcpp

View File

@ -16,15 +16,115 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_
#include <math.h> // sqrtf
#include <stdint.h>
#include <string>
#include "gemma/configs.h" // IWYU pragma: export
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // ConvertScalarTo
namespace gcpp {
using ByteStorageT = hwy::AlignedFreeUniquePtr<uint8_t[]>;
template <typename T>
ByteStorageT AllocateSizeof() {
return hwy::AllocateAligned<uint8_t>(sizeof(T));
}
// Model variants: see configs.h for details.
enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B, GEMMA_TINY };
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
// Returns the return value of Func<T>().operator() called with `args`, where
// `T` is selected based on `model`.
//
// This is used to implement type-erased functions such as
// LoadCompressedWeights, which can be called from other .cc files, by calling a
// functor LoadCompressedWeightsT, which has a template argument. `Func` must
// be a functor because function templates cannot be passed as a template
// template argument, and we prefer to avoid the overhead of std::function.
//
// This function avoids having to update all call sites when we extend `Model`.
template <template <typename Config> class Func, typename... Args>
decltype(auto) CallFunctorForModel(Model model, Args&&... args) {
switch (model) {
case Model::GEMMA_TINY:
return Func<ConfigGemmaTiny>()(std::forward<Args>(args)...);
case Model::GEMMA_2B:
return Func<ConfigGemma2B>()(std::forward<Args>(args)...);
case Model::GEMMA_7B:
return Func<ConfigGemma7B>()(std::forward<Args>(args)...);
case Model::GRIFFIN_2B:
return Func<ConfigGriffin2B>()(std::forward<Args>(args)...);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
}
// Like CallFunctorForModel, but for SIMD function templates. This is a macro
// because it boils down to N_SSE4::FUNC, which would not work if FUNC was a
// normal function argument.
#define GEMMA_EXPORT_AND_DISPATCH_MODEL(MODEL, FUNC, ARGS) \
switch (MODEL) { \
case Model::GEMMA_TINY: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemmaTiny>) \
ARGS; \
break; \
} \
case Model::GEMMA_2B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma2B>) \
ARGS; \
break; \
} \
case Model::GEMMA_7B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma7B>) \
ARGS; \
break; \
} \
case Model::GRIFFIN_2B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGriffin2B>) \
ARGS; \
break; \
} \
default: \
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
}
// Returns error string or nullptr if OK.
// Thread-hostile.
const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training);
// __builtin_sqrt is not constexpr as of Clang 17.
#if HWY_COMPILER_GCC_ACTUAL
#define GEMMA_CONSTEXPR_SQRT constexpr
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) {
return __builtin_sqrt(x);
}
#else
#define GEMMA_CONSTEXPR_SQRT
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
#endif
// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo`
// are both constexpr
#if HWY_COMPILER_GCC_ACTUAL
#define GEMMA_CONSTEXPR_EMBSCALING HWY_BF16_CONSTEXPR
#else
#define GEMMA_CONSTEXPR_EMBSCALING
#endif
template <typename TConfig>
GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
Sqrt(static_cast<float>(TConfig::kModelDim))));
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_

View File

@ -15,11 +15,33 @@
// Command line tool to create compressed weights.
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/compress_weights.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep
// Must come after foreach_target.h to avoid redefinition errors.
#include "compression/compress-inl.h"
#include "hwy/highway.h"
#ifndef GEMMA_COMPRESS_WEIGHTS_ONCE
#define GEMMA_COMPRESS_WEIGHTS_ONCE
#include <stddef.h>
#include <stdio.h>
#include <algorithm> // std::clamp
#include <iostream>
#include <string>
#include <thread> // NOLINT
#include "gemma/gemma.h" // Gemma
#include "compression/io.h" // Path
#include "gemma/common.h" // Model
#include "gemma/weights.h"
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
@ -99,10 +121,57 @@ void ShowHelp(gcpp::Args& args) {
std::cerr << "\n";
}
} // namespace gcpp
#endif // GEMMA_COMPRESS_WEIGHTS_ONCE
// SIMD code, compiled once per target.
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
template <class TConfig>
void CompressWeights(const Path& weights_path,
const Path& compressed_weights_path, Model model_type,
hwy::ThreadPool& pool) {
if (!weights_path.Exists()) {
HWY_ABORT("The model weights file '%s' does not exist.",
weights_path.path.c_str());
}
// Allocate compressed weights.
using CWeights = CompressedWeights<TConfig>;
ByteStorageT c_weights_u8 = AllocateSizeof<CWeights>();
CWeights* c_weights = reinterpret_cast<CWeights*>(c_weights_u8.get());
new (&c_weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
// Get weights, compress, and store.
const bool scale_for_compression = TConfig::kNumTensorScales > 0;
const ByteStorageT weights_u8 = gcpp::LoadRawWeights(
weights_path, model_type, pool, scale_for_compression);
WeightsF<TConfig>* weights =
reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
Compressor compressor(pool);
ForEachTensor<TConfig>(weights, *c_weights, compressor);
compressor.AddScales(weights->scales.data(), weights->scales.size());
compressor.WriteAll(pool, compressed_weights_path);
weights->layer_ptrs.~LayerPointers<float, TConfig>();
c_weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
void Run(Args& args) {
hwy::ThreadPool pool(args.num_threads);
gcpp::CompressWeights(args.ModelType(), args.weights, args.compressed_weights,
pool);
const Model model_type = args.ModelType();
GEMMA_EXPORT_AND_DISPATCH_MODEL(
model_type, CompressWeights,
(args.weights, args.compressed_weights, model_type, pool));
}
} // namespace gcpp
@ -111,12 +180,12 @@ int main(int argc, char** argv) {
gcpp::Args args(argc, argv);
if (gcpp::HasHelp(argc, argv)) {
ShowHelp(args);
gcpp::ShowHelp(args);
return 0;
}
if (const char* error = args.Validate()) {
ShowHelp(args);
gcpp::ShowHelp(args);
HWY_ABORT("\nInvalid args: %s", error);
}
@ -124,3 +193,5 @@ int main(int argc, char** argv) {
return 0;
}
#endif // HWY_ONCE

View File

@ -13,11 +13,20 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// Model configurations
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_CONFIGS_H_
// Model configurations
#include <stddef.h>
#include <array>
#include "compression/compress.h" // SfpStream
#include "hwy/base.h" // hwy::bfloat16_t
namespace gcpp {
// Allow changing pre-allocated kv cache size as a compiler flag
#ifndef GEMMA_MAX_SEQLEN
#define GEMMA_MAX_SEQLEN 4096
@ -33,25 +42,20 @@
#define GEMMA_MAX_THREADS 128
#endif // !GEMMA_MAX_THREADS
#include <stddef.h>
#include <array>
#include "compression/sfp.h"
#include "hwy/base.h" // hwy::bfloat16_t
// Allowable types for GEMMA_WEIGHT_T (can be specified at compilation time):
// float, hwy::bfloat16_t, SfpStream, NuqStream
#ifndef GEMMA_WEIGHT_T
#define GEMMA_WEIGHT_T SfpStream
#endif // !GEMMA_WEIGHT_T
namespace gcpp {
static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
static constexpr size_t kTopK = GEMMA_TOPK;
static constexpr size_t kMaxThreads = GEMMA_MAX_THREADS;
using GemmaWeightT = GEMMA_WEIGHT_T;
using EmbedderInputT = hwy::bfloat16_t;
enum class LayerAttentionType {
kGemma,
kGriffinRecurrentBlock,

File diff suppressed because it is too large Load Diff

View File

@ -24,22 +24,12 @@
#include "compression/io.h" // Path
#include "gemma/common.h"
#include "gemma/configs.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // hwy::bfloat16_t
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
using GemmaWeightT = GEMMA_WEIGHT_T;
using EmbedderInputT = hwy::bfloat16_t;
// Will be called for layers output with:
// - position in the tokens sequence
// - name of the data, p.ex. "tokens", "block.1", "final_norm"
// - pointer to the data array
// - size of the data array
using LayersOutputT =
std::function<void(int, const std::string&, const float*, size_t)>;
constexpr size_t kPrefillBatchSize = 16;
constexpr bool kSystemPrompt = false;
@ -50,14 +40,30 @@ struct KVCache {
conv1d_cache; // (kConv1dWidth - 1) * kModelDim * kGriffinLayers
hwy::AlignedFreeUniquePtr<float[]>
rglru_cache; // kModelDim * kGriffinLayers
static KVCache Create(Model type);
};
enum class ModelTraining { GEMMA_IT, GEMMA_PT };
constexpr int EOS_ID = 1;
// Returns error string or nullptr if OK.
// Thread-hostile.
const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training);
class GemmaTokenizer {
public:
GemmaTokenizer() = default; // for second Gemma ctor.
explicit GemmaTokenizer(const Path& tokenizer_path);
// must come after definition of Impl
~GemmaTokenizer();
GemmaTokenizer(GemmaTokenizer&& other);
GemmaTokenizer& operator=(GemmaTokenizer&& other);
bool Encode(const std::string& input, std::vector<std::string>* pieces) const;
bool Encode(const std::string& input, std::vector<int>* pieces) const;
bool Decode(const std::vector<int>& ids, std::string* detokenized) const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
// StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f. StreamFunc should return False to stop generation and
@ -67,8 +73,6 @@ using StreamFunc = std::function<bool(int, float)>;
// want to generate and True for tokens you want to generate.
using AcceptFunc = std::function<bool(int)>;
constexpr int EOS_ID = 1;
struct RuntimeConfig {
size_t max_tokens;
size_t max_generated_tokens;
@ -80,64 +84,65 @@ struct RuntimeConfig {
int eos_id = EOS_ID;
};
struct GemmaInterface;
class GemmaTokenizer {
public:
virtual ~GemmaTokenizer() = default;
virtual bool Encode(const std::string& input,
std::vector<std::string>* pieces) const = 0;
virtual bool Encode(const std::string& input,
std::vector<int>* pieces) const = 0;
virtual bool Decode(const std::vector<int>& ids,
std::string* detokenized) const = 0;
};
struct Gemma {
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
hwy::ThreadPool& pool);
~Gemma(); // must be defined after the GemmaInterface dtor is defined.
const GemmaTokenizer* Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_;
};
struct TimingInfo {
double prefill_tok_sec = 0.0;
double gen_tok_sec = 0.0;
double time_to_first_token = 0;
};
KVCache CreateKVCache(Model type); // convenient workaround for now
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len,
size_t conv1d_cache_size, size_t rglru_cache_size);
// Will be called for layers output with:
// - position in the tokens sequence
// - name of the data, p.ex. "tokens", "block.1", "final_norm"
// - pointer to the data array
// - size of the data array
using LayersOutputT =
std::function<void(int, const std::string&, const float*, size_t)>;
// Bundle runtime parameters as RuntimeConfig
// layers_output is optional; if set - it will be called with the activations
// output after applying each layer.
void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info,
LayersOutputT* layers_output = nullptr);
class Gemma {
public:
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
hwy::ThreadPool& pool);
void GenerateGemma(Model model, const ByteStorageT& weights,
ByteStorageT& inference_state,
RuntimeConfig runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& pool,
TimingInfo& timing_info);
// Allocates weights, caller is responsible for filling them.
Gemma(GemmaTokenizer&& tokenizer, Model model_type, hwy::ThreadPool& pool);
~Gemma();
ByteStorageT LoadWeights(const Path& weights, Model model,
hwy::ThreadPool& pool);
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
const ByteStorageT& Weights() const { return weights_u8_; }
const ByteStorageT& Prefill() const { return prefill_u8_; }
const ByteStorageT& Decode() const { return decode_u8_; }
ByteStorageT AllocateInferenceState(Model model);
// layers_output is optional; if set - it will be called with the activations
// output after applying each layer.
void Generate(const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, TimingInfo& timing_info,
LayersOutputT* layers_output = nullptr);
void CompressWeights(gcpp::Model model, const Path& weights,
const Path& compressed_weights, hwy::ThreadPool& pool);
float ComputeCrossEntropy(size_t max_tokens, const std::vector<int>& prompt,
KVCache& kv_cache, int verbosity);
float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
hwy::ThreadPool& pool, int verbosity);
private:
hwy::ThreadPool& pool_;
GemmaTokenizer tokenizer_;
// Type-erased so that this can be defined in the header, without requiring
// forwarding functions.
ByteStorageT weights_u8_;
ByteStorageT prefill_u8_;
ByteStorageT decode_u8_;
Model model_type_;
};
// DEPRECATED, call Gemma::Generate directly.
HWY_INLINE void GenerateGemma(Gemma& gemma, const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t start_pos,
KVCache& kv_cache, hwy::ThreadPool& /*pool*/,
TimingInfo& timing_info,
LayersOutputT* layers_output) {
gemma.Generate(runtime_config, prompt, start_pos, kv_cache, timing_info,
layers_output);
}
} // namespace gcpp

View File

@ -22,8 +22,9 @@
#include <thread> // NOLINT
#include <vector>
#include "compression/io.h" // Path
#include "gemma/common.h"
#include "gemma/ops.h"
#include "util/args.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/tests/test_util-inl.h"
@ -38,7 +39,7 @@ class GemmaTest : public ::testing::Test {
pool(std::min<int>(20, (std::thread::hardware_concurrency() - 1) / 2)),
model_type(gcpp::Model::GEMMA_2B),
model(tokenizer, weights, model_type, pool) {
kv_cache = CreateKVCache(model_type);
KVCache kv_cache = KVCache::Create(model_type);
}
std::string GemmaReply(const std::string& prompt_string) {
@ -46,7 +47,7 @@ class GemmaTest : public ::testing::Test {
gen.seed(42);
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt));
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
prompt.insert(prompt.begin(), 2);
@ -66,18 +67,18 @@ class GemmaTest : public ::testing::Test {
.accept_token = [](int) { return true; },
};
gcpp::TimingInfo timing_info;
gcpp::GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0,
kv_cache, pool, timing_info, /*layers_output=*/nullptr);
model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache,
timing_info, /*layers_output=*/nullptr);
std::string response_text;
HWY_ASSERT(model.Tokenizer()->Decode(response, &response_text));
HWY_ASSERT(model.Tokenizer().Decode(response, &response_text));
return response_text;
}
float GemmaCrossEntropy(const std::string& prompt_string) {
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
return gcpp::ComputeCrossEntropy(model, /*max_tokens=*/3072, prompt,
kv_cache, pool, /*verbosity=*/0) /
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt));
return model.ComputeCrossEntropy(/*max_tokens=*/3072, prompt, kv_cache,
/*verbosity=*/0) /
prompt_string.size();
}

View File

@ -17,12 +17,12 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
#include <math.h>
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <array>
#include <cmath>
#include <cstdio>
#include <random>
#include <type_traits> // std::enable_if_t
@ -31,21 +31,6 @@
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
namespace gcpp {
// __builtin_sqrt is not constexpr as of Clang 17.
#if HWY_COMPILER_GCC_ACTUAL
#define GEMMA_CONSTEXPR_SQRT constexpr
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) {
return __builtin_sqrt(x);
}
#else
#define GEMMA_CONSTEXPR_SQRT
static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
#endif
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_OPS_H_
// Include guard for (potentially) SIMD code.

View File

@ -25,6 +25,7 @@
// Placeholder for internal header, do not modify.
#include "compression/compress.h"
#include "gemma/configs.h"
#include "gemma/gemma.h" // Gemma
#include "util/app.h"
#include "util/args.h" // HasHelp
@ -116,8 +117,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
// callback function invoked for each generated token.
auto stream_token = [&abs_pos, &current_pos, &args, &gen, &prompt_size,
tokenizer = model.Tokenizer(),
verbosity](int token, float) {
&model, verbosity](int token, float) {
++abs_pos;
++current_pos;
// <= since position is incremented before
@ -135,7 +135,8 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
}
} else {
std::string token_text;
HWY_ASSERT(tokenizer->Decode(std::vector<int>{token}, &token_text));
HWY_ASSERT(
model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
// +1 since position is incremented above
if (current_pos == prompt_size + 1) {
// first token of response
@ -192,7 +193,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
}
}
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt));
// For both pre-trained and instruction-tuned models: prepend "<bos>" token
// if needed.
@ -221,8 +222,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
.stream_token = stream_token,
.accept_token = accept_token,
};
GenerateGemma(model, runtime_config, prompt, abs_pos, kv_cache, pool,
timing_info);
model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info);
if (verbosity >= 2) {
std::cout << current_pos << " tokens (" << abs_pos << " total tokens)"
<< "\n"
@ -251,7 +251,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
auto kv_cache = CreateKVCache(loader.ModelType());
KVCache kv_cache = KVCache::Create(loader.ModelType());
if (const char* error = inference.Validate()) {
ShowHelp(loader, inference, app);

View File

@ -15,7 +15,6 @@
// Command line text interface to gemma.
#include <ctime>
#include <fstream>
#include <iostream>
#include <random>
@ -63,7 +62,7 @@ void JsonGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
std::set<int> accept_token_set{};
for (const std::string& accept_token : accept_tokens) {
std::vector<int> accept_token_ids;
HWY_ASSERT(model.Tokenizer()->Encode(accept_token, &accept_token_ids));
HWY_ASSERT(model.Tokenizer().Encode(accept_token, &accept_token_ids));
accept_token_set.insert(accept_token_ids.begin(), accept_token_ids.end());
}
@ -76,7 +75,7 @@ void JsonGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
const std::string& prompt_string = sample["prompt"];
std::vector<int> prompt;
HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt));
HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt));
prompt_size = prompt.size();
const std::string& correct_answer = accept_tokens[sample["input_label"]];
@ -124,11 +123,10 @@ void JsonGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
.stream_token = stream_token,
.accept_token = accept_token,
};
GenerateGemma(model, runtime_config, prompt, abs_pos, kv_cache, pool,
timing_info);
model.Generate(runtime_config, prompt, abs_pos, kv_cache, timing_info);
std::string output_string;
HWY_ASSERT(model.Tokenizer()->Decode(predicted_token_ids, &output_string));
HWY_ASSERT(model.Tokenizer().Decode(predicted_token_ids, &output_string));
std::cout << "QuestionId: " << sample["i"] << "; "
<< "Predicted Answer: " << output_string << "; "
<< "Correct Answer: " << correct_answer << std::endl;
@ -161,7 +159,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool);
auto kv_cache = CreateKVCache(loader.ModelType());
gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType());
JsonGemma(model, kv_cache, pool, inference, app.verbosity, app.eot_line);
}

View File

@ -15,53 +15,230 @@
#include "gemma/weights.h"
#include <algorithm>
#include <cstdlib>
#include "compression/compress.h"
#include "compression/io.h" // Path
#include "gemma/common.h"
#include "gemma/configs.h"
#include "hwy/base.h" // HWY_ABORT
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
#include "hwy/stats.h"
namespace gcpp {
ByteStorageT AllocateWeights(Model model, hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
return AllocateWeights<float, ConfigGemma2B>(pool);
case Model::GEMMA_7B:
return AllocateWeights<float, ConfigGemma7B>(pool);
case Model::GRIFFIN_2B:
return AllocateWeights<float, ConfigGriffin2B>(pool);
case Model::GEMMA_TINY:
return AllocateWeights<float, ConfigGemmaTiny>(pool);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
// Setting this to true disables fread() calls that read the model file.
constexpr bool kDryRunFread = false;
namespace {
float ScaleWeights(float* data, size_t len) {
float maxabs = 0.0;
for (size_t i = 0; i < len; ++i) {
maxabs = std::max(maxabs, std::abs(data[i]));
}
const float kMaxRange = 1.875f;
if (maxabs <= kMaxRange) {
return 1.0f;
}
const float scale = maxabs / kMaxRange;
const float inv_scale = 1.0f / scale;
for (size_t i = 0; i < len; ++i) {
data[i] *= inv_scale;
}
return scale;
}
#define READ_WEIGHTS(name) \
do { \
do_fread(&(layer_view->name), layer, #name, sizeof(layer_view->name)); \
} while (0)
#define SCALE_WEIGHTS(name) \
do { \
if (ok && !kDryRunFread && scale_for_compression) { \
weights->scales[scale_pos++] = \
ScaleWeights(layer_view->name.data(), layer_view->name.size()); \
} \
} while (0)
template <typename TConfig>
struct LoadRawWeightsT {
ByteStorageT operator()(const Path& checkpoint, hwy::ThreadPool& pool,
bool scale_for_compression) const {
PROFILER_ZONE("Startup.LoadWeights");
if (!checkpoint.Exists()) {
HWY_ABORT("The model weights file '%s' does not exist.",
checkpoint.path.c_str());
}
ByteStorageT weights_u8 = AllocateWeights<TConfig>()(pool);
auto* weights = reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
size_t scale_pos = 0;
FILE* fptr;
if constexpr (kDryRunFread) {
fprintf(stderr, "Dry-Run, not reading model-file.\n");
} else {
fptr = fopen(checkpoint.path.c_str(), "rb");
if (fptr == nullptr) {
HWY_ABORT("Failed to open model file %s - does it exist?",
checkpoint.path.c_str());
}
}
bool ok = true;
uint64_t total_size = 0;
auto do_fread = [&](void* var, int layer, const char* name, size_t size) {
if (layer == -1) {
fprintf(stderr, "Loading Parameters (size %zu): %s\n", size, name);
} else {
fprintf(stderr, "Loading Parameters (layer=%d, size %zu): %s\n", layer,
size, name);
}
if constexpr (!kDryRunFread) {
ok &= 1 == fread(var, size, 1, fptr);
total_size += size;
}
};
do_fread(&(weights->embedder_input_embedding), -1,
"embedder_input_embedding",
sizeof(weights->embedder_input_embedding));
do_fread(&(weights->final_norm_scale), -1, "final_norm_scale",
sizeof(weights->final_norm_scale));
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
auto type = TConfig::kLayerConfig[layer];
LayerF<TConfig>* layer_view = weights->GetLayer(layer);
// Make sure we don't have uninitialized memory.
hwy::ZeroBytes(layer_view, sizeof(*layer_view));
if (type == LayerAttentionType::kGemma) {
READ_WEIGHTS(attn_vec_einsum_w);
READ_WEIGHTS(qkv_einsum_w);
SCALE_WEIGHTS(attn_vec_einsum_w);
SCALE_WEIGHTS(qkv_einsum_w);
} else {
READ_WEIGHTS(griffin.linear_x_w);
READ_WEIGHTS(griffin.linear_x_biases);
READ_WEIGHTS(griffin.linear_y_w);
READ_WEIGHTS(griffin.linear_y_biases);
READ_WEIGHTS(griffin.linear_out_w);
READ_WEIGHTS(griffin.linear_out_biases);
READ_WEIGHTS(griffin.conv_w);
READ_WEIGHTS(griffin.conv_biases);
READ_WEIGHTS(griffin.gate_w);
READ_WEIGHTS(griffin.gate_biases);
READ_WEIGHTS(griffin.a);
SCALE_WEIGHTS(griffin.linear_x_w);
SCALE_WEIGHTS(griffin.linear_y_w);
SCALE_WEIGHTS(griffin.linear_out_w);
SCALE_WEIGHTS(griffin.gate_w);
}
READ_WEIGHTS(gating_einsum_w);
READ_WEIGHTS(linear_w);
SCALE_WEIGHTS(gating_einsum_w);
SCALE_WEIGHTS(linear_w);
READ_WEIGHTS(pre_attention_norm_scale);
READ_WEIGHTS(pre_ffw_norm_scale);
if (TConfig::kPostNormScale) {
READ_WEIGHTS(post_attention_norm_scale);
READ_WEIGHTS(post_ffw_norm_scale);
}
if (TConfig::kFFBiases) {
READ_WEIGHTS(ffw_gating_biases);
READ_WEIGHTS(ffw_output_biases);
}
if (TConfig::kSoftmaxAttnOutputBiases &&
type == LayerAttentionType::kGemma) {
READ_WEIGHTS(attention_output_biases);
}
}
if (!ok) {
HWY_ABORT(
"Failed to read from %s - might be a directory, or too small? "
"expected size: %d kB",
checkpoint.path.c_str(), static_cast<uint32_t>(total_size >> 10));
}
if (!kDryRunFread) {
HWY_ASSERT(0 == fclose(fptr));
if (scale_for_compression) {
HWY_ASSERT(scale_pos == TConfig::kNumTensorScales);
}
}
return weights_u8;
}
};
#undef READ_WEIGHTS
#undef SCALE_WEIGHTS
} // namespace
ByteStorageT LoadRawWeights(const Path& weights, Model model,
hwy::ThreadPool& pool, bool scale_for_compression) {
return CallFunctorForModel<LoadRawWeightsT>(model, weights, pool,
scale_for_compression);
}
namespace {
template <typename TConfig>
void ZeroInitWeightsT(ByteStorageT& weights, hwy::ThreadPool& pool) {
ZeroInit<float, TConfig>(
*reinterpret_cast<Weights<float, TConfig>*>(weights.get()));
}
template <class TConfig>
struct LoadCompressedWeightsT {
ByteStorageT operator()(const Path& weights, hwy::ThreadPool& pool) const {
PROFILER_ZONE("Startup.LoadCompressedWeights");
if (!weights.Exists()) {
HWY_ABORT("The model weights file '%s' does not exist.",
weights.path.c_str());
}
// Allocate compressed weights.
using CWeights = CompressedWeights<TConfig>;
ByteStorageT c_weights_u8 = AllocateSizeof<CWeights>();
CWeights* c_weights = reinterpret_cast<CWeights*>(c_weights_u8.get());
new (&c_weights->c_layer_ptrs) CompressedLayerPointers<TConfig>(pool);
std::array<float, TConfig::kNumTensorScales> scales;
CacheLoader loader(weights);
ForEachTensor<TConfig>(nullptr, *c_weights, loader);
loader.LoadScales(scales.data(), scales.size());
if (!loader.ReadAll(pool)) {
HWY_ABORT("Failed to load model weights.");
}
if (TConfig::kNumTensorScales > 0) {
size_t scale_pos = 0;
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx);
CompressedLayer<TConfig>* layer_weights = c_weights->GetLayer(idx);
if (type == LayerAttentionType::kGemma) {
layer_weights->attn_vec_einsum_w.set_scale(scales[scale_pos++]);
layer_weights->qkv_einsum_w.set_scale(scales[scale_pos++]);
} else {
layer_weights->griffin.linear_x_w.set_scale(scales[scale_pos++]);
layer_weights->griffin.linear_y_w.set_scale(scales[scale_pos++]);
layer_weights->griffin.linear_out_w.set_scale(scales[scale_pos++]);
layer_weights->griffin.gate_w.set_scale(scales[scale_pos++]);
}
layer_weights->gating_einsum_w.set_scale(scales[scale_pos++]);
layer_weights->linear_w.set_scale(scales[scale_pos++]);
}
HWY_ASSERT(scale_pos == TConfig::kNumTensorScales);
}
return c_weights_u8;
}
};
} // namespace
void ZeroInitWeights(Model model, ByteStorageT& weights,
hwy::ThreadPool& pool) {
switch (model) {
case Model::GEMMA_2B:
ZeroInitWeightsT<ConfigGemma2B>(weights, pool);
break;
case Model::GEMMA_7B:
ZeroInitWeightsT<ConfigGemma7B>(weights, pool);
break;
case Model::GRIFFIN_2B:
ZeroInitWeightsT<ConfigGriffin2B>(weights, pool);
break;
case Model::GEMMA_TINY:
ZeroInitWeightsT<ConfigGemmaTiny>(weights, pool);
break;
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
ByteStorageT LoadCompressedWeights(const Path& weights, Model model,
hwy::ThreadPool& pool) {
return CallFunctorForModel<LoadCompressedWeightsT>(model, weights, pool);
}
ByteStorageT LoadWeights(const Path& weights, Model model,
hwy::ThreadPool& pool) {
if constexpr (kWeightsAreCompressed) {
return LoadCompressedWeights(weights, model, pool);
} else {
return LoadRawWeights(weights, model, pool,
/*scale_for_compression=*/false);
}
}
@ -86,27 +263,19 @@ class WeightLogger {
};
template <typename TConfig>
void LogWeightStats(const ByteStorageT& weights_u8) {
const auto& weights = *reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
WeightLogger logger;
ForEachTensor1<float, TConfig>(logger, weights);
printf("%-20s %12zu\n", "Total", logger.total_weights);
}
struct LogWeightStatsT {
void operator()(const ByteStorageT& weights_u8) const {
const auto& weights =
*reinterpret_cast<WeightsF<TConfig>*>(weights_u8.get());
WeightLogger logger;
ForEachTensor1<float, TConfig>(logger, weights);
printf("%-20s %12zu\n", "Total", logger.total_weights);
}
};
} // namespace
void LogWeightStats(gcpp::Model model, const ByteStorageT& weights) {
switch (model) {
case Model::GEMMA_2B:
return LogWeightStats<ConfigGemma2B>(weights);
case Model::GEMMA_7B:
return LogWeightStats<ConfigGemma7B>(weights);
case Model::GRIFFIN_2B:
return LogWeightStats<ConfigGriffin2B>(weights);
case Model::GEMMA_TINY:
return LogWeightStats<ConfigGemmaTiny>(weights);
default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
CallFunctorForModel<LogWeightStatsT>(model, weights);
}
} // namespace gcpp

View File

@ -16,13 +16,21 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
#include "compression/compress.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
// Setting this to false will load and use uncompressed weights.
constexpr bool kWeightsAreCompressed = true;
// ----------------------------------------------------------------------------
// Uncompressed
template <typename T, class TConfig>
struct Layer {
Layer() {}
@ -118,15 +126,262 @@ struct Weights {
template <class TConfig>
using WeightsF = Weights<float, TConfig>;
// ----------------------------------------------------------------------------
// Compressed
template <class TConfig>
struct CompressedLayer {
// No ctor/dtor, allocated via AllocateAligned.
using TLayer = gcpp::LayerF<TConfig>;
using WeightT = typename TConfig::WeightT;
static constexpr size_t kHeads = TLayer::kHeads;
static constexpr size_t kKVHeads = TLayer::kKVHeads;
static constexpr size_t kModelDim = TLayer::kModelDim;
static constexpr size_t kQKVDim = TLayer::kQKVDim;
static constexpr size_t kFFHiddenDim = TLayer::kFFHiddenDim;
static constexpr size_t kAttVecEinsumWSize = TLayer::kAttVecEinsumWSize;
static constexpr size_t kQKVEinsumWSize = TLayer::kQKVEinsumWSize;
static constexpr size_t kGatingEinsumWSize = TLayer::kGatingEinsumWSize;
static constexpr size_t kConv1dWidth = TLayer::kConv1dWidth;
static constexpr bool kFFBiases = TLayer::kFFBiases;
static constexpr bool kPostNormScale = TConfig::kPostNormScale;
static constexpr size_t kAOBiasDim = TLayer::kAOBiasDim;
static constexpr size_t kGriffinDim = TLayer::kGriffinDim;
// Compressed Parameters
template <class T, size_t N>
using ArrayT = CompressedArray<T, N>;
union {
struct {
ArrayT<WeightT, kAttVecEinsumWSize> attn_vec_einsum_w;
ArrayT<WeightT, kQKVEinsumWSize> qkv_einsum_w;
ArrayT<float, kAOBiasDim> attention_output_biases;
};
struct {
ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_x_w;
ArrayT<float, kGriffinDim> linear_x_biases;
ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_y_w;
ArrayT<float, kGriffinDim> linear_y_biases;
ArrayT<WeightT, kGriffinDim * kGriffinDim> linear_out_w;
ArrayT<float, kGriffinDim> linear_out_biases;
ArrayT<float, TConfig::kConv1dWidth * kGriffinDim> conv_w;
ArrayT<float, kGriffinDim> conv_biases;
ArrayT<WeightT, kGriffinDim * kGriffinDim / kHeads * 2> gate_w;
ArrayT<float, kGriffinDim * 2> gate_biases;
ArrayT<float, kGriffinDim> a;
} griffin;
};
ArrayT<WeightT, TLayer::kGatingEinsumWSize> gating_einsum_w;
ArrayT<WeightT, kModelDim * kFFHiddenDim> linear_w;
// We don't yet have an RMSNorm that accepts all WeightT.
ArrayT<hwy::bfloat16_t, kModelDim> pre_attention_norm_scale;
ArrayT<hwy::bfloat16_t, kModelDim> pre_ffw_norm_scale;
ArrayT<hwy::bfloat16_t, kPostNormScale ? kModelDim : 0>
post_attention_norm_scale;
ArrayT<hwy::bfloat16_t, kPostNormScale ? kModelDim : 0> post_ffw_norm_scale;
ArrayT<float, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
ArrayT<float, kFFBiases ? kModelDim : 0> ffw_output_biases;
};
// Array instead of single large allocation for parallel mem init. Split out
// of CompressedWeights so that only these pointers are initialized, not the
// CompressedArray.
template <class TConfig>
struct CompressedLayerPointers {
explicit CompressedLayerPointers(hwy::ThreadPool& pool) {
pool.Run(0, TConfig::kLayers, [this](uint64_t task, size_t /*thread*/) {
this->c_layers[task] = hwy::AllocateAligned<CompressedLayer<TConfig>>(1);
});
}
using CLayer = CompressedLayer<TConfig>;
std::array<hwy::AlignedFreeUniquePtr<CLayer[]>, TConfig::kLayers> c_layers;
};
template <class TConfig>
struct CompressedWeights {
// No ctor/dtor, allocated via AllocateAligned.
CompressedArray<EmbedderInputT, TConfig::kVocabSize * TConfig::kModelDim>
embedder_input_embedding;
CompressedArray<hwy::bfloat16_t, TConfig::kModelDim> final_norm_scale;
// Must be last so that the other arrays remain aligned.
CompressedLayerPointers<TConfig> c_layer_ptrs;
const CompressedLayer<TConfig>* GetLayer(size_t layer) const {
return c_layer_ptrs.c_layers[layer].get();
}
CompressedLayer<TConfig>* GetLayer(size_t layer) {
return c_layer_ptrs.c_layers[layer].get();
}
};
// ----------------------------------------------------------------------------
// Interface
template <class TConfig>
using WeightsT = hwy::If<kWeightsAreCompressed, CompressedWeights<TConfig>,
WeightsF<TConfig>>;
// Call via CallFunctorForModel.
template <typename TConfig>
struct AllocateWeights {
ByteStorageT operator()(hwy::ThreadPool& pool) const {
using TWeights = WeightsF<TConfig>;
ByteStorageT weights_u8 = AllocateSizeof<TWeights>();
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
new (&weights->layer_ptrs) LayerPointers<float, TConfig>(pool);
return weights_u8;
}
};
template <typename TConfig>
struct ZeroInitWeights {
void operator()(ByteStorageT& weights, hwy::ThreadPool& pool) const {
WeightsF<TConfig>& w = *reinterpret_cast<WeightsF<TConfig>*>(weights.get());
hwy::ZeroBytes(&w.embedder_input_embedding,
sizeof(w.embedder_input_embedding));
hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale));
for (int i = 0; i < TConfig::kLayers; ++i) {
hwy::ZeroBytes(w.GetLayer(i), sizeof(*w.GetLayer(i)));
}
}
};
template <typename TConfig>
struct CopyWeights {
void operator()(WeightsF<TConfig>& dst, const WeightsF<TConfig>& src) const {
hwy::CopyBytes(&src.embedder_input_embedding, &dst.embedder_input_embedding,
sizeof(src.embedder_input_embedding));
hwy::CopyBytes(&src.final_norm_scale, &dst.final_norm_scale,
sizeof(src.final_norm_scale));
for (int i = 0; i < TConfig::kLayers; ++i) {
hwy::CopyBytes(src.GetLayer(i), dst.GetLayer(i),
sizeof(*dst.GetLayer(i)));
}
}
};
template <class TConfig>
struct DeleteLayersPtrs {
void operator()(ByteStorageT& weights_u8) const {
auto* weights = reinterpret_cast<WeightsT<TConfig>*>(weights_u8.get());
if constexpr (kWeightsAreCompressed) {
weights->c_layer_ptrs.~CompressedLayerPointers<TConfig>();
} else {
weights->layer_ptrs.~LayerPointers<float, TConfig>();
}
}
};
// Owns weights and provides access to TConfig.
template <typename T, typename TConfig>
ByteStorageT AllocateWeights(hwy::ThreadPool& pool) {
using TWeights = Weights<T, TConfig>;
ByteStorageT weights_u8 = hwy::AllocateAligned<uint8_t>(sizeof(TWeights));
TWeights* weights = reinterpret_cast<TWeights*>(weights_u8.get());
new (&weights->layer_ptrs) LayerPointers<T, TConfig>(pool);
return weights_u8;
class WeightsWrapper {
public:
WeightsWrapper()
: pool_(0),
data_(AllocateWeights<TConfig>(pool_)),
weights_(reinterpret_cast<Weights<T, TConfig>*>(data_.get())) {}
const Weights<T, TConfig>& get() const { return *weights_; }
Weights<T, TConfig>& get() { return *weights_; }
void clear() { ZeroInitWeights<TConfig>()(get()); }
void copy(const WeightsWrapper<T, TConfig>& other) {
CopyWeights<TConfig>()(get(), other.get());
}
private:
hwy::ThreadPool pool_;
ByteStorageT data_;
Weights<T, TConfig>* weights_;
};
// For use by compress_weights.cc.
ByteStorageT LoadRawWeights(const Path& weights, Model model,
hwy::ThreadPool& pool, bool scale_for_compression);
// For gemma.cc; calls LoadRawWeights if !kWeightsAreCompressed.
ByteStorageT LoadWeights(const Path& weights, Model model,
hwy::ThreadPool& pool);
void LogWeightStats(Model model, const ByteStorageT& weights);
// ----------------------------------------------------------------------------
// Iterators
#define GEMMA_CALL_FUNC(name, member) \
snprintf(name_buf, sizeof(name_buf), name "_%d", layer_idx); \
func(name_buf, layer ? layer->member.data() : nullptr, layer_weights->member)
// Calls func(name, float*, CompressedArray&) for each tensor. float* is null
// if weights = null, which happens during the first call where we attempt to
// load from cache.
//
// This avoids repeating the list of tensors between loading and compressing.
template <class TConfig, class Func>
void ForEachTensor(const WeightsF<TConfig>* weights,
CompressedWeights<TConfig>& c_weights, Func& func) {
func("c_embedding",
weights ? weights->embedder_input_embedding.data() : nullptr,
c_weights.embedder_input_embedding);
func("c_final_norm", weights ? weights->final_norm_scale.data() : nullptr,
c_weights.final_norm_scale);
char name_buf[16];
for (int layer_idx = 0; layer_idx < TConfig::kLayers; ++layer_idx) {
auto type = TConfig::kLayerConfig[layer_idx];
const size_t idx = static_cast<size_t>(layer_idx);
const LayerF<TConfig>* layer = weights ? weights->GetLayer(idx) : nullptr;
CompressedLayer<TConfig>* layer_weights = c_weights.GetLayer(idx);
GEMMA_CALL_FUNC("pre_ff_ns", pre_ffw_norm_scale);
GEMMA_CALL_FUNC("gating_ein", gating_einsum_w);
GEMMA_CALL_FUNC("linear_w", linear_w);
if (type == LayerAttentionType::kGemma) {
GEMMA_CALL_FUNC("qkv_ein", qkv_einsum_w);
GEMMA_CALL_FUNC("att_ein", attn_vec_einsum_w);
} else {
GEMMA_CALL_FUNC("gr_lin_x_w", griffin.linear_x_w);
GEMMA_CALL_FUNC("gr_lin_x_b", griffin.linear_x_biases);
GEMMA_CALL_FUNC("gr_lin_y_w", griffin.linear_y_w);
GEMMA_CALL_FUNC("gr_lin_y_b", griffin.linear_y_biases);
GEMMA_CALL_FUNC("gr_lin_out_w", griffin.linear_out_w);
GEMMA_CALL_FUNC("gr_lin_out_b", griffin.linear_out_biases);
GEMMA_CALL_FUNC("gr_conv_w", griffin.conv_w);
GEMMA_CALL_FUNC("gr_conv_b", griffin.conv_biases);
GEMMA_CALL_FUNC("gr_gate_w", griffin.gate_w);
GEMMA_CALL_FUNC("gr_gate_b", griffin.gate_biases);
GEMMA_CALL_FUNC("gr_a", griffin.a);
}
GEMMA_CALL_FUNC("pre_att_ns", pre_attention_norm_scale);
if (TConfig::kPostNormScale) {
GEMMA_CALL_FUNC("post_att_ns", post_attention_norm_scale);
GEMMA_CALL_FUNC("post_ff_ns", post_ffw_norm_scale);
}
if (TConfig::kFFBiases) {
GEMMA_CALL_FUNC("ffw_gat_b", ffw_gating_biases);
GEMMA_CALL_FUNC("ffw_out_b", ffw_output_biases);
}
if (TConfig::kSoftmaxAttnOutputBiases &&
type == LayerAttentionType::kGemma) {
GEMMA_CALL_FUNC("attn_ob", attention_output_biases);
}
}
}
#undef GEMMA_CALL_FUNC
#define GEMMA_CALL_TOP_FUNC1(name, member) func(name, weights1.member)
#define GEMMA_CALL_TOP_FUNC2(name, member) \
func(name, weights1.member, weights2.member)
@ -237,54 +492,6 @@ void ForEachTensor2(Func& func, const Weights<T, TConfig>& weights1,
#undef GEMMA_CALL_LAYER_FUNC4
#undef GEMMA_CALL_ALL_LAYER_FUNC
template<typename T, typename TConfig>
void ZeroInit(Weights<T, TConfig>& w) {
hwy::ZeroBytes(&w.embedder_input_embedding,
sizeof(w.embedder_input_embedding));
hwy::ZeroBytes(&w.final_norm_scale, sizeof(w.final_norm_scale));
for (int i = 0; i < TConfig::kLayers; ++i) {
hwy::ZeroBytes(w.GetLayer(i), sizeof(*w.GetLayer(i)));
}
}
template<typename T, typename TConfig>
void Copy(Weights<T, TConfig>& dst, const Weights<T, TConfig>& src) {
hwy::CopyBytes(&src.embedder_input_embedding, &dst.embedder_input_embedding,
sizeof(src.embedder_input_embedding));
hwy::CopyBytes(&src.final_norm_scale, &dst.final_norm_scale,
sizeof(src.final_norm_scale));
for (int i = 0; i < TConfig::kLayers; ++i) {
hwy::CopyBytes(src.GetLayer(i), dst.GetLayer(i), sizeof(*dst.GetLayer(i)));
}
}
// Owns weights and undoes the type erasure of AllocateWeights.
template<typename T, typename TConfig>
class WeightsWrapper {
public:
WeightsWrapper()
: pool_(0), data_(AllocateWeights<T, TConfig>(pool_)),
weights_(reinterpret_cast<Weights<T, TConfig>*>(data_.get())) {}
const Weights<T, TConfig>& get() const { return *weights_; }
Weights<T, TConfig>& get() { return *weights_; }
void clear() { ZeroInit(get()); }
void copy(const WeightsWrapper<T, TConfig>& other) {
Copy(get(), other.get());
}
private:
hwy::ThreadPool pool_;
ByteStorageT data_;
Weights<T, TConfig>* weights_;
};
ByteStorageT AllocateWeights(Model model, hwy::ThreadPool& pool);
void ZeroInitWeights(Model model, ByteStorageT& weights, hwy::ThreadPool& pool);
void LogWeightStats(Model model, const ByteStorageT& weights);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_