diff --git a/examples/hello_world/BUILD b/examples/hello_world/BUILD new file mode 100644 index 0000000..98fe5fd --- /dev/null +++ b/examples/hello_world/BUILD @@ -0,0 +1,22 @@ +# Hello World example frontend to gemma.cpp. +package( + default_applicable_licenses = [ + "//:license", # Placeholder comment, do not modify + ], + default_visibility = ["//visibility:public"], +) + +cc_binary( + name = "hello_world", + srcs = ["run.cc"], + deps = [ + # Placeholder for internal dep, do not remove., + "//:app", + "//:args", + "//:common", + "//:gemma_lib", + "//:tokenizer", + "@hwy//:hwy", + "@hwy//:thread_pool", + ], +) diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 5984fa6..65844dd 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -15,33 +15,40 @@ #include +#include #include #include #include -#include "third_party/gemma_cpp/gemma.h" +// Placeholder for internal header, do not modify. +#include "gemma/common.h" +#include "gemma/gemma.h" +#include "gemma/tokenizer.h" #include "util/app.h" // LoaderArgs +#include "util/args.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" -std::vector tokenize(const std::string& prompt_string, - const gcpp::GemmaTokenizer* tokenizer) { - std::string formatted = "user\n" + prompt_string + - "\nmodel\n"; - std::vector tokens; - HWY_ASSERT(tokenizer->Encode(formatted, &tokens)); - tokens.insert(tokens.begin(), BOS_ID); - return tokens; -} - int main(int argc, char** argv) { + int argc_dummy = 1; + // Required because sentencepiece uses Google I/O which requires InitGoogle. + // argc_dummy = 1 avoids sentencepiece absl flags attempting to parse + // arguments + InitGoogle("usage", &argc_dummy, &argv, false); + gcpp::LoaderArgs loader(argc, argv); - gcpp::AppArgs app(argc, argv); + if (gcpp::HasHelp(argc, argv)) { + loader.Help(); + return 0; + } else if (const char* error = loader.Validate()) { + loader.Help(); + HWY_ABORT("\nInvalid args: %s", error); + } // Instantiate model and KV Cache - hwy::ThreadPool pool(app.num_threads); + hwy::ThreadPool pool(gcpp::AppArgs::GetSupportedThreadCount()); gcpp::Gemma model = gcpp::CreateGemma(loader, pool); - gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.ModelType()); + gcpp::KVCache kv_cache = gcpp::KVCache::Create(loader.Info().model); size_t pos = 0; // KV Cache position // Initialize random number generator @@ -49,30 +56,33 @@ int main(int argc, char** argv) { std::random_device rd; gen.seed(rd()); - // Tokenize instruction - std::vector tokens = - tokenize("Write a greeting to the world.", model.Tokenizer()); + // Tokenize instructions. + std::string prompt = "Write a greeting to the world."; + const std::vector tokens = + gcpp::WrapAndTokenize(model.Tokenizer(), loader.Info(), pos, prompt); size_t ntokens = tokens.size(); // This callback function gets invoked every time a token is generated - auto stream_token = [&pos, &ntokens, tokenizer = model.Tokenizer()](int token, - float) { + auto stream_token = [&pos, &ntokens, &model](int token, float) { ++pos; if (pos < ntokens) { // print feedback } else if (token != gcpp::EOS_ID) { std::string token_text; - HWY_ASSERT(tokenizer->Decode(std::vector{token}, &token_text)); + HWY_ASSERT(model.Tokenizer().Decode({token}, &token_text)); std::cout << token_text << std::flush; } return true; }; - GenerateGemma(model, - {.max_tokens = 2048, - .max_generated_tokens = 1024, - .temperature = 1.0, - .verbosity = 0}, - tokens, /*KV cache position = */ 0, kv_cache, pool, - stream_token, gen); + gcpp::TimingInfo timing_info; + gcpp::RuntimeConfig runtime_config = { + .max_tokens = 1536, + .max_generated_tokens = 1024, + .temperature = 1.0, + .verbosity = 0, + .gen = &gen, + .stream_token = stream_token, + }; + model.Generate(runtime_config, tokens, 0, kv_cache, timing_info); }