Cleanup: add ModelInfo struct, remove gcpp::

PiperOrigin-RevId: 648707763
This commit is contained in:
Jan Wassenberg 2024-07-02 07:10:32 -07:00 committed by Copybara-Service
parent b1c1ec1d59
commit 85fcd3cd80
10 changed files with 101 additions and 107 deletions

View File

@ -37,21 +37,24 @@ TEST(OptimizeTest, GradientDescent) {
hwy::ThreadPool pool(0);
std::mt19937 gen(42);
Model model_type = Model::GEMMA_TINY;
Type weight_type = Type::kF32;
const ModelInfo info = {
.model = Model::GEMMA_TINY,
.training = ModelTraining::GEMMA_IT,
.weight = Type::kF32,
};
ByteStorageT grad = CallForModelAndWeight<AllocateCompressedWeights>(
model_type, weight_type, pool);
info.model, info.weight, pool);
ByteStorageT grad_m = CallForModelAndWeight<AllocateCompressedWeights>(
model_type, weight_type, pool);
info.model, info.weight, pool);
ByteStorageT grad_v = CallForModelAndWeight<AllocateCompressedWeights>(
model_type, weight_type, pool);
info.model, info.weight, pool);
ByteStorageT forward =
CallForModelAndWeight<AllocateForwardPass>(model_type, weight_type);
CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
ByteStorageT backward =
CallForModelAndWeight<AllocateForwardPass>(model_type, weight_type);
KVCache kv_cache = KVCache::Create(model_type);
CallForModelAndWeight<AllocateForwardPass>(info.model, info.weight);
KVCache kv_cache = KVCache::Create(info.model);
Gemma gemma(GemmaTokenizer(), model_type, weight_type, pool);
Gemma gemma(GemmaTokenizer(), info, pool);
const auto generate = [&](const std::vector<int>& prompt) {
std::vector<int> reply;
@ -85,14 +88,14 @@ TEST(OptimizeTest, GradientDescent) {
return ok;
};
RandInitWeights(model_type, weight_type, gemma.Weights(), pool, gen);
CallForModelAndWeight<ZeroInitCompressedWeights>(
model_type, weight_type, grad_m, pool);
CallForModelAndWeight<ZeroInitCompressedWeights>(
model_type, weight_type, grad_v, pool);
RandInitWeights(info.model, info.weight, gemma.Weights(), pool, gen);
CallForModelAndWeight<ZeroInitCompressedWeights>(info.model, info.weight,
grad_m, pool);
CallForModelAndWeight<ZeroInitCompressedWeights>(info.model, info.weight,
grad_v, pool);
printf("Initial weights:\n");
LogWeightStats(model_type, weight_type, gemma.Weights());
LogWeightStats(info.model, info.weight, gemma.Weights());
constexpr size_t kBatchSize = 8;
const float alpha = 0.001f;
@ -107,27 +110,27 @@ TEST(OptimizeTest, GradientDescent) {
size_t num_ok;
for (; steps < 1000000; ++steps) {
std::mt19937 sgen(42);
CallForModelAndWeight<ZeroInitCompressedWeights>(
model_type, weight_type, grad, pool);
CallForModelAndWeight<ZeroInitCompressedWeights>(info.model, info.weight,
grad, pool);
float total_loss = 0.0f;
num_ok = 0;
for (size_t i = 0; i < kBatchSize; ++i) {
Prompt prompt = training_task.Sample(sgen);
total_loss += CrossEntropyLossForwardPass(model_type, prompt,
total_loss += CrossEntropyLossForwardPass(info.model, prompt,
gemma.Weights(), forward, pool);
CrossEntropyLossBackwardPass(model_type, prompt, gemma.Weights(), forward,
CrossEntropyLossBackwardPass(info.model, prompt, gemma.Weights(), forward,
grad, backward, pool);
num_ok += verify(prompt) ? 1 : 0;
}
total_loss /= kBatchSize;
AdamUpdate(model_type, weight_type, grad, alpha, beta1, beta2, epsilon,
AdamUpdate(info.model, info.weight, grad, alpha, beta1, beta2, epsilon,
steps + 1, gemma.Weights(), grad_m, grad_v, pool);
printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n",
steps, total_loss, num_ok, kBatchSize);
if (steps % 100 == 0) {
printf("Batch gradient:\n");
LogWeightStats(model_type, weight_type, grad);
LogWeightStats(info.model, info.weight, grad);
}
if (total_loss < 0.5f) {
break;
@ -136,7 +139,7 @@ TEST(OptimizeTest, GradientDescent) {
}
printf("Num steps: %zu\n", steps);
printf("Final weights:\n");
LogWeightStats(model_type, weight_type, gemma.Weights());
LogWeightStats(info.model, info.weight, gemma.Weights());
EXPECT_LT(steps, 300);
EXPECT_EQ(num_ok, kBatchSize);
}

View File

@ -128,7 +128,7 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
std::vector<int> prompt_slice(prompt.begin() + pos,
prompt.begin() + pos + num_tokens);
KVCache kv_cache = KVCache::Create(env.ModelType());
KVCache kv_cache = KVCache::Create(env.Info().model);
float entropy = ComputeCrossEntropy(
*env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
total_entropy += entropy;
@ -185,7 +185,7 @@ int main(int argc, char** argv) {
if (!benchmark_args.goldens.Empty()) {
const std::string golden_path =
benchmark_args.goldens.path + "/" +
gcpp::ModelString(env.ModelType(), env.ModelTrainingType()) + ".txt";
gcpp::ModelString(env.Info().model, env.Info().training) + ".txt";
return BenchmarkGoldens(env, golden_path);
} else if (!benchmark_args.summarize_text.Empty()) {
return BenchmarkSummary(env, benchmark_args.summarize_text);

View File

@ -71,7 +71,7 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
// TWeight is unused, but we have to pass it to Config*.
const int vocab_size =
CallForModel</*TWeight=*/float, GetVocabSize>(gemma.ModelType());
CallForModel</*TWeight=*/float, GetVocabSize>(gemma.Info().model);
float cross_entropy = std::log(vocab_size); // first token
size_t pos = 1;
const SampleFunc sample_token = [&](const float* probs,

View File

@ -104,7 +104,7 @@ void Run(GemmaEnv& env, JsonArgs& json) {
"Do not include any justifications or explanations. Reply only with a "
"letter.";
const std::vector<int> prompt =
WrapAndTokenize(env.GetModel()->Tokenizer(), env.ModelTrainingType(),
WrapAndTokenize(env.GetModel()->Tokenizer(), env.Info(),
/*pos=*/0, prompt_string);
const size_t prompt_size = prompt.size();

View File

@ -64,7 +64,7 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
pool_(app_.num_threads) {
// For many-core, pinning workers to cores helps.
if (app_.num_threads > 10) {
gcpp::PinWorkersToCores(pool_);
PinWorkersToCores(pool_);
}
AbortIfInvalidArgs(inference_args_);
@ -78,7 +78,7 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
kv_caches_.reserve(16);
for (int i = 0; i < 16; ++i) {
kv_caches_.push_back(new KVCache(KVCache::Create(loader_.ModelType())));
kv_caches_.push_back(new KVCache(KVCache::Create(model_->Info().model)));
}
}
@ -181,9 +181,8 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel2(
}
std::pair<std::string, size_t> GemmaEnv::QueryModel(std::string& input) {
const std::vector<int> prompt =
WrapAndTokenize(model_->Tokenizer(), loader_.ModelTrainingType(),
/*pos=*/0, input);
const std::vector<int> prompt = WrapAndTokenize(model_->Tokenizer(), Info(),
/*pos=*/0, input);
return QueryModel(prompt);
}
std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
@ -193,9 +192,8 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
for (auto& input : inputs) {
std::string mutable_prompt = input;
prompts.push_back(std::make_unique<std::vector<int>>(
WrapAndTokenize(model_->Tokenizer(),
loader_.ModelTrainingType(),
/*pos=*/0, mutable_prompt)));
WrapAndTokenize(model_->Tokenizer(), model_->Info(),
/*pos=*/0, mutable_prompt)));
}
std::vector<hwy::Span<int>> prompt_vector;
prompt_vector.reserve(prompts.size());
@ -234,8 +232,7 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
time_t now = time(nullptr);
char* dt = ctime(&now); // NOLINT
std::cout << "Date & Time : " << dt
<< "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize
<< "\n"
<< "Prefill Token Batch Size : " << kPrefillBatchSize << "\n"
<< "Hardware concurrency : "
<< std::thread::hardware_concurrency() << "\n"
<< "Instruction set : "
@ -247,14 +244,13 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
}
std::cout << "Compiled config : " << CompiledConfig() << "\n"
<< "Weight Type : "
<< gcpp::StringFromType(loader.WeightType()) << "\n"
<< StringFromType(loader.Info().weight) << "\n"
<< "EmbedderInput Type : "
<< gcpp::TypeName(gcpp::EmbedderInputT()) << "\n";
<< TypeName(EmbedderInputT()) << "\n";
}
}
void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference,
gcpp::AppArgs& app) {
void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
std::cerr
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
"==========================================================\n\n"

View File

@ -24,7 +24,6 @@
#include <utility>
#include <vector>
#include "gemma/common.h"
#include "gemma/gemma.h"
#include "util/app.h"
#include "hwy/base.h"
@ -83,12 +82,10 @@ class GemmaEnv {
// Returns nullptr if the model failed to load.
Gemma* GetModel() const { return model_.get(); }
Model ModelType() const { return loader_.ModelType(); }
ModelTraining ModelTrainingType() const {
return loader_.ModelTrainingType();
}
int Verbosity() const { return app_.verbosity; }
RuntimeConfig& MutableConfig() { return runtime_config_; }
const ModelInfo& Info() const { return loader_.Info(); }
InferenceArgs& MutableInferenceArgs() { return inference_args_; }
std::mt19937& MutableGen() { return gen_; }
KVCache& MutableKVCache() { return *kv_caches_[0]; }

View File

@ -39,7 +39,7 @@
#include <array>
#include <memory>
#include <string>
#include <utility>
#include <utility> // std::move
#include <vector>
#include "compression/io.h" // Path
@ -1085,32 +1085,26 @@ struct AllocateState {
} // namespace
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
Type weight_type, hwy::ThreadPool& pool)
: pool_(pool),
tokenizer_(tokenizer_path),
model_type_(model_type),
weight_type_(weight_type) {
weights_u8_ = LoadCompressedWeights(weights, model_type, weight_type, pool);
CallForModelAndWeight<AllocateState>(model_type, weight_type, prefill_u8_,
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
const ModelInfo& info, hwy::ThreadPool& pool)
: pool_(pool), tokenizer_(tokenizer_path), info_(info) {
weights_u8_ = LoadCompressedWeights(weights, info.model, info.weight, pool);
CallForModelAndWeight<AllocateState>(info.model, info.weight, prefill_u8_,
decode_u8_);
}
Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type,
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
hwy::ThreadPool& pool)
: pool_(pool),
tokenizer_(std::move(tokenizer)),
model_type_(model_type),
weight_type_(weight_type) {
HWY_ASSERT(weight_type == Type::kF32);
weights_u8_ = CallForModel<float, AllocateCompressedWeights>(
model_type, pool);
CallForModelAndWeight<AllocateState>(model_type, weight_type, prefill_u8_,
: pool_(pool), tokenizer_(std::move(tokenizer)), info_(info) {
HWY_ASSERT(info.weight == Type::kF32);
weights_u8_ =
CallForModel<float, AllocateCompressedWeights>(info.model, pool);
CallForModelAndWeight<AllocateState>(info.model, info.weight, prefill_u8_,
decode_u8_);
}
Gemma::~Gemma() {
CallForModelAndWeight<DeleteCompressedWeights>(model_type_, weight_type_,
CallForModelAndWeight<DeleteCompressedWeights>(info_.model, info_.weight,
weights_u8_);
}
@ -1120,7 +1114,7 @@ void Gemma::Generate(const RuntimeConfig& runtime_config,
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
GEMMA_EXPORT_AND_DISPATCH(
model_type_, weight_type_, GenerateOneQueryT,
info_.model, info_.weight, GenerateOneQueryT,
(weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompt, start_pos,
kv_cache, pool_, timing_info));
@ -1135,24 +1129,29 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
pool_.SetWaitMode(hwy::PoolWaitMode::kSpin);
GEMMA_EXPORT_AND_DISPATCH(
model_type_, weight_type_, GenerateBatchT,
info_.model, info_.weight, GenerateBatchT,
(weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompts, start_pos,
kv_caches, pool_, timing_info));
pool_.SetWaitMode(hwy::PoolWaitMode::kBlock);
}
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const ModelTraining training, size_t pos,
std::string& prompt) {
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) {
// Instruction-tuned models are trained to expect control tokens.
if (training == ModelTraining::GEMMA_IT) {
if (info.training == ModelTraining::GEMMA_IT) {
// Prepend "<end_of_turn>" if this is a multi-turn dialogue continuation.
const std::string start = (pos == 0)
? "<start_of_turn>user\n"
: "<end_of_turn>\n<start_of_turn>user\n";
prompt = start + prompt + "<end_of_turn>\n<start_of_turn>model\n";
}
}
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const ModelInfo& info, size_t pos,
std::string& prompt) {
Wrap(info, pos, prompt);
std::vector<int> tokens;
HWY_ASSERT(tokenizer.Encode(prompt, &tokens));

View File

@ -93,6 +93,13 @@ using SampleFunc = std::function<int(const float*, size_t)>;
using LayersOutputFunc =
std::function<void(int, const std::string&, const float*, size_t)>;
// TODO(janwas): move into common.h, merge with parser/ToString.
struct ModelInfo {
Model model;
ModelTraining training;
Type weight;
};
struct RuntimeConfig {
size_t max_tokens;
size_t max_generated_tokens;
@ -115,15 +122,15 @@ struct TimingInfo {
class Gemma {
public:
Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
Type weight_type, hwy::ThreadPool& pool);
Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info,
hwy::ThreadPool& pool);
// Allocates weights, caller is responsible for filling them.
Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type,
Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info,
hwy::ThreadPool& pool);
~Gemma();
Model ModelType() const { return model_type_; }
const ModelInfo& Info() const { return info_; }
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
const ByteStorageT& Weights() const { return weights_u8_; }
const ByteStorageT& Prefill() const { return prefill_u8_; }
@ -147,15 +154,14 @@ class Gemma {
ByteStorageT weights_u8_;
ByteStorageT prefill_u8_;
ByteStorageT decode_u8_;
Model model_type_;
Type weight_type_;
ModelInfo info_;
};
// Adds BOS token and possibly 'turn' annotations, which depend on `training`
// and `pos`, the number of tokens decoded so far; returns the corresponding
// tokens. Asserts that tokenization is successful.
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
ModelTraining training, size_t pos,
const ModelInfo& info, size_t pos,
std::string& prompt);
// DEPRECATED, call Gemma::Generate directly.

View File

@ -53,10 +53,9 @@ static constexpr std::string_view kAsciiArtBanner = R""(
)"";
// The main Read-Eval-Print Loop.
void ReplGemma(gcpp::Gemma& model, ModelTraining training,
gcpp::KVCache& kv_cache, hwy::ThreadPool& pool,
void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool,
const InferenceArgs& args, int verbosity,
const gcpp::AcceptFunc& accept_token, std::string& eot_line) {
const AcceptFunc& accept_token, std::string& eot_line) {
PROFILER_ZONE("Gen.misc");
size_t abs_pos = 0; // absolute token index over all turns
int current_pos = 0; // token index within the current turn
@ -73,7 +72,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
// <= since position is incremented before
if (current_pos <= prompt_size) {
std::cerr << "." << std::flush;
} else if (token == gcpp::EOS_ID) {
} else if (token == EOS_ID) {
if (!args.multiturn) {
abs_pos = 0;
if (args.deterministic) {
@ -131,8 +130,8 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
continue;
}
const std::vector<int> prompt =
WrapAndTokenize(model.Tokenizer(), training, abs_pos, prompt_string);
const std::vector<int> prompt = WrapAndTokenize(
model.Tokenizer(), model.Info(), abs_pos, prompt_string);
prompt_size = prompt.size();
std::cerr << "\n"
<< "[ Reading prompt ] " << std::flush;
@ -143,7 +142,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training,
}
TimingInfo timing_info;
gcpp::RuntimeConfig runtime_config = {
RuntimeConfig runtime_config = {
.max_tokens = args.max_tokens,
.max_generated_tokens = args.max_generated_tokens,
.temperature = args.temperature,
@ -179,8 +178,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
PinWorkersToCores(pool);
}
gcpp::Gemma model = gcpp::CreateGemma(loader, pool);
KVCache kv_cache = KVCache::Create(loader.ModelType());
Gemma model = CreateGemma(loader, pool);
KVCache kv_cache = KVCache::Create(model.Info().model);
if (app.verbosity >= 1) {
const std::string instructions =
@ -208,8 +207,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
std::cout << "\n" << instructions << "\n";
}
ReplGemma(model, loader.ModelTrainingType(), kv_cache, pool, inference,
app.verbosity, AcceptFunc(), app.eot_line);
ReplGemma(model, kv_cache, pool, inference, app.verbosity, AcceptFunc(),
app.eot_line);
}
} // namespace gcpp

View File

@ -18,18 +18,12 @@
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
#include <memory>
#include "hwy/contrib/thread_pool/thread_pool.h"
#if HWY_OS_LINUX
#include <sched.h>
#endif // HWY_OS_LINUX
#include <stddef.h>
#include <stdio.h>
#include <algorithm> // std::clamp
#include <memory>
#include <string>
#include <thread> // NOLINT>
#include <vector>
#include "compression/io.h" // Path
@ -38,8 +32,13 @@
#include "gemma/gemma.h"
#include "util/args.h"
#include "hwy/base.h" // HWY_ASSERT
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/contrib/thread_pool/topology.h"
#if HWY_OS_LINUX
#include <sched.h>
#endif // HWY_OS_LINUX
namespace gcpp {
static inline const char* CompiledConfig() {
@ -168,11 +167,11 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
// Returns error string or nullptr if OK.
const char* Validate() {
if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_,
model_training_)) {
if (const char* err = ParseModelTypeAndTraining(model_type_str, info_.model,
info_.training)) {
return err;
}
if (const char* err = ParseType(weight_type_str, weight_type_)) {
if (const char* err = ParseType(weight_type_str, info_.weight)) {
return err;
}
if (tokenizer.path.empty()) {
@ -226,26 +225,21 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
}
// Uninitialized before Validate, must call after that.
gcpp::Model ModelType() const { return model_type_; }
gcpp::ModelTraining ModelTrainingType() const { return model_training_; }
gcpp::Type WeightType() const { return weight_type_; }
const ModelInfo& Info() const { return info_; }
private:
Model model_type_;
ModelTraining model_training_;
Type weight_type_;
ModelInfo info_;
};
static inline Gemma CreateGemma(const LoaderArgs& loader,
hwy::ThreadPool& pool) {
return Gemma(loader.tokenizer, loader.weights, loader.ModelType(),
loader.WeightType(), pool);
return Gemma(loader.tokenizer, loader.weights, loader.Info(), pool);
}
static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
hwy::ThreadPool& pool) {
return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
loader.ModelType(), loader.WeightType(), pool);
loader.Info(), pool);
}
struct InferenceArgs : public ArgsBase<InferenceArgs> {