diff --git a/examples/hello_world/build/.gitignore b/examples/hello_world/build/.gitignore index e69de29..d6b7ef3 100644 --- a/examples/hello_world/build/.gitignore +++ b/examples/hello_world/build/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index a017e22..1a4beed 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -12,6 +12,9 @@ // copybara:import_next_line:gemma_cpp #include "util/args.h" // HasHelp // copybara:end +// copybara:import_next_line:gemma_cpp +#include "configs.h" +// copybara:end #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" @@ -35,17 +38,13 @@ int main(int argc, char** argv) { hwy::ThreadPool pool(app.num_threads); hwy::ThreadPool inner_pool(0); gcpp::Gemma model(loader, pool); - - std::vector tokens = tokenize("Hello, how are you?", model.Tokenizer()); - std::mt19937 gen; std::random_device rd; gen.seed(rd()); + std::vector tokens = tokenize("Hello, how are you?", model.Tokenizer()); size_t ntokens = tokens.size(); - size_t pos = 0; - auto stream_token = [&pos, &gen, &ntokens, tokenizer = &model.Tokenizer()](int token, float) { ++pos; if (pos < ntokens) { diff --git a/gemma.cc b/gemma.cc index 9f1e4a0..add3721 100644 --- a/gemma.cc +++ b/gemma.cc @@ -19,18 +19,18 @@ // which we pass the filename via macro 'argument'. #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE "gemma.cc" // NOLINT -#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/foreach_target.h" // IWYU pragma: keep // Must come after foreach_target.h to avoid redefinition errors. // copybara:import_next_line:gemma_cpp #include "compression/compress-inl.h" // copybara:import_next_line:gemma_cpp #include "ops.h" // copybara:import_next_line:gemma_cpp -#include "util/args.h" // Path #include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/highway.h" #include "hwy/profiler.h" #include "hwy/timer.h" +#include "util/args.h" // Path // Non-SIMD includes and types. Note that HWY_ONCE is only true on the last // compile pass, whereas we want this defined in the first. @@ -231,9 +231,8 @@ struct Activations { struct GemmaInterface { virtual ~GemmaInterface() = default; - virtual const sentencepiece::SentencePieceProcessor& Tokenizer() const = 0; + virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0; - // TODO: group pool/callbacks into struct virtual void Generate(const InferenceArgs& args, const std::vector& prompt, size_t start_pos, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, @@ -244,7 +243,10 @@ struct GemmaInterface { template struct GemmaImpl : public GemmaInterface { - GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool); + GemmaImpl( // const LoaderArgs& args, + std::unique_ptr& tokenizer, + hwy::AlignedFreeUniquePtr& compressed_weights, + hwy::ThreadPool& pool); ~GemmaImpl() { using CWeights = CompressedWeights; @@ -252,8 +254,8 @@ struct GemmaImpl : public GemmaInterface { c_weights->c_layer_ptrs.~CompressedLayerPointers(); } - const sentencepiece::SentencePieceProcessor& Tokenizer() const { - return tokenizer; + const sentencepiece::SentencePieceProcessor* Tokenizer() const { + return tokenizer.get(); } void Generate(const InferenceArgs& args, const std::vector& prompt, @@ -261,9 +263,8 @@ struct GemmaImpl : public GemmaInterface { hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937&, int verbosity); - sentencepiece::SentencePieceProcessor tokenizer; + std::unique_ptr tokenizer; - // CompressedWeights hwy::AlignedFreeUniquePtr compressed_weights; hwy::AlignedUniquePtr> prefill; hwy::AlignedUniquePtr> state; @@ -495,7 +496,8 @@ void Transformer(int token, size_t pos, } template -void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, +void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, + size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t pos, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, @@ -549,7 +551,7 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, // should be available as observable state for frontend code to handle I/O. double prefill_end = hwy::platform::Now(); const double prefill_tok_sec = pos_offset / (prefill_end - prefill_start); - std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]\n"; + std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]"; } double gen_start = hwy::platform::Now(); @@ -558,10 +560,10 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, if (verbosity >= 2) { // Provide usage warnings if max_new_tokens is out of range. - if (args.max_generated_tokens > args.max_tokens) { + if (max_generated_tokens > max_tokens) { std::cout << "Warning: max_new_tokens should be <= max_tokens" << std::endl; - } else if ((prompt.size() + args.max_generated_tokens) > args.max_tokens) { + } else if ((prompt.size() + max_generated_tokens) > max_tokens) { std::cout << "Warning: Prompt size + max_new_tokens exceeds max_tokens." << std::endl; } @@ -570,7 +572,7 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, auto pos_gen_start = pos_offset; token = prompt.at(pos_offset); size_t generate_pos = 0; - for (; pos < args.max_tokens && generate_pos < args.max_generated_tokens; + for (; pos < max_tokens && generate_pos < max_generated_tokens; ++pos, ++pos_offset, ++generate_pos) { Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool); float* final_activation = activations.x.data(); @@ -583,7 +585,7 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, // Barrier: must have all logits so we can subtract max. Softmax(activations.logits.data(), kVocabSize); token = SampleTopK(activations.logits.data(), kVocabSize, gen, - args.temperature, accept_token); + temperature, accept_token); } if (!stream_token(token, activations.logits[token])) { token = EOS_ID; @@ -593,7 +595,7 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, double gen_end = hwy::platform::Now(); const double gen_tok_sec = (pos_offset - pos_gen_start) / (gen_end - gen_start); - std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n"; + std::cout << "[ Generation tokens / sec = " << gen_tok_sec << " ]\n"; } break; } @@ -605,8 +607,9 @@ void Generate2B(GemmaImpl& gemma, const InferenceArgs& args, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { - GenerateImpl(gemma, args, prompt, start_pos, pool, inner_pool, stream_token, - accept_token, gen, verbosity); + GenerateImpl(gemma, args.max_tokens, args.max_generated_tokens, + args.temperature, prompt, start_pos, pool, inner_pool, + stream_token, accept_token, gen, verbosity); } void Generate7B(GemmaImpl& gemma, const InferenceArgs& args, @@ -614,8 +617,9 @@ void Generate7B(GemmaImpl& gemma, const InferenceArgs& args, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { - GenerateImpl(gemma, args, prompt, start_pos, pool, inner_pool, stream_token, - accept_token, gen, verbosity); + GenerateImpl(gemma, args.max_tokens, args.max_generated_tokens, + args.temperature, prompt, start_pos, pool, inner_pool, + stream_token, accept_token, gen, verbosity); } // Calls func(name, float*, CompressedArray&) for each tensor. float* is null @@ -729,17 +733,22 @@ KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) { } template -GemmaImpl::GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool) - : compressed_weights( - HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)), +GemmaImpl::GemmaImpl( + std::unique_ptr& tokenizer, + hwy::AlignedFreeUniquePtr& compressed_weights, + hwy::ThreadPool& pool) + // GemmaImpl::GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& + // pool) + : compressed_weights(std::move(compressed_weights)), + // HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)), prefill(hwy::MakeUniqueAligned>()), state(hwy::MakeUniqueAligned>()), kv_cache( CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim, - Config::kSeqLen)) { - PROFILER_ZONE("Startup.tokenizer"); - - HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok()); + Config::kSeqLen)), + tokenizer(std::move(tokenizer)) { + // PROFILER_ZONE("Startup.tokenizer"); + // HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok()); } template <> @@ -770,12 +779,20 @@ void GemmaImpl::Generate(const InferenceArgs& args, Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { const Model model_type = args.ModelType(); model_training = args.ModelTraining(); + PROFILER_ZONE("Startup.tokenizer"); + std::unique_ptr tokenizer = + std::make_unique(); + HWY_ASSERT(tokenizer->Load(args.tokenizer.path).ok()); + auto compressed_weights = + HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool); switch (model_type) { case Model::GEMMA_2B: - impl_.reset(new GemmaImpl(args, pool)); + impl_.reset( + new GemmaImpl(tokenizer, compressed_weights, pool)); break; case Model::GEMMA_7B: - impl_.reset(new GemmaImpl(args, pool)); + impl_.reset( + new GemmaImpl(tokenizer, compressed_weights, pool)); break; default: HWY_ABORT("Model type %d unknown.", static_cast(model_type)); @@ -783,7 +800,7 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { } Gemma::~Gemma() = default; // after GemmaInterface is defined -const sentencepiece::SentencePieceProcessor& Gemma::Tokenizer() const { +const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { return impl_->Tokenizer(); } diff --git a/gemma.h b/gemma.h index 7195bc9..3de9f0e 100644 --- a/gemma.h +++ b/gemma.h @@ -64,6 +64,15 @@ struct KVCache { enum class Model { GEMMA_2B, GEMMA_7B }; enum class ModelTraining { GEMMA_IT, GEMMA_PT }; +// TODO: incorporate +struct InferenceParams { + Model model; + ModelTraining model_training; + size_t max_generated_tokens; + size_t max_tokens; + float temperature; +}; + struct LoaderArgs : public ArgsBase { LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } @@ -129,9 +138,9 @@ struct LoaderArgs : public ArgsBase { "file if " "the compressed weights file does not exist.\n Required argument."); visitor(model_type, "model", std::string(), - "Model type\n 2b-it (2B parameters, instruction-tuned)\n " - "2b-pt (2B parameters, pretrained)\n 7b-it (7B parameters " - "instruction-tuned)\n 7b-pt (7B parameters, pretrained)\n" + "Model type\n 2b-it = 2B parameters, instruction-tuned\n " + "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters " + "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n" " Required argument."); visitor(model, "weights", Path(), "Path name of model weights (.sbs) file. Only required if " @@ -147,7 +156,10 @@ struct Gemma { Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. - const sentencepiece::SentencePieceProcessor& Tokenizer() const; + // TODO: cleanup + // const sentencepiece::SentencePieceProcessor& Tokenizer() const; + // const std::unique_ptr Tokenizer() const; + const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; gcpp::ModelTraining model_training; @@ -192,8 +204,8 @@ struct InferenceArgs : public ArgsBase { visitor(deterministic, "deterministic", false, "Make top-k sampling deterministic", 2); visitor(multiturn, "multiturn", false, - "Multiturn mode (if 0, this clears the KV cache after every " - "interaction without quitting)\n Default : 0 (conversation " + "Multiturn mode\n 0 = clear KV cache after every " + "interaction\n 1 = continue KV cache after every interaction\n Default : 0 (conversation " "resets every turn)"); } }; diff --git a/run.cc b/run.cc index 507979d..50b9a24 100644 --- a/run.cc +++ b/run.cc @@ -115,7 +115,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, // callback function invoked for each generated token. auto stream_token = [&abs_pos, ¤t_pos, &args, &gen, &prompt_size, - tokenizer = &model.Tokenizer(), + tokenizer = model.Tokenizer(), verbosity](int token, float) { ++abs_pos; ++current_pos; @@ -129,7 +129,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, } } if (verbosity >= 2) { - std::cout << "\n[ End ]" << std::endl; + std::cout << "\n[ End ]\n"; } } else { std::string token_text; @@ -142,7 +142,6 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, std::cout << std::endl << std::endl; } } - // TODO(austinvhuang): is explicit space necessary? std::cout << token_text << std::flush; } return true; @@ -191,7 +190,8 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, } } - HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok()); + // HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok()); + HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt).ok()); // For both pre-trained and instruction-tuned models: prepend "" token // if needed. diff --git a/util/app.h b/util/app.h index 7f926a5..754b2fb 100644 --- a/util/app.h +++ b/util/app.h @@ -79,9 +79,9 @@ class AppArgs : public ArgsBase { template void ForEach(const Visitor& visitor) { visitor(verbosity, "verbosity", 1, - "Show verbose developer information\n 0 = only print generation " - "output\n 1 = standard user-facing terminal ui\n 2 = show " - "developer/debug info).\n Default = 1.", + "Show verbose developer information\n 0 = only print generation " + "output\n 1 = standard user-facing terminal ui\n 2 = show " + "developer/debug info).\n Default = 1.", 2); visitor(num_threads, "num_threads", kDefaultNumThreads, // see ChooseNumThreads