diff --git a/examples/hello_world/README.md b/examples/hello_world/README.md index 63c319e..f396c05 100644 --- a/examples/hello_world/README.md +++ b/examples/hello_world/README.md @@ -49,3 +49,12 @@ Should print a greeting to the terminal: ``` "Hello, world! It's a pleasure to greet you all. May your day be filled with joy, peace, and all the things that make your heart soar. ``` + +For a demonstration of constrained decoding, add the `--reject` flag followed by +a list of token IDs (note that it must be the last flag, since it consumes every +subsequent argument). For example, to reject variations of the word "greeting", +run: + +```sh +./hello_world [...] --reject 32338 42360 78107 106837 132832 143859 154230 190205 +``` diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 4e0f5f9..1850f58 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -15,8 +15,11 @@ #include +#include +#include #include #include +#include #include #include @@ -42,6 +45,15 @@ int main(int argc, char** argv) { HWY_ABORT("\nInvalid args: %s", error); } + // Demonstrate constrained decoding by never outputting certain tokens. + std::set reject_tokens; + for (int arg = 0; arg < argc; ++arg) { + // Find a --reject flag and consume everything after it. + if (strcmp(argv[arg], "--reject") == 0) { + while (++arg < argc) reject_tokens.insert(atoi(argv[arg])); + } + } + // Instantiate model and KV Cache gcpp::PerClusterPools pools(app.max_clusters, app.num_threads, app.pin); gcpp::Gemma model = gcpp::CreateGemma(loader, pools); @@ -81,6 +93,10 @@ int main(int argc, char** argv) { .verbosity = 0, .gen = &gen, .stream_token = stream_token, + .accept_token = + [&](int token, float /* prob */) { + return !reject_tokens.contains(token); + }, }; model.Generate(runtime_config, tokens, 0, kv_cache, timing_info); }