diff --git a/BUILD.bazel b/BUILD.bazel index 621233f..11893c1 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -53,6 +53,36 @@ cc_test( ], ) +cc_library( + name = "common", + srcs = ["gemma/common.cc"], + hdrs = [ + "gemma/common.h", + "gemma/configs.h", + ], + deps = [ + "//compression:compress", + "//compression:io", + "@hwy//:hwy", # base.h + "@hwy//:thread_pool", + ], +) + +cc_library( + name = "weights", + srcs = ["gemma/weights.cc"], + hdrs = ["gemma/weights.h"], + deps = [ + ":common", + "//compression:compress", + "//compression:io", + "@hwy//:hwy", + "@hwy//:profiler", + "@hwy//:stats", + "@hwy//:thread_pool", + ], +) + cc_library( name = "gemma_lib", srcs = [ @@ -60,14 +90,12 @@ cc_library( ], hdrs = [ "gemma/activations.h", - "gemma/common.h", - "gemma/common-inl.h", - "gemma/configs.h", "gemma/gemma.h", - "gemma/weights.h", ], deps = [ + ":common", ":ops", + ":weights", "//compression:compress", "//compression:io", "@hwy//:hwy", @@ -93,6 +121,7 @@ cc_library( hdrs = ["util/app.h"], deps = [ ":args", + ":common", ":gemma_lib", "//compression:io", "@hwy//:hwy", @@ -126,6 +155,7 @@ cc_binary( deps = [ ":app", ":args", + ":common", ":gemma_lib", # "//base", "//compression:compress", @@ -141,7 +171,9 @@ cc_binary( srcs = ["gemma/compress_weights.cc"], deps = [ ":args", + ":common", ":gemma_lib", + ":weights", # "//base", "//compression:compress", "@hwy//:hwy", @@ -157,7 +189,9 @@ cc_binary( deps = [ ":app", ":args", + ":common", ":gemma_lib", + "//compression:io", "@hwy//:hwy", "@hwy//:nanobenchmark", "@hwy//:thread_pool", diff --git a/CMakeLists.txt b/CMakeLists.txt index dcc9087..6c4d12f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -50,8 +50,6 @@ set(SOURCES backprop/backward.h backprop/backward-inl.h backprop/backward_scalar.h - backprop/common_scalar.cc - backprop/common_scalar.h backprop/forward.cc backprop/forward.h backprop/forward-inl.h @@ -59,10 +57,9 @@ set(SOURCES backprop/optimizer.cc backprop/optimizer.h gemma/configs.h - gemma/activations.cc gemma/activations.h + gemma/common.cc gemma/common.h - gemma/common-inl.h gemma/gemma.cc gemma/gemma.h gemma/ops.h diff --git a/DEVELOPERS.md b/DEVELOPERS.md index 865ac26..324a33e 100644 --- a/DEVELOPERS.md +++ b/DEVELOPERS.md @@ -151,9 +151,9 @@ file (usually `tokenizer.spm`), call `Encode()` to go from string prompts to token id vectors, or `Decode()` to go from token id vector outputs from the model back to strings. -### `GenerateGemma()` is the entrypoint for token generation +### `model.Generate()` is the entrypoint for token generation -Calling into `GenerateGemma` with a tokenized prompt will 1) mutate the +Calling into `model.Generate` with a tokenized prompt will 1) mutate the activation values in `model` and 2) invoke StreamFunc - a lambda callback for each generated token. @@ -170,11 +170,11 @@ no-op which is what `run.cc` does. ### `Transformer()` implements the inference (i.e. `forward()` method in PyTorch or Jax) computation of the neural network -For high-level applications, you might only call `GenerateGemma()` and never +For high-level applications, you might only call `model.Generate()` and never interact directly with the neural network, but if you're doing something a bit -more custom you can call transformer which performs a single inference -operation on a single token and mutates the Activations and the KVCache through -the neural network computation. +more custom you can call transformer which performs a single inference operation +on a single token and mutates the Activations and the KVCache through the neural +network computation. ### For low level operations, defining new architectures, call `ops.h` functions directly diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h index 6d11fea..a0a6be3 100644 --- a/backprop/backward-inl.h +++ b/backprop/backward-inl.h @@ -28,8 +28,8 @@ #include "backprop/prompt.h" #include "gemma/activations.h" +#include "gemma/common.h" #include "gemma/weights.h" - #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -43,7 +43,6 @@ #define THIRD_PARTY_GEMMA_CPP_BACKWARD_TOGGLE #endif -#include "gemma/common-inl.h" #include "gemma/ops.h" HWY_BEFORE_NAMESPACE(); diff --git a/backprop/backward.cc b/backprop/backward.cc index d6987a4..4baeb0a 100644 --- a/backprop/backward.cc +++ b/backprop/backward.cc @@ -54,6 +54,7 @@ void CrossEntropyLossBackwardPassT(Model model, ByteStorageT& grad, ByteStorageT& backward, hwy::ThreadPool& pool) { + // TODO(janwas): use CallFunctorForModel switch (model) { case Model::GEMMA_2B: CrossEntropyLossBackwardPass( diff --git a/backprop/backward_scalar.h b/backprop/backward_scalar.h index 1c9fbbc..fe5ca46 100644 --- a/backprop/backward_scalar.h +++ b/backprop/backward_scalar.h @@ -23,9 +23,9 @@ #include #include -#include "backprop/common_scalar.h" #include "backprop/prompt.h" #include "gemma/activations.h" +#include "gemma/common.h" // EmbeddingScaling #include "gemma/weights.h" namespace gcpp { diff --git a/backprop/backward_scalar_test.cc b/backprop/backward_scalar_test.cc index 17f1b6c..2ae1298 100644 --- a/backprop/backward_scalar_test.cc +++ b/backprop/backward_scalar_test.cc @@ -19,10 +19,10 @@ #include #include +#include "gtest/gtest.h" #include "backprop/forward_scalar.h" #include "backprop/sampler.h" #include "backprop/test_util.h" -#include "gtest/gtest.h" namespace gcpp { diff --git a/backprop/backward_test.cc b/backprop/backward_test.cc index cf13e6e..aad6838 100644 --- a/backprop/backward_test.cc +++ b/backprop/backward_test.cc @@ -30,11 +30,11 @@ #include "backprop/sampler.h" #include "backprop/test_util.h" #include "compression/compress.h" +#include "gemma/gemma.h" +#include "gemma/weights.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" -#include "gemma/gemma.h" -#include "gemma/weights.h" // clang-format off #undef HWY_TARGET_INCLUDE diff --git a/backprop/common_scalar.cc b/backprop/common_scalar.cc deleted file mode 100644 index b04dffb..0000000 --- a/backprop/common_scalar.cc +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Compiles this file for multiple architectures via "foreach_target.h", to -// which we pass the filename via macro 'argument'. -#undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE "backprop/common_scalar.cc" // NOLINT -#include "hwy/foreach_target.h" // IWYU pragma: keep - -#include "hwy/highway.h" -// After highway.h -#include "gemma/ops.h" - -HWY_BEFORE_NAMESPACE(); -namespace gcpp { -namespace HWY_NAMESPACE { -namespace hn = hwy::HWY_NAMESPACE; - -float EmbeddingScaling(int model_dim) { - // Round to bf16 to match Gemma's Embedder, which casts before mul. - return hwy::ConvertScalarTo(hwy::ConvertScalarTo( - Sqrt(static_cast(model_dim)))); -} - -} // namespace HWY_NAMESPACE -} // namespace gcpp -HWY_AFTER_NAMESPACE(); - -#if HWY_ONCE -namespace gcpp { - -HWY_EXPORT(EmbeddingScaling); - -float EmbeddingScaling(int model_dim) { - return HWY_DYNAMIC_DISPATCH(EmbeddingScaling)(model_dim); -} - -} // namespace gcpp -#endif // HWY_ONCE diff --git a/backprop/common_scalar.h b/backprop/common_scalar.h deleted file mode 100644 index 628962b..0000000 --- a/backprop/common_scalar.h +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_ - -#include - -namespace gcpp { - -template -U DotT(const T* a, const U* b, size_t N) { - U sum = {}; - for (size_t i = 0; i < N; ++i) { - sum += a[i] * b[i]; - } - return sum; -} - -template<> -std::complex DotT(const float* a, const std::complex* b, - size_t N) { - std::complex sum = {}; - for (size_t i = 0; i < N; ++i) { - sum += static_cast(a[i]) * b[i]; - } - return sum; -} - -template -void MulByConstT(T c, T* x, size_t N) { - for (size_t i = 0; i < N; ++i) { - x[i] *= c; - } -} - -// out += c * x -template -void MulByConstAndAddT(T c, const T* x, T* out, size_t N) { - for (size_t i = 0; i < N; ++i) { - out[i] += c * x[i]; - } -} - -template -void MulByConstAndAddT(T c, const std::array& x, std::array& out) { - MulByConstAndAddT(c, x.data(), out.data(), N); -} - -template -void AddFromT(const T* a, T* out, size_t N) { - for (size_t i = 0; i < N; ++i) { - out[i] += a[i]; - } -} - -template -T SquaredL2(const T* x, size_t N) { - T sum = {}; - for (size_t i = 0; i < N; ++i) { - sum += x[i] * x[i]; - } - return sum; -} - -template -T Gelu(T x) { - static const T kMul = 0.044715; - static const T kSqrt2OverPi = 0.797884560804236; - - const T x3 = x * x * x; - const T arg = kSqrt2OverPi * (kMul * x3 + x); - const T cdf = T(0.5) * (T(1.0) + std::tanh(arg)); - return x * cdf; -} - -template -void Rope(T* x, U base, size_t N, int i) { - const size_t N2 = N / 2; - for (size_t dim = 0; dim < N2; ++dim) { - const T freq_exponents = T(2 * dim) / T(N); - const T timescale = std::pow(base, freq_exponents); - const T theta = T(i) / timescale; - const T cos_val = std::cos(theta); - const T sin_val = std::sin(theta); - const T x0 = x[dim]; - const T x1 = x[dim + N2]; - x[dim] = x0 * cos_val - x1 * sin_val; - x[dim + N2] = x0 * sin_val + x1 * cos_val; - } -} - -template -void Rope(T* x, size_t N, int i) { - Rope(x, T(10000.0), N, i); -} - -template -void Rope(std::complex* x, size_t N, int i) { - Rope(x, T(10000.0), N, i); -} - -float EmbeddingScaling(int model_dim); - -} // namespace gcpp - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_SCALAR_H_ diff --git a/backprop/forward-inl.h b/backprop/forward-inl.h index 4b7cdf1..253319c 100644 --- a/backprop/forward-inl.h +++ b/backprop/forward-inl.h @@ -24,8 +24,8 @@ #include #include "gemma/activations.h" +#include "gemma/common.h" #include "gemma/configs.h" - #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -39,7 +39,6 @@ #define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE #endif -#include "gemma/common-inl.h" #include "gemma/ops.h" HWY_BEFORE_NAMESPACE(); diff --git a/backprop/forward.cc b/backprop/forward.cc index fb67c2a..b712ce6 100644 --- a/backprop/forward.cc +++ b/backprop/forward.cc @@ -48,6 +48,7 @@ float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt, const ByteStorageT& weights, ByteStorageT& forward, hwy::ThreadPool& pool) { + // TODO(janwas): use CallFunctorForModel switch (model) { case Model::GEMMA_2B: return CrossEntropyLossForwardPass( diff --git a/backprop/forward_scalar.h b/backprop/forward_scalar.h index 5643530..44182a8 100644 --- a/backprop/forward_scalar.h +++ b/backprop/forward_scalar.h @@ -23,9 +23,9 @@ #include #include -#include "backprop/common_scalar.h" #include "backprop/prompt.h" #include "gemma/activations.h" +#include "gemma/common.h" // EmbeddingScaling #include "gemma/weights.h" namespace gcpp { diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 5c354a9..32d4b93 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -16,6 +16,7 @@ #include #include +#include "gtest/gtest.h" #include "backprop/backward.h" #include "backprop/forward.h" #include "backprop/optimizer.h" @@ -23,7 +24,6 @@ #include "gemma/activations.h" #include "gemma/gemma.h" #include "gemma/weights.h" -#include "gtest/gtest.h" namespace gcpp { @@ -32,20 +32,20 @@ TEST(OptimizeTest, GradientDescent) { std::mt19937 gen(42); Model model_type = Model::GEMMA_TINY; - ByteStorageT weights = AllocateWeights(model_type, pool); - ByteStorageT grad = AllocateWeights(model_type, pool); - ByteStorageT grad_m = AllocateWeights(model_type, pool); - ByteStorageT grad_v = AllocateWeights(model_type, pool); - ByteStorageT forward = AllocateForwardPass(model_type); - ByteStorageT backward = AllocateForwardPass(model_type); - ByteStorageT inference = AllocateInferenceState(model_type); - auto kv_cache = CreateKVCache(model_type); + ByteStorageT grad = CallFunctorForModel(model_type, pool); + ByteStorageT grad_m = CallFunctorForModel(model_type, pool); + ByteStorageT grad_v = CallFunctorForModel(model_type, pool); + ByteStorageT forward = CallFunctorForModel(model_type); + ByteStorageT backward = CallFunctorForModel(model_type); + KVCache kv_cache = KVCache::Create(model_type); size_t max_tokens = 32; size_t max_generated_tokens = 16; float temperature = 1.0f; int verbosity = 0; const auto accept_token = [](int) { return true; }; + Gemma gemma(GemmaTokenizer(), model_type, pool); + const auto generate = [&](const std::vector& prompt) { std::vector reply; auto stream_token = [&reply](int token, float) { @@ -57,8 +57,7 @@ TEST(OptimizeTest, GradientDescent) { stream_token, accept_token, ReverseSequenceSampler::kEndToken, }; TimingInfo timing_info; - GenerateGemma(model_type, weights, inference, runtime, prompt, 0, - kv_cache, pool, timing_info); + model.Generate(runtime, prompt, 0, kv_cache, timing_info); return reply; }; @@ -74,12 +73,12 @@ TEST(OptimizeTest, GradientDescent) { return ok; }; - RandInitWeights(model_type, weights, pool, gen); - ZeroInitWeights(model_type, grad_m, pool); - ZeroInitWeights(model_type, grad_v, pool); + CallFunctorForModel(model_type, gemma.Weights(), pool, gen); + CallFunctorForModel(model_type, grad_m, pool); + CallFunctorForModel(model_type, grad_v, pool); printf("Initial weights:\n"); - LogWeightStats(model_type, weights); + LogWeightStats(model_type, gemma.Weights()); constexpr size_t kBatchSize = 8; float learning_rate = 0.0005f; @@ -91,21 +90,21 @@ TEST(OptimizeTest, GradientDescent) { size_t num_ok; for (; steps < 1000000; ++steps) { std::mt19937 sgen(42); - ZeroInitWeights(model_type, grad, pool); + CallFunctorForModel(model_type, grad, pool); float total_loss = 0.0f; num_ok = 0; for (size_t i = 0; i < kBatchSize; ++i) { Prompt prompt = training_task.Sample(sgen); - total_loss += CrossEntropyLossForwardPass( - model_type, prompt, weights, forward, pool); - CrossEntropyLossBackwardPass( - model_type, prompt, weights, forward, grad, backward, pool); + total_loss += CrossEntropyLossForwardPass(model_type, prompt, + gemma.Weights(), forward, pool); + CrossEntropyLossBackwardPass(model_type, prompt, gemma.Weights(), forward, + grad, backward, pool); num_ok += verify(prompt) ? 1 : 0; } total_loss /= kBatchSize; const float scale = -learning_rate / kBatchSize; - UpdateWeights(model_type, grad, scale, weights, pool); + UpdateWeights(model_type, grad, scale, gemma.Weights(), pool); printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n", steps, total_loss, num_ok, kBatchSize); if (steps % 100 == 0) { @@ -119,7 +118,7 @@ TEST(OptimizeTest, GradientDescent) { } printf("Num steps: %zu\n", steps); printf("Final weights:\n"); - LogWeightStats(model_type, weights); + LogWeightStats(model_type, gemma.Weights()); EXPECT_LT(steps, 3000); EXPECT_EQ(num_ok, kBatchSize); } diff --git a/backprop/optimizer.cc b/backprop/optimizer.cc index 5b1c61d..79659b7 100644 --- a/backprop/optimizer.cc +++ b/backprop/optimizer.cc @@ -41,14 +41,16 @@ class WeightInitializer { }; template -void RandInitWeights(ByteStorageT& weights_u8, hwy::ThreadPool& pool, - std::mt19937& gen) { - auto& weights = *reinterpret_cast*>(weights_u8.get()); - // TODO(szabadka) Use the same weight initialization method as in the python - // version. - WeightInitializer init(gen); - ForEachTensor1(init, weights); -} +struct RandInitWeights { + void operator()(ByteStorageT& weights_u8, hwy::ThreadPool& pool, + std::mt19937& gen) const { + auto& weights = *reinterpret_cast*>(weights_u8.get()); + // TODO(szabadka) Use the same weight initialization method as in the python + // version. + WeightInitializer init(gen); + ForEachTensor1(init, weights); + } +}; class WeightUpdater { public: @@ -67,55 +69,22 @@ class WeightUpdater { }; template -void UpdateWeights(const ByteStorageT& grad_u8, float scale, - ByteStorageT& weights_u8, hwy::ThreadPool& pool) { - const auto& grad = - *reinterpret_cast*>(grad_u8.get()); - auto& weights = *reinterpret_cast*>(weights_u8.get()); - WeightUpdater updater(scale); - ForEachTensor2(updater, grad, weights); -} +struct UpdateWeightsT { + void operator()(const ByteStorageT& grad_u8, float scale, + ByteStorageT& weights_u8, hwy::ThreadPool& pool) const { + const auto& grad = + *reinterpret_cast*>(grad_u8.get()); + auto& weights = *reinterpret_cast*>(weights_u8.get()); + WeightUpdater updater(scale); + ForEachTensor2(updater, grad, weights); + } +}; } // namespace -void RandInitWeights(Model model, ByteStorageT& weights_u8, - hwy::ThreadPool& pool, std::mt19937& gen) { - switch (model) { - case Model::GEMMA_2B: - RandInitWeights(weights_u8, pool, gen); - break; - case Model::GEMMA_7B: - RandInitWeights(weights_u8, pool, gen); - break; - case Model::GRIFFIN_2B: - RandInitWeights(weights_u8, pool, gen); - break; - case Model::GEMMA_TINY: - RandInitWeights(weights_u8, pool, gen); - break; - default: - HWY_ABORT("Model type %d unknown.", static_cast(model)); - } -} - void UpdateWeights(Model model, const ByteStorageT& grad, float scale, ByteStorageT& weights, hwy::ThreadPool& pool) { - switch (model) { - case Model::GEMMA_2B: - UpdateWeights(grad, scale, weights, pool); - break; - case Model::GEMMA_7B: - UpdateWeights(grad, scale, weights, pool); - break; - case Model::GRIFFIN_2B: - UpdateWeights(grad, scale, weights, pool); - break; - case Model::GEMMA_TINY: - UpdateWeights(grad, scale, weights, pool); - break; - default: - HWY_ABORT("Model type %d unknown.", static_cast(model)); - } + CallFunctorForModel(model, grad, scale, weights, pool); } } // namespace gcpp diff --git a/backprop/optimizer.h b/backprop/optimizer.h index 343db97..157352b 100644 --- a/backprop/optimizer.h +++ b/backprop/optimizer.h @@ -16,16 +16,11 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_OPTIMIZER_H_ -#include - #include "gemma/common.h" #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { -void RandInitWeights(Model model, ByteStorageT& weights, hwy::ThreadPool& pool, - std::mt19937& gen); - void UpdateWeights(Model model, const ByteStorageT& grad, float scale, ByteStorageT& weights, hwy::ThreadPool& pool); diff --git a/debug_prompt.cc b/debug_prompt.cc index 8954f29..0ceb545 100644 --- a/debug_prompt.cc +++ b/debug_prompt.cc @@ -37,7 +37,7 @@ std::pair QueryModel( gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input, gcpp::LayersOutputT* layers_output) { std::vector prompt; - HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt)); + HWY_ASSERT(model.Tokenizer().Encode(input, &prompt)); // For both pre-trained and instruction-tuned models: prepend "" token // if needed. @@ -48,11 +48,10 @@ std::pair QueryModel( std::mt19937 gen; gen.seed(42); - auto stream_token = [&res, &total_tokens, &app, - tokenizer = model.Tokenizer()](int token, float) { + auto stream_token = [&res, &total_tokens, &model](int token, float) { ++total_tokens; std::string token_text; - HWY_ASSERT(tokenizer->Decode(std::vector{token}, &token_text)); + HWY_ASSERT(model.Tokenizer().Decode(std::vector{token}, &token_text)); res += token_text; return true; }; @@ -70,8 +69,8 @@ std::pair QueryModel( .stream_token = stream_token, .accept_token = accept_token, }; - GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0, kv_cache, pool, - timing_info, layers_output); + model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache, timing_info, + layers_output); return {res, total_tokens}; } @@ -115,7 +114,7 @@ int main(int argc, char** argv) { } gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool); - auto kv_cache = CreateKVCache(loader.ModelType()); + gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType()); const std::string& prompt = prompt_args.prompt; if (prompt.empty()) { diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 9bb0d9f..9ac7259 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -40,7 +40,7 @@ int main(int argc, char** argv) { // Instantiate model and KV Cache gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool); - auto kv_cache = CreateKVCache(loader.ModelType()); + gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType()); size_t pos = 0; // KV Cache position // Initialize random number generator diff --git a/gemma/activations.cc b/gemma/activations.cc deleted file mode 100644 index 56f8bde..0000000 --- a/gemma/activations.cc +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "gemma/activations.h" - -#include "gemma/common.h" -#include "gemma/configs.h" - -namespace gcpp { - -ByteStorageT AllocateForwardPass(Model model) { - switch (model) { - case Model::GEMMA_2B: - return ForwardPass::Allocate(); - case Model::GEMMA_7B: - return ForwardPass::Allocate(); - case Model::GRIFFIN_2B: - return ForwardPass::Allocate(); - case Model::GEMMA_TINY: - return ForwardPass::Allocate(); - default: - HWY_ABORT("Model type %d unknown.", static_cast(model)); - } -} - -} // namespace gcpp diff --git a/gemma/activations.h b/gemma/activations.h index 9a43344..6894d67 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -16,8 +16,11 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ -#include "gemma/common.h" -#include "hwy/aligned_allocator.h" +#include + +#include + +#include "gemma/common.h" // ByteStorageT namespace gcpp { @@ -54,9 +57,12 @@ struct ForwardPass { std::array final_norm_output; std::array logits; std::array probs; +}; - static ByteStorageT Allocate() { - return hwy::AllocateAligned(sizeof(ForwardPass)); +template +struct AllocateForwardPass { + ByteStorageT operator()() const { + return AllocateSizeof>(); } }; @@ -78,8 +84,6 @@ class ActivationsWrapper { WrappedT& activations_; }; -ByteStorageT AllocateForwardPass(Model model); - } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ACTIVATIONS_H_ diff --git a/gemma/benchmark.cc b/gemma/benchmark.cc index 2233de9..1d41153 100644 --- a/gemma/benchmark.cc +++ b/gemma/benchmark.cc @@ -1,4 +1,5 @@ #include +#include // EXIT_FAILURE #include #include #include @@ -8,6 +9,7 @@ #include // std::pair #include +#include "compression/io.h" // Path #include "gemma/gemma.h" #include "util/app.h" #include "util/args.h" @@ -61,7 +63,7 @@ std::pair QueryModel( gcpp::Gemma& model, gcpp::InferenceArgs& args, gcpp::AppArgs& app, gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const std::string& input) { std::vector prompt; - HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt)); + HWY_ASSERT(model.Tokenizer().Encode(input, &prompt)); // For both pre-trained and instruction-tuned models: prepend "" token // if needed. @@ -73,11 +75,11 @@ std::pair QueryModel( gen.seed(42); const double time_start = hwy::platform::Now(); - auto stream_token = [&res, &total_tokens, &time_start, &app, - tokenizer = model.Tokenizer()](int token, float) { + auto stream_token = [&res, &total_tokens, &time_start, &app, &model]( + int token, float) { ++total_tokens; std::string token_text; - HWY_ASSERT(tokenizer->Decode(std::vector{token}, &token_text)); + HWY_ASSERT(model.Tokenizer().Decode(std::vector{token}, &token_text)); res += token_text; if (app.verbosity >= 1 && total_tokens % 100 == 0) { LogSpeedStats(time_start, total_tokens); @@ -98,8 +100,8 @@ std::pair QueryModel( .stream_token = stream_token, .accept_token = accept_token, }; - GenerateGemma(model, runtime_config, prompt, /*start_pos=*/0, kv_cache, pool, - timing_info, /*layers_output=*/nullptr); + model.Generate(runtime_config, prompt, /*start_pos=*/0, kv_cache, timing_info, + /*layers_output=*/nullptr); if (app.verbosity >= 1) { LogSpeedStats(time_start, total_tokens); } @@ -190,7 +192,7 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type, size_t batch_tokens) { std::string input = ReadFile(text); std::vector prompt; - HWY_ASSERT(model.Tokenizer()->Encode(input, &prompt)); + HWY_ASSERT(model.Tokenizer().Encode(input, &prompt)); prompt.resize(std::min(args.max_tokens, prompt.size())); std::cout << "Number of input tokens: " << prompt.size() << "\n"; const double time_start = hwy::platform::Now(); @@ -201,14 +203,13 @@ int BenchmarkCrossEntropy(gcpp::Gemma& model, gcpp::Model model_type, size_t num_tokens = std::min(prompt.size() - pos, batch_tokens); std::vector prompt_slice(prompt.begin() + pos, prompt.begin() + pos + num_tokens); - auto kv_cache = CreateKVCache(model_type); - float entropy = - ComputeCrossEntropy(model, num_tokens, prompt_slice, kv_cache, pool, - app.verbosity); + gcpp::KVCache kv_cache = gcpp::KVCache::Create(model_type); + float entropy = model.ComputeCrossEntropy(num_tokens, prompt_slice, + kv_cache, app.verbosity); total_entropy += entropy; LogSpeedStats(time_start, pos + num_tokens); std::string text_slice; - HWY_ASSERT(model.Tokenizer()->Decode(prompt_slice, &text_slice)); + HWY_ASSERT(model.Tokenizer().Decode(prompt_slice, &text_slice)); total_input_len += text_slice.size(); printf("Cross entropy per byte: %f [cumulative: %f]\n", entropy / text_slice.size(), total_entropy / total_input_len); @@ -277,7 +278,7 @@ int main(int argc, char** argv) { } gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool); - auto kv_cache = CreateKVCache(loader.ModelType()); + gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType()); if (!benchmark_args.goldens.path.empty()) { const std::string golden_path = diff --git a/gemma/common-inl.h b/gemma/common-inl.h deleted file mode 100644 index ac39d73..0000000 --- a/gemma/common-inl.h +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Include guard for non-SIMD code. -#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_ -#define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_ - -#include -#include - -#include -#include - -#include "gemma/activations.h" - -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_INL_H_ - -// Include guard for (potentially) SIMD code. -#if defined(THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE) == defined(HWY_TARGET_TOGGLE) -#ifdef THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE -#undef THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE -#else -#define THIRD_PARTY_GEMMA_CPP_COMMON_TOGGLE -#endif - -#include "gemma/ops.h" - -HWY_BEFORE_NAMESPACE(); -namespace gcpp { -namespace HWY_NAMESPACE { -namespace hn = hwy::HWY_NAMESPACE; - -// `EmbeddingScaling` can be constexpr only if `Sqrt` and `hwy::ConvertScalarTo` -// are both constexpr -#if HWY_COMPILER_GCC_ACTUAL -#define GEMMA_CONSTEXPR_EMBSCALING HWY_BF16_CONSTEXPR -#else -#define GEMMA_CONSTEXPR_EMBSCALING -#endif - -template -GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() { - // Round to bf16 to match Gemma's Embedder, which casts before mul. - return hwy::ConvertScalarTo(hwy::ConvertScalarTo( - Sqrt(static_cast(TConfig::kModelDim)))); -} - -// NOLINTNEXTLINE(google-readability-namespace-comments) -} // namespace HWY_NAMESPACE -} // namespace gcpp -HWY_AFTER_NAMESPACE(); - -#endif // NOLINT diff --git a/gemma/common.cc b/gemma/common.cc new file mode 100644 index 0000000..8ed8718 --- /dev/null +++ b/gemma/common.cc @@ -0,0 +1,67 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gemma/common.h" + +#include +#include + +#include // std::transform +#include +#include +#include + +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +namespace gcpp { + +namespace { +constexpr const char* kModelFlags[] = {"2b-pt", "7b-pt", "gr2b-pt", "2b-it", + "7b-it", "gr2b-it", "tiny"}; +constexpr Model kModelTypes[] = { + Model::GEMMA_2B, Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_2B, + Model::GEMMA_7B, Model::GRIFFIN_2B, Model::GEMMA_TINY}; +constexpr ModelTraining kModelTraining[] = { + ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, ModelTraining::GEMMA_PT, + ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, ModelTraining::GEMMA_IT, + ModelTraining::GEMMA_IT}; +} // namespace + +const char* ParseModelTypeAndTraining(const std::string& model_flag, + Model& model, ModelTraining& training) { + constexpr size_t kNum = std::end(kModelFlags) - std::begin(kModelFlags); + static char kErrorMessageBuffer[kNum * 8 + 1024] = + "Invalid or missing model flag, need to specify one of "; + for (size_t i = 0; i + 1 < kNum; i++) { + strcat(kErrorMessageBuffer, kModelFlags[i]); // NOLINT + strcat(kErrorMessageBuffer, ", "); // NOLINT + } + strcat(kErrorMessageBuffer, kModelFlags[kNum - 1]); // NOLINT + strcat(kErrorMessageBuffer, "."); // NOLINT + std::string model_type_lc = model_flag; + std::transform(begin(model_type_lc), end(model_type_lc), begin(model_type_lc), + [](unsigned char c) { return std::tolower(c); }); + for (size_t i = 0; i < kNum; i++) { + if (kModelFlags[i] == model_type_lc) { + model = kModelTypes[i]; + training = kModelTraining[i]; + return nullptr; + } + } + return kErrorMessageBuffer; +} + +} // namespace gcpp diff --git a/gemma/common.h b/gemma/common.h index d497259..e277095 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -16,15 +16,115 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_COMMON_H_ +#include // sqrtf +#include + +#include + +#include "gemma/configs.h" // IWYU pragma: export #include "hwy/aligned_allocator.h" +#include "hwy/base.h" // ConvertScalarTo namespace gcpp { using ByteStorageT = hwy::AlignedFreeUniquePtr; +template +ByteStorageT AllocateSizeof() { + return hwy::AllocateAligned(sizeof(T)); +} + // Model variants: see configs.h for details. enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B, GEMMA_TINY }; +enum class ModelTraining { GEMMA_IT, GEMMA_PT }; + +// Returns the return value of Func().operator() called with `args`, where +// `T` is selected based on `model`. +// +// This is used to implement type-erased functions such as +// LoadCompressedWeights, which can be called from other .cc files, by calling a +// functor LoadCompressedWeightsT, which has a template argument. `Func` must +// be a functor because function templates cannot be passed as a template +// template argument, and we prefer to avoid the overhead of std::function. +// +// This function avoids having to update all call sites when we extend `Model`. +template