models : fix the attn_factor for mistral3 graphs + improve consistency (#17945)
* models : fix the attn_factor for mistral3 graphs * cont : rework attn_factor correction logic * cont : make deepseek2 consistent * cont : add TODO * cont : special-case DSv2 * cont : revert Mistral 3 Large changes * cont : fix DS2 to use the original attn_factor * cont : minor comments
This commit is contained in:
parent
dcb7d17758
commit
7bed317f53
|
|
@ -7286,6 +7286,10 @@ class DeepseekV2Model(TextModel):
|
||||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||||
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
|
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
|
||||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
|
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
|
||||||
|
|
||||||
|
# [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
|
||||||
|
# note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul
|
||||||
|
# ref https://github.com/ggml-org/llama.cpp/pull/17945
|
||||||
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"])
|
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"])
|
||||||
|
|
||||||
_experts: list[dict[str, Tensor]] | None = None
|
_experts: list[dict[str, Tensor]] | None = None
|
||||||
|
|
@ -10041,6 +10045,10 @@ class MistralMoeModel(DeepseekV2Model):
|
||||||
MistralModel.set_mistral_config(self.gguf_writer, self.hparams)
|
MistralModel.set_mistral_config(self.gguf_writer, self.hparams)
|
||||||
yarn_params = self.hparams["yarn"]
|
yarn_params = self.hparams["yarn"]
|
||||||
self.gguf_writer.add_attn_temperature_length(yarn_params["original_max_position_embeddings"])
|
self.gguf_writer.add_attn_temperature_length(yarn_params["original_max_position_embeddings"])
|
||||||
|
|
||||||
|
# [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
|
||||||
|
# note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul
|
||||||
|
# ref https://github.com/ggml-org/llama.cpp/pull/17945
|
||||||
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1) # mscale_all_dim * 0.1
|
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1) # mscale_all_dim * 0.1
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||||
|
|
|
||||||
|
|
@ -574,7 +574,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
||||||
freq_base (cparams.rope_freq_base),
|
freq_base (cparams.rope_freq_base),
|
||||||
freq_scale (cparams.rope_freq_scale),
|
freq_scale (cparams.rope_freq_scale),
|
||||||
ext_factor (cparams.yarn_ext_factor),
|
ext_factor (cparams.yarn_ext_factor),
|
||||||
attn_factor (cparams.yarn_attn_factor),
|
attn_factor (llama_hparams::yarn_attn_factor_adjust(cparams.yarn_attn_factor, cparams.rope_freq_scale, cparams.yarn_ext_factor)),
|
||||||
beta_fast (cparams.yarn_beta_fast),
|
beta_fast (cparams.yarn_beta_fast),
|
||||||
beta_slow (cparams.yarn_beta_slow),
|
beta_slow (cparams.yarn_beta_slow),
|
||||||
norm_eps (hparams.f_norm_eps),
|
norm_eps (hparams.f_norm_eps),
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
#include "llama-hparams.h"
|
#include "llama-hparams.h"
|
||||||
|
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
|
void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
|
||||||
if (dense_first) {
|
if (dense_first) {
|
||||||
|
|
@ -229,3 +231,13 @@ bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float llama_hparams::yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor) {
|
||||||
|
GGML_ASSERT(ext_factor >= 0.0f);
|
||||||
|
|
||||||
|
if (ext_factor != 0.0f) {
|
||||||
|
attn_factor *= 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
|
||||||
|
}
|
||||||
|
|
||||||
|
return attn_factor;
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -107,6 +107,7 @@ struct llama_hparams {
|
||||||
float rope_freq_base_train_swa;
|
float rope_freq_base_train_swa;
|
||||||
float rope_freq_scale_train;
|
float rope_freq_scale_train;
|
||||||
float rope_freq_scale_train_swa;
|
float rope_freq_scale_train_swa;
|
||||||
|
|
||||||
uint32_t n_ctx_orig_yarn;
|
uint32_t n_ctx_orig_yarn;
|
||||||
float rope_yarn_log_mul = 0.0f;
|
float rope_yarn_log_mul = 0.0f;
|
||||||
|
|
||||||
|
|
@ -267,7 +268,13 @@ struct llama_hparams {
|
||||||
// TODO: think of a better place for this function
|
// TODO: think of a better place for this function
|
||||||
// TODO: pack the SWA params in a struct?
|
// TODO: pack the SWA params in a struct?
|
||||||
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
|
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
|
||||||
|
|
||||||
|
// when YARN is applied with yarn_ext_factor != 0.0f, we need to cancel this factor:
|
||||||
|
// https://github.com/ggml-org/llama.cpp/blob/a81a569577cc38b32558958b048228150be63eae/ggml/src/ggml-cpu/ops.cpp#L5541-L5544
|
||||||
|
//
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/discussions/7416
|
||||||
|
// https://github.com/ggml-org/llama.cpp/pull/17945
|
||||||
|
static float yarn_attn_factor_adjust(float attn_factor, float freq_scale, float ext_factor);
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1369,9 +1369,10 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
|
||||||
float freq_scale) const {
|
float freq_scale) const {
|
||||||
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
|
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
|
||||||
|
|
||||||
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
|
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
|
||||||
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
|
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
|
||||||
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
|
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
|
||||||
|
const auto & yarn_attn_factor = llama_hparams::yarn_attn_factor_adjust(cparams.yarn_attn_factor, cparams.rope_freq_scale, cparams.yarn_ext_factor);
|
||||||
|
|
||||||
const auto & n_rot = hparams.n_rot;
|
const auto & n_rot = hparams.n_rot;
|
||||||
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
|
const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE
|
||||||
|
|
@ -1382,12 +1383,6 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
|
||||||
? LLAMA_ROPE_TYPE_NEOX
|
? LLAMA_ROPE_TYPE_NEOX
|
||||||
: hparams.rope_type;
|
: hparams.rope_type;
|
||||||
|
|
||||||
// See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
|
|
||||||
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
|
|
||||||
const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2
|
|
||||||
? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale))
|
|
||||||
: cparams.yarn_attn_factor;
|
|
||||||
|
|
||||||
ggml_tensor * tmp;
|
ggml_tensor * tmp;
|
||||||
|
|
||||||
if (ggml_is_quantized(cur->type)) {
|
if (ggml_is_quantized(cur->type)) {
|
||||||
|
|
|
||||||
|
|
@ -1635,7 +1635,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
// that have no expert_gating_func model parameter set
|
// that have no expert_gating_func model parameter set
|
||||||
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX;
|
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX;
|
||||||
}
|
}
|
||||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);
|
|
||||||
|
if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) {
|
||||||
|
// [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
|
||||||
|
// cancel the factor from the convert script
|
||||||
|
hparams.rope_yarn_log_mul /= 0.1f;
|
||||||
|
}
|
||||||
|
|
||||||
// (optional) temperature tuning - used by mistral-large
|
// (optional) temperature tuning - used by mistral-large
|
||||||
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
|
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
|
||||||
|
|
@ -2267,9 +2272,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
|
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
|
||||||
|
|
||||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false);
|
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false);
|
||||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
|
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
|
||||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);
|
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f);
|
||||||
|
|
||||||
// TODO: maybe add n_attn_temp_floor_scale as a separate KV?
|
// TODO: maybe add n_attn_temp_floor_scale as a separate KV?
|
||||||
if (hparams.f_attn_temp_scale != 0.0f) {
|
if (hparams.f_attn_temp_scale != 0.0f) {
|
||||||
|
|
@ -2279,18 +2284,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: this seems to be correct with the case of mscale == mscale_all_dims == 1.0f
|
|
||||||
// but may need further verification with other values
|
|
||||||
if (hparams.rope_yarn_log_mul != 0.0f) {
|
|
||||||
float factor = 1.0f / hparams.rope_freq_scale_train;
|
|
||||||
float mscale = 1.0f;
|
|
||||||
float mscale_all_dims = hparams.rope_yarn_log_mul;
|
|
||||||
static auto get_mscale = [](float scale, float mscale) {
|
|
||||||
return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
|
|
||||||
};
|
|
||||||
hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
case 26: type = LLM_TYPE_3B; break;
|
case 26: type = LLM_TYPE_3B; break;
|
||||||
case 34: type = LLM_TYPE_8B; break;
|
case 34: type = LLM_TYPE_8B; break;
|
||||||
|
|
@ -2301,6 +2294,32 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
default: throw std::runtime_error("unsupported model architecture");
|
default: throw std::runtime_error("unsupported model architecture");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ref: https://github.com/huggingface/transformers/blob/6d00f6b0a5679c36510f203e4226e36f517c3032/src/transformers/modeling_rope_utils.py#L336-L348
|
||||||
|
if (hparams.rope_yarn_log_mul != 0.0f) {
|
||||||
|
const float factor = 1.0f / hparams.rope_freq_scale_train;
|
||||||
|
|
||||||
|
// note: here we assume `mscale == 1.0f`
|
||||||
|
// TODO: start reading the actual value of mscale and handle the case where it is not 1.0f
|
||||||
|
float mscale = 1.0f;
|
||||||
|
const float mscale_all_dims = hparams.rope_yarn_log_mul;
|
||||||
|
|
||||||
|
// [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
|
||||||
|
// special-case DEEPSEEK v2:
|
||||||
|
// https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite-Chat/blob/main/config.json#L42-L43
|
||||||
|
if (arch == LLM_ARCH_DEEPSEEK2 && mscale_all_dims != 1.0f) {
|
||||||
|
mscale = mscale_all_dims;
|
||||||
|
}
|
||||||
|
|
||||||
|
static auto get_mscale = [](float scale, float mscale) {
|
||||||
|
return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
|
||||||
|
};
|
||||||
|
|
||||||
|
hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
|
||||||
|
|
||||||
|
LLAMA_LOG_WARN("%s: setting new yarn_attn_factor = %.4f (mscale == %.1f, mscale_all_dim = %.1f)\n",
|
||||||
|
__func__, hparams.yarn_attn_factor, mscale, mscale_all_dims);
|
||||||
|
}
|
||||||
|
|
||||||
pimpl->n_bytes = ml.n_bytes;
|
pimpl->n_bytes = ml.n_bytes;
|
||||||
|
|
||||||
pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name();
|
pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name();
|
||||||
|
|
@ -6806,6 +6825,7 @@ void llama_model::print_info() const {
|
||||||
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
|
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
|
||||||
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
|
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
|
||||||
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
|
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
|
||||||
|
LLAMA_LOG_INFO("%s: rope_yarn_log_mul= %.4f\n", __func__, hparams.rope_yarn_log_mul);
|
||||||
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
|
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
|
||||||
// MRoPE (Multi-axis Rotary Position Embedding) sections
|
// MRoPE (Multi-axis Rotary Position Embedding) sections
|
||||||
if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) {
|
if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) {
|
||||||
|
|
@ -6869,7 +6889,6 @@ void llama_model::print_info() const {
|
||||||
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
|
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
|
||||||
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
|
LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
|
||||||
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
|
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
|
||||||
LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (arch == LLM_ARCH_QWEN2MOE) {
|
if (arch == LLM_ARCH_QWEN2MOE) {
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,5 @@
|
||||||
#include "models.h"
|
#include "models.h"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) :
|
llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) :
|
||||||
llm_graph_context(params) {
|
llm_graph_context(params) {
|
||||||
// lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
|
// lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
|
||||||
|
|
@ -20,9 +18,15 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
|
||||||
|
|
||||||
// We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
|
// We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
|
||||||
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
|
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
|
||||||
const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
|
// And also: https://github.com/ggml-org/llama.cpp/pull/17945 [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
|
||||||
const float kq_scale = 1.0f * mscale * mscale / sqrtf(float(n_embd_head_k));
|
|
||||||
const float attn_factor = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
|
// first cancel the adjustment from llama_hparams::yarn_attn_factor_adjust to get the original attn_factor
|
||||||
|
GGML_ASSERT(ext_factor >= 0.0f);
|
||||||
|
const float attn_factor_org = attn_factor * (1.0f + 0.1f * logf(1.0f / freq_scale));
|
||||||
|
|
||||||
|
// use the original attn_factor to pre-scale the kq_scale
|
||||||
|
const float mscale = attn_factor_org * (1.0f + 0.1f * hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
|
||||||
|
const float kq_scale = 1.0f * mscale * mscale / sqrtf(float(n_embd_head_k));
|
||||||
|
|
||||||
ggml_tensor * cur;
|
ggml_tensor * cur;
|
||||||
ggml_tensor * inpL;
|
ggml_tensor * inpL;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue