Rename ModelTraining to PromptWrapping which is a more accurate name.

PiperOrigin-RevId: 705881500
This commit is contained in:
Daniel Keysers 2024-12-13 07:45:25 -08:00 committed by Copybara-Service
parent 6254f2e5ca
commit 62c70d6715
13 changed files with 49 additions and 49 deletions

View File

@ -49,7 +49,7 @@ TEST(OptimizeTest, GradientDescent) {
const ModelInfo info = { const ModelInfo info = {
.model = Model::GEMMA_TINY, .model = Model::GEMMA_TINY,
.training = ModelTraining::GEMMA_IT, .wrapping = PromptWrapping::GEMMA_IT,
.weight = Type::kF32, .weight = Type::kF32,
}; };
ModelConfig config = ConfigFromModel(info.model); ModelConfig config = ConfigFromModel(info.model);

View File

@ -40,8 +40,8 @@
#include <vector> #include <vector>
#include "compression/compress.h" #include "compression/compress.h"
#include "compression/shared.h" // ModelTraining
#include "compression/io.h" // Path #include "compression/io.h" // Path
#include "compression/shared.h" // PromptWrapping
#include "gemma/common.h" // Model #include "gemma/common.h" // Model
#include "gemma/weights.h" #include "gemma/weights.h"
#include "util/allocator.h" #include "util/allocator.h"
@ -74,8 +74,8 @@ struct Args : public ArgsBase<Args> {
// Returns error string or nullptr if OK. // Returns error string or nullptr if OK.
const char* Validate() { const char* Validate() {
if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_, if (const char* err = ParseModelTypeAndWrapping(model_type_str, model_type_,
model_training_)) { prompt_wrapping_)) {
return err; return err;
} }
if (const char* err = ParseType(weight_type_str, weight_type_)) { if (const char* err = ParseType(weight_type_str, weight_type_)) {
@ -127,12 +127,12 @@ struct Args : public ArgsBase<Args> {
// Uninitialized before Validate, must call after that. // Uninitialized before Validate, must call after that.
gcpp::Model ModelType() const { return model_type_; } gcpp::Model ModelType() const { return model_type_; }
gcpp::ModelTraining ModelTrainingType() const { return model_training_; } gcpp::PromptWrapping PromptWrappingType() const { return prompt_wrapping_; }
gcpp::Type WeightType() const { return weight_type_; } gcpp::Type WeightType() const { return weight_type_; }
private: private:
Model model_type_; Model model_type_;
ModelTraining model_training_; PromptWrapping prompt_wrapping_;
Type weight_type_; Type weight_type_;
}; };
@ -212,7 +212,7 @@ namespace gcpp {
void Run(Args& args) { void Run(Args& args) {
hwy::ThreadPool pool(args.num_threads); hwy::ThreadPool pool(args.num_threads);
if (args.ModelTrainingType() == ModelTraining::PALIGEMMA) { if (args.PromptWrappingType() == PromptWrapping::PALIGEMMA) {
HWY_ABORT("PaliGemma is not supported in compress_weights."); HWY_ABORT("PaliGemma is not supported in compress_weights.");
} }
const Model model_type = args.ModelType(); const Model model_type = args.ModelType();

View File

@ -196,7 +196,7 @@ constexpr bool IsNuqStream() {
} }
// Instruction-tuned models require extra 'turn structure' tokens in prompts. // Instruction-tuned models require extra 'turn structure' tokens in prompts.
enum class ModelTraining { GEMMA_IT, GEMMA_PT, PALIGEMMA }; enum class PromptWrapping { GEMMA_IT, GEMMA_PT, PALIGEMMA };
// Tensor types for loading weights. Note that not all types are supported as // Tensor types for loading weights. Note that not all types are supported as
// weights for a model, but can be used for other purposes, such as types for // weights for a model, but can be used for other purposes, such as types for

View File

@ -187,7 +187,7 @@ int main(int argc, char** argv) {
const std::string golden_path = const std::string golden_path =
benchmark_args.goldens.path + "/" + benchmark_args.goldens.path + "/" +
gcpp::ModelString(env.GetModel()->Info().model, gcpp::ModelString(env.GetModel()->Info().model,
env.GetModel()->Info().training) + env.GetModel()->Info().wrapping) +
".txt"; ".txt";
return BenchmarkGoldens(env, golden_path); return BenchmarkGoldens(env, golden_path);
} else if (!benchmark_args.summarize_text.Empty()) { } else if (!benchmark_args.summarize_text.Empty()) {

View File

@ -60,25 +60,25 @@ constexpr Model kModelTypes[] = {
Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224 Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224
Model::PALIGEMMA2_10B_448, // PaliGemma2 10B 448 Model::PALIGEMMA2_10B_448, // PaliGemma2 10B 448
}; };
constexpr ModelTraining kModelTraining[] = { constexpr PromptWrapping kPromptWrapping[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 2B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 7B PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 7B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // RecurrentGemma PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // RecurrentGemma
ModelTraining::GEMMA_IT, // Gemma Tiny PromptWrapping::GEMMA_IT, // Gemma Tiny
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 2B PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 2B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 9B PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 9B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 27B PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 27B
ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PaliGemma 224 / 448 PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PaliGemma 224/448
ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PG2 3B 224 / 448 PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 3B 224/448
ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PG2 10B 224 / 448 PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 10B 224/448
}; };
constexpr size_t kNumModelFlags = std::size(kModelFlags); constexpr size_t kNumModelFlags = std::size(kModelFlags);
static_assert(kNumModelFlags == std::size(kModelTypes)); static_assert(kNumModelFlags == std::size(kModelTypes));
static_assert(kNumModelFlags == std::size(kModelTraining)); static_assert(kNumModelFlags == std::size(kPromptWrapping));
const char* ParseModelTypeAndTraining(const std::string& model_flag, const char* ParseModelTypeAndWrapping(const std::string& model_flag,
Model& model, ModelTraining& training) { Model& model, PromptWrapping& wrapping) {
static std::string kErrorMessageBuffer = static std::string kErrorMessageBuffer =
"Invalid or missing model flag, need to specify one of "; "Invalid or missing model flag, need to specify one of ";
for (size_t i = 0; i + 1 < kNumModelFlags; ++i) { for (size_t i = 0; i + 1 < kNumModelFlags; ++i) {
@ -93,21 +93,21 @@ const char* ParseModelTypeAndTraining(const std::string& model_flag,
for (size_t i = 0; i < kNumModelFlags; ++i) { for (size_t i = 0; i < kNumModelFlags; ++i) {
if (kModelFlags[i] == model_type_lc) { if (kModelFlags[i] == model_type_lc) {
model = kModelTypes[i]; model = kModelTypes[i];
training = kModelTraining[i]; wrapping = kPromptWrapping[i];
HWY_ASSERT(std::string(ModelString(model, training)) == model_type_lc); HWY_ASSERT(std::string(ModelString(model, wrapping)) == model_type_lc);
return nullptr; return nullptr;
} }
} }
return kErrorMessageBuffer.c_str(); return kErrorMessageBuffer.c_str();
} }
const char* ModelString(Model model, ModelTraining training) { const char* ModelString(Model model, PromptWrapping wrapping) {
for (size_t i = 0; i < kNumModelFlags; i++) { for (size_t i = 0; i < kNumModelFlags; i++) {
if (kModelTypes[i] == model && kModelTraining[i] == training) if (kModelTypes[i] == model && kPromptWrapping[i] == wrapping)
return kModelFlags[i]; return kModelFlags[i];
} }
HWY_ABORT("Unknown model %d training %d\n", static_cast<int>(model), HWY_ABORT("Unknown model %d wrapping %d\n", static_cast<int>(model),
static_cast<int>(training)); static_cast<int>(wrapping));
} }
const char* StringFromType(Type type) { const char* StringFromType(Type type) {
@ -139,7 +139,7 @@ const char* ParseType(const std::string& type_string, Type& type) {
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) { void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) {
// Instruction-tuned models are trained to expect control tokens. // Instruction-tuned models are trained to expect control tokens.
if (info.training == ModelTraining::GEMMA_IT) { if (info.wrapping == PromptWrapping::GEMMA_IT) {
// Prepend "<end_of_turn>" if this is a multi-turn dialogue continuation. // Prepend "<end_of_turn>" if this is a multi-turn dialogue continuation.
const std::string start = (pos == 0) const std::string start = (pos == 0)
? "<start_of_turn>user\n" ? "<start_of_turn>user\n"

View File

@ -20,7 +20,7 @@
#include <string> #include <string>
#include "compression/shared.h" // ModelTraining #include "compression/shared.h" // PromptWrapping
#include "gemma/configs.h" // IWYU pragma: export #include "gemma/configs.h" // IWYU pragma: export
#include "hwy/base.h" // ConvertScalarTo #include "hwy/base.h" // ConvertScalarTo
@ -29,18 +29,18 @@ namespace gcpp {
// Struct to bundle model information. // Struct to bundle model information.
struct ModelInfo { struct ModelInfo {
Model model; Model model;
ModelTraining training; PromptWrapping wrapping;
Type weight; Type weight;
}; };
// Returns error string or nullptr if OK. // Returns error string or nullptr if OK.
// Thread-hostile. // Thread-hostile.
const char* ParseModelTypeAndTraining(const std::string& model_flag, const char* ParseModelTypeAndWrapping(const std::string& model_flag,
Model& model, ModelTraining& training); Model& model, PromptWrapping& wrapping);
const char* ParseType(const std::string& type_string, Type& type); const char* ParseType(const std::string& type_string, Type& type);
// Inverse of ParseModelTypeAndTraining. // Inverse of ParseModelTypeAndWrapping.
const char* ModelString(Model model, ModelTraining training); const char* ModelString(Model model, PromptWrapping wrapping);
const char* StringFromType(Type type); const char* StringFromType(Type type);
// Wraps the given prompt using the expected control tokens for IT models. // Wraps the given prompt using the expected control tokens for IT models.

View File

@ -366,13 +366,13 @@ bool LayerConfig::TestEqual(const LayerConfig& other, bool partial,
bool ModelConfig::TestEqual(const ModelConfig& other, bool partial, bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
bool debug) const { bool debug) const {
bool result = true; bool result = true;
// We don't care about model_name, model, training, or weight being different, // We don't care about model_name, model, wrapping, or weight being different,
// but will output in debug mode if they are. // but will output in debug mode if they are.
if (debug) { if (debug) {
WARN_IF_NOT_EQUAL(model_name, other.model_name); WARN_IF_NOT_EQUAL(model_name, other.model_name);
WARN_IF_NOT_EQUAL(static_cast<int>(model), static_cast<int>(other.model)); WARN_IF_NOT_EQUAL(static_cast<int>(model), static_cast<int>(other.model));
WARN_IF_NOT_EQUAL(static_cast<int>(training), WARN_IF_NOT_EQUAL(static_cast<int>(wrapping),
static_cast<int>(other.training)); static_cast<int>(other.wrapping));
WARN_IF_NOT_EQUAL(static_cast<int>(weight), static_cast<int>(other.weight)); WARN_IF_NOT_EQUAL(static_cast<int>(weight), static_cast<int>(other.weight));
} }
TEST_EQUAL(model_dim, other.model_dim); TEST_EQUAL(model_dim, other.model_dim);

View File

@ -200,7 +200,7 @@ struct ModelConfig {
std::string model_name; std::string model_name;
Model model; Model model;
ModelTraining training; PromptWrapping wrapping;
Type weight; Type weight;
size_t num_layers = 0; size_t num_layers = 0;
size_t model_dim = 0; size_t model_dim = 0;

View File

@ -244,7 +244,7 @@ class Gemma {
ModelInfo info_; ModelInfo info_;
}; };
// Adds BOS token and possibly 'turn' annotations, which depend on `training` // Adds BOS token and possibly 'turn' annotations, which depend on `info`
// and `pos`, the number of tokens decoded so far; returns the corresponding // and `pos`, the number of tokens decoded so far; returns the corresponding
// tokens. Asserts that tokenization is successful. // tokens. Asserts that tokenization is successful.
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer, std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,

View File

@ -22,7 +22,7 @@
#include <vector> #include <vector>
// Placeholder for internal header, do not modify. // Placeholder for internal header, do not modify.
#include "compression/shared.h" // ModelTraining #include "compression/shared.h" // PromptWrapping
#include "evals/benchmark_helper.h" #include "evals/benchmark_helper.h"
#include "gemma/common.h" #include "gemma/common.h"
#include "gemma/gemma.h" // Gemma #include "gemma/gemma.h" // Gemma
@ -96,7 +96,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
if (have_image) { if (have_image) {
image_tokens = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len, image_tokens = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len,
model.GetModelConfig().model_dim)); model.GetModelConfig().model_dim));
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA); HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(args.image_file.path)); HWY_ASSERT(image.ReadPPM(args.image_file.path));
const size_t image_size = model.GetModelConfig().image_size; const size_t image_size = model.GetModelConfig().image_size;
image.Resize(image_size, image_size); image.Resize(image_size, image_size);
@ -207,7 +207,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
std::cout << "\n\n"; std::cout << "\n\n";
// Prepare for the next turn. // Prepare for the next turn.
if (!args.multiturn || model.Info().training == ModelTraining::PALIGEMMA) { if (!args.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) {
abs_pos = 0; // Start a new turn at position 0. abs_pos = 0; // Start a new turn at position 0.
InitGenerator(args, gen); InitGenerator(args, gen);
} else { } else {

View File

@ -22,7 +22,7 @@
#include <vector> #include <vector>
#include "compression/io.h" // Path #include "compression/io.h" // Path
#include "compression/shared.h" // ModelTraining #include "compression/shared.h" // PromptWrapping
#include "gemma/common.h" // Wrap #include "gemma/common.h" // Wrap
#include "hwy/base.h" // HWY_ASSERT #include "hwy/base.h" // HWY_ASSERT
#include "hwy/profiler.h" #include "hwy/profiler.h"
@ -110,7 +110,7 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
} }
// PaliGemma separator. The SEP token "\n" is always tokenized separately. // PaliGemma separator. The SEP token "\n" is always tokenized separately.
if (info.training == ModelTraining::PALIGEMMA) { if (info.wrapping == PromptWrapping::PALIGEMMA) {
std::vector<int> sep_tokens; std::vector<int> sep_tokens;
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens)); HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end()); tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end());

View File

@ -53,7 +53,7 @@ void PaliGemmaTest::InitVit(const std::string& path) {
image_tokens_ = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len, image_tokens_ = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len,
model.GetModelConfig().model_dim)); model.GetModelConfig().model_dim));
Image image; Image image;
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA); HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(path)); HWY_ASSERT(image.ReadPPM(path));
const size_t image_size = model.GetModelConfig().image_size; const size_t image_size = model.GetModelConfig().image_size;
image.Resize(image_size, image_size); image.Resize(image_size, image_size);

View File

@ -136,8 +136,8 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
// Returns error string or nullptr if OK. // Returns error string or nullptr if OK.
const char* Validate() { const char* Validate() {
if (const char* err = ParseModelTypeAndTraining(model_type_str, info_.model, if (const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
info_.training)) { info_.wrapping)) {
return err; return err;
} }
if (const char* err = ParseType(weight_type_str, info_.weight)) { if (const char* err = ParseType(weight_type_str, info_.weight)) {