mirror of https://github.com/google/gemma.cpp.git
Allow interactive use with new single-file weight format.
Add section about new weights format to README.md. Remove model_type_required parameter. Update error handling for flags. PiperOrigin-RevId: 715788822
This commit is contained in:
parent
b93231a47d
commit
493688f6f1
23
README.md
23
README.md
|
|
@ -305,6 +305,24 @@ A tall tree stands in front of the building, and a window on the building is
|
||||||
visible from the water. The water is green, and the sky is blue.
|
visible from the water. The water is green, and the sky is blue.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Migrating to single-file format
|
||||||
|
|
||||||
|
There is now a new format for the weights file, which is a single file that
|
||||||
|
allows to contain the tokenizer (and the model type) directly. A tool to migrate
|
||||||
|
from the multi-file format to the single-file format is available.
|
||||||
|
|
||||||
|
```sh
|
||||||
|
compression/migrate_weights \
|
||||||
|
--tokenizer .../tokenizer.spm --weights .../gemma2-2b-it-sfp.sbs \
|
||||||
|
--model gemma2-2b-it --output_weights .../gemma2-2b-it-sfp-single.sbs
|
||||||
|
```
|
||||||
|
|
||||||
|
After migration, you can use the new weights file with gemma.cpp like this:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./gemma --weights .../gemma2-2b-it-sfp-single.sbs
|
||||||
|
```
|
||||||
|
|
||||||
### Troubleshooting and FAQs
|
### Troubleshooting and FAQs
|
||||||
|
|
||||||
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
|
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
|
||||||
|
|
@ -331,9 +349,8 @@ and not a pre-trained model (any model with a `-pt` suffix).
|
||||||
|
|
||||||
**How do I convert my fine-tune to a `.sbs` compressed model file?**
|
**How do I convert my fine-tune to a `.sbs` compressed model file?**
|
||||||
|
|
||||||
We're working on a python script to convert a standard model format to `.sbs`,
|
See compression/convert_weights.py to convert a pytorch checkpint. (The code may
|
||||||
and hope have it available soon. Follow
|
need updates to work with Gemma-2 models.)
|
||||||
[this issue](https://github.com/google/gemma.cpp/issues/11) for updates.
|
|
||||||
|
|
||||||
**What are some easy ways to make the model run faster?**
|
**What are some easy ways to make the model run faster?**
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ int main(int argc, char** argv) {
|
||||||
fprintf(stderr, "Skipping model load because: %s\n", err);
|
fprintf(stderr, "Skipping model load because: %s\n", err);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
gcpp::GemmaEnv env(argc, argv, /*required=*/true);
|
gcpp::GemmaEnv env(argc, argv);
|
||||||
hwy::ThreadPool pool(0);
|
hwy::ThreadPool pool(0);
|
||||||
env.GetModel()->Save(args.output_weights, pool);
|
env.GetModel()->Save(args.output_weights, pool);
|
||||||
return 0;
|
return 0;
|
||||||
|
|
|
||||||
|
|
@ -92,9 +92,9 @@ static AppArgs MakeAppArgs(int argc, char** argv) {
|
||||||
return AppArgs(argc, argv);
|
return AppArgs(argc, argv);
|
||||||
}
|
}
|
||||||
|
|
||||||
GemmaEnv::GemmaEnv(int argc, char** argv, bool model_type_required)
|
GemmaEnv::GemmaEnv(int argc, char** argv)
|
||||||
: GemmaEnv(LoaderArgs(argc, argv, model_type_required),
|
: GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv),
|
||||||
InferenceArgs(argc, argv), MakeAppArgs(argc, argv)) {}
|
MakeAppArgs(argc, argv)) {}
|
||||||
|
|
||||||
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
||||||
QueryResult result;
|
QueryResult result;
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ struct QueryResult {
|
||||||
class GemmaEnv {
|
class GemmaEnv {
|
||||||
public:
|
public:
|
||||||
// Calls the other constructor with *Args arguments initialized from argv.
|
// Calls the other constructor with *Args arguments initialized from argv.
|
||||||
GemmaEnv(int argc, char** argv, bool model_type_required = false);
|
GemmaEnv(int argc, char** argv);
|
||||||
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
const AppArgs& app);
|
const AppArgs& app);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@
|
||||||
// This test can be run manually with the downloaded gemma weights.
|
// This test can be run manually with the downloaded gemma weights.
|
||||||
// To run the test, pass the following flags:
|
// To run the test, pass the following flags:
|
||||||
// --model <model> --tokenizer <tokenizer_path> --weights <weights_path>
|
// --model <model> --tokenizer <tokenizer_path> --weights <weights_path>
|
||||||
|
// or just use the single-file weights file with --weights <weights_path>.
|
||||||
// It should pass for the following models:
|
// It should pass for the following models:
|
||||||
// Gemma1: 2b-it (v1 and v1.1), 7b-it (v1 and v1.1), gr2b-it,
|
// Gemma1: 2b-it (v1 and v1.1), 7b-it (v1 and v1.1), gr2b-it,
|
||||||
// Gemma2: gemma2-2b-it, 9b-it, 27b-it,
|
// Gemma2: gemma2-2b-it, 9b-it, 27b-it,
|
||||||
|
|
|
||||||
|
|
@ -525,9 +525,9 @@ class ModelWeightsStorage {
|
||||||
|
|
||||||
// Loads the weights from a blob store file. Supports multi-file or
|
// Loads the weights from a blob store file. Supports multi-file or
|
||||||
// single-file format. If the weights file contains a TOC, then it is in
|
// single-file format. If the weights file contains a TOC, then it is in
|
||||||
// single-file format, and model_type, weight_type, training are ignored,
|
// single-file format, and model_type, weight_type, wrapping are ignored,
|
||||||
// and tokenizer_proto is required and written to.
|
// and tokenizer_proto is required and written to.
|
||||||
// With a multi-file format, file, model_type, weight_type, training are
|
// With a multi-file format, file, model_type, weight_type, wrapping are
|
||||||
// required and tokenizer_proto is ignored.
|
// required and tokenizer_proto is ignored.
|
||||||
BlobError Load(const Path& weights, Model model_type, Type weight_type,
|
BlobError Load(const Path& weights, Model model_type, Type weight_type,
|
||||||
PromptWrapping wrapping, hwy::ThreadPool& pool,
|
PromptWrapping wrapping, hwy::ThreadPool& pool,
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@
|
||||||
// This test can be run manually with the downloaded PaliGemma weights.
|
// This test can be run manually with the downloaded PaliGemma weights.
|
||||||
// To run the test, pass the following flags:
|
// To run the test, pass the following flags:
|
||||||
// --model paligemma-224 --tokenizer <tokenizer_path> --weights <weights_path>
|
// --model paligemma-224 --tokenizer <tokenizer_path> --weights <weights_path>
|
||||||
|
// or just use the single-file weights file with --weights <weights_path>.
|
||||||
// It should pass for the following models:
|
// It should pass for the following models:
|
||||||
// paligemma-3b-mix-224, paligemma2-3b-pt-448
|
// paligemma-3b-mix-224, paligemma2-3b-pt-448
|
||||||
|
|
||||||
|
|
|
||||||
55
util/app.h
55
util/app.h
|
|
@ -126,8 +126,7 @@ static inline NestedPools CreatePools(const AppArgs& app) {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
LoaderArgs(int argc, char* argv[], bool required = true)
|
LoaderArgs(int argc, char* argv[]) {
|
||||||
: model_type_required(required) {
|
|
||||||
InitAndParse(argc, argv);
|
InitAndParse(argc, argv);
|
||||||
}
|
}
|
||||||
LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path,
|
LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path,
|
||||||
|
|
@ -140,25 +139,6 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
// Returns error string or nullptr if OK.
|
||||||
const char* Validate() {
|
const char* Validate() {
|
||||||
info_.model = Model::UNKNOWN;
|
|
||||||
info_.wrapping = PromptWrapping::GEMMA_PT;
|
|
||||||
info_.weight = Type::kUnknown;
|
|
||||||
if (const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
|
|
||||||
info_.wrapping)) {
|
|
||||||
if (model_type_required) return err;
|
|
||||||
}
|
|
||||||
if (const char* err = ParseType(weight_type_str, info_.weight)) {
|
|
||||||
if (model_type_required) return err;
|
|
||||||
}
|
|
||||||
if (model_type_required) {
|
|
||||||
if (tokenizer.path.empty()) {
|
|
||||||
return "Missing --tokenizer flag, a file for the tokenizer is "
|
|
||||||
"required.";
|
|
||||||
}
|
|
||||||
if (!tokenizer.Exists()) {
|
|
||||||
return "Can't open file specified with --tokenizer flag.";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!compressed_weights.path.empty()) {
|
if (!compressed_weights.path.empty()) {
|
||||||
if (weights.path.empty()) {
|
if (weights.path.empty()) {
|
||||||
weights = compressed_weights;
|
weights = compressed_weights;
|
||||||
|
|
@ -174,6 +154,28 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
if (!weights.Exists()) {
|
if (!weights.Exists()) {
|
||||||
return "Can't open file specified with --weights flag.";
|
return "Can't open file specified with --weights flag.";
|
||||||
}
|
}
|
||||||
|
info_.model = Model::UNKNOWN;
|
||||||
|
info_.wrapping = PromptWrapping::GEMMA_PT;
|
||||||
|
info_.weight = Type::kUnknown;
|
||||||
|
if (!model_type_str.empty()) {
|
||||||
|
const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
|
||||||
|
info_.wrapping);
|
||||||
|
if (err != nullptr) return err;
|
||||||
|
}
|
||||||
|
if (!weight_type_str.empty()) {
|
||||||
|
const char* err = ParseType(weight_type_str, info_.weight);
|
||||||
|
if (err != nullptr) return err;
|
||||||
|
}
|
||||||
|
if (!tokenizer.path.empty()) {
|
||||||
|
if (!tokenizer.Exists()) {
|
||||||
|
return "Can't open file specified with --tokenizer flag.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// model_type and tokenizer must be either both present or both absent.
|
||||||
|
// Further checks happen on weight loading.
|
||||||
|
if (model_type_str.empty() != tokenizer.path.empty()) {
|
||||||
|
return "Missing or extra flags for model_type or tokenizer.";
|
||||||
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -182,7 +184,6 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
Path compressed_weights;
|
Path compressed_weights;
|
||||||
std::string model_type_str;
|
std::string model_type_str;
|
||||||
std::string weight_type_str;
|
std::string weight_type_str;
|
||||||
bool model_type_required = true;
|
|
||||||
|
|
||||||
template <class Visitor>
|
template <class Visitor>
|
||||||
void ForEach(const Visitor& visitor) {
|
void ForEach(const Visitor& visitor) {
|
||||||
|
|
@ -199,7 +200,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
|
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
|
||||||
"gr2b-pt = griffin 2B parameters, pretrained.");
|
"gr2b-pt = griffin 2B parameters, pretrained.");
|
||||||
visitor(weight_type_str, "weight_type", std::string("sfp"),
|
visitor(weight_type_str, "weight_type", std::string("sfp"),
|
||||||
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit FP.");
|
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit SFP.");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Uninitialized before Validate, must call after that.
|
// Uninitialized before Validate, must call after that.
|
||||||
|
|
@ -212,6 +213,11 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
};
|
};
|
||||||
|
|
||||||
static inline Gemma CreateGemma(const LoaderArgs& loader, NestedPools& pools) {
|
static inline Gemma CreateGemma(const LoaderArgs& loader, NestedPools& pools) {
|
||||||
|
if (Type::kUnknown == loader.Info().weight ||
|
||||||
|
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
|
||||||
|
// New weights file format doesn't need tokenizer path or model/weightinfo.
|
||||||
|
return Gemma(loader.weights, pools);
|
||||||
|
}
|
||||||
return Gemma(loader.tokenizer, loader.weights, loader.Info(), pools);
|
return Gemma(loader.tokenizer, loader.weights, loader.Info(), pools);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -219,8 +225,7 @@ static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
|
||||||
NestedPools& pools) {
|
NestedPools& pools) {
|
||||||
if (Type::kUnknown == loader.Info().weight ||
|
if (Type::kUnknown == loader.Info().weight ||
|
||||||
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
|
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
|
||||||
// Newer weights file format doesn't need tokenizer path or model/weight
|
// New weights file format doesn't need tokenizer path or model/weight info.
|
||||||
// info.
|
|
||||||
return std::make_unique<Gemma>(loader.weights, pools);
|
return std::make_unique<Gemma>(loader.weights, pools);
|
||||||
}
|
}
|
||||||
return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
|
return std::make_unique<Gemma>(loader.tokenizer, loader.weights,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue