Cleanup: include fixes/comments, fix leak, vector reserve

Also remove unused RowSpan
configs.cc: Assign prompt wrapping to ModelConfig
configs.h: simplify EnumValid via sentinel

PiperOrigin-RevId: 750278497
This commit is contained in:
Jan Wassenberg 2025-04-22 12:01:00 -07:00 committed by Copybara-Service
parent ba10c88a94
commit 160a5824fb
16 changed files with 91 additions and 60 deletions

View File

@ -443,22 +443,19 @@ cc_library(
"mem": "28g",
},
deps = [
":allocator",
":basics",
":benchmark_helper",
":common",
":gemma_args",
":gemma_lib",
":kv_cache",
":mat",
":ops",
":threading",
":threading_context",
":tokenizer",
":weights",
"//compression:shared",
"//paligemma:image",
"@highway//:hwy",
"@highway//:profiler",
"@highway//:timer",
],
)

View File

@ -27,7 +27,8 @@
#include "backprop/activations.h"
#include "backprop/prompt.h"
#include "gemma/common.h"
#include "gemma/common.h" // EmbeddingScaling
#include "gemma/configs.h" // LayerConfig, ModelConfig
#include "gemma/weights.h"
#include "util/allocator.h"
#include "hwy/base.h"

View File

@ -24,7 +24,7 @@
#include <vector>
#include "backprop/activations.h"
#include "gemma/common.h"
#include "gemma/common.h" // EmbeddingScaling
#include "gemma/configs.h"
#include "gemma/weights.h"
#include "util/allocator.h"

View File

@ -12,7 +12,6 @@
#include "compression/io.h" // Path
#include "evals/benchmark_helper.h"
#include "evals/cross_entropy.h"
#include "gemma/common.h"
#include "gemma/gemma.h"
#include "util/args.h"
#include "hwy/base.h"

View File

@ -49,6 +49,8 @@ class GemmaEnv {
GemmaEnv(int argc, char** argv);
GemmaEnv(const ThreadingArgs& threading, const LoaderArgs& loader,
const InferenceArgs& inference);
// Avoid memory leaks in test.
~GemmaEnv() { ThreadingContext2::ThreadHostileInvalidate(); }
MatMulEnv& Env() { return env_; }

View File

@ -38,7 +38,6 @@
#include <vector>
#include "evals/cross_entropy.h"
#include "gemma/common.h"
#include "gemma/gemma.h"
#include "hwy/base.h"

View File

@ -13,15 +13,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/gemma.h"
#include <stdio.h>
#include <string>
#include <vector>
#include "evals/benchmark_helper.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "hwy/base.h"
#include "hwy/tests/hwy_gtest.h"
@ -65,6 +64,7 @@ class GemmaTest : public ::testing::Test {
prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string));
}
std::vector<PromptTokens> prompt_spans;
prompt_spans.reserve(prompts_vector.size());
for (const auto& prompt : prompts_vector) {
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
}
@ -79,6 +79,7 @@ class GemmaTest : public ::testing::Test {
ASSERT_NE(s_env->GetGemma(), nullptr);
std::vector<std::string> inputs;
inputs.reserve(num_questions);
for (size_t i = 0; i < num_questions; ++i) {
inputs.push_back(kQA[i]);
}

View File

@ -187,6 +187,7 @@ static ModelConfig ConfigGemmaTiny() {
ModelConfig config = ConfigNoSSM();
config.model_name = "GemmaTiny";
config.model = Model::GEMMA_TINY;
config.wrapping = PromptWrapping::GEMMA_IT;
config.model_dim = 128;
config.vocab_size = 64;
config.seq_len = 32;
@ -277,6 +278,7 @@ static ModelConfig ConfigPaliGemma_224() {
ModelConfig config = ConfigGemma2B();
config.model_name = "PaliGemma_224";
config.model = Model::PALIGEMMA_224;
config.wrapping = PromptWrapping::PALIGEMMA;
AddVitConfig(config);
return config;
}
@ -285,6 +287,7 @@ static ModelConfig ConfigPaliGemma_448() {
ModelConfig config = ConfigGemma2B();
config.model_name = "PaliGemma_448";
config.model = Model::PALIGEMMA_448;
config.wrapping = PromptWrapping::PALIGEMMA;
AddVitConfig(config, /*image_size=*/448);
return config;
}
@ -305,6 +308,7 @@ static ModelConfig ConfigPaliGemma2_3B_224() {
ModelConfig config = ConfigGemma2_2B();
config.model_name = "PaliGemma2_3B_224";
config.model = Model::PALIGEMMA2_3B_224;
config.wrapping = PromptWrapping::PALIGEMMA;
AddVitConfig(config);
return config;
}
@ -313,6 +317,7 @@ static ModelConfig ConfigPaliGemma2_3B_448() {
ModelConfig config = ConfigGemma2_2B();
config.model_name = "PaliGemma2_3B_448";
config.model = Model::PALIGEMMA2_3B_448;
config.wrapping = PromptWrapping::PALIGEMMA;
AddVitConfig(config, /*image_size=*/448);
return config;
}
@ -321,6 +326,7 @@ static ModelConfig ConfigPaliGemma2_10B_224() {
ModelConfig config = ConfigGemma2_9B();
config.model_name = "PaliGemma2_10B_224";
config.model = Model::PALIGEMMA2_10B_224;
config.wrapping = PromptWrapping::PALIGEMMA;
AddVitConfig(config);
return config;
}
@ -329,6 +335,7 @@ static ModelConfig ConfigPaliGemma2_10B_448() {
ModelConfig config = ConfigGemma2_9B();
config.model_name = "PaliGemma2_10B_448";
config.model = Model::PALIGEMMA2_10B_448;
config.wrapping = PromptWrapping::PALIGEMMA;
AddVitConfig(config, /*image_size=*/448);
return config;
}
@ -360,6 +367,7 @@ static ModelConfig ConfigGemma3_1B() {
ModelConfig config = ConfigBaseGemmaV3();
config.model_name = "Gemma3_1B";
config.model = Model::GEMMA3_1B;
config.wrapping = PromptWrapping::GEMMA_VLM;
config.model_dim = 1152;
config.vocab_size = 262144; // new vocab size / tokenizer
config.seq_len = 32 * 1024;
@ -391,6 +399,7 @@ static ModelConfig ConfigGemma3_4B_LM() {
ModelConfig config = ConfigBaseGemmaV3();
config.model_name = "Gemma3_4B";
config.model = Model::GEMMA3_4B;
config.wrapping = PromptWrapping::GEMMA_VLM;
config.model_dim = 2560;
config.vocab_size = 262144; // new vocab size / tokenizer
config.seq_len = 32 * 1024;
@ -408,6 +417,7 @@ static ModelConfig ConfigGemma3_4B() {
ModelConfig config = ConfigGemma3_4B_LM();
config.model_name = "Gemma3_4B";
config.model = Model::GEMMA3_4B;
config.wrapping = PromptWrapping::GEMMA_VLM;
AddVitConfig(config, /*image_size=*/896);
config.vocab_size = 262144;
config.vit_config.pool_dim = 4;
@ -438,6 +448,7 @@ static ModelConfig ConfigGemma3_12B_LM() {
ModelConfig config = ConfigBaseGemmaV3();
config.model_name = "Gemma3_12B";
config.model = Model::GEMMA3_12B;
config.wrapping = PromptWrapping::GEMMA_VLM;
config.model_dim = 3840;
config.vocab_size = 262144; // new vocab size / tokenizer
config.seq_len = 32 * 1024;
@ -455,6 +466,7 @@ static ModelConfig ConfigGemma3_12B() {
ModelConfig config = ConfigGemma3_12B_LM();
config.model_name = "Gemma3_12B";
config.model = Model::GEMMA3_12B;
config.wrapping = PromptWrapping::GEMMA_VLM;
AddVitConfig(config, /*image_size=*/896);
config.vocab_size = 262144;
config.vit_config.pool_dim = 4;
@ -485,6 +497,7 @@ static ModelConfig ConfigGemma3_27B_LM() {
ModelConfig config = ConfigBaseGemmaV3();
config.model_name = "Gemma3_27B";
config.model = Model::GEMMA3_27B;
config.wrapping = PromptWrapping::GEMMA_VLM;
config.model_dim = 5376;
config.vocab_size = 262144; // new vocab size / tokenizer
config.seq_len = 32 * 1024;
@ -502,6 +515,7 @@ static ModelConfig ConfigGemma3_27B() {
ModelConfig config = ConfigGemma3_27B_LM();
config.model_name = "Gemma3_27B";
config.model = Model::GEMMA3_27B;
config.wrapping = PromptWrapping::GEMMA_VLM;
AddVitConfig(config, /*image_size=*/896);
config.vocab_size = 262144;
config.vit_config.pool_dim = 4;

View File

@ -19,10 +19,9 @@
// Model configurations
#include <stddef.h>
#include <stdint.h>
#include <algorithm>
#include <array>
#include <cstdint>
#include <string>
#include <unordered_set>
#include <vector>
@ -53,11 +52,27 @@ using EmbedderInputT = BF16;
enum class PromptWrapping {
GEMMA_IT,
GEMMA_PT,
GEMMA_VLM,
GEMMA_VLM, // for >1B Gemma3
PALIGEMMA,
kSentinel // must be last
};
// Defined as the suffix for use with `ModelString`.
static inline const char* ToString(PromptWrapping wrapping) {
switch (wrapping) {
case PromptWrapping::GEMMA_IT:
return "-it";
case PromptWrapping::GEMMA_PT:
return "-pt";
case PromptWrapping::GEMMA_VLM:
return "-vlm";
case PromptWrapping::PALIGEMMA:
return "-pg";
default:
return "-?";
}
}
static inline bool EnumValid(PromptWrapping wrapping) {
return static_cast<size_t>(wrapping) <
static_cast<size_t>(PromptWrapping::kSentinel);
@ -69,63 +84,68 @@ enum class LayerAttentionType {
kVit,
};
inline bool EnumValid(LayerAttentionType type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <= static_cast<int>(LayerAttentionType::kVit);
static inline bool EnumValid(LayerAttentionType type) {
return type == LayerAttentionType::kGemma ||
type == LayerAttentionType::kGriffinRecurrentBlock ||
type == LayerAttentionType::kVit;
}
// Post attention and ffw normalization type.
enum class PostNormType {
None,
Scale,
kSentinel // must be last
};
inline bool EnumValid(PostNormType type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <= static_cast<int>(PostNormType::Scale);
static inline bool EnumValid(PostNormType type) {
return static_cast<size_t>(type) <
static_cast<size_t>(PostNormType::kSentinel);
}
// Post qk projection operation type.
enum class PostQKType {
Rope,
HalfRope,
kSentinel // must be last
};
inline bool EnumValid(PostQKType type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <= static_cast<int>(PostQKType::HalfRope);
static inline bool EnumValid(PostQKType type) {
return static_cast<size_t>(type) <
static_cast<size_t>(PostNormType::kSentinel);
}
// FFW activation function.
enum class ActivationType {
Gelu,
kSentinel // must be last
};
inline bool EnumValid(ActivationType type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <= static_cast<int>(ActivationType::Gelu);
static inline bool EnumValid(ActivationType type) {
return static_cast<size_t>(type) <
static_cast<size_t>(ActivationType::kSentinel);
}
// Attention query scale.
enum class QueryScaleType {
SqrtKeySize,
SqrtModelDimDivNumHeads,
kSentinel // must be last
};
inline bool EnumValid(QueryScaleType type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <=
static_cast<int>(QueryScaleType::SqrtModelDimDivNumHeads);
static inline bool EnumValid(QueryScaleType type) {
return static_cast<size_t>(type) <
static_cast<size_t>(QueryScaleType::kSentinel);
}
// Residual connection type.
enum class ResidualType {
Add,
kSentinel // must be last
};
inline bool EnumValid(ResidualType type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <= static_cast<int>(ResidualType::Add);
static inline bool EnumValid(ResidualType type) {
return static_cast<size_t>(type) <
static_cast<size_t>(ResidualType::kSentinel);
}
template <size_t kNum>
@ -169,6 +189,7 @@ enum class Model {
GEMMA3_1B,
GEMMA3_12B,
GEMMA3_27B,
kSentinel,
};
// Allows the Model enum to be iterated over.
@ -181,9 +202,18 @@ static constexpr Model kAllModels[] = {
Model::GEMMA3_12B, Model::GEMMA3_27B,
};
inline bool EnumValid(Model model) {
for (Model m : kAllModels) {
if (m == model) return true;
template <class Func>
void ForEachModel(const Func& func) {
for (size_t i = static_cast<size_t>(Model::UNKNOWN) + 1;
i < static_cast<size_t>(Model::kSentinel); ++i) {
func(static_cast<Model>(i));
}
}
static inline bool EnumValid(Model model) {
const size_t i = static_cast<size_t>(model);
if (i < static_cast<size_t>(Model::kSentinel)) {
return true;
}
return false;
}
@ -301,7 +331,7 @@ struct ModelConfig : public IFields {
size_t NumHeads() const {
uint32_t num_heads = 0;
for (const auto& layer_config : layer_configs) {
num_heads = std::max(num_heads, layer_config.heads);
num_heads = HWY_MAX(num_heads, layer_config.heads);
}
return num_heads;
}

View File

@ -26,12 +26,12 @@
#include "compression/io.h" // Path
#include "compression/shared.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/gemma.h" // For CreateGemma
#include "hwy/base.h" // HWY_ABORT
#include "ops/matmul.h"
#include "util/args.h"
#include "util/basics.h" // Tristate
#include "hwy/base.h" // HWY_ABORT
namespace gcpp {
@ -237,4 +237,4 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_

View File

@ -17,7 +17,7 @@
#include <algorithm>
#include "gemma/common.h" // CallForModel
#include "gemma/configs.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // ZeroBytes

View File

@ -18,7 +18,7 @@
#include <stddef.h>
#include "gemma/common.h" // Model
#include "gemma/configs.h" // ModelConfig
#include "hwy/aligned_allocator.h"
namespace gcpp {

View File

@ -31,7 +31,7 @@
#include <random>
#include <vector>
#include "gemma/common.h"
#include "gemma/common.h" // ChooseQueryScale
#include "util/allocator.h"
#include "util/basics.h" // BF16
#include "util/mat.h" // RowVectorBatch

View File

@ -19,7 +19,6 @@
#include "compression/shared.h"
#include "evals/benchmark_helper.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "util/allocator.h"
@ -27,11 +26,7 @@
#include "hwy/tests/hwy_gtest.h"
// This test can be run manually with the downloaded PaliGemma weights.
// To run the test, pass the following flags:
// --model paligemma-224 --tokenizer <tokenizer_path> --weights <weights_path>
// or just use the single-file weights file with --weights <weights_path>.
// It should pass for the following models:
// paligemma-3b-mix-224, paligemma2-3b-pt-448
// It should pass for `paligemma-3b-mix-224` and `paligemma2-3b-pt-448`.
namespace gcpp {
namespace {

View File

@ -251,19 +251,11 @@ class MatPtrT : public MatPtr {
HWY_ASSERT(IsPacked());
return MakeSpan(Row(0), num_elements_);
}
// For when a span of a single row is required. This also works if padded,
// but does not support `GetType() == kNUQ`, because that requires the use of
// offsets instead of a row pointer. Used by `gemma-inl.h` to decompress
// embeddings.
PackedSpan<const MatT> RowSpan(size_t row) const {
HWY_DASSERT(GetType() != Type::kNUQ);
return MakeConstSpan(Row(row), Cols());
}
};
// Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT<T>`, plus the
// optional `args`.
// optional `args`. Currently unused but may be used after we move toward
// type-erased `WeightsPtrs`.
template <class Func, typename... Args>
decltype(auto) CallUpcasted(Type type, MatPtr* base, const Func& func,
Args&&... args) {

View File

@ -118,6 +118,7 @@ class ThreadingContext2 {
// changing the arguments between tests. Callers must again call `Get`
// afterwards to obtain an instance. WARNING: must not be called concurrently
// with other calls to `Get` and usages of its return value.
// Also useful to suppress memory leak warnings in tests.
static void ThreadHostileInvalidate();
explicit ThreadingContext2(PrivateToken); // only called via `Get`.