diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index 6d67448fda..7f854d0cb0 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -259,12 +259,27 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K; ++qs.i_attention_wv; } - else if (qs.model.hparams.n_expert >= 8 && name.find("attn_k") != std::string::npos) { + else if (qs.model.hparams.n_expert >= 8 && name.find("attn_k.weight") != std::string::npos) { new_type = GGML_TYPE_Q4_K; } - else if (qs.model.hparams.n_expert >= 8 && name.find("attn_q") != std::string::npos) { + else if (qs.model.hparams.n_expert >= 8 && name.find("attn_kv_a_mqa.weight") != std::string::npos) { new_type = GGML_TYPE_Q4_K; } + else if (qs.model.hparams.n_expert >= 8 && name.find("attn_kv_b.weight") != std::string::npos) { + if (qs.i_attention_wv < qs.n_attention_wv/16) { + new_type = GGML_TYPE_Q4_K; + } + else if (use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) { + new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K; + } + ++qs.i_attention_wv; + } + else if (qs.model.hparams.n_expert >= 8 && name.find("attn_q_a.weight") != std::string::npos) { + new_type = GGML_TYPE_Q4_K; + } + else if (qs.model.hparams.n_expert >= 8 && name.find("attn_q_b.weight") != std::string::npos) { + new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K; + } else if (qs.model.hparams.n_expert >= 8 && name.find("ffn_down") != std::string::npos) { if (qs.i_ffn_down < qs.n_ffn_down/16) { new_type = GGML_TYPE_Q4_K;