Cleanup: add ModelInfo struct, remove gcpp::

PiperOrigin-RevId: 648707763
This commit is contained in:
Jan Wassenberg 2024-07-02 07:10:32 -07:00 committed by Copybara-Service
parent b1c1ec1d59
commit 85fcd3cd80
10 changed files with 101 additions and 107 deletions

View File

@ -37,21 +37,24 @@ TEST(OptimizeTest, GradientDescent) {
hwy::ThreadPool pool(0); hwy::ThreadPool pool(0);
std::mt19937 gen(42); std::mt19937 gen(42);
Model model_type = Model::GEMMA_TINY; const ModelInfo info = {
Type weight_type = Type::kF32; .model = Model::GEMMA_TINY,
.training = ModelTraining::GEMMA_IT,
.weight = Type::kF32,
};
ByteStorageT grad = CallForModelAndWeight<AllocateCompressedWeights>( ByteStorageT grad = CallForModelAndWeight<AllocateCompressedWeights>(
model_type, weight_type, pool); info.model, info.weight, pool);
ByteStorageT grad_m = CallForModelAndWeight<AllocateCompressedWeights>( ByteStorageT grad_m = CallForModelAndWeight<AllocateCompressedWeights>(
model_type, weight_type, pool); info.model, info.weight, pool);
ByteStorageT grad_v = CallForModelAndWeight<AllocateCompressedWeights>( ByteStorageT grad_v = CallForModelAndWeight<AllocateCompressedWeights>(
model_type, weight_type, pool); info.model, info.weight, pool);
ByteStorageT forward = ByteStorageT forward =
CallForModelAndWeight<AllocateForwardPass>(model_type, weight_type); CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
ByteStorageT backward = ByteStorageT backward =
CallForModelAndWeight<AllocateForwardPass>(model_type, weight_type); CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
KVCache kv_cache = KVCache::Create(model_type); KVCache kv_cache = KVCache::Create(info.model);
Gemma gemma(GemmaTokenizer(), model_type, weight_type, pool); Gemma gemma(GemmaTokenizer(), info, pool);
const auto generate = [&](const std::vector<int>& prompt) { const auto generate = [&](const std::vector<int>& prompt) {
std::vector<int> reply; std::vector<int> reply;
@ -85,14 +88,14 @@ TEST(OptimizeTest, GradientDescent) {
return ok; return ok;
}; };
RandInitWeights(model_type, weight_type, gemma.Weights(), pool, gen); RandInitWeights(info.model, info.weight, gemma.Weights(), pool, gen);
CallForModelAndWeight<ZeroInitCompressedWeights>( CallForModelAndWeight<ZeroInitCompressedWeights>(info.model, info.weight,
model_type, weight_type, grad_m, pool); grad_m, pool);
CallForModelAndWeight<ZeroInitCompressedWeights>( CallForModelAndWeight<ZeroInitCompressedWeights>(info.model, info.weight,
model_type, weight_type, grad_v, pool); grad_v, pool);
printf("Initial weights:\n"); printf("Initial weights:\n");
LogWeightStats(model_type, weight_type, gemma.Weights()); LogWeightStats(info.model, info.weight, gemma.Weights());
constexpr size_t kBatchSize = 8; constexpr size_t kBatchSize = 8;
const float alpha = 0.001f; const float alpha = 0.001f;
@ -107,27 +110,27 @@ TEST(OptimizeTest, GradientDescent) {
size_t num_ok; size_t num_ok;
for (; steps < 1000000; ++steps) { for (; steps < 1000000; ++steps) {
std::mt19937 sgen(42); std::mt19937 sgen(42);
CallForModelAndWeight<ZeroInitCompressedWeights>( CallForModelAndWeight<ZeroInitCompressedWeights>(info.model, info.weight,
model_type, weight_type, grad, pool); grad, pool);
float total_loss = 0.0f; float total_loss = 0.0f;
num_ok = 0; num_ok = 0;
for (size_t i = 0; i < kBatchSize; ++i) { for (size_t i = 0; i < kBatchSize; ++i) {
Prompt prompt = training_task.Sample(sgen); Prompt prompt = training_task.Sample(sgen);
total_loss += CrossEntropyLossForwardPass(model_type, prompt, total_loss += CrossEntropyLossForwardPass(info.model, prompt,
gemma.Weights(), forward, pool); gemma.Weights(), forward, pool);
CrossEntropyLossBackwardPass(model_type, prompt, gemma.Weights(), forward, CrossEntropyLossBackwardPass(info.model, prompt, gemma.Weights(), forward,
grad, backward, pool); grad, backward, pool);
num_ok += verify(prompt) ? 1 : 0; num_ok += verify(prompt) ? 1 : 0;
} }
total_loss /= kBatchSize; total_loss /= kBatchSize;
AdamUpdate(model_type, weight_type, grad, alpha, beta1, beta2, epsilon, AdamUpdate(info.model, info.weight, grad, alpha, beta1, beta2, epsilon,
steps + 1, gemma.Weights(), grad_m, grad_v, pool); steps + 1, gemma.Weights(), grad_m, grad_v, pool);
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n", printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
steps, total_loss, num_ok, kBatchSize); steps, total_loss, num_ok, kBatchSize);
if (steps % 100 == 0) { if (steps % 100 == 0) {
printf("Batch gradient:\n"); printf("Batch gradient:\n");
LogWeightStats(model_type, weight_type, grad); LogWeightStats(info.model, info.weight, grad);
} }
if (total_loss < 0.5f) { if (total_loss < 0.5f) {
break; break;
@ -136,7 +139,7 @@ TEST(OptimizeTest, GradientDescent) {
} }
printf("Num steps: %zu\n", steps); printf("Num steps: %zu\n", steps);
printf("Final weights:\n"); printf("Final weights:\n");
LogWeightStats(model_type, weight_type, gemma.Weights()); LogWeightStats(info.model, info.weight, gemma.Weights());
EXPECT_LT(steps, 300); EXPECT_LT(steps, 300);
EXPECT_EQ(num_ok, kBatchSize); EXPECT_EQ(num_ok, kBatchSize);
} }

View File

@ -128,7 +128,7 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens); size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
std::vector<int> prompt_slice(prompt.begin() + pos, std::vector<int> prompt_slice(prompt.begin() + pos,
prompt.begin() + pos + num_tokens); prompt.begin() + pos + num_tokens);
KVCache kv_cache = KVCache::Create(env.ModelType()); KVCache kv_cache = KVCache::Create(env.Info().model);
float entropy = ComputeCrossEntropy( float entropy = ComputeCrossEntropy(
*env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity()); *env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
total_entropy += entropy; total_entropy += entropy;
@ -185,7 +185,7 @@ int main(int argc, char** argv) {
if (!benchmark_args.goldens.Empty()) { if (!benchmark_args.goldens.Empty()) {
const std::string golden_path = const std::string golden_path =
benchmark_args.goldens.path + "/" + benchmark_args.goldens.path + "/" +
gcpp::ModelString(env.ModelType(), env.ModelTrainingType()) + ".txt"; gcpp::ModelString(env.Info().model, env.Info().training) + ".txt";
return BenchmarkGoldens(env, golden_path); return BenchmarkGoldens(env, golden_path);
} else if (!benchmark_args.summarize_text.Empty()) { } else if (!benchmark_args.summarize_text.Empty()) {
return BenchmarkSummary(env, benchmark_args.summarize_text); return BenchmarkSummary(env, benchmark_args.summarize_text);

View File

@ -71,7 +71,7 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
// TWeight is unused, but we have to pass it to Config*. // TWeight is unused, but we have to pass it to Config*.
const int vocab_size = const int vocab_size =
CallForModel</*TWeight=*/float, GetVocabSize>(gemma.ModelType()); CallForModel</*TWeight=*/float, GetVocabSize>(gemma.Info().model);
float cross_entropy = std::log(vocab_size); // first token float cross_entropy = std::log(vocab_size); // first token
size_t pos = 1; size_t pos = 1;
const SampleFunc sample_token = [&](const float* probs, const SampleFunc sample_token = [&](const float* probs,

View File

@ -104,7 +104,7 @@ void Run(GemmaEnv& env, JsonArgs& json) {
"Do not include any justifications or explanations. Reply only with a " "Do not include any justifications or explanations. Reply only with a "
"letter."; "letter.";
const std::vector<int> prompt = const std::vector<int> prompt =
WrapAndTokenize(env.GetModel()->Tokenizer(), env.ModelTrainingType(), WrapAndTokenize(env.GetModel()->Tokenizer(), env.Info(),
/*pos=*/0, prompt_string); /*pos=*/0, prompt_string);
const size_t prompt_size = prompt.size(); const size_t prompt_size = prompt.size();

View File

@ -64,7 +64,7 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
pool_(app_.num_threads) { pool_(app_.num_threads) {
// For many-core, pinning workers to cores helps. // For many-core, pinning workers to cores helps.
if (app_.num_threads > 10) { if (app_.num_threads > 10) {
gcpp::PinWorkersToCores(pool_); PinWorkersToCores(pool_);
} }
AbortIfInvalidArgs(inference_args_); AbortIfInvalidArgs(inference_args_);
@ -78,7 +78,7 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
kv_caches_.reserve(16); kv_caches_.reserve(16);
for (int i = 0; i < 16; ++i) { for (int i = 0; i < 16; ++i) {
kv_caches_.push_back(new KVCache(KVCache::Create(loader_.ModelType()))); kv_caches_.push_back(new KVCache(KVCache::Create(model_->Info().model)));
} }
} }
@ -181,9 +181,8 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
} }
std::pair<std::string, size_t> GemmaEnv::QueryModel(std::string& input) { std::pair<std::string, size_t> GemmaEnv::QueryModel(std::string& input) {
const std::vector<int> prompt = const std::vector<int> prompt = WrapAndTokenize(model_->Tokenizer(), Info(),
WrapAndTokenize(model_->Tokenizer(), loader_.ModelTrainingType(), /*pos=*/0, input);
/*pos=*/0, input);
return QueryModel(prompt); return QueryModel(prompt);
} }
std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel( std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
@ -193,9 +192,8 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
for (auto& input : inputs) { for (auto& input : inputs) {
std::string mutable_prompt = input; std::string mutable_prompt = input;
prompts.push_back(std::make_unique<std::vector<int>>( prompts.push_back(std::make_unique<std::vector<int>>(
WrapAndTokenize(model_->Tokenizer(), WrapAndTokenize(model_->Tokenizer(), model_->Info(),
loader_.ModelTrainingType(), /*pos=*/0, mutable_prompt)));
/*pos=*/0, mutable_prompt)));
} }
std::vector<hwy::Span<int>> prompt_vector; std::vector<hwy::Span<int>> prompt_vector;
prompt_vector.reserve(prompts.size()); prompt_vector.reserve(prompts.size());
@ -234,8 +232,7 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
time_t now = time(nullptr); time_t now = time(nullptr);
char* dt = ctime(&now); // NOLINT char* dt = ctime(&now); // NOLINT
std::cout << "Date & Time : " << dt std::cout << "Date & Time : " << dt
<< "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize << "Prefill Token Batch Size : " << kPrefillBatchSize << "\n"
<< "\n"
<< "Hardware concurrency : " << "Hardware concurrency : "
<< std::thread::hardware_concurrency() << "\n" << std::thread::hardware_concurrency() << "\n"
<< "Instruction set : " << "Instruction set : "
@ -247,14 +244,13 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
} }
std::cout << "Compiled config : " << CompiledConfig() << "\n" std::cout << "Compiled config : " << CompiledConfig() << "\n"
<< "Weight Type : " << "Weight Type : "
<< gcpp::StringFromType(loader.WeightType()) << "\n" << StringFromType(loader.Info().weight) << "\n"
<< "EmbedderInput Type : " << "EmbedderInput Type : "
<< gcpp::TypeName(gcpp::EmbedderInputT()) << "\n"; << TypeName(EmbedderInputT()) << "\n";
} }
} }
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
gcpp::AppArgs& app) {
std::cerr std::cerr
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n" << "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
"==========================================================\n\n" "==========================================================\n\n"

View File

@ -24,7 +24,6 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "gemma/common.h"
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "util/app.h" #include "util/app.h"
#include "hwy/base.h" #include "hwy/base.h"
@ -83,12 +82,10 @@ class GemmaEnv {
// Returns nullptr if the model failed to load. // Returns nullptr if the model failed to load.
Gemma* GetModel() const { return model_.get(); } Gemma* GetModel() const { return model_.get(); }
Model ModelType() const { return loader_.ModelType(); }
ModelTraining ModelTrainingType() const {
return loader_.ModelTrainingType();
}
int Verbosity() const { return app_.verbosity; } int Verbosity() const { return app_.verbosity; }
RuntimeConfig& MutableConfig() { return runtime_config_; } RuntimeConfig& MutableConfig() { return runtime_config_; }
const ModelInfo& Info() const { return loader_.Info(); }
InferenceArgs& MutableInferenceArgs() { return inference_args_; } InferenceArgs& MutableInferenceArgs() { return inference_args_; }
std::mt19937& MutableGen() { return gen_; } std::mt19937& MutableGen() { return gen_; }
KVCache& MutableKVCache() { return *kv_caches_[0]; } KVCache& MutableKVCache() { return *kv_caches_[0]; }

View File

@ -39,7 +39,7 @@
#include <array> #include <array>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility> // std::move
#include <vector> #include <vector>
#include "compression/io.h" // Path #include "compression/io.h" // Path
@ -1085,32 +1085,26 @@ struct AllocateState {
} // namespace } // namespace
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
Type weight_type, hwy::ThreadPool& pool) const ModelInfo& info, hwy::ThreadPool& pool)
: pool_(pool), : pool_(pool), tokenizer_(tokenizer_path), info_(info) {
tokenizer_(tokenizer_path), weights_u8_ = LoadCompressedWeights(weights, info.model, info.weight, pool);
model_type_(model_type), CallForModelAndWeight<AllocateState>(info.model, info.weight, prefill_u8_,
weight_type_(weight_type) {
weights_u8_ = LoadCompressedWeights(weights, model_type, weight_type, pool);
CallForModelAndWeight<AllocateState>(model_type, weight_type, prefill_u8_,
decode_u8_); decode_u8_);
} }
Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type, Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
hwy::ThreadPool& pool) hwy::ThreadPool& pool)
: pool_(pool), : pool_(pool), tokenizer_(std::move(tokenizer)), info_(info) {
tokenizer_(std::move(tokenizer)), HWY_ASSERT(info.weight == Type::kF32);
model_type_(model_type), weights_u8_ =
weight_type_(weight_type) { CallForModel<float, AllocateCompressedWeights>(info.model, pool);
HWY_ASSERT(weight_type == Type::kF32); CallForModelAndWeight<AllocateState>(info.model, info.weight, prefill_u8_,
weights_u8_ = CallForModel<float, AllocateCompressedWeights>(
model_type, pool);
CallForModelAndWeight<AllocateState>(model_type, weight_type, prefill_u8_,
decode_u8_); decode_u8_);
} }
Gemma::~Gemma() { Gemma::~Gemma() {
CallForModelAndWeight<DeleteCompressedWeights>(model_type_, weight_type_, CallForModelAndWeight<DeleteCompressedWeights>(info_.model, info_.weight,
weights_u8_); weights_u8_);
} }
@ -1120,7 +1114,7 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
GEMMA_EXPORT_AND_DISPATCH( GEMMA_EXPORT_AND_DISPATCH(
model_type_, weight_type_, GenerateOneQueryT, info_.model, info_.weight, GenerateOneQueryT,
(weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompt, start_pos, (weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompt, start_pos,
kv_cache, pool_, timing_info)); kv_cache, pool_, timing_info));
@ -1135,24 +1129,29 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
GEMMA_EXPORT_AND_DISPATCH( GEMMA_EXPORT_AND_DISPATCH(
model_type_, weight_type_, GenerateBatchT, info_.model, info_.weight, GenerateBatchT,
(weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompts, start_pos, (weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompts, start_pos,
kv_caches, pool_, timing_info)); kv_caches, pool_, timing_info));
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
} }
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer, void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) {
const ModelTraining training, size_t pos,
std::string& prompt) {
// Instruction-tuned models are trained to expect control tokens. // Instruction-tuned models are trained to expect control tokens.
if (training == ModelTraining::GEMMA_IT) { if (info.training == ModelTraining::GEMMA_IT) {
// Prepend "<end_of_turn>" if this is a multi-turn dialogue continuation. // Prepend "<end_of_turn>" if this is a multi-turn dialogue continuation.
const std::string start = (pos == 0) const std::string start = (pos == 0)
? "<start_of_turn>user\n" ? "<start_of_turn>user\n"
: "<end_of_turn>\n<start_of_turn>user\n"; : "<end_of_turn>\n<start_of_turn>user\n";
prompt = start + prompt + "<end_of_turn>\n<start_of_turn>model\n"; prompt = start + prompt + "<end_of_turn>\n<start_of_turn>model\n";
} }
}
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const ModelInfo& info, size_t pos,
std::string& prompt) {
Wrap(info, pos, prompt);
std::vector<int> tokens; std::vector<int> tokens;
HWY_ASSERT(tokenizer.Encode(prompt, &tokens)); HWY_ASSERT(tokenizer.Encode(prompt, &tokens));

View File

@ -93,6 +93,13 @@ using SampleFunc = std::function<int(const float*, size_t)>;
using LayersOutputFunc = using LayersOutputFunc =
std::function<void(int, const std::string&, const float*, size_t)>; std::function<void(int, const std::string&, const float*, size_t)>;
// TODO(janwas): move into common.h, merge with parser/ToString.
struct ModelInfo {
Model model;
ModelTraining training;
Type weight;
};
struct RuntimeConfig { struct RuntimeConfig {
size_t max_tokens; size_t max_tokens;
size_t max_generated_tokens; size_t max_generated_tokens;
@ -115,15 +122,15 @@ struct TimingInfo {
class Gemma { class Gemma {
public: public:
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
Type weight_type, hwy::ThreadPool& pool); hwy::ThreadPool& pool);
// Allocates weights, caller is responsible for filling them. // Allocates weights, caller is responsible for filling them.
Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type, Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
hwy::ThreadPool& pool); hwy::ThreadPool& pool);
~Gemma(); ~Gemma();
Model ModelType() const { return model_type_; } const ModelInfo& Info() const { return info_; }
const GemmaTokenizer& Tokenizer() const { return tokenizer_; } const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
const ByteStorageT& Weights() const { return weights_u8_; } const ByteStorageT& Weights() const { return weights_u8_; }
const ByteStorageT& Prefill() const { return prefill_u8_; } const ByteStorageT& Prefill() const { return prefill_u8_; }
@ -147,15 +154,14 @@ class Gemma {
ByteStorageT weights_u8_; ByteStorageT weights_u8_;
ByteStorageT prefill_u8_; ByteStorageT prefill_u8_;
ByteStorageT decode_u8_; ByteStorageT decode_u8_;
Model model_type_; ModelInfo info_;
Type weight_type_;
}; };
// Adds BOS token and possibly 'turn' annotations, which depend on `training` // Adds BOS token and possibly 'turn' annotations, which depend on `training`
// and `pos`, the number of tokens decoded so far; returns the corresponding // and `pos`, the number of tokens decoded so far; returns the corresponding
// tokens. Asserts that tokenization is successful. // tokens. Asserts that tokenization is successful.
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer, std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
ModelTraining training, size_t pos, const ModelInfo& info, size_t pos,
std::string& prompt); std::string& prompt);
// DEPRECATED, call Gemma::Generate directly. // DEPRECATED, call Gemma::Generate directly.

View File

@ -53,10 +53,9 @@ static constexpr std::string_view kAsciiArtBanner = R""(
)""; )"";
// The main Read-Eval-Print Loop. // The main Read-Eval-Print Loop.
void ReplGemma(gcpp::Gemma& model, ModelTraining training, void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool,
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool,
const InferenceArgs& args, int verbosity, const InferenceArgs& args, int verbosity,
const gcpp::AcceptFunc& accept_token, std::string& eot_line) { const AcceptFunc& accept_token, std::string& eot_line) {
PROFILER_ZONE("Gen.misc"); PROFILER_ZONE("Gen.misc");
size_t abs_pos = 0; // absolute token index over all turns size_t abs_pos = 0; // absolute token index over all turns
int current_pos = 0; // token index within the current turn int current_pos = 0; // token index within the current turn
@ -73,7 +72,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
// <= since position is incremented before // <= since position is incremented before
if (current_pos <= prompt_size) { if (current_pos <= prompt_size) {
std::cerr << "." << std::flush; std::cerr << "." << std::flush;
} else if (token == gcpp::EOS_ID) { } else if (token == EOS_ID) {
if (!args.multiturn) { if (!args.multiturn) {
abs_pos = 0; abs_pos = 0;
if (args.deterministic) { if (args.deterministic) {
@ -131,8 +130,8 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
continue; continue;
} }
const std::vector<int> prompt = const std::vector<int> prompt = WrapAndTokenize(
WrapAndTokenize(model.Tokenizer(), training, abs_pos, prompt_string); model.Tokenizer(), model.Info(), abs_pos, prompt_string);
prompt_size = prompt.size(); prompt_size = prompt.size();
std::cerr << "\n" std::cerr << "\n"
<< "[ Reading prompt ] " << std::flush; << "[ Reading prompt ] " << std::flush;
@ -143,7 +142,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
} }
TimingInfo timing_info; TimingInfo timing_info;
gcpp::RuntimeConfig runtime_config = { RuntimeConfig runtime_config = {
.max_tokens = args.max_tokens, .max_tokens = args.max_tokens,
.max_generated_tokens = args.max_generated_tokens, .max_generated_tokens = args.max_generated_tokens,
.temperature = args.temperature, .temperature = args.temperature,
@ -179,8 +178,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
PinWorkersToCores(pool); PinWorkersToCores(pool);
} }
gcpp::Gemma model = gcpp::CreateGemma(loader, pool); Gemma model = CreateGemma(loader, pool);
KVCache kv_cache = KVCache::Create(loader.ModelType()); KVCache kv_cache = KVCache::Create(model.Info().model);
if (app.verbosity >= 1) { if (app.verbosity >= 1) {
const std::string instructions = const std::string instructions =
@ -208,8 +207,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
std::cout << "\n" << instructions << "\n"; std::cout << "\n" << instructions << "\n";
} }
ReplGemma(model, loader.ModelTrainingType(), kv_cache, pool, inference, ReplGemma(model, kv_cache, pool, inference, app.verbosity, AcceptFunc(),
app.verbosity, AcceptFunc(), app.eot_line); app.eot_line);
} }
} // namespace gcpp } // namespace gcpp

View File

@ -18,18 +18,12 @@
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ #ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ #define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
#include <memory>
#include "hwy/contrib/thread_pool/thread_pool.h"
#if HWY_OS_LINUX
#include <sched.h>
#endif // HWY_OS_LINUX
#include <stddef.h> #include <stddef.h>
#include <stdio.h> #include <stdio.h>
#include <algorithm> // std::clamp #include <algorithm> // std::clamp
#include <memory>
#include <string> #include <string>
#include <thread> // NOLINT>
#include <vector> #include <vector>
#include "compression/io.h" // Path #include "compression/io.h" // Path
@ -38,8 +32,13 @@
#include "gemma/gemma.h" #include "gemma/gemma.h"
#include "util/args.h" #include "util/args.h"
#include "hwy/base.h" // HWY_ASSERT #include "hwy/base.h" // HWY_ASSERT
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/contrib/thread_pool/topology.h" #include "hwy/contrib/thread_pool/topology.h"
#if HWY_OS_LINUX
#include <sched.h>
#endif // HWY_OS_LINUX
namespace gcpp { namespace gcpp {
static inline const char* CompiledConfig() { static inline const char* CompiledConfig() {
@ -168,11 +167,11 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
// Returns error string or nullptr if OK. // Returns error string or nullptr if OK.
const char* Validate() { const char* Validate() {
if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_, if (const char* err = ParseModelTypeAndTraining(model_type_str, info_.model,
model_training_)) { info_.training)) {
return err; return err;
} }
if (const char* err = ParseType(weight_type_str, weight_type_)) { if (const char* err = ParseType(weight_type_str, info_.weight)) {
return err; return err;
} }
if (tokenizer.path.empty()) { if (tokenizer.path.empty()) {
@ -226,26 +225,21 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
} }
// Uninitialized before Validate, must call after that. // Uninitialized before Validate, must call after that.
gcpp::Model ModelType() const { return model_type_; } const ModelInfo& Info() const { return info_; }
gcpp::ModelTraining ModelTrainingType() const { return model_training_; }
gcpp::Type WeightType() const { return weight_type_; }
private: private:
Model model_type_; ModelInfo info_;
ModelTraining model_training_;
Type weight_type_;
}; };
static inline Gemma CreateGemma(const LoaderArgs& loader, static inline Gemma CreateGemma(const LoaderArgs& loader,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
return Gemma(loader.tokenizer, loader.weights, loader.ModelType(), return Gemma(loader.tokenizer, loader.weights, loader.Info(), pool);
loader.WeightType(), pool);
} }
static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader, static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
return std::make_unique<Gemma>(loader.tokenizer, loader.weights, return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
loader.ModelType(), loader.WeightType(), pool); loader.Info(), pool);
} }
struct InferenceArgs : public ArgsBase<InferenceArgs> { struct InferenceArgs : public ArgsBase<InferenceArgs> {