diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index e0aa84379f..a103b15a4c 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -134,6 +134,26 @@ static bool category_is_attn_v(tensor_category cat) { cat == tensor_category::ATTENTION_KV_B; } +struct compiled_tensor_type_patterns { + std::vector> patterns; +}; + +quantize_state_impl::quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params) + : model(model), params(params) +{ + if (params->tensor_types) { + const auto & tensor_types = *static_cast *>(params->tensor_types); + if (!tensor_types.empty()) { + tensor_type_patterns = std::make_unique(); + for (const auto & [tname, qtype] : tensor_types) { + tensor_type_patterns->patterns.emplace_back(std::regex(tname), qtype); + } + } + } +} + +quantize_state_impl::~quantize_state_impl() = default; + // // dequantization // @@ -598,9 +618,9 @@ ggml_type llama_tensor_get_type(quantize_state_impl & qs, const llama_model_quan if (!params->pure && ggml_is_quantized(default_type)) { // if the user provided tensor types - use those bool manual = false; - if (!qs.tensor_type_patterns.empty()) { + if (qs.tensor_type_patterns) { const std::string tensor_name(tensor->name); - for (const auto & [pattern, qtype] : qs.tensor_type_patterns) { + for (const auto & [pattern, qtype] : qs.tensor_type_patterns->patterns) { if (std::regex_search(tensor_name, pattern)) { if (qtype != new_type) { LLAMA_LOG_WARN("%s: %-36s - applying manual override: %s -> %s\n", @@ -940,8 +960,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: const auto * it = tensors[i]; const struct ggml_tensor * tensor = it->tensor; - metadata[i].category = tensor_get_category(name); - uint16_t i_split = params->keep_split ? it->idx : 0; if (!ctx_outs[i_split]) { ctx_outs[i_split].reset(gguf_init_empty()); diff --git a/src/llama-quant.h b/src/llama-quant.h index 1614aa4013..b32ce2d92d 100644 --- a/src/llama-quant.h +++ b/src/llama-quant.h @@ -6,7 +6,7 @@ #include "llama-arch.h" -#include +#include #include #include @@ -46,6 +46,8 @@ struct tensor_type_option { ggml_type type = GGML_TYPE_COUNT; }; +struct compiled_tensor_type_patterns; + struct quantize_state_impl { const llama_model & model; const llama_model_quantize_params * params; @@ -67,20 +69,11 @@ struct quantize_state_impl { // used to figure out if a model has tied embeddings (tok_embd shares weights with output) bool has_tied_embeddings = true; // assume tied until we see output.weight - // tensor type override patterns (compiled once, used twice) - std::vector> tensor_type_patterns; + // tensor type override patterns (compiled once, used in llama_tensor_get_type) + std::unique_ptr tensor_type_patterns; - quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params): - model(model), params(params) - { - // compile regex patterns once - they are expensive - if (params->tensor_types) { - const auto & tensor_types = *static_cast *>(params->tensor_types); - for (const auto & [tname, qtype] : tensor_types) { - tensor_type_patterns.emplace_back(std::regex(tname), qtype); - } - } - } + quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params); + ~quantize_state_impl(); }; ggml_type llama_tensor_get_type(quantize_state_impl & qs, const llama_model_quantize_params * params, const ggml_tensor * tensor, ggml_type default_type, const tensor_metadata & tm);