diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 5a8e343..e0ccd90 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -37,21 +37,24 @@ TEST(OptimizeTest, GradientDescent) { hwy::ThreadPool pool(0); std::mt19937 gen(42); - Model model_type = Model::GEMMA_TINY; - Type weight_type = Type::kF32; + const ModelInfo info = { + .model = Model::GEMMA_TINY, + .training = ModelTraining::GEMMA_IT, + .weight = Type::kF32, + }; ByteStorageT grad = CallForModelAndWeight( - model_type, weight_type, pool); + info.model, info.weight, pool); ByteStorageT grad_m = CallForModelAndWeight( - model_type, weight_type, pool); + info.model, info.weight, pool); ByteStorageT grad_v = CallForModelAndWeight( - model_type, weight_type, pool); + info.model, info.weight, pool); ByteStorageT forward = - CallForModelAndWeight(model_type, weight_type); + CallForModelAndWeight(info.model, info.weight); ByteStorageT backward = - CallForModelAndWeight(model_type, weight_type); - KVCache kv_cache = KVCache::Create(model_type); + CallForModelAndWeight(info.model, info.weight); + KVCache kv_cache = KVCache::Create(info.model); - Gemma gemma(GemmaTokenizer(), model_type, weight_type, pool); + Gemma gemma(GemmaTokenizer(), info, pool); const auto generate = [&](const std::vector& prompt) { std::vector reply; @@ -85,14 +88,14 @@ TEST(OptimizeTest, GradientDescent) { return ok; }; - RandInitWeights(model_type, weight_type, gemma.Weights(), pool, gen); - CallForModelAndWeight( - model_type, weight_type, grad_m, pool); - CallForModelAndWeight( - model_type, weight_type, grad_v, pool); + RandInitWeights(info.model, info.weight, gemma.Weights(), pool, gen); + CallForModelAndWeight(info.model, info.weight, + grad_m, pool); + CallForModelAndWeight(info.model, info.weight, + grad_v, pool); printf("Initial weights:\n"); - LogWeightStats(model_type, weight_type, gemma.Weights()); + LogWeightStats(info.model, info.weight, gemma.Weights()); constexpr size_t kBatchSize = 8; const float alpha = 0.001f; @@ -107,27 +110,27 @@ TEST(OptimizeTest, GradientDescent) { size_t num_ok; for (; steps < 1000000; ++steps) { std::mt19937 sgen(42); - CallForModelAndWeight( - model_type, weight_type, grad, pool); + CallForModelAndWeight(info.model, info.weight, + grad, pool); float total_loss = 0.0f; num_ok = 0; for (size_t i = 0; i < kBatchSize; ++i) { Prompt prompt = training_task.Sample(sgen); - total_loss += CrossEntropyLossForwardPass(model_type, prompt, + total_loss += CrossEntropyLossForwardPass(info.model, prompt, gemma.Weights(), forward, pool); - CrossEntropyLossBackwardPass(model_type, prompt, gemma.Weights(), forward, + CrossEntropyLossBackwardPass(info.model, prompt, gemma.Weights(), forward, grad, backward, pool); num_ok += verify(prompt) ? 1 : 0; } total_loss /= kBatchSize; - AdamUpdate(model_type, weight_type, grad, alpha, beta1, beta2, epsilon, + AdamUpdate(info.model, info.weight, grad, alpha, beta1, beta2, epsilon, steps + 1, gemma.Weights(), grad_m, grad_v, pool); printf("step: %zu total_loss: %.15f num_ok: %zu/%zu\n", steps, total_loss, num_ok, kBatchSize); if (steps % 100 == 0) { printf("Batch gradient:\n"); - LogWeightStats(model_type, weight_type, grad); + LogWeightStats(info.model, info.weight, grad); } if (total_loss < 0.5f) { break; @@ -136,7 +139,7 @@ TEST(OptimizeTest, GradientDescent) { } printf("Num steps: %zu\n", steps); printf("Final weights:\n"); - LogWeightStats(model_type, weight_type, gemma.Weights()); + LogWeightStats(info.model, info.weight, gemma.Weights()); EXPECT_LT(steps, 300); EXPECT_EQ(num_ok, kBatchSize); } diff --git a/evals/benchmark.cc b/evals/benchmark.cc index 323cc96..d9d354d 100644 --- a/evals/benchmark.cc +++ b/evals/benchmark.cc @@ -128,7 +128,7 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text, size_t num_tokens = std::min(prompt.size() - pos, batch_tokens); std::vector prompt_slice(prompt.begin() + pos, prompt.begin() + pos + num_tokens); - KVCache kv_cache = KVCache::Create(env.ModelType()); + KVCache kv_cache = KVCache::Create(env.Info().model); float entropy = ComputeCrossEntropy( *env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity()); total_entropy += entropy; @@ -185,7 +185,7 @@ int main(int argc, char** argv) { if (!benchmark_args.goldens.Empty()) { const std::string golden_path = benchmark_args.goldens.path + "/" + - gcpp::ModelString(env.ModelType(), env.ModelTrainingType()) + ".txt"; + gcpp::ModelString(env.Info().model, env.Info().training) + ".txt"; return BenchmarkGoldens(env, golden_path); } else if (!benchmark_args.summarize_text.Empty()) { return BenchmarkSummary(env, benchmark_args.summarize_text); diff --git a/evals/cross_entropy.cc b/evals/cross_entropy.cc index 627dd84..566ab85 100644 --- a/evals/cross_entropy.cc +++ b/evals/cross_entropy.cc @@ -71,7 +71,7 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens, // TWeight is unused, but we have to pass it to Config*. const int vocab_size = - CallForModel(gemma.ModelType()); + CallForModel(gemma.Info().model); float cross_entropy = std::log(vocab_size); // first token size_t pos = 1; const SampleFunc sample_token = [&](const float* probs, diff --git a/evals/run_mmlu.cc b/evals/run_mmlu.cc index 814ecd6..f6f4640 100644 --- a/evals/run_mmlu.cc +++ b/evals/run_mmlu.cc @@ -104,7 +104,7 @@ void Run(GemmaEnv& env, JsonArgs& json) { "Do not include any justifications or explanations. Reply only with a " "letter."; const std::vector prompt = - WrapAndTokenize(env.GetModel()->Tokenizer(), env.ModelTrainingType(), + WrapAndTokenize(env.GetModel()->Tokenizer(), env.Info(), /*pos=*/0, prompt_string); const size_t prompt_size = prompt.size(); diff --git a/gemma/benchmark_helper.cc b/gemma/benchmark_helper.cc index d5c3fea..a5d994c 100644 --- a/gemma/benchmark_helper.cc +++ b/gemma/benchmark_helper.cc @@ -64,7 +64,7 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, pool_(app_.num_threads) { // For many-core, pinning workers to cores helps. if (app_.num_threads > 10) { - gcpp::PinWorkersToCores(pool_); + PinWorkersToCores(pool_); } AbortIfInvalidArgs(inference_args_); @@ -78,7 +78,7 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, kv_caches_.reserve(16); for (int i = 0; i < 16; ++i) { - kv_caches_.push_back(new KVCache(KVCache::Create(loader_.ModelType()))); + kv_caches_.push_back(new KVCache(KVCache::Create(model_->Info().model))); } } @@ -181,9 +181,8 @@ std::vector> GemmaEnv::BatchQueryModel2( } std::pair GemmaEnv::QueryModel(std::string& input) { - const std::vector prompt = - WrapAndTokenize(model_->Tokenizer(), loader_.ModelTrainingType(), - /*pos=*/0, input); + const std::vector prompt = WrapAndTokenize(model_->Tokenizer(), Info(), + /*pos=*/0, input); return QueryModel(prompt); } std::vector> GemmaEnv::BatchQueryModel( @@ -193,9 +192,8 @@ std::vector> GemmaEnv::BatchQueryModel( for (auto& input : inputs) { std::string mutable_prompt = input; prompts.push_back(std::make_unique>( - WrapAndTokenize(model_->Tokenizer(), - loader_.ModelTrainingType(), - /*pos=*/0, mutable_prompt))); + WrapAndTokenize(model_->Tokenizer(), model_->Info(), + /*pos=*/0, mutable_prompt))); } std::vector> prompt_vector; prompt_vector.reserve(prompts.size()); @@ -234,8 +232,7 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { time_t now = time(nullptr); char* dt = ctime(&now); // NOLINT std::cout << "Date & Time : " << dt - << "Prefill Token Batch Size : " << gcpp::kPrefillBatchSize - << "\n" + << "Prefill Token Batch Size : " << kPrefillBatchSize << "\n" << "Hardware concurrency : " << std::thread::hardware_concurrency() << "\n" << "Instruction set : " @@ -247,14 +244,13 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } std::cout << "Compiled config : " << CompiledConfig() << "\n" << "Weight Type : " - << gcpp::StringFromType(loader.WeightType()) << "\n" + << StringFromType(loader.Info().weight) << "\n" << "EmbedderInput Type : " - << gcpp::TypeName(gcpp::EmbedderInputT()) << "\n"; + << TypeName(EmbedderInputT()) << "\n"; } } -void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, - gcpp::AppArgs& app) { +void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { std::cerr << "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n" "==========================================================\n\n" diff --git a/gemma/benchmark_helper.h b/gemma/benchmark_helper.h index 3909606..25f40b4 100644 --- a/gemma/benchmark_helper.h +++ b/gemma/benchmark_helper.h @@ -24,7 +24,6 @@ #include #include -#include "gemma/common.h" #include "gemma/gemma.h" #include "util/app.h" #include "hwy/base.h" @@ -83,12 +82,10 @@ class GemmaEnv { // Returns nullptr if the model failed to load. Gemma* GetModel() const { return model_.get(); } - Model ModelType() const { return loader_.ModelType(); } - ModelTraining ModelTrainingType() const { - return loader_.ModelTrainingType(); - } + int Verbosity() const { return app_.verbosity; } RuntimeConfig& MutableConfig() { return runtime_config_; } + const ModelInfo& Info() const { return loader_.Info(); } InferenceArgs& MutableInferenceArgs() { return inference_args_; } std::mt19937& MutableGen() { return gen_; } KVCache& MutableKVCache() { return *kv_caches_[0]; } diff --git a/gemma/gemma.cc b/gemma/gemma.cc index f6f7a58..48b4ddc 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -39,7 +39,7 @@ #include #include #include -#include +#include // std::move #include #include "compression/io.h" // Path @@ -1085,32 +1085,26 @@ struct AllocateState { } // namespace -Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, - Type weight_type, hwy::ThreadPool& pool) - : pool_(pool), - tokenizer_(tokenizer_path), - model_type_(model_type), - weight_type_(weight_type) { - weights_u8_ = LoadCompressedWeights(weights, model_type, weight_type, pool); - CallForModelAndWeight(model_type, weight_type, prefill_u8_, +Gemma::Gemma(const Path& tokenizer_path, const Path& weights, + const ModelInfo& info, hwy::ThreadPool& pool) + : pool_(pool), tokenizer_(tokenizer_path), info_(info) { + weights_u8_ = LoadCompressedWeights(weights, info.model, info.weight, pool); + CallForModelAndWeight(info.model, info.weight, prefill_u8_, decode_u8_); } -Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type, +Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, hwy::ThreadPool& pool) - : pool_(pool), - tokenizer_(std::move(tokenizer)), - model_type_(model_type), - weight_type_(weight_type) { - HWY_ASSERT(weight_type == Type::kF32); - weights_u8_ = CallForModel( - model_type, pool); - CallForModelAndWeight(model_type, weight_type, prefill_u8_, + : pool_(pool), tokenizer_(std::move(tokenizer)), info_(info) { + HWY_ASSERT(info.weight == Type::kF32); + weights_u8_ = + CallForModel(info.model, pool); + CallForModelAndWeight(info.model, info.weight, prefill_u8_, decode_u8_); } Gemma::~Gemma() { - CallForModelAndWeight(model_type_, weight_type_, + CallForModelAndWeight(info_.model, info_.weight, weights_u8_); } @@ -1120,7 +1114,7 @@ void Gemma::Generate(const RuntimeConfig& runtime_config, pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); GEMMA_EXPORT_AND_DISPATCH( - model_type_, weight_type_, GenerateOneQueryT, + info_.model, info_.weight, GenerateOneQueryT, (weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompt, start_pos, kv_cache, pool_, timing_info)); @@ -1135,24 +1129,29 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, pool_.SetWaitMode(hwy::PoolWaitMode::kSpin); GEMMA_EXPORT_AND_DISPATCH( - model_type_, weight_type_, GenerateBatchT, + info_.model, info_.weight, GenerateBatchT, (weights_u8_, prefill_u8_, decode_u8_, runtime_config, prompts, start_pos, kv_caches, pool_, timing_info)); pool_.SetWaitMode(hwy::PoolWaitMode::kBlock); } -std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, - const ModelTraining training, size_t pos, - std::string& prompt) { +void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) { + // Instruction-tuned models are trained to expect control tokens. - if (training == ModelTraining::GEMMA_IT) { + if (info.training == ModelTraining::GEMMA_IT) { // Prepend "" if this is a multi-turn dialogue continuation. const std::string start = (pos == 0) ? "user\n" : "\nuser\n"; prompt = start + prompt + "\nmodel\n"; } +} + +std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, + const ModelInfo& info, size_t pos, + std::string& prompt) { + Wrap(info, pos, prompt); std::vector tokens; HWY_ASSERT(tokenizer.Encode(prompt, &tokens)); diff --git a/gemma/gemma.h b/gemma/gemma.h index 6509d17..4b2afc2 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -93,6 +93,13 @@ using SampleFunc = std::function; using LayersOutputFunc = std::function; +// TODO(janwas): move into common.h, merge with parser/ToString. +struct ModelInfo { + Model model; + ModelTraining training; + Type weight; +}; + struct RuntimeConfig { size_t max_tokens; size_t max_generated_tokens; @@ -115,15 +122,15 @@ struct TimingInfo { class Gemma { public: - Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, - Type weight_type, hwy::ThreadPool& pool); + Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info, + hwy::ThreadPool& pool); // Allocates weights, caller is responsible for filling them. - Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type, + Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, hwy::ThreadPool& pool); ~Gemma(); - Model ModelType() const { return model_type_; } + const ModelInfo& Info() const { return info_; } const GemmaTokenizer& Tokenizer() const { return tokenizer_; } const ByteStorageT& Weights() const { return weights_u8_; } const ByteStorageT& Prefill() const { return prefill_u8_; } @@ -147,15 +154,14 @@ class Gemma { ByteStorageT weights_u8_; ByteStorageT prefill_u8_; ByteStorageT decode_u8_; - Model model_type_; - Type weight_type_; + ModelInfo info_; }; // Adds BOS token and possibly 'turn' annotations, which depend on `training` // and `pos`, the number of tokens decoded so far; returns the corresponding // tokens. Asserts that tokenization is successful. std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, - ModelTraining training, size_t pos, + const ModelInfo& info, size_t pos, std::string& prompt); // DEPRECATED, call Gemma::Generate directly. diff --git a/gemma/run.cc b/gemma/run.cc index 78e6734..ccad159 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -53,10 +53,9 @@ static constexpr std::string_view kAsciiArtBanner = R""( )""; // The main Read-Eval-Print Loop. -void ReplGemma(gcpp::Gemma& model, ModelTraining training, - gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, +void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool, const InferenceArgs& args, int verbosity, - const gcpp::AcceptFunc& accept_token, std::string& eot_line) { + const AcceptFunc& accept_token, std::string& eot_line) { PROFILER_ZONE("Gen.misc"); size_t abs_pos = 0; // absolute token index over all turns int current_pos = 0; // token index within the current turn @@ -73,7 +72,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training, // <= since position is incremented before if (current_pos <= prompt_size) { std::cerr << "." << std::flush; - } else if (token == gcpp::EOS_ID) { + } else if (token == EOS_ID) { if (!args.multiturn) { abs_pos = 0; if (args.deterministic) { @@ -131,8 +130,8 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training, continue; } - const std::vector prompt = - WrapAndTokenize(model.Tokenizer(), training, abs_pos, prompt_string); + const std::vector prompt = WrapAndTokenize( + model.Tokenizer(), model.Info(), abs_pos, prompt_string); prompt_size = prompt.size(); std::cerr << "\n" << "[ Reading prompt ] " << std::flush; @@ -143,7 +142,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training, } TimingInfo timing_info; - gcpp::RuntimeConfig runtime_config = { + RuntimeConfig runtime_config = { .max_tokens = args.max_tokens, .max_generated_tokens = args.max_generated_tokens, .temperature = args.temperature, @@ -179,8 +178,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { PinWorkersToCores(pool); } - gcpp::Gemma model = gcpp::CreateGemma(loader, pool); - KVCache kv_cache = KVCache::Create(loader.ModelType()); + Gemma model = CreateGemma(loader, pool); + KVCache kv_cache = KVCache::Create(model.Info().model); if (app.verbosity >= 1) { const std::string instructions = @@ -208,8 +207,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { std::cout << "\n" << instructions << "\n"; } - ReplGemma(model, loader.ModelTrainingType(), kv_cache, pool, inference, - app.verbosity, AcceptFunc(), app.eot_line); + ReplGemma(model, kv_cache, pool, inference, app.verbosity, AcceptFunc(), + app.eot_line); } } // namespace gcpp diff --git a/util/app.h b/util/app.h index a5a7dfd..c1eff00 100644 --- a/util/app.h +++ b/util/app.h @@ -18,18 +18,12 @@ #ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ #define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ -#include - -#include "hwy/contrib/thread_pool/thread_pool.h" -#if HWY_OS_LINUX -#include -#endif // HWY_OS_LINUX #include #include #include // std::clamp +#include #include -#include // NOLINT> #include #include "compression/io.h" // Path @@ -38,8 +32,13 @@ #include "gemma/gemma.h" #include "util/args.h" #include "hwy/base.h" // HWY_ASSERT +#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/topology.h" +#if HWY_OS_LINUX +#include +#endif // HWY_OS_LINUX + namespace gcpp { static inline const char* CompiledConfig() { @@ -168,11 +167,11 @@ struct LoaderArgs : public ArgsBase { // Returns error string or nullptr if OK. const char* Validate() { - if (const char* err = ParseModelTypeAndTraining(model_type_str, model_type_, - model_training_)) { + if (const char* err = ParseModelTypeAndTraining(model_type_str, info_.model, + info_.training)) { return err; } - if (const char* err = ParseType(weight_type_str, weight_type_)) { + if (const char* err = ParseType(weight_type_str, info_.weight)) { return err; } if (tokenizer.path.empty()) { @@ -226,26 +225,21 @@ struct LoaderArgs : public ArgsBase { } // Uninitialized before Validate, must call after that. - gcpp::Model ModelType() const { return model_type_; } - gcpp::ModelTraining ModelTrainingType() const { return model_training_; } - gcpp::Type WeightType() const { return weight_type_; } + const ModelInfo& Info() const { return info_; } private: - Model model_type_; - ModelTraining model_training_; - Type weight_type_; + ModelInfo info_; }; static inline Gemma CreateGemma(const LoaderArgs& loader, hwy::ThreadPool& pool) { - return Gemma(loader.tokenizer, loader.weights, loader.ModelType(), - loader.WeightType(), pool); + return Gemma(loader.tokenizer, loader.weights, loader.Info(), pool); } static inline std::unique_ptr AllocateGemma(const LoaderArgs& loader, hwy::ThreadPool& pool) { return std::make_unique(loader.tokenizer, loader.weights, - loader.ModelType(), loader.WeightType(), pool); + loader.Info(), pool); } struct InferenceArgs : public ArgsBase {