diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index 43159d7..97686dd 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -32,7 +32,7 @@ endif() if (BUILD_MODE STREQUAL "local") FetchContent_Declare(gemma SOURCE_DIR ../../..) else() - FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 42e53e2da89f80dc46399c7037fbbfb15cdc3de3) + FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG dfd2fdc1dd8e7a84d2e2f9618334b87a79ba02b1) endif() FetchContent_MakeAvailable(gemma) diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 8b5de24..7b57403 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -28,7 +28,8 @@ int main(int argc, char** argv) { hwy::ThreadPool pool(num_threads); // Instantiate model and KV Cache - gcpp::Gemma model(loader.tokenizer, loader.cache, loader.ModelType(), pool); + gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, + loader.ModelType(), pool); auto kv_cache = CreateKVCache(loader.ModelType()); size_t pos = 0; // KV Cache position diff --git a/gemma.h b/gemma.h index 7c08412..48fb52b 100644 --- a/gemma.h +++ b/gemma.h @@ -114,7 +114,7 @@ struct LoaderArgs : public ArgsBase { if (tokenizer.path.empty()) { return "Missing --tokenizer flag, a file for the tokenizer is required."; } - if (cache.path.empty()) { + if (compressed_weights.path.empty()) { return "Missing --compressed_weights flag, a file for the compressed " "model."; } @@ -122,8 +122,8 @@ struct LoaderArgs : public ArgsBase { } Path tokenizer; - Path model; // uncompressed weights OR - Path cache; // compressed weights (TODO: update name) + Path weights; // uncompressed weights file location + Path compressed_weights; // compressed weights file location std::string model_type; template @@ -131,7 +131,7 @@ struct LoaderArgs : public ArgsBase { visitor(tokenizer, "tokenizer", Path(), "Path name of tokenizer model file.\n Required argument."); visitor( - cache, "compressed_weights", Path(), + compressed_weights, "compressed_weights", Path(), "Path name of compressed weights file, regenerated from `--weights` " "file if " "the compressed weights file does not exist.\n Required argument."); @@ -140,7 +140,7 @@ struct LoaderArgs : public ArgsBase { "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters " "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n" " Required argument."); - visitor(model, "weights", Path(), + visitor(weights, "weights", Path(), "Path name of model weights (.sbs) file. Only required if " "compressed_weights file is not present and needs to be " "regenerated. This parameter is only required for compressing" diff --git a/run.cc b/run.cc index cdeb95e..c9fa78d 100644 --- a/run.cc +++ b/run.cc @@ -235,7 +235,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); } - gcpp::Gemma model(loader.tokenizer, loader.cache, loader.ModelType(), pool); + gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, + loader.ModelType(), pool); auto kv_cache = CreateKVCache(loader.ModelType());