From d4050a2917345232867f5e3dd35d8799269b2d6e Mon Sep 17 00:00:00 2001 From: Paul Chang Date: Thu, 7 Nov 2024 10:27:57 -0800 Subject: [PATCH 1/9] Expose BlobReader::Keys() PiperOrigin-RevId: 694166186 --- compression/blob_store.cc | 8 ++++++++ compression/blob_store.h | 3 +++ compression/blob_store_test.cc | 5 +++++ 3 files changed, 16 insertions(+) diff --git a/compression/blob_store.cc b/compression/blob_store.cc index 659946c..06bcb56 100644 --- a/compression/blob_store.cc +++ b/compression/blob_store.cc @@ -201,6 +201,10 @@ class BlobStore { return false; } + hwy::Span Keys() const { + return hwy::Span(keys_, num_blobs_); + } + private: uint32_t magic_; uint32_t num_blobs_; // never 0 @@ -303,6 +307,10 @@ BlobError BlobReader::ReadOne(hwy::uint128_t key, void* data, return 0; } +hwy::Span BlobReader::Keys() const { + return blob_store_->Keys(); +} + BlobError BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) { HWY_ASSERT(keys_.size() == blobs_.size()); diff --git a/compression/blob_store.h b/compression/blob_store.h index 4aba006..94bbace 100644 --- a/compression/blob_store.h +++ b/compression/blob_store.h @@ -84,6 +84,9 @@ class BlobReader { // Reads one blob directly. BlobError ReadOne(hwy::uint128_t key, void* data, size_t size) const; + // Returns all available blob keys. + hwy::Span Keys() const; + private: BlobStorePtr blob_store_; // holds header, not the entire file std::vector requests_; diff --git a/compression/blob_store_test.cc b/compression/blob_store_test.cc index 6464756..dbba55f 100644 --- a/compression/blob_store_test.cc +++ b/compression/blob_store_test.cc @@ -70,6 +70,11 @@ TEST(BlobStoreTest, TestReadWrite) { HWY_ASSERT_STRING_EQ("DATA", buffer.data()); } + const hwy::Span keys = reader.Keys(); + HWY_ASSERT_EQ(keys.size(), 2); + HWY_ASSERT_EQ(keys[0], keyA); + HWY_ASSERT_EQ(keys[1], keyB); + close(fd); unlink(path_str); } From e54d9cbdddf98b79c698cbb4d08dd3c0c88325fc Mon Sep 17 00:00:00 2001 From: Daniel Keysers Date: Fri, 8 Nov 2024 08:30:04 -0800 Subject: [PATCH 2/9] Fix Griffin model: - use HalfRope position encodings - zero-initialize the caches for each Generate at position 0 The lack of the latter made the tests in gemma_test dependent on each other. PiperOrigin-RevId: 694509054 --- evals/gemma_test.cc | 6 +++--- gemma/configs.cc | 2 +- gemma/configs_test.cc | 6 +++++- gemma/gemma-inl.h | 11 ++++++++--- gemma/kv_cache.cc | 23 ++++++++++++++++------- gemma/kv_cache.h | 6 ++++++ 6 files changed, 39 insertions(+), 15 deletions(-) diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 98029fe..676a5d2 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -246,7 +246,7 @@ TEST_F(GemmaTest, CrossEntropySmall) { EXPECT_NEAR(entropy, 2.8f, 0.2f); break; case gcpp::Model::GRIFFIN_2B: - EXPECT_NEAR(entropy, 1.57f, 0.02f); + EXPECT_NEAR(entropy, 2.61f, 0.02f); break; case gcpp::Model::GEMMA2_2B: EXPECT_NEAR(entropy, 1.14f, 0.02f); @@ -277,7 +277,7 @@ TEST_F(GemmaTest, CrossEntropyJingleBells) { EXPECT_NEAR(entropy, 1.07f, 0.05f); break; case gcpp::Model::GRIFFIN_2B: - EXPECT_NEAR(entropy, 2.09f, 0.02f); + EXPECT_NEAR(entropy, 1.62f, 0.02f); break; case gcpp::Model::GEMMA2_2B: EXPECT_NEAR(entropy, 0.49f, 0.02f); @@ -308,7 +308,7 @@ TEST_F(GemmaTest, CrossEntropyGettysburg) { EXPECT_NEAR(entropy, 0.75f, 0.1f); break; case gcpp::Model::GRIFFIN_2B: - EXPECT_NEAR(entropy, 0.86f, 0.02f); + EXPECT_NEAR(entropy, 0.71f, 0.02f); break; case gcpp::Model::GEMMA2_2B: EXPECT_NEAR(entropy, 0.20f, 0.02f); diff --git a/gemma/configs.cc b/gemma/configs.cc index 326b18e..03fce99 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -183,7 +183,7 @@ static ModelConfig ConfigGriffin2B() { .softmax_attn_output_biases = true, .type = LayerAttentionType::kGriffinRecurrentBlock, .activation = ActivationType::Gelu, - .post_qk = PostQKType::Rope, + .post_qk = PostQKType::HalfRope, }; config.layer_configs = {26, layer_config}; for (size_t i = 2; i < config.layer_configs.size(); i += 3) { diff --git a/gemma/configs_test.cc b/gemma/configs_test.cc index a6668a4..91bfc53 100644 --- a/gemma/configs_test.cc +++ b/gemma/configs_test.cc @@ -397,7 +397,11 @@ void AssertMatch(const ModelConfig& config) { ASSERT_EQ(TConfig::kPostNorm, config.layer_configs[i].post_norm); ASSERT_EQ(TConfig::kLayerConfig[i], config.layer_configs[i].type); ASSERT_EQ(TConfig::kActivation, config.layer_configs[i].activation); - ASSERT_EQ(TConfig::kPostQK, config.layer_configs[i].post_qk); + PostQKType post_qk = TConfig::kPostQK; + if (TConfig::kUseHalfRope) { + post_qk = PostQKType::HalfRope; + } + ASSERT_EQ(post_qk, config.layer_configs[i].post_qk); } ASSERT_EQ(TConfig::kAttentionWindowSizes.size(), diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index c58f9a8..2b1587d 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -1240,8 +1240,12 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, const QueriesPos& queries_prefix_end, const size_t query_idx_start, const KVCaches& kv_caches, TimingInfo& timing_info) { - const size_t vocab_size = model.Config().vocab_size; - const ModelWeightsPtrs& weights = *model.GetWeightsOfType(); + // Griffin assumes that the recurrent block cache is zero-initialized. + for (size_t i = 0; i < kv_caches.size(); ++i) { + if (queries_pos_in[i] == 0) { + kv_caches[i].ZeroGriffinCache(); // No-op for non-Griffin models. + } + } // Copy so we can increment without requiring users to pass in a mutable span. std::vector queries_pos_copy(queries_pos_in.cbegin(), @@ -1268,7 +1272,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, HWY_ASSERT(queries_pos_in.size() == num_queries); HWY_ASSERT(kv_caches.size() == num_queries); const hwy::Divisor div_seq_len(static_cast(kv_caches[0].seq_len)); - + const ModelWeightsPtrs& weights = *model.GetWeightsOfType(); size_t max_prompt_size = MaxQueryLength(queries_prompt); size_t max_generated_tokens = runtime_config.max_generated_tokens; RangeChecks(weights.weights_config, max_generated_tokens, max_prompt_size); @@ -1314,6 +1318,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, 0.0f); } + const size_t vocab_size = model.Config().vocab_size; const double gen_start = hwy::platform::Now(); for (size_t gen = 0; gen < max_generated_tokens; ++gen) { // Decode generates one token per query and increments queries_mutable_pos. diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index cc9db89..82ee01d 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -23,6 +23,17 @@ namespace gcpp { +void KVCache::ZeroGriffinCache() { + if (conv1d_cache_size != 0) { + hwy::ZeroBytes(conv1d_cache.get(), + conv1d_cache_size * sizeof(conv1d_cache[0])); + } + if (rglru_cache_size != 0) { + hwy::ZeroBytes(rglru_cache.get(), + rglru_cache_size * sizeof(rglru_cache[0])); + } +} + // prefill_tbatch_size is the maximum number of tokens from one query to // prefill at a time. KVCache KVCache::Create(const ModelConfig& weights_config, @@ -37,9 +48,9 @@ KVCache KVCache::Create(const ModelConfig& weights_config, kv_cache.kv_cache = hwy::AllocateAligned(kv_cache.seq_len * size_cache_pos); } - size_t num_griffin_layers = weights_config.NumLayersOfType( - LayerAttentionType::kGriffinRecurrentBlock); + const size_t num_griffin_layers = weights_config.NumLayersOfType( + LayerAttentionType::kGriffinRecurrentBlock); // TODO(patrickms): Add query batching support for Griffin. if (num_griffin_layers > 0) { size_t conv1d_width = 0; @@ -49,20 +60,18 @@ KVCache KVCache::Create(const ModelConfig& weights_config, const size_t conv1d_cache_size = num_griffin_layers * (conv1d_width == 0 ? 0 : conv1d_width - 1) * weights_config.model_dim; + kv_cache.conv1d_cache_size = conv1d_cache_size; if (conv1d_cache_size != 0) { kv_cache.conv1d_cache = hwy::AllocateAligned(conv1d_cache_size); - hwy::ZeroBytes(kv_cache.conv1d_cache.get(), - conv1d_cache_size * sizeof(kv_cache.conv1d_cache[0])); } const size_t rglru_cache_size = num_griffin_layers * weights_config.model_dim; + kv_cache.rglru_cache_size = rglru_cache_size; if (rglru_cache_size != 0) { kv_cache.rglru_cache = hwy::AllocateAligned(rglru_cache_size); - hwy::ZeroBytes(kv_cache.rglru_cache.get(), - rglru_cache_size * sizeof(kv_cache.rglru_cache[0])); } - } // kGriffinLayers + } // num_griffin_layers return kv_cache; } diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index 9c46d93..69f9564 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -31,9 +31,15 @@ struct KVCache { // (kConv1dWidth - 1) * kModelDim * kGriffinLayers hwy::AlignedFreeUniquePtr conv1d_cache; + size_t conv1d_cache_size = 0; // kModelDim * kGriffinLayers hwy::AlignedFreeUniquePtr rglru_cache; + size_t rglru_cache_size = 0; + + // Zero-initialize the Griffin recurrent block cache, i.e. the conv1d_cache + // and rglru_cache. + void ZeroGriffinCache(); static KVCache Create(const ModelConfig& weights_config, size_t prefill_tbatch_size); From b94295b6d909a1512dd8431eefa85d3393a7006b Mon Sep 17 00:00:00 2001 From: Paul Chang Date: Wed, 13 Nov 2024 09:01:02 -0800 Subject: [PATCH 3/9] Internal changes PiperOrigin-RevId: 696155630 --- bazel/{BUILD => BUILD.bazel} | 0 compression/BUILD.bazel | 2 +- compression/python/{BUILD => BUILD.bazel} | 0 paligemma/{BUILD => BUILD.bazel} | 2 +- 4 files changed, 2 insertions(+), 2 deletions(-) rename bazel/{BUILD => BUILD.bazel} (100%) rename compression/python/{BUILD => BUILD.bazel} (100%) rename paligemma/{BUILD => BUILD.bazel} (94%) diff --git a/bazel/BUILD b/bazel/BUILD.bazel similarity index 100% rename from bazel/BUILD rename to bazel/BUILD.bazel diff --git a/compression/BUILD.bazel b/compression/BUILD.bazel index d5c90c3..f16b4d7 100644 --- a/compression/BUILD.bazel +++ b/compression/BUILD.bazel @@ -6,7 +6,7 @@ package( ], default_visibility = [ # Placeholder for internal visibility, - "//:__subpackages__", # Placeholder, do not modify + "//:__subpackages__", ], ) diff --git a/compression/python/BUILD b/compression/python/BUILD.bazel similarity index 100% rename from compression/python/BUILD rename to compression/python/BUILD.bazel diff --git a/paligemma/BUILD b/paligemma/BUILD.bazel similarity index 94% rename from paligemma/BUILD rename to paligemma/BUILD.bazel index 01514e2..0710423 100644 --- a/paligemma/BUILD +++ b/paligemma/BUILD.bazel @@ -3,7 +3,7 @@ package( "//:license", # Placeholder comment, do not modify ], default_visibility = [ - "//:__subpackages__", # Placeholder, do not modify + "//:__subpackages__", ], ) From 719699f132df2733ff5cf74134e703f6d753dbb7 Mon Sep 17 00:00:00 2001 From: Daniel Keysers Date: Wed, 13 Nov 2024 09:48:25 -0800 Subject: [PATCH 4/9] Make top_k a runtime argument (instead of a model argument). PiperOrigin-RevId: 696170691 --- backprop/optimize_test.cc | 2 +- evals/benchmark_helper.cc | 2 +- evals/cross_entropy.cc | 2 +- evals/gemma_test.cc | 2 +- evals/run_mmlu.cc | 2 +- examples/hello_world/run.cc | 2 +- gemma/configs.h | 1 - gemma/configs_test.cc | 2 +- gemma/gemma-inl.h | 18 ++++++++---------- gemma/gemma.h | 8 ++++++-- gemma/run.cc | 6 +++--- paligemma/paligemma_test.cc | 6 +++--- util/app.h | 4 ++++ 13 files changed, 31 insertions(+), 26 deletions(-) diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 6d83de0..a23ac84 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -74,8 +74,8 @@ TEST(OptimizeTest, GradientDescent) { RuntimeConfig runtime = { .max_generated_tokens = 16, .temperature = 1.0f, - .verbosity = 0, .gen = &gen, + .verbosity = 0, .stream_token = stream_token, .eos_id = ReverseSequenceSampler::kEndToken, }; diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 8c84f96..60cc61e 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -74,8 +74,8 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, runtime_config_ = { .max_generated_tokens = inference.max_generated_tokens, .temperature = inference.temperature, - .verbosity = app.verbosity, .gen = &gen_, + .verbosity = app.verbosity, }; } diff --git a/evals/cross_entropy.cc b/evals/cross_entropy.cc index 13ff3d3..6393c53 100644 --- a/evals/cross_entropy.cc +++ b/evals/cross_entropy.cc @@ -139,8 +139,8 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens, RuntimeConfig runtime = { .max_generated_tokens = max_generated_tokens - 1, .temperature = 0.0f, - .verbosity = verbosity, .gen = nullptr, + .verbosity = verbosity, .stream_token = stream_token, .sample_func = sample_token, }; diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 676a5d2..114d5a3 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -169,8 +169,8 @@ TEST_F(GemmaTest, Multiturn) { RuntimeConfig runtime_config{ .max_generated_tokens = 64, .temperature = 0.0f, - .verbosity = 2, .gen = &s_env->MutableGen(), + .verbosity = 2, .stream_token = stream_token, }; TimingInfo timing_info{.verbosity = 0}; diff --git a/evals/run_mmlu.cc b/evals/run_mmlu.cc index d3618db..77c9dcd 100644 --- a/evals/run_mmlu.cc +++ b/evals/run_mmlu.cc @@ -127,8 +127,8 @@ void Run(GemmaEnv& env, JsonArgs& json) { gcpp::RuntimeConfig runtime_config = { .max_generated_tokens = 30, .temperature = 0.0f, - .verbosity = env.Verbosity(), .gen = &env.MutableGen(), + .verbosity = env.Verbosity(), .stream_token = stream_token, }; env.GetModel()->Generate(runtime_config, prompt, /*pos=*/0, diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 7b2e90f..7e9e561 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -89,8 +89,8 @@ int main(int argc, char** argv) { gcpp::RuntimeConfig runtime_config = { .max_generated_tokens = 1024, .temperature = 1.0, - .verbosity = 0, .gen = &gen, + .verbosity = 0, .stream_token = stream_token, .accept_token = [&](int token, float /* prob */) { diff --git a/gemma/configs.h b/gemma/configs.h index f6a4245..e709df7 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -178,7 +178,6 @@ struct ModelConfig { size_t vit_seq_len = 0; size_t num_tensor_scales = 0; size_t num_vit_scales = 0; - size_t top_k = kTopK; float att_cap = 0.0f; float final_cap = 0.0f; bool absolute_pe = false; diff --git a/gemma/configs_test.cc b/gemma/configs_test.cc index 91bfc53..8128baf 100644 --- a/gemma/configs_test.cc +++ b/gemma/configs_test.cc @@ -374,7 +374,7 @@ void AssertMatch(const ModelConfig& config) { } ASSERT_EQ(TConfig::kVocabSize, config.vocab_size); ASSERT_EQ(TConfig::kSeqLen, config.seq_len); - ASSERT_EQ(TConfig::kTopK, config.top_k); + // ASSERT_EQ(TConfig::kTopK, config.top_k); - is now a runtime config value. ASSERT_EQ(TConfig::kAttCap, config.att_cap); ASSERT_EQ(TConfig::kFinalCap, config.final_cap); ASSERT_EQ(TConfig::kAbsolutePE, config.absolute_pe); diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 2b1587d..51f2999 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -1196,13 +1196,12 @@ class TokenStreamer { hwy::BitSet4096<> is_eos_; }; -HWY_INLINE SampleFunc ChooseSampleFunc(int top_k, - const RuntimeConfig& runtime_config) { +HWY_INLINE SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) { // If user provided a sample_func, use it. if (runtime_config.sample_func) return runtime_config.sample_func; // Fast path for top-1 with no accept_token. - if (top_k == 1 && !runtime_config.accept_token) { + if (runtime_config.top_k == 1 && !runtime_config.accept_token) { return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb { PROFILER_ZONE("Gen.Sample Top1"); return Top1OfSoftmax(logits, vocab_size); @@ -1210,13 +1209,13 @@ HWY_INLINE SampleFunc ChooseSampleFunc(int top_k, } // General case: Softmax with top-k sampling. - return [top_k, &runtime_config](float* logits, - size_t vocab_size) HWY_ATTR -> TokenAndProb { + return [&runtime_config](float* logits, + size_t vocab_size) HWY_ATTR -> TokenAndProb { PROFILER_ZONE("Gen.Sample general"); Softmax(logits, vocab_size); - const int token = - SampleTopK(logits, top_k, vocab_size, *runtime_config.gen, - runtime_config.temperature, runtime_config.accept_token); + const int token = SampleTopK( + logits, runtime_config.top_k, vocab_size, *runtime_config.gen, + runtime_config.temperature, runtime_config.accept_token); return TokenAndProb{.token = token, .prob = logits[token]}; }; } @@ -1276,8 +1275,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, size_t max_prompt_size = MaxQueryLength(queries_prompt); size_t max_generated_tokens = runtime_config.max_generated_tokens; RangeChecks(weights.weights_config, max_generated_tokens, max_prompt_size); - const SampleFunc sample_token = - ChooseSampleFunc(weights.weights_config.top_k, runtime_config); + const SampleFunc sample_token = ChooseSampleFunc(runtime_config); // Prefill stops before min_prompt_size - 1 because the last prompt // token is the first input token for generation. diff --git a/gemma/gemma.h b/gemma/gemma.h index 5df319f..5b84053 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -25,6 +25,7 @@ #include "compression/io.h" // Path #include "gemma/activations.h" #include "gemma/common.h" +#include "gemma/configs.h" #include "gemma/kv_cache.h" #include "gemma/tokenizer.h" #include "gemma/weights.h" @@ -102,9 +103,12 @@ struct RuntimeConfig { // Max queries per batch (one token from each) during decode. size_t decode_qbatch_size = 16; - float temperature; // Temperature for sampling. + // Sampling-related parameters. + float temperature; // Temperature for sampling. + size_t top_k = kTopK; // Top-k for sampling. + std::mt19937* gen; // Random number generator used for sampling. + int verbosity; // Controls verbosity of printed messages. - std::mt19937* gen; // Random number generator used for sampling. // Functions operating on the generated tokens. StreamFunc stream_token; diff --git a/gemma/run.cc b/gemma/run.cc index 2c62bdb..87c7c9d 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -99,7 +99,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, HWY_ASSERT(image.ReadPPM(args.image_file.path)); image.Resize(); RuntimeConfig runtime_config = { - .verbosity = app.verbosity, .gen = &gen, .use_spinning = app.spin}; + .gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin}; double image_tokens_start = hwy::platform::Now(); model.GenerateImageTokens(runtime_config, image, image_tokens); if (app.verbosity >= 1) { @@ -172,8 +172,8 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, } TimingInfo timing_info = {.verbosity = app.verbosity}; - RuntimeConfig runtime_config = {.verbosity = app.verbosity, - .gen = &gen, + RuntimeConfig runtime_config = {.gen = &gen, + .verbosity = app.verbosity, .stream_token = stream_token, .accept_token = accept_token, .use_spinning = app.spin}; diff --git a/paligemma/paligemma_test.cc b/paligemma/paligemma_test.cc index b820eec..64c0ee8 100644 --- a/paligemma/paligemma_test.cc +++ b/paligemma/paligemma_test.cc @@ -56,7 +56,7 @@ void PaliGemmaTest::InitVit(const std::string& path) { HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA); HWY_ASSERT(image.ReadPPM(path)); image.Resize(); - RuntimeConfig runtime_config = {.verbosity = 0, .gen = &s_env->MutableGen()}; + RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0}; model.GenerateImageTokens(runtime_config, image, image_tokens_); } @@ -64,8 +64,8 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{ Gemma& model = *(s_env->GetModel()); s_env->MutableGen().seed(0x12345678); RuntimeConfig runtime_config = {.max_generated_tokens = 512, - .verbosity = 0, - .gen = &s_env->MutableGen()}; + .gen = &s_env->MutableGen(), + .verbosity = 0}; runtime_config.image_tokens = &image_tokens_; size_t abs_pos = 0; std::string mutable_prompt = prompt_text; diff --git a/util/app.h b/util/app.h index ebc16b9..5128a38 100644 --- a/util/app.h +++ b/util/app.h @@ -220,6 +220,7 @@ struct InferenceArgs : public ArgsBase { size_t decode_qbatch_size; float temperature; + size_t top_k; bool deterministic; bool multiturn; Path image_file; @@ -244,6 +245,8 @@ struct InferenceArgs : public ArgsBase { "Decode: max queries per batch."); visitor(temperature, "temperature", 1.0f, "Temperature for top-K", 2); + visitor(top_k, "top_k", size_t{1}, "Number of top-K tokens to sample from", + 2); visitor(deterministic, "deterministic", false, "Make top-k sampling deterministic", 2); visitor(multiturn, "multiturn", false, @@ -259,6 +262,7 @@ struct InferenceArgs : public ArgsBase { runtime_config.prefill_tbatch_size = prefill_tbatch_size; runtime_config.decode_qbatch_size = decode_qbatch_size; runtime_config.temperature = temperature; + runtime_config.top_k = top_k; } }; From 5674c33dc51dd264929b3c689a48f1a5960ef0f1 Mon Sep 17 00:00:00 2001 From: Paul Chang Date: Wed, 13 Nov 2024 10:18:11 -0800 Subject: [PATCH 5/9] Replace CLIF SbsWriter with pybind-based gcpp extension Maintains compatibility with previous version. PiperOrigin-RevId: 696181603 --- MODULE.bazel | 1 + compression/python/BUILD.bazel | 11 +++--- compression/python/compression.clif | 14 -------- compression/python/compression_clif_aux.cc | 2 +- compression/python/compression_clif_aux.h | 2 +- compression/python/compression_extension.cc | 38 +++++++++++++++++++++ 6 files changed, 47 insertions(+), 21 deletions(-) delete mode 100644 compression/python/compression.clif create mode 100644 compression/python/compression_extension.cc diff --git a/MODULE.bazel b/MODULE.bazel index 58faa0d..ee63456 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -9,6 +9,7 @@ bazel_dep(name = "googletest", version = "1.15.2") bazel_dep(name = "highway", version = "1.1.0") bazel_dep(name = "nlohmann_json", version = "3.11.3") bazel_dep(name = "platforms", version = "0.0.10") +bazel_dep(name = "pybind11_bazel", version = "2.12.0") bazel_dep(name = "rules_cc", version = "0.0.9") bazel_dep(name = "rules_license", version = "0.0.7") bazel_dep(name = "google_benchmark", version = "1.8.5") diff --git a/compression/python/BUILD.bazel b/compression/python/BUILD.bazel index 89e2222..6b451bd 100644 --- a/compression/python/BUILD.bazel +++ b/compression/python/BUILD.bazel @@ -1,5 +1,5 @@ -load("//devtools/clif/python:clif_build_rule.bzl", "py_clif_cc") # [internal] load strict.bzl +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") package( default_applicable_licenses = [ @@ -12,8 +12,9 @@ cc_library( name = "compression_clif_aux", srcs = ["compression_clif_aux.cc"], hdrs = ["compression_clif_aux.h"], + visibility = ["//visibility:private"], deps = [ - "//third_party/absl/types:span", + "@abseil-cpp//absl/types:span", "//compression:compress", "//compression:io", "@highway//:hwy", @@ -21,12 +22,12 @@ cc_library( ], ) -py_clif_cc( +pybind_extension( name = "compression", - srcs = ["compression.clif"], + srcs = ["compression_extension.cc"], deps = [ ":compression_clif_aux", - "//third_party/absl/python/numpy:span_clif_lib", + "@abseil-cpp//absl/types:span", ], ) diff --git a/compression/python/compression.clif b/compression/python/compression.clif deleted file mode 100644 index 69dfc9b..0000000 --- a/compression/python/compression.clif +++ /dev/null @@ -1,14 +0,0 @@ -from "third_party/absl/python/numpy/span.h" import * -from "third_party/gemma_cpp/compression/python/compression_clif_aux.h": - namespace `gcpp`: - class SbsWriter: - # NOTE: Individual compression backends may impose constraints on the - # array length, such as a minimum of (say) 32 elements. - def `Insert` as insert(self, name: str, weights: NumpyArray) - def `InsertNUQ` as insert_nuq(self, name: str, weights: NumpyArray) - def `InsertBfloat16` as insert_bf16(self, name: str, weights: NumpyArray) - def `InsertFloat` as insert_float(self, name: str, weights: NumpyArray) - - def `AddScales` as add_scales(self, scales: list) - - def `Write` as write(self, path: str) diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index a9d3894..ba91781 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -20,7 +20,7 @@ #ifndef GEMMA_ONCE #define GEMMA_ONCE -#include "third_party/absl/types/span.h" +#include "absl/types/span.h" #include "compression/io.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" diff --git a/compression/python/compression_clif_aux.h b/compression/python/compression_clif_aux.h index 8dc7a9d..fd4efc8 100644 --- a/compression/python/compression_clif_aux.h +++ b/compression/python/compression_clif_aux.h @@ -5,7 +5,7 @@ #include #include -#include "third_party/absl/types/span.h" +#include "absl/types/span.h" namespace gcpp { diff --git a/compression/python/compression_extension.cc b/compression/python/compression_extension.cc new file mode 100644 index 0000000..c2916a8 --- /dev/null +++ b/compression/python/compression_extension.cc @@ -0,0 +1,38 @@ +#include + +#include +#include +#include + +#include "absl/types/span.h" +#include "compression/python/compression_clif_aux.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +using gcpp::SbsWriter; + +namespace py = pybind11; + +namespace { +template +void wrap_span(SbsWriter& writer, std::string name, py::array_t data) { + if (data.ndim() != 1 || data.strides(0) != sizeof(float)) { + throw std::domain_error("Input array must be 1D and densely packed."); + } + std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size())); +} +} // namespace + +PYBIND11_MODULE(compression, m) { + py::class_(m, "SbsWriter") + .def(py::init<>()) + // NOTE: Individual compression backends may impose constraints on the + // array length, such as a minimum of (say) 32 elements. + .def("insert", wrap_span<&SbsWriter::Insert>) + .def("insert_nuq", wrap_span<&SbsWriter::InsertNUQ>) + .def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>) + .def("insert_float", wrap_span<&SbsWriter::InsertFloat>) + .def("add_scales", &SbsWriter::AddScales) + .def("write", &SbsWriter::Write); +} From 96513a8dc308dbd81f565d3f98d66742b3178cba Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Thu, 14 Nov 2024 03:26:00 -0800 Subject: [PATCH 6/9] Added a blob_compare tool that compares two sbs files that may have the blobs in a different order PiperOrigin-RevId: 696458888 --- compression/BUILD.bazel | 11 +++++++ compression/blob_compare.cc | 64 +++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 compression/blob_compare.cc diff --git a/compression/BUILD.bazel b/compression/BUILD.bazel index f16b4d7..585a74b 100644 --- a/compression/BUILD.bazel +++ b/compression/BUILD.bazel @@ -259,3 +259,14 @@ cc_binary( "@highway//:thread_pool", ], ) + +cc_binary( + name = "blob_compare", + srcs = ["blob_compare.cc"], + deps = [ + ":blob_store", + ":io", + "@highway//:hwy", + "@highway//:hwy_test_util", + ], +) diff --git a/compression/blob_compare.cc b/compression/blob_compare.cc new file mode 100644 index 0000000..c6f0a00 --- /dev/null +++ b/compression/blob_compare.cc @@ -0,0 +1,64 @@ +#include +#include +#include +#include + +#include "compression/blob_store.h" +#include "compression/io.h" +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" + +namespace gcpp { + +// Compares two sbs files, ignoring the order of the blobs. +// Gives up on the first mismatch. +void CompareBlobs(const char* path1, const char* path2) { + BlobReader reader1; + HWY_ASSERT(reader1.Open(Path(path1)) == 0); + BlobReader reader2; + HWY_ASSERT(reader2.Open(Path(path2)) == 0); + hwy::Span keys1 = reader1.Keys(); + size_t total_matches = 0; + size_t total_fails = 0; + for (size_t i = 0; i < keys1.size(); ++i) { + fprintf(stderr, "key %s, blob1 size=%zu, blob2 size=%zu\n", + StringFromKey(keys1[i]).c_str(), reader1.BlobSize(keys1[i]), + reader2.BlobSize(keys1[i])); + std::vector data1(reader1.BlobSize(keys1[i])); + HWY_ASSERT(reader1.ReadOne(keys1[i], data1.data(), data1.size()) == 0); + HWY_ASSERT(reader2.BlobSize(keys1[i]) == data1.size()); + std::vector data2(reader2.BlobSize(keys1[i])); + HWY_ASSERT(reader2.ReadOne(keys1[i], data2.data(), data2.size()) == 0); + size_t fails = 0; + for (size_t j = 0; j < data1.size(); ++j) { + if (data1[j] != data2[j]) { + if (fails == 0) { + fprintf(stderr, "key %s Mismatch at %zu\n", + StringFromKey(keys1[i]).c_str(), j); + } + ++fails; + } + } + if (fails > 0) { + fprintf(stderr, "key %s has %.2f%% Mismatch!\n", + StringFromKey(keys1[i]).c_str(), 100.0 * fails / data1.size()); + ++total_fails; + } else { + fprintf(stderr, "key %s Matched!\n", StringFromKey(keys1[i]).c_str()); + ++total_matches; + } + } + fprintf(stderr, "Total matches=%zu, mismatches=%zu\n", total_matches, + total_fails); +} + +} // namespace gcpp + +int main(int argc, char** argv) { + if (argc != 3) { + fprintf(stderr, "Usage: %s \n", argv[0]); + return 1; + } + gcpp::CompareBlobs(argv[1], argv[2]); + return 0; +} From 36f02ef89256a49cae88cbe85eded5eee7e32caf Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 15 Nov 2024 02:21:55 -0800 Subject: [PATCH 7/9] Internal change. PiperOrigin-RevId: 696815335 --- examples/hello_world/run.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 7e9e561..70c3654 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -33,6 +33,10 @@ #include "hwy/contrib/thread_pool/thread_pool.h" int main(int argc, char** argv) { + { + // Placeholder for internal init, do not modify. + } + gcpp::LoaderArgs loader(argc, argv); gcpp::InferenceArgs inference(argc, argv); gcpp::AppArgs app(argc, argv); From 7d685a267f29efb2bc8504da33708e90f385a7ff Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Mon, 18 Nov 2024 04:03:10 -0800 Subject: [PATCH 8/9] Added pybind for configs. Added ability to test configs for equality. PiperOrigin-RevId: 697572671 --- gemma/configs.cc | 117 +++++++++++++++++++++++++++++++++ gemma/configs.h | 36 +++++++++-- gemma/python/BUILD.bazel | 17 +++++ gemma/python/configs.cc | 136 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 302 insertions(+), 4 deletions(-) create mode 100644 gemma/python/BUILD.bazel create mode 100644 gemma/python/configs.cc diff --git a/gemma/configs.cc b/gemma/configs.cc index 03fce99..7724c59 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -15,6 +15,8 @@ #include "gemma/configs.h" +#include + #include "hwy/base.h" namespace gcpp { @@ -181,6 +183,7 @@ static ModelConfig ConfigGriffin2B() { .conv1d_width = 4, .ff_biases = true, .softmax_attn_output_biases = true, + .optimized_gating = false, .type = LayerAttentionType::kGriffinRecurrentBlock, .activation = ActivationType::Gelu, .post_qk = PostQKType::HalfRope, @@ -204,6 +207,9 @@ static void AddVitConfig(ModelConfig& config) { config.vocab_size = 256000 + 1024 + 128; // = 257152 config.image_size = 224; config.patch_width = 14; + for (auto& layer_config : config.layer_configs) { + layer_config.optimized_gating = false; + } const size_t num_patches = config.image_size / config.patch_width; config.vit_seq_len = num_patches * num_patches; LayerConfig vit_layer_config = { @@ -260,4 +266,115 @@ ModelConfig ConfigFromModel(Model model) { } } +#define TEST_EQUAL(a, b) \ + if (a != b) { \ + if (debug) \ + std::cerr << #a << "=" << a << " != " << #b << "=" << b << "\n"; \ + result = false; \ + } + +#define RETURN_IF_NOT_EQUAL(a, b) \ + if (a != b) { \ + if (debug) \ + std::cerr << #a << "=" << a << " != " << #b << "=" << b << "\n"; \ + return false; \ + } + +#define WARN_IF_NOT_EQUAL(a, b) \ + if (a != b) { \ + std::cerr << #a << "=" << a << " != " << #b << "=" << b << "\n"; \ + } + +bool LayerConfig::TestEqual(const LayerConfig& other, bool partial, + bool debug) const { + bool result = true; + // Optimized gating may not be set correctly in the c++ configs. + if (debug) { + WARN_IF_NOT_EQUAL(optimized_gating, other.optimized_gating) + } + TEST_EQUAL(model_dim, other.model_dim); + TEST_EQUAL(griffin_dim, other.griffin_dim); + TEST_EQUAL(ff_hidden_dim, other.ff_hidden_dim); + TEST_EQUAL(heads, other.heads); + TEST_EQUAL(kv_heads, other.kv_heads); + TEST_EQUAL(qkv_dim, other.qkv_dim); + TEST_EQUAL(conv1d_width, other.conv1d_width); + if (!partial) { + TEST_EQUAL(ff_biases, other.ff_biases); + TEST_EQUAL(softmax_attn_output_biases, other.softmax_attn_output_biases); + } + TEST_EQUAL(static_cast(post_norm), static_cast(other.post_norm)); + TEST_EQUAL(static_cast(type), static_cast(other.type)); + TEST_EQUAL(static_cast(activation), static_cast(other.activation)); + TEST_EQUAL(static_cast(post_qk), static_cast(other.post_qk)); + return result; +} + +bool ModelConfig::TestEqual(const ModelConfig& other, bool partial, + bool debug) const { + bool result = true; + // We don't care about model_name, model, training, or weight being different, + // but will output in debug mode if they are. + if (debug) { + WARN_IF_NOT_EQUAL(model_name, other.model_name); + WARN_IF_NOT_EQUAL(static_cast(model), static_cast(other.model)); + WARN_IF_NOT_EQUAL(static_cast(training), + static_cast(other.training)); + WARN_IF_NOT_EQUAL(static_cast(weight), static_cast(other.weight)); + } + TEST_EQUAL(model_dim, other.model_dim); + TEST_EQUAL(vit_model_dim, other.vit_model_dim); + TEST_EQUAL(vocab_size, other.vocab_size); + TEST_EQUAL(seq_len, other.seq_len); + TEST_EQUAL(vit_seq_len, other.vit_seq_len); + if (!partial) { + TEST_EQUAL(num_tensor_scales, other.num_tensor_scales); + TEST_EQUAL(num_vit_scales, other.num_vit_scales); + } + TEST_EQUAL(att_cap, other.att_cap); + TEST_EQUAL(final_cap, other.final_cap); + TEST_EQUAL(absolute_pe, other.absolute_pe); + TEST_EQUAL(use_local_attention, other.use_local_attention); + TEST_EQUAL(static_cast(query_scale), + static_cast(other.query_scale)); + RETURN_IF_NOT_EQUAL(layer_configs.size(), other.layer_configs.size()); + for (size_t i = 0; i < layer_configs.size(); ++i) { + result &= + layer_configs[i].TestEqual(other.layer_configs[i], partial, debug); + } + RETURN_IF_NOT_EQUAL(attention_window_sizes.size(), + other.attention_window_sizes.size()); + for (size_t i = 0; i < attention_window_sizes.size(); ++i) { + TEST_EQUAL(attention_window_sizes[i], other.attention_window_sizes[i]); + } + RETURN_IF_NOT_EQUAL(vit_layer_configs.size(), other.vit_layer_configs.size()); + for (size_t i = 0; i < vit_layer_configs.size(); ++i) { + result &= vit_layer_configs[i].TestEqual(other.vit_layer_configs[i], + partial, debug); + } + if (!partial) { + if (scale_names != other.scale_names) { + result = false; + if (debug) { + std::cerr << "scale_names mismatch\n"; + } + } + } + TEST_EQUAL(norm_num_groups, other.norm_num_groups); + TEST_EQUAL(model_family_version, other.model_family_version); + TEST_EQUAL(patch_width, other.patch_width); + TEST_EQUAL(image_size, other.image_size); + return result; +} + +Model ModelFromConfig(const ModelConfig& config) { + for (Model model : kAllModels) { + ModelConfig model_config = ConfigFromModel(model); + if (config.TestEqual(model_config, /*partial=*/true, /*debug=*/false)) { + return model; + } + } + return Model::UNKNOWN; +} + } // namespace gcpp diff --git a/gemma/configs.h b/gemma/configs.h index e709df7..6bbbc45 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -116,7 +116,20 @@ enum class Model { PALIGEMMA_224, }; +// Allows the Model enum to be iterated over. +static constexpr Model kAllModels[] = { + Model::GEMMA_2B, Model::GEMMA_7B, Model::GEMMA2_9B, Model::GEMMA2_27B, + Model::GRIFFIN_2B, Model::GEMMA_TINY, Model::GEMMA2_2B, + Model::PALIGEMMA_224, +}; + struct LayerConfig { + // Returns true if *this and other are equal. + // If partial is true, then we don't check for items that are only set after + // the tensors are loaded from the checkpoint. + // If debug is true, then we output the mismatched fields to stderr. + bool TestEqual(const LayerConfig& other, bool partial, bool debug) const; + size_t CacheLayerSize() const { return kv_heads * qkv_dim * 2; } // Multi-Head Attention? @@ -132,9 +145,10 @@ struct LayerConfig { size_t heads = 0; size_t kv_heads = 0; size_t qkv_dim = 0; - size_t conv1d_width = 0; + size_t conv1d_width = 0; // griffin only bool ff_biases = false; bool softmax_attn_output_biases = false; + bool optimized_gating = true; PostNormType post_norm = PostNormType::None; LayerAttentionType type = LayerAttentionType::kGemma; ActivationType activation = ActivationType::Gelu; @@ -142,6 +156,16 @@ struct LayerConfig { }; struct ModelConfig { + // Returns true if *this and other are equal. + // If partial is true, then we don't check for items that are only set after + // the tensors are loaded from the checkpoint. + // If debug is true, then we output the mismatched fields to stderr. + bool TestEqual(const ModelConfig& other, bool partial, bool debug) const; + + void AddLayerConfig(const LayerConfig& layer_config) { + layer_configs.push_back(layer_config); + } + size_t CachePosSize() const { size_t num_layers = layer_configs.size(); return num_layers * layer_configs[0].CacheLayerSize(); @@ -171,6 +195,7 @@ struct ModelConfig { Model model; ModelTraining training; Type weight; + size_t num_layers = 0; size_t model_dim = 0; size_t vit_model_dim = 0; size_t vocab_size = 0; @@ -181,7 +206,7 @@ struct ModelConfig { float att_cap = 0.0f; float final_cap = 0.0f; bool absolute_pe = false; - bool use_local_attention = false; + bool use_local_attention = false; // griffin only QueryScaleType query_scale = QueryScaleType::SqrtKeySize; std::vector layer_configs; std::vector attention_window_sizes; @@ -190,13 +215,16 @@ struct ModelConfig { int norm_num_groups = 1; int model_family_version = 1; // Dimensions related to image processing. - int patch_width = 14; - int image_size = 224; + size_t patch_width = 14; + size_t image_size = 224; }; // Returns the config for the given model. ModelConfig ConfigFromModel(Model model); +// Returns the model for the given config, if it matches any standard model. +Model ModelFromConfig(const ModelConfig& config); + // Returns the sub-config for the ViT model of the PaliGemma model. ModelConfig VitConfig(const ModelConfig& config); diff --git a/gemma/python/BUILD.bazel b/gemma/python/BUILD.bazel new file mode 100644 index 0000000..d6b09b9 --- /dev/null +++ b/gemma/python/BUILD.bazel @@ -0,0 +1,17 @@ +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package( + default_applicable_licenses = [ + "//:license", # Placeholder comment, do not modify + ], + default_visibility = ["//visibility:public"], +) + +pybind_extension( + name = "configs", + srcs = ["configs.cc"], + deps = [ + "//:common", + "//compression:sfp", + ], +) diff --git a/gemma/python/configs.cc b/gemma/python/configs.cc new file mode 100644 index 0000000..aff93cc --- /dev/null +++ b/gemma/python/configs.cc @@ -0,0 +1,136 @@ +#include "gemma/configs.h" + +#include +#include + +#include "compression/shared.h" +#include "pybind11/cast.h" + +using gcpp::ActivationType; +using gcpp::LayerAttentionType; +using gcpp::LayerConfig; +using gcpp::Model; +using gcpp::ModelConfig; +using gcpp::ModelTraining; +using gcpp::PostNormType; +using gcpp::PostQKType; +using gcpp::QueryScaleType; +using gcpp::ResidualType; +using gcpp::Type; + +namespace pybind11 { + +PYBIND11_MODULE(configs, py_module) { + enum_(py_module, "ModelTraining") + .value("GEMMA_IT", ModelTraining::GEMMA_IT) + .value("GEMMA_PT", ModelTraining::GEMMA_PT) + .value("PALIGEMMA", ModelTraining::PALIGEMMA); + + enum_(py_module, "Type") + .value("kUnknown", Type::kUnknown) + .value("kF32", Type::kF32) + .value("kBF16", Type::kBF16) + .value("kSFP", Type::kSFP) + .value("kNUQ", Type::kNUQ) + .value("kF64", Type::kF64) + .value("kC64", Type::kC64) + .value("kU128", Type::kU128); + + enum_(py_module, "LayerAttentionType") + .value("kGemma", LayerAttentionType::kGemma) + .value("kGriffinRecurrentBlock", + LayerAttentionType::kGriffinRecurrentBlock) + .value("kVit", LayerAttentionType::kVit); + + enum_(py_module, "PostNormType") + .value("NoPostNorm", PostNormType::None) + .value("Scale", PostNormType::Scale); + + enum_(py_module, "PostQKType") + .value("Rope", PostQKType::Rope) + .value("HalfRope", PostQKType::HalfRope); + + enum_(py_module, "ActivationType") + .value("Gelu", ActivationType::Gelu); + + enum_(py_module, "QueryScaleType") + .value("SqrtKeySize", QueryScaleType::SqrtKeySize) + .value("SqrtModelDimDivNumHeads", + QueryScaleType::SqrtModelDimDivNumHeads); + + enum_(py_module, "ResidualType") + .value("Add", ResidualType::Add); + + enum_(py_module, "Model") + .value("UNKNOWN", Model::UNKNOWN) + .value("GEMMA_2B", Model::GEMMA_2B) + .value("GEMMA_7B", Model::GEMMA_7B) + .value("GEMMA2_9B", Model::GEMMA2_9B) + .value("GEMMA2_27B", Model::GEMMA2_27B) + .value("GRIFFIN_2B", Model::GRIFFIN_2B) + .value("GEMMA_TINY", Model::GEMMA_TINY) + .value("GEMMA2_2B", Model::GEMMA2_2B) + .value("PALIGEMMA_224", Model::PALIGEMMA_224); + + class_(py_module, "LayerConfig") + .def(init()) + .def_readwrite("model_dim", &LayerConfig::model_dim) + .def_readwrite("griffin_dim", &LayerConfig::griffin_dim) + .def_readwrite("ff_hidden_dim", &LayerConfig::ff_hidden_dim) + .def_readwrite("heads", &LayerConfig::heads) + .def_readwrite("kv_heads", &LayerConfig::kv_heads) + .def_readwrite("qkv_dim", &LayerConfig::qkv_dim) + .def_readwrite("conv1d_width", &LayerConfig::conv1d_width) + .def_readwrite("ff_biases", &LayerConfig::ff_biases) + .def_readwrite("softmax_attn_output_biases", + &LayerConfig::softmax_attn_output_biases) + .def_readwrite("optimized_gating", &LayerConfig::optimized_gating) + .def_readwrite("post_norm", &LayerConfig::post_norm) + .def_readwrite("type", &LayerConfig::type) + .def_readwrite("activation", &LayerConfig::activation) + .def_readwrite("post_qk", &LayerConfig::post_qk); + + class_(py_module, "ModelConfig") + .def(init()) + .def_readwrite("model_name", &ModelConfig::model_name) + .def_readwrite("model", &ModelConfig::model) + .def_readwrite("training", &ModelConfig::training) + .def_readwrite("weight", &ModelConfig::weight) + .def_readwrite("num_layers", &ModelConfig::num_layers) + .def_readwrite("model_dim", &ModelConfig::model_dim) + .def_readwrite("vit_model_dim", &ModelConfig::vit_model_dim) + .def_readwrite("vocab_size", &ModelConfig::vocab_size) + .def_readwrite("seq_len", &ModelConfig::seq_len) + .def_readwrite("vit_seq_len", &ModelConfig::vit_seq_len) + .def_readwrite("num_tensor_scales", &ModelConfig::num_tensor_scales) + .def_readwrite("num_vit_scales", &ModelConfig::num_vit_scales) + .def_readwrite("att_cap", &ModelConfig::att_cap) + .def_readwrite("final_cap", &ModelConfig::final_cap) + .def_readwrite("absolute_pe", &ModelConfig::absolute_pe) + .def_readwrite("use_local_attention", &ModelConfig::use_local_attention) + .def_readwrite("query_scale", &ModelConfig::query_scale) + .def_readwrite("layer_configs", &ModelConfig::layer_configs) + .def_readwrite("attention_window_sizes", + &ModelConfig::attention_window_sizes) + .def_readwrite("vit_layer_configs", &ModelConfig::vit_layer_configs) + .def_readwrite("scale_names", &ModelConfig::scale_names) + .def_readwrite("norm_num_groups", &ModelConfig::norm_num_groups) + .def_readwrite("model_family_version", &ModelConfig::model_family_version) + .def_readwrite("patch_width", &ModelConfig::patch_width) + .def_readwrite("image_size", &ModelConfig::image_size) + .def("add_layer_config", &ModelConfig::AddLayerConfig, + arg("layer_config")) + .def("test_equal", &ModelConfig::TestEqual, arg("other"), arg("partial"), + arg("debug")); + + // Returns the config for the given model. + py_module.def("config_from_model", &gcpp::ConfigFromModel, arg("model")); + + // Returns the model for the given config, if it matches any standard model. + py_module.def("model_from_config", &gcpp::ModelFromConfig, arg("config")); + + // Returns the sub-config for the ViT model of the PaliGemma model. + py_module.def("vit_config", &gcpp::VitConfig, arg("config")); +} + +} // namespace pybind11 From 73640d25212b0cb6d24f62741f9ef70d33cb4c1d Mon Sep 17 00:00:00 2001 From: Ray Smith Date: Tue, 19 Nov 2024 00:25:01 -0800 Subject: [PATCH 9/9] Added tensor_index as a single source of truth on tensor shapes/sources and transformations PiperOrigin-RevId: 697903886 --- BUILD.bazel | 18 +- CMakeLists.txt | 3 + compression/BUILD.bazel | 1 + compression/compress.h | 22 ++ gemma/python/configs.cc | 20 ++ gemma/tensor_index.cc | 565 +++++++++++++++++++++++++++++++++++++ gemma/tensor_index.h | 91 ++++++ gemma/tensor_index_test.cc | 73 +++++ gemma/weights.h | 13 +- 9 files changed, 798 insertions(+), 8 deletions(-) create mode 100644 gemma/tensor_index.cc create mode 100644 gemma/tensor_index.h create mode 100644 gemma/tensor_index_test.cc diff --git a/BUILD.bazel b/BUILD.bazel index e5a7939..669caf8 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -194,13 +194,15 @@ cc_library( srcs = [ "gemma/common.cc", "gemma/configs.cc", + "gemma/tensor_index.cc", ], hdrs = [ "gemma/common.h", "gemma/configs.h", + "gemma/tensor_index.h", ], deps = [ - "//compression:compress", + "//compression:sfp", "@highway//:hwy", # base.h "@highway//:thread_pool", ], @@ -215,6 +217,20 @@ cc_test( ], ) +cc_test( + name = "tensor_index_test", + srcs = ["gemma/tensor_index_test.cc"], + deps = [ + ":basics", + ":common", + ":weights", + "@googletest//:gtest_main", + "//compression:compress", + "@highway//:hwy", + "@highway//:thread_pool", + ], +) + cc_library( name = "weights", srcs = ["gemma/weights.cc"], diff --git a/CMakeLists.txt b/CMakeLists.txt index 876da6c..4d03da0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -84,6 +84,8 @@ set(SOURCES gemma/instantiations/sfp.cc gemma/kv_cache.cc gemma/kv_cache.h + gemma/tensor_index.cc + gemma/tensor_index.h gemma/tokenizer.cc gemma/tokenizer.h gemma/weights.cc @@ -157,6 +159,7 @@ set(GEMMA_TEST_FILES compression/nuq_test.cc compression/sfp_test.cc evals/gemma_test.cc + gemma/tensor_index_test.cc ops/dot_test.cc ops/gemma_matvec_test.cc ops/matmul_test.cc diff --git a/compression/BUILD.bazel b/compression/BUILD.bazel index 585a74b..c7cec0a 100644 --- a/compression/BUILD.bazel +++ b/compression/BUILD.bazel @@ -201,6 +201,7 @@ cc_library( ":sfp", "//:allocator", "//:basics", + "//:common", "@highway//:hwy", "@highway//:nanobenchmark", "@highway//:profiler", diff --git a/compression/compress.h b/compression/compress.h index 9050d53..ff64b49 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -17,6 +17,7 @@ #ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_ #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_ +#include "hwy/base.h" #define COMPRESS_STATS 0 #include @@ -33,6 +34,7 @@ #include "compression/blob_store.h" #include "compression/io.h" #include "compression/shared.h" +#include "gemma/tensor_index.h" #include "util/basics.h" // IWYU pragma: end_exports #include "util/allocator.h" @@ -211,6 +213,26 @@ class MatPtrT : public MatPtr { // Full constructor for dynamic sizing. MatPtrT(const std::string& name, size_t rows, size_t cols) : MatPtr(name, TypeEnum(), sizeof(MatT), rows, cols) {} + // Construction from TensorIndex entry to remove duplication of sizes. + MatPtrT(const std::string& name, const TensorIndex& tensor_index) + : MatPtr(name, TypeEnum(), sizeof(MatT), 0, 0) { + const TensorInfo* tensor = tensor_index.FindName(name); + HWY_ASSERT(tensor != nullptr); + cols_ = tensor->shape.back(); + rows_ = 1; + if (tensor->cols_take_extra_dims) { + // The columns eat the extra dimensions. + rows_ = tensor->shape[0]; + for (size_t i = 1; i < tensor->shape.size() - 1; ++i) { + cols_ *= tensor->shape[i]; + } + } else { + // The rows eat the extra dimensions. + for (size_t i = 0; i < tensor->shape.size() - 1; ++i) { + rows_ *= tensor->shape[i]; + } + } + } // Copying allowed as the metadata is small. MatPtrT(const MatPtr& other) : MatPtr(other) {} diff --git a/gemma/python/configs.cc b/gemma/python/configs.cc index aff93cc..8c37840 100644 --- a/gemma/python/configs.cc +++ b/gemma/python/configs.cc @@ -4,6 +4,7 @@ #include #include "compression/shared.h" +#include "gemma/tensor_index.h" #include "pybind11/cast.h" using gcpp::ActivationType; @@ -16,6 +17,8 @@ using gcpp::PostNormType; using gcpp::PostQKType; using gcpp::QueryScaleType; using gcpp::ResidualType; +using gcpp::TensorIndex; +using gcpp::TensorInfo; using gcpp::Type; namespace pybind11 { @@ -72,6 +75,23 @@ PYBIND11_MODULE(configs, py_module) { .value("GEMMA2_2B", Model::GEMMA2_2B) .value("PALIGEMMA_224", Model::PALIGEMMA_224); + class_(py_module, "TensorInfo") + .def(init()) + .def_readwrite("name", &TensorInfo::name) + .def_readwrite("source_names", &TensorInfo::source_names) + .def_readwrite("preshape", &TensorInfo::preshape) + .def_readwrite("axes", &TensorInfo::axes) + .def_readwrite("shape", &TensorInfo::shape) + .def_readwrite("concat_names", &TensorInfo::concat_names) + .def_readwrite("concat_axis", &TensorInfo::concat_axis) + .def_readwrite("min_size", &TensorInfo::min_size) + .def_readwrite("scaled_softplus", &TensorInfo::scaled_softplus) + .def_readwrite("cols_take_extra_dims", &TensorInfo::cols_take_extra_dims); + + class_(py_module, "TensorIndex") + .def(init()) + .def("get_tensor_info", &TensorIndex::GetTensorInfo, arg("path")); + class_(py_module, "LayerConfig") .def(init()) .def_readwrite("model_dim", &LayerConfig::model_dim) diff --git a/gemma/tensor_index.cc b/gemma/tensor_index.cc new file mode 100644 index 0000000..9f44a9c --- /dev/null +++ b/gemma/tensor_index.cc @@ -0,0 +1,565 @@ +#include "gemma/tensor_index.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include "compression/shared.h" +#include "gemma/configs.h" + +namespace gcpp { +namespace { + +// Returns the non-layer tensors for the model. +std::vector ModelTensors(const ModelConfig& config) { + return { + TensorInfo{ + .name = "c_embedding", + .source_names = {"embedder/input_embedding"}, + .axes = {0, 1}, + .shape = {config.vocab_size, config.model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "c_final_norm", + .source_names = {"final_norm/scale"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "enc_norm_bias", + .source_names = {"img/Transformer/encoder_norm/bias"}, + .axes = {0}, + .shape = {config.vit_model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "enc_norm_scale", + .source_names = {"img/Transformer/encoder_norm/scale"}, + .axes = {0}, + .shape = {config.vit_model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "img_emb_bias", + .source_names = {"img/embedding/bias"}, + .axes = {0}, + .shape = {config.vit_model_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "img_emb_kernel", + .source_names = {"img/embedding/kernel"}, + .axes = {3, 0, 1, 2}, + .shape = {config.vit_model_dim, config.patch_width, + config.patch_width, 3}, + .min_size = Type::kBF16, + .cols_take_extra_dims = true, + }, + TensorInfo{ + .name = "img_head_bias", + .source_names = {"img/head/bias"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "img_head_kernel", + .source_names = {"img/head/kernel"}, + .axes = {1, 0}, + .shape = {config.model_dim, config.vit_model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "img_pos_emb", + .source_names = {"img/pos_embedding"}, + .axes = {0, 1}, + .shape = {/*1,*/ 256, config.vit_model_dim}, + .min_size = Type::kF32, + }, + }; +} + +// Returns the tensors for the given image layer config. +std::vector ImageLayerTensors(const ModelConfig& config, + const LayerConfig& layer_config) { + return { + // Vit layers. + TensorInfo{ + .name = "attn_out_w", + .source_names = {"MultiHeadDotProductAttention_0/out/kernel"}, + .axes = {2, 0, 1}, + .shape = {config.vit_model_dim, layer_config.heads, + layer_config.qkv_dim}, + .min_size = Type::kBF16, + .cols_take_extra_dims = true, + }, + TensorInfo{ + .name = "attn_out_b", + .source_names = {"MultiHeadDotProductAttention_0/out/bias"}, + .axes = {0}, + .shape = {config.vit_model_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "q_ein_w", + .source_names = {"MultiHeadDotProductAttention_0/query/kernel"}, + .axes = {1, 2, 0}, + .shape = {layer_config.heads, layer_config.qkv_dim, + config.vit_model_dim}, + .concat_names = {"qkv_ein_w", "k_ein_w", "v_ein_w"}, + .concat_axis = 1, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "k_ein_w", + .source_names = {"MultiHeadDotProductAttention_0/key/kernel"}, + .axes = {1, 2, 0}, + .shape = {layer_config.heads, layer_config.qkv_dim, + config.vit_model_dim}, + .concat_names = {""}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "v_ein_w", + .source_names = {"MultiHeadDotProductAttention_0/value/kernel"}, + .axes = {1, 2, 0}, + .shape = {layer_config.heads, layer_config.qkv_dim, + config.vit_model_dim}, + .concat_names = {""}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "qkv_ein_w", + .source_names = {"MultiHeadDotProductAttention_0/qkv/kernel"}, + .axes = {2, 0, 3, 1}, + .shape = {layer_config.heads, 3, layer_config.qkv_dim, + config.vit_model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "q_ein_b", + .source_names = {"MultiHeadDotProductAttention_0/query/bias"}, + .axes = {0, 1}, + .shape = {layer_config.heads, layer_config.qkv_dim}, + .concat_names = {"qkv_ein_b", "k_ein_b", "v_ein_b"}, + .concat_axis = 1, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "k_ein_b", + .source_names = {"MultiHeadDotProductAttention_0/key/bias"}, + .axes = {0, 1}, + .shape = {layer_config.heads, layer_config.qkv_dim}, + .concat_names = {""}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "v_ein_b", + .source_names = {"MultiHeadDotProductAttention_0/value/bias"}, + .axes = {0, 1}, + .shape = {layer_config.heads, layer_config.qkv_dim}, + .concat_names = {""}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "qkv_ein_b", + .source_names = {"MultiHeadDotProductAttention_0/qkv/bias"}, + .axes = {1, 0, 2}, + .shape = {layer_config.heads * 3, layer_config.qkv_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "linear_0_w", + .source_names = {"MlpBlock_0/Dense_0/kernel"}, + .axes = {1, 0}, + .shape = {layer_config.ff_hidden_dim, config.vit_model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "linear_0_b", + .source_names = {"MlpBlock_0/Dense_0/bias"}, + .axes = {0}, + .shape = {layer_config.ff_hidden_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "linear_1_w", + .source_names = {"MlpBlock_0/Dense_1/kernel"}, + .axes = {1, 0}, + .shape = {config.vit_model_dim, layer_config.ff_hidden_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "linear_1_b", + .source_names = {"MlpBlock_0/Dense_1/bias"}, + .axes = {0}, + .shape = {config.vit_model_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "ln_0_bias", + .source_names = {"img/Transformer/encoderblock/LayerNorm_0/bias"}, + .axes = {0}, + .shape = {config.vit_model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "ln_0_scale", + .source_names = {"img/Transformer/encoderblock/LayerNorm_0/scale"}, + .axes = {0}, + .shape = {config.vit_model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "ln_1_bias", + .source_names = {"img/Transformer/encoderblock/LayerNorm_1/bias"}, + .axes = {0}, + .shape = {config.vit_model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "ln_1_scale", + .source_names = {"img/Transformer/encoderblock/LayerNorm_1/scale"}, + .axes = {0}, + .shape = {config.vit_model_dim}, + .min_size = Type::kBF16, + }, + }; +} + +// Returns the tensors for the given LLM layer config. +std::vector LLMLayerTensors(const ModelConfig& config, + const LayerConfig& layer_config, + bool reshape_att) { + std::vector tensors = { + TensorInfo{ + .name = "qkv1_w", + .source_names = {"attn/q_einsum/w"}, + .axes = {0, 2, 1}, + .shape = {layer_config.heads, layer_config.qkv_dim, config.model_dim}, + .concat_names = {"qkv_ein", "qkv2_w"}, + }, + TensorInfo{ + .name = "qkv2_w", + .source_names = {"attn/kv_einsum/w"}, + .axes = {1, 0, 3, 2}, + .shape = {2 * layer_config.kv_heads, layer_config.qkv_dim, + config.model_dim}, + .concat_names = {""}, + }, + TensorInfo{ + .name = "q_ein", + .source_names = {"attention_block/proj_q/kernel"}, + .axes = {1, 0}, + .shape = {layer_config.model_dim, layer_config.model_dim}, + .concat_names = {"qkv_ein", "k_ein", "v_ein"}, + }, + TensorInfo{ + .name = "k_ein", + .source_names = {"attention_block/proj_k/kernel"}, + .axes = {1, 0}, + .shape = {layer_config.qkv_dim, layer_config.model_dim}, + .concat_names = {""}, + }, + TensorInfo{ + .name = "v_ein", + .source_names = {"attention_block/proj_v/kernel"}, + .axes = {1, 0}, + .shape = {layer_config.qkv_dim, layer_config.model_dim}, + .concat_names = {""}, + }, + TensorInfo{ + .name = "qkv_ein", + .source_names = {"attn/qkv_einsum/w"}, + .axes = {1, 0, 3, 2}, + .shape = {(layer_config.heads + 2 * layer_config.kv_heads), + layer_config.qkv_dim, config.model_dim}, + }, + TensorInfo{ + .name = "attn_ob", + .source_names = {"attention_block/proj_final/bias"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kF32, + }, + // Griffin layers. + TensorInfo{ + .name = "gr_lin_x_w", + .source_names = {"recurrent_block/linear_x/kernel"}, + .axes = {1, 0}, + .shape = {layer_config.griffin_dim, layer_config.griffin_dim}, + }, + TensorInfo{ + .name = "gr_lin_x_b", + .source_names = {"recurrent_block/linear_x/bias"}, + .axes = {0}, + .shape = {layer_config.griffin_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "gr_lin_y_w", + .source_names = {"recurrent_block/linear_y/kernel"}, + .axes = {1, 0}, + .shape = {layer_config.griffin_dim, layer_config.griffin_dim}, + }, + TensorInfo{ + .name = "gr_lin_y_b", + .source_names = {"recurrent_block/linear_y/bias"}, + .axes = {0}, + .shape = {layer_config.griffin_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "gr_lin_out_w", + .source_names = {"recurrent_block/linear_out/kernel"}, + .axes = {1, 0}, + .shape = {layer_config.griffin_dim, layer_config.griffin_dim}, + }, + TensorInfo{ + .name = "gr_lin_out_b", + .source_names = {"recurrent_block/linear_out/bias"}, + .axes = {0}, + .shape = {layer_config.griffin_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "gr_conv_w", + .source_names = {"recurrent_block/conv_1d/w"}, + .axes = {0, 1}, + .shape = {layer_config.conv1d_width, layer_config.griffin_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "gr_conv_b", + .source_names = {"recurrent_block/conv_1d/b"}, + .axes = {0}, + .shape = {layer_config.griffin_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "gr1_gate_w", + .source_names = {"recurrent_block/rg_lru/input_gate/w"}, + .axes = {0, 2, 1}, + .shape = {layer_config.heads, + layer_config.griffin_dim / layer_config.heads, + layer_config.griffin_dim / layer_config.heads}, + .concat_names = {"gr_gate_w", "gr2_gate_w"}, + }, + TensorInfo{ + .name = "gr2_gate_w", + .source_names = {"recurrent_block/rg_lru/a_gate/w"}, + .axes = {0, 2, 1}, + .shape = {layer_config.heads, + layer_config.griffin_dim / layer_config.heads, + layer_config.griffin_dim / layer_config.heads}, + .concat_names = {""}, + }, + TensorInfo{ + .name = "gr_gate_w", + .source_names = {"recurrent_block/rg_lru/gate/w"}, + .axes = {0, 2, 1}, + .shape = {2 * layer_config.heads, + layer_config.griffin_dim / layer_config.heads, + layer_config.griffin_dim / layer_config.heads}, + }, + TensorInfo{ + .name = "gr1_gate_b", + .source_names = {"recurrent_block/rg_lru/input_gate/b"}, + .axes = {0}, + .shape = {layer_config.griffin_dim}, + .concat_names = {"gr_gate_b", "gr2_gate_b"}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "gr2_gate_b", + .source_names = {"recurrent_block/rg_lru/a_gate/b"}, + .axes = {0}, + .shape = {layer_config.griffin_dim}, + .concat_names = {""}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "gr_gate_b", + .source_names = {"recurrent_block/rg_lru/input_gate/b"}, + .axes = {0, 1}, + .shape = {2 * layer_config.griffin_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "gr_a", + .source_names = {"recurrent_block/rg_lru/a_param"}, + .axes = {0}, + .shape = {layer_config.griffin_dim}, + .min_size = Type::kF32, + .scaled_softplus = true, + }, + + TensorInfo{ + .name = "gating_ein", + .source_names = {"mlp/gating_einsum/w", "mlp/gating_einsum", + "mlp_block/ffw_up/w"}, + .axes = {0, layer_config.optimized_gating ? 1u : 2u, + layer_config.optimized_gating ? 2u : 1u}, + .shape = {2, layer_config.ff_hidden_dim, config.model_dim}, + }, + TensorInfo{ + .name = "gating1_w", + .source_names = {"none"}, + .axes = {0, layer_config.optimized_gating ? 1u : 2u, + layer_config.optimized_gating ? 2u : 1u}, + .shape = {layer_config.ff_hidden_dim, config.model_dim}, + }, + TensorInfo{ + .name = "gating2_w", + .source_names = {"none"}, + .axes = {0, layer_config.optimized_gating ? 1u : 2u, + layer_config.optimized_gating ? 2u : 1u}, + .shape = {layer_config.ff_hidden_dim, config.model_dim}, + }, + TensorInfo{ + .name = "linear_w", + .source_names = {"mlp/linear/w", "mlp/linear", + "mlp_block/ffw_down/kernel"}, + .axes = {1, 0}, + .shape = {config.model_dim, layer_config.ff_hidden_dim}, + }, + TensorInfo{ + .name = "pre_att_ns", + .source_names = {"pre_attention_norm/scale", + "temporal_pre_norm/scale"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "pre_ff_ns", + .source_names = {"pre_ffw_norm/scale", "channel_pre_norm/scale"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "post_att_ns", + .source_names = {"post_attention_norm/scale"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "post_ff_ns", + .source_names = {"post_ffw_norm/scale"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kBF16, + }, + TensorInfo{ + .name = "ffw_gat_b", + .source_names = {"mlp_block/ffw_up/b"}, + .axes = {0}, + .shape = {2 * layer_config.ff_hidden_dim}, + .min_size = Type::kF32, + }, + TensorInfo{ + .name = "ffw_out_b", + .source_names = {"mlp_block/ffw_down/bias"}, + .axes = {0}, + .shape = {config.model_dim}, + .min_size = Type::kF32, + }, + }; + if (reshape_att) { + tensors.push_back(TensorInfo{ + .name = "att_w", + .source_names = {"attn/attn_vec_einsum/w", + "attention_block/proj_final/kernel"}, + .axes = {2, 0, 1}, + .shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim}, + .cols_take_extra_dims = true, + }); + tensors.push_back(TensorInfo{ + .name = "att_ein", + .shape = {layer_config.heads, config.model_dim, layer_config.qkv_dim}, + }); + } else { + tensors.push_back(TensorInfo{ + .name = "att_ein", + .source_names = {"attn/attn_vec_einsum/w", + "attention_block/proj_final/kernel"}, + .preshape = {layer_config.heads, layer_config.qkv_dim, + config.model_dim}, + .axes = {0, 2, 1}, + .shape = {layer_config.heads, config.model_dim, layer_config.qkv_dim}, + }); + tensors.push_back(TensorInfo{ + .name = "att_w", + .shape = {config.model_dim, layer_config.heads, layer_config.qkv_dim}, + .cols_take_extra_dims = true, + }); + } + return tensors; +} + +} // namespace + +TensorIndex::TensorIndex(const ModelConfig& config, int llm_layer_idx, + int img_layer_idx, bool reshape_att) + : config_(config), + llm_layer_idx_(llm_layer_idx), + img_layer_idx_(img_layer_idx) { + int layer_idx = std::max(llm_layer_idx_, img_layer_idx_); + std::string suffix; + if (layer_idx >= 0) { + suffix = "_" + std::to_string(layer_idx); + } + if (llm_layer_idx < 0 && img_layer_idx < 0) { + tensors_ = ModelTensors(config); + } else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx && + img_layer_idx < config.vit_layer_configs.size()) { + const auto& layer_config = config.vit_layer_configs[img_layer_idx]; + tensors_ = ImageLayerTensors(config, layer_config); + } else if (0 <= llm_layer_idx && + llm_layer_idx < config.layer_configs.size()) { + const auto& layer_config = config.layer_configs[llm_layer_idx]; + tensors_ = LLMLayerTensors(config, layer_config, reshape_att); + } + for (size_t i = 0; i < tensors_.size(); ++i) { + std::string key = tensors_[i].name + suffix; + name_map_.insert({key, i}); + } +} + +TensorInfo TensorIndex::GetTensorInfo(const std::string& path) const { + for (const auto& tensor : tensors_) { + for (const auto& source_name : tensor.source_names) { + auto pos = path.rfind(source_name); + if (pos != std::string::npos && path.size() == pos + source_name.size()) + return tensor; + } + } + return TensorInfo(); +} + +const TensorInfo* TensorIndex::FindName(const std::string& name) const { + std::string name_to_find = name; + if (!std::isdigit(name[name.size() - 1])) { + if (img_layer_idx_ >= 0 && llm_layer_idx_ < 0) { + name_to_find = name + "_" + std::to_string(img_layer_idx_); + } else if (llm_layer_idx_ >= 0) { + name_to_find = name + "_" + std::to_string(llm_layer_idx_); + } + } + auto it = name_map_.find(name_to_find); + if (it == name_map_.end()) { + return nullptr; + } + return &tensors_[it->second]; +} + +} // namespace gcpp diff --git a/gemma/tensor_index.h b/gemma/tensor_index.h new file mode 100644 index 0000000..a1acfd6 --- /dev/null +++ b/gemma/tensor_index.h @@ -0,0 +1,91 @@ +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_ + +#include + +#include +#include +#include + +#include "compression/shared.h" +#include "gemma/configs.h" + +namespace gcpp { + +// Universal tensor information. Holds enough information to construct a +// tensor in LayerWeightsPtrs/ModelWeightsPtrs, as well as to export the +// tensor from the python model with necessary transpose/reshape info. +struct TensorInfo { + // The name of the tensor in the sbs file + std::string name; + // Strings to match to the end of the name of the tensor in the python model. + std::vector source_names; + // Initial reshape shape. Use only as a last resort when input may have + // dimensions combined that need to be split before the transpose, as it + // defeats the post-transpose shape checking. Normally empty. + std::vector preshape; + // Transpose axes arg. If the input tensor has more dimensions than axes, + // then leading dimensions are collapsed until the number of axes matches. + std::vector axes; + // Expected final shape of the tensor after reshape/transpose. + // Note that this is the shape of the tensor during export, + // not the shape of the tensor in the sbs file, as the sbs file + // is restricted to 2D tensors. With few exceptions, the sbs file + // tensor rows gather all the excess dimensions. See cols_take_extra_dims. + std::vector shape; + // List of names to concatenate with, used only if multiple tensors are + // concatenated into one. The first tensor in the concatenation should have + // concat names thus: The first name is the name of the result, and the + // tensors with the remaining names are concatenated after this. + // The remaining tensors to be concatenated should have just a single + // empty string in concat_names to indicate that they have been consumed. + std::vector concat_names; + // Axis at which to concatenate. + size_t concat_axis = 0; + // The minimum compression weight type for this tensor. The default is + // kNUQ, which provides maximum compression. Other values such as kBF16 + // or kF32 can be used to limit the compression to a specific type. + Type min_size = Type::kNUQ; + // Whether to apply scaled softplus to the data. + bool scaled_softplus = false; + // Whether the columns or the rows take any extra dimensions. + // If false, then [10, 20, 30] -> [10*20, 30] and [30] -> [1, 30]. + // If true, then [10, 20, 30] -> [10, 20*30] and [30] -> [1, 30]. + bool cols_take_extra_dims = false; +}; + +// Universal index of tensor information, which can be built for a specific +// layer_idx. +class TensorIndex { + public: + // Builds a list of TensorInfo for the given layer_idx. + // If reshape_att is true, the attn_vec_einsum tensor is reshaped. + TensorIndex(const ModelConfig& config, int llm_layer_idx, int img_layer_idx, + bool reshape_att); + ~TensorIndex() = default; + + // Returns the TensorInfo whose source_name matches the end of the given path, + // or an empty TensorInfo if not found. + // NOTE: that the returned TensorInfo is a copy, so that the source + // TensorIndex can be destroyed without affecting the returned TensorInfo. + TensorInfo GetTensorInfo(const std::string& path) const; + + // Returns the TensorInfo for the given tensor name, for concise construction + // of ModelWeightsPtrs/LayerWeightsPtrs. + const TensorInfo* FindName(const std::string& name) const; + + private: + // Config that was used to build the tensor index. + const ModelConfig& config_; + // Layer that this tensor index is for - either LLM or image. + int llm_layer_idx_; + int img_layer_idx_; + // List of tensor information for this layer. + std::vector tensors_; + // Map from tensor name to index in tensors_. + std::unordered_map name_map_; +}; + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_ diff --git a/gemma/tensor_index_test.cc b/gemma/tensor_index_test.cc new file mode 100644 index 0000000..7fd1268 --- /dev/null +++ b/gemma/tensor_index_test.cc @@ -0,0 +1,73 @@ +#include "gemma/tensor_index.h" + +#include +#include +#include +#include +#include + +#include "gtest/gtest.h" +#include "compression/compress.h" +#include "compression/shared.h" +#include "gemma/configs.h" +#include "gemma/weights.h" +#include "util/basics.h" +#include "hwy/aligned_allocator.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + +namespace gcpp { +namespace { + +// Tests that each tensor in the model can be found by exactly one TensorIndex, +// and that the TensorIndex returns the correct shape and name for the tensor, +// for all models. +TEST(TensorIndexTest, FindName) { + hwy::ThreadPool pool(4); + for (Model model : kAllModels) { + fprintf(stderr, "Testing model %d\n", static_cast(model)); + ModelConfig config = ConfigFromModel(model); + std::vector tensor_indexes; + tensor_indexes.emplace_back(config, /*llm_layer_idx=*/-1, + /*img_layer_idx=*/-1, + /*split_and_reshape=*/false); + for (size_t llm_layer_idx = 0; llm_layer_idx < config.layer_configs.size(); + ++llm_layer_idx) { + tensor_indexes.emplace_back(config, static_cast(llm_layer_idx), + /*img_layer_idx=*/-1, + /*split_and_reshape=*/false); + } + for (size_t img_layer_idx = 0; + img_layer_idx < config.vit_layer_configs.size(); + ++img_layer_idx) { + tensor_indexes.emplace_back(config, /*llm_layer_idx=*/-1, + static_cast(img_layer_idx), + /*split_and_reshape=*/false); + } + // For each tensor in any model, exactly one TensorIndex should find it. + ModelWeightsPtrs weights(config, pool); + ModelWeightsPtrs::ForEachTensor( + {&weights}, ForEachType::kInitNoToc, + [&tensor_indexes](const char* name, hwy::Span tensors) { + int num_found = 0; + const MatPtr& tensor = *tensors[0]; + for (const auto& tensor_index : tensor_indexes) { + // Skip the type marker prefix, but we want the layer index suffix. + std::string name_to_find(name + 1, strlen(name) - 1); + const TensorInfo* info = tensor_index.FindName(name_to_find); + if (info != nullptr) { + // Test that the MatPtr can be constructed from the TensorInfo, + // and that the dimensions match. + MatPtrT mat_ptr(tensor.Name(), tensor_index); + EXPECT_EQ(tensor.Name(), mat_ptr.Name()) << "on tensor " << name; + EXPECT_EQ(tensor.Rows(), mat_ptr.Rows()) << "on tensor " << name; + EXPECT_EQ(tensor.Cols(), mat_ptr.Cols()) << "on tensor " << name; + ++num_found; + } + } + EXPECT_EQ(num_found, 1) << " for tensor " << name; + }); + } +} + +} // namespace +} // namespace gcpp diff --git a/gemma/weights.h b/gemma/weights.h index ce2df43..b9acf89 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -57,8 +57,8 @@ template struct LayerWeightsPtrs { // Large data is constructed separately. explicit LayerWeightsPtrs(const LayerConfig& config) - : attn_vec_einsum_w("att_ein", config.model_dim, - config.heads * config.qkv_dim), + : attn_vec_einsum_w("att_ein", config.heads * config.model_dim, + config.qkv_dim), qkv_einsum_w("qkv_ein", (config.heads + 2 * config.kv_heads) * config.qkv_dim, config.model_dim), @@ -86,8 +86,8 @@ struct LayerWeightsPtrs { .gate_biases = {"gr_gate_b", 1, config.griffin_dim * 2}, .a = {"gr_a", 1, config.griffin_dim}}), // MultiHeadDotProductAttention. - vit({.attn_out_w = {"attn_out_w", config.heads * config.qkv_dim, - config.model_dim}, + vit({.attn_out_w = {"attn_out_w", config.model_dim, + config.heads * config.qkv_dim}, .attn_out_b = {"attn_out_b", 1, config.model_dim}, .qkv_einsum_w = {"qkv_ein_w", (config.heads + 2 * config.kv_heads) * @@ -349,9 +349,8 @@ struct ModelWeightsPtrs { vit_encoder_norm_bias("enc_norm_bias", 1, config.vit_model_dim), vit_encoder_norm_scale("enc_norm_scale", 1, config.vit_model_dim), vit_img_embedding_bias("img_emb_bias", 1, config.vit_model_dim), - vit_img_embedding_kernel("img_emb_kernel", - config.patch_width * config.patch_width * 3, - config.vit_model_dim), + vit_img_embedding_kernel("img_emb_kernel", config.vit_model_dim, + config.patch_width * config.patch_width * 3), vit_img_pos_embedding("img_pos_emb", 256, config.vit_model_dim), vit_img_head_bias("img_head_bias", 1, config.model_dim), vit_img_head_kernel("img_head_kernel", config.model_dim,