mirror of https://github.com/google/gemma.cpp.git
Rename ModelTraining to PromptWrapping which is a more accurate name.
PiperOrigin-RevId: 705881500
This commit is contained in:
parent
6254f2e5ca
commit
62c70d6715
|
|
@ -49,7 +49,7 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
|
||||
const ModelInfo info = {
|
||||
.model = Model::GEMMA_TINY,
|
||||
.training = ModelTraining::GEMMA_IT,
|
||||
.wrapping = PromptWrapping::GEMMA_IT,
|
||||
.weight = Type::kF32,
|
||||
};
|
||||
ModelConfig config = ConfigFromModel(info.model);
|
||||
|
|
|
|||
|
|
@ -40,8 +40,8 @@
|
|||
#include <vector>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "compression/shared.h" // ModelTraining
|
||||
#include "compression/io.h" // Path
|
||||
#include "compression/shared.h" // PromptWrapping
|
||||
#include "gemma/common.h" // Model
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
|
|
@ -74,8 +74,8 @@ struct Args : public ArgsBase<Args> {
|
|||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() {
|
||||
if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_,
|
||||
model_training_)) {
|
||||
if (const char* err = ParseModelTypeAndWrapping(model_type_str, model_type_,
|
||||
prompt_wrapping_)) {
|
||||
return err;
|
||||
}
|
||||
if (const char* err = ParseType(weight_type_str, weight_type_)) {
|
||||
|
|
@ -127,12 +127,12 @@ struct Args : public ArgsBase<Args> {
|
|||
|
||||
// Uninitialized before Validate, must call after that.
|
||||
gcpp::Model ModelType() const { return model_type_; }
|
||||
gcpp::ModelTraining ModelTrainingType() const { return model_training_; }
|
||||
gcpp::PromptWrapping PromptWrappingType() const { return prompt_wrapping_; }
|
||||
gcpp::Type WeightType() const { return weight_type_; }
|
||||
|
||||
private:
|
||||
Model model_type_;
|
||||
ModelTraining model_training_;
|
||||
PromptWrapping prompt_wrapping_;
|
||||
Type weight_type_;
|
||||
};
|
||||
|
||||
|
|
@ -212,7 +212,7 @@ namespace gcpp {
|
|||
|
||||
void Run(Args& args) {
|
||||
hwy::ThreadPool pool(args.num_threads);
|
||||
if (args.ModelTrainingType() == ModelTraining::PALIGEMMA) {
|
||||
if (args.PromptWrappingType() == PromptWrapping::PALIGEMMA) {
|
||||
HWY_ABORT("PaliGemma is not supported in compress_weights.");
|
||||
}
|
||||
const Model model_type = args.ModelType();
|
||||
|
|
|
|||
|
|
@ -196,7 +196,7 @@ constexpr bool IsNuqStream() {
|
|||
}
|
||||
|
||||
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
|
||||
enum class ModelTraining { GEMMA_IT, GEMMA_PT, PALIGEMMA };
|
||||
enum class PromptWrapping { GEMMA_IT, GEMMA_PT, PALIGEMMA };
|
||||
|
||||
// Tensor types for loading weights. Note that not all types are supported as
|
||||
// weights for a model, but can be used for other purposes, such as types for
|
||||
|
|
|
|||
|
|
@ -187,7 +187,7 @@ int main(int argc, char** argv) {
|
|||
const std::string golden_path =
|
||||
benchmark_args.goldens.path + "/" +
|
||||
gcpp::ModelString(env.GetModel()->Info().model,
|
||||
env.GetModel()->Info().training) +
|
||||
env.GetModel()->Info().wrapping) +
|
||||
".txt";
|
||||
return BenchmarkGoldens(env, golden_path);
|
||||
} else if (!benchmark_args.summarize_text.Empty()) {
|
||||
|
|
|
|||
|
|
@ -60,25 +60,25 @@ constexpr Model kModelTypes[] = {
|
|||
Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224
|
||||
Model::PALIGEMMA2_10B_448, // PaliGemma2 10B 448
|
||||
};
|
||||
constexpr ModelTraining kModelTraining[] = {
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 7B
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // RecurrentGemma
|
||||
ModelTraining::GEMMA_IT, // Gemma Tiny
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 2B
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 9B
|
||||
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 27B
|
||||
ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PaliGemma 224 / 448
|
||||
ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PG2 3B 224 / 448
|
||||
ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PG2 10B 224 / 448
|
||||
constexpr PromptWrapping kPromptWrapping[] = {
|
||||
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 2B
|
||||
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 7B
|
||||
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // RecurrentGemma
|
||||
PromptWrapping::GEMMA_IT, // Gemma Tiny
|
||||
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 2B
|
||||
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 9B
|
||||
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 27B
|
||||
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PaliGemma 224/448
|
||||
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 3B 224/448
|
||||
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 10B 224/448
|
||||
};
|
||||
|
||||
constexpr size_t kNumModelFlags = std::size(kModelFlags);
|
||||
static_assert(kNumModelFlags == std::size(kModelTypes));
|
||||
static_assert(kNumModelFlags == std::size(kModelTraining));
|
||||
static_assert(kNumModelFlags == std::size(kPromptWrapping));
|
||||
|
||||
const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
||||
Model& model, ModelTraining& training) {
|
||||
const char* ParseModelTypeAndWrapping(const std::string& model_flag,
|
||||
Model& model, PromptWrapping& wrapping) {
|
||||
static std::string kErrorMessageBuffer =
|
||||
"Invalid or missing model flag, need to specify one of ";
|
||||
for (size_t i = 0; i + 1 < kNumModelFlags; ++i) {
|
||||
|
|
@ -93,21 +93,21 @@ const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
|||
for (size_t i = 0; i < kNumModelFlags; ++i) {
|
||||
if (kModelFlags[i] == model_type_lc) {
|
||||
model = kModelTypes[i];
|
||||
training = kModelTraining[i];
|
||||
HWY_ASSERT(std::string(ModelString(model, training)) == model_type_lc);
|
||||
wrapping = kPromptWrapping[i];
|
||||
HWY_ASSERT(std::string(ModelString(model, wrapping)) == model_type_lc);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return kErrorMessageBuffer.c_str();
|
||||
}
|
||||
|
||||
const char* ModelString(Model model, ModelTraining training) {
|
||||
const char* ModelString(Model model, PromptWrapping wrapping) {
|
||||
for (size_t i = 0; i < kNumModelFlags; i++) {
|
||||
if (kModelTypes[i] == model && kModelTraining[i] == training)
|
||||
if (kModelTypes[i] == model && kPromptWrapping[i] == wrapping)
|
||||
return kModelFlags[i];
|
||||
}
|
||||
HWY_ABORT("Unknown model %d training %d\n", static_cast<int>(model),
|
||||
static_cast<int>(training));
|
||||
HWY_ABORT("Unknown model %d wrapping %d\n", static_cast<int>(model),
|
||||
static_cast<int>(wrapping));
|
||||
}
|
||||
|
||||
const char* StringFromType(Type type) {
|
||||
|
|
@ -139,7 +139,7 @@ const char* ParseType(const std::string& type_string, Type& type) {
|
|||
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) {
|
||||
|
||||
// Instruction-tuned models are trained to expect control tokens.
|
||||
if (info.training == ModelTraining::GEMMA_IT) {
|
||||
if (info.wrapping == PromptWrapping::GEMMA_IT) {
|
||||
// Prepend "<end_of_turn>" if this is a multi-turn dialogue continuation.
|
||||
const std::string start = (pos == 0)
|
||||
? "<start_of_turn>user\n"
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@
|
|||
|
||||
#include <string>
|
||||
|
||||
#include "compression/shared.h" // ModelTraining
|
||||
#include "compression/shared.h" // PromptWrapping
|
||||
#include "gemma/configs.h" // IWYU pragma: export
|
||||
#include "hwy/base.h" // ConvertScalarTo
|
||||
|
||||
|
|
@ -29,18 +29,18 @@ namespace gcpp {
|
|||
// Struct to bundle model information.
|
||||
struct ModelInfo {
|
||||
Model model;
|
||||
ModelTraining training;
|
||||
PromptWrapping wrapping;
|
||||
Type weight;
|
||||
};
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
// Thread-hostile.
|
||||
const char* ParseModelTypeAndTraining(const std::string& model_flag,
|
||||
Model& model, ModelTraining& training);
|
||||
const char* ParseModelTypeAndWrapping(const std::string& model_flag,
|
||||
Model& model, PromptWrapping& wrapping);
|
||||
const char* ParseType(const std::string& type_string, Type& type);
|
||||
|
||||
// Inverse of ParseModelTypeAndTraining.
|
||||
const char* ModelString(Model model, ModelTraining training);
|
||||
// Inverse of ParseModelTypeAndWrapping.
|
||||
const char* ModelString(Model model, PromptWrapping wrapping);
|
||||
const char* StringFromType(Type type);
|
||||
|
||||
// Wraps the given prompt using the expected control tokens for IT models.
|
||||
|
|
|
|||
|
|
@ -366,13 +366,13 @@ bool LayerConfig::TestEqual(const LayerConfig& other, bool partial,
|
|||
bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
|
||||
bool debug) const {
|
||||
bool result = true;
|
||||
// We don't care about model_name, model, training, or weight being different,
|
||||
// We don't care about model_name, model, wrapping, or weight being different,
|
||||
// but will output in debug mode if they are.
|
||||
if (debug) {
|
||||
WARN_IF_NOT_EQUAL(model_name, other.model_name);
|
||||
WARN_IF_NOT_EQUAL(static_cast<int>(model), static_cast<int>(other.model));
|
||||
WARN_IF_NOT_EQUAL(static_cast<int>(training),
|
||||
static_cast<int>(other.training));
|
||||
WARN_IF_NOT_EQUAL(static_cast<int>(wrapping),
|
||||
static_cast<int>(other.wrapping));
|
||||
WARN_IF_NOT_EQUAL(static_cast<int>(weight), static_cast<int>(other.weight));
|
||||
}
|
||||
TEST_EQUAL(model_dim, other.model_dim);
|
||||
|
|
|
|||
|
|
@ -200,7 +200,7 @@ struct ModelConfig {
|
|||
|
||||
std::string model_name;
|
||||
Model model;
|
||||
ModelTraining training;
|
||||
PromptWrapping wrapping;
|
||||
Type weight;
|
||||
size_t num_layers = 0;
|
||||
size_t model_dim = 0;
|
||||
|
|
|
|||
|
|
@ -244,7 +244,7 @@ class Gemma {
|
|||
ModelInfo info_;
|
||||
};
|
||||
|
||||
// Adds BOS token and possibly 'turn' annotations, which depend on `training`
|
||||
// Adds BOS token and possibly 'turn' annotations, which depend on `info`
|
||||
// and `pos`, the number of tokens decoded so far; returns the corresponding
|
||||
// tokens. Asserts that tokenization is successful.
|
||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@
|
|||
#include <vector>
|
||||
|
||||
// Placeholder for internal header, do not modify.
|
||||
#include "compression/shared.h" // ModelTraining
|
||||
#include "compression/shared.h" // PromptWrapping
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h" // Gemma
|
||||
|
|
@ -96,7 +96,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
if (have_image) {
|
||||
image_tokens = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len,
|
||||
model.GetModelConfig().model_dim));
|
||||
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
|
||||
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
|
||||
HWY_ASSERT(image.ReadPPM(args.image_file.path));
|
||||
const size_t image_size = model.GetModelConfig().image_size;
|
||||
image.Resize(image_size, image_size);
|
||||
|
|
@ -207,7 +207,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
std::cout << "\n\n";
|
||||
|
||||
// Prepare for the next turn.
|
||||
if (!args.multiturn || model.Info().training == ModelTraining::PALIGEMMA) {
|
||||
if (!args.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
||||
abs_pos = 0; // Start a new turn at position 0.
|
||||
InitGenerator(args, gen);
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
#include "compression/shared.h" // ModelTraining
|
||||
#include "compression/shared.h" // PromptWrapping
|
||||
#include "gemma/common.h" // Wrap
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
#include "hwy/profiler.h"
|
||||
|
|
@ -110,7 +110,7 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
|||
}
|
||||
|
||||
// PaliGemma separator. The SEP token "\n" is always tokenized separately.
|
||||
if (info.training == ModelTraining::PALIGEMMA) {
|
||||
if (info.wrapping == PromptWrapping::PALIGEMMA) {
|
||||
std::vector<int> sep_tokens;
|
||||
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
|
||||
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end());
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ void PaliGemmaTest::InitVit(const std::string& path) {
|
|||
image_tokens_ = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len,
|
||||
model.GetModelConfig().model_dim));
|
||||
Image image;
|
||||
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
|
||||
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
|
||||
HWY_ASSERT(image.ReadPPM(path));
|
||||
const size_t image_size = model.GetModelConfig().image_size;
|
||||
image.Resize(image_size, image_size);
|
||||
|
|
|
|||
|
|
@ -136,8 +136,8 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() {
|
||||
if (const char* err = ParseModelTypeAndTraining(model_type_str, info_.model,
|
||||
info_.training)) {
|
||||
if (const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
|
||||
info_.wrapping)) {
|
||||
return err;
|
||||
}
|
||||
if (const char* err = ParseType(weight_type_str, info_.weight)) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue