From 6d86944cb494bdc105b86fabfdcc0d4a06c19dc3 Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Wed, 3 Sep 2025 14:32:39 -0400 Subject: [PATCH] working through previous attemp, implimented more accurate conversion per previous attempt, added local sliding window attention that alternates every third layer --- convert_hf_to_gguf.py | 47 +++++---- src/llama-hparams.h | 1 + src/llama-kv-cache-unified.cpp | 12 +++ src/llama-model.cpp | 168 +++++++++++++-------------------- 4 files changed, 102 insertions(+), 126 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3c483e1028..cff934ba23 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -8308,37 +8308,32 @@ class SmallThinkerModel(TextModel): raise ValueError(f"Unprocessed experts: {experts}") -@ModelBase.register("ModernBertModel") -class ModernBertModel(TextModel): +@ModelBase.register("ModernBertModel", "ModernBertForMaskedLM", "ModernBertForSequenceClassification") +class ModernBertModel(BertModel): model_arch = gguf.MODEL_ARCH.MODERN_BERT - def set_gguf_parameters(self) -> None: - # Determine block count (number of hidden layers) - block_count = self.hparams.get("num_hidden_layers") or self.hparams.get("num_hidden_layers_alt") - if block_count is None: - raise ValueError("Could not determine number of hidden layers from hparams") + def set_vocab(self): + self._set_vocab_gpt2() + self.gguf_writer.add_add_bos_token(True) + self.gguf_writer.add_add_eos_token(True) - # Attention heads and dimensions - n_head = self.hparams.get("num_attention_heads") - if n_head is None: - raise ValueError("Missing 'num_attention_heads' in hparams") - - hidden_size = self.hparams["hidden_size"] - head_dim = hidden_size // n_head - ffn_dim = self.hparams.get("intermediate_size", 4 * hidden_size) - - # GGUF parameter assignment - self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 512)) - self.gguf_writer.add_embedding_length(hidden_size) - self.gguf_writer.add_feed_forward_length(ffn_dim) - self.gguf_writer.add_block_count(block_count) - self.gguf_writer.add_head_count(n_head) - self.gguf_writer.add_layer_norm_eps(self.hparams.get("layer_norm_eps", 1e-12)) - self.gguf_writer.add_file_type(self.ftype) + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_sliding_window(self.hparams["local_attention"]) + self.gguf_writer.add_rope_freq_base(self.hparams["global_rope_theta"]) + self.gguf_writer.add_rope_freq_base_swa(self.hparams["local_rope_theta"]) + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - # Directly map tensor names without QKV splitting or reordering - return [(self.map_tensor_name(name), data_torch)] + # These layers act as MLM head, so we don't need them + if name.startswith("decoder."): + return [] + + if name.startswith("model."): + name = name[6:] + + return super().modify_tensors(data_torch, name, bid) ###### CONVERSION LOGIC ###### diff --git a/src/llama-hparams.h b/src/llama-hparams.h index bd23122443..2e13a7732c 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -19,6 +19,7 @@ enum llama_swa_type { LLAMA_SWA_TYPE_NONE = 0, LLAMA_SWA_TYPE_STANDARD = 1, LLAMA_SWA_TYPE_CHUNKED = 2, + LLAMA_SWA_TYPE_LOCAL = 3, }; struct llama_hparams_posnet { diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index e539142e6b..678a7d23ad 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -1807,6 +1807,18 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { return true; } } break; + case LLAMA_SWA_TYPE_LOCAL: + { + const int32_t half_n_swa = (int32_t) n_swa / 2; + const int32_t pos_diff = p1 - p0; + + // mask if outside the window + if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { + return true; + } + } break; + + } return false; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a159eb3472..6f70335647 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -759,11 +759,20 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_MODERN_BERT: { - //ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - LLAMA_LOG_INFO("Switching Modern Bert Arch\n"); + + hparams.swa_type = LLAMA_SWA_TYPE_LOCAL; + + hparams.set_swa_pattern(3, 0); + hparams.n_swa = 128; + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + switch (hparams.n_layer) { case 12: - type = LLM_TYPE_47M; break; // granite-embeddings-mall + type = LLM_TYPE_47M; break; // granite-embeddings-small default: type = LLM_TYPE_UNKNOWN; } } break; @@ -7544,152 +7553,111 @@ struct llm_build_bert : public llm_graph_context { struct llm_build_modern_bert : public llm_graph_context { llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd = hparams.n_embd; - const int64_t n_layer = hparams.n_layer; - const int64_t n_head = hparams.n_head(); - const int64_t n_head_kv = hparams.n_head_kv(); - const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - const int64_t n_tokens = ubatch.n_tokens; + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_head = hparams.n_head(); + const int64_t n_head_kv = hparams.n_head_kv(); + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + const int64_t n_tokens = ubatch.n_tokens; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - // RoPE params - const int32_t rope_type = LLAMA_ROPE_TYPE_NEOX; // uses rotary - const int32_t n_rot = hparams.n_rot; - const int32_t n_ctx_orig = hparams.n_ctx_train; + // rope params + const int32_t rope_type = LLAMA_ROPE_TYPE_NEOX; + const int32_t n_rot = hparams.n_rot; + const int32_t n_ctx_orig = hparams.n_ctx_train; + const float freq_base = hparams.rope_freq_base_train; + const float freq_scale = hparams.rope_freq_scale_train; + const float attn_factor = 1.0f; + const float ext_factor = 1.0f; + const float beta_fast = 0.0f; + const float beta_slow = 0.0f; - ggml_tensor * cur; - ggml_tensor * inpL; - ggml_tensor * inp_pos = nullptr; - - // needs positions for RoPE - inp_pos = build_inp_pos(); - - // embeddings (token + optional type), NO absolute pos embed - inpL = build_inp_embd(model.tok_embd); + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inpL = build_inp_embd(model.tok_embd); if (model.type_embd) { - ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); - inpL = ggml_add(ctx0, inpL, type_row0); + inpL = ggml_add(ctx0, inpL, ggml_view_1d(ctx0, model.type_embd, n_embd, 0)); } - cb(inpL, "inp_embd", -1); - - // embeddings LayerNorm (embeddings.norm) inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); - cb(inpL, "inp_norm", -1); - auto * inp_attn = build_attn_inp_no_cache(); + auto * inp_attn = build_attn_inp_no_cache(); ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { ggml_tensor * x = inpL; - // pre attention norm (attn_norm). Layer 0 may be Identity() -> nullptr + // Pre attention Layer norm ggml_tensor * x_attn_in = x; if (model.layers[il].attn_norm) { - x_attn_in = build_norm(x, - model.layers[il].attn_norm, - model.layers[il].attn_norm_b, - LLM_NORM, il); - cb(x_attn_in, "attn_pre_norm", il); - } else { - cb(x_attn_in, "attn_pre_norm_identity", il); + x_attn_in = build_norm(x, model.layers[il].attn_norm, model.layers[il].attn_norm_b, LLM_NORM, il); } - // Attention: fused Wqkv -> split -> heads -> RoPE(Q,K) -> attn -> Wo - ggml_tensor * qkv = nullptr; - ggml_tensor * Qcur; - ggml_tensor * Kcur; - ggml_tensor * Vcur; - - GGML_ASSERT(model.layers[il].wqkv); // fused QKV - qkv = build_lora_mm(model.layers[il].wqkv, x_attn_in); - cb(qkv, "wqkv", il); - + // fused qkv + GGML_ASSERT(model.layers[il].wqkv); + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, x_attn_in); if (model.layers[il].bqkv) { qkv = ggml_add(ctx0, qkv, model.layers[il].bqkv); - cb(qkv, "bqkv", il); } - // Fused layout: [ (n_embd + 2*n_embd_gqa), n_tokens ] - Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd, n_tokens, qkv->nb[1], 0*sizeof(float)*(n_embd))); - Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], 1*sizeof(float)*(n_embd))); - Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa))); + ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd, n_tokens, qkv->nb[1], 0)); + ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], n_embd)); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], n_embd + n_embd_gqa)); - // optional per Q/K - if (model.layers[il].attn_q_norm) { - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il); - } - if (model.layers[il].attn_k_norm) { - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il); - } + // optional q/k LayerNorm + if (model.layers[il].attn_q_norm) Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il); + if (model.layers[il].attn_k_norm) Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il); - // heads + // reshape for multi head Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - // RoPE (NEOX ... maybe?) on Q and K + // rope embedding Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - - cb(Qcur, "Qcur_rope", il); - cb(Kcur, "Kcur_rope", il); - cb(Vcur, "Vcur", il); + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); ggml_tensor * attn_out = build_attn( inp_attn, - model.layers[il].wo, model.layers[il].bo, // Wo, optional bias + model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, - /*K_cache*/ nullptr, - /*V_cache*/ nullptr, + /*k cache*/ nullptr, + /*v cache*/ nullptr, 1.0f / sqrtf(float(n_embd_head)), - il); - cb(attn_out, "attn_out", il); + il + ); - // residual after attention ggml_tensor * cur_attn = ggml_add(ctx0, attn_out, x); - // ifwe subselect outputs, do it at the last layer after attn resid + // optional subselect output tokens (inp_out_ids) if (il == n_layer - 1 && inp_out_ids) { - cur_attn = ggml_get_rows(ctx0, cur_attn, inp_out_ids); - x = ggml_get_rows(ctx0, x, inp_out_ids); + cur_attn = ggml_get_rows(ctx0, cur_attn, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } - // pre mlp norm - ggml_tensor * h = build_norm(cur_attn, - model.layers[il].ffn_norm, - model.layers[il].ffn_norm_b, - LLM_NORM, il); - cb(h, "mlp_pre_norm", il); + // pre mlp LayerNorm + ggml_tensor * h = build_norm(cur_attn, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, il); - // GEGLU because we will split ffn_up which has shape [n_embd, n_ff * 2] and ffn_down has shape [n_ff, n_embd] + // geglu FFN ggml_tensor * mlp_out = build_ffn( h, - model.layers[il].ffn_up, /*up_b*/ NULL, /*up_shexp*/ NULL, - /*gate*/ NULL , /*gate_b*/ NULL, /*gate_shexp*/ NULL, - model.layers[il].ffn_down, /*down_b*/ NULL, /*down_shexp*/ NULL, - /*act_scales*/ NULL, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, LLM_FFN_GEGLU, LLM_FFN_PAR, il ); - cb(mlp_out, "ffn_out_geglu", il); - // Residual after MLP - ggml_tensor * cur_layer = ggml_add(ctx0, mlp_out, cur_attn); - - // feed into next layer - inpL = cur_layer; + // resid addition + inpL = ggml_add(ctx0, mlp_out, cur_attn); } - // final model norm (final_norm) - cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1); - cb(cur, "final_norm", -1); - + ggml_tensor * cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1); res->t_embd = cur; ggml_build_forward_expand(gf, cur); }