mirror of https://github.com/google/gemma.cpp.git
Cleanup: add ModelInfo struct, remove gcpp::
PiperOrigin-RevId: 648707763
This commit is contained in:
parent
b1c1ec1d59
commit
85fcd3cd80
|
|
@ -37,21 +37,24 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
hwy::ThreadPool pool(0);
|
hwy::ThreadPool pool(0);
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
|
|
||||||
Model model_type = Model::GEMMA_TINY;
|
const ModelInfo info = {
|
||||||
Type weight_type = Type::kF32;
|
.model = Model::GEMMA_TINY,
|
||||||
|
.training = ModelTraining::GEMMA_IT,
|
||||||
|
.weight = Type::kF32,
|
||||||
|
};
|
||||||
ByteStorageT grad = CallForModelAndWeight<AllocateCompressedWeights>(
|
ByteStorageT grad = CallForModelAndWeight<AllocateCompressedWeights>(
|
||||||
model_type, weight_type, pool);
|
info.model, info.weight, pool);
|
||||||
ByteStorageT grad_m = CallForModelAndWeight<AllocateCompressedWeights>(
|
ByteStorageT grad_m = CallForModelAndWeight<AllocateCompressedWeights>(
|
||||||
model_type, weight_type, pool);
|
info.model, info.weight, pool);
|
||||||
ByteStorageT grad_v = CallForModelAndWeight<AllocateCompressedWeights>(
|
ByteStorageT grad_v = CallForModelAndWeight<AllocateCompressedWeights>(
|
||||||
model_type, weight_type, pool);
|
info.model, info.weight, pool);
|
||||||
ByteStorageT forward =
|
ByteStorageT forward =
|
||||||
CallForModelAndWeight<AllocateForwardPass>(model_type, weight_type);
|
CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
|
||||||
ByteStorageT backward =
|
ByteStorageT backward =
|
||||||
CallForModelAndWeight<AllocateForwardPass>(model_type, weight_type);
|
CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
|
||||||
KVCache kv_cache = KVCache::Create(model_type);
|
KVCache kv_cache = KVCache::Create(info.model);
|
||||||
|
|
||||||
Gemma gemma(GemmaTokenizer(), model_type, weight_type, pool);
|
Gemma gemma(GemmaTokenizer(), info, pool);
|
||||||
|
|
||||||
const auto generate = [&](const std::vector<int>& prompt) {
|
const auto generate = [&](const std::vector<int>& prompt) {
|
||||||
std::vector<int> reply;
|
std::vector<int> reply;
|
||||||
|
|
@ -85,14 +88,14 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
return ok;
|
return ok;
|
||||||
};
|
};
|
||||||
|
|
||||||
RandInitWeights(model_type, weight_type, gemma.Weights(), pool, gen);
|
RandInitWeights(info.model, info.weight, gemma.Weights(), pool, gen);
|
||||||
CallForModelAndWeight<ZeroInitCompressedWeights>(
|
CallForModelAndWeight<ZeroInitCompressedWeights>(info.model, info.weight,
|
||||||
model_type, weight_type, grad_m, pool);
|
grad_m, pool);
|
||||||
CallForModelAndWeight<ZeroInitCompressedWeights>(
|
CallForModelAndWeight<ZeroInitCompressedWeights>(info.model, info.weight,
|
||||||
model_type, weight_type, grad_v, pool);
|
grad_v, pool);
|
||||||
|
|
||||||
printf("Initial weights:\n");
|
printf("Initial weights:\n");
|
||||||
LogWeightStats(model_type, weight_type, gemma.Weights());
|
LogWeightStats(info.model, info.weight, gemma.Weights());
|
||||||
|
|
||||||
constexpr size_t kBatchSize = 8;
|
constexpr size_t kBatchSize = 8;
|
||||||
const float alpha = 0.001f;
|
const float alpha = 0.001f;
|
||||||
|
|
@ -107,27 +110,27 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
size_t num_ok;
|
size_t num_ok;
|
||||||
for (; steps < 1000000; ++steps) {
|
for (; steps < 1000000; ++steps) {
|
||||||
std::mt19937 sgen(42);
|
std::mt19937 sgen(42);
|
||||||
CallForModelAndWeight<ZeroInitCompressedWeights>(
|
CallForModelAndWeight<ZeroInitCompressedWeights>(info.model, info.weight,
|
||||||
model_type, weight_type, grad, pool);
|
grad, pool);
|
||||||
float total_loss = 0.0f;
|
float total_loss = 0.0f;
|
||||||
num_ok = 0;
|
num_ok = 0;
|
||||||
for (size_t i = 0; i < kBatchSize; ++i) {
|
for (size_t i = 0; i < kBatchSize; ++i) {
|
||||||
Prompt prompt = training_task.Sample(sgen);
|
Prompt prompt = training_task.Sample(sgen);
|
||||||
total_loss += CrossEntropyLossForwardPass(model_type, prompt,
|
total_loss += CrossEntropyLossForwardPass(info.model, prompt,
|
||||||
gemma.Weights(), forward, pool);
|
gemma.Weights(), forward, pool);
|
||||||
CrossEntropyLossBackwardPass(model_type, prompt, gemma.Weights(), forward,
|
CrossEntropyLossBackwardPass(info.model, prompt, gemma.Weights(), forward,
|
||||||
grad, backward, pool);
|
grad, backward, pool);
|
||||||
num_ok += verify(prompt) ? 1 : 0;
|
num_ok += verify(prompt) ? 1 : 0;
|
||||||
}
|
}
|
||||||
total_loss /= kBatchSize;
|
total_loss /= kBatchSize;
|
||||||
|
|
||||||
AdamUpdate(model_type, weight_type, grad, alpha, beta1, beta2, epsilon,
|
AdamUpdate(info.model, info.weight, grad, alpha, beta1, beta2, epsilon,
|
||||||
steps + 1, gemma.Weights(), grad_m, grad_v, pool);
|
steps + 1, gemma.Weights(), grad_m, grad_v, pool);
|
||||||
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
|
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
|
||||||
steps, total_loss, num_ok, kBatchSize);
|
steps, total_loss, num_ok, kBatchSize);
|
||||||
if (steps % 100 == 0) {
|
if (steps % 100 == 0) {
|
||||||
printf("Batch gradient:\n");
|
printf("Batch gradient:\n");
|
||||||
LogWeightStats(model_type, weight_type, grad);
|
LogWeightStats(info.model, info.weight, grad);
|
||||||
}
|
}
|
||||||
if (total_loss < 0.5f) {
|
if (total_loss < 0.5f) {
|
||||||
break;
|
break;
|
||||||
|
|
@ -136,7 +139,7 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
}
|
}
|
||||||
printf("Num steps: %zu\n", steps);
|
printf("Num steps: %zu\n", steps);
|
||||||
printf("Final weights:\n");
|
printf("Final weights:\n");
|
||||||
LogWeightStats(model_type, weight_type, gemma.Weights());
|
LogWeightStats(info.model, info.weight, gemma.Weights());
|
||||||
EXPECT_LT(steps, 300);
|
EXPECT_LT(steps, 300);
|
||||||
EXPECT_EQ(num_ok, kBatchSize);
|
EXPECT_EQ(num_ok, kBatchSize);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -128,7 +128,7 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
|
||||||
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
||||||
std::vector<int> prompt_slice(prompt.begin() + pos,
|
std::vector<int> prompt_slice(prompt.begin() + pos,
|
||||||
prompt.begin() + pos + num_tokens);
|
prompt.begin() + pos + num_tokens);
|
||||||
KVCache kv_cache = KVCache::Create(env.ModelType());
|
KVCache kv_cache = KVCache::Create(env.Info().model);
|
||||||
float entropy = ComputeCrossEntropy(
|
float entropy = ComputeCrossEntropy(
|
||||||
*env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
|
*env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
|
||||||
total_entropy += entropy;
|
total_entropy += entropy;
|
||||||
|
|
@ -185,7 +185,7 @@ int main(int argc, char** argv) {
|
||||||
if (!benchmark_args.goldens.Empty()) {
|
if (!benchmark_args.goldens.Empty()) {
|
||||||
const std::string golden_path =
|
const std::string golden_path =
|
||||||
benchmark_args.goldens.path + "/" +
|
benchmark_args.goldens.path + "/" +
|
||||||
gcpp::ModelString(env.ModelType(), env.ModelTrainingType()) + ".txt";
|
gcpp::ModelString(env.Info().model, env.Info().training) + ".txt";
|
||||||
return BenchmarkGoldens(env, golden_path);
|
return BenchmarkGoldens(env, golden_path);
|
||||||
} else if (!benchmark_args.summarize_text.Empty()) {
|
} else if (!benchmark_args.summarize_text.Empty()) {
|
||||||
return BenchmarkSummary(env, benchmark_args.summarize_text);
|
return BenchmarkSummary(env, benchmark_args.summarize_text);
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
|
||||||
|
|
||||||
// TWeight is unused, but we have to pass it to Config*.
|
// TWeight is unused, but we have to pass it to Config*.
|
||||||
const int vocab_size =
|
const int vocab_size =
|
||||||
CallForModel</*TWeight=*/float, GetVocabSize>(gemma.ModelType());
|
CallForModel</*TWeight=*/float, GetVocabSize>(gemma.Info().model);
|
||||||
float cross_entropy = std::log(vocab_size); // first token
|
float cross_entropy = std::log(vocab_size); // first token
|
||||||
size_t pos = 1;
|
size_t pos = 1;
|
||||||
const SampleFunc sample_token = [&](const float* probs,
|
const SampleFunc sample_token = [&](const float* probs,
|
||||||
|
|
|
||||||
|
|
@ -104,7 +104,7 @@ void Run(GemmaEnv& env, JsonArgs& json) {
|
||||||
"Do not include any justifications or explanations. Reply only with a "
|
"Do not include any justifications or explanations. Reply only with a "
|
||||||
"letter.";
|
"letter.";
|
||||||
const std::vector<int> prompt =
|
const std::vector<int> prompt =
|
||||||
WrapAndTokenize(env.GetModel()->Tokenizer(), env.ModelTrainingType(),
|
WrapAndTokenize(env.GetModel()->Tokenizer(), env.Info(),
|
||||||
/*pos=*/0, prompt_string);
|
/*pos=*/0, prompt_string);
|
||||||
const size_t prompt_size = prompt.size();
|
const size_t prompt_size = prompt.size();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
pool_(app_.num_threads) {
|
pool_(app_.num_threads) {
|
||||||
// For many-core, pinning workers to cores helps.
|
// For many-core, pinning workers to cores helps.
|
||||||
if (app_.num_threads > 10) {
|
if (app_.num_threads > 10) {
|
||||||
gcpp::PinWorkersToCores(pool_);
|
PinWorkersToCores(pool_);
|
||||||
}
|
}
|
||||||
|
|
||||||
AbortIfInvalidArgs(inference_args_);
|
AbortIfInvalidArgs(inference_args_);
|
||||||
|
|
@ -78,7 +78,7 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
|
|
||||||
kv_caches_.reserve(16);
|
kv_caches_.reserve(16);
|
||||||
for (int i = 0; i < 16; ++i) {
|
for (int i = 0; i < 16; ++i) {
|
||||||
kv_caches_.push_back(new KVCache(KVCache::Create(loader_.ModelType())));
|
kv_caches_.push_back(new KVCache(KVCache::Create(model_->Info().model)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -181,8 +181,7 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::string, size_t> GemmaEnv::QueryModel(std::string& input) {
|
std::pair<std::string, size_t> GemmaEnv::QueryModel(std::string& input) {
|
||||||
const std::vector<int> prompt =
|
const std::vector<int> prompt = WrapAndTokenize(model_->Tokenizer(), Info(),
|
||||||
WrapAndTokenize(model_->Tokenizer(), loader_.ModelTrainingType(),
|
|
||||||
/*pos=*/0, input);
|
/*pos=*/0, input);
|
||||||
return QueryModel(prompt);
|
return QueryModel(prompt);
|
||||||
}
|
}
|
||||||
|
|
@ -193,8 +192,7 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
|
||||||
for (auto& input : inputs) {
|
for (auto& input : inputs) {
|
||||||
std::string mutable_prompt = input;
|
std::string mutable_prompt = input;
|
||||||
prompts.push_back(std::make_unique<std::vector<int>>(
|
prompts.push_back(std::make_unique<std::vector<int>>(
|
||||||
WrapAndTokenize(model_->Tokenizer(),
|
WrapAndTokenize(model_->Tokenizer(), model_->Info(),
|
||||||
loader_.ModelTrainingType(),
|
|
||||||
/*pos=*/0, mutable_prompt)));
|
/*pos=*/0, mutable_prompt)));
|
||||||
}
|
}
|
||||||
std::vector<hwy::Span<int>> prompt_vector;
|
std::vector<hwy::Span<int>> prompt_vector;
|
||||||
|
|
@ -234,8 +232,7 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
time_t now = time(nullptr);
|
time_t now = time(nullptr);
|
||||||
char* dt = ctime(&now); // NOLINT
|
char* dt = ctime(&now); // NOLINT
|
||||||
std::cout << "Date & Time : " << dt
|
std::cout << "Date & Time : " << dt
|
||||||
<< "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize
|
<< "Prefill Token Batch Size : " << kPrefillBatchSize << "\n"
|
||||||
<< "\n"
|
|
||||||
<< "Hardware concurrency : "
|
<< "Hardware concurrency : "
|
||||||
<< std::thread::hardware_concurrency() << "\n"
|
<< std::thread::hardware_concurrency() << "\n"
|
||||||
<< "Instruction set : "
|
<< "Instruction set : "
|
||||||
|
|
@ -247,14 +244,13 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
}
|
}
|
||||||
std::cout << "Compiled config : " << CompiledConfig() << "\n"
|
std::cout << "Compiled config : " << CompiledConfig() << "\n"
|
||||||
<< "Weight Type : "
|
<< "Weight Type : "
|
||||||
<< gcpp::StringFromType(loader.WeightType()) << "\n"
|
<< StringFromType(loader.Info().weight) << "\n"
|
||||||
<< "EmbedderInput Type : "
|
<< "EmbedderInput Type : "
|
||||||
<< gcpp::TypeName(gcpp::EmbedderInputT()) << "\n";
|
<< TypeName(EmbedderInputT()) << "\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
|
void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
gcpp::AppArgs& app) {
|
|
||||||
std::cerr
|
std::cerr
|
||||||
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
|
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
|
||||||
"==========================================================\n\n"
|
"==========================================================\n\n"
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,6 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "gemma/common.h"
|
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "util/app.h"
|
#include "util/app.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
@ -83,12 +82,10 @@ class GemmaEnv {
|
||||||
|
|
||||||
// Returns nullptr if the model failed to load.
|
// Returns nullptr if the model failed to load.
|
||||||
Gemma* GetModel() const { return model_.get(); }
|
Gemma* GetModel() const { return model_.get(); }
|
||||||
Model ModelType() const { return loader_.ModelType(); }
|
|
||||||
ModelTraining ModelTrainingType() const {
|
|
||||||
return loader_.ModelTrainingType();
|
|
||||||
}
|
|
||||||
int Verbosity() const { return app_.verbosity; }
|
int Verbosity() const { return app_.verbosity; }
|
||||||
RuntimeConfig& MutableConfig() { return runtime_config_; }
|
RuntimeConfig& MutableConfig() { return runtime_config_; }
|
||||||
|
const ModelInfo& Info() const { return loader_.Info(); }
|
||||||
InferenceArgs& MutableInferenceArgs() { return inference_args_; }
|
InferenceArgs& MutableInferenceArgs() { return inference_args_; }
|
||||||
std::mt19937& MutableGen() { return gen_; }
|
std::mt19937& MutableGen() { return gen_; }
|
||||||
KVCache& MutableKVCache() { return *kv_caches_[0]; }
|
KVCache& MutableKVCache() { return *kv_caches_[0]; }
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility> // std::move
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/io.h" // Path
|
#include "compression/io.h" // Path
|
||||||
|
|
@ -1085,32 +1085,26 @@ struct AllocateState {
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
||||||
Type weight_type, hwy::ThreadPool& pool)
|
const ModelInfo& info, hwy::ThreadPool& pool)
|
||||||
: pool_(pool),
|
: pool_(pool), tokenizer_(tokenizer_path), info_(info) {
|
||||||
tokenizer_(tokenizer_path),
|
weights_u8_ = LoadCompressedWeights(weights, info.model, info.weight, pool);
|
||||||
model_type_(model_type),
|
CallForModelAndWeight<AllocateState>(info.model, info.weight, prefill_u8_,
|
||||||
weight_type_(weight_type) {
|
|
||||||
weights_u8_ = LoadCompressedWeights(weights, model_type, weight_type, pool);
|
|
||||||
CallForModelAndWeight<AllocateState>(model_type, weight_type, prefill_u8_,
|
|
||||||
decode_u8_);
|
decode_u8_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type,
|
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
|
||||||
hwy::ThreadPool& pool)
|
hwy::ThreadPool& pool)
|
||||||
: pool_(pool),
|
: pool_(pool), tokenizer_(std::move(tokenizer)), info_(info) {
|
||||||
tokenizer_(std::move(tokenizer)),
|
HWY_ASSERT(info.weight == Type::kF32);
|
||||||
model_type_(model_type),
|
weights_u8_ =
|
||||||
weight_type_(weight_type) {
|
CallForModel<float, AllocateCompressedWeights>(info.model, pool);
|
||||||
HWY_ASSERT(weight_type == Type::kF32);
|
CallForModelAndWeight<AllocateState>(info.model, info.weight, prefill_u8_,
|
||||||
weights_u8_ = CallForModel<float, AllocateCompressedWeights>(
|
|
||||||
model_type, pool);
|
|
||||||
CallForModelAndWeight<AllocateState>(model_type, weight_type, prefill_u8_,
|
|
||||||
decode_u8_);
|
decode_u8_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma::~Gemma() {
|
Gemma::~Gemma() {
|
||||||
CallForModelAndWeight<DeleteCompressedWeights>(model_type_, weight_type_,
|
CallForModelAndWeight<DeleteCompressedWeights>(info_.model, info_.weight,
|
||||||
weights_u8_);
|
weights_u8_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1120,7 +1114,7 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||||
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||||
|
|
||||||
GEMMA_EXPORT_AND_DISPATCH(
|
GEMMA_EXPORT_AND_DISPATCH(
|
||||||
model_type_, weight_type_, GenerateOneQueryT,
|
info_.model, info_.weight, GenerateOneQueryT,
|
||||||
(weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompt, start_pos,
|
(weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompt, start_pos,
|
||||||
kv_cache, pool_, timing_info));
|
kv_cache, pool_, timing_info));
|
||||||
|
|
||||||
|
|
@ -1135,24 +1129,29 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||||
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
|
||||||
|
|
||||||
GEMMA_EXPORT_AND_DISPATCH(
|
GEMMA_EXPORT_AND_DISPATCH(
|
||||||
model_type_, weight_type_, GenerateBatchT,
|
info_.model, info_.weight, GenerateBatchT,
|
||||||
(weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompts, start_pos,
|
(weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompts, start_pos,
|
||||||
kv_caches, pool_, timing_info));
|
kv_caches, pool_, timing_info));
|
||||||
|
|
||||||
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) {
|
||||||
const ModelTraining training, size_t pos,
|
|
||||||
std::string& prompt) {
|
|
||||||
// Instruction-tuned models are trained to expect control tokens.
|
// Instruction-tuned models are trained to expect control tokens.
|
||||||
if (training == ModelTraining::GEMMA_IT) {
|
if (info.training == ModelTraining::GEMMA_IT) {
|
||||||
// Prepend "<end_of_turn>" if this is a multi-turn dialogue continuation.
|
// Prepend "<end_of_turn>" if this is a multi-turn dialogue continuation.
|
||||||
const std::string start = (pos == 0)
|
const std::string start = (pos == 0)
|
||||||
? "<start_of_turn>user\n"
|
? "<start_of_turn>user\n"
|
||||||
: "<end_of_turn>\n<start_of_turn>user\n";
|
: "<end_of_turn>\n<start_of_turn>user\n";
|
||||||
prompt = start + prompt + "<end_of_turn>\n<start_of_turn>model\n";
|
prompt = start + prompt + "<end_of_turn>\n<start_of_turn>model\n";
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||||
|
const ModelInfo& info, size_t pos,
|
||||||
|
std::string& prompt) {
|
||||||
|
Wrap(info, pos, prompt);
|
||||||
|
|
||||||
std::vector<int> tokens;
|
std::vector<int> tokens;
|
||||||
HWY_ASSERT(tokenizer.Encode(prompt, &tokens));
|
HWY_ASSERT(tokenizer.Encode(prompt, &tokens));
|
||||||
|
|
|
||||||
|
|
@ -93,6 +93,13 @@ using SampleFunc = std::function<int(const float*, size_t)>;
|
||||||
using LayersOutputFunc =
|
using LayersOutputFunc =
|
||||||
std::function<void(int, const std::string&, const float*, size_t)>;
|
std::function<void(int, const std::string&, const float*, size_t)>;
|
||||||
|
|
||||||
|
// TODO(janwas): move into common.h, merge with parser/ToString.
|
||||||
|
struct ModelInfo {
|
||||||
|
Model model;
|
||||||
|
ModelTraining training;
|
||||||
|
Type weight;
|
||||||
|
};
|
||||||
|
|
||||||
struct RuntimeConfig {
|
struct RuntimeConfig {
|
||||||
size_t max_tokens;
|
size_t max_tokens;
|
||||||
size_t max_generated_tokens;
|
size_t max_generated_tokens;
|
||||||
|
|
@ -115,15 +122,15 @@ struct TimingInfo {
|
||||||
|
|
||||||
class Gemma {
|
class Gemma {
|
||||||
public:
|
public:
|
||||||
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
|
||||||
Type weight_type, hwy::ThreadPool& pool);
|
hwy::ThreadPool& pool);
|
||||||
|
|
||||||
// Allocates weights, caller is responsible for filling them.
|
// Allocates weights, caller is responsible for filling them.
|
||||||
Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type,
|
Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
|
||||||
hwy::ThreadPool& pool);
|
hwy::ThreadPool& pool);
|
||||||
~Gemma();
|
~Gemma();
|
||||||
|
|
||||||
Model ModelType() const { return model_type_; }
|
const ModelInfo& Info() const { return info_; }
|
||||||
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
||||||
const ByteStorageT& Weights() const { return weights_u8_; }
|
const ByteStorageT& Weights() const { return weights_u8_; }
|
||||||
const ByteStorageT& Prefill() const { return prefill_u8_; }
|
const ByteStorageT& Prefill() const { return prefill_u8_; }
|
||||||
|
|
@ -147,15 +154,14 @@ class Gemma {
|
||||||
ByteStorageT weights_u8_;
|
ByteStorageT weights_u8_;
|
||||||
ByteStorageT prefill_u8_;
|
ByteStorageT prefill_u8_;
|
||||||
ByteStorageT decode_u8_;
|
ByteStorageT decode_u8_;
|
||||||
Model model_type_;
|
ModelInfo info_;
|
||||||
Type weight_type_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Adds BOS token and possibly 'turn' annotations, which depend on `training`
|
// Adds BOS token and possibly 'turn' annotations, which depend on `training`
|
||||||
// and `pos`, the number of tokens decoded so far; returns the corresponding
|
// and `pos`, the number of tokens decoded so far; returns the corresponding
|
||||||
// tokens. Asserts that tokenization is successful.
|
// tokens. Asserts that tokenization is successful.
|
||||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||||
ModelTraining training, size_t pos,
|
const ModelInfo& info, size_t pos,
|
||||||
std::string& prompt);
|
std::string& prompt);
|
||||||
|
|
||||||
// DEPRECATED, call Gemma::Generate directly.
|
// DEPRECATED, call Gemma::Generate directly.
|
||||||
|
|
|
||||||
21
gemma/run.cc
21
gemma/run.cc
|
|
@ -53,10 +53,9 @@ static constexpr std::string_view kAsciiArtBanner = R""(
|
||||||
)"";
|
)"";
|
||||||
|
|
||||||
// The main Read-Eval-Print Loop.
|
// The main Read-Eval-Print Loop.
|
||||||
void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool,
|
|
||||||
const InferenceArgs& args, int verbosity,
|
const InferenceArgs& args, int verbosity,
|
||||||
const gcpp::AcceptFunc& accept_token, std::string& eot_line) {
|
const AcceptFunc& accept_token, std::string& eot_line) {
|
||||||
PROFILER_ZONE("Gen.misc");
|
PROFILER_ZONE("Gen.misc");
|
||||||
size_t abs_pos = 0; // absolute token index over all turns
|
size_t abs_pos = 0; // absolute token index over all turns
|
||||||
int current_pos = 0; // token index within the current turn
|
int current_pos = 0; // token index within the current turn
|
||||||
|
|
@ -73,7 +72,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
// <= since position is incremented before
|
// <= since position is incremented before
|
||||||
if (current_pos <= prompt_size) {
|
if (current_pos <= prompt_size) {
|
||||||
std::cerr << "." << std::flush;
|
std::cerr << "." << std::flush;
|
||||||
} else if (token == gcpp::EOS_ID) {
|
} else if (token == EOS_ID) {
|
||||||
if (!args.multiturn) {
|
if (!args.multiturn) {
|
||||||
abs_pos = 0;
|
abs_pos = 0;
|
||||||
if (args.deterministic) {
|
if (args.deterministic) {
|
||||||
|
|
@ -131,8 +130,8 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<int> prompt =
|
const std::vector<int> prompt = WrapAndTokenize(
|
||||||
WrapAndTokenize(model.Tokenizer(), training, abs_pos, prompt_string);
|
model.Tokenizer(), model.Info(), abs_pos, prompt_string);
|
||||||
prompt_size = prompt.size();
|
prompt_size = prompt.size();
|
||||||
std::cerr << "\n"
|
std::cerr << "\n"
|
||||||
<< "[ Reading prompt ] " << std::flush;
|
<< "[ Reading prompt ] " << std::flush;
|
||||||
|
|
@ -143,7 +142,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
|
||||||
}
|
}
|
||||||
|
|
||||||
TimingInfo timing_info;
|
TimingInfo timing_info;
|
||||||
gcpp::RuntimeConfig runtime_config = {
|
RuntimeConfig runtime_config = {
|
||||||
.max_tokens = args.max_tokens,
|
.max_tokens = args.max_tokens,
|
||||||
.max_generated_tokens = args.max_generated_tokens,
|
.max_generated_tokens = args.max_generated_tokens,
|
||||||
.temperature = args.temperature,
|
.temperature = args.temperature,
|
||||||
|
|
@ -179,8 +178,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
PinWorkersToCores(pool);
|
PinWorkersToCores(pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
|
Gemma model = CreateGemma(loader, pool);
|
||||||
KVCache kv_cache = KVCache::Create(loader.ModelType());
|
KVCache kv_cache = KVCache::Create(model.Info().model);
|
||||||
|
|
||||||
if (app.verbosity >= 1) {
|
if (app.verbosity >= 1) {
|
||||||
const std::string instructions =
|
const std::string instructions =
|
||||||
|
|
@ -208,8 +207,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
std::cout << "\n" << instructions << "\n";
|
std::cout << "\n" << instructions << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
ReplGemma(model, loader.ModelTrainingType(), kv_cache, pool, inference,
|
ReplGemma(model, kv_cache, pool, inference, app.verbosity, AcceptFunc(),
|
||||||
app.verbosity, AcceptFunc(), app.eot_line);
|
app.eot_line);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
32
util/app.h
32
util/app.h
|
|
@ -18,18 +18,12 @@
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
|
||||||
#if HWY_OS_LINUX
|
|
||||||
#include <sched.h>
|
|
||||||
#endif // HWY_OS_LINUX
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <algorithm> // std::clamp
|
#include <algorithm> // std::clamp
|
||||||
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <thread> // NOLINT>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/io.h" // Path
|
#include "compression/io.h" // Path
|
||||||
|
|
@ -38,8 +32,13 @@
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "hwy/base.h" // HWY_ASSERT
|
#include "hwy/base.h" // HWY_ASSERT
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/contrib/thread_pool/topology.h"
|
#include "hwy/contrib/thread_pool/topology.h"
|
||||||
|
|
||||||
|
#if HWY_OS_LINUX
|
||||||
|
#include <sched.h>
|
||||||
|
#endif // HWY_OS_LINUX
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
static inline const char* CompiledConfig() {
|
static inline const char* CompiledConfig() {
|
||||||
|
|
@ -168,11 +167,11 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
// Returns error string or nullptr if OK.
|
||||||
const char* Validate() {
|
const char* Validate() {
|
||||||
if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_,
|
if (const char* err = ParseModelTypeAndTraining(model_type_str, info_.model,
|
||||||
model_training_)) {
|
info_.training)) {
|
||||||
return err;
|
return err;
|
||||||
}
|
}
|
||||||
if (const char* err = ParseType(weight_type_str, weight_type_)) {
|
if (const char* err = ParseType(weight_type_str, info_.weight)) {
|
||||||
return err;
|
return err;
|
||||||
}
|
}
|
||||||
if (tokenizer.path.empty()) {
|
if (tokenizer.path.empty()) {
|
||||||
|
|
@ -226,26 +225,21 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Uninitialized before Validate, must call after that.
|
// Uninitialized before Validate, must call after that.
|
||||||
gcpp::Model ModelType() const { return model_type_; }
|
const ModelInfo& Info() const { return info_; }
|
||||||
gcpp::ModelTraining ModelTrainingType() const { return model_training_; }
|
|
||||||
gcpp::Type WeightType() const { return weight_type_; }
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Model model_type_;
|
ModelInfo info_;
|
||||||
ModelTraining model_training_;
|
|
||||||
Type weight_type_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
static inline Gemma CreateGemma(const LoaderArgs& loader,
|
static inline Gemma CreateGemma(const LoaderArgs& loader,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
return Gemma(loader.tokenizer, loader.weights, loader.ModelType(),
|
return Gemma(loader.tokenizer, loader.weights, loader.Info(), pool);
|
||||||
loader.WeightType(), pool);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
|
static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
|
return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
|
||||||
loader.ModelType(), loader.WeightType(), pool);
|
loader.Info(), pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue