From 810b5a0cc23e36fe0f120c6071b50b6d33a28c28 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Fri, 15 Mar 2024 14:10:24 -0400 Subject: [PATCH 01/25] Update README with more details on contributing code, add experimental/ directory, add READMEs for subdirectories, clean up DEVELOPER notes --- DEVELOPERS.md | 15 ++++++++------- README.md | 7 ++++++- examples/README.md | 7 +++++++ experimental/.gitkeep | 0 experimental/README.md | 3 +++ 5 files changed, 24 insertions(+), 8 deletions(-) create mode 100644 examples/README.md create mode 100644 experimental/.gitkeep create mode 100644 experimental/README.md diff --git a/DEVELOPERS.md b/DEVELOPERS.md index 43b3187..4e104b9 100644 --- a/DEVELOPERS.md +++ b/DEVELOPERS.md @@ -118,8 +118,7 @@ jax / pytorch / keras for NN deployments. ### Gemma struct contains all the state of the inference engine - tokenizer, weights, and activations -`Gemma(...)` - constructor, creates a gemma model object, which is a wrapper -around 3 things - the tokenizer object, weights, activations, and KV Cache. +`Gemma(...)` - constructor, creates a gemma model object. In a standard LLM chat app, you'll probably use a Gemma object directly, in more exotic data processing or research applications, you might decompose @@ -129,11 +128,13 @@ only using a Gemma object. ### Use the tokenizer in the Gemma object (or interact with the Tokenizer object directly) -You pretty much only do things with the tokenizer, call `Encode()` to go from -string prompts to token id vectors, or `Decode()` to go from token id vector -outputs from the model back to strings. +The Gemma object contains contains a pointer to a Tokenizer object. The main +operations performed on the tokenizer are to load the tokenizer model from a +file (usually `tokenizer.spm`), call `Encode()` to go from string prompts to +token id vectors, or `Decode()` to go from token id vector outputs from the +model back to strings. -### The main entrypoint for generation is `GenerateGemma()` +### `GenerateGemma()` is the entrypoint for token generation Calling into `GenerateGemma` with a tokenized prompt will 1) mutate the activation values in `model` and 2) invoke StreamFunc - a lambda callback for @@ -150,7 +151,7 @@ constrained decoding type of use cases where you want to force the generation to fit a grammar. If you're not doing this, you can send an empty lambda as a no-op which is what `run.cc` does. -### If you want to invoke the neural network forward function directly call the `Transformer()` function +### `Transformer()` implements the inference (i.e. `forward()` method in PyTorch or Jax) computation of the neural network For high-level applications, you might only call `GenerateGemma()` and never interact directly with the neural network, but if you're doing something a bit diff --git a/README.md b/README.md index b0b9aad..5b0df18 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,12 @@ For production-oriented edge deployments we recommend standard deployment pathways using Python frameworks like JAX, Keras, PyTorch, and Transformers ([all model variations here](https://www.kaggle.com/models/google/gemma)). -Community contributions large and small are welcome. This project follows +## Contributing + +Community contributions large and small are welcome. See +[DEVELOPERS.md](https://github.com/google/gemma.cpp/blob/main/DEVELOPERS.md) +for additional notes contributing developers and [join the discord by following +this invite link](https://discord.gg/H5jCBAWxAe). This project follows [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..87eb54d --- /dev/null +++ b/examples/README.md @@ -0,0 +1,7 @@ +# Examples + +In this directory are some simple examples illustrating usage of `gemma.cpp` as +a library beyond the interactive `gemma` app implemented in `run.cc`. + +- `hello_world/` - minimal/template project for using `gemma.cpp` as a library. + It sets up the model state and generates text for a single hard coded prompt. diff --git a/experimental/.gitkeep b/experimental/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/experimental/README.md b/experimental/README.md new file mode 100644 index 0000000..2b6ff83 --- /dev/null +++ b/experimental/README.md @@ -0,0 +1,3 @@ +# Experimental + +This directory is for experimental code and features. From fdb1091b9ca65b5265cb83632c2b4f685beafe7a Mon Sep 17 00:00:00 2001 From: Eric Ye Date: Tue, 19 Mar 2024 16:07:47 -0700 Subject: [PATCH 02/25] Connect "--weights" parameter to Gemma PiperOrigin-RevId: 617323257 --- gemma.cc | 5 ----- gemma.h | 2 -- run.cc | 2 +- 3 files changed, 1 insertion(+), 8 deletions(-) diff --git a/gemma.cc b/gemma.cc index f2a2275..35a4a47 100644 --- a/gemma.cc +++ b/gemma.cc @@ -839,11 +839,6 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, } } -Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - Model model_type, hwy::ThreadPool& pool) - : Gemma(tokenizer_path, compressed_weights_path, Path{""}, model_type, - pool) {} - Gemma::~Gemma() = default; // after GemmaInterface is defined const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { diff --git a/gemma.h b/gemma.h index cdd4873..f5e88fa 100644 --- a/gemma.h +++ b/gemma.h @@ -76,8 +76,6 @@ struct GemmaInterface; struct Gemma { Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, const Path& weights_path, Model model_type, hwy::ThreadPool& pool); - Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - Model model_type, hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; diff --git a/run.cc b/run.cc index b08e4ca..fcf974b 100644 --- a/run.cc +++ b/run.cc @@ -234,7 +234,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); } - gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, + gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights, loader.ModelType(), pool); auto kv_cache = CreateKVCache(loader.ModelType()); From 5e0cafbdc2a421a4ab3f8dda906cba79437a2439 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 19 Mar 2024 21:12:06 -0700 Subject: [PATCH 03/25] Fix msan error, uninitialized model_training This arose during the unpacking of LoaderArgs into individual ctor args. Probably better to pass LoaderArgs in, and have only a single ctor to reduce confusion. Also fix includes. PiperOrigin-RevId: 617386447 --- gemma.cc | 7 +++++-- gemma.h | 8 +++----- run.cc | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/gemma.cc b/gemma.cc index 35a4a47..7c9d187 100644 --- a/gemma.cc +++ b/gemma.cc @@ -25,6 +25,8 @@ #include "compression/compress-inl.h" // copybara:import_next_line:gemma_cpp #include "ops.h" +// copybara:import_next_line:gemma_cpp +#include "util/args.h" // Path #include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/highway.h" #include "hwy/profiler.h" @@ -813,8 +815,9 @@ void GemmaImpl::Generate( } Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - const Path& weights_path, Model model_type, - hwy::ThreadPool& pool) { + const Path& weights_path, Model model_type, ModelTraining training, + hwy::ThreadPool& pool) + : model_training(training) { std::unique_ptr tokenizer; { PROFILER_ZONE("Startup.tokenizer"); diff --git a/gemma.h b/gemma.h index f5e88fa..d52356e 100644 --- a/gemma.h +++ b/gemma.h @@ -16,12 +16,9 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_H_ -#include -#include #include #include #include -#include #include // copybara:import_next_line:gemma_cpp @@ -31,7 +28,7 @@ #include "configs.h" // kSeqLen // copybara:end // copybara:import_next_line:gemma_cpp -#include "util/args.h" // ArgsBase +#include "util/args.h" // Path // copybara:end #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::bfloat16_t @@ -75,7 +72,8 @@ struct GemmaInterface; struct Gemma { Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - const Path& weights_path, Model model_type, hwy::ThreadPool& pool); + const Path& weights_path, Model model_type, ModelTraining training, + hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; diff --git a/run.cc b/run.cc index fcf974b..ce9de93 100644 --- a/run.cc +++ b/run.cc @@ -235,7 +235,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights, - loader.ModelType(), pool); + loader.ModelType(), loader.ModelTraining(), pool); auto kv_cache = CreateKVCache(loader.ModelType()); From 130e1f678fbc5cfb6c17a7f15e7c882ff21a4d46 Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Tue, 19 Mar 2024 22:00:52 +0800 Subject: [PATCH 04/25] Adjust vocab size to be the same as gemma_pytorch --- configs.h | 4 ++-- util/convert_weights.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/configs.h b/configs.h index 7b420b5..58c053f 100644 --- a/configs.h +++ b/configs.h @@ -37,7 +37,7 @@ static constexpr size_t kTopK = GEMMA_TOPK; struct ConfigGemma7B { static constexpr int kSeqLen = gcpp::kSeqLen; - static constexpr int kVocabSize = 256128; + static constexpr int kVocabSize = 256000; static constexpr int kLayers = 28; static constexpr int kModelDim = 3072; static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 @@ -49,7 +49,7 @@ struct ConfigGemma7B { struct ConfigGemma2B { static constexpr int kSeqLen = gcpp::kSeqLen; - static constexpr int kVocabSize = 256128; + static constexpr int kVocabSize = 256000; static constexpr int kLayers = 18; static constexpr int kModelDim = 2048; static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 diff --git a/util/convert_weights.py b/util/convert_weights.py index bd6750a..6552d89 100644 --- a/util/convert_weights.py +++ b/util/convert_weights.py @@ -90,7 +90,7 @@ TRANSFORMATIONS = { "2b":defaultdict( lambda: lambda x: x, { - "embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 2048])], 0), + "embedder.weight": lambda x: x, "self_attn.qkv_proj.weight": expand_qkv, "self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]), "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], @@ -101,7 +101,7 @@ TRANSFORMATIONS = { "7b":defaultdict( lambda: lambda x: x, { - "embedder.weight": lambda x: np.concatenate([x, np.zeros([128, 3072])], 0), + "embedder.weight": lambda x: x, "self_attn.qkv_proj.weight": lambda x: x.reshape((3, 16, 256, 3072)).transpose([1,0,2,3]), "self_attn.o_proj.weight": lambda x: x.reshape((3072, 16, 256)).transpose([1,0,2]), "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], @@ -113,7 +113,7 @@ TRANSFORMATIONS = { VALIDATIONS = { "2b": { - "embedder.weight": lambda x: x.shape == (256128, 2048), + "embedder.weight": lambda x: x.shape == (256000, 2048), "model.norm.weight": lambda x: x.shape == (2048,), "self_attn.qkv_proj.weight": lambda x: x.shape == (8, 3, 256, 2048), "self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256), @@ -124,7 +124,7 @@ VALIDATIONS = { "post_attention_layernorm.weight": lambda x: x.shape == (2048,), }, "7b": { - "embedder.weight": lambda x: x.shape == (256128, 3072), + "embedder.weight": lambda x: x.shape == (256000, 3072), "model.norm.weight": lambda x: x.shape == (3072,), "self_attn.qkv_proj.weight": lambda x: x.shape == (16, 3, 256, 3072), "self_attn.o_proj.weight": lambda x: x.shape == (16, 3072, 256), From 6923aec853f2d8df5038855e1a52ab56ee99d06f Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Wed, 20 Mar 2024 18:14:09 +0800 Subject: [PATCH 05/25] Add MQA support --- configs.h | 2 +- gemma.cc | 56 ++++++++++++++++++++++++++++++----------- util/convert_weights.py | 18 ++----------- 3 files changed, 44 insertions(+), 32 deletions(-) diff --git a/configs.h b/configs.h index 58c053f..e704664 100644 --- a/configs.h +++ b/configs.h @@ -54,7 +54,7 @@ struct ConfigGemma2B { static constexpr int kModelDim = 2048; static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 static constexpr int kHeads = 8; - static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support + static constexpr int kKVHeads = 1; static constexpr int kQKVDim = 256; // query size == key size == value size static constexpr int kTopK = gcpp::kTopK; }; diff --git a/gemma.cc b/gemma.cc index 7c9d187..1867fbf 100644 --- a/gemma.cc +++ b/gemma.cc @@ -70,12 +70,13 @@ template struct Layer { Layer() = default; static constexpr size_t kHeads = TConfig::kHeads; + static constexpr size_t kKVHeads = TConfig::kKVHeads; static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kQKVDim = TConfig::kQKVDim; static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; static constexpr size_t kAttVecEinsumWSize = kHeads * kQKVDim * kModelDim; - // 3x for (query, key, value) - static constexpr size_t kQKVEinsumWSize = 3 * kHeads * kQKVDim * kModelDim; + static constexpr size_t kQKVEinsumWSize = + (kHeads + 2 * kKVHeads) * kQKVDim * kModelDim; // 2x for (gelu gating vector, gated vector) static constexpr size_t kGatingEinsumWSize = 2 * kFFHiddenDim * kModelDim; @@ -313,28 +314,46 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, static constexpr size_t kModelDim = gcpp::Activations::kModelDim; static constexpr size_t kHeads = TConfig::kHeads; + static constexpr size_t kKVHeads = TConfig::kKVHeads; static const float kQueryScale = static_cast(1.0 / sqrt(static_cast(kQKVDim))); + const size_t batch_offset = batch_idx * kModelDim; + pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { // linear projections to QKV - const size_t head_offset = - 3 * kQKVDim * kModelDim; // 3x for QKV dimensions + constexpr const size_t head_offset = + kHeads == kKVHeads ? 3 * kQKVDim * kModelDim : kQKVDim * kModelDim; const size_t q_offset = head * head_offset + 0 * kQKVDim * kModelDim; - const size_t k_offset = head * head_offset + 1 * kQKVDim * kModelDim; - const size_t v_offset = head * head_offset + 2 * kQKVDim * kModelDim; float* HWY_RESTRICT q = activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; - const size_t batch_offset = batch_idx * kModelDim; - MatVecLoop( c_layer->c_qkv_einsum_w, q_offset, activations.pre_att_rms_out.data() + batch_offset, q); - const size_t kv_offset = - pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; + if constexpr (kHeads == kKVHeads) { + const size_t k_offset = head * head_offset + 1 * kQKVDim * kModelDim; + const size_t v_offset = head * head_offset + 2 * kQKVDim * kModelDim; + const size_t kv_offset = + pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; + + TwoOfsMatVecLoop( + c_layer->c_qkv_einsum_w, k_offset, v_offset, + activations.pre_att_rms_out.data() + batch_offset, + kv_cache.key_cache.get() + kv_offset, + kv_cache.value_cache.get() + kv_offset); + + Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); + } + }); + + if constexpr (kHeads != kKVHeads) { + constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim; + constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim; + constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim; + const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize; TwoOfsMatVecLoop( c_layer->c_qkv_einsum_w, k_offset, v_offset, @@ -342,18 +361,24 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, kv_cache.key_cache.get() + kv_offset, kv_cache.value_cache.get() + kv_offset); + Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); + } + + pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { // Calculate scores + float* HWY_RESTRICT q = + activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; float* HWY_RESTRICT head_att = activations.att.data() + head * TConfig::kSeqLen + batch_idx * kHeads * kQKVDim; Rope(q, kQKVDim, pos); - Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); MulByConst(kQueryScale, q, kQKVDim); // Compute Q dot K scores for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const size_t cache_offset = - pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; + const size_t cache_offset = kHeads == kKVHeads + ? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim + : pos2 * kCachePosSize + layer * kCacheLayerSize; const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset; const float score = Dot(q, k2, kQKVDim); head_att[pos2] = score; @@ -365,8 +390,9 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, batch_idx * kHeads * kQKVDim; hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const size_t cache_offset = - pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; + const size_t cache_offset = kHeads == kKVHeads + ? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim + : pos2 * kCachePosSize + layer * kCacheLayerSize; float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset; MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); } diff --git a/util/convert_weights.py b/util/convert_weights.py index 6552d89..0211c01 100644 --- a/util/convert_weights.py +++ b/util/convert_weights.py @@ -72,26 +72,12 @@ parser.add_argument( args = parser.parse_args() -def expand_qkv(qkv_proj: np.array) -> np.array: - """This won't be needed anymore when MQA is implemented""" - assert qkv_proj.shape == (2560, 2048) - qkv = qkv_proj.reshape((10, 256, 2048)) - - q_proj = qkv[:8].reshape((1,8,256,2048)) - kv_proj = qkv[8:] - kv_proj = kv_proj[:, np.newaxis, :, :] - kv_proj = np.repeat(kv_proj, 8, axis=1) - - qkv = np.concatenate([q_proj, kv_proj]) - qkv = np.transpose(qkv, axes=[1,0,2,3]) - return qkv - TRANSFORMATIONS = { "2b":defaultdict( lambda: lambda x: x, { "embedder.weight": lambda x: x, - "self_attn.qkv_proj.weight": expand_qkv, + "self_attn.qkv_proj.weight": lambda x: x.reshape((10, 256, 2048)), "self_attn.o_proj.weight": lambda x: x.reshape((2048, 8, 256)).transpose([1,0,2]), "mlp.gate_proj.weight": lambda x: x[np.newaxis, :, :], "mlp.up_proj.weight": lambda x: x[np.newaxis, :, :], @@ -115,7 +101,7 @@ VALIDATIONS = { "2b": { "embedder.weight": lambda x: x.shape == (256000, 2048), "model.norm.weight": lambda x: x.shape == (2048,), - "self_attn.qkv_proj.weight": lambda x: x.shape == (8, 3, 256, 2048), + "self_attn.qkv_proj.weight": lambda x: x.shape == (10, 256, 2048), "self_attn.o_proj.weight": lambda x: x.shape == (8, 2048, 256), "mlp.gate_proj.weight": lambda x: x.shape == (1, 16384, 2048), "mlp.up_proj.weight": lambda x: x.shape == (1, 16384, 2048), From fdc3812446327b60d3a187c1153ec6791fc773ad Mon Sep 17 00:00:00 2001 From: Eric Ye Date: Tue, 19 Mar 2024 23:35:58 +0100 Subject: [PATCH 06/25] No public description PiperOrigin-RevId: 617315030 --- gemma.cc | 12 +++++++----- gemma.h | 10 +++++++--- run.cc | 4 ++-- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/gemma.cc b/gemma.cc index 7c9d187..f2a2275 100644 --- a/gemma.cc +++ b/gemma.cc @@ -25,8 +25,6 @@ #include "compression/compress-inl.h" // copybara:import_next_line:gemma_cpp #include "ops.h" -// copybara:import_next_line:gemma_cpp -#include "util/args.h" // Path #include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/highway.h" #include "hwy/profiler.h" @@ -815,9 +813,8 @@ void GemmaImpl::Generate( } Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - const Path& weights_path, Model model_type, ModelTraining training, - hwy::ThreadPool& pool) - : model_training(training) { + const Path& weights_path, Model model_type, + hwy::ThreadPool& pool) { std::unique_ptr tokenizer; { PROFILER_ZONE("Startup.tokenizer"); @@ -842,6 +839,11 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, } } +Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + Model model_type, hwy::ThreadPool& pool) + : Gemma(tokenizer_path, compressed_weights_path, Path{""}, model_type, + pool) {} + Gemma::~Gemma() = default; // after GemmaInterface is defined const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { diff --git a/gemma.h b/gemma.h index d52356e..cdd4873 100644 --- a/gemma.h +++ b/gemma.h @@ -16,9 +16,12 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_H_ +#include +#include #include #include #include +#include #include // copybara:import_next_line:gemma_cpp @@ -28,7 +31,7 @@ #include "configs.h" // kSeqLen // copybara:end // copybara:import_next_line:gemma_cpp -#include "util/args.h" // Path +#include "util/args.h" // ArgsBase // copybara:end #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::bfloat16_t @@ -72,8 +75,9 @@ struct GemmaInterface; struct Gemma { Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - const Path& weights_path, Model model_type, ModelTraining training, - hwy::ThreadPool& pool); + const Path& weights_path, Model model_type, hwy::ThreadPool& pool); + Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + Model model_type, hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; diff --git a/run.cc b/run.cc index ce9de93..b08e4ca 100644 --- a/run.cc +++ b/run.cc @@ -234,8 +234,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); } - gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights, - loader.ModelType(), loader.ModelTraining(), pool); + gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, + loader.ModelType(), pool); auto kv_cache = CreateKVCache(loader.ModelType()); From 6865819bb72f765d42915d69a6def106366de46d Mon Sep 17 00:00:00 2001 From: Eric Ye Date: Wed, 20 Mar 2024 00:07:47 +0100 Subject: [PATCH 07/25] Connect "--weights" parameter to Gemma PiperOrigin-RevId: 617323257 --- gemma.cc | 5 ----- gemma.h | 2 -- run.cc | 2 +- 3 files changed, 1 insertion(+), 8 deletions(-) diff --git a/gemma.cc b/gemma.cc index f2a2275..35a4a47 100644 --- a/gemma.cc +++ b/gemma.cc @@ -839,11 +839,6 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, } } -Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - Model model_type, hwy::ThreadPool& pool) - : Gemma(tokenizer_path, compressed_weights_path, Path{""}, model_type, - pool) {} - Gemma::~Gemma() = default; // after GemmaInterface is defined const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { diff --git a/gemma.h b/gemma.h index cdd4873..f5e88fa 100644 --- a/gemma.h +++ b/gemma.h @@ -76,8 +76,6 @@ struct GemmaInterface; struct Gemma { Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, const Path& weights_path, Model model_type, hwy::ThreadPool& pool); - Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - Model model_type, hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; diff --git a/run.cc b/run.cc index b08e4ca..fcf974b 100644 --- a/run.cc +++ b/run.cc @@ -234,7 +234,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); } - gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, + gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights, loader.ModelType(), pool); auto kv_cache = CreateKVCache(loader.ModelType()); From 11d9c51473d7a838896a81b0fc9d4c4976c534e7 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 20 Mar 2024 05:12:06 +0100 Subject: [PATCH 08/25] Fix msan error, uninitialized model_training This arose during the unpacking of LoaderArgs into individual ctor args. Probably better to pass LoaderArgs in, and have only a single ctor to reduce confusion. Also fix includes. PiperOrigin-RevId: 617386447 --- gemma.cc | 7 +++++-- gemma.h | 8 +++----- run.cc | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/gemma.cc b/gemma.cc index 35a4a47..7c9d187 100644 --- a/gemma.cc +++ b/gemma.cc @@ -25,6 +25,8 @@ #include "compression/compress-inl.h" // copybara:import_next_line:gemma_cpp #include "ops.h" +// copybara:import_next_line:gemma_cpp +#include "util/args.h" // Path #include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/highway.h" #include "hwy/profiler.h" @@ -813,8 +815,9 @@ void GemmaImpl::Generate( } Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - const Path& weights_path, Model model_type, - hwy::ThreadPool& pool) { + const Path& weights_path, Model model_type, ModelTraining training, + hwy::ThreadPool& pool) + : model_training(training) { std::unique_ptr tokenizer; { PROFILER_ZONE("Startup.tokenizer"); diff --git a/gemma.h b/gemma.h index f5e88fa..d52356e 100644 --- a/gemma.h +++ b/gemma.h @@ -16,12 +16,9 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_H_ -#include -#include #include #include #include -#include #include // copybara:import_next_line:gemma_cpp @@ -31,7 +28,7 @@ #include "configs.h" // kSeqLen // copybara:end // copybara:import_next_line:gemma_cpp -#include "util/args.h" // ArgsBase +#include "util/args.h" // Path // copybara:end #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::bfloat16_t @@ -75,7 +72,8 @@ struct GemmaInterface; struct Gemma { Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - const Path& weights_path, Model model_type, hwy::ThreadPool& pool); + const Path& weights_path, Model model_type, ModelTraining training, + hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; diff --git a/run.cc b/run.cc index fcf974b..ce9de93 100644 --- a/run.cc +++ b/run.cc @@ -235,7 +235,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights, - loader.ModelType(), pool); + loader.ModelType(), loader.ModelTraining(), pool); auto kv_cache = CreateKVCache(loader.ModelType()); From ce32f4db81f9ac91ac18fff42516e3c1a3f12b24 Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Wed, 20 Mar 2024 22:39:31 +0800 Subject: [PATCH 09/25] Streamline the implementation --- gemma.cc | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/gemma.cc b/gemma.cc index 1867fbf..76086d9 100644 --- a/gemma.cc +++ b/gemma.cc @@ -320,6 +320,16 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, const size_t batch_offset = batch_idx * kModelDim; + auto ProjKV = [&](size_t k_offset, size_t v_offset, size_t kv_offset) { + TwoOfsMatVecLoop( + c_layer->c_qkv_einsum_w, k_offset, v_offset, + activations.pre_att_rms_out.data() + batch_offset, + kv_cache.key_cache.get() + kv_offset, + kv_cache.value_cache.get() + kv_offset); + + Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); + }; + pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { // linear projections to QKV constexpr const size_t head_offset = @@ -339,13 +349,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; - TwoOfsMatVecLoop( - c_layer->c_qkv_einsum_w, k_offset, v_offset, - activations.pre_att_rms_out.data() + batch_offset, - kv_cache.key_cache.get() + kv_offset, - kv_cache.value_cache.get() + kv_offset); - - Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); + ProjKV(k_offset, v_offset, kv_offset); } }); @@ -355,13 +359,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim; const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize; - TwoOfsMatVecLoop( - c_layer->c_qkv_einsum_w, k_offset, v_offset, - activations.pre_att_rms_out.data() + batch_offset, - kv_cache.key_cache.get() + kv_offset, - kv_cache.value_cache.get() + kv_offset); - - Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); + ProjKV(k_offset, v_offset, kv_offset); } pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { @@ -376,9 +374,10 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, MulByConst(kQueryScale, q, kQKVDim); // Compute Q dot K scores for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const size_t cache_offset = kHeads == kKVHeads - ? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim - : pos2 * kCachePosSize + layer * kCacheLayerSize; + const size_t cache_offset = + kHeads == kKVHeads + ? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim + : pos2 * kCachePosSize + layer * kCacheLayerSize; const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset; const float score = Dot(q, k2, kQKVDim); head_att[pos2] = score; @@ -390,9 +389,10 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, batch_idx * kHeads * kQKVDim; hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - const size_t cache_offset = kHeads == kKVHeads - ? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim - : pos2 * kCachePosSize + layer * kCacheLayerSize; + const size_t cache_offset = + kHeads == kKVHeads + ? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim + : pos2 * kCachePosSize + layer * kCacheLayerSize; float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset; MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); } From c75d2eb63549fe844c61ee6a80f968f3af34f995 Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Wed, 20 Mar 2024 23:21:43 +0800 Subject: [PATCH 10/25] Add the missing `HWY_ATTR` of `ProjKV` --- gemma.cc | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/gemma.cc b/gemma.cc index 76086d9..877a3dc 100644 --- a/gemma.cc +++ b/gemma.cc @@ -320,15 +320,16 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, const size_t batch_offset = batch_idx * kModelDim; - auto ProjKV = [&](size_t k_offset, size_t v_offset, size_t kv_offset) { - TwoOfsMatVecLoop( - c_layer->c_qkv_einsum_w, k_offset, v_offset, - activations.pre_att_rms_out.data() + batch_offset, - kv_cache.key_cache.get() + kv_offset, - kv_cache.value_cache.get() + kv_offset); + auto ProjKV = + [&](size_t k_offset, size_t v_offset, size_t kv_offset) HWY_ATTR { + TwoOfsMatVecLoop( + c_layer->c_qkv_einsum_w, k_offset, v_offset, + activations.pre_att_rms_out.data() + batch_offset, + kv_cache.key_cache.get() + kv_offset, + kv_cache.value_cache.get() + kv_offset); - Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); - }; + Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); + }; pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { // linear projections to QKV From 8fc6959950df9a2e9b5fa0d95096d1b6d511e7b5 Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Wed, 20 Mar 2024 23:50:14 +0800 Subject: [PATCH 11/25] Move conditional branch out of `pos2` loop --- gemma.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/gemma.cc b/gemma.cc index 877a3dc..9baccd7 100644 --- a/gemma.cc +++ b/gemma.cc @@ -373,12 +373,13 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, Rope(q, kQKVDim, pos); MulByConst(kQueryScale, q, kQKVDim); + + const size_t head_offset = kHeads == kKVHeads ? head * kQKVDim : 0; + // Compute Q dot K scores for (size_t pos2 = 0; pos2 <= pos; ++pos2) { const size_t cache_offset = - kHeads == kKVHeads - ? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim - : pos2 * kCachePosSize + layer * kCacheLayerSize; + pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset; const float score = Dot(q, k2, kQKVDim); head_att[pos2] = score; @@ -391,9 +392,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); for (size_t pos2 = 0; pos2 <= pos; ++pos2) { const size_t cache_offset = - kHeads == kKVHeads - ? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim - : pos2 * kCachePosSize + layer * kCacheLayerSize; + pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset; float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset; MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); } From 7d5364bb802e91c9e7c770266592f79c53d1267a Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 20 Mar 2024 11:31:23 -0700 Subject: [PATCH 12/25] Remove obsolete copybara tags, faster bazel builds (debug) PiperOrigin-RevId: 617576799 --- .github/workflows/build.yml | 2 +- BUILD.bazel | 4 +++- bazel/BUILD | 1 + compression/BUILD | 8 +++++--- examples/hello_world/run.cc | 7 ++----- gemma.cc | 2 ++ gemma.h | 8 +------- run.cc | 7 +++---- util/app.h | 5 ----- 9 files changed, 18 insertions(+), 26 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a0e9dc2..6fa432c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -72,4 +72,4 @@ jobs: with: path: ~/.cache/bazel key: bazel-${{ runner.os }} - - run: bazel build -c opt --cxxopt=-std=c++20 //... \ No newline at end of file + - run: bazel build --cxxopt=-std=c++20 //... diff --git a/BUILD.bazel b/BUILD.bazel index 7f9dfce..319421f 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -4,7 +4,9 @@ load("@rules_license//rules:license.bzl", "license") package( - default_applicable_licenses = ["//:license"], + default_applicable_licenses = [ + "//:license", # Placeholder comment, do not modify + ], default_visibility = ["//visibility:public"], ) diff --git a/bazel/BUILD b/bazel/BUILD index 952624f..194a082 100644 --- a/bazel/BUILD +++ b/bazel/BUILD @@ -1,3 +1,4 @@ +# Required for referencing bazel:com_google_sentencepiece.patch package( default_applicable_licenses = ["//:license"], default_visibility = ["//:__subpackages__"], diff --git a/compression/BUILD b/compression/BUILD index 6c7e9a0..cfbeb99 100644 --- a/compression/BUILD +++ b/compression/BUILD @@ -1,10 +1,12 @@ # Weight compression, I/O and analysis package( - default_applicable_licenses = ["//:license"], + default_applicable_licenses = [ + "//:license", # Placeholder comment, do not modify + ], default_visibility = [ - "//learning/gemini/prod/contrib/gemini_cpp:__subpackages__", - "//:__subpackages__", + # Placeholder for internal visibility, + "//:__subpackages__", # Placeholder, do not modify ], ) diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index a994f31..a352250 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -17,13 +17,10 @@ // copybara:import_next_line:gemma_cpp #include "gemma.h" -// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/app.h" // LoaderArgs // copybara:import_next_line:gemma_cpp #include "util/args.h" -// copybara:end -// copybara:import_next_line:gemma_cpp -#include "util/app.h" // LoaderArgs -// copybara:end #include "hwy/contrib/thread_pool/thread_pool.h" std::vector tokenize( diff --git a/gemma.cc b/gemma.cc index 7c9d187..a230bea 100644 --- a/gemma.cc +++ b/gemma.cc @@ -52,6 +52,8 @@ #include #include +// Placeholder for internal header, do not modify. + // copybara:import_next_line:gemma_cpp #include "compression/compress.h" // copybara:import_next_line:gemma_cpp diff --git a/gemma.h b/gemma.h index d52356e..8c4cab8 100644 --- a/gemma.h +++ b/gemma.h @@ -23,19 +23,13 @@ // copybara:import_next_line:gemma_cpp #include "compression/compress.h" // SfpStream/NuqStream -// copybara:end // copybara:import_next_line:gemma_cpp -#include "configs.h" // kSeqLen -// copybara:end -// copybara:import_next_line:gemma_cpp -#include "util/args.h" // Path -// copybara:end +#include "util/args.h" // Path #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::bfloat16_t #include "hwy/contrib/thread_pool/thread_pool.h" // copybara:import_next_line:sentencepiece #include "src/sentencepiece_processor.h" -// copybara:end namespace gcpp { diff --git a/run.cc b/run.cc index ce9de93..3f38031 100644 --- a/run.cc +++ b/run.cc @@ -22,18 +22,15 @@ #include // NOLINT #include +// Placeholder for internal header, do not modify. // copybara:import_next_line:gemma_cpp #include "compression/compress.h" -// copybara:end // copybara:import_next_line:gemma_cpp #include "gemma.h" // Gemma -// copybara:end // copybara:import_next_line:gemma_cpp #include "util/app.h" -// copybara:end // copybara:import_next_line:gemma_cpp #include "util/args.h" // HasHelp -// copybara:end #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" @@ -277,6 +274,8 @@ int main(int argc, char** argv) { { PROFILER_ZONE("Startup.misc"); + // Placeholder for internal init, do not modify. + gcpp::LoaderArgs loader(argc, argv); gcpp::InferenceArgs inference(argc, argv); gcpp::AppArgs app(argc, argv); diff --git a/util/app.h b/util/app.h index cd6cb6c..ac37971 100644 --- a/util/app.h +++ b/util/app.h @@ -34,15 +34,10 @@ // copybara:import_next_line:gemma_cpp #include "configs.h" -// copybara:end - // copybara:import_next_line:gemma_cpp #include "gemma.h" -// copybara:end - // copybara:import_next_line:gemma_cpp #include "util/args.h" -// copybara:end #include "hwy/base.h" // HWY_ASSERT namespace gcpp { From ffd02c59adaacd92d5607a467825c18723a19ad1 Mon Sep 17 00:00:00 2001 From: Eric Ye Date: Tue, 19 Mar 2024 23:35:58 +0100 Subject: [PATCH 13/25] No public description PiperOrigin-RevId: 617315030 --- .github/workflows/build.yml | 2 +- BUILD.bazel | 4 +--- bazel/BUILD | 1 - compression/BUILD | 8 +++----- examples/hello_world/run.cc | 7 +++++-- gemma.cc | 14 +++++++------- gemma.h | 16 +++++++++++++--- run.cc | 11 ++++++----- util/app.h | 5 +++++ 9 files changed, 41 insertions(+), 27 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6fa432c..a0e9dc2 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -72,4 +72,4 @@ jobs: with: path: ~/.cache/bazel key: bazel-${{ runner.os }} - - run: bazel build --cxxopt=-std=c++20 //... + - run: bazel build -c opt --cxxopt=-std=c++20 //... \ No newline at end of file diff --git a/BUILD.bazel b/BUILD.bazel index 319421f..7f9dfce 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -4,9 +4,7 @@ load("@rules_license//rules:license.bzl", "license") package( - default_applicable_licenses = [ - "//:license", # Placeholder comment, do not modify - ], + default_applicable_licenses = ["//:license"], default_visibility = ["//visibility:public"], ) diff --git a/bazel/BUILD b/bazel/BUILD index 194a082..952624f 100644 --- a/bazel/BUILD +++ b/bazel/BUILD @@ -1,4 +1,3 @@ -# Required for referencing bazel:com_google_sentencepiece.patch package( default_applicable_licenses = ["//:license"], default_visibility = ["//:__subpackages__"], diff --git a/compression/BUILD b/compression/BUILD index cfbeb99..6c7e9a0 100644 --- a/compression/BUILD +++ b/compression/BUILD @@ -1,12 +1,10 @@ # Weight compression, I/O and analysis package( - default_applicable_licenses = [ - "//:license", # Placeholder comment, do not modify - ], + default_applicable_licenses = ["//:license"], default_visibility = [ - # Placeholder for internal visibility, - "//:__subpackages__", # Placeholder, do not modify + "//learning/gemini/prod/contrib/gemini_cpp:__subpackages__", + "//:__subpackages__", ], ) diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index a352250..a994f31 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -17,10 +17,13 @@ // copybara:import_next_line:gemma_cpp #include "gemma.h" -// copybara:import_next_line:gemma_cpp -#include "util/app.h" // LoaderArgs +// copybara:end // copybara:import_next_line:gemma_cpp #include "util/args.h" +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/app.h" // LoaderArgs +// copybara:end #include "hwy/contrib/thread_pool/thread_pool.h" std::vector tokenize( diff --git a/gemma.cc b/gemma.cc index a230bea..f2a2275 100644 --- a/gemma.cc +++ b/gemma.cc @@ -25,8 +25,6 @@ #include "compression/compress-inl.h" // copybara:import_next_line:gemma_cpp #include "ops.h" -// copybara:import_next_line:gemma_cpp -#include "util/args.h" // Path #include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/highway.h" #include "hwy/profiler.h" @@ -52,8 +50,6 @@ #include #include -// Placeholder for internal header, do not modify. - // copybara:import_next_line:gemma_cpp #include "compression/compress.h" // copybara:import_next_line:gemma_cpp @@ -817,9 +813,8 @@ void GemmaImpl::Generate( } Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - const Path& weights_path, Model model_type, ModelTraining training, - hwy::ThreadPool& pool) - : model_training(training) { + const Path& weights_path, Model model_type, + hwy::ThreadPool& pool) { std::unique_ptr tokenizer; { PROFILER_ZONE("Startup.tokenizer"); @@ -844,6 +839,11 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, } } +Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + Model model_type, hwy::ThreadPool& pool) + : Gemma(tokenizer_path, compressed_weights_path, Path{""}, model_type, + pool) {} + Gemma::~Gemma() = default; // after GemmaInterface is defined const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { diff --git a/gemma.h b/gemma.h index 8c4cab8..cdd4873 100644 --- a/gemma.h +++ b/gemma.h @@ -16,20 +16,29 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_H_ +#include +#include #include #include #include +#include #include // copybara:import_next_line:gemma_cpp #include "compression/compress.h" // SfpStream/NuqStream +// copybara:end // copybara:import_next_line:gemma_cpp -#include "util/args.h" // Path +#include "configs.h" // kSeqLen +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/args.h" // ArgsBase +// copybara:end #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::bfloat16_t #include "hwy/contrib/thread_pool/thread_pool.h" // copybara:import_next_line:sentencepiece #include "src/sentencepiece_processor.h" +// copybara:end namespace gcpp { @@ -66,8 +75,9 @@ struct GemmaInterface; struct Gemma { Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - const Path& weights_path, Model model_type, ModelTraining training, - hwy::ThreadPool& pool); + const Path& weights_path, Model model_type, hwy::ThreadPool& pool); + Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + Model model_type, hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; diff --git a/run.cc b/run.cc index 3f38031..b08e4ca 100644 --- a/run.cc +++ b/run.cc @@ -22,15 +22,18 @@ #include // NOLINT #include -// Placeholder for internal header, do not modify. // copybara:import_next_line:gemma_cpp #include "compression/compress.h" +// copybara:end // copybara:import_next_line:gemma_cpp #include "gemma.h" // Gemma +// copybara:end // copybara:import_next_line:gemma_cpp #include "util/app.h" +// copybara:end // copybara:import_next_line:gemma_cpp #include "util/args.h" // HasHelp +// copybara:end #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" @@ -231,8 +234,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); } - gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights, - loader.ModelType(), loader.ModelTraining(), pool); + gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, + loader.ModelType(), pool); auto kv_cache = CreateKVCache(loader.ModelType()); @@ -274,8 +277,6 @@ int main(int argc, char** argv) { { PROFILER_ZONE("Startup.misc"); - // Placeholder for internal init, do not modify. - gcpp::LoaderArgs loader(argc, argv); gcpp::InferenceArgs inference(argc, argv); gcpp::AppArgs app(argc, argv); diff --git a/util/app.h b/util/app.h index ac37971..cd6cb6c 100644 --- a/util/app.h +++ b/util/app.h @@ -34,10 +34,15 @@ // copybara:import_next_line:gemma_cpp #include "configs.h" +// copybara:end + // copybara:import_next_line:gemma_cpp #include "gemma.h" +// copybara:end + // copybara:import_next_line:gemma_cpp #include "util/args.h" +// copybara:end #include "hwy/base.h" // HWY_ASSERT namespace gcpp { From e2a04b79ed7090b26a1cfa0f3e9a6893f74849db Mon Sep 17 00:00:00 2001 From: Eric Ye Date: Wed, 20 Mar 2024 00:07:47 +0100 Subject: [PATCH 14/25] Connect "--weights" parameter to Gemma PiperOrigin-RevId: 617323257 --- gemma.cc | 5 ----- gemma.h | 2 -- run.cc | 2 +- 3 files changed, 1 insertion(+), 8 deletions(-) diff --git a/gemma.cc b/gemma.cc index f2a2275..35a4a47 100644 --- a/gemma.cc +++ b/gemma.cc @@ -839,11 +839,6 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, } } -Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - Model model_type, hwy::ThreadPool& pool) - : Gemma(tokenizer_path, compressed_weights_path, Path{""}, model_type, - pool) {} - Gemma::~Gemma() = default; // after GemmaInterface is defined const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { diff --git a/gemma.h b/gemma.h index cdd4873..f5e88fa 100644 --- a/gemma.h +++ b/gemma.h @@ -76,8 +76,6 @@ struct GemmaInterface; struct Gemma { Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, const Path& weights_path, Model model_type, hwy::ThreadPool& pool); - Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - Model model_type, hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; diff --git a/run.cc b/run.cc index b08e4ca..fcf974b 100644 --- a/run.cc +++ b/run.cc @@ -234,7 +234,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); } - gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, + gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights, loader.ModelType(), pool); auto kv_cache = CreateKVCache(loader.ModelType()); From edaafe335f4c151778061c8ed6964c77caaaf637 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 20 Mar 2024 05:12:06 +0100 Subject: [PATCH 15/25] Fix msan error, uninitialized model_training This arose during the unpacking of LoaderArgs into individual ctor args. Probably better to pass LoaderArgs in, and have only a single ctor to reduce confusion. Also fix includes. PiperOrigin-RevId: 617386447 --- gemma.cc | 7 +++++-- gemma.h | 8 +++----- run.cc | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/gemma.cc b/gemma.cc index 35a4a47..7c9d187 100644 --- a/gemma.cc +++ b/gemma.cc @@ -25,6 +25,8 @@ #include "compression/compress-inl.h" // copybara:import_next_line:gemma_cpp #include "ops.h" +// copybara:import_next_line:gemma_cpp +#include "util/args.h" // Path #include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/highway.h" #include "hwy/profiler.h" @@ -813,8 +815,9 @@ void GemmaImpl::Generate( } Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - const Path& weights_path, Model model_type, - hwy::ThreadPool& pool) { + const Path& weights_path, Model model_type, ModelTraining training, + hwy::ThreadPool& pool) + : model_training(training) { std::unique_ptr tokenizer; { PROFILER_ZONE("Startup.tokenizer"); diff --git a/gemma.h b/gemma.h index f5e88fa..d52356e 100644 --- a/gemma.h +++ b/gemma.h @@ -16,12 +16,9 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_H_ -#include -#include #include #include #include -#include #include // copybara:import_next_line:gemma_cpp @@ -31,7 +28,7 @@ #include "configs.h" // kSeqLen // copybara:end // copybara:import_next_line:gemma_cpp -#include "util/args.h" // ArgsBase +#include "util/args.h" // Path // copybara:end #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::bfloat16_t @@ -75,7 +72,8 @@ struct GemmaInterface; struct Gemma { Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - const Path& weights_path, Model model_type, hwy::ThreadPool& pool); + const Path& weights_path, Model model_type, ModelTraining training, + hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; diff --git a/run.cc b/run.cc index fcf974b..ce9de93 100644 --- a/run.cc +++ b/run.cc @@ -235,7 +235,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights, - loader.ModelType(), pool); + loader.ModelType(), loader.ModelTraining(), pool); auto kv_cache = CreateKVCache(loader.ModelType()); From 06cea2bcdb8bcf22cdd8a19d3c7f4d618f193de6 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 20 Mar 2024 19:31:23 +0100 Subject: [PATCH 16/25] Remove obsolete copybara tags, faster bazel builds (debug) PiperOrigin-RevId: 617576799 --- .github/workflows/build.yml | 2 +- BUILD.bazel | 4 +++- bazel/BUILD | 1 + compression/BUILD | 6 ++++-- examples/hello_world/run.cc | 7 ++----- gemma.cc | 2 ++ gemma.h | 8 +------- run.cc | 11 +++++++---- util/app.h | 5 ----- 9 files changed, 21 insertions(+), 25 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a0e9dc2..6fa432c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -72,4 +72,4 @@ jobs: with: path: ~/.cache/bazel key: bazel-${{ runner.os }} - - run: bazel build -c opt --cxxopt=-std=c++20 //... \ No newline at end of file + - run: bazel build --cxxopt=-std=c++20 //... diff --git a/BUILD.bazel b/BUILD.bazel index 7f9dfce..319421f 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -4,7 +4,9 @@ load("@rules_license//rules:license.bzl", "license") package( - default_applicable_licenses = ["//:license"], + default_applicable_licenses = [ + "//:license", # Placeholder comment, do not modify + ], default_visibility = ["//visibility:public"], ) diff --git a/bazel/BUILD b/bazel/BUILD index 952624f..194a082 100644 --- a/bazel/BUILD +++ b/bazel/BUILD @@ -1,3 +1,4 @@ +# Required for referencing bazel:com_google_sentencepiece.patch package( default_applicable_licenses = ["//:license"], default_visibility = ["//:__subpackages__"], diff --git a/compression/BUILD b/compression/BUILD index 6c7e9a0..7b7040b 100644 --- a/compression/BUILD +++ b/compression/BUILD @@ -1,10 +1,12 @@ # Weight compression, I/O and analysis package( - default_applicable_licenses = ["//:license"], + default_applicable_licenses = [ + "//:license", # Placeholder comment, do not modify + ], default_visibility = [ "//learning/gemini/prod/contrib/gemini_cpp:__subpackages__", - "//:__subpackages__", + "//:__subpackages__", # Placeholder, do not modify ], ) diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index a994f31..a352250 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -17,13 +17,10 @@ // copybara:import_next_line:gemma_cpp #include "gemma.h" -// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/app.h" // LoaderArgs // copybara:import_next_line:gemma_cpp #include "util/args.h" -// copybara:end -// copybara:import_next_line:gemma_cpp -#include "util/app.h" // LoaderArgs -// copybara:end #include "hwy/contrib/thread_pool/thread_pool.h" std::vector tokenize( diff --git a/gemma.cc b/gemma.cc index 7c9d187..743e13b 100644 --- a/gemma.cc +++ b/gemma.cc @@ -52,6 +52,8 @@ #include #include +#include "base/init_google.h" + // copybara:import_next_line:gemma_cpp #include "compression/compress.h" // copybara:import_next_line:gemma_cpp diff --git a/gemma.h b/gemma.h index d52356e..8c4cab8 100644 --- a/gemma.h +++ b/gemma.h @@ -23,19 +23,13 @@ // copybara:import_next_line:gemma_cpp #include "compression/compress.h" // SfpStream/NuqStream -// copybara:end // copybara:import_next_line:gemma_cpp -#include "configs.h" // kSeqLen -// copybara:end -// copybara:import_next_line:gemma_cpp -#include "util/args.h" // Path -// copybara:end +#include "util/args.h" // Path #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::bfloat16_t #include "hwy/contrib/thread_pool/thread_pool.h" // copybara:import_next_line:sentencepiece #include "src/sentencepiece_processor.h" -// copybara:end namespace gcpp { diff --git a/run.cc b/run.cc index ce9de93..04767d7 100644 --- a/run.cc +++ b/run.cc @@ -22,18 +22,15 @@ #include // NOLINT #include +#include "base/init_google.h" // copybara:import_next_line:gemma_cpp #include "compression/compress.h" -// copybara:end // copybara:import_next_line:gemma_cpp #include "gemma.h" // Gemma -// copybara:end // copybara:import_next_line:gemma_cpp #include "util/app.h" -// copybara:end // copybara:import_next_line:gemma_cpp #include "util/args.h" // HasHelp -// copybara:end #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" @@ -277,6 +274,12 @@ int main(int argc, char** argv) { { PROFILER_ZONE("Startup.misc"); + int argc_dummy = 1; + // Required because sentencepiece uses Google I/O which requires InitGoogle. + // argc_dummy = 1 avoids sentencepiece absl flags attempting to parse + // arguments + InitGoogle("usage", &argc_dummy, &argv, false); + gcpp::LoaderArgs loader(argc, argv); gcpp::InferenceArgs inference(argc, argv); gcpp::AppArgs app(argc, argv); diff --git a/util/app.h b/util/app.h index cd6cb6c..ac37971 100644 --- a/util/app.h +++ b/util/app.h @@ -34,15 +34,10 @@ // copybara:import_next_line:gemma_cpp #include "configs.h" -// copybara:end - // copybara:import_next_line:gemma_cpp #include "gemma.h" -// copybara:end - // copybara:import_next_line:gemma_cpp #include "util/args.h" -// copybara:end #include "hwy/base.h" // HWY_ASSERT namespace gcpp { From 30b8a3c1acd8c1461dd4725f483b906b323c1d0d Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 20 Mar 2024 20:07:12 -0700 Subject: [PATCH 17/25] Fix build for RPi, missing hn::. Refs #112, thanks long568 PiperOrigin-RevId: 617704418 --- compression/BUILD | 2 +- gemma.cc | 2 +- ops.h | 15 ++++++++------- run.cc | 8 ++------ 4 files changed, 12 insertions(+), 15 deletions(-) diff --git a/compression/BUILD b/compression/BUILD index 7b7040b..cfbeb99 100644 --- a/compression/BUILD +++ b/compression/BUILD @@ -5,7 +5,7 @@ package( "//:license", # Placeholder comment, do not modify ], default_visibility = [ - "//learning/gemini/prod/contrib/gemini_cpp:__subpackages__", + # Placeholder for internal visibility, "//:__subpackages__", # Placeholder, do not modify ], ) diff --git a/gemma.cc b/gemma.cc index 743e13b..a230bea 100644 --- a/gemma.cc +++ b/gemma.cc @@ -52,7 +52,7 @@ #include #include -#include "base/init_google.h" +// Placeholder for internal header, do not modify. // copybara:import_next_line:gemma_cpp #include "compression/compress.h" diff --git a/ops.h b/ops.h index 481e1d7..7aa7b62 100644 --- a/ops.h +++ b/ops.h @@ -341,20 +341,21 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a, static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2( const float* HWY_RESTRICT a, size_t size) { const hn::ScalableTag d; + using V = hn::Vec; const size_t N = hn::Lanes(d); HWY_DASSERT(size >= 2 * N); HWY_DASSERT(size % (2 * N) == 0); - auto sum0 = hn::Zero(d); - auto sum1 = hn::Zero(d); + V sum0 = hn::Zero(d); + V sum1 = hn::Zero(d); for (size_t i = 0; i <= size - 2 * N; i += 2 * N) { - const auto a0 = LoadU(d, a + i); - sum0 = MulAdd(a0, a0, sum0); - const auto a1 = LoadU(d, a + i + N); - sum1 = MulAdd(a1, a1, sum1); + const V a0 = hn::LoadU(d, a + i); + sum0 = hn::MulAdd(a0, a0, sum0); + const V a1 = hn::LoadU(d, a + i + N); + sum1 = hn::MulAdd(a1, a1, sum1); } - return ReduceSum(d, Add(sum0, sum1)); + return hn::ReduceSum(d, hn::Add(sum0, sum1)); } static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( diff --git a/run.cc b/run.cc index 04767d7..3f38031 100644 --- a/run.cc +++ b/run.cc @@ -22,7 +22,7 @@ #include // NOLINT #include -#include "base/init_google.h" +// Placeholder for internal header, do not modify. // copybara:import_next_line:gemma_cpp #include "compression/compress.h" // copybara:import_next_line:gemma_cpp @@ -274,11 +274,7 @@ int main(int argc, char** argv) { { PROFILER_ZONE("Startup.misc"); - int argc_dummy = 1; - // Required because sentencepiece uses Google I/O which requires InitGoogle. - // argc_dummy = 1 avoids sentencepiece absl flags attempting to parse - // arguments - InitGoogle("usage", &argc_dummy, &argv, false); + // Placeholder for internal init, do not modify. gcpp::LoaderArgs loader(argc, argv); gcpp::InferenceArgs inference(argc, argv); From 89be4c3de8f911257c0d3c15831a3aaaf696da77 Mon Sep 17 00:00:00 2001 From: Eric Ye Date: Tue, 19 Mar 2024 23:35:58 +0100 Subject: [PATCH 18/25] No public description PiperOrigin-RevId: 617315030 --- .github/workflows/build.yml | 2 +- BUILD.bazel | 4 +--- bazel/BUILD | 1 - compression/BUILD | 6 ++---- examples/hello_world/run.cc | 7 +++++-- gemma.cc | 15 ++++++++++----- gemma.h | 16 +++++++++++++--- ops.h | 15 +++++++-------- run.cc | 12 +++++++++--- util/app.h | 5 +++++ 10 files changed, 53 insertions(+), 30 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6fa432c..a0e9dc2 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -72,4 +72,4 @@ jobs: with: path: ~/.cache/bazel key: bazel-${{ runner.os }} - - run: bazel build --cxxopt=-std=c++20 //... + - run: bazel build -c opt --cxxopt=-std=c++20 //... \ No newline at end of file diff --git a/BUILD.bazel b/BUILD.bazel index 319421f..84b393c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -4,9 +4,7 @@ load("@rules_license//rules:license.bzl", "license") package( - default_applicable_licenses = [ - "//:license", # Placeholder comment, do not modify - ], + default_applicable_licenses = ["//third_party/gemma_cpp:license"], default_visibility = ["//visibility:public"], ) diff --git a/bazel/BUILD b/bazel/BUILD index 194a082..952624f 100644 --- a/bazel/BUILD +++ b/bazel/BUILD @@ -1,4 +1,3 @@ -# Required for referencing bazel:com_google_sentencepiece.patch package( default_applicable_licenses = ["//:license"], default_visibility = ["//:__subpackages__"], diff --git a/compression/BUILD b/compression/BUILD index cfbeb99..ddf30c5 100644 --- a/compression/BUILD +++ b/compression/BUILD @@ -1,12 +1,10 @@ # Weight compression, I/O and analysis package( - default_applicable_licenses = [ - "//:license", # Placeholder comment, do not modify - ], + default_applicable_licenses = ["//third_party/gemma_cpp:license"], default_visibility = [ # Placeholder for internal visibility, - "//:__subpackages__", # Placeholder, do not modify + "//third_party/gemma_cpp:__subpackages__", ], ) diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index a352250..a994f31 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -17,10 +17,13 @@ // copybara:import_next_line:gemma_cpp #include "gemma.h" -// copybara:import_next_line:gemma_cpp -#include "util/app.h" // LoaderArgs +// copybara:end // copybara:import_next_line:gemma_cpp #include "util/args.h" +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/app.h" // LoaderArgs +// copybara:end #include "hwy/contrib/thread_pool/thread_pool.h" std::vector tokenize( diff --git a/gemma.cc b/gemma.cc index a230bea..b41bd9c 100644 --- a/gemma.cc +++ b/gemma.cc @@ -25,8 +25,6 @@ #include "compression/compress-inl.h" // copybara:import_next_line:gemma_cpp #include "ops.h" -// copybara:import_next_line:gemma_cpp -#include "util/args.h" // Path #include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/highway.h" #include "hwy/profiler.h" @@ -52,7 +50,10 @@ #include #include +// copybara:strip_begin +// Required because sentencepiece uses Google I/O which requires InitGoogle. // Placeholder for internal header, do not modify. +// copybara:strip_end // copybara:import_next_line:gemma_cpp #include "compression/compress.h" @@ -817,9 +818,8 @@ void GemmaImpl::Generate( } Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - const Path& weights_path, Model model_type, ModelTraining training, - hwy::ThreadPool& pool) - : model_training(training) { + const Path& weights_path, Model model_type, + hwy::ThreadPool& pool) { std::unique_ptr tokenizer; { PROFILER_ZONE("Startup.tokenizer"); @@ -844,6 +844,11 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, } } +Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + Model model_type, hwy::ThreadPool& pool) + : Gemma(tokenizer_path, compressed_weights_path, Path{""}, model_type, + pool) {} + Gemma::~Gemma() = default; // after GemmaInterface is defined const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { diff --git a/gemma.h b/gemma.h index 8c4cab8..cdd4873 100644 --- a/gemma.h +++ b/gemma.h @@ -16,20 +16,29 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_H_ +#include +#include #include #include #include +#include #include // copybara:import_next_line:gemma_cpp #include "compression/compress.h" // SfpStream/NuqStream +// copybara:end // copybara:import_next_line:gemma_cpp -#include "util/args.h" // Path +#include "configs.h" // kSeqLen +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/args.h" // ArgsBase +// copybara:end #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::bfloat16_t #include "hwy/contrib/thread_pool/thread_pool.h" // copybara:import_next_line:sentencepiece #include "src/sentencepiece_processor.h" +// copybara:end namespace gcpp { @@ -66,8 +75,9 @@ struct GemmaInterface; struct Gemma { Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - const Path& weights_path, Model model_type, ModelTraining training, - hwy::ThreadPool& pool); + const Path& weights_path, Model model_type, hwy::ThreadPool& pool); + Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + Model model_type, hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; diff --git a/ops.h b/ops.h index 7aa7b62..481e1d7 100644 --- a/ops.h +++ b/ops.h @@ -341,21 +341,20 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a, static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2( const float* HWY_RESTRICT a, size_t size) { const hn::ScalableTag d; - using V = hn::Vec; const size_t N = hn::Lanes(d); HWY_DASSERT(size >= 2 * N); HWY_DASSERT(size % (2 * N) == 0); - V sum0 = hn::Zero(d); - V sum1 = hn::Zero(d); + auto sum0 = hn::Zero(d); + auto sum1 = hn::Zero(d); for (size_t i = 0; i <= size - 2 * N; i += 2 * N) { - const V a0 = hn::LoadU(d, a + i); - sum0 = hn::MulAdd(a0, a0, sum0); - const V a1 = hn::LoadU(d, a + i + N); - sum1 = hn::MulAdd(a1, a1, sum1); + const auto a0 = LoadU(d, a + i); + sum0 = MulAdd(a0, a0, sum0); + const auto a1 = LoadU(d, a + i + N); + sum1 = MulAdd(a1, a1, sum1); } - return hn::ReduceSum(d, hn::Add(sum0, sum1)); + return ReduceSum(d, Add(sum0, sum1)); } static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( diff --git a/run.cc b/run.cc index 3f38031..45f6783 100644 --- a/run.cc +++ b/run.cc @@ -22,15 +22,19 @@ #include // NOLINT #include -// Placeholder for internal header, do not modify. +// Placeholder for internal header, do not modify. // copybara:strip // copybara:import_next_line:gemma_cpp #include "compression/compress.h" +// copybara:end // copybara:import_next_line:gemma_cpp #include "gemma.h" // Gemma +// copybara:end // copybara:import_next_line:gemma_cpp #include "util/app.h" +// copybara:end // copybara:import_next_line:gemma_cpp #include "util/args.h" // HasHelp +// copybara:end #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" @@ -231,8 +235,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); } - gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights, - loader.ModelType(), loader.ModelTraining(), pool); + gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, + loader.ModelType(), pool); auto kv_cache = CreateKVCache(loader.ModelType()); @@ -274,7 +278,9 @@ int main(int argc, char** argv) { { PROFILER_ZONE("Startup.misc"); + // copybara:strip_begin // Placeholder for internal init, do not modify. + // copybara:strip_end gcpp::LoaderArgs loader(argc, argv); gcpp::InferenceArgs inference(argc, argv); diff --git a/util/app.h b/util/app.h index ac37971..cd6cb6c 100644 --- a/util/app.h +++ b/util/app.h @@ -34,10 +34,15 @@ // copybara:import_next_line:gemma_cpp #include "configs.h" +// copybara:end + // copybara:import_next_line:gemma_cpp #include "gemma.h" +// copybara:end + // copybara:import_next_line:gemma_cpp #include "util/args.h" +// copybara:end #include "hwy/base.h" // HWY_ASSERT namespace gcpp { From 52940d435f83ba91dc6dd913288f80babe3448e0 Mon Sep 17 00:00:00 2001 From: Eric Ye Date: Wed, 20 Mar 2024 00:07:47 +0100 Subject: [PATCH 19/25] Connect "--weights" parameter to Gemma PiperOrigin-RevId: 617323257 --- gemma.cc | 5 ----- gemma.h | 2 -- run.cc | 2 +- 3 files changed, 1 insertion(+), 8 deletions(-) diff --git a/gemma.cc b/gemma.cc index b41bd9c..68f9645 100644 --- a/gemma.cc +++ b/gemma.cc @@ -844,11 +844,6 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, } } -Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - Model model_type, hwy::ThreadPool& pool) - : Gemma(tokenizer_path, compressed_weights_path, Path{""}, model_type, - pool) {} - Gemma::~Gemma() = default; // after GemmaInterface is defined const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { diff --git a/gemma.h b/gemma.h index cdd4873..f5e88fa 100644 --- a/gemma.h +++ b/gemma.h @@ -76,8 +76,6 @@ struct GemmaInterface; struct Gemma { Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, const Path& weights_path, Model model_type, hwy::ThreadPool& pool); - Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - Model model_type, hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; diff --git a/run.cc b/run.cc index 45f6783..c111d47 100644 --- a/run.cc +++ b/run.cc @@ -235,7 +235,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); } - gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, + gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights, loader.ModelType(), pool); auto kv_cache = CreateKVCache(loader.ModelType()); From f8baac80f968fff197e9f081988363220361f057 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 20 Mar 2024 05:12:06 +0100 Subject: [PATCH 20/25] Fix msan error, uninitialized model_training This arose during the unpacking of LoaderArgs into individual ctor args. Probably better to pass LoaderArgs in, and have only a single ctor to reduce confusion. Also fix includes. PiperOrigin-RevId: 617386447 --- gemma.cc | 7 +++++-- gemma.h | 8 +++----- run.cc | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/gemma.cc b/gemma.cc index 68f9645..032177e 100644 --- a/gemma.cc +++ b/gemma.cc @@ -25,6 +25,8 @@ #include "compression/compress-inl.h" // copybara:import_next_line:gemma_cpp #include "ops.h" +// copybara:import_next_line:gemma_cpp +#include "util/args.h" // Path #include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/highway.h" #include "hwy/profiler.h" @@ -818,8 +820,9 @@ void GemmaImpl::Generate( } Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - const Path& weights_path, Model model_type, - hwy::ThreadPool& pool) { + const Path& weights_path, Model model_type, ModelTraining training, + hwy::ThreadPool& pool) + : model_training(training) { std::unique_ptr tokenizer; { PROFILER_ZONE("Startup.tokenizer"); diff --git a/gemma.h b/gemma.h index f5e88fa..d52356e 100644 --- a/gemma.h +++ b/gemma.h @@ -16,12 +16,9 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_H_ -#include -#include #include #include #include -#include #include // copybara:import_next_line:gemma_cpp @@ -31,7 +28,7 @@ #include "configs.h" // kSeqLen // copybara:end // copybara:import_next_line:gemma_cpp -#include "util/args.h" // ArgsBase +#include "util/args.h" // Path // copybara:end #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::bfloat16_t @@ -75,7 +72,8 @@ struct GemmaInterface; struct Gemma { Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - const Path& weights_path, Model model_type, hwy::ThreadPool& pool); + const Path& weights_path, Model model_type, ModelTraining training, + hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; diff --git a/run.cc b/run.cc index c111d47..dd7cd25 100644 --- a/run.cc +++ b/run.cc @@ -236,7 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights, - loader.ModelType(), pool); + loader.ModelType(), loader.ModelTraining(), pool); auto kv_cache = CreateKVCache(loader.ModelType()); From ba86c8d590728b36b3e020bfbf7840102ec53c77 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 20 Mar 2024 19:31:23 +0100 Subject: [PATCH 21/25] Remove obsolete copybara tags, faster bazel builds (debug) PiperOrigin-RevId: 617576799 --- .github/workflows/build.yml | 2 +- BUILD.bazel | 4 +++- bazel/BUILD | 1 + compression/BUILD | 6 ++++-- examples/hello_world/run.cc | 7 ++----- gemma.cc | 3 --- gemma.h | 8 +------- run.cc | 8 +------- util/app.h | 5 ----- 9 files changed, 13 insertions(+), 31 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a0e9dc2..6fa432c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -72,4 +72,4 @@ jobs: with: path: ~/.cache/bazel key: bazel-${{ runner.os }} - - run: bazel build -c opt --cxxopt=-std=c++20 //... \ No newline at end of file + - run: bazel build --cxxopt=-std=c++20 //... diff --git a/BUILD.bazel b/BUILD.bazel index 84b393c..319421f 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -4,7 +4,9 @@ load("@rules_license//rules:license.bzl", "license") package( - default_applicable_licenses = ["//third_party/gemma_cpp:license"], + default_applicable_licenses = [ + "//:license", # Placeholder comment, do not modify + ], default_visibility = ["//visibility:public"], ) diff --git a/bazel/BUILD b/bazel/BUILD index 952624f..194a082 100644 --- a/bazel/BUILD +++ b/bazel/BUILD @@ -1,3 +1,4 @@ +# Required for referencing bazel:com_google_sentencepiece.patch package( default_applicable_licenses = ["//:license"], default_visibility = ["//:__subpackages__"], diff --git a/compression/BUILD b/compression/BUILD index ddf30c5..cfbeb99 100644 --- a/compression/BUILD +++ b/compression/BUILD @@ -1,10 +1,12 @@ # Weight compression, I/O and analysis package( - default_applicable_licenses = ["//third_party/gemma_cpp:license"], + default_applicable_licenses = [ + "//:license", # Placeholder comment, do not modify + ], default_visibility = [ # Placeholder for internal visibility, - "//third_party/gemma_cpp:__subpackages__", + "//:__subpackages__", # Placeholder, do not modify ], ) diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index a994f31..a352250 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -17,13 +17,10 @@ // copybara:import_next_line:gemma_cpp #include "gemma.h" -// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/app.h" // LoaderArgs // copybara:import_next_line:gemma_cpp #include "util/args.h" -// copybara:end -// copybara:import_next_line:gemma_cpp -#include "util/app.h" // LoaderArgs -// copybara:end #include "hwy/contrib/thread_pool/thread_pool.h" std::vector tokenize( diff --git a/gemma.cc b/gemma.cc index 032177e..a230bea 100644 --- a/gemma.cc +++ b/gemma.cc @@ -52,10 +52,7 @@ #include #include -// copybara:strip_begin -// Required because sentencepiece uses Google I/O which requires InitGoogle. // Placeholder for internal header, do not modify. -// copybara:strip_end // copybara:import_next_line:gemma_cpp #include "compression/compress.h" diff --git a/gemma.h b/gemma.h index d52356e..8c4cab8 100644 --- a/gemma.h +++ b/gemma.h @@ -23,19 +23,13 @@ // copybara:import_next_line:gemma_cpp #include "compression/compress.h" // SfpStream/NuqStream -// copybara:end // copybara:import_next_line:gemma_cpp -#include "configs.h" // kSeqLen -// copybara:end -// copybara:import_next_line:gemma_cpp -#include "util/args.h" // Path -// copybara:end +#include "util/args.h" // Path #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::bfloat16_t #include "hwy/contrib/thread_pool/thread_pool.h" // copybara:import_next_line:sentencepiece #include "src/sentencepiece_processor.h" -// copybara:end namespace gcpp { diff --git a/run.cc b/run.cc index dd7cd25..3f38031 100644 --- a/run.cc +++ b/run.cc @@ -22,19 +22,15 @@ #include // NOLINT #include -// Placeholder for internal header, do not modify. // copybara:strip +// Placeholder for internal header, do not modify. // copybara:import_next_line:gemma_cpp #include "compression/compress.h" -// copybara:end // copybara:import_next_line:gemma_cpp #include "gemma.h" // Gemma -// copybara:end // copybara:import_next_line:gemma_cpp #include "util/app.h" -// copybara:end // copybara:import_next_line:gemma_cpp #include "util/args.h" // HasHelp -// copybara:end #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" @@ -278,9 +274,7 @@ int main(int argc, char** argv) { { PROFILER_ZONE("Startup.misc"); - // copybara:strip_begin // Placeholder for internal init, do not modify. - // copybara:strip_end gcpp::LoaderArgs loader(argc, argv); gcpp::InferenceArgs inference(argc, argv); diff --git a/util/app.h b/util/app.h index cd6cb6c..ac37971 100644 --- a/util/app.h +++ b/util/app.h @@ -34,15 +34,10 @@ // copybara:import_next_line:gemma_cpp #include "configs.h" -// copybara:end - // copybara:import_next_line:gemma_cpp #include "gemma.h" -// copybara:end - // copybara:import_next_line:gemma_cpp #include "util/args.h" -// copybara:end #include "hwy/base.h" // HWY_ASSERT namespace gcpp { From a135bc1e47aedf6505853fb66a1033f754c7a283 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 21 Mar 2024 04:07:12 +0100 Subject: [PATCH 22/25] Fix build for RPi, missing hn::. Refs #112, thanks long568 PiperOrigin-RevId: 617704418 --- ops.h | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/ops.h b/ops.h index 481e1d7..7aa7b62 100644 --- a/ops.h +++ b/ops.h @@ -341,20 +341,21 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED float Dot(const float* HWY_RESTRICT a, static HWY_NOINLINE HWY_MAYBE_UNUSED float SquaredL2( const float* HWY_RESTRICT a, size_t size) { const hn::ScalableTag d; + using V = hn::Vec; const size_t N = hn::Lanes(d); HWY_DASSERT(size >= 2 * N); HWY_DASSERT(size % (2 * N) == 0); - auto sum0 = hn::Zero(d); - auto sum1 = hn::Zero(d); + V sum0 = hn::Zero(d); + V sum1 = hn::Zero(d); for (size_t i = 0; i <= size - 2 * N; i += 2 * N) { - const auto a0 = LoadU(d, a + i); - sum0 = MulAdd(a0, a0, sum0); - const auto a1 = LoadU(d, a + i + N); - sum1 = MulAdd(a1, a1, sum1); + const V a0 = hn::LoadU(d, a + i); + sum0 = hn::MulAdd(a0, a0, sum0); + const V a1 = hn::LoadU(d, a + i + N); + sum1 = hn::MulAdd(a1, a1, sum1); } - return ReduceSum(d, Add(sum0, sum1)); + return hn::ReduceSum(d, hn::Add(sum0, sum1)); } static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm( From 90b0e9fd7ac3dbe73d63abd3ed7eeb14ad6d013b Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Thu, 21 Mar 2024 14:40:56 +0800 Subject: [PATCH 23/25] Refactor the implementation of `Attention` --- gemma.cc | 81 +++++++++++++++++++++++++++++++------------------------- 1 file changed, 45 insertions(+), 36 deletions(-) diff --git a/gemma.cc b/gemma.cc index 9baccd7..533854c 100644 --- a/gemma.cc +++ b/gemma.cc @@ -320,6 +320,15 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, const size_t batch_offset = batch_idx * kModelDim; + auto ProjQ = [&](uint64_t head, size_t head_offset) HWY_ATTR { + float* HWY_RESTRICT q = + activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; + + MatVecLoop( + c_layer->c_qkv_einsum_w, head_offset + 0 * kQKVDim * kModelDim, + activations.pre_att_rms_out.data() + batch_offset, q); + }; + auto ProjKV = [&](size_t k_offset, size_t v_offset, size_t kv_offset) HWY_ATTR { TwoOfsMatVecLoop( @@ -331,39 +340,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos); }; - pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { - // linear projections to QKV - constexpr const size_t head_offset = - kHeads == kKVHeads ? 3 * kQKVDim * kModelDim : kQKVDim * kModelDim; - const size_t q_offset = head * head_offset + 0 * kQKVDim * kModelDim; - - float* HWY_RESTRICT q = - activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; - - MatVecLoop( - c_layer->c_qkv_einsum_w, q_offset, - activations.pre_att_rms_out.data() + batch_offset, q); - - if constexpr (kHeads == kKVHeads) { - const size_t k_offset = head * head_offset + 1 * kQKVDim * kModelDim; - const size_t v_offset = head * head_offset + 2 * kQKVDim * kModelDim; - const size_t kv_offset = - pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; - - ProjKV(k_offset, v_offset, kv_offset); - } - }); - - if constexpr (kHeads != kKVHeads) { - constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim; - constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim; - constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim; - const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize; - - ProjKV(k_offset, v_offset, kv_offset); - } - - pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + auto Attn = [&](uint64_t head, size_t head_offset) HWY_ATTR { // Calculate scores float* HWY_RESTRICT q = activations.q.data() + head * kQKVDim + batch_idx * kHeads * kQKVDim; @@ -374,8 +351,6 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, Rope(q, kQKVDim, pos); MulByConst(kQueryScale, q, kQKVDim); - const size_t head_offset = kHeads == kKVHeads ? head * kQKVDim : 0; - // Compute Q dot K scores for (size_t pos2 = 0; pos2 <= pos; ++pos2) { const size_t cache_offset = @@ -405,7 +380,41 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, MatVecLoop(c_layer->c_attn_vec_einsum_w, head * kModelDim * kQKVDim, att_out, head_out); - }); + }; + + if constexpr (kHeads == kKVHeads) { + // Multi-Head Attention + pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + const size_t head_offset = head * 3 * kQKVDim * kModelDim; + + ProjQ(head, head_offset); + + const size_t k_offset = head_offset + 1 * kQKVDim * kModelDim; + const size_t v_offset = head_offset + 2 * kQKVDim * kModelDim; + const size_t kv_offset = + pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim; + + ProjKV(k_offset, v_offset, kv_offset); + + Attn(head, head * kQKVDim); + }); + } else { + // Multi-Query Attention + pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + ProjQ(head, head * kQKVDim * kModelDim); + }); + + constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim; + constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim; + constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim; + const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize; + + ProjKV(k_offset, v_offset, kv_offset); + + pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { + Attn(head, 0); + }); + } // accumulate output across all heads into att_post2. head 0 already wrote // directly to att_post2. From 24add61dd9c1c2a6a6cc8092124cd1b1d7cd909a Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 21 Mar 2024 19:05:44 -0700 Subject: [PATCH 24/25] Fix SFP/NUQ for bf16 rounding in Highway SFP: Avoid rounding twice, and more robust TestDot. NUQ: also more robust SNR, minor touchups to header. PiperOrigin-RevId: 618030096 --- CMakeLists.txt | 1 + compression/BUILD | 37 ++++++++++++----- compression/nuq-inl.h | 23 ++++++----- compression/nuq_test.cc | 89 ++++++++++++++++++++++++++--------------- compression/sfp-inl.h | 19 +++++++-- compression/sfp_test.cc | 82 +++++++++++++++++++++++-------------- compression/test_util.h | 64 +++++++++++++++++++++++++++++ 7 files changed, 230 insertions(+), 85 deletions(-) create mode 100644 compression/test_util.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 1d2a7a0..9efce80 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,6 +40,7 @@ set(SOURCES compression/nuq-inl.h compression/sfp.h compression/sfp-inl.h + compression/test_util.h util/app.h util/args.h ) diff --git a/compression/BUILD b/compression/BUILD index cfbeb99..5b525a5 100644 --- a/compression/BUILD +++ b/compression/BUILD @@ -24,20 +24,36 @@ cc_library( ], ) +# Deprecated because it is also implemented in Highway; will be removed once +# that Highway version is sufficiently widespread. cc_library( name = "stats", - srcs = [ - "stats.cc", - ], - hdrs = [ - "distortion.h", - "stats.h", - ], + srcs = ["stats.cc"], + hdrs = ["stats.h"], deps = [ "@hwy//:hwy", ], ) +cc_library( + name = "distortion", + hdrs = ["distortion.h"], + deps = [ + "@hwy//:hwy", + ], +) + +cc_library( + name = "test_util", + hdrs = ["test_util.h"], + deps = [ + ":distortion", + ":stats", + "@hwy//:hwy", + "@hwy//:hwy_test_util", + ], +) + cc_library( name = "sfp", hdrs = [ @@ -62,12 +78,11 @@ cc_test( tags = ["hwy_ops_test"], deps = [ ":sfp", - ":stats", + ":test_util", "@googletest//:gtest_main", "@hwy//:hwy", "@hwy//:hwy_test_util", "@hwy//:nanobenchmark", - "@hwy//:thread_pool", ], ) @@ -98,7 +113,7 @@ cc_test( deps = [ ":nuq", ":sfp", - ":stats", + ":test_util", "@googletest//:gtest_main", "@hwy//:hwy", "@hwy//:hwy_test_util", @@ -118,6 +133,7 @@ cc_library( ], deps = [ ":blob_store", + ":distortion", ":nuq", ":sfp", ":stats", @@ -134,6 +150,7 @@ cc_library( "analyze.h", ], deps = [ + ":distortion", ":nuq", ":sfp", ":stats", diff --git a/compression/nuq-inl.h b/compression/nuq-inl.h index e7d85a7..1c8bdf1 100644 --- a/compression/nuq-inl.h +++ b/compression/nuq-inl.h @@ -42,6 +42,10 @@ #include "hwy/contrib/sort/vqsort-inl.h" #include "hwy/highway.h" +#ifndef HWY_IF_CONSTEXPR +#define HWY_IF_CONSTEXPR if +#endif + HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { @@ -124,7 +128,7 @@ class NuqClustering { } private: - // Float has enough precision for our relatively small kGroupSize (128). + // Float has enough precision for our relatively small kGroupSize (256). float cumsum_[kGroupSize + 1]; float cumsum2_[kGroupSize + 1]; float inv_len_[kGroupSize + 1]; @@ -168,8 +172,8 @@ class NuqClustering { // `centers`; prior centers are zero-initialized. // // O(kClusters * kGroupSize * kGroupSize), but the constant factors are so low - // that this is about 10 times as fast as the O(kClusters * kGroupSize) SMAWK - // as implemented in FAISS, for our kGroupSize <= 128. + // that this is about 5 times as fast as the O(kClusters * kGroupSize) SMAWK + // as implemented in FAISS, for our kGroupSize of 256. template static HWY_NOINLINE size_t ClusterExactL2(DF df, const float* x, ClusterBuf& buf, @@ -228,7 +232,7 @@ class NuqClustering { // Center = mean, O(1) thanks to cumulative sums. const float sum = cc.SumOfSorted(start, last); const int size = static_cast(last) - static_cast(start) + 1; - HWY_DASSERT(0 < size && size <= kGroupSize); + HWY_DASSERT(0 < size && size <= static_cast(kGroupSize)); centers[k] = sum / static_cast(size); // We know the range inside sorted_and_i[]; translate to original indices, @@ -427,7 +431,7 @@ class NuqCodec { // instead of TableLookupBytes, which requires extra interleaving of lo/hi. HWY_DASSERT(hn::Lanes(du) >= 8); - if (NumTables(du) == 2) { + HWY_IF_CONSTEXPR(NumTables(du) == 2) { // Reduce cap for second half to avoid loading past the end of the table. const hn::CappedTag d_table2; *tbl1 = hn::ResizeBitCast(du, hn::LoadU(d_table2, table + kClusters / 2)); @@ -449,11 +453,12 @@ class NuqCodec { const auto indices0 = hn::IndicesFromVec(du, idx0); const auto indices1 = hn::IndicesFromVec(du, idx1); - if (NumTables(du) == 1) { + HWY_IF_CONSTEXPR(NumTables(du) == 1) { (void)tbl1; c0 = hn::TableLookupLanes(tbl0, indices0); c1 = hn::TableLookupLanes(tbl0, indices1); - } else { + } + HWY_IF_CONSTEXPR(NumTables(du) == 2) { // `else` is poorly formatted. c0 = hn::TwoTablesLookupLanes(du, tbl0, tbl1, indices0); c1 = hn::TwoTablesLookupLanes(du, tbl0, tbl1, indices1); } @@ -521,8 +526,8 @@ class NuqCodec { // Decodes `num` values from the stream `in`, starting at the offset `in_ofs` // (in units of values), to bf16 in `out`. `in_capacity`, `in_ofs` and `num` // must all be multiples of `kGroupSize`. - template - static HWY_INLINE void Dec(DF dbf, const size_t in_capacity, + template + static HWY_INLINE void Dec(DBF dbf, const size_t in_capacity, const NuqStream* const in, const size_t in_ofs, hwy::bfloat16_t* const out, const size_t num) { const hn::RebindToUnsigned d16; diff --git a/compression/nuq_test.cc b/compression/nuq_test.cc index d679376..7045ba4 100644 --- a/compression/nuq_test.cc +++ b/compression/nuq_test.cc @@ -27,6 +27,7 @@ #include "hwy/aligned_allocator.h" #include "hwy/base.h" +#include "hwy/timer.h" // clang-format off #undef HWY_TARGET_INCLUDE @@ -35,15 +36,14 @@ #include "hwy/foreach_target.h" // IWYU pragma: keep // Other headers that include Highway must come after foreach_target.h // copybara:import_next_line:gemma_cpp -#include "compression/distortion.h" -// copybara:import_next_line:gemma_cpp #include "compression/nuq-inl.h" // copybara:import_next_line:gemma_cpp #include "compression/nuq.h" +// copybara:import_next_line:gemma_cpp +#include "compression/test_util.h" #include "hwy/highway.h" #include "hwy/tests/hwy_gtest.h" #include "hwy/tests/test_util-inl.h" -#include "hwy/timer.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { @@ -181,12 +181,14 @@ struct TestNormal { auto in = hwy::AllocateAligned(kGroupSize); HWY_ASSERT(in); - std::mt19937 rng(123); - std::normal_distribution dist{0.001f, 0.3f}; + hwy::RandomState rng; + Stats in_stats; for (size_t i = 0; i < kGroupSize; ++i) { - in[i] = dist(rng); + const double r = RandomGaussian(rng); + in_stats.Notify(r); + in[i] = hwy::ConvertScalarTo(r); } - std::shuffle(in.get(), in.get() + kGroupSize, rng); + VerifyGaussian(in_stats); ClusterBuf buf; float centers[kClusters]; @@ -212,9 +214,9 @@ struct TestNormal { const float snr = stats.GeomeanValueDivL1(); fprintf(stderr, "p-norm %.3E snr %.2f @%zu = %.4E\n", pnorm, snr, stats.MaxIndex(), stats.MaxL1()); - static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected"); - const float expected_pnorm = kGroupSize == 128 ? 3E-2f : 3.4E-2f; - const float expected_snr = kGroupSize == 128 ? 17.4f : 13.1f; + static_assert(kGroupSize == 256, "Update expected"); + const float expected_pnorm = 3.68E-2f; + const float expected_snr = 12.7f; HWY_ASSERT(expected_pnorm <= pnorm && pnorm < 1.02f * expected_pnorm); HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr); } @@ -345,21 +347,27 @@ struct TestDot { auto nuq = hwy::AllocateAligned(NuqStream::PackedEnd(num)); HWY_ASSERT(in && dec && vec && nuq); - std::mt19937 rng(123); - std::normal_distribution dist{0.001f, 0.3f}; + // Generate inputs and verify their distribution. + hwy::RandomState rng; + Stats in_stats; for (size_t i = 0; i < num; ++i) { - in[i] = dist(rng); - vec[i] = hwy::ConvertScalarTo(dist(rng)); + const float r = static_cast(RandomGaussian(rng)); + in_stats.Notify(r); + in[i] = r; } - // This changes the correlation between in and vec, which considerably - // affects the error of the result. - std::shuffle(in.get(), in.get() + num, rng); + for (size_t i = 0; i < num; ++i) { + const float r = static_cast(RandomGaussian(rng)); + in_stats.Notify(r); + vec[i] = hwy::ConvertScalarTo(r); + } + VerifyGaussian(in_stats); ClusterBuf buf; const size_t unused_clusters = NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0); HWY_ASSERT(unused_clusters == 0); + // Compute dot product without decompression. double actual = 0.0; double elapsed = hwy::HighestValue(); for (size_t rep = 0; rep < 20; ++rep) { @@ -380,24 +388,39 @@ struct TestDot { fprintf(stderr, "Vec %zu Dec %.2f MB/s\n", Lanes(d) * sizeof(T), num * sizeof(in[0]) * 1E-6 / elapsed); - double expected = 0.0; // using original input - double expected2 = 0.0; // using decoded NUQ + // Exact and decompressed dot products for comparison. + double exact = 0.0; // using original input + double expected = 0.0; // using decoded NUQ + DistortionStats dec_stats; + Stats ratios; for (size_t i = 0; i < num; ++i) { - expected += in[i] * hwy::ConvertScalarTo(vec[i]); - expected2 += dec[i] * hwy::ConvertScalarTo(vec[i]); + dec_stats.Notify(in[i], dec[i]); + const float v1 = hwy::ConvertScalarTo(vec[i]); + exact += in[i] * v1; + expected += dec[i] * v1; + if (expected != 0.0f) { + ratios.Notify(exact / expected); + } } - const double l1 = hwy::ScalarAbs(expected - actual); - const double snr = 1.0 + hwy::ScalarAbs(expected) / l1; - fprintf(stderr, "expected %.3f e2 %.4f actual %.4f l1 %E snr %.2f\n", - expected, expected2, actual, l1, snr); - HWY_ASSERT(hwy::ScalarAbs(expected2 - actual) < 1E-4); - static_assert(kGroupSize == 128 || kGroupSize == 256, "Update expected"); - const double expected_l1 = kGroupSize == 128 ? 7.3E-2 : 4.34E-2; - const double expected_snr = kGroupSize == 128 ? 9.7f - : sizeof(T) == 2 ? 14.5f - : 14.9f; - HWY_ASSERT(expected_l1 <= l1 && l1 < 1.02f * expected_l1); - HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr); + const double dec_snr = dec_stats.GeomeanValueDivL1(); + const double dot_snr = 1.0 / hwy::ScalarAbs(1.0 - ratios.GeometricMean()); + // exact and actual fluctuate due to the combination of NUQ imprecision, + // and whether vec[i] is negative or positive, so this is quite loose. + const float final_ratio = HWY_MIN(exact / actual, actual / exact); + fprintf(stderr, "ratios %s\n", ratios.ToString().c_str()); + fprintf(stderr, + "exact %.3f e2 %.4f actual %.4f final_ratio %.3f dec_snr %.2f " + "dot_snr %.2f\n", + exact, expected, actual, final_ratio, dec_snr, dot_snr); + // Final values are not too far apart. + HWY_ASSERT(0.88f <= final_ratio && final_ratio <= 1.0f); + // Decompressed and uncompressed dot should match exactly. + HWY_ASSERT(hwy::ScalarAbs(expected - actual) < 1E-4f); + // dec[] is close to in[], but we already check that in TestStream. + HWY_ASSERT(dec_snr >= 13.0); + // Geomean of ratios for each i is an approximation of the actual SNR. + HWY_ASSERT(dot_snr >= (sizeof(T) == 2 ? 17.0 : 14.0)); + static_assert(kGroupSize == 256, "Update expected*"); } }; diff --git a/compression/sfp-inl.h b/compression/sfp-inl.h index 62b8955..77f7ede 100644 --- a/compression/sfp-inl.h +++ b/compression/sfp-inl.h @@ -449,6 +449,18 @@ class SfpCodec { return Enc2U(d16, w0, w1); } + // Truncates two f32 to bf16, in lane order, without rounding (see Enc4F). + template > + static HWY_INLINE hn::Vec Truncate2To(DBF dbf, hn::Vec f0, + hn::Vec f1) { + const hn::RebindToUnsigned d16; + using V16 = hn::Vec; + const V16 u0 = BitCast(d16, f0); + const V16 u1 = BitCast(d16, f1); + return BitCast(DBF(), HWY_IS_LITTLE_ENDIAN ? ConcatOdd(d16, u1, u0) + : ConcatEven(d16, u1, u0)); + } + template >> static HWY_INLINE V8 Enc4F(DF df, const float* HWY_RESTRICT in) { @@ -462,9 +474,10 @@ class SfpCodec { const VF f1 = hn::LoadU(df, in + NF * 1); const VF f2 = hn::LoadU(df, in + NF * 2); const VF f3 = hn::LoadU(df, in + NF * 3); - // Chop off the lower 16 bits; EncBytes still rounds properly. - const V16 w0 = hn::BitCast(d16, hn::OrderedDemote2To(dbf, f0, f1)); - const V16 w1 = hn::BitCast(d16, hn::OrderedDemote2To(dbf, f2, f3)); + // Chop off the lower 16 bits instead of OrderedDemote2To, which rounds to + // the nearest bf16, because EncBytes will round again. + const V16 w0 = hn::BitCast(d16, Truncate2To(dbf, f0, f1)); + const V16 w1 = hn::BitCast(d16, Truncate2To(dbf, f2, f3)); return Enc2U(d16, w0, w1); } diff --git a/compression/sfp_test.cc b/compression/sfp_test.cc index b51505f..b362728 100644 --- a/compression/sfp_test.cc +++ b/compression/sfp_test.cc @@ -25,12 +25,11 @@ #include #include -#include -#include #include #include "hwy/aligned_allocator.h" #include "hwy/base.h" +#include "hwy/timer.h" // clang-format off #undef HWY_TARGET_INCLUDE @@ -39,13 +38,12 @@ #include "hwy/foreach_target.h" // IWYU pragma: keep // Any highway.h must come after foreach_target.h // copybara:import_next_line:gemma_cpp -#include "compression/distortion.h" -// copybara:import_next_line:gemma_cpp #include "compression/sfp-inl.h" +// copybara:import_next_line:gemma_cpp +#include "compression/test_util.h" #include "hwy/highway.h" #include "hwy/tests/hwy_gtest.h" #include "hwy/tests/test_util-inl.h" -#include "hwy/timer.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { @@ -358,25 +356,31 @@ struct TestDot { template HWY_INLINE void operator()(T /*unused*/, D d) { const hn::Repartition df; - const size_t num = 384; + const size_t num = 1024; // not too many for GeometricMean overflow. auto in = hwy::AllocateAligned(num); auto dec = hwy::AllocateAligned(num); auto vec = hwy::AllocateAligned(num); auto sfp = hwy::AllocateAligned(num); HWY_ASSERT(in && dec && vec && sfp); - std::mt19937 rng(123); - std::normal_distribution dist{0.001f, 0.3f}; + // Generate inputs and verify their distribution. + hwy::RandomState rng; + Stats in_stats; for (size_t i = 0; i < num; ++i) { - in[i] = hwy::ConvertScalarTo(dist(rng)); - vec[i] = hwy::ConvertScalarTo(dist(rng)); + const float r = static_cast(RandomGaussian(rng)); + in_stats.Notify(r); + in[i] = hwy::ConvertScalarTo(r); } - // This changes the correlation between in and vec, which considerably - // affects the error of the result. - std::shuffle(in.get(), in.get() + num, rng); + for (size_t i = 0; i < num; ++i) { + const float r = static_cast(RandomGaussian(rng)); + in_stats.Notify(r); + vec[i] = hwy::ConvertScalarTo(r); + } + VerifyGaussian(in_stats); SfpCodec::Enc(d, in.get(), num, sfp.get()); + // Compute dot product without decompression. double actual = 0.0; double elapsed = hwy::HighestValue(); for (size_t rep = 0; rep < 200; ++rep) { @@ -393,26 +397,44 @@ struct TestDot { } SfpCodec::Dec(d, sfp.get(), num, dec.get()); - fprintf(stderr, "Vec %zu Dot %.2f MB/s\n", Lanes(d) * sizeof(T), - num * sizeof(T) * 1E-6 / elapsed); + fprintf(stderr, "Vec %zu Dot %zu-bit %.2f MB/s\n", Lanes(d) * sizeof(T), + sizeof(T) * 8, num * sizeof(T) * 1E-6 / elapsed); - double expected = 0.0; // using original input - double expected2 = 0.0; // using decoded SFP + // Exact and decompressed dot products for comparison. + float exact = 0.0f; // using original input + float expected = 0.0f; // using decoded SFP + DistortionStats dec_stats; + Stats ratios; for (size_t i = 0; i < num; ++i) { - expected += hwy::ConvertScalarTo(in[i]) * - hwy::ConvertScalarTo(vec[i]); - expected2 += hwy::ConvertScalarTo(dec[i]) * - hwy::ConvertScalarTo(vec[i]); + const float in1 = hwy::ConvertScalarTo(in[i]); + const float dec1 = hwy::ConvertScalarTo(dec[i]); + const float vec1 = hwy::ConvertScalarTo(vec[i]); + dec_stats.Notify(in1, dec1); + + exact += in1 * vec1; + expected += dec1 * vec1; + if (expected != 0.0f) { + ratios.Notify(exact / expected); + } } - const double l1 = hwy::ScalarAbs(expected - actual); - const double snr = 1.0 + hwy::ScalarAbs(expected) / l1; - fprintf(stderr, "expected %.3f e2 %.4f actual %.4f l1 %E snr %.2f\n", - expected, expected2, actual, l1, snr); - HWY_ASSERT(hwy::ScalarAbs(expected2 - actual) < 1E-4); - const double expected_l1 = sizeof(T) == 2 ? 1.52E-2 : 1.15E-2; - const double expected_snr = sizeof(T) == 2 ? 80.1f : 104.9f; - HWY_ASSERT(expected_l1 <= l1 && l1 < 1.02f * expected_l1); - HWY_ASSERT(expected_snr <= snr && snr < 1.01f * expected_snr); + const double dec_snr = dec_stats.GeomeanValueDivL1(); + const double dot_snr = 1.0 / hwy::ScalarAbs(1.0 - ratios.GeometricMean()); + // exact and actual fluctuate due to the combination of SFP imprecision, + // and whether vec[i] is negative or positive, so this is quite loose. + const float final_ratio = HWY_MIN(exact / actual, actual / exact); + fprintf(stderr, "ratios %s\n", ratios.ToString().c_str()); + fprintf(stderr, + "exact %.3f e2 %.4f actual %.4f final_ratio %.3f dec_snr %.2f " + "dot_snr %.2f\n", + exact, expected, actual, final_ratio, dec_snr, dot_snr); + // Final values are not too far apart. + HWY_ASSERT(0.87f <= final_ratio && final_ratio <= 1.0f); + // Decompressed and uncompressed dot should match exactly. + HWY_ASSERT(hwy::ScalarAbs(expected - actual) < 1E-4f); + // dec[] is close to in[], but we already check that in TestEncDec. + HWY_ASSERT(dec_snr >= 50.0); + // Geomean of ratios for each i should be very close to one. + HWY_ASSERT(dot_snr >= (sizeof(T) == 2 ? 70.0 : 1000.0)); } }; diff --git a/compression/test_util.h b/compression/test_util.h new file mode 100644 index 0000000..b1e4026 --- /dev/null +++ b/compression/test_util.h @@ -0,0 +1,64 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_H_ +#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_H_ + +#include +#include + +#include + +#include "hwy/base.h" + +// IWYU pragma: begin_exports +// copybara:import_next_line:gemma_cpp +#include "compression/distortion.h" +// copybara:import_next_line:gemma_cpp +#include "compression/stats.h" +#include "hwy/tests/test_util.h" // RandomState +// IWYU pragma: end_exports + +namespace gcpp { + +// Returns random Gaussian (mean=0, stddev=1/3 similar to expected weights) +// using the central limit theorem. Avoid std::normal_distribution for +// consistent cross-platform output. +HWY_INLINE double RandomGaussian(hwy::RandomState& rng) { + uint64_t sum = 0; + constexpr int kReps = 40; + for (int rep = 0; rep < kReps; ++rep) { + sum += hwy::Random32(&rng) & 0xFFFFF; + } + const double sum_f = + static_cast(sum) / static_cast(0xFFFFF * kReps); + HWY_ASSERT(0.0 <= sum_f && sum_f <= 1.0); + const double plus_minus_1 = 2.0 * sum_f - 1.0; + HWY_ASSERT(-1.0 <= plus_minus_1 && plus_minus_1 <= 1.0); + // Normalize by stddev of sum of uniform random scaled to [-1, 1]. + return plus_minus_1 * std::sqrt(kReps / 3.0); +}; + +HWY_INLINE void VerifyGaussian(Stats& stats) { + const double stddev = stats.StandardDeviation(); + HWY_ASSERT(-0.01 <= stats.Mean() && stats.Mean() <= 0.01); + HWY_ASSERT(0.30 <= stddev && stddev <= 0.35); + HWY_ASSERT(-1.1 <= stats.Min() && stats.Min() <= -0.9); + HWY_ASSERT(0.9 <= stats.Max() && stats.Max() <= 1.1); +} + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_H_ From 61e031fe984d199da53f839948064811c46a77c5 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 21 Mar 2024 19:26:27 -0700 Subject: [PATCH 25/25] Towards building tests without GUnit Refs #29 PiperOrigin-RevId: 618032987 --- compression/nuq_test.cc | 3 +++ compression/sfp_test.cc | 3 +++ 2 files changed, 6 insertions(+) diff --git a/compression/nuq_test.cc b/compression/nuq_test.cc index 7045ba4..8e34b4d 100644 --- a/compression/nuq_test.cc +++ b/compression/nuq_test.cc @@ -452,6 +452,9 @@ HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamF32); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamBF16); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllDotF32); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllDotBF16); +#ifdef HWY_AFTER_TEST +HWY_AFTER_TEST(); +#endif } // namespace gcpp #endif diff --git a/compression/sfp_test.cc b/compression/sfp_test.cc index b362728..1a4e4ec 100644 --- a/compression/sfp_test.cc +++ b/compression/sfp_test.cc @@ -463,6 +463,9 @@ HWY_EXPORT_AND_TEST_P(SfpTest, TestAllEncDec); HWY_EXPORT_AND_TEST_P(SfpTest, TestAllOrder); HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDotF32); HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDotBF16); +#ifdef HWY_AFTER_TEST +HWY_AFTER_TEST(); +#endif } // namespace gcpp #endif