diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index be9a108..2f5d648 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -18,9 +18,9 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(FetchContent) -FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 9414b48aeec251b69e6cadbfa42bebb5ddae1c34) +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG bfc36a6e633af94e63ac4b91c687bf0354cb24e0) FetchContent_MakeAvailable(highway) -FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) +FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9) FetchContent_MakeAvailable(sentencepiece) @@ -34,7 +34,7 @@ if (BUILD_MODE STREQUAL "local") # Relative path to gemma.cpp from examples/hello_world/build/ FetchContent_Declare(gemma SOURCE_DIR ../../..) else() - FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG a9aa63fd2ea6b786ed0706d619588bfe2d43370e) + FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 4a924f179448dc83e46a2af9520c61b4ef56174c) endif() FetchContent_MakeAvailable(gemma) diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 52bd507..193903f 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -44,9 +44,7 @@ int main(int argc, char** argv) { for (int arg = 0; arg < argc; ++arg) { // Find a --reject flag and consume everything after it. if (strcmp(argv[arg], "--reject") == 0) { - while (++arg < argc) { - reject_tokens.insert(atoi(argv[arg])); // NOLINT - } + while (++arg < argc) reject_tokens.insert(atoi(argv[arg])); } } @@ -90,9 +88,9 @@ int main(int argc, char** argv) { .verbosity = 0, .stream_token = stream_token, .accept_token = - [&](int token, float /* prob */) { - return !reject_tokens.contains(token); - }, + std::function([&](int token, float /* prob */) { + return reject_tokens.find(token) == reject_tokens.end(); + }), }; gemma.Generate(runtime_config, tokens, 0, kv_cache, env, timing_info); }