mirror of https://github.com/google/gemma.cpp.git
Demonstrate constrained decoding in gemma_cpp's hello world example
PiperOrigin-RevId: 669327521
This commit is contained in:
parent
4033ed9e78
commit
22d9476aad
|
|
@ -49,3 +49,12 @@ Should print a greeting to the terminal:
|
|||
```
|
||||
"Hello, world! It's a pleasure to greet you all. May your day be filled with joy, peace, and all the things that make your heart soar.
|
||||
```
|
||||
|
||||
For a demonstration of constrained decoding, add the `--reject` flag followed by
|
||||
a list of token IDs (note that it must be the last flag, since it consumes every
|
||||
subsequent argument). For example, to reject variations of the word "greeting",
|
||||
run:
|
||||
|
||||
```sh
|
||||
./hello_world [...] --reject 32338 42360 78107 106837 132832 143859 154230 190205
|
||||
```
|
||||
|
|
|
|||
|
|
@ -15,8 +15,11 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -42,6 +45,15 @@ int main(int argc, char** argv) {
|
|||
HWY_ABORT("\nInvalid args: %s", error);
|
||||
}
|
||||
|
||||
// Demonstrate constrained decoding by never outputting certain tokens.
|
||||
std::set<int> reject_tokens;
|
||||
for (int arg = 0; arg < argc; ++arg) {
|
||||
// Find a --reject flag and consume everything after it.
|
||||
if (strcmp(argv[arg], "--reject") == 0) {
|
||||
while (++arg < argc) reject_tokens.insert(atoi(argv[arg]));
|
||||
}
|
||||
}
|
||||
|
||||
// Instantiate model and KV Cache
|
||||
gcpp::PerClusterPools pools(app.max_clusters, app.num_threads, app.pin);
|
||||
gcpp::Gemma model = gcpp::CreateGemma(loader, pools);
|
||||
|
|
@ -81,6 +93,10 @@ int main(int argc, char** argv) {
|
|||
.verbosity = 0,
|
||||
.gen = &gen,
|
||||
.stream_token = stream_token,
|
||||
.accept_token =
|
||||
[&](int token, float /* prob */) {
|
||||
return !reject_tokens.contains(token);
|
||||
},
|
||||
};
|
||||
model.Generate(runtime_config, tokens, 0, kv_cache, timing_info);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue