Cleanup: include fixes/comments, fix leak, vector reserve

Also remove unused RowSpan
configs.cc: Assign prompt wrapping to ModelConfig
configs.h: simplify EnumValid via sentinel

PiperOrigin-RevId: 750278497
This commit is contained in:
Jan Wassenberg 2025-04-22 12:01:00 -07:00 committed by Copybara-Service
parent ba10c88a94
commit 160a5824fb
16 changed files with 91 additions and 60 deletions

View File

@ -443,22 +443,19 @@ cc_library(
"mem": "28g", "mem": "28g",
}, },
deps = [ deps = [
":allocator",
":basics",
":benchmark_helper", ":benchmark_helper",
":common", ":common",
":gemma_args", ":gemma_args",
":gemma_lib", ":gemma_lib",
":kv_cache", ":kv_cache",
":mat",
":ops", ":ops",
":threading", ":threading",
":threading_context", ":threading_context",
":tokenizer", ":tokenizer",
":weights",
"//compression:shared",
"//paligemma:image", "//paligemma:image",
"@highway//:hwy", "@highway//:hwy",
"@highway//:profiler",
"@highway//:timer",
], ],
) )

View File

@ -27,7 +27,8 @@
#include "backprop/activations.h" #include "backprop/activations.h"
#include "backprop/prompt.h" #include "backprop/prompt.h"
#include "gemma/common.h" #include "gemma/common.h" // EmbeddingScaling
#include "gemma/configs.h" // LayerConfig, ModelConfig
#include "gemma/weights.h" #include "gemma/weights.h"
#include "util/allocator.h" #include "util/allocator.h"
#include "hwy/base.h" #include "hwy/base.h"

View File

@ -24,7 +24,7 @@
#include <vector> #include <vector>
#include "backprop/activations.h" #include "backprop/activations.h"
#include "gemma/common.h" #include "gemma/common.h" // EmbeddingScaling
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "util/allocator.h" #include "util/allocator.h"

View File

@ -12,7 +12,6 @@
#include "compression/io.h" // Path #include "compression/io.h" // Path
#include "evals/benchmark_helper.h" #include "evals/benchmark_helper.h"
#include "evals/cross_entropy.h" #include "evals/cross_entropy.h"
#include "gemma/common.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "util/args.h" #include "util/args.h"
#include "hwy/base.h" #include "hwy/base.h"

View File

@ -49,6 +49,8 @@ class GemmaEnv {
GemmaEnv(int argc, char** argv); GemmaEnv(int argc, char** argv);
GemmaEnv(const ThreadingArgs& threading, const LoaderArgs& loader, GemmaEnv(const ThreadingArgs& threading, const LoaderArgs& loader,
const InferenceArgs& inference); const InferenceArgs& inference);
// Avoid memory leaks in test.
~GemmaEnv() { ThreadingContext2::ThreadHostileInvalidate(); }
MatMulEnv& Env() { return env_; } MatMulEnv& Env() { return env_; }

View File

@ -38,7 +38,6 @@
#include <vector> #include <vector>
#include "evals/cross_entropy.h" #include "evals/cross_entropy.h"
#include "gemma/common.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "hwy/base.h" #include "hwy/base.h"

View File

@ -13,15 +13,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "gemma/gemma.h"
#include <stdio.h> #include <stdio.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include "evals/benchmark_helper.h" #include "evals/benchmark_helper.h"
#include "gemma/common.h" #include "gemma/configs.h"
#include "gemma/gemma.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"
@ -65,6 +64,7 @@ class GemmaTest : public ::testing::Test {
prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string)); prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string));
} }
std::vector<PromptTokens> prompt_spans; std::vector<PromptTokens> prompt_spans;
prompt_spans.reserve(prompts_vector.size());
for (const auto& prompt : prompts_vector) { for (const auto& prompt : prompts_vector) {
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size())); prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
} }
@ -79,6 +79,7 @@ class GemmaTest : public ::testing::Test {
ASSERT_NE(s_env->GetGemma(), nullptr); ASSERT_NE(s_env->GetGemma(), nullptr);
std::vector<std::string> inputs; std::vector<std::string> inputs;
inputs.reserve(num_questions);
for (size_t i = 0; i < num_questions; ++i) { for (size_t i = 0; i < num_questions; ++i) {
inputs.push_back(kQA[i]); inputs.push_back(kQA[i]);
} }

View File

@ -187,6 +187,7 @@ static ModelConfig ConfigGemmaTiny() {
ModelConfig config = ConfigNoSSM(); ModelConfig config = ConfigNoSSM();
config.model_name = "GemmaTiny"; config.model_name = "GemmaTiny";
config.model = Model::GEMMA_TINY; config.model = Model::GEMMA_TINY;
config.wrapping = PromptWrapping::GEMMA_IT;
config.model_dim = 128; config.model_dim = 128;
config.vocab_size = 64; config.vocab_size = 64;
config.seq_len = 32; config.seq_len = 32;
@ -277,6 +278,7 @@ static ModelConfig ConfigPaliGemma_224() {
ModelConfig config = ConfigGemma2B(); ModelConfig config = ConfigGemma2B();
config.model_name = "PaliGemma_224"; config.model_name = "PaliGemma_224";
config.model = Model::PALIGEMMA_224; config.model = Model::PALIGEMMA_224;
config.wrapping = PromptWrapping::PALIGEMMA;
AddVitConfig(config); AddVitConfig(config);
return config; return config;
} }
@ -285,6 +287,7 @@ static ModelConfig ConfigPaliGemma_448() {
ModelConfig config = ConfigGemma2B(); ModelConfig config = ConfigGemma2B();
config.model_name = "PaliGemma_448"; config.model_name = "PaliGemma_448";
config.model = Model::PALIGEMMA_448; config.model = Model::PALIGEMMA_448;
config.wrapping = PromptWrapping::PALIGEMMA;
AddVitConfig(config, /*image_size=*/448); AddVitConfig(config, /*image_size=*/448);
return config; return config;
} }
@ -305,6 +308,7 @@ static ModelConfig ConfigPaliGemma2_3B_224() {
ModelConfig config = ConfigGemma2_2B(); ModelConfig config = ConfigGemma2_2B();
config.model_name = "PaliGemma2_3B_224"; config.model_name = "PaliGemma2_3B_224";
config.model = Model::PALIGEMMA2_3B_224; config.model = Model::PALIGEMMA2_3B_224;
config.wrapping = PromptWrapping::PALIGEMMA;
AddVitConfig(config); AddVitConfig(config);
return config; return config;
} }
@ -313,6 +317,7 @@ static ModelConfig ConfigPaliGemma2_3B_448() {
ModelConfig config = ConfigGemma2_2B(); ModelConfig config = ConfigGemma2_2B();
config.model_name = "PaliGemma2_3B_448"; config.model_name = "PaliGemma2_3B_448";
config.model = Model::PALIGEMMA2_3B_448; config.model = Model::PALIGEMMA2_3B_448;
config.wrapping = PromptWrapping::PALIGEMMA;
AddVitConfig(config, /*image_size=*/448); AddVitConfig(config, /*image_size=*/448);
return config; return config;
} }
@ -321,6 +326,7 @@ static ModelConfig ConfigPaliGemma2_10B_224() {
ModelConfig config = ConfigGemma2_9B(); ModelConfig config = ConfigGemma2_9B();
config.model_name = "PaliGemma2_10B_224"; config.model_name = "PaliGemma2_10B_224";
config.model = Model::PALIGEMMA2_10B_224; config.model = Model::PALIGEMMA2_10B_224;
config.wrapping = PromptWrapping::PALIGEMMA;
AddVitConfig(config); AddVitConfig(config);
return config; return config;
} }
@ -329,6 +335,7 @@ static ModelConfig ConfigPaliGemma2_10B_448() {
ModelConfig config = ConfigGemma2_9B(); ModelConfig config = ConfigGemma2_9B();
config.model_name = "PaliGemma2_10B_448"; config.model_name = "PaliGemma2_10B_448";
config.model = Model::PALIGEMMA2_10B_448; config.model = Model::PALIGEMMA2_10B_448;
config.wrapping = PromptWrapping::PALIGEMMA;
AddVitConfig(config, /*image_size=*/448); AddVitConfig(config, /*image_size=*/448);
return config; return config;
} }
@ -360,6 +367,7 @@ static ModelConfig ConfigGemma3_1B() {
ModelConfig config = ConfigBaseGemmaV3(); ModelConfig config = ConfigBaseGemmaV3();
config.model_name = "Gemma3_1B"; config.model_name = "Gemma3_1B";
config.model = Model::GEMMA3_1B; config.model = Model::GEMMA3_1B;
config.wrapping = PromptWrapping::GEMMA_VLM;
config.model_dim = 1152; config.model_dim = 1152;
config.vocab_size = 262144; // new vocab size / tokenizer config.vocab_size = 262144; // new vocab size / tokenizer
config.seq_len = 32 * 1024; config.seq_len = 32 * 1024;
@ -391,6 +399,7 @@ static ModelConfig ConfigGemma3_4B_LM() {
ModelConfig config = ConfigBaseGemmaV3(); ModelConfig config = ConfigBaseGemmaV3();
config.model_name = "Gemma3_4B"; config.model_name = "Gemma3_4B";
config.model = Model::GEMMA3_4B; config.model = Model::GEMMA3_4B;
config.wrapping = PromptWrapping::GEMMA_VLM;
config.model_dim = 2560; config.model_dim = 2560;
config.vocab_size = 262144; // new vocab size / tokenizer config.vocab_size = 262144; // new vocab size / tokenizer
config.seq_len = 32 * 1024; config.seq_len = 32 * 1024;
@ -408,6 +417,7 @@ static ModelConfig ConfigGemma3_4B() {
ModelConfig config = ConfigGemma3_4B_LM(); ModelConfig config = ConfigGemma3_4B_LM();
config.model_name = "Gemma3_4B"; config.model_name = "Gemma3_4B";
config.model = Model::GEMMA3_4B; config.model = Model::GEMMA3_4B;
config.wrapping = PromptWrapping::GEMMA_VLM;
AddVitConfig(config, /*image_size=*/896); AddVitConfig(config, /*image_size=*/896);
config.vocab_size = 262144; config.vocab_size = 262144;
config.vit_config.pool_dim = 4; config.vit_config.pool_dim = 4;
@ -438,6 +448,7 @@ static ModelConfig ConfigGemma3_12B_LM() {
ModelConfig config = ConfigBaseGemmaV3(); ModelConfig config = ConfigBaseGemmaV3();
config.model_name = "Gemma3_12B"; config.model_name = "Gemma3_12B";
config.model = Model::GEMMA3_12B; config.model = Model::GEMMA3_12B;
config.wrapping = PromptWrapping::GEMMA_VLM;
config.model_dim = 3840; config.model_dim = 3840;
config.vocab_size = 262144; // new vocab size / tokenizer config.vocab_size = 262144; // new vocab size / tokenizer
config.seq_len = 32 * 1024; config.seq_len = 32 * 1024;
@ -455,6 +466,7 @@ static ModelConfig ConfigGemma3_12B() {
ModelConfig config = ConfigGemma3_12B_LM(); ModelConfig config = ConfigGemma3_12B_LM();
config.model_name = "Gemma3_12B"; config.model_name = "Gemma3_12B";
config.model = Model::GEMMA3_12B; config.model = Model::GEMMA3_12B;
config.wrapping = PromptWrapping::GEMMA_VLM;
AddVitConfig(config, /*image_size=*/896); AddVitConfig(config, /*image_size=*/896);
config.vocab_size = 262144; config.vocab_size = 262144;
config.vit_config.pool_dim = 4; config.vit_config.pool_dim = 4;
@ -485,6 +497,7 @@ static ModelConfig ConfigGemma3_27B_LM() {
ModelConfig config = ConfigBaseGemmaV3(); ModelConfig config = ConfigBaseGemmaV3();
config.model_name = "Gemma3_27B"; config.model_name = "Gemma3_27B";
config.model = Model::GEMMA3_27B; config.model = Model::GEMMA3_27B;
config.wrapping = PromptWrapping::GEMMA_VLM;
config.model_dim = 5376; config.model_dim = 5376;
config.vocab_size = 262144; // new vocab size / tokenizer config.vocab_size = 262144; // new vocab size / tokenizer
config.seq_len = 32 * 1024; config.seq_len = 32 * 1024;
@ -502,6 +515,7 @@ static ModelConfig ConfigGemma3_27B() {
ModelConfig config = ConfigGemma3_27B_LM(); ModelConfig config = ConfigGemma3_27B_LM();
config.model_name = "Gemma3_27B"; config.model_name = "Gemma3_27B";
config.model = Model::GEMMA3_27B; config.model = Model::GEMMA3_27B;
config.wrapping = PromptWrapping::GEMMA_VLM;
AddVitConfig(config, /*image_size=*/896); AddVitConfig(config, /*image_size=*/896);
config.vocab_size = 262144; config.vocab_size = 262144;
config.vit_config.pool_dim = 4; config.vit_config.pool_dim = 4;

View File

@ -19,10 +19,9 @@
// Model configurations // Model configurations
#include <stddef.h> #include <stddef.h>
#include <stdint.h>
#include <algorithm>
#include <array> #include <array>
#include <cstdint>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
@ -53,11 +52,27 @@ using EmbedderInputT = BF16;
enum class PromptWrapping { enum class PromptWrapping {
GEMMA_IT, GEMMA_IT,
GEMMA_PT, GEMMA_PT,
GEMMA_VLM, GEMMA_VLM, // for >1B Gemma3
PALIGEMMA, PALIGEMMA,
kSentinel // must be last kSentinel // must be last
}; };
// Defined as the suffix for use with `ModelString`.
static inline const char* ToString(PromptWrapping wrapping) {
switch (wrapping) {
case PromptWrapping::GEMMA_IT:
return "-it";
case PromptWrapping::GEMMA_PT:
return "-pt";
case PromptWrapping::GEMMA_VLM:
return "-vlm";
case PromptWrapping::PALIGEMMA:
return "-pg";
default:
return "-?";
}
}
static inline bool EnumValid(PromptWrapping wrapping) { static inline bool EnumValid(PromptWrapping wrapping) {
return static_cast<size_t>(wrapping) < return static_cast<size_t>(wrapping) <
static_cast<size_t>(PromptWrapping::kSentinel); static_cast<size_t>(PromptWrapping::kSentinel);
@ -69,63 +84,68 @@ enum class LayerAttentionType {
kVit, kVit,
}; };
inline bool EnumValid(LayerAttentionType type) { static inline bool EnumValid(LayerAttentionType type) {
return static_cast<int>(type) >= 0 && return type == LayerAttentionType::kGemma ||
static_cast<int>(type) <= static_cast<int>(LayerAttentionType::kVit); type == LayerAttentionType::kGriffinRecurrentBlock ||
type == LayerAttentionType::kVit;
} }
// Post attention and ffw normalization type. // Post attention and ffw normalization type.
enum class PostNormType { enum class PostNormType {
None, None,
Scale, Scale,
kSentinel // must be last
}; };
inline bool EnumValid(PostNormType type) { static inline bool EnumValid(PostNormType type) {
return static_cast<int>(type) >= 0 && return static_cast<size_t>(type) <
static_cast<int>(type) <= static_cast<int>(PostNormType::Scale); static_cast<size_t>(PostNormType::kSentinel);
} }
// Post qk projection operation type. // Post qk projection operation type.
enum class PostQKType { enum class PostQKType {
Rope, Rope,
HalfRope, HalfRope,
kSentinel // must be last
}; };
inline bool EnumValid(PostQKType type) { static inline bool EnumValid(PostQKType type) {
return static_cast<int>(type) >= 0 && return static_cast<size_t>(type) <
static_cast<int>(type) <= static_cast<int>(PostQKType::HalfRope); static_cast<size_t>(PostNormType::kSentinel);
} }
// FFW activation function. // FFW activation function.
enum class ActivationType { enum class ActivationType {
Gelu, Gelu,
kSentinel // must be last
}; };
inline bool EnumValid(ActivationType type) { static inline bool EnumValid(ActivationType type) {
return static_cast<int>(type) >= 0 && return static_cast<size_t>(type) <
static_cast<int>(type) <= static_cast<int>(ActivationType::Gelu); static_cast<size_t>(ActivationType::kSentinel);
} }
// Attention query scale. // Attention query scale.
enum class QueryScaleType { enum class QueryScaleType {
SqrtKeySize, SqrtKeySize,
SqrtModelDimDivNumHeads, SqrtModelDimDivNumHeads,
kSentinel // must be last
}; };
inline bool EnumValid(QueryScaleType type) { static inline bool EnumValid(QueryScaleType type) {
return static_cast<int>(type) >= 0 && return static_cast<size_t>(type) <
static_cast<int>(type) <= static_cast<size_t>(QueryScaleType::kSentinel);
static_cast<int>(QueryScaleType::SqrtModelDimDivNumHeads);
} }
// Residual connection type. // Residual connection type.
enum class ResidualType { enum class ResidualType {
Add, Add,
kSentinel // must be last
}; };
inline bool EnumValid(ResidualType type) { static inline bool EnumValid(ResidualType type) {
return static_cast<int>(type) >= 0 && return static_cast<size_t>(type) <
static_cast<int>(type) <= static_cast<int>(ResidualType::Add); static_cast<size_t>(ResidualType::kSentinel);
} }
template <size_t kNum> template <size_t kNum>
@ -169,6 +189,7 @@ enum class Model {
GEMMA3_1B, GEMMA3_1B,
GEMMA3_12B, GEMMA3_12B,
GEMMA3_27B, GEMMA3_27B,
kSentinel,
}; };
// Allows the Model enum to be iterated over. // Allows the Model enum to be iterated over.
@ -181,9 +202,18 @@ static constexpr Model kAllModels[] = {
Model::GEMMA3_12B, Model::GEMMA3_27B, Model::GEMMA3_12B, Model::GEMMA3_27B,
}; };
inline bool EnumValid(Model model) { template <class Func>
for (Model m : kAllModels) { void ForEachModel(const Func& func) {
if (m == model) return true; for (size_t i = static_cast<size_t>(Model::UNKNOWN) + 1;
i < static_cast<size_t>(Model::kSentinel); ++i) {
func(static_cast<Model>(i));
}
}
static inline bool EnumValid(Model model) {
const size_t i = static_cast<size_t>(model);
if (i < static_cast<size_t>(Model::kSentinel)) {
return true;
} }
return false; return false;
} }
@ -301,7 +331,7 @@ struct ModelConfig : public IFields {
size_t NumHeads() const { size_t NumHeads() const {
uint32_t num_heads = 0; uint32_t num_heads = 0;
for (const auto& layer_config : layer_configs) { for (const auto& layer_config : layer_configs) {
num_heads = std::max(num_heads, layer_config.heads); num_heads = HWY_MAX(num_heads, layer_config.heads);
} }
return num_heads; return num_heads;
} }

View File

@ -26,12 +26,12 @@
#include "compression/io.h" // Path #include "compression/io.h" // Path
#include "compression/shared.h" #include "compression/shared.h"
#include "gemma/common.h" #include "gemma/configs.h"
#include "gemma/gemma.h" // For CreateGemma #include "gemma/gemma.h" // For CreateGemma
#include "hwy/base.h" // HWY_ABORT
#include "ops/matmul.h" #include "ops/matmul.h"
#include "util/args.h" #include "util/args.h"
#include "util/basics.h" // Tristate #include "util/basics.h" // Tristate
#include "hwy/base.h" // HWY_ABORT
namespace gcpp { namespace gcpp {

View File

@ -17,7 +17,7 @@
#include <algorithm> #include <algorithm>
#include "gemma/common.h" // CallForModel #include "gemma/configs.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" // ZeroBytes #include "hwy/base.h" // ZeroBytes

View File

@ -18,7 +18,7 @@
#include <stddef.h> #include <stddef.h>
#include "gemma/common.h" // Model #include "gemma/configs.h" // ModelConfig
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
namespace gcpp { namespace gcpp {

View File

@ -31,7 +31,7 @@
#include <random> #include <random>
#include <vector> #include <vector>
#include "gemma/common.h" #include "gemma/common.h" // ChooseQueryScale
#include "util/allocator.h" #include "util/allocator.h"
#include "util/basics.h" // BF16 #include "util/basics.h" // BF16
#include "util/mat.h" // RowVectorBatch #include "util/mat.h" // RowVectorBatch

View File

@ -19,7 +19,6 @@
#include "compression/shared.h" #include "compression/shared.h"
#include "evals/benchmark_helper.h" #include "evals/benchmark_helper.h"
#include "gemma/common.h"
#include "gemma/configs.h" #include "gemma/configs.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "util/allocator.h" #include "util/allocator.h"
@ -27,11 +26,7 @@
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"
// This test can be run manually with the downloaded PaliGemma weights. // This test can be run manually with the downloaded PaliGemma weights.
// To run the test, pass the following flags: // It should pass for `paligemma-3b-mix-224` and `paligemma2-3b-pt-448`.
// --model paligemma-224 --tokenizer <tokenizer_path> --weights <weights_path>
// or just use the single-file weights file with --weights <weights_path>.
// It should pass for the following models:
// paligemma-3b-mix-224, paligemma2-3b-pt-448
namespace gcpp { namespace gcpp {
namespace { namespace {

View File

@ -251,19 +251,11 @@ class MatPtrT : public MatPtr {
HWY_ASSERT(IsPacked()); HWY_ASSERT(IsPacked());
return MakeSpan(Row(0), num_elements_); return MakeSpan(Row(0), num_elements_);
} }
// For when a span of a single row is required. This also works if padded,
// but does not support `GetType() == kNUQ`, because that requires the use of
// offsets instead of a row pointer. Used by `gemma-inl.h` to decompress
// embeddings.
PackedSpan<const MatT> RowSpan(size_t row) const {
HWY_DASSERT(GetType() != Type::kNUQ);
return MakeConstSpan(Row(row), Cols());
}
}; };
// Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT<T>`, plus the // Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT<T>`, plus the
// optional `args`. // optional `args`. Currently unused but may be used after we move toward
// type-erased `WeightsPtrs`.
template <class Func, typename... Args> template <class Func, typename... Args>
decltype(auto) CallUpcasted(Type type, MatPtr* base, const Func& func, decltype(auto) CallUpcasted(Type type, MatPtr* base, const Func& func,
Args&&... args) { Args&&... args) {

View File

@ -118,6 +118,7 @@ class ThreadingContext2 {
// changing the arguments between tests. Callers must again call `Get` // changing the arguments between tests. Callers must again call `Get`
// afterwards to obtain an instance. WARNING: must not be called concurrently // afterwards to obtain an instance. WARNING: must not be called concurrently
// with other calls to `Get` and usages of its return value. // with other calls to `Get` and usages of its return value.
// Also useful to suppress memory leak warnings in tests.
static void ThreadHostileInvalidate(); static void ThreadHostileInvalidate();
explicit ThreadingContext2(PrivateToken); // only called via `Get`. explicit ThreadingContext2(PrivateToken); // only called via `Get`.