From f9b390b134260d51970a2089a0066f1b645a65fd Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 7 Jun 2024 09:04:06 -0700 Subject: [PATCH] Support all weight types in a single binary. This changes the command line flags, but the default value retains the previous behavior. Also add a CreateGemma helper to enable extra args without interface changes. PiperOrigin-RevId: 641266411 --- BUILD.bazel | 1 + CMakeLists.txt | 11 --- DEVELOPERS.md | 10 +-- README.md | 52 +++++-------- backprop/backward-inl.h | 2 +- backprop/backward.cc | 10 ++- backprop/forward.cc | 12 ++- backprop/optimize_test.cc | 50 ++++++++----- backprop/optimizer.cc | 16 ++-- backprop/optimizer.h | 10 +-- debug_prompt.cc | 2 +- examples/hello_world/run.cc | 2 +- gemma/benchmark.cc | 2 +- gemma/common.cc | 24 ++++++ gemma/common.h | 143 ++++++++++++++++++++++++++---------- gemma/compress_weights.cc | 36 ++++++--- gemma/configs.h | 25 +++---- gemma/cross_entropy.cc | 11 ++- gemma/cross_entropy.h | 2 + gemma/gemma.cc | 52 +++++++------ gemma/gemma.h | 6 +- gemma/gemma_test.cc | 6 +- gemma/run.cc | 8 +- gemma/run_mmlu.cc | 3 +- gemma/weights.cc | 29 ++++---- gemma/weights.h | 47 ++++++------ util/app.h | 48 +++++++----- 27 files changed, 372 insertions(+), 248 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index b3fd70d..f244d13 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -116,6 +116,7 @@ cc_library( "gemma/cross_entropy.h", ], deps = [ + ":common", ":gemma_lib", ], ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6a24ba0..5739561 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -76,17 +76,6 @@ if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE "Release") endif() -# Allowable types for WEIGHT_TYPE: -# float - slow, not recommended -# hwy::bfloat16_t - bfloat16 as implemented by https://github.com/google/highway -# SfpStream - 8-bit switched floating point (recommended) -# NuqStream - experimental, work-in-progress -option(WEIGHT_TYPE "Set weight type" "") - -if (WEIGHT_TYPE) - add_definitions(-DGEMMA_WEIGHT_T=${WEIGHT_TYPE}) -endif() - FetchContent_GetProperties(sentencepiece) ## Library Target diff --git a/DEVELOPERS.md b/DEVELOPERS.md index 324a33e..32bc729 100644 --- a/DEVELOPERS.md +++ b/DEVELOPERS.md @@ -105,18 +105,12 @@ the resulting file as `--weights` and the desired .sbs name as the There are several compile-time flags to be aware of (note these may or may not be exposed to the build system): -- `GEMMA_WEIGHT_T` : Sets the level of compression for weights (surfaced as - WEIGHT_TYPE in CMakeLists.txt). Currently this should be set to `SfpStream` - (default, if no flag is specified) for 8-bit SFP, or `hwy::bfloat16_t` to - enable for higher-fidelity (but slower) bfloat16 support. This is defined in - `gemma.h`. - `GEMMA_MAX_SEQ_LEN` : Sets maximum sequence length to preallocate for the KV Cache. The default is 4096 tokens but can be overridden. This is not exposed through `CMakeLists.txt` yet. -In the medium term both of these will likely be deprecated in favor of handling -options at runtime - allowing for multiple weight compression schemes in a single -build and dynamically resizes the KV cache as needed. +In the medium term this will likely be deprecated in favor of handling options +at runtime - dynamically resizing the KV cache as needed. ## Using gemma.cpp as a Library (Advanced) diff --git a/README.md b/README.md index 1240616..4425cc4 100644 --- a/README.md +++ b/README.md @@ -138,33 +138,16 @@ convenient directory location (e.g. the `build/` directory in this repo). The build system uses [CMake](https://cmake.org/). To build the gemma inference runtime, create a build directory and generate the build files using `cmake` from the top-level project directory. Note if you previous ran `cmake` and are -re-running with a different setting, be sure to clean out the `build/` directory -with `rm -rf build/*` (warning this will delete any other files in the `build/` -directory. - -For the 8-bit switched floating point weights (sfp), run cmake with no options: +re-running with a different setting, be sure to delete all files in the `build/` +directory with `rm -rf build/*`. #### Unix-like Platforms ```sh cmake -B build ``` -**or** if you downloaded bfloat16 weights (any model *without* `-sfp` in the -name), instead of running cmake with no options as above, run cmake with -WEIGHT_TYPE set to [highway's](https://github.com/google/highway) -`hwy::bfloat16_t` type. Alternatively, you can also add -`-DGEMMA_WEIGHT_T=hwy::bfloat16_t` to the C++ compiler flags. - -We intend to soon support all weight types without requiring extra flags. Note -that we recommend using `-sfp` weights instead of bfloat16 for faster inference. - -```sh -cmake -B build -DWEIGHT_TYPE=hwy::bfloat16_t -``` - -After running whichever of the above `cmake` invocations that is appropriate for -your weights, you can enter the `build/` directory and run `make` to build the -`./gemma` executable: +After running `cmake`, you can enter the `build/` directory and run `make` to +build the `./gemma` executable: ```sh # Configure `build` directory @@ -221,11 +204,12 @@ You can now run `gemma` from inside the `build/` directory. `gemma` has the following required arguments: -| Argument | Description | Example value | -| ------------- | ---------------------------- | -------------------------- | -| `--model` | The model type. | `2b-it`, `2b-pt`, `7b-it`, `7b-pt`, ... (see above) | -| `--weights` | The compressed weights file. | `2b-it-sfp.sbs`, ... (see above) | -| `--tokenizer` | The tokenizer file. | `tokenizer.spm` | +Argument | Description | Example value +--------------- | ---------------------------- | ----------------------- +`--model` | The model type. | `2b-it` ... (see below) +`--weights` | The compressed weights file. | `2b-it-sfp.sbs` +`--weight_type` | The compressed weight type. | `sfp` +`--tokenizer` | The tokenizer file. | `tokenizer.spm` `gemma` is invoked as: @@ -233,6 +217,7 @@ You can now run `gemma` from inside the `build/` directory. ./gemma \ --tokenizer [tokenizer file] \ --weights [compressed weights file] \ +--weight_type [f32 or bf16 or sfp] \ --model [2b-it or 2b-pt or 7b-it or 7b-pt or ...] ``` @@ -245,8 +230,7 @@ Example invocation for the following configuration: ```sh ./gemma \ --tokenizer tokenizer.spm \ ---weights 2b-it-sfp.sbs \ ---model 2b-it +--weights 2b-it-sfp.sbs --weight_type sfp --model 2b-it ``` ### RecurrentGemma @@ -270,14 +254,12 @@ Step 1, and run the binary as follows: **Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."** -The most common problem is that `cmake` was built with the wrong weight type and -`gemma` is attempting to load `bfloat16` weights (`2b-it`, `2b-pt`, `7b-it`, -`7b-pt`) using the default switched floating point (sfp) or vice versa. Revisit -step #3 and check that the `cmake` command used to build `gemma` was correct for -the weights that you downloaded. +The most common problem is that the `--weight_type` argument does not match that +of the model file. Revisit step #3 and check which weights you downloaded. -In the future we will handle model format handling from compile time to runtime -to simplify this. +Note that we have already moved weight type from a compile-time decision to a +runtime argument. In a subsequent step, we plan to bake this information into +the weights. **Problems building in Windows / Visual Studio** diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h index 3249db0..fb7a3a1 100644 --- a/backprop/backward-inl.h +++ b/backprop/backward-inl.h @@ -21,7 +21,6 @@ #define THIRD_PARTY_GEMMA_CPP_GEMMA_BACKWARD_INL_H_ #include -#include #include #include @@ -44,6 +43,7 @@ #endif #include "gemma/ops.h" +#include "hwy/highway.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { diff --git a/backprop/backward.cc b/backprop/backward.cc index 4baeb0a..bc1a630 100644 --- a/backprop/backward.cc +++ b/backprop/backward.cc @@ -15,6 +15,11 @@ #include "backprop/backward.h" +#include "backprop/prompt.h" +#include "gemma/activations.h" +#include "gemma/common.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + // Compiles this file for multiple architectures via "foreach_target.h", to // which we pass the filename via macro 'argument'. #undef HWY_TARGET_INCLUDE @@ -29,7 +34,6 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -namespace hn = hwy::HWY_NAMESPACE; template void CrossEntropyLossBackwardPass(const Prompt& prompt, @@ -57,11 +61,11 @@ void CrossEntropyLossBackwardPassT(Model model, // TODO(janwas): use CallFunctorForModel switch (model) { case Model::GEMMA_2B: - CrossEntropyLossBackwardPass( + CrossEntropyLossBackwardPass>( prompt, weights, forward, grad, backward, pool); break; case Model::GEMMA_TINY: - CrossEntropyLossBackwardPass( + CrossEntropyLossBackwardPass>( prompt, weights, forward, grad, backward, pool); break; default: diff --git a/backprop/forward.cc b/backprop/forward.cc index b712ce6..0880ee2 100644 --- a/backprop/forward.cc +++ b/backprop/forward.cc @@ -15,6 +15,11 @@ #include "backprop/forward.h" +#include "backprop/prompt.h" +#include "gemma/activations.h" +#include "gemma/common.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + // Compiles this file for multiple architectures via "foreach_target.h", to // which we pass the filename via macro 'argument'. #undef HWY_TARGET_INCLUDE @@ -29,7 +34,6 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -namespace hn = hwy::HWY_NAMESPACE; template float CrossEntropyLossForwardPass(const Prompt& prompt, @@ -51,10 +55,10 @@ float CrossEntropyLossForwardPassT(Model model, const Prompt& prompt, // TODO(janwas): use CallFunctorForModel switch (model) { case Model::GEMMA_2B: - return CrossEntropyLossForwardPass( - prompt, weights, forward, pool); + return CrossEntropyLossForwardPass>(prompt, weights, + forward, pool); case Model::GEMMA_TINY: - return CrossEntropyLossForwardPass( + return CrossEntropyLossForwardPass>( prompt, weights, forward, pool); default: HWY_ABORT("Model type %d unknown.", static_cast(model)); diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 3698ceb..2a79049 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -13,18 +13,23 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include +#include +#include +#include +#include + +#include "gtest/gtest.h" #include "backprop/backward.h" #include "backprop/forward.h" #include "backprop/optimizer.h" +#include "backprop/prompt.h" #include "backprop/sampler.h" #include "gemma/activations.h" #include "gemma/common.h" #include "gemma/gemma.h" #include "gemma/weights.h" -#include "gtest/gtest.h" +#include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { @@ -35,11 +40,17 @@ TEST(OptimizeTest, GradientDescent) { std::mt19937 gen(42); Model model_type = Model::GEMMA_TINY; - ByteStorageT grad = CallFunctorForModel(model_type, pool); - ByteStorageT grad_m = CallFunctorForModel(model_type, pool); - ByteStorageT grad_v = CallFunctorForModel(model_type, pool); - ByteStorageT forward = CallFunctorForModel(model_type); - ByteStorageT backward = CallFunctorForModel(model_type); + Type weight_type = Type::kF32; + ByteStorageT grad = + CallForModelAndWeight(model_type, weight_type, pool); + ByteStorageT grad_m = + CallForModelAndWeight(model_type, weight_type, pool); + ByteStorageT grad_v = + CallForModelAndWeight(model_type, weight_type, pool); + ByteStorageT forward = + CallForModelAndWeight(model_type, weight_type); + ByteStorageT backward = + CallForModelAndWeight(model_type, weight_type); KVCache kv_cache = KVCache::Create(model_type); size_t max_tokens = 32; size_t max_generated_tokens = 16; @@ -47,7 +58,7 @@ TEST(OptimizeTest, GradientDescent) { int verbosity = 0; const auto accept_token = [](int) { return true; }; - Gemma gemma(GemmaTokenizer(), model_type, pool); + Gemma gemma(GemmaTokenizer(), model_type, weight_type, pool); const auto generate = [&](const std::vector& prompt) { std::vector reply; @@ -76,12 +87,14 @@ TEST(OptimizeTest, GradientDescent) { return ok; }; - RandInitWeights(model_type, gemma.Weights(), pool, gen); - CallFunctorForModel(model_type, grad_m, pool); - CallFunctorForModel(model_type, grad_v, pool); + RandInitWeights(model_type, weight_type, gemma.Weights(), pool, gen); + CallForModelAndWeight(model_type, weight_type, grad_m, + pool); + CallForModelAndWeight(model_type, weight_type, grad_v, + pool); printf("Initial weights:\n"); - LogWeightStats(model_type, gemma.Weights()); + LogWeightStats(model_type, weight_type, gemma.Weights()); constexpr size_t kBatchSize = 8; const float alpha = 0.001f; @@ -96,7 +109,8 @@ TEST(OptimizeTest, GradientDescent) { size_t num_ok; for (; steps < 1000000; ++steps) { std::mt19937 sgen(42); - CallFunctorForModel(model_type, grad, pool); + CallForModelAndWeight(model_type, weight_type, grad, + pool); float total_loss = 0.0f; num_ok = 0; for (size_t i = 0; i < kBatchSize; ++i) { @@ -109,13 +123,13 @@ TEST(OptimizeTest, GradientDescent) { } total_loss /= kBatchSize; - AdamUpdate(model_type, grad, alpha, beta1, beta2, epsilon, steps + 1, - gemma.Weights(), grad_m, grad_v, pool); + AdamUpdate(model_type, weight_type, grad, alpha, beta1, beta2, epsilon, + steps + 1, gemma.Weights(), grad_m, grad_v, pool); printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n", steps, total_loss, num_ok, kBatchSize); if (steps % 100 == 0) { printf("Batch gradient:\n"); - LogWeightStats(model_type, grad); + LogWeightStats(model_type, weight_type, grad); } if (total_loss < 0.5f) { break; @@ -124,7 +138,7 @@ TEST(OptimizeTest, GradientDescent) { } printf("Num steps: %zu\n", steps); printf("Final weights:\n"); - LogWeightStats(model_type, gemma.Weights()); + LogWeightStats(model_type, weight_type, gemma.Weights()); EXPECT_LT(steps, 200); EXPECT_EQ(num_ok, kBatchSize); } diff --git a/backprop/optimizer.cc b/backprop/optimizer.cc index 5b97cfc..4302f9c 100644 --- a/backprop/optimizer.cc +++ b/backprop/optimizer.cc @@ -107,18 +107,20 @@ struct AdamUpdateT { } // namespace -void RandInitWeights(Model model, const ByteStorageT& weights, - hwy::ThreadPool& pool, +void RandInitWeights(Model model_type, Type weight_type, + const ByteStorageT& weights, hwy::ThreadPool& pool, std::mt19937& gen) { - CallFunctorForModel(model, weights, pool, gen); + CallForModelAndWeight(model_type, weight_type, weights, + pool, gen); } -void AdamUpdate(Model model, const ByteStorageT& grad, float alpha, float beta1, - float beta2, float epsilon, size_t t, +void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad, + float alpha, float beta1, float beta2, float epsilon, size_t t, const ByteStorageT& weights, const ByteStorageT& grad_m, const ByteStorageT& grad_v, hwy::ThreadPool& pool) { - CallFunctorForModel(model, grad, alpha, beta1, beta2, epsilon, t, - weights, grad_m, grad_v, pool); + CallForModelAndWeight(model_type, weight_type, grad, alpha, + beta1, beta2, epsilon, t, weights, grad_m, + grad_v, pool); } } // namespace gcpp diff --git a/backprop/optimizer.h b/backprop/optimizer.h index 90fa2c7..9157fa8 100644 --- a/backprop/optimizer.h +++ b/backprop/optimizer.h @@ -19,16 +19,16 @@ #include #include "gemma/common.h" -#include "gemma/weights.h" #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { -void RandInitWeights(Model model, const ByteStorageT& weights, - hwy::ThreadPool& pool, std::mt19937& gen); +void RandInitWeights(Model model_type, Type weight_type, + const ByteStorageT& weights, hwy::ThreadPool& pool, + std::mt19937& gen); -void AdamUpdate(Model model, const ByteStorageT& grad, float alpha, float beta1, - float beta2, float epsilon, size_t t, +void AdamUpdate(Model model_type, Type weight_type, const ByteStorageT& grad, + float alpha, float beta1, float beta2, float epsilon, size_t t, const ByteStorageT& weights, const ByteStorageT& grad_m, const ByteStorageT& grad_v, hwy::ThreadPool& pool); diff --git a/debug_prompt.cc b/debug_prompt.cc index 0ceb545..cb49fe3 100644 --- a/debug_prompt.cc +++ b/debug_prompt.cc @@ -113,7 +113,7 @@ int main(int argc, char** argv) { gcpp::PinWorkersToCores(pool); } - gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool); + gcpp::Gemma model = gcpp::CreateGemma(loader, pool); gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType()); const std::string& prompt = prompt_args.prompt; diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 9ac7259..ac917af 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -39,7 +39,7 @@ int main(int argc, char** argv) { hwy::ThreadPool pool(num_threads); // Instantiate model and KV Cache - gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool); + gcpp::Gemma model = gcpp::CreateGemma(loader, pool); gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType()); size_t pos = 0; // KV Cache position diff --git a/gemma/benchmark.cc b/gemma/benchmark.cc index 2f0b132..7a04f1f 100644 --- a/gemma/benchmark.cc +++ b/gemma/benchmark.cc @@ -280,7 +280,7 @@ int main(int argc, char** argv) { gcpp::PinWorkersToCores(pool); } - gcpp::Gemma model(loader.tokenizer, loader.weights, loader.ModelType(), pool); + gcpp::Gemma model = gcpp::CreateGemma(loader, pool); gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType()); if (!benchmark_args.goldens.path.empty()) { diff --git a/gemma/common.cc b/gemma/common.cc index 8ed8718..45f98d7 100644 --- a/gemma/common.cc +++ b/gemma/common.cc @@ -64,4 +64,28 @@ const char* ParseModelTypeAndTraining(const std::string& model_flag, return kErrorMessageBuffer; } +const char* ParseType(const std::string& type_string, Type& type) { + constexpr Type kTypes[] = {Type::kF32, Type::kBF16, Type::kSFP}; + constexpr const char* kStrings[] = {"f32", "bf16", "sfp"}; + constexpr size_t kNum = std::end(kStrings) - std::begin(kStrings); + static char kErrorMessageBuffer[kNum * 8 + 100] = + "Invalid or missing type, need to specify one of "; + for (size_t i = 0; i + 1 < kNum; i++) { + strcat(kErrorMessageBuffer, kStrings[i]); // NOLINT + strcat(kErrorMessageBuffer, ", "); // NOLINT + } + strcat(kErrorMessageBuffer, kStrings[kNum - 1]); // NOLINT + strcat(kErrorMessageBuffer, "."); // NOLINT + std::string type_lc = type_string; + std::transform(begin(type_lc), end(type_lc), begin(type_lc), + [](unsigned char c) { return std::tolower(c); }); + for (size_t i = 0; i < kNum; i++) { + if (kStrings[i] == type_lc) { + type = kTypes[i]; + return nullptr; + } + } + return kErrorMessageBuffer; +} + } // namespace gcpp diff --git a/gemma/common.h b/gemma/common.h index e3c2650..e5a581e 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -21,6 +21,7 @@ #include +#include "compression/compress.h" #include "gemma/configs.h" // IWYU pragma: export #include "hwy/aligned_allocator.h" #include "hwy/base.h" // ConvertScalarTo @@ -37,67 +38,129 @@ ByteStorageT AllocateSizeof() { // Model variants: see configs.h for details. enum class Model { GEMMA_2B, GEMMA_7B, GRIFFIN_2B, GEMMA_TINY }; +// Instruction-tuned models require extra 'turn structure' tokens in prompts. enum class ModelTraining { GEMMA_IT, GEMMA_PT }; -// Returns the return value of Func().operator() called with `args`, where -// `T` is selected based on `model`. +// Tensor types for loading weights. +enum class Type { kF32, kBF16, kSFP }; + +// Returns the return value of FuncT>().operator()(args), where +// Config* is selected via `model`. Typically called by CallForModelAndWeight, +// but can also be called directly when FuncT does not actually use TWeight. // -// This is used to implement type-erased functions such as -// LoadCompressedWeights, which can be called from other .cc files, by calling a -// functor LoadCompressedWeightsT, which has a template argument. `Func` must -// be a functor because function templates cannot be passed as a template -// template argument, and we prefer to avoid the overhead of std::function. +// Note that a T prefix indicates a concrete type template argument, whereas a +// T suffix indicates the argument is itself a template. // -// This function avoids having to update all call sites when we extend `Model`. -template