mirror of https://github.com/google/gemma.cpp.git
update loader arg names: cache -> compressed_weights, model -> weights
This commit is contained in:
parent
dfd2fdc1dd
commit
03147effbd
|
|
@ -32,7 +32,7 @@ endif()
|
||||||
if (BUILD_MODE STREQUAL "local")
|
if (BUILD_MODE STREQUAL "local")
|
||||||
FetchContent_Declare(gemma SOURCE_DIR ../../..)
|
FetchContent_Declare(gemma SOURCE_DIR ../../..)
|
||||||
else()
|
else()
|
||||||
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 42e53e2da89f80dc46399c7037fbbfb15cdc3de3)
|
FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG dfd2fdc1dd8e7a84d2e2f9618334b87a79ba02b1)
|
||||||
endif()
|
endif()
|
||||||
FetchContent_MakeAvailable(gemma)
|
FetchContent_MakeAvailable(gemma)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,8 @@ int main(int argc, char** argv) {
|
||||||
hwy::ThreadPool pool(num_threads);
|
hwy::ThreadPool pool(num_threads);
|
||||||
|
|
||||||
// Instantiate model and KV Cache
|
// Instantiate model and KV Cache
|
||||||
gcpp::Gemma model(loader.tokenizer, loader.cache, loader.ModelType(), pool);
|
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights,
|
||||||
|
loader.ModelType(), pool);
|
||||||
auto kv_cache = CreateKVCache(loader.ModelType());
|
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||||
size_t pos = 0; // KV Cache position
|
size_t pos = 0; // KV Cache position
|
||||||
|
|
||||||
|
|
|
||||||
10
gemma.h
10
gemma.h
|
|
@ -114,7 +114,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
if (tokenizer.path.empty()) {
|
if (tokenizer.path.empty()) {
|
||||||
return "Missing --tokenizer flag, a file for the tokenizer is required.";
|
return "Missing --tokenizer flag, a file for the tokenizer is required.";
|
||||||
}
|
}
|
||||||
if (cache.path.empty()) {
|
if (compressed_weights.path.empty()) {
|
||||||
return "Missing --compressed_weights flag, a file for the compressed "
|
return "Missing --compressed_weights flag, a file for the compressed "
|
||||||
"model.";
|
"model.";
|
||||||
}
|
}
|
||||||
|
|
@ -122,8 +122,8 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
}
|
}
|
||||||
|
|
||||||
Path tokenizer;
|
Path tokenizer;
|
||||||
Path model; // uncompressed weights OR
|
Path weights; // uncompressed weights file location
|
||||||
Path cache; // compressed weights (TODO: update name)
|
Path compressed_weights; // compressed weights file location
|
||||||
std::string model_type;
|
std::string model_type;
|
||||||
|
|
||||||
template <class Visitor>
|
template <class Visitor>
|
||||||
|
|
@ -131,7 +131,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
visitor(tokenizer, "tokenizer", Path(),
|
visitor(tokenizer, "tokenizer", Path(),
|
||||||
"Path name of tokenizer model file.\n Required argument.");
|
"Path name of tokenizer model file.\n Required argument.");
|
||||||
visitor(
|
visitor(
|
||||||
cache, "compressed_weights", Path(),
|
compressed_weights, "compressed_weights", Path(),
|
||||||
"Path name of compressed weights file, regenerated from `--weights` "
|
"Path name of compressed weights file, regenerated from `--weights` "
|
||||||
"file if "
|
"file if "
|
||||||
"the compressed weights file does not exist.\n Required argument.");
|
"the compressed weights file does not exist.\n Required argument.");
|
||||||
|
|
@ -140,7 +140,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
|
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
|
||||||
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n"
|
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n"
|
||||||
" Required argument.");
|
" Required argument.");
|
||||||
visitor(model, "weights", Path(),
|
visitor(weights, "weights", Path(),
|
||||||
"Path name of model weights (.sbs) file. Only required if "
|
"Path name of model weights (.sbs) file. Only required if "
|
||||||
"compressed_weights file is not present and needs to be "
|
"compressed_weights file is not present and needs to be "
|
||||||
"regenerated. This parameter is only required for compressing"
|
"regenerated. This parameter is only required for compressing"
|
||||||
|
|
|
||||||
3
run.cc
3
run.cc
|
|
@ -235,7 +235,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
|
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
|
||||||
}
|
}
|
||||||
|
|
||||||
gcpp::Gemma model(loader.tokenizer, loader.cache, loader.ModelType(), pool);
|
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights,
|
||||||
|
loader.ModelType(), pool);
|
||||||
|
|
||||||
auto kv_cache = CreateKVCache(loader.ModelType());
|
auto kv_cache = CreateKVCache(loader.ModelType());
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue