From c378ac2c565081cfc9bf0e57531d9531948a6e32 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Sun, 3 Mar 2024 11:36:48 -0500 Subject: [PATCH] [WIP] hello world example working. TODO: refactor interfaces to decouple arguments --- examples/{look => hello_world}/CMakeLists.txt | 18 ++--- .../{look => hello_world}/build/.gitignore | 0 examples/hello_world/run.cc | 74 +++++++++++++++++++ examples/look/run.cc | 28 ------- models/.gitignore | 0 5 files changed, 83 insertions(+), 37 deletions(-) rename examples/{look => hello_world}/CMakeLists.txt (70%) rename examples/{look => hello_world}/build/.gitignore (100%) create mode 100644 examples/hello_world/run.cc delete mode 100644 examples/look/run.cc create mode 100644 models/.gitignore diff --git a/examples/look/CMakeLists.txt b/examples/hello_world/CMakeLists.txt similarity index 70% rename from examples/look/CMakeLists.txt rename to examples/hello_world/CMakeLists.txt index 8666ce7..69e4e98 100644 --- a/examples/look/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -14,7 +14,7 @@ cmake_minimum_required(VERSION 3.11) -project(look) +project(hello_world) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) @@ -34,12 +34,12 @@ if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE "Release") endif() -add_executable(look run.cc) -target_sources(look PRIVATE ${SOURCES}) -set_property(TARGET look PROPERTY CXX_STANDARD 17) -target_link_libraries(look hwy hwy_contrib sentencepiece libgemma) -target_include_directories(look PRIVATE ./) +add_executable(hello_world run.cc) +target_sources(hello_world PRIVATE ${SOURCES}) +set_property(TARGET hello_world PROPERTY CXX_STANDARD 17) +target_link_libraries(hello_world hwy hwy_contrib sentencepiece libgemma) +target_include_directories(hello_world PRIVATE ./) FetchContent_GetProperties(sentencepiece) -target_include_directories(look PRIVATE ${sentencepiece_SOURCE_DIR}) -target_compile_definitions(look PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) -target_compile_options(look PRIVATE $<$:-Wno-deprecated-declarations>) +target_include_directories(hello_world PRIVATE ${sentencepiece_SOURCE_DIR}) +target_compile_definitions(hello_world PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) +target_compile_options(hello_world PRIVATE $<$:-Wno-deprecated-declarations>) diff --git a/examples/look/build/.gitignore b/examples/hello_world/build/.gitignore similarity index 100% rename from examples/look/build/.gitignore rename to examples/hello_world/build/.gitignore diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc new file mode 100644 index 0000000..a017e22 --- /dev/null +++ b/examples/hello_world/run.cc @@ -0,0 +1,74 @@ +#include + +// copybara:import_next_line:gemma_cpp +#include "compression/compress.h" +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "gemma.h" // Gemma +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/app.h" +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/args.h" // HasHelp +// copybara:end +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/highway.h" +#include "hwy/per_target.h" +#include "hwy/profiler.h" +#include "hwy/timer.h" + +std::vector tokenize(std::string prompt_string, const sentencepiece::SentencePieceProcessor& tokenizer) { + prompt_string = "user\n" + prompt_string + + "\nmodel\n"; + std::vector tokens; + HWY_ASSERT(tokenizer.Encode(prompt_string, &tokens).ok()); + tokens.insert(tokens.begin(), 2); // BOS token + return tokens; +} + +int main(int argc, char** argv) { + gcpp::InferenceArgs inference(argc, argv); + gcpp::LoaderArgs loader(argc, argv); + gcpp::AppArgs app(argc, argv); + hwy::ThreadPool pool(app.num_threads); + hwy::ThreadPool inner_pool(0); + gcpp::Gemma model(loader, pool); + + std::vector tokens = tokenize("Hello, how are you?", model.Tokenizer()); + + std::mt19937 gen; + std::random_device rd; + gen.seed(rd()); + + size_t ntokens = tokens.size(); + + size_t pos = 0; + + auto stream_token = [&pos, &gen, &ntokens, tokenizer = &model.Tokenizer()](int token, float) { + ++pos; + if (pos < ntokens) { + // print feedback + } else if (token != gcpp::EOS_ID) { + std::string token_text; + HWY_ASSERT(tokenizer->Decode(std::vector{token}, &token_text).ok()); + if (pos == ntokens + 1) { + // first token of response + token_text.erase(0, token_text.find_first_not_of(" \t\n\n")); + } + std::cout << token_text << std::flush; + } + return true; + }; + + inference.temperature = 1.0f; + inference.deterministic = true; + inference.multiturn = false; + + GenerateGemma( + model, inference, tokens, 0, pool, inner_pool, stream_token, + [](int) {return true;}, gen, 0); + + std::cout << std::endl; +} diff --git a/examples/look/run.cc b/examples/look/run.cc deleted file mode 100644 index 85b0da4..0000000 --- a/examples/look/run.cc +++ /dev/null @@ -1,28 +0,0 @@ -#include - -// copybara:import_next_line:gemma_cpp -#include "compression/compress.h" -// copybara:end -// copybara:import_next_line:gemma_cpp -#include "gemma.h" // Gemma -// copybara:end -// copybara:import_next_line:gemma_cpp -#include "util/app.h" -// copybara:end -// copybara:import_next_line:gemma_cpp -#include "util/args.h" // HasHelp -// copybara:end -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/highway.h" -#include "hwy/per_target.h" -#include "hwy/profiler.h" -#include "hwy/timer.h" - -int main(int argc, char** argv) { - gcpp::LoaderArgs loader(argc, argv); - gcpp::AppArgs app(argc, argv); - hwy::ThreadPool pool(app.num_threads); - gcpp::Gemma model(loader, pool); - std::cout << "Done" << std::endl; -} diff --git a/models/.gitignore b/models/.gitignore new file mode 100644 index 0000000..e69de29