Allow interactive use with new single-file weight format.

Add section about new weights format to README.md.
Remove model_type_required parameter.
Update error handling for flags.

PiperOrigin-RevId: 715788822
This commit is contained in:
Daniel Keysers 2025-01-15 07:22:00 -08:00 committed by Copybara-Service
parent b93231a47d
commit 493688f6f1
8 changed files with 59 additions and 35 deletions

View File

@ -305,6 +305,24 @@ A tall tree stands in front of the building, and a window on the building is
visible from the water. The water is green, and the sky is blue. visible from the water. The water is green, and the sky is blue.
``` ```
### Migrating to single-file format
There is now a new format for the weights file, which is a single file that
allows to contain the tokenizer (and the model type) directly. A tool to migrate
from the multi-file format to the single-file format is available.
```sh
compression/migrate_weights \
--tokenizer .../tokenizer.spm --weights .../gemma2-2b-it-sfp.sbs \
--model gemma2-2b-it --output_weights .../gemma2-2b-it-sfp-single.sbs
```
After migration, you can use the new weights file with gemma.cpp like this:
```sh
./gemma --weights .../gemma2-2b-it-sfp-single.sbs
```
### Troubleshooting and FAQs ### Troubleshooting and FAQs
**Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."** **Running `./gemma` fails with "Failed to read cache gating_ein_0 (error 294) ..."**
@ -331,9 +349,8 @@ and not a pre-trained model (any model with a `-pt` suffix).
**How do I convert my fine-tune to a `.sbs` compressed model file?** **How do I convert my fine-tune to a `.sbs` compressed model file?**
We're working on a python script to convert a standard model format to `.sbs`, See compression/convert_weights.py to convert a pytorch checkpint. (The code may
and hope have it available soon. Follow need updates to work with Gemma-2 models.)
[this issue](https://github.com/google/gemma.cpp/issues/11) for updates.
**What are some easy ways to make the model run faster?** **What are some easy ways to make the model run faster?**

View File

@ -55,7 +55,7 @@ int main(int argc, char** argv) {
fprintf(stderr, "Skipping model load because: %s\n", err); fprintf(stderr, "Skipping model load because: %s\n", err);
return 1; return 1;
} }
gcpp::GemmaEnv env(argc, argv, /*required=*/true); gcpp::GemmaEnv env(argc, argv);
hwy::ThreadPool pool(0); hwy::ThreadPool pool(0);
env.GetModel()->Save(args.output_weights, pool); env.GetModel()->Save(args.output_weights, pool);
return 0; return 0;

View File

@ -92,9 +92,9 @@ static AppArgs MakeAppArgs(int argc, char** argv) {
return AppArgs(argc, argv); return AppArgs(argc, argv);
} }
GemmaEnv::GemmaEnv(int argc, char** argv, bool model_type_required) GemmaEnv::GemmaEnv(int argc, char** argv)
: GemmaEnv(LoaderArgs(argc, argv, model_type_required), : GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv),
InferenceArgs(argc, argv), MakeAppArgs(argc, argv)) {} MakeAppArgs(argc, argv)) {}
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) { QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
QueryResult result; QueryResult result;

View File

@ -44,7 +44,7 @@ struct QueryResult {
class GemmaEnv { class GemmaEnv {
public: public:
// Calls the other constructor with *Args arguments initialized from argv. // Calls the other constructor with *Args arguments initialized from argv.
GemmaEnv(int argc, char** argv, bool model_type_required = false); GemmaEnv(int argc, char** argv);
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
const AppArgs& app); const AppArgs& app);

View File

@ -28,6 +28,7 @@
// This test can be run manually with the downloaded gemma weights. // This test can be run manually with the downloaded gemma weights.
// To run the test, pass the following flags: // To run the test, pass the following flags:
// --model <model> --tokenizer <tokenizer_path> --weights <weights_path> // --model <model> --tokenizer <tokenizer_path> --weights <weights_path>
// or just use the single-file weights file with --weights <weights_path>.
// It should pass for the following models: // It should pass for the following models:
// Gemma1: 2b-it (v1 and v1.1), 7b-it (v1 and v1.1), gr2b-it, // Gemma1: 2b-it (v1 and v1.1), 7b-it (v1 and v1.1), gr2b-it,
// Gemma2: gemma2-2b-it, 9b-it, 27b-it, // Gemma2: gemma2-2b-it, 9b-it, 27b-it,

View File

@ -525,9 +525,9 @@ class ModelWeightsStorage {
// Loads the weights from a blob store file. Supports multi-file or // Loads the weights from a blob store file. Supports multi-file or
// single-file format. If the weights file contains a TOC, then it is in // single-file format. If the weights file contains a TOC, then it is in
// single-file format, and model_type, weight_type, training are ignored, // single-file format, and model_type, weight_type, wrapping are ignored,
// and tokenizer_proto is required and written to. // and tokenizer_proto is required and written to.
// With a multi-file format, file, model_type, weight_type, training are // With a multi-file format, file, model_type, weight_type, wrapping are
// required and tokenizer_proto is ignored. // required and tokenizer_proto is ignored.
BlobError Load(const Path& weights, Model model_type, Type weight_type, BlobError Load(const Path& weights, Model model_type, Type weight_type,
PromptWrapping wrapping, hwy::ThreadPool& pool, PromptWrapping wrapping, hwy::ThreadPool& pool,

View File

@ -27,6 +27,7 @@
// This test can be run manually with the downloaded PaliGemma weights. // This test can be run manually with the downloaded PaliGemma weights.
// To run the test, pass the following flags: // To run the test, pass the following flags:
// --model paligemma-224 --tokenizer <tokenizer_path> --weights <weights_path> // --model paligemma-224 --tokenizer <tokenizer_path> --weights <weights_path>
// or just use the single-file weights file with --weights <weights_path>.
// It should pass for the following models: // It should pass for the following models:
// paligemma-3b-mix-224, paligemma2-3b-pt-448 // paligemma-3b-mix-224, paligemma2-3b-pt-448

View File

@ -126,8 +126,7 @@ static inline NestedPools CreatePools(const AppArgs& app) {
} }
struct LoaderArgs : public ArgsBase<LoaderArgs> { struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[], bool required = true) LoaderArgs(int argc, char* argv[]) {
: model_type_required(required) {
InitAndParse(argc, argv); InitAndParse(argc, argv);
} }
LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path, LoaderArgs(const std::string& tokenizer_path, const std::string& weights_path,
@ -140,25 +139,6 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
// Returns error string or nullptr if OK. // Returns error string or nullptr if OK.
const char* Validate() { const char* Validate() {
info_.model = Model::UNKNOWN;
info_.wrapping = PromptWrapping::GEMMA_PT;
info_.weight = Type::kUnknown;
if (const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
info_.wrapping)) {
if (model_type_required) return err;
}
if (const char* err = ParseType(weight_type_str, info_.weight)) {
if (model_type_required) return err;
}
if (model_type_required) {
if (tokenizer.path.empty()) {
return "Missing --tokenizer flag, a file for the tokenizer is "
"required.";
}
if (!tokenizer.Exists()) {
return "Can't open file specified with --tokenizer flag.";
}
}
if (!compressed_weights.path.empty()) { if (!compressed_weights.path.empty()) {
if (weights.path.empty()) { if (weights.path.empty()) {
weights = compressed_weights; weights = compressed_weights;
@ -174,6 +154,28 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
if (!weights.Exists()) { if (!weights.Exists()) {
return "Can't open file specified with --weights flag."; return "Can't open file specified with --weights flag.";
} }
info_.model = Model::UNKNOWN;
info_.wrapping = PromptWrapping::GEMMA_PT;
info_.weight = Type::kUnknown;
if (!model_type_str.empty()) {
const char* err = ParseModelTypeAndWrapping(model_type_str, info_.model,
info_.wrapping);
if (err != nullptr) return err;
}
if (!weight_type_str.empty()) {
const char* err = ParseType(weight_type_str, info_.weight);
if (err != nullptr) return err;
}
if (!tokenizer.path.empty()) {
if (!tokenizer.Exists()) {
return "Can't open file specified with --tokenizer flag.";
}
}
// model_type and tokenizer must be either both present or both absent.
// Further checks happen on weight loading.
if (model_type_str.empty() != tokenizer.path.empty()) {
return "Missing or extra flags for model_type or tokenizer.";
}
return nullptr; return nullptr;
} }
@ -182,7 +184,6 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
Path compressed_weights; Path compressed_weights;
std::string model_type_str; std::string model_type_str;
std::string weight_type_str; std::string weight_type_str;
bool model_type_required = true;
template <class Visitor> template <class Visitor>
void ForEach(const Visitor& visitor) { void ForEach(const Visitor& visitor) {
@ -199,7 +200,7 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
"gr2b-it = griffin 2B parameters, instruction-tuned\n " "gr2b-it = griffin 2B parameters, instruction-tuned\n "
"gr2b-pt = griffin 2B parameters, pretrained."); "gr2b-pt = griffin 2B parameters, pretrained.");
visitor(weight_type_str, "weight_type", std::string("sfp"), visitor(weight_type_str, "weight_type", std::string("sfp"),
"Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit FP."); "Weight type\n f32 = float, bf16 = bfloat16, sfp = 8-bit SFP.");
} }
// Uninitialized before Validate, must call after that. // Uninitialized before Validate, must call after that.
@ -212,6 +213,11 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
}; };
static inline Gemma CreateGemma(const LoaderArgs& loader, NestedPools& pools) { static inline Gemma CreateGemma(const LoaderArgs& loader, NestedPools& pools) {
if (Type::kUnknown == loader.Info().weight ||
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
// New weights file format doesn't need tokenizer path or model/weightinfo.
return Gemma(loader.weights, pools);
}
return Gemma(loader.tokenizer, loader.weights, loader.Info(), pools); return Gemma(loader.tokenizer, loader.weights, loader.Info(), pools);
} }
@ -219,8 +225,7 @@ static inline std::unique_ptr<Gemma> AllocateGemma(const LoaderArgs& loader,
NestedPools& pools) { NestedPools& pools) {
if (Type::kUnknown == loader.Info().weight || if (Type::kUnknown == loader.Info().weight ||
Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) { Model::UNKNOWN == loader.Info().model || loader.tokenizer.path.empty()) {
// Newer weights file format doesn't need tokenizer path or model/weight // New weights file format doesn't need tokenizer path or model/weight info.
// info.
return std::make_unique<Gemma>(loader.weights, pools); return std::make_unique<Gemma>(loader.weights, pools);
} }
return std::make_unique<Gemma>(loader.tokenizer, loader.weights, return std::make_unique<Gemma>(loader.tokenizer, loader.weights,