Add specific attn_qkv logic
This commit is contained in:
parent
fc0a02df27
commit
bf34e75799
|
|
@ -81,6 +81,7 @@ struct quantize_state_impl {
|
||||||
const llama_model_quantize_params * params;
|
const llama_model_quantize_params * params;
|
||||||
|
|
||||||
int n_attention_wv = 0;
|
int n_attention_wv = 0;
|
||||||
|
int n_attn_qkv = 0;
|
||||||
int n_ffn_down = 0;
|
int n_ffn_down = 0;
|
||||||
int n_ffn_gate = 0;
|
int n_ffn_gate = 0;
|
||||||
int n_ffn_up = 0;
|
int n_ffn_up = 0;
|
||||||
|
|
@ -92,6 +93,7 @@ struct quantize_state_impl {
|
||||||
int n_ffn_up_shexp = 0;
|
int n_ffn_up_shexp = 0;
|
||||||
int n_ssm_out = 0;
|
int n_ssm_out = 0;
|
||||||
int n_attn_q = 0;
|
int n_attn_q = 0;
|
||||||
|
int i_attn_qkv = 0;
|
||||||
int i_attention_wv = 0;
|
int i_attention_wv = 0;
|
||||||
int i_ffn_down = 0;
|
int i_ffn_down = 0;
|
||||||
int i_ffn_gate = 0;
|
int i_ffn_gate = 0;
|
||||||
|
|
@ -333,7 +335,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
|
||||||
}
|
}
|
||||||
} else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
|
} else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S ||
|
||||||
ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
|
ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) {
|
||||||
if (name.find("attn_v.weight") != std::string::npos || name.find("attn_qkv.weight") != std::string::npos) {
|
if (name.find("attn_v.weight") != std::string::npos) {
|
||||||
if (use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) {
|
if (use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) {
|
||||||
new_type = GGML_TYPE_Q6_K;
|
new_type = GGML_TYPE_Q6_K;
|
||||||
}
|
}
|
||||||
|
|
@ -342,6 +344,20 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
|
||||||
}
|
}
|
||||||
++qs.i_attention_wv;
|
++qs.i_attention_wv;
|
||||||
}
|
}
|
||||||
|
else if (name.find("attn_qkv.weight") != std::string::npos) {
|
||||||
|
if (qs.model.hparams.n_expert >= 8) {
|
||||||
|
if (use_more_bits(qs.i_attn_qkv, qs.n_attn_qkv)) {
|
||||||
|
new_type = GGML_TYPE_Q6_K;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
new_type = GGML_TYPE_Q4_K;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (use_more_bits(qs.i_attn_qkv, qs.n_attn_qkv)) {
|
||||||
|
new_type = GGML_TYPE_Q4_K;
|
||||||
|
}
|
||||||
|
++qs.i_attn_qkv;
|
||||||
|
}
|
||||||
else if (qs.model.hparams.n_expert >= 8 && name.find("ssm_out.weight") != std::string::npos) {
|
else if (qs.model.hparams.n_expert >= 8 && name.find("ssm_out.weight") != std::string::npos) {
|
||||||
if (use_more_bits(qs.i_ssm_out, qs.n_ssm_out)) {
|
if (use_more_bits(qs.i_ssm_out, qs.n_ssm_out)) {
|
||||||
new_type = GGML_TYPE_Q4_K;
|
new_type = GGML_TYPE_Q4_K;
|
||||||
|
|
@ -640,7 +656,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
|
||||||
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) {
|
if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) {
|
||||||
new_type = GGML_TYPE_Q4_K;
|
new_type = GGML_TYPE_Q4_K;
|
||||||
}
|
}
|
||||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS ||
|
else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS ||
|
||||||
ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M ||
|
ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M ||
|
||||||
ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
|
ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
|
||||||
ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S) {
|
ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S) {
|
||||||
|
|
@ -673,7 +689,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
|
||||||
new_type = GGML_TYPE_Q5_K;
|
new_type = GGML_TYPE_Q5_K;
|
||||||
}
|
}
|
||||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
|
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
|
||||||
ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS){
|
ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
|
||||||
new_type = GGML_TYPE_IQ4_XS;
|
new_type = GGML_TYPE_IQ4_XS;
|
||||||
}
|
}
|
||||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) {
|
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) {
|
||||||
|
|
@ -705,7 +721,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
|
||||||
new_type = GGML_TYPE_Q4_K;
|
new_type = GGML_TYPE_Q4_K;
|
||||||
}
|
}
|
||||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
|
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
|
||||||
ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS){
|
ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
|
||||||
new_type = GGML_TYPE_IQ4_XS;
|
new_type = GGML_TYPE_IQ4_XS;
|
||||||
}
|
}
|
||||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) {
|
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) {
|
||||||
|
|
@ -720,7 +736,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
|
||||||
new_type = GGML_TYPE_Q5_K;
|
new_type = GGML_TYPE_Q5_K;
|
||||||
}
|
}
|
||||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
|
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_M || ftype == LLAMA_FTYPE_MOSTLY_IQ3_S ||
|
||||||
ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS){
|
ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
|
||||||
new_type = GGML_TYPE_IQ4_XS;
|
new_type = GGML_TYPE_IQ4_XS;
|
||||||
}
|
}
|
||||||
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) {
|
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) {
|
||||||
|
|
@ -973,11 +989,12 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||||
|
|
||||||
// TODO: avoid hardcoded tensor names - use the TN_* constants
|
// TODO: avoid hardcoded tensor names - use the TN_* constants
|
||||||
if (name.find("attn_v.weight") != std::string::npos ||
|
if (name.find("attn_v.weight") != std::string::npos ||
|
||||||
name.find("attn_qkv.weight") != std::string::npos ||
|
|
||||||
name.find("attn_kv_b.weight")!= std::string::npos) {
|
name.find("attn_kv_b.weight")!= std::string::npos) {
|
||||||
++qs.n_attention_wv;
|
++qs.n_attention_wv;
|
||||||
} else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
|
} else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) {
|
||||||
qs.has_output = true;
|
qs.has_output = true;
|
||||||
|
} else if (name.find("attn_qkv.weight") != std::string::npos) {
|
||||||
|
++qs.n_attn_qkv;
|
||||||
} else if (name.find("ffn_gate_exps.weight") != std::string::npos) {
|
} else if (name.find("ffn_gate_exps.weight") != std::string::npos) {
|
||||||
++qs.n_ffn_gate_exps;
|
++qs.n_ffn_gate_exps;
|
||||||
} else if (name.find("ffn_gate_shexp.weight") != std::string::npos) {
|
} else if (name.find("ffn_gate_shexp.weight") != std::string::npos) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue