diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 6cf9a883a6..8909bbfb95 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -78,7 +78,7 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) { for (int i = 0; i < n_tokens; ++i) { const float pos = ubatch->pos[i]; attn_scale_data[i] = std::log( - std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0 + std::floor((pos + f_attn_temp_offset) / n_attn_temp_floor_scale) + 1.0 ) * f_attn_temp_scale + 1.0; } @@ -1203,7 +1203,7 @@ ggml_tensor * llm_graph_context::build_inp_pos() const { } ggml_tensor * llm_graph_context::build_inp_attn_scale() const { - auto inp = std::make_unique(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale); + auto inp = std::make_unique(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset); auto & cur = inp->attn_scale; diff --git a/src/llama-graph.h b/src/llama-graph.h index d0c3934f67..e9d387bd7c 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -132,8 +132,8 @@ public: // temperature tuning, used by llama4 class llm_graph_input_attn_temp : public llm_graph_input_i { public: - llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale) - : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {} + llm_graph_input_attn_temp(uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale, float f_attn_temp_offset) + : n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale), f_attn_temp_offset(f_attn_temp_offset) {} virtual ~llm_graph_input_attn_temp() = default; void set_input(const llama_ubatch * ubatch) override; @@ -142,6 +142,7 @@ public: const uint32_t n_attn_temp_floor_scale; const float f_attn_temp_scale; + const float f_attn_temp_offset; }; class llm_graph_input_pos_bucket : public llm_graph_input_i { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index aab319754e..a467c64a14 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -165,6 +165,7 @@ struct llama_hparams { uint32_t n_no_rope_layer_step = 4; uint32_t n_attn_temp_floor_scale = 0; float f_attn_temp_scale = 0.0f; + float f_attn_temp_offset = 0.0f; // offset position index // gemma3n altup uint32_t n_altup = 4; // altup_num_inputs diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 5da1dd6dbb..28f06b4e61 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -668,6 +668,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.n_swa = 8192; hparams.n_attn_temp_floor_scale = 8192; hparams.f_attn_temp_scale = 0.1f; + hparams.f_attn_temp_offset = 1.0f; hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full } @@ -1646,6 +1647,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false); + hparams.f_attn_temp_offset = 0.0f; + switch (hparams.n_layer) { case 27: type = LLM_TYPE_16B; break; case 60: type = LLM_TYPE_236B; break; @@ -2276,6 +2279,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f); + hparams.f_attn_temp_offset = 0.0f; + // TODO: maybe add n_attn_temp_floor_scale as a separate KV? if (hparams.f_attn_temp_scale != 0.0f) { hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn;