From e101005d1af7ff184234a6f7d0ca054ef2dfb2dd Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Sun, 7 Sep 2025 21:00:38 -0400 Subject: [PATCH] working on swa with local and global alternating attention --- src/llama-arch.cpp | 1 + src/llama-arch.h | 1 + src/llama-model.cpp | 130 +++++++++++++++++++++++++++++++++++++------- 3 files changed, 113 insertions(+), 19 deletions(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 9a009ac902..cbb1f3d8f6 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -171,6 +171,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index c99448e78f..8422cbe2a1 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -176,6 +176,7 @@ enum llm_kv { LLM_KV_ROPE_DIMENSION_SECTIONS, LLM_KV_ROPE_FREQ_BASE, LLM_KV_ROPE_SCALE_LINEAR, + LLM_KV_ROPE_FREQ_BASE_SWA, LLM_KV_ROPE_SCALING_TYPE, LLM_KV_ROPE_SCALING_FACTOR, LLM_KV_ROPE_SCALING_ATTN_FACTOR, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 6f70335647..a3b4646f0b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7559,6 +7559,7 @@ struct llm_build_modern_bert : public llm_graph_context { const int64_t n_head_kv = hparams.n_head_kv(); const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + const int64_t n_local_swa = hparams.n_swa; const int64_t n_tokens = ubatch.n_tokens; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -7574,7 +7575,17 @@ struct llm_build_modern_bert : public llm_graph_context { const float beta_fast = 0.0f; const float beta_slow = 0.0f; - ggml_tensor * inp_pos = build_inp_pos(); + + ggml_tensor *inp_pos_global = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, 4096, 1); + ggml_set_input(inp_pos_global); + size_t element_size = ggml_type_size(inp_pos_global->type); + + size_t nb1 = element_size; + size_t nb2 = nb1; + + inp_pos_global = ggml_view_3d(ctx0, inp_pos_global, 1, 1, 4096, nb1, nb2, 0); + inp_pos_global = ggml_cont(ctx0, inp_pos_global); + ggml_tensor * inpL = build_inp_embd(model.tok_embd); if (model.type_embd) { @@ -7582,19 +7593,20 @@ struct llm_build_modern_bert : public llm_graph_context { } inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); - auto * inp_attn = build_attn_inp_no_cache(); + auto * inp_attn = build_attn_inp_kv_unified_iswa(); ggml_tensor * inp_out_ids = build_inp_out_ids(); + for (int il = 0; il < n_layer; ++il) { ggml_tensor * x = inpL; - // Pre attention Layer norm + // pre attn LayerNorm ggml_tensor * x_attn_in = x; if (model.layers[il].attn_norm) { x_attn_in = build_norm(x, model.layers[il].attn_norm, model.layers[il].attn_norm_b, LLM_NORM, il); } - // fused qkv + // fused QKV GGML_ASSERT(model.layers[il].wqkv); ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, x_attn_in); if (model.layers[il].bqkv) { @@ -7609,41 +7621,120 @@ struct llm_build_modern_bert : public llm_graph_context { if (model.layers[il].attn_q_norm) Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il); if (model.layers[il].attn_k_norm) Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il); - // reshape for multi head + // reshape for multi-head attention Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - // rope embedding - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + // global or local layer + bool is_global = ((il + 1) % 3 == 0); + float freq_base_l = is_global ? 160000.0f : 10000.0f; // rope theta + float freq_scale_l = 1.0f; + + ggml_tensor * pos_q = inp_pos_global; + + ggml_tensor * K_work = Kcur; + ggml_tensor * V_work = Vcur; + ggml_tensor * pos_k = inp_pos_global; + + if (!is_global) { + ggml_tensor * idx_src = inp_attn->self_k_idxs_swa; + + ggml_tensor * idx_view1d = ggml_view_1d(ctx0, idx_src, idx_src->ne[0], 0); + ggml_tensor * idx_cont = ggml_cont(ctx0, idx_view1d); + + ggml_tensor * idx_i32 = idx_cont; + if (idx_i32->type != GGML_TYPE_I32) { + idx_i32 = ggml_cast(ctx0, idx_cont, GGML_TYPE_I32); + } + + const int64_t n_indices = idx_i32->ne[0]; + ggml_tensor * idx_2d = ggml_view_2d(ctx0, idx_i32, 1, n_indices, sizeof(int32_t), 0); + + idx_2d = ggml_cont(ctx0, idx_2d); + if (idx_2d->type != GGML_TYPE_I32) idx_2d = ggml_cast(ctx0, idx_2d, GGML_TYPE_I32); + + Kcur->ne[0], Kcur->ne[1], Kcur->ne[2], + idx_2d->ne[0], idx_2d->ne[1], idx_2d->ne[2], idx_2d->ne[3], + idx_2d->type); + + K_work = ggml_get_rows(ctx0, Kcur, idx_2d); + V_work = ggml_get_rows(ctx0, Vcur, idx_2d); + + + + ggml_tensor * pos_rows = ggml_get_rows(ctx0, inp_pos_global, idx_2d); + + if (!ggml_is_vector(pos_rows)) { + const int64_t n_el = ggml_nelements(pos_rows); + pos_rows = ggml_view_1d(ctx0, pos_rows, n_el, 0); + pos_rows = ggml_cont(ctx0, pos_rows); + } else { + pos_rows = ggml_cont(ctx0, pos_rows); + } + // ensure I32 + if (pos_rows->type != GGML_TYPE_I32) { + pos_rows = ggml_cast(ctx0, pos_rows, GGML_TYPE_I32); + } + + // final pos_k to pass to rope + pos_k = pos_rows; + LLAMA_LOG_INFO("pos_k final: ne[0]=%lld, type=%d\n", pos_k->ne[0], pos_k->type); + } + + if( !ggml_is_vector(pos_q) ) { + const int64_t n_el = ggml_nelements(pos_q); + pos_q = ggml_view_1d(ctx0, pos_q, n_el, 0); + pos_q = ggml_cont(ctx0, pos_q); + } + if( !ggml_is_vector(pos_q) ) { + } + + + // apply rope + Qcur = ggml_rope_ext(ctx0, Qcur, pos_q, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + if( !ggml_is_vector(pos_k) ) { + const int64_t n_el = ggml_nelements(pos_k); + pos_k = ggml_view_1d(ctx0, pos_k, n_el, 0); + pos_k = ggml_cont(ctx0, pos_k); + } + + K_work = ggml_rope_ext(ctx0, K_work, pos_k, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); + // choseing mask, global vs swa + ggml_tensor * kq_b_layer = is_global ? inp_attn->self_kq_mask : inp_attn->self_kq_mask_swa; + ggml_tensor * attn_out = build_attn( inp_attn, - model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, - /*k cache*/ nullptr, - /*v cache*/ nullptr, + model.layers[il].wo, + model.layers[il].bo, + Qcur, + K_work, + V_work, + kq_b_layer, + nullptr, 1.0f / sqrtf(float(n_embd_head)), il ); + // residual addition ggml_tensor * cur_attn = ggml_add(ctx0, attn_out, x); - // optional subselect output tokens (inp_out_ids) + // optional output select if (il == n_layer - 1 && inp_out_ids) { cur_attn = ggml_get_rows(ctx0, cur_attn, inp_out_ids); inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } - // pre mlp LayerNorm + // pre mlp layer norm ggml_tensor * h = build_norm(cur_attn, model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, LLM_NORM, il); - // geglu FFN + // geglu ffn ggml_tensor * mlp_out = build_ffn( h, model.layers[il].ffn_up, NULL, NULL, @@ -7653,10 +7744,11 @@ struct llm_build_modern_bert : public llm_graph_context { LLM_FFN_GEGLU, LLM_FFN_PAR, il ); - // resid addition + // resudi addition after FFN inpL = ggml_add(ctx0, mlp_out, cur_attn); } + ggml_tensor * cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1); res->t_embd = cur; ggml_build_forward_expand(gf, cur);