diff --git a/src/llama-quant.h b/src/llama-quant.h index fc7da927cd..7c89d81d70 100644 --- a/src/llama-quant.h +++ b/src/llama-quant.h @@ -29,13 +29,23 @@ struct quantize_state_impl { bool has_imatrix = false; - // used to figure out if a model shares tok_embd with the output weight - bool has_output = false; + // used to figure out if a model has tied embeddings (tok_embd shares weights with output) + bool has_tied_embeddings = false; // assume tied until we see output.weight - quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params) - : model(model) - , params(params) - {} + // tensor type override patterns (compiled once, used twice) + std::vector> tensor_type_patterns; + + quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params): + model(model), params(params) + { + // compile regex patterns once - they are expensive + if (params->tensor_types) { + const auto & tensor_types = *static_cast *>(params->tensor_types); + for (const auto & [tname, qtype] : tensor_types) { + tensor_type_patterns.emplace_back(std::regex(tname), qtype); + } + } + } }; ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype);