mirror of https://github.com/google/gemma.cpp.git
Cleanup: include fixes/comments, fix leak, vector reserve
Also remove unused RowSpan configs.cc: Assign prompt wrapping to ModelConfig configs.h: simplify EnumValid via sentinel PiperOrigin-RevId: 750278497
This commit is contained in:
parent
ba10c88a94
commit
160a5824fb
|
|
@ -443,22 +443,19 @@ cc_library(
|
|||
"mem": "28g",
|
||||
},
|
||||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":benchmark_helper",
|
||||
":common",
|
||||
":gemma_args",
|
||||
":gemma_lib",
|
||||
":kv_cache",
|
||||
":mat",
|
||||
":ops",
|
||||
":threading",
|
||||
":threading_context",
|
||||
":tokenizer",
|
||||
":weights",
|
||||
"//compression:shared",
|
||||
"//paligemma:image",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
"@highway//:timer",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,8 @@
|
|||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/common.h" // EmbeddingScaling
|
||||
#include "gemma/configs.h" // LayerConfig, ModelConfig
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/base.h"
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "backprop/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/common.h" // EmbeddingScaling
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@
|
|||
#include "compression/io.h" // Path
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "evals/cross_entropy.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "util/args.h"
|
||||
#include "hwy/base.h"
|
||||
|
|
|
|||
|
|
@ -49,6 +49,8 @@ class GemmaEnv {
|
|||
GemmaEnv(int argc, char** argv);
|
||||
GemmaEnv(const ThreadingArgs& threading, const LoaderArgs& loader,
|
||||
const InferenceArgs& inference);
|
||||
// Avoid memory leaks in test.
|
||||
~GemmaEnv() { ThreadingContext2::ThreadHostileInvalidate(); }
|
||||
|
||||
MatMulEnv& Env() { return env_; }
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,6 @@
|
|||
#include <vector>
|
||||
|
||||
#include "evals/cross_entropy.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
|
|
|
|||
|
|
@ -13,15 +13,14 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/gemma.h"
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
||||
|
|
@ -65,6 +64,7 @@ class GemmaTest : public ::testing::Test {
|
|||
prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string));
|
||||
}
|
||||
std::vector<PromptTokens> prompt_spans;
|
||||
prompt_spans.reserve(prompts_vector.size());
|
||||
for (const auto& prompt : prompts_vector) {
|
||||
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
|
||||
}
|
||||
|
|
@ -79,6 +79,7 @@ class GemmaTest : public ::testing::Test {
|
|||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||
|
||||
std::vector<std::string> inputs;
|
||||
inputs.reserve(num_questions);
|
||||
for (size_t i = 0; i < num_questions; ++i) {
|
||||
inputs.push_back(kQA[i]);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -187,6 +187,7 @@ static ModelConfig ConfigGemmaTiny() {
|
|||
ModelConfig config = ConfigNoSSM();
|
||||
config.model_name = "GemmaTiny";
|
||||
config.model = Model::GEMMA_TINY;
|
||||
config.wrapping = PromptWrapping::GEMMA_IT;
|
||||
config.model_dim = 128;
|
||||
config.vocab_size = 64;
|
||||
config.seq_len = 32;
|
||||
|
|
@ -277,6 +278,7 @@ static ModelConfig ConfigPaliGemma_224() {
|
|||
ModelConfig config = ConfigGemma2B();
|
||||
config.model_name = "PaliGemma_224";
|
||||
config.model = Model::PALIGEMMA_224;
|
||||
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||
AddVitConfig(config);
|
||||
return config;
|
||||
}
|
||||
|
|
@ -285,6 +287,7 @@ static ModelConfig ConfigPaliGemma_448() {
|
|||
ModelConfig config = ConfigGemma2B();
|
||||
config.model_name = "PaliGemma_448";
|
||||
config.model = Model::PALIGEMMA_448;
|
||||
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||
AddVitConfig(config, /*image_size=*/448);
|
||||
return config;
|
||||
}
|
||||
|
|
@ -305,6 +308,7 @@ static ModelConfig ConfigPaliGemma2_3B_224() {
|
|||
ModelConfig config = ConfigGemma2_2B();
|
||||
config.model_name = "PaliGemma2_3B_224";
|
||||
config.model = Model::PALIGEMMA2_3B_224;
|
||||
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||
AddVitConfig(config);
|
||||
return config;
|
||||
}
|
||||
|
|
@ -313,6 +317,7 @@ static ModelConfig ConfigPaliGemma2_3B_448() {
|
|||
ModelConfig config = ConfigGemma2_2B();
|
||||
config.model_name = "PaliGemma2_3B_448";
|
||||
config.model = Model::PALIGEMMA2_3B_448;
|
||||
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||
AddVitConfig(config, /*image_size=*/448);
|
||||
return config;
|
||||
}
|
||||
|
|
@ -321,6 +326,7 @@ static ModelConfig ConfigPaliGemma2_10B_224() {
|
|||
ModelConfig config = ConfigGemma2_9B();
|
||||
config.model_name = "PaliGemma2_10B_224";
|
||||
config.model = Model::PALIGEMMA2_10B_224;
|
||||
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||
AddVitConfig(config);
|
||||
return config;
|
||||
}
|
||||
|
|
@ -329,6 +335,7 @@ static ModelConfig ConfigPaliGemma2_10B_448() {
|
|||
ModelConfig config = ConfigGemma2_9B();
|
||||
config.model_name = "PaliGemma2_10B_448";
|
||||
config.model = Model::PALIGEMMA2_10B_448;
|
||||
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||
AddVitConfig(config, /*image_size=*/448);
|
||||
return config;
|
||||
}
|
||||
|
|
@ -360,6 +367,7 @@ static ModelConfig ConfigGemma3_1B() {
|
|||
ModelConfig config = ConfigBaseGemmaV3();
|
||||
config.model_name = "Gemma3_1B";
|
||||
config.model = Model::GEMMA3_1B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
config.model_dim = 1152;
|
||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||
config.seq_len = 32 * 1024;
|
||||
|
|
@ -391,6 +399,7 @@ static ModelConfig ConfigGemma3_4B_LM() {
|
|||
ModelConfig config = ConfigBaseGemmaV3();
|
||||
config.model_name = "Gemma3_4B";
|
||||
config.model = Model::GEMMA3_4B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
config.model_dim = 2560;
|
||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||
config.seq_len = 32 * 1024;
|
||||
|
|
@ -408,6 +417,7 @@ static ModelConfig ConfigGemma3_4B() {
|
|||
ModelConfig config = ConfigGemma3_4B_LM();
|
||||
config.model_name = "Gemma3_4B";
|
||||
config.model = Model::GEMMA3_4B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
AddVitConfig(config, /*image_size=*/896);
|
||||
config.vocab_size = 262144;
|
||||
config.vit_config.pool_dim = 4;
|
||||
|
|
@ -438,6 +448,7 @@ static ModelConfig ConfigGemma3_12B_LM() {
|
|||
ModelConfig config = ConfigBaseGemmaV3();
|
||||
config.model_name = "Gemma3_12B";
|
||||
config.model = Model::GEMMA3_12B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
config.model_dim = 3840;
|
||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||
config.seq_len = 32 * 1024;
|
||||
|
|
@ -455,6 +466,7 @@ static ModelConfig ConfigGemma3_12B() {
|
|||
ModelConfig config = ConfigGemma3_12B_LM();
|
||||
config.model_name = "Gemma3_12B";
|
||||
config.model = Model::GEMMA3_12B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
AddVitConfig(config, /*image_size=*/896);
|
||||
config.vocab_size = 262144;
|
||||
config.vit_config.pool_dim = 4;
|
||||
|
|
@ -485,6 +497,7 @@ static ModelConfig ConfigGemma3_27B_LM() {
|
|||
ModelConfig config = ConfigBaseGemmaV3();
|
||||
config.model_name = "Gemma3_27B";
|
||||
config.model = Model::GEMMA3_27B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
config.model_dim = 5376;
|
||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||
config.seq_len = 32 * 1024;
|
||||
|
|
@ -502,6 +515,7 @@ static ModelConfig ConfigGemma3_27B() {
|
|||
ModelConfig config = ConfigGemma3_27B_LM();
|
||||
config.model_name = "Gemma3_27B";
|
||||
config.model = Model::GEMMA3_27B;
|
||||
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||
AddVitConfig(config, /*image_size=*/896);
|
||||
config.vocab_size = 262144;
|
||||
config.vit_config.pool_dim = 4;
|
||||
|
|
|
|||
|
|
@ -19,10 +19,9 @@
|
|||
// Model configurations
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
|
@ -53,11 +52,27 @@ using EmbedderInputT = BF16;
|
|||
enum class PromptWrapping {
|
||||
GEMMA_IT,
|
||||
GEMMA_PT,
|
||||
GEMMA_VLM,
|
||||
GEMMA_VLM, // for >1B Gemma3
|
||||
PALIGEMMA,
|
||||
kSentinel // must be last
|
||||
};
|
||||
|
||||
// Defined as the suffix for use with `ModelString`.
|
||||
static inline const char* ToString(PromptWrapping wrapping) {
|
||||
switch (wrapping) {
|
||||
case PromptWrapping::GEMMA_IT:
|
||||
return "-it";
|
||||
case PromptWrapping::GEMMA_PT:
|
||||
return "-pt";
|
||||
case PromptWrapping::GEMMA_VLM:
|
||||
return "-vlm";
|
||||
case PromptWrapping::PALIGEMMA:
|
||||
return "-pg";
|
||||
default:
|
||||
return "-?";
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool EnumValid(PromptWrapping wrapping) {
|
||||
return static_cast<size_t>(wrapping) <
|
||||
static_cast<size_t>(PromptWrapping::kSentinel);
|
||||
|
|
@ -69,63 +84,68 @@ enum class LayerAttentionType {
|
|||
kVit,
|
||||
};
|
||||
|
||||
inline bool EnumValid(LayerAttentionType type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) <= static_cast<int>(LayerAttentionType::kVit);
|
||||
static inline bool EnumValid(LayerAttentionType type) {
|
||||
return type == LayerAttentionType::kGemma ||
|
||||
type == LayerAttentionType::kGriffinRecurrentBlock ||
|
||||
type == LayerAttentionType::kVit;
|
||||
}
|
||||
|
||||
// Post attention and ffw normalization type.
|
||||
enum class PostNormType {
|
||||
None,
|
||||
Scale,
|
||||
kSentinel // must be last
|
||||
};
|
||||
|
||||
inline bool EnumValid(PostNormType type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) <= static_cast<int>(PostNormType::Scale);
|
||||
static inline bool EnumValid(PostNormType type) {
|
||||
return static_cast<size_t>(type) <
|
||||
static_cast<size_t>(PostNormType::kSentinel);
|
||||
}
|
||||
|
||||
// Post qk projection operation type.
|
||||
enum class PostQKType {
|
||||
Rope,
|
||||
HalfRope,
|
||||
kSentinel // must be last
|
||||
};
|
||||
|
||||
inline bool EnumValid(PostQKType type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) <= static_cast<int>(PostQKType::HalfRope);
|
||||
static inline bool EnumValid(PostQKType type) {
|
||||
return static_cast<size_t>(type) <
|
||||
static_cast<size_t>(PostNormType::kSentinel);
|
||||
}
|
||||
|
||||
// FFW activation function.
|
||||
enum class ActivationType {
|
||||
Gelu,
|
||||
kSentinel // must be last
|
||||
};
|
||||
|
||||
inline bool EnumValid(ActivationType type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) <= static_cast<int>(ActivationType::Gelu);
|
||||
static inline bool EnumValid(ActivationType type) {
|
||||
return static_cast<size_t>(type) <
|
||||
static_cast<size_t>(ActivationType::kSentinel);
|
||||
}
|
||||
|
||||
// Attention query scale.
|
||||
enum class QueryScaleType {
|
||||
SqrtKeySize,
|
||||
SqrtModelDimDivNumHeads,
|
||||
kSentinel // must be last
|
||||
};
|
||||
|
||||
inline bool EnumValid(QueryScaleType type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) <=
|
||||
static_cast<int>(QueryScaleType::SqrtModelDimDivNumHeads);
|
||||
static inline bool EnumValid(QueryScaleType type) {
|
||||
return static_cast<size_t>(type) <
|
||||
static_cast<size_t>(QueryScaleType::kSentinel);
|
||||
}
|
||||
|
||||
// Residual connection type.
|
||||
enum class ResidualType {
|
||||
Add,
|
||||
kSentinel // must be last
|
||||
};
|
||||
|
||||
inline bool EnumValid(ResidualType type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) <= static_cast<int>(ResidualType::Add);
|
||||
static inline bool EnumValid(ResidualType type) {
|
||||
return static_cast<size_t>(type) <
|
||||
static_cast<size_t>(ResidualType::kSentinel);
|
||||
}
|
||||
|
||||
template <size_t kNum>
|
||||
|
|
@ -169,6 +189,7 @@ enum class Model {
|
|||
GEMMA3_1B,
|
||||
GEMMA3_12B,
|
||||
GEMMA3_27B,
|
||||
kSentinel,
|
||||
};
|
||||
|
||||
// Allows the Model enum to be iterated over.
|
||||
|
|
@ -181,9 +202,18 @@ static constexpr Model kAllModels[] = {
|
|||
Model::GEMMA3_12B, Model::GEMMA3_27B,
|
||||
};
|
||||
|
||||
inline bool EnumValid(Model model) {
|
||||
for (Model m : kAllModels) {
|
||||
if (m == model) return true;
|
||||
template <class Func>
|
||||
void ForEachModel(const Func& func) {
|
||||
for (size_t i = static_cast<size_t>(Model::UNKNOWN) + 1;
|
||||
i < static_cast<size_t>(Model::kSentinel); ++i) {
|
||||
func(static_cast<Model>(i));
|
||||
}
|
||||
}
|
||||
|
||||
static inline bool EnumValid(Model model) {
|
||||
const size_t i = static_cast<size_t>(model);
|
||||
if (i < static_cast<size_t>(Model::kSentinel)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
|
@ -301,7 +331,7 @@ struct ModelConfig : public IFields {
|
|||
size_t NumHeads() const {
|
||||
uint32_t num_heads = 0;
|
||||
for (const auto& layer_config : layer_configs) {
|
||||
num_heads = std::max(num_heads, layer_config.heads);
|
||||
num_heads = HWY_MAX(num_heads, layer_config.heads);
|
||||
}
|
||||
return num_heads;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,12 +26,12 @@
|
|||
|
||||
#include "compression/io.h" // Path
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h" // For CreateGemma
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
#include "ops/matmul.h"
|
||||
#include "util/args.h"
|
||||
#include "util/basics.h" // Tristate
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
#include <algorithm>
|
||||
|
||||
#include "gemma/common.h" // CallForModel
|
||||
#include "gemma/configs.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // ZeroBytes
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "gemma/common.h" // Model
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "hwy/aligned_allocator.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@
|
|||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/common.h" // ChooseQueryScale
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h" // BF16
|
||||
#include "util/mat.h" // RowVectorBatch
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@
|
|||
|
||||
#include "compression/shared.h"
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "util/allocator.h"
|
||||
|
|
@ -27,11 +26,7 @@
|
|||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
||||
// This test can be run manually with the downloaded PaliGemma weights.
|
||||
// To run the test, pass the following flags:
|
||||
// --model paligemma-224 --tokenizer <tokenizer_path> --weights <weights_path>
|
||||
// or just use the single-file weights file with --weights <weights_path>.
|
||||
// It should pass for the following models:
|
||||
// paligemma-3b-mix-224, paligemma2-3b-pt-448
|
||||
// It should pass for `paligemma-3b-mix-224` and `paligemma2-3b-pt-448`.
|
||||
|
||||
namespace gcpp {
|
||||
namespace {
|
||||
|
|
|
|||
12
util/mat.h
12
util/mat.h
|
|
@ -251,19 +251,11 @@ class MatPtrT : public MatPtr {
|
|||
HWY_ASSERT(IsPacked());
|
||||
return MakeSpan(Row(0), num_elements_);
|
||||
}
|
||||
|
||||
// For when a span of a single row is required. This also works if padded,
|
||||
// but does not support `GetType() == kNUQ`, because that requires the use of
|
||||
// offsets instead of a row pointer. Used by `gemma-inl.h` to decompress
|
||||
// embeddings.
|
||||
PackedSpan<const MatT> RowSpan(size_t row) const {
|
||||
HWY_DASSERT(GetType() != Type::kNUQ);
|
||||
return MakeConstSpan(Row(row), Cols());
|
||||
}
|
||||
};
|
||||
|
||||
// Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT<T>`, plus the
|
||||
// optional `args`.
|
||||
// optional `args`. Currently unused but may be used after we move toward
|
||||
// type-erased `WeightsPtrs`.
|
||||
template <class Func, typename... Args>
|
||||
decltype(auto) CallUpcasted(Type type, MatPtr* base, const Func& func,
|
||||
Args&&... args) {
|
||||
|
|
|
|||
|
|
@ -118,6 +118,7 @@ class ThreadingContext2 {
|
|||
// changing the arguments between tests. Callers must again call `Get`
|
||||
// afterwards to obtain an instance. WARNING: must not be called concurrently
|
||||
// with other calls to `Get` and usages of its return value.
|
||||
// Also useful to suppress memory leak warnings in tests.
|
||||
static void ThreadHostileInvalidate();
|
||||
|
||||
explicit ThreadingContext2(PrivateToken); // only called via `Get`.
|
||||
|
|
|
|||
Loading…
Reference in New Issue