From 160a5824fb9c49bfa1176b308400ffb9018bfc73 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Tue, 22 Apr 2025 12:01:00 -0700 Subject: [PATCH] Cleanup: include fixes/comments, fix leak, vector reserve Also remove unused RowSpan configs.cc: Assign prompt wrapping to ModelConfig configs.h: simplify EnumValid via sentinel PiperOrigin-RevId: 750278497 --- BUILD.bazel | 7 +--- backprop/backward-inl.h | 3 +- backprop/forward-inl.h | 2 +- evals/benchmark.cc | 1 - evals/benchmark_helper.h | 2 + evals/cross_entropy.cc | 1 - evals/gemma_batch_bench.cc | 7 ++-- gemma/configs.cc | 14 +++++++ gemma/configs.h | 82 +++++++++++++++++++++++++------------ gemma/gemma_args.h | 6 +-- gemma/kv_cache.cc | 2 +- gemma/kv_cache.h | 2 +- ops/ops_test.cc | 2 +- paligemma/paligemma_test.cc | 7 +--- util/mat.h | 12 +----- util/threading_context.h | 1 + 16 files changed, 91 insertions(+), 60 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 85ba3f7..970e2f8 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -443,22 +443,19 @@ cc_library( "mem": "28g", }, deps = [ - ":allocator", - ":basics", ":benchmark_helper", ":common", ":gemma_args", ":gemma_lib", ":kv_cache", - ":mat", ":ops", ":threading", ":threading_context", ":tokenizer", - ":weights", - "//compression:shared", "//paligemma:image", "@highway//:hwy", + "@highway//:profiler", + "@highway//:timer", ], ) diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h index 9716d87..fbc59e2 100644 --- a/backprop/backward-inl.h +++ b/backprop/backward-inl.h @@ -27,7 +27,8 @@ #include "backprop/activations.h" #include "backprop/prompt.h" -#include "gemma/common.h" +#include "gemma/common.h" // EmbeddingScaling +#include "gemma/configs.h" // LayerConfig, ModelConfig #include "gemma/weights.h" #include "util/allocator.h" #include "hwy/base.h" diff --git a/backprop/forward-inl.h b/backprop/forward-inl.h index 75de9a2..0730dbe 100644 --- a/backprop/forward-inl.h +++ b/backprop/forward-inl.h @@ -24,7 +24,7 @@ #include #include "backprop/activations.h" -#include "gemma/common.h" +#include "gemma/common.h" // EmbeddingScaling #include "gemma/configs.h" #include "gemma/weights.h" #include "util/allocator.h" diff --git a/evals/benchmark.cc b/evals/benchmark.cc index 579a64f..18f39e0 100644 --- a/evals/benchmark.cc +++ b/evals/benchmark.cc @@ -12,7 +12,6 @@ #include "compression/io.h" // Path #include "evals/benchmark_helper.h" #include "evals/cross_entropy.h" -#include "gemma/common.h" #include "gemma/gemma.h" #include "util/args.h" #include "hwy/base.h" diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index c2772f8..75379d9 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -49,6 +49,8 @@ class GemmaEnv { GemmaEnv(int argc, char** argv); GemmaEnv(const ThreadingArgs& threading, const LoaderArgs& loader, const InferenceArgs& inference); + // Avoid memory leaks in test. + ~GemmaEnv() { ThreadingContext2::ThreadHostileInvalidate(); } MatMulEnv& Env() { return env_; } diff --git a/evals/cross_entropy.cc b/evals/cross_entropy.cc index e4bf1b1..a32873c 100644 --- a/evals/cross_entropy.cc +++ b/evals/cross_entropy.cc @@ -38,7 +38,6 @@ #include #include "evals/cross_entropy.h" -#include "gemma/common.h" #include "gemma/gemma.h" #include "hwy/base.h" diff --git a/evals/gemma_batch_bench.cc b/evals/gemma_batch_bench.cc index c92194c..f2b3a3b 100644 --- a/evals/gemma_batch_bench.cc +++ b/evals/gemma_batch_bench.cc @@ -13,15 +13,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "gemma/gemma.h" - #include #include #include #include "evals/benchmark_helper.h" -#include "gemma/common.h" +#include "gemma/configs.h" +#include "gemma/gemma.h" #include "hwy/base.h" #include "hwy/tests/hwy_gtest.h" @@ -65,6 +64,7 @@ class GemmaTest : public ::testing::Test { prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string)); } std::vector prompt_spans; + prompt_spans.reserve(prompts_vector.size()); for (const auto& prompt : prompts_vector) { prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size())); } @@ -79,6 +79,7 @@ class GemmaTest : public ::testing::Test { ASSERT_NE(s_env->GetGemma(), nullptr); std::vector inputs; + inputs.reserve(num_questions); for (size_t i = 0; i < num_questions; ++i) { inputs.push_back(kQA[i]); } diff --git a/gemma/configs.cc b/gemma/configs.cc index 276c8f9..2f18c0b 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -187,6 +187,7 @@ static ModelConfig ConfigGemmaTiny() { ModelConfig config = ConfigNoSSM(); config.model_name = "GemmaTiny"; config.model = Model::GEMMA_TINY; + config.wrapping = PromptWrapping::GEMMA_IT; config.model_dim = 128; config.vocab_size = 64; config.seq_len = 32; @@ -277,6 +278,7 @@ static ModelConfig ConfigPaliGemma_224() { ModelConfig config = ConfigGemma2B(); config.model_name = "PaliGemma_224"; config.model = Model::PALIGEMMA_224; + config.wrapping = PromptWrapping::PALIGEMMA; AddVitConfig(config); return config; } @@ -285,6 +287,7 @@ static ModelConfig ConfigPaliGemma_448() { ModelConfig config = ConfigGemma2B(); config.model_name = "PaliGemma_448"; config.model = Model::PALIGEMMA_448; + config.wrapping = PromptWrapping::PALIGEMMA; AddVitConfig(config, /*image_size=*/448); return config; } @@ -305,6 +308,7 @@ static ModelConfig ConfigPaliGemma2_3B_224() { ModelConfig config = ConfigGemma2_2B(); config.model_name = "PaliGemma2_3B_224"; config.model = Model::PALIGEMMA2_3B_224; + config.wrapping = PromptWrapping::PALIGEMMA; AddVitConfig(config); return config; } @@ -313,6 +317,7 @@ static ModelConfig ConfigPaliGemma2_3B_448() { ModelConfig config = ConfigGemma2_2B(); config.model_name = "PaliGemma2_3B_448"; config.model = Model::PALIGEMMA2_3B_448; + config.wrapping = PromptWrapping::PALIGEMMA; AddVitConfig(config, /*image_size=*/448); return config; } @@ -321,6 +326,7 @@ static ModelConfig ConfigPaliGemma2_10B_224() { ModelConfig config = ConfigGemma2_9B(); config.model_name = "PaliGemma2_10B_224"; config.model = Model::PALIGEMMA2_10B_224; + config.wrapping = PromptWrapping::PALIGEMMA; AddVitConfig(config); return config; } @@ -329,6 +335,7 @@ static ModelConfig ConfigPaliGemma2_10B_448() { ModelConfig config = ConfigGemma2_9B(); config.model_name = "PaliGemma2_10B_448"; config.model = Model::PALIGEMMA2_10B_448; + config.wrapping = PromptWrapping::PALIGEMMA; AddVitConfig(config, /*image_size=*/448); return config; } @@ -360,6 +367,7 @@ static ModelConfig ConfigGemma3_1B() { ModelConfig config = ConfigBaseGemmaV3(); config.model_name = "Gemma3_1B"; config.model = Model::GEMMA3_1B; + config.wrapping = PromptWrapping::GEMMA_VLM; config.model_dim = 1152; config.vocab_size = 262144; // new vocab size / tokenizer config.seq_len = 32 * 1024; @@ -391,6 +399,7 @@ static ModelConfig ConfigGemma3_4B_LM() { ModelConfig config = ConfigBaseGemmaV3(); config.model_name = "Gemma3_4B"; config.model = Model::GEMMA3_4B; + config.wrapping = PromptWrapping::GEMMA_VLM; config.model_dim = 2560; config.vocab_size = 262144; // new vocab size / tokenizer config.seq_len = 32 * 1024; @@ -408,6 +417,7 @@ static ModelConfig ConfigGemma3_4B() { ModelConfig config = ConfigGemma3_4B_LM(); config.model_name = "Gemma3_4B"; config.model = Model::GEMMA3_4B; + config.wrapping = PromptWrapping::GEMMA_VLM; AddVitConfig(config, /*image_size=*/896); config.vocab_size = 262144; config.vit_config.pool_dim = 4; @@ -438,6 +448,7 @@ static ModelConfig ConfigGemma3_12B_LM() { ModelConfig config = ConfigBaseGemmaV3(); config.model_name = "Gemma3_12B"; config.model = Model::GEMMA3_12B; + config.wrapping = PromptWrapping::GEMMA_VLM; config.model_dim = 3840; config.vocab_size = 262144; // new vocab size / tokenizer config.seq_len = 32 * 1024; @@ -455,6 +466,7 @@ static ModelConfig ConfigGemma3_12B() { ModelConfig config = ConfigGemma3_12B_LM(); config.model_name = "Gemma3_12B"; config.model = Model::GEMMA3_12B; + config.wrapping = PromptWrapping::GEMMA_VLM; AddVitConfig(config, /*image_size=*/896); config.vocab_size = 262144; config.vit_config.pool_dim = 4; @@ -485,6 +497,7 @@ static ModelConfig ConfigGemma3_27B_LM() { ModelConfig config = ConfigBaseGemmaV3(); config.model_name = "Gemma3_27B"; config.model = Model::GEMMA3_27B; + config.wrapping = PromptWrapping::GEMMA_VLM; config.model_dim = 5376; config.vocab_size = 262144; // new vocab size / tokenizer config.seq_len = 32 * 1024; @@ -502,6 +515,7 @@ static ModelConfig ConfigGemma3_27B() { ModelConfig config = ConfigGemma3_27B_LM(); config.model_name = "Gemma3_27B"; config.model = Model::GEMMA3_27B; + config.wrapping = PromptWrapping::GEMMA_VLM; AddVitConfig(config, /*image_size=*/896); config.vocab_size = 262144; config.vit_config.pool_dim = 4; diff --git a/gemma/configs.h b/gemma/configs.h index 77d063a..483b35b 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -19,10 +19,9 @@ // Model configurations #include +#include -#include #include -#include #include #include #include @@ -53,11 +52,27 @@ using EmbedderInputT = BF16; enum class PromptWrapping { GEMMA_IT, GEMMA_PT, - GEMMA_VLM, + GEMMA_VLM, // for >1B Gemma3 PALIGEMMA, kSentinel // must be last }; +// Defined as the suffix for use with `ModelString`. +static inline const char* ToString(PromptWrapping wrapping) { + switch (wrapping) { + case PromptWrapping::GEMMA_IT: + return "-it"; + case PromptWrapping::GEMMA_PT: + return "-pt"; + case PromptWrapping::GEMMA_VLM: + return "-vlm"; + case PromptWrapping::PALIGEMMA: + return "-pg"; + default: + return "-?"; + } +} + static inline bool EnumValid(PromptWrapping wrapping) { return static_cast(wrapping) < static_cast(PromptWrapping::kSentinel); @@ -69,63 +84,68 @@ enum class LayerAttentionType { kVit, }; -inline bool EnumValid(LayerAttentionType type) { - return static_cast(type) >= 0 && - static_cast(type) <= static_cast(LayerAttentionType::kVit); +static inline bool EnumValid(LayerAttentionType type) { + return type == LayerAttentionType::kGemma || + type == LayerAttentionType::kGriffinRecurrentBlock || + type == LayerAttentionType::kVit; } // Post attention and ffw normalization type. enum class PostNormType { None, Scale, + kSentinel // must be last }; -inline bool EnumValid(PostNormType type) { - return static_cast(type) >= 0 && - static_cast(type) <= static_cast(PostNormType::Scale); +static inline bool EnumValid(PostNormType type) { + return static_cast(type) < + static_cast(PostNormType::kSentinel); } // Post qk projection operation type. enum class PostQKType { Rope, HalfRope, + kSentinel // must be last }; -inline bool EnumValid(PostQKType type) { - return static_cast(type) >= 0 && - static_cast(type) <= static_cast(PostQKType::HalfRope); +static inline bool EnumValid(PostQKType type) { + return static_cast(type) < + static_cast(PostNormType::kSentinel); } // FFW activation function. enum class ActivationType { Gelu, + kSentinel // must be last }; -inline bool EnumValid(ActivationType type) { - return static_cast(type) >= 0 && - static_cast(type) <= static_cast(ActivationType::Gelu); +static inline bool EnumValid(ActivationType type) { + return static_cast(type) < + static_cast(ActivationType::kSentinel); } // Attention query scale. enum class QueryScaleType { SqrtKeySize, SqrtModelDimDivNumHeads, + kSentinel // must be last }; -inline bool EnumValid(QueryScaleType type) { - return static_cast(type) >= 0 && - static_cast(type) <= - static_cast(QueryScaleType::SqrtModelDimDivNumHeads); +static inline bool EnumValid(QueryScaleType type) { + return static_cast(type) < + static_cast(QueryScaleType::kSentinel); } // Residual connection type. enum class ResidualType { Add, + kSentinel // must be last }; -inline bool EnumValid(ResidualType type) { - return static_cast(type) >= 0 && - static_cast(type) <= static_cast(ResidualType::Add); +static inline bool EnumValid(ResidualType type) { + return static_cast(type) < + static_cast(ResidualType::kSentinel); } template @@ -169,6 +189,7 @@ enum class Model { GEMMA3_1B, GEMMA3_12B, GEMMA3_27B, + kSentinel, }; // Allows the Model enum to be iterated over. @@ -181,9 +202,18 @@ static constexpr Model kAllModels[] = { Model::GEMMA3_12B, Model::GEMMA3_27B, }; -inline bool EnumValid(Model model) { - for (Model m : kAllModels) { - if (m == model) return true; +template +void ForEachModel(const Func& func) { + for (size_t i = static_cast(Model::UNKNOWN) + 1; + i < static_cast(Model::kSentinel); ++i) { + func(static_cast(i)); + } +} + +static inline bool EnumValid(Model model) { + const size_t i = static_cast(model); + if (i < static_cast(Model::kSentinel)) { + return true; } return false; } @@ -301,7 +331,7 @@ struct ModelConfig : public IFields { size_t NumHeads() const { uint32_t num_heads = 0; for (const auto& layer_config : layer_configs) { - num_heads = std::max(num_heads, layer_config.heads); + num_heads = HWY_MAX(num_heads, layer_config.heads); } return num_heads; } diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index d02dece..63f191a 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -26,12 +26,12 @@ #include "compression/io.h" // Path #include "compression/shared.h" -#include "gemma/common.h" +#include "gemma/configs.h" #include "gemma/gemma.h" // For CreateGemma -#include "hwy/base.h" // HWY_ABORT #include "ops/matmul.h" #include "util/args.h" #include "util/basics.h" // Tristate +#include "hwy/base.h" // HWY_ABORT namespace gcpp { @@ -237,4 +237,4 @@ struct InferenceArgs : public ArgsBase { } // namespace gcpp -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ \ No newline at end of file +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index 60ad5dd..d3c2372 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -17,7 +17,7 @@ #include -#include "gemma/common.h" // CallForModel +#include "gemma/configs.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" // ZeroBytes diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index 6052d0b..907bee3 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -18,7 +18,7 @@ #include -#include "gemma/common.h" // Model +#include "gemma/configs.h" // ModelConfig #include "hwy/aligned_allocator.h" namespace gcpp { diff --git a/ops/ops_test.cc b/ops/ops_test.cc index b44c3f7..6ff7816 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -31,7 +31,7 @@ #include #include -#include "gemma/common.h" +#include "gemma/common.h" // ChooseQueryScale #include "util/allocator.h" #include "util/basics.h" // BF16 #include "util/mat.h" // RowVectorBatch diff --git a/paligemma/paligemma_test.cc b/paligemma/paligemma_test.cc index 2453822..95dce0d 100644 --- a/paligemma/paligemma_test.cc +++ b/paligemma/paligemma_test.cc @@ -19,7 +19,6 @@ #include "compression/shared.h" #include "evals/benchmark_helper.h" -#include "gemma/common.h" #include "gemma/configs.h" #include "gemma/gemma.h" #include "util/allocator.h" @@ -27,11 +26,7 @@ #include "hwy/tests/hwy_gtest.h" // This test can be run manually with the downloaded PaliGemma weights. -// To run the test, pass the following flags: -// --model paligemma-224 --tokenizer --weights -// or just use the single-file weights file with --weights . -// It should pass for the following models: -// paligemma-3b-mix-224, paligemma2-3b-pt-448 +// It should pass for `paligemma-3b-mix-224` and `paligemma2-3b-pt-448`. namespace gcpp { namespace { diff --git a/util/mat.h b/util/mat.h index d1c7c9d..e9b5189 100644 --- a/util/mat.h +++ b/util/mat.h @@ -251,19 +251,11 @@ class MatPtrT : public MatPtr { HWY_ASSERT(IsPacked()); return MakeSpan(Row(0), num_elements_); } - - // For when a span of a single row is required. This also works if padded, - // but does not support `GetType() == kNUQ`, because that requires the use of - // offsets instead of a row pointer. Used by `gemma-inl.h` to decompress - // embeddings. - PackedSpan RowSpan(size_t row) const { - HWY_DASSERT(GetType() != Type::kNUQ); - return MakeConstSpan(Row(row), Cols()); - } }; // Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT`, plus the -// optional `args`. +// optional `args`. Currently unused but may be used after we move toward +// type-erased `WeightsPtrs`. template decltype(auto) CallUpcasted(Type type, MatPtr* base, const Func& func, Args&&... args) { diff --git a/util/threading_context.h b/util/threading_context.h index a59dcdd..0f9d569 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -118,6 +118,7 @@ class ThreadingContext2 { // changing the arguments between tests. Callers must again call `Get` // afterwards to obtain an instance. WARNING: must not be called concurrently // with other calls to `Get` and usages of its return value. + // Also useful to suppress memory leak warnings in tests. static void ThreadHostileInvalidate(); explicit ThreadingContext2(PrivateToken); // only called via `Get`.