Rename ModelTraining to PromptWrapping which is a more accurate name.

PiperOrigin-RevId: 705881500
This commit is contained in:
Daniel Keysers 2024-12-13 07:45:25 -08:00 committed by Copybara-Service
parent 6254f2e5ca
commit 62c70d6715
13 changed files with 49 additions and 49 deletions

View File

@ -49,7 +49,7 @@ TEST(OptimizeTest, GradientDescent) {
const ModelInfo info = {
.model = Model::GEMMA_TINY,
.training = ModelTraining::GEMMA_IT,
.wrapping = PromptWrapping::GEMMA_IT,
.weight = Type::kF32,
};
ModelConfig config = ConfigFromModel(info.model);

View File

@ -40,8 +40,8 @@
#include <vector>
#include "compression/compress.h"
#include "compression/shared.h" // ModelTraining
#include "compression/io.h" // Path
#include "compression/shared.h" // PromptWrapping
#include "gemma/common.h" // Model
#include "gemma/weights.h"
#include "util/allocator.h"
@ -74,8 +74,8 @@ struct Args : public ArgsBase<Args> {
// Returns error string or nullptr if OK.
const char* Validate() {
if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_,
model_training_)) {
if (const char* err = ParseModelTypeAndWrapping(model_type_str, model_type_,
prompt_wrapping_)) {
return err;
}
if (const char* err = ParseType(weight_type_str, weight_type_)) {
@ -127,12 +127,12 @@ struct Args : public ArgsBase<Args> {
// Uninitialized before Validate, must call after that.
gcpp::Model ModelType() const { return model_type_; }
gcpp::ModelTraining ModelTrainingType() const { return model_training_; }
gcpp::PromptWrapping PromptWrappingType() const { return prompt_wrapping_; }
gcpp::Type WeightType() const { return weight_type_; }
private:
Model model_type_;
ModelTraining model_training_;
PromptWrapping prompt_wrapping_;
Type weight_type_;
};
@ -212,7 +212,7 @@ namespace gcpp {
void Run(Args& args) {
hwy::ThreadPool pool(args.num_threads);
if (args.ModelTrainingType() == ModelTraining::PALIGEMMA) {
if (args.PromptWrappingType() == PromptWrapping::PALIGEMMA) {
HWY_ABORT("PaliGemma is not supported in compress_weights.");
}
const Model model_type = args.ModelType();

View File

@ -196,7 +196,7 @@ constexpr bool IsNuqStream() {
}
// Instruction-tuned models require extra 'turn structure' tokens in prompts.
enum class ModelTraining { GEMMA_IT, GEMMA_PT, PALIGEMMA };
enum class PromptWrapping { GEMMA_IT, GEMMA_PT, PALIGEMMA };
// Tensor types for loading weights. Note that not all types are supported as
// weights for a model, but can be used for other purposes, such as types for

View File

@ -187,7 +187,7 @@ int main(int argc, char** argv) {
const std::string golden_path =
benchmark_args.goldens.path + "/" +
gcpp::ModelString(env.GetModel()->Info().model,
env.GetModel()->Info().training) +
env.GetModel()->Info().wrapping) +
".txt";
return BenchmarkGoldens(env, golden_path);
} else if (!benchmark_args.summarize_text.Empty()) {

View File

@ -60,25 +60,25 @@ constexpr Model kModelTypes[] = {
Model::PALIGEMMA2_10B_224, // PaliGemma2 10B 224
Model::PALIGEMMA2_10B_448, // PaliGemma2 10B 448
};
constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 7B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // RecurrentGemma
ModelTraining::GEMMA_IT, // Gemma Tiny
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 2B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 9B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma2 27B
ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PaliGemma 224 / 448
ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PG2 3B 224 / 448
ModelTraining::PALIGEMMA, ModelTraining::PALIGEMMA, // PG2 10B 224 / 448
constexpr PromptWrapping kPromptWrapping[] = {
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 2B
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma 7B
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // RecurrentGemma
PromptWrapping::GEMMA_IT, // Gemma Tiny
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 2B
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 9B
PromptWrapping::GEMMA_PT, PromptWrapping::GEMMA_IT, // Gemma2 27B
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PaliGemma 224/448
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 3B 224/448
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 10B 224/448
};
constexpr size_t kNumModelFlags = std::size(kModelFlags);
static_assert(kNumModelFlags == std::size(kModelTypes));
static_assert(kNumModelFlags == std::size(kModelTraining));
static_assert(kNumModelFlags == std::size(kPromptWrapping));
const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training) {
const char* ParseModelTypeAndWrapping(const std::string& model_flag,
Model& model, PromptWrapping& wrapping) {
static std::string kErrorMessageBuffer =
"Invalid or missing model flag, need to specify one of ";
for (size_t i = 0; i + 1 < kNumModelFlags; ++i) {
@ -93,21 +93,21 @@ const char* ParseModelTypeAndTraining(const std::string& model_flag,
for (size_t i = 0; i < kNumModelFlags; ++i) {
if (kModelFlags[i] == model_type_lc) {
model = kModelTypes[i];
training = kModelTraining[i];
HWY_ASSERT(std::string(ModelString(model, training)) == model_type_lc);
wrapping = kPromptWrapping[i];
HWY_ASSERT(std::string(ModelString(model, wrapping)) == model_type_lc);
return nullptr;
}
}
return kErrorMessageBuffer.c_str();
}
const char* ModelString(Model model, ModelTraining training) {
const char* ModelString(Model model, PromptWrapping wrapping) {
for (size_t i = 0; i < kNumModelFlags; i++) {
if (kModelTypes[i] == model && kModelTraining[i] == training)
if (kModelTypes[i] == model && kPromptWrapping[i] == wrapping)
return kModelFlags[i];
}
HWY_ABORT("Unknown model %d training %d\n", static_cast<int>(model),
static_cast<int>(training));
HWY_ABORT("Unknown model %d wrapping %d\n", static_cast<int>(model),
static_cast<int>(wrapping));
}
const char* StringFromType(Type type) {
@ -139,7 +139,7 @@ const char* ParseType(const std::string& type_string, Type& type) {
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) {
// Instruction-tuned models are trained to expect control tokens.
if (info.training == ModelTraining::GEMMA_IT) {
if (info.wrapping == PromptWrapping::GEMMA_IT) {
// Prepend "<end_of_turn>" if this is a multi-turn dialogue continuation.
const std::string start = (pos == 0)
? "<start_of_turn>user\n"

View File

@ -20,7 +20,7 @@
#include <string>
#include "compression/shared.h" // ModelTraining
#include "compression/shared.h" // PromptWrapping
#include "gemma/configs.h" // IWYU pragma: export
#include "hwy/base.h" // ConvertScalarTo
@ -29,18 +29,18 @@ namespace gcpp {
// Struct to bundle model information.
struct ModelInfo {
Model model;
ModelTraining training;
PromptWrapping wrapping;
Type weight;
};
// Returns error string or nullptr if OK.
// Thread-hostile.
const char* ParseModelTypeAndTraining(const std::string& model_flag,
Model& model, ModelTraining& training);
const char* ParseModelTypeAndWrapping(const std::string& model_flag,
Model& model, PromptWrapping& wrapping);
const char* ParseType(const std::string& type_string, Type& type);
// Inverse of ParseModelTypeAndTraining.
const char* ModelString(Model model, ModelTraining training);
// Inverse of ParseModelTypeAndWrapping.
const char* ModelString(Model model, PromptWrapping wrapping);
const char* StringFromType(Type type);
// Wraps the given prompt using the expected control tokens for IT models.

View File

@ -366,13 +366,13 @@ bool LayerConfig::TestEqual(const LayerConfig& other, bool partial,
bool ModelConfig::TestEqual(const ModelConfig& other, bool partial,
bool debug) const {
bool result = true;
// We don't care about model_name, model, training, or weight being different,
// We don't care about model_name, model, wrapping, or weight being different,
// but will output in debug mode if they are.
if (debug) {
WARN_IF_NOT_EQUAL(model_name, other.model_name);
WARN_IF_NOT_EQUAL(static_cast<int>(model), static_cast<int>(other.model));
WARN_IF_NOT_EQUAL(static_cast<int>(training),
static_cast<int>(other.training));
WARN_IF_NOT_EQUAL(static_cast<int>(wrapping),
static_cast<int>(other.wrapping));
WARN_IF_NOT_EQUAL(static_cast<int>(weight), static_cast<int>(other.weight));
}
TEST_EQUAL(model_dim, other.model_dim);

View File

@ -200,7 +200,7 @@ struct ModelConfig {
std::string model_name;
Model model;
ModelTraining training;
PromptWrapping wrapping;
Type weight;
size_t num_layers = 0;
size_t model_dim = 0;

View File

@ -244,7 +244,7 @@ class Gemma {
ModelInfo info_;
};
// Adds BOS token and possibly 'turn' annotations, which depend on `training`
// Adds BOS token and possibly 'turn' annotations, which depend on `info`
// and `pos`, the number of tokens decoded so far; returns the corresponding
// tokens. Asserts that tokenization is successful.
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,

View File

@ -22,7 +22,7 @@
#include <vector>
// Placeholder for internal header, do not modify.
#include "compression/shared.h" // ModelTraining
#include "compression/shared.h" // PromptWrapping
#include "evals/benchmark_helper.h"
#include "gemma/common.h"
#include "gemma/gemma.h" // Gemma
@ -96,7 +96,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
if (have_image) {
image_tokens = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len,
model.GetModelConfig().model_dim));
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(args.image_file.path));
const size_t image_size = model.GetModelConfig().image_size;
image.Resize(image_size, image_size);
@ -207,7 +207,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
std::cout << "\n\n";
// Prepare for the next turn.
if (!args.multiturn || model.Info().training == ModelTraining::PALIGEMMA) {
if (!args.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) {
abs_pos = 0; // Start a new turn at position 0.
InitGenerator(args, gen);
} else {

View File

@ -22,7 +22,7 @@
#include <vector>
#include "compression/io.h" // Path
#include "compression/shared.h" // ModelTraining
#include "compression/shared.h" // PromptWrapping
#include "gemma/common.h" // Wrap
#include "hwy/base.h" // HWY_ASSERT
#include "hwy/profiler.h"
@ -110,7 +110,7 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
}
// PaliGemma separator. The SEP token "\n" is always tokenized separately.
if (info.training == ModelTraining::PALIGEMMA) {
if (info.wrapping == PromptWrapping::PALIGEMMA) {
std::vector<int> sep_tokens;
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end());

View File

@ -53,7 +53,7 @@ void PaliGemmaTest::InitVit(const std::string& path) {
image_tokens_ = ImageTokens(Extents2D(model.GetModelConfig().vit_seq_len,
model.GetModelConfig().model_dim));
Image image;
HWY_ASSERT(model.Info().training == ModelTraining::PALIGEMMA);
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA);
HWY_ASSERT(image.ReadPPM(path));
const size_t image_size = model.GetModelConfig().image_size;
image.Resize(image_size, image_size);

View File

@ -136,8 +136,8 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
// Returns error string or nullptr if OK.
const char* Validate() {
if (const char* err = ParseModelTypeAndTraining(model_type_str, info_.model,
info_.training)) {
if (const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
info_.wrapping)) {
return err;
}
if (const char* err = ParseType(weight_type_str, info_.weight)) {