graph: add f_attn_temp_offset

This commit is contained in:
Xuan Son Nguyen 2025-12-14 12:12:12 +01:00
parent 254098a279
commit 357f999381
4 changed files with 11 additions and 4 deletions

View File

@ -78,7 +78,7 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
for (int i = 0; i < n_tokens; ++i) { for (int i = 0; i < n_tokens; ++i) {
const float pos = ubatch->pos[i]; const float pos = ubatch->pos[i];
attn_scale_data[i] = std::log( attn_scale_data[i] = std::log(
std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0 std::floor((pos + f_attn_temp_offset) / n_attn_temp_floor_scale) + 1.0
) * f_attn_temp_scale + 1.0; ) * f_attn_temp_scale + 1.0;
} }
@ -1203,7 +1203,7 @@ ggml_tensor * llm_graph_context::build_inp_pos() const {
} }
ggml_tensor * llm_graph_context::build_inp_attn_scale() const { ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale); auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset);
auto & cur = inp->attn_scale; auto & cur = inp->attn_scale;

View File

@ -132,8 +132,8 @@ public:
// temperature tuning, used by llama4 // temperature tuning, used by llama4
class llm_graph_input_attn_temp : public llm_graph_input_i { class llm_graph_input_attn_temp : public llm_graph_input_i {
public: public:
llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale) llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale, float f_attn_temp_offset)
: n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {} : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale), f_attn_temp_offset(f_attn_temp_offset) {}
virtual ~llm_graph_input_attn_temp() = default; virtual ~llm_graph_input_attn_temp() = default;
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
@ -142,6 +142,7 @@ public:
const uint32_t n_attn_temp_floor_scale; const uint32_t n_attn_temp_floor_scale;
const float f_attn_temp_scale; const float f_attn_temp_scale;
const float f_attn_temp_offset;
}; };
class llm_graph_input_pos_bucket : public llm_graph_input_i { class llm_graph_input_pos_bucket : public llm_graph_input_i {

View File

@ -165,6 +165,7 @@ struct llama_hparams {
uint32_t n_no_rope_layer_step = 4; uint32_t n_no_rope_layer_step = 4;
uint32_t n_attn_temp_floor_scale = 0; uint32_t n_attn_temp_floor_scale = 0;
float f_attn_temp_scale = 0.0f; float f_attn_temp_scale = 0.0f;
float f_attn_temp_offset = 0.0f; // offset position index
// gemma3n altup // gemma3n altup
uint32_t n_altup = 4; // altup_num_inputs uint32_t n_altup = 4; // altup_num_inputs

View File

@ -668,6 +668,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.n_swa = 8192; hparams.n_swa = 8192;
hparams.n_attn_temp_floor_scale = 8192; hparams.n_attn_temp_floor_scale = 8192;
hparams.f_attn_temp_scale = 0.1f; hparams.f_attn_temp_scale = 0.1f;
hparams.f_attn_temp_offset = 1.0f;
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
} }
@ -1646,6 +1647,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false); ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false);
hparams.f_attn_temp_offset = 0.0f;
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 27: type = LLM_TYPE_16B; break; case 27: type = LLM_TYPE_16B; break;
case 60: type = LLM_TYPE_236B; break; case 60: type = LLM_TYPE_236B; break;
@ -2276,6 +2279,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f); ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f);
hparams.f_attn_temp_offset = 0.0f;
// TODO: maybe add n_attn_temp_floor_scale as a separate KV? // TODO: maybe add n_attn_temp_floor_scale as a separate KV?
if (hparams.f_attn_temp_scale != 0.0f) { if (hparams.f_attn_temp_scale != 0.0f) {
hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn; hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn;