mirror of https://github.com/google/gemma.cpp.git
Demonstrate constrained decoding in gemma_cpp's hello world example
PiperOrigin-RevId: 669327521
This commit is contained in:
parent
4033ed9e78
commit
22d9476aad
|
|
@ -49,3 +49,12 @@ Should print a greeting to the terminal:
|
||||||
```
|
```
|
||||||
"Hello, world! It's a pleasure to greet you all. May your day be filled with joy, peace, and all the things that make your heart soar.
|
"Hello, world! It's a pleasure to greet you all. May your day be filled with joy, peace, and all the things that make your heart soar.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
For a demonstration of constrained decoding, add the `--reject` flag followed by
|
||||||
|
a list of token IDs (note that it must be the last flag, since it consumes every
|
||||||
|
subsequent argument). For example, to reject variations of the word "greeting",
|
||||||
|
run:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./hello_world [...] --reject 32338 42360 78107 106837 132832 143859 154230 190205
|
||||||
|
```
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,11 @@
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <random>
|
#include <random>
|
||||||
|
#include <set>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
@ -42,6 +45,15 @@ int main(int argc, char** argv) {
|
||||||
HWY_ABORT("\nInvalid args: %s", error);
|
HWY_ABORT("\nInvalid args: %s", error);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Demonstrate constrained decoding by never outputting certain tokens.
|
||||||
|
std::set<int> reject_tokens;
|
||||||
|
for (int arg = 0; arg < argc; ++arg) {
|
||||||
|
// Find a --reject flag and consume everything after it.
|
||||||
|
if (strcmp(argv[arg], "--reject") == 0) {
|
||||||
|
while (++arg < argc) reject_tokens.insert(atoi(argv[arg]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Instantiate model and KV Cache
|
// Instantiate model and KV Cache
|
||||||
gcpp::PerClusterPools pools(app.max_clusters, app.num_threads, app.pin);
|
gcpp::PerClusterPools pools(app.max_clusters, app.num_threads, app.pin);
|
||||||
gcpp::Gemma model = gcpp::CreateGemma(loader, pools);
|
gcpp::Gemma model = gcpp::CreateGemma(loader, pools);
|
||||||
|
|
@ -81,6 +93,10 @@ int main(int argc, char** argv) {
|
||||||
.verbosity = 0,
|
.verbosity = 0,
|
||||||
.gen = &gen,
|
.gen = &gen,
|
||||||
.stream_token = stream_token,
|
.stream_token = stream_token,
|
||||||
|
.accept_token =
|
||||||
|
[&](int token, float /* prob */) {
|
||||||
|
return !reject_tokens.contains(token);
|
||||||
|
},
|
||||||
};
|
};
|
||||||
model.Generate(runtime_config, tokens, 0, kv_cache, timing_info);
|
model.Generate(runtime_config, tokens, 0, kv_cache, timing_info);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue