Demonstrate constrained decoding in gemma_cpp's hello world example

PiperOrigin-RevId: 669327521
This commit is contained in:
Paul Chang 2024-08-30 08:02:09 -07:00 committed by Copybara-Service
parent 4033ed9e78
commit 22d9476aad
2 changed files with 25 additions and 0 deletions

View File

@ -49,3 +49,12 @@ Should print a greeting to the terminal:
```
"Hello, world! It's a pleasure to greet you all. May your day be filled with joy, peace, and all the things that make your heart soar.
```
For a demonstration of constrained decoding, add the `--reject` flag followed by
a list of token IDs (note that it must be the last flag, since it consumes every
subsequent argument). For example, to reject variations of the word "greeting",
run:
```sh
./hello_world [...] --reject 32338 42360 78107 106837 132832 143859 154230 190205
```

View File

@ -15,8 +15,11 @@
#include <stddef.h>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <random>
#include <set>
#include <string>
#include <vector>
@ -42,6 +45,15 @@ int main(int argc, char** argv) {
HWY_ABORT("\nInvalid args: %s", error);
}
// Demonstrate constrained decoding by never outputting certain tokens.
std::set<int> reject_tokens;
for (int arg = 0; arg < argc; ++arg) {
// Find a --reject flag and consume everything after it.
if (strcmp(argv[arg], "--reject") == 0) {
while (++arg < argc) reject_tokens.insert(atoi(argv[arg]));
}
}
// Instantiate model and KV Cache
gcpp::PerClusterPools pools(app.max_clusters, app.num_threads, app.pin);
gcpp::Gemma model = gcpp::CreateGemma(loader, pools);
@ -81,6 +93,10 @@ int main(int argc, char** argv) {
.verbosity = 0,
.gen = &gen,
.stream_token = stream_token,
.accept_token =
[&](int token, float /* prob */) {
return !reject_tokens.contains(token);
},
};
model.Generate(runtime_config, tokens, 0, kv_cache, timing_info);
}