mirror of https://github.com/google/gemma.cpp.git
Update run.cc, CMakeLists and README for incompatible code, dependency changes and argument updates
This commit is contained in:
parent
4a924f1794
commit
0ea118ebbe
|
|
@ -14,12 +14,13 @@
|
||||||
|
|
||||||
cmake_minimum_required(VERSION 3.11)
|
cmake_minimum_required(VERSION 3.11)
|
||||||
project(hello_world)
|
project(hello_world)
|
||||||
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG c5bebf84ad01edec97e336f5c97ca4e0df6b4d06)
|
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG bfc36a6e633af94e63ac4b91c687bf0354cb24e0)
|
||||||
FetchContent_MakeAvailable(highway)
|
FetchContent_MakeAvailable(highway)
|
||||||
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c)
|
FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 9045b2f60fa2b323dfac0eaef8fc17565036f9f9)
|
||||||
FetchContent_MakeAvailable(sentencepiece)
|
FetchContent_MakeAvailable(sentencepiece)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -31,9 +32,9 @@ if (NOT BUILD_MODE)
|
||||||
endif()
|
endif()
|
||||||
if (BUILD_MODE STREQUAL "local")
|
if (BUILD_MODE STREQUAL "local")
|
||||||
# Relative path to gemma.cpp from examples/hello_world/build/
|
# Relative path to gemma.cpp from examples/hello_world/build/
|
||||||
FetchContent_Declare(gemma SOURCE_DIR ../../..)
|
FetchContent_Declare(gemma SOURCE_DIR ../../..)
|
||||||
else()
|
else()
|
||||||
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG a9aa63fd2ea6b786ed0706d619588bfe2d43370e)
|
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 4a924f179448dc83e46a2af9520c61b4ef56174c)
|
||||||
endif()
|
endif()
|
||||||
FetchContent_MakeAvailable(gemma)
|
FetchContent_MakeAvailable(gemma)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ gemma.cpp specifying the tokenizer, compressed weights file, and model type, for
|
||||||
example:
|
example:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
./hello_world --tokenizer tokenizer.spm --compressed_weights 2b-it-sfp.sbs --model 2b-it
|
./hello_world --tokenizer tokenizer.spm --weights 2b-it-sfp.sbs --model 2b-it
|
||||||
```
|
```
|
||||||
|
|
||||||
Should print a greeting to the terminal:
|
Should print a greeting to the terminal:
|
||||||
|
|
|
||||||
|
|
@ -32,76 +32,76 @@
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char **argv) { {
|
||||||
{
|
// Placeholder for internal init, do not modify.
|
||||||
// Placeholder for internal init, do not modify.
|
|
||||||
}
|
|
||||||
|
|
||||||
gcpp::LoaderArgs loader(argc, argv);
|
|
||||||
gcpp::InferenceArgs inference(argc, argv);
|
|
||||||
gcpp::AppArgs app(argc, argv);
|
|
||||||
if (gcpp::HasHelp(argc, argv)) {
|
|
||||||
loader.Help();
|
|
||||||
return 0;
|
|
||||||
} else if (const char* error = loader.Validate()) {
|
|
||||||
loader.Help();
|
|
||||||
HWY_ABORT("\nInvalid args: %s", error);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Demonstrate constrained decoding by never outputting certain tokens.
|
|
||||||
std::set<int> reject_tokens;
|
|
||||||
for (int arg = 0; arg < argc; ++arg) {
|
|
||||||
// Find a --reject flag and consume everything after it.
|
|
||||||
if (strcmp(argv[arg], "--reject") == 0) {
|
|
||||||
while (++arg < argc) reject_tokens.insert(atoi(argv[arg]));
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Instantiate model and KV Cache
|
gcpp::LoaderArgs loader(argc, argv);
|
||||||
gcpp::BoundedTopology topology(gcpp::CreateTopology(app));
|
gcpp::InferenceArgs inference(argc, argv);
|
||||||
gcpp::NestedPools pools = gcpp::CreatePools(topology, app);
|
gcpp::AppArgs app(argc, argv);
|
||||||
gcpp::MatMulEnv env(topology, pools);
|
if (gcpp::HasHelp(argc, argv)) {
|
||||||
gcpp::Gemma model = gcpp::CreateGemma(loader, env);
|
loader.Help();
|
||||||
gcpp::KVCache kv_cache =
|
return 0;
|
||||||
gcpp::KVCache::Create(model.GetModelConfig(),
|
} else if (const char *error = loader.Validate()) {
|
||||||
inference.prefill_tbatch_size);
|
loader.Help();
|
||||||
size_t generated = 0;
|
HWY_ABORT("\nInvalid args: %s", error);
|
||||||
|
|
||||||
// Initialize random number generator
|
|
||||||
std::mt19937 gen;
|
|
||||||
std::random_device rd;
|
|
||||||
gen.seed(rd());
|
|
||||||
|
|
||||||
// Tokenize instructions.
|
|
||||||
std::string prompt = "Write a greeting to the world.";
|
|
||||||
const std::vector<int> tokens = gcpp::WrapAndTokenize(
|
|
||||||
model.Tokenizer(), loader.Info(), generated, prompt);
|
|
||||||
const size_t prompt_size = tokens.size();
|
|
||||||
|
|
||||||
// This callback function gets invoked every time a token is generated
|
|
||||||
auto stream_token = [&generated, &prompt_size, &model](int token, float) {
|
|
||||||
++generated;
|
|
||||||
if (generated < prompt_size) {
|
|
||||||
// print feedback
|
|
||||||
} else if (!model.GetModelConfig().IsEOS(token)) {
|
|
||||||
std::string token_text;
|
|
||||||
HWY_ASSERT(model.Tokenizer().Decode({token}, &token_text));
|
|
||||||
std::cout << token_text << std::flush;
|
|
||||||
}
|
}
|
||||||
return true;
|
|
||||||
};
|
|
||||||
|
|
||||||
gcpp::TimingInfo timing_info;
|
// Demonstrate constrained decoding by never outputting certain tokens.
|
||||||
gcpp::RuntimeConfig runtime_config = {
|
std::set<int> reject_tokens;
|
||||||
.max_generated_tokens = 1024,
|
for (int arg = 0; arg < argc; ++arg) {
|
||||||
.temperature = 1.0,
|
// Find a --reject flag and consume everything after it.
|
||||||
.gen = &gen,
|
if (strcmp(argv[arg], "--reject") == 0) {
|
||||||
.verbosity = 0,
|
while (++arg < argc) reject_tokens.insert(atoi(argv[arg]));
|
||||||
.stream_token = stream_token,
|
}
|
||||||
.accept_token =
|
}
|
||||||
[&](int token, float /* prob */) {
|
|
||||||
return !reject_tokens.contains(token);
|
// Instantiate model and KV Cache
|
||||||
},
|
gcpp::BoundedTopology topology(gcpp::CreateTopology(app));
|
||||||
};
|
gcpp::NestedPools pools = gcpp::CreatePools(topology, app);
|
||||||
model.Generate(runtime_config, tokens, 0, kv_cache, timing_info);
|
gcpp::MatMulEnv env(topology, pools);
|
||||||
|
gcpp::Gemma model = gcpp::CreateGemma(loader, env);
|
||||||
|
gcpp::KVCache kv_cache =
|
||||||
|
gcpp::KVCache::Create(model.GetModelConfig(),
|
||||||
|
inference.prefill_tbatch_size);
|
||||||
|
size_t generated = 0;
|
||||||
|
|
||||||
|
// Initialize random number generator
|
||||||
|
std::mt19937 gen;
|
||||||
|
std::random_device rd;
|
||||||
|
gen.seed(rd());
|
||||||
|
|
||||||
|
// Tokenize instructions.
|
||||||
|
std::string prompt = "Write a greeting to the world.";
|
||||||
|
const std::vector<int> tokens = gcpp::WrapAndTokenize(
|
||||||
|
model.Tokenizer(), loader.Info(), generated, prompt);
|
||||||
|
const size_t prompt_size = tokens.size();
|
||||||
|
|
||||||
|
// This callback function gets invoked every time a token is generated
|
||||||
|
auto stream_token = [&generated, &prompt_size, &model](int token, float) {
|
||||||
|
++generated;
|
||||||
|
if (generated < prompt_size) {
|
||||||
|
// print feedback
|
||||||
|
} else if (!model.GetModelConfig().IsEOS(token)) {
|
||||||
|
std::string token_text;
|
||||||
|
HWY_ASSERT(model.Tokenizer().Decode({token}, &token_text));
|
||||||
|
std::cout << token_text << std::flush;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
gcpp::TimingInfo timing_info;
|
||||||
|
gcpp::RuntimeConfig runtime_config = {
|
||||||
|
.max_generated_tokens = 1024,
|
||||||
|
.temperature = 1.0,
|
||||||
|
.gen = &gen,
|
||||||
|
.verbosity = 0,
|
||||||
|
.stream_token = stream_token,
|
||||||
|
.accept_token =
|
||||||
|
std::function<bool(int, float)>(
|
||||||
|
[&](int token, float /* prob */) {
|
||||||
|
return reject_tokens.find(token) == reject_tokens.end();
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
model.Generate(runtime_config, tokens, 0, kv_cache, timing_info);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue