mirror of https://github.com/google/gemma.cpp.git
Cleanup: include fixes/comments, fix leak, vector reserve
Also remove unused RowSpan configs.cc: Assign prompt wrapping to ModelConfig configs.h: simplify EnumValid via sentinel PiperOrigin-RevId: 750278497
This commit is contained in:
parent
ba10c88a94
commit
160a5824fb
|
|
@ -443,22 +443,19 @@ cc_library(
|
||||||
"mem": "28g",
|
"mem": "28g",
|
||||||
},
|
},
|
||||||
deps = [
|
deps = [
|
||||||
":allocator",
|
|
||||||
":basics",
|
|
||||||
":benchmark_helper",
|
":benchmark_helper",
|
||||||
":common",
|
":common",
|
||||||
":gemma_args",
|
":gemma_args",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":kv_cache",
|
":kv_cache",
|
||||||
":mat",
|
|
||||||
":ops",
|
":ops",
|
||||||
":threading",
|
":threading",
|
||||||
":threading_context",
|
":threading_context",
|
||||||
":tokenizer",
|
":tokenizer",
|
||||||
":weights",
|
|
||||||
"//compression:shared",
|
|
||||||
"//paligemma:image",
|
"//paligemma:image",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
|
"@highway//:profiler",
|
||||||
|
"@highway//:timer",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,8 @@
|
||||||
|
|
||||||
#include "backprop/activations.h"
|
#include "backprop/activations.h"
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h" // EmbeddingScaling
|
||||||
|
#include "gemma/configs.h" // LayerConfig, ModelConfig
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "backprop/activations.h"
|
#include "backprop/activations.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h" // EmbeddingScaling
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@
|
||||||
#include "compression/io.h" // Path
|
#include "compression/io.h" // Path
|
||||||
#include "evals/benchmark_helper.h"
|
#include "evals/benchmark_helper.h"
|
||||||
#include "evals/cross_entropy.h"
|
#include "evals/cross_entropy.h"
|
||||||
#include "gemma/common.h"
|
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,8 @@ class GemmaEnv {
|
||||||
GemmaEnv(int argc, char** argv);
|
GemmaEnv(int argc, char** argv);
|
||||||
GemmaEnv(const ThreadingArgs& threading, const LoaderArgs& loader,
|
GemmaEnv(const ThreadingArgs& threading, const LoaderArgs& loader,
|
||||||
const InferenceArgs& inference);
|
const InferenceArgs& inference);
|
||||||
|
// Avoid memory leaks in test.
|
||||||
|
~GemmaEnv() { ThreadingContext2::ThreadHostileInvalidate(); }
|
||||||
|
|
||||||
MatMulEnv& Env() { return env_; }
|
MatMulEnv& Env() { return env_; }
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,6 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "evals/cross_entropy.h"
|
#include "evals/cross_entropy.h"
|
||||||
#include "gemma/common.h"
|
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,15 +13,14 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "gemma/gemma.h"
|
|
||||||
|
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "evals/benchmark_helper.h"
|
#include "evals/benchmark_helper.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/configs.h"
|
||||||
|
#include "gemma/gemma.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/tests/hwy_gtest.h"
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
|
|
||||||
|
|
@ -65,6 +64,7 @@ class GemmaTest : public ::testing::Test {
|
||||||
prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string));
|
prompts_vector.push_back(s_env->TokenizeAndPrependBOS(input_string));
|
||||||
}
|
}
|
||||||
std::vector<PromptTokens> prompt_spans;
|
std::vector<PromptTokens> prompt_spans;
|
||||||
|
prompt_spans.reserve(prompts_vector.size());
|
||||||
for (const auto& prompt : prompts_vector) {
|
for (const auto& prompt : prompts_vector) {
|
||||||
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
|
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
|
||||||
}
|
}
|
||||||
|
|
@ -79,6 +79,7 @@ class GemmaTest : public ::testing::Test {
|
||||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||||
|
|
||||||
std::vector<std::string> inputs;
|
std::vector<std::string> inputs;
|
||||||
|
inputs.reserve(num_questions);
|
||||||
for (size_t i = 0; i < num_questions; ++i) {
|
for (size_t i = 0; i < num_questions; ++i) {
|
||||||
inputs.push_back(kQA[i]);
|
inputs.push_back(kQA[i]);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -187,6 +187,7 @@ static ModelConfig ConfigGemmaTiny() {
|
||||||
ModelConfig config = ConfigNoSSM();
|
ModelConfig config = ConfigNoSSM();
|
||||||
config.model_name = "GemmaTiny";
|
config.model_name = "GemmaTiny";
|
||||||
config.model = Model::GEMMA_TINY;
|
config.model = Model::GEMMA_TINY;
|
||||||
|
config.wrapping = PromptWrapping::GEMMA_IT;
|
||||||
config.model_dim = 128;
|
config.model_dim = 128;
|
||||||
config.vocab_size = 64;
|
config.vocab_size = 64;
|
||||||
config.seq_len = 32;
|
config.seq_len = 32;
|
||||||
|
|
@ -277,6 +278,7 @@ static ModelConfig ConfigPaliGemma_224() {
|
||||||
ModelConfig config = ConfigGemma2B();
|
ModelConfig config = ConfigGemma2B();
|
||||||
config.model_name = "PaliGemma_224";
|
config.model_name = "PaliGemma_224";
|
||||||
config.model = Model::PALIGEMMA_224;
|
config.model = Model::PALIGEMMA_224;
|
||||||
|
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||||
AddVitConfig(config);
|
AddVitConfig(config);
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
@ -285,6 +287,7 @@ static ModelConfig ConfigPaliGemma_448() {
|
||||||
ModelConfig config = ConfigGemma2B();
|
ModelConfig config = ConfigGemma2B();
|
||||||
config.model_name = "PaliGemma_448";
|
config.model_name = "PaliGemma_448";
|
||||||
config.model = Model::PALIGEMMA_448;
|
config.model = Model::PALIGEMMA_448;
|
||||||
|
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||||
AddVitConfig(config, /*image_size=*/448);
|
AddVitConfig(config, /*image_size=*/448);
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
@ -305,6 +308,7 @@ static ModelConfig ConfigPaliGemma2_3B_224() {
|
||||||
ModelConfig config = ConfigGemma2_2B();
|
ModelConfig config = ConfigGemma2_2B();
|
||||||
config.model_name = "PaliGemma2_3B_224";
|
config.model_name = "PaliGemma2_3B_224";
|
||||||
config.model = Model::PALIGEMMA2_3B_224;
|
config.model = Model::PALIGEMMA2_3B_224;
|
||||||
|
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||||
AddVitConfig(config);
|
AddVitConfig(config);
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
@ -313,6 +317,7 @@ static ModelConfig ConfigPaliGemma2_3B_448() {
|
||||||
ModelConfig config = ConfigGemma2_2B();
|
ModelConfig config = ConfigGemma2_2B();
|
||||||
config.model_name = "PaliGemma2_3B_448";
|
config.model_name = "PaliGemma2_3B_448";
|
||||||
config.model = Model::PALIGEMMA2_3B_448;
|
config.model = Model::PALIGEMMA2_3B_448;
|
||||||
|
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||||
AddVitConfig(config, /*image_size=*/448);
|
AddVitConfig(config, /*image_size=*/448);
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
@ -321,6 +326,7 @@ static ModelConfig ConfigPaliGemma2_10B_224() {
|
||||||
ModelConfig config = ConfigGemma2_9B();
|
ModelConfig config = ConfigGemma2_9B();
|
||||||
config.model_name = "PaliGemma2_10B_224";
|
config.model_name = "PaliGemma2_10B_224";
|
||||||
config.model = Model::PALIGEMMA2_10B_224;
|
config.model = Model::PALIGEMMA2_10B_224;
|
||||||
|
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||||
AddVitConfig(config);
|
AddVitConfig(config);
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
@ -329,6 +335,7 @@ static ModelConfig ConfigPaliGemma2_10B_448() {
|
||||||
ModelConfig config = ConfigGemma2_9B();
|
ModelConfig config = ConfigGemma2_9B();
|
||||||
config.model_name = "PaliGemma2_10B_448";
|
config.model_name = "PaliGemma2_10B_448";
|
||||||
config.model = Model::PALIGEMMA2_10B_448;
|
config.model = Model::PALIGEMMA2_10B_448;
|
||||||
|
config.wrapping = PromptWrapping::PALIGEMMA;
|
||||||
AddVitConfig(config, /*image_size=*/448);
|
AddVitConfig(config, /*image_size=*/448);
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
@ -360,6 +367,7 @@ static ModelConfig ConfigGemma3_1B() {
|
||||||
ModelConfig config = ConfigBaseGemmaV3();
|
ModelConfig config = ConfigBaseGemmaV3();
|
||||||
config.model_name = "Gemma3_1B";
|
config.model_name = "Gemma3_1B";
|
||||||
config.model = Model::GEMMA3_1B;
|
config.model = Model::GEMMA3_1B;
|
||||||
|
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||||
config.model_dim = 1152;
|
config.model_dim = 1152;
|
||||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||||
config.seq_len = 32 * 1024;
|
config.seq_len = 32 * 1024;
|
||||||
|
|
@ -391,6 +399,7 @@ static ModelConfig ConfigGemma3_4B_LM() {
|
||||||
ModelConfig config = ConfigBaseGemmaV3();
|
ModelConfig config = ConfigBaseGemmaV3();
|
||||||
config.model_name = "Gemma3_4B";
|
config.model_name = "Gemma3_4B";
|
||||||
config.model = Model::GEMMA3_4B;
|
config.model = Model::GEMMA3_4B;
|
||||||
|
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||||
config.model_dim = 2560;
|
config.model_dim = 2560;
|
||||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||||
config.seq_len = 32 * 1024;
|
config.seq_len = 32 * 1024;
|
||||||
|
|
@ -408,6 +417,7 @@ static ModelConfig ConfigGemma3_4B() {
|
||||||
ModelConfig config = ConfigGemma3_4B_LM();
|
ModelConfig config = ConfigGemma3_4B_LM();
|
||||||
config.model_name = "Gemma3_4B";
|
config.model_name = "Gemma3_4B";
|
||||||
config.model = Model::GEMMA3_4B;
|
config.model = Model::GEMMA3_4B;
|
||||||
|
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||||
AddVitConfig(config, /*image_size=*/896);
|
AddVitConfig(config, /*image_size=*/896);
|
||||||
config.vocab_size = 262144;
|
config.vocab_size = 262144;
|
||||||
config.vit_config.pool_dim = 4;
|
config.vit_config.pool_dim = 4;
|
||||||
|
|
@ -438,6 +448,7 @@ static ModelConfig ConfigGemma3_12B_LM() {
|
||||||
ModelConfig config = ConfigBaseGemmaV3();
|
ModelConfig config = ConfigBaseGemmaV3();
|
||||||
config.model_name = "Gemma3_12B";
|
config.model_name = "Gemma3_12B";
|
||||||
config.model = Model::GEMMA3_12B;
|
config.model = Model::GEMMA3_12B;
|
||||||
|
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||||
config.model_dim = 3840;
|
config.model_dim = 3840;
|
||||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||||
config.seq_len = 32 * 1024;
|
config.seq_len = 32 * 1024;
|
||||||
|
|
@ -455,6 +466,7 @@ static ModelConfig ConfigGemma3_12B() {
|
||||||
ModelConfig config = ConfigGemma3_12B_LM();
|
ModelConfig config = ConfigGemma3_12B_LM();
|
||||||
config.model_name = "Gemma3_12B";
|
config.model_name = "Gemma3_12B";
|
||||||
config.model = Model::GEMMA3_12B;
|
config.model = Model::GEMMA3_12B;
|
||||||
|
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||||
AddVitConfig(config, /*image_size=*/896);
|
AddVitConfig(config, /*image_size=*/896);
|
||||||
config.vocab_size = 262144;
|
config.vocab_size = 262144;
|
||||||
config.vit_config.pool_dim = 4;
|
config.vit_config.pool_dim = 4;
|
||||||
|
|
@ -485,6 +497,7 @@ static ModelConfig ConfigGemma3_27B_LM() {
|
||||||
ModelConfig config = ConfigBaseGemmaV3();
|
ModelConfig config = ConfigBaseGemmaV3();
|
||||||
config.model_name = "Gemma3_27B";
|
config.model_name = "Gemma3_27B";
|
||||||
config.model = Model::GEMMA3_27B;
|
config.model = Model::GEMMA3_27B;
|
||||||
|
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||||
config.model_dim = 5376;
|
config.model_dim = 5376;
|
||||||
config.vocab_size = 262144; // new vocab size / tokenizer
|
config.vocab_size = 262144; // new vocab size / tokenizer
|
||||||
config.seq_len = 32 * 1024;
|
config.seq_len = 32 * 1024;
|
||||||
|
|
@ -502,6 +515,7 @@ static ModelConfig ConfigGemma3_27B() {
|
||||||
ModelConfig config = ConfigGemma3_27B_LM();
|
ModelConfig config = ConfigGemma3_27B_LM();
|
||||||
config.model_name = "Gemma3_27B";
|
config.model_name = "Gemma3_27B";
|
||||||
config.model = Model::GEMMA3_27B;
|
config.model = Model::GEMMA3_27B;
|
||||||
|
config.wrapping = PromptWrapping::GEMMA_VLM;
|
||||||
AddVitConfig(config, /*image_size=*/896);
|
AddVitConfig(config, /*image_size=*/896);
|
||||||
config.vocab_size = 262144;
|
config.vocab_size = 262144;
|
||||||
config.vit_config.pool_dim = 4;
|
config.vit_config.pool_dim = 4;
|
||||||
|
|
|
||||||
|
|
@ -19,10 +19,9 @@
|
||||||
// Model configurations
|
// Model configurations
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cstdint>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
@ -53,11 +52,27 @@ using EmbedderInputT = BF16;
|
||||||
enum class PromptWrapping {
|
enum class PromptWrapping {
|
||||||
GEMMA_IT,
|
GEMMA_IT,
|
||||||
GEMMA_PT,
|
GEMMA_PT,
|
||||||
GEMMA_VLM,
|
GEMMA_VLM, // for >1B Gemma3
|
||||||
PALIGEMMA,
|
PALIGEMMA,
|
||||||
kSentinel // must be last
|
kSentinel // must be last
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Defined as the suffix for use with `ModelString`.
|
||||||
|
static inline const char* ToString(PromptWrapping wrapping) {
|
||||||
|
switch (wrapping) {
|
||||||
|
case PromptWrapping::GEMMA_IT:
|
||||||
|
return "-it";
|
||||||
|
case PromptWrapping::GEMMA_PT:
|
||||||
|
return "-pt";
|
||||||
|
case PromptWrapping::GEMMA_VLM:
|
||||||
|
return "-vlm";
|
||||||
|
case PromptWrapping::PALIGEMMA:
|
||||||
|
return "-pg";
|
||||||
|
default:
|
||||||
|
return "-?";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static inline bool EnumValid(PromptWrapping wrapping) {
|
static inline bool EnumValid(PromptWrapping wrapping) {
|
||||||
return static_cast<size_t>(wrapping) <
|
return static_cast<size_t>(wrapping) <
|
||||||
static_cast<size_t>(PromptWrapping::kSentinel);
|
static_cast<size_t>(PromptWrapping::kSentinel);
|
||||||
|
|
@ -69,63 +84,68 @@ enum class LayerAttentionType {
|
||||||
kVit,
|
kVit,
|
||||||
};
|
};
|
||||||
|
|
||||||
inline bool EnumValid(LayerAttentionType type) {
|
static inline bool EnumValid(LayerAttentionType type) {
|
||||||
return static_cast<int>(type) >= 0 &&
|
return type == LayerAttentionType::kGemma ||
|
||||||
static_cast<int>(type) <= static_cast<int>(LayerAttentionType::kVit);
|
type == LayerAttentionType::kGriffinRecurrentBlock ||
|
||||||
|
type == LayerAttentionType::kVit;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Post attention and ffw normalization type.
|
// Post attention and ffw normalization type.
|
||||||
enum class PostNormType {
|
enum class PostNormType {
|
||||||
None,
|
None,
|
||||||
Scale,
|
Scale,
|
||||||
|
kSentinel // must be last
|
||||||
};
|
};
|
||||||
|
|
||||||
inline bool EnumValid(PostNormType type) {
|
static inline bool EnumValid(PostNormType type) {
|
||||||
return static_cast<int>(type) >= 0 &&
|
return static_cast<size_t>(type) <
|
||||||
static_cast<int>(type) <= static_cast<int>(PostNormType::Scale);
|
static_cast<size_t>(PostNormType::kSentinel);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Post qk projection operation type.
|
// Post qk projection operation type.
|
||||||
enum class PostQKType {
|
enum class PostQKType {
|
||||||
Rope,
|
Rope,
|
||||||
HalfRope,
|
HalfRope,
|
||||||
|
kSentinel // must be last
|
||||||
};
|
};
|
||||||
|
|
||||||
inline bool EnumValid(PostQKType type) {
|
static inline bool EnumValid(PostQKType type) {
|
||||||
return static_cast<int>(type) >= 0 &&
|
return static_cast<size_t>(type) <
|
||||||
static_cast<int>(type) <= static_cast<int>(PostQKType::HalfRope);
|
static_cast<size_t>(PostNormType::kSentinel);
|
||||||
}
|
}
|
||||||
|
|
||||||
// FFW activation function.
|
// FFW activation function.
|
||||||
enum class ActivationType {
|
enum class ActivationType {
|
||||||
Gelu,
|
Gelu,
|
||||||
|
kSentinel // must be last
|
||||||
};
|
};
|
||||||
|
|
||||||
inline bool EnumValid(ActivationType type) {
|
static inline bool EnumValid(ActivationType type) {
|
||||||
return static_cast<int>(type) >= 0 &&
|
return static_cast<size_t>(type) <
|
||||||
static_cast<int>(type) <= static_cast<int>(ActivationType::Gelu);
|
static_cast<size_t>(ActivationType::kSentinel);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attention query scale.
|
// Attention query scale.
|
||||||
enum class QueryScaleType {
|
enum class QueryScaleType {
|
||||||
SqrtKeySize,
|
SqrtKeySize,
|
||||||
SqrtModelDimDivNumHeads,
|
SqrtModelDimDivNumHeads,
|
||||||
|
kSentinel // must be last
|
||||||
};
|
};
|
||||||
|
|
||||||
inline bool EnumValid(QueryScaleType type) {
|
static inline bool EnumValid(QueryScaleType type) {
|
||||||
return static_cast<int>(type) >= 0 &&
|
return static_cast<size_t>(type) <
|
||||||
static_cast<int>(type) <=
|
static_cast<size_t>(QueryScaleType::kSentinel);
|
||||||
static_cast<int>(QueryScaleType::SqrtModelDimDivNumHeads);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Residual connection type.
|
// Residual connection type.
|
||||||
enum class ResidualType {
|
enum class ResidualType {
|
||||||
Add,
|
Add,
|
||||||
|
kSentinel // must be last
|
||||||
};
|
};
|
||||||
|
|
||||||
inline bool EnumValid(ResidualType type) {
|
static inline bool EnumValid(ResidualType type) {
|
||||||
return static_cast<int>(type) >= 0 &&
|
return static_cast<size_t>(type) <
|
||||||
static_cast<int>(type) <= static_cast<int>(ResidualType::Add);
|
static_cast<size_t>(ResidualType::kSentinel);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t kNum>
|
template <size_t kNum>
|
||||||
|
|
@ -169,6 +189,7 @@ enum class Model {
|
||||||
GEMMA3_1B,
|
GEMMA3_1B,
|
||||||
GEMMA3_12B,
|
GEMMA3_12B,
|
||||||
GEMMA3_27B,
|
GEMMA3_27B,
|
||||||
|
kSentinel,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Allows the Model enum to be iterated over.
|
// Allows the Model enum to be iterated over.
|
||||||
|
|
@ -181,9 +202,18 @@ static constexpr Model kAllModels[] = {
|
||||||
Model::GEMMA3_12B, Model::GEMMA3_27B,
|
Model::GEMMA3_12B, Model::GEMMA3_27B,
|
||||||
};
|
};
|
||||||
|
|
||||||
inline bool EnumValid(Model model) {
|
template <class Func>
|
||||||
for (Model m : kAllModels) {
|
void ForEachModel(const Func& func) {
|
||||||
if (m == model) return true;
|
for (size_t i = static_cast<size_t>(Model::UNKNOWN) + 1;
|
||||||
|
i < static_cast<size_t>(Model::kSentinel); ++i) {
|
||||||
|
func(static_cast<Model>(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline bool EnumValid(Model model) {
|
||||||
|
const size_t i = static_cast<size_t>(model);
|
||||||
|
if (i < static_cast<size_t>(Model::kSentinel)) {
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
@ -301,7 +331,7 @@ struct ModelConfig : public IFields {
|
||||||
size_t NumHeads() const {
|
size_t NumHeads() const {
|
||||||
uint32_t num_heads = 0;
|
uint32_t num_heads = 0;
|
||||||
for (const auto& layer_config : layer_configs) {
|
for (const auto& layer_config : layer_configs) {
|
||||||
num_heads = std::max(num_heads, layer_config.heads);
|
num_heads = HWY_MAX(num_heads, layer_config.heads);
|
||||||
}
|
}
|
||||||
return num_heads;
|
return num_heads;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -26,12 +26,12 @@
|
||||||
|
|
||||||
#include "compression/io.h" // Path
|
#include "compression/io.h" // Path
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/configs.h"
|
||||||
#include "gemma/gemma.h" // For CreateGemma
|
#include "gemma/gemma.h" // For CreateGemma
|
||||||
#include "hwy/base.h" // HWY_ABORT
|
|
||||||
#include "ops/matmul.h"
|
#include "ops/matmul.h"
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "util/basics.h" // Tristate
|
#include "util/basics.h" // Tristate
|
||||||
|
#include "hwy/base.h" // HWY_ABORT
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -237,4 +237,4 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
#include "gemma/common.h" // CallForModel
|
#include "gemma/configs.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h" // ZeroBytes
|
#include "hwy/base.h" // ZeroBytes
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
#include "gemma/common.h" // Model
|
#include "gemma/configs.h" // ModelConfig
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h" // ChooseQueryScale
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
#include "util/basics.h" // BF16
|
#include "util/basics.h" // BF16
|
||||||
#include "util/mat.h" // RowVectorBatch
|
#include "util/mat.h" // RowVectorBatch
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,6 @@
|
||||||
|
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h"
|
||||||
#include "evals/benchmark_helper.h"
|
#include "evals/benchmark_helper.h"
|
||||||
#include "gemma/common.h"
|
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
|
|
@ -27,11 +26,7 @@
|
||||||
#include "hwy/tests/hwy_gtest.h"
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
|
|
||||||
// This test can be run manually with the downloaded PaliGemma weights.
|
// This test can be run manually with the downloaded PaliGemma weights.
|
||||||
// To run the test, pass the following flags:
|
// It should pass for `paligemma-3b-mix-224` and `paligemma2-3b-pt-448`.
|
||||||
// --model paligemma-224 --tokenizer <tokenizer_path> --weights <weights_path>
|
|
||||||
// or just use the single-file weights file with --weights <weights_path>.
|
|
||||||
// It should pass for the following models:
|
|
||||||
// paligemma-3b-mix-224, paligemma2-3b-pt-448
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
|
||||||
12
util/mat.h
12
util/mat.h
|
|
@ -251,19 +251,11 @@ class MatPtrT : public MatPtr {
|
||||||
HWY_ASSERT(IsPacked());
|
HWY_ASSERT(IsPacked());
|
||||||
return MakeSpan(Row(0), num_elements_);
|
return MakeSpan(Row(0), num_elements_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// For when a span of a single row is required. This also works if padded,
|
|
||||||
// but does not support `GetType() == kNUQ`, because that requires the use of
|
|
||||||
// offsets instead of a row pointer. Used by `gemma-inl.h` to decompress
|
|
||||||
// embeddings.
|
|
||||||
PackedSpan<const MatT> RowSpan(size_t row) const {
|
|
||||||
HWY_DASSERT(GetType() != Type::kNUQ);
|
|
||||||
return MakeConstSpan(Row(row), Cols());
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT<T>`, plus the
|
// Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT<T>`, plus the
|
||||||
// optional `args`.
|
// optional `args`. Currently unused but may be used after we move toward
|
||||||
|
// type-erased `WeightsPtrs`.
|
||||||
template <class Func, typename... Args>
|
template <class Func, typename... Args>
|
||||||
decltype(auto) CallUpcasted(Type type, MatPtr* base, const Func& func,
|
decltype(auto) CallUpcasted(Type type, MatPtr* base, const Func& func,
|
||||||
Args&&... args) {
|
Args&&... args) {
|
||||||
|
|
|
||||||
|
|
@ -118,6 +118,7 @@ class ThreadingContext2 {
|
||||||
// changing the arguments between tests. Callers must again call `Get`
|
// changing the arguments between tests. Callers must again call `Get`
|
||||||
// afterwards to obtain an instance. WARNING: must not be called concurrently
|
// afterwards to obtain an instance. WARNING: must not be called concurrently
|
||||||
// with other calls to `Get` and usages of its return value.
|
// with other calls to `Get` and usages of its return value.
|
||||||
|
// Also useful to suppress memory leak warnings in tests.
|
||||||
static void ThreadHostileInvalidate();
|
static void ThreadHostileInvalidate();
|
||||||
|
|
||||||
explicit ThreadingContext2(PrivateToken); // only called via `Get`.
|
explicit ThreadingContext2(PrivateToken); // only called via `Get`.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue