diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index 2eb39c6..2f5d648 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -14,12 +14,13 @@ cmake_minimum_required(VERSION 3.11) project(hello_world) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(FetchContent) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c5bebf84ad01edec97e336f5c97ca4e0df6b4d06) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG bfc36a6e633af94e63ac4b91c687bf0354cb24e0) FetchContent_MakeAvailable(highway) -FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) +FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9) FetchContent_MakeAvailable(sentencepiece) @@ -31,9 +32,9 @@ if (NOT BUILD_MODE) endif() if (BUILD_MODE STREQUAL "local") # Relative path to gemma.cpp from examples/hello_world/build/ - FetchContent_Declare(gemma SOURCE_DIR ../../..) + FetchContent_Declare(gemma SOURCE_DIR ../../..) else() - FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG a9aa63fd2ea6b786ed0706d619588bfe2d43370e) + FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 4a924f179448dc83e46a2af9520c61b4ef56174c) endif() FetchContent_MakeAvailable(gemma) diff --git a/examples/hello_world/README.md b/examples/hello_world/README.md index f396c05..acfaa48 100644 --- a/examples/hello_world/README.md +++ b/examples/hello_world/README.md @@ -41,7 +41,7 @@ gemma.cpp specifying the tokenizer, compressed weights file, and model type, for example: ```sh -./hello_world --tokenizer tokenizer.spm --compressed_weights 2b-it-sfp.sbs --model 2b-it +./hello_world --tokenizer tokenizer.spm --weights 2b-it-sfp.sbs --model 2b-it ``` Should print a greeting to the terminal: diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index fb2fea3..4eb8647 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -32,76 +32,76 @@ #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" -int main(int argc, char** argv) { - { - // Placeholder for internal init, do not modify. - } - - gcpp::LoaderArgs loader(argc, argv); - gcpp::InferenceArgs inference(argc, argv); - gcpp::AppArgs app(argc, argv); - if (gcpp::HasHelp(argc, argv)) { - loader.Help(); - return 0; - } else if (const char* error = loader.Validate()) { - loader.Help(); - HWY_ABORT("\nInvalid args: %s", error); - } - - // Demonstrate constrained decoding by never outputting certain tokens. - std::set reject_tokens; - for (int arg = 0; arg < argc; ++arg) { - // Find a --reject flag and consume everything after it. - if (strcmp(argv[arg], "--reject") == 0) { - while (++arg < argc) reject_tokens.insert(atoi(argv[arg])); +int main(int argc, char **argv) { { + // Placeholder for internal init, do not modify. } - } - // Instantiate model and KV Cache - gcpp::BoundedTopology topology(gcpp::CreateTopology(app)); - gcpp::NestedPools pools = gcpp::CreatePools(topology, app); - gcpp::MatMulEnv env(topology, pools); - gcpp::Gemma model = gcpp::CreateGemma(loader, env); - gcpp::KVCache kv_cache = - gcpp::KVCache::Create(model.GetModelConfig(), - inference.prefill_tbatch_size); - size_t generated = 0; - - // Initialize random number generator - std::mt19937 gen; - std::random_device rd; - gen.seed(rd()); - - // Tokenize instructions. - std::string prompt = "Write a greeting to the world."; - const std::vector tokens = gcpp::WrapAndTokenize( - model.Tokenizer(), loader.Info(), generated, prompt); - const size_t prompt_size = tokens.size(); - - // This callback function gets invoked every time a token is generated - auto stream_token = [&generated, &prompt_size, &model](int token, float) { - ++generated; - if (generated < prompt_size) { - // print feedback - } else if (!model.GetModelConfig().IsEOS(token)) { - std::string token_text; - HWY_ASSERT(model.Tokenizer().Decode({token}, &token_text)); - std::cout << token_text << std::flush; + gcpp::LoaderArgs loader(argc, argv); + gcpp::InferenceArgs inference(argc, argv); + gcpp::AppArgs app(argc, argv); + if (gcpp::HasHelp(argc, argv)) { + loader.Help(); + return 0; + } else if (const char *error = loader.Validate()) { + loader.Help(); + HWY_ABORT("\nInvalid args: %s", error); } - return true; - }; - gcpp::TimingInfo timing_info; - gcpp::RuntimeConfig runtime_config = { - .max_generated_tokens = 1024, - .temperature = 1.0, - .gen = &gen, - .verbosity = 0, - .stream_token = stream_token, - .accept_token = - [&](int token, float /* prob */) { - return !reject_tokens.contains(token); - }, - }; - model.Generate(runtime_config, tokens, 0, kv_cache, timing_info); + // Demonstrate constrained decoding by never outputting certain tokens. + std::set reject_tokens; + for (int arg = 0; arg < argc; ++arg) { + // Find a --reject flag and consume everything after it. + if (strcmp(argv[arg], "--reject") == 0) { + while (++arg < argc) reject_tokens.insert(atoi(argv[arg])); + } + } + + // Instantiate model and KV Cache + gcpp::BoundedTopology topology(gcpp::CreateTopology(app)); + gcpp::NestedPools pools = gcpp::CreatePools(topology, app); + gcpp::MatMulEnv env(topology, pools); + gcpp::Gemma model = gcpp::CreateGemma(loader, env); + gcpp::KVCache kv_cache = + gcpp::KVCache::Create(model.GetModelConfig(), + inference.prefill_tbatch_size); + size_t generated = 0; + + // Initialize random number generator + std::mt19937 gen; + std::random_device rd; + gen.seed(rd()); + + // Tokenize instructions. + std::string prompt = "Write a greeting to the world."; + const std::vector tokens = gcpp::WrapAndTokenize( + model.Tokenizer(), loader.Info(), generated, prompt); + const size_t prompt_size = tokens.size(); + + // This callback function gets invoked every time a token is generated + auto stream_token = [&generated, &prompt_size, &model](int token, float) { + ++generated; + if (generated < prompt_size) { + // print feedback + } else if (!model.GetModelConfig().IsEOS(token)) { + std::string token_text; + HWY_ASSERT(model.Tokenizer().Decode({token}, &token_text)); + std::cout << token_text << std::flush; + } + return true; + }; + + gcpp::TimingInfo timing_info; + gcpp::RuntimeConfig runtime_config = { + .max_generated_tokens = 1024, + .temperature = 1.0, + .gen = &gen, + .verbosity = 0, + .stream_token = stream_token, + .accept_token = + std::function( + [&](int token, float /* prob */) { + return reject_tokens.find(token) == reject_tokens.end(); + }), + }; + model.Generate(runtime_config, tokens, 0, kv_cache, timing_info); }