diff --git a/src/llama-quant.h b/src/llama-quant.h index 7c89d81d70..f3c1290c58 100644 --- a/src/llama-quant.h +++ b/src/llama-quant.h @@ -6,11 +6,19 @@ #include "llama-arch.h" +#include #include #include struct llama_model; +// result of parsing --tensor-type option +// (changes to this struct must be reflected in tools/quantize/quantize.cpp) +struct tensor_type_option { + std::string name; + ggml_type type = GGML_TYPE_COUNT; +}; + struct quantize_state_impl { const llama_model & model; const llama_model_quantize_params * params; @@ -30,7 +38,7 @@ struct quantize_state_impl { bool has_imatrix = false; // used to figure out if a model has tied embeddings (tok_embd shares weights with output) - bool has_tied_embeddings = false; // assume tied until we see output.weight + bool has_tied_embeddings = true; // assume tied until we see output.weight // tensor type override patterns (compiled once, used twice) std::vector> tensor_type_patterns; diff --git a/tests/test-quant-type-selection.cpp b/tests/test-quant-type-selection.cpp index 8dbfe34567..3384e4af11 100644 --- a/tests/test-quant-type-selection.cpp +++ b/tests/test-quant-type-selection.cpp @@ -309,7 +309,7 @@ static std::string generate_snapshot(const std::string & name, for (int i = 0; i < LLAMA_FTYPE_GUESSED; i++) { llama_ftype ft = (llama_ftype) i; - ggml_type default_type = llama_ftype_default_type(ft); + ggml_type default_type = llama_ftype_get_default_type(ft); if (default_type == GGML_TYPE_COUNT) { continue; } @@ -384,8 +384,8 @@ static int run_generate(const std::string & snapshot_dir) { static bool run_test_section(llama_model & mdl, const std::vector & tensors, const snapshot_section & section) { - // verify default_type matches what llama_ftype_default_type returns - ggml_type computed_default = llama_ftype_default_type(section.ftype); + // verify default_type matches what llama_ftype_get_default_type returns + ggml_type computed_default = llama_ftype_get_default_type(section.ftype); if (computed_default != section.default_type) { printf(" FAIL [%s] default type mismatch: file says %s, code says %s\n", llama_ftype_to_name(section.ftype), ggml_type_name(section.default_type), ggml_type_name(computed_default)); @@ -408,7 +408,7 @@ static bool run_test_section(llama_model & mdl, } if (got != expected) { - printf(" FAIL %-50s expected %s, got %s\n", name.c_str(), ggml_type_name(expected), ggml_type_name(got)); + printf(" FAIL [%s] %-50s expected %s, got %s\n", llama_ftype_to_name(section.ftype), name.c_str(), ggml_type_name(expected), ggml_type_name(got)); all_pass = false; } }