From e296a0b6e694283ef0fc52c1bf9a780cff930c77 Mon Sep 17 00:00:00 2001 From: ryan-mangeno Date: Mon, 8 Sep 2025 15:38:13 -0400 Subject: [PATCH] starting to work, and some cleanup, currently failing on last layer construction in graph build --- src/llama-graph.cpp | 2 +- src/llama-model.cpp | 21 +++++++++++++-------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 8760046c84..9ca2e579d7 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1547,7 +1547,7 @@ ggml_tensor * llm_graph_context::build_attn_with_sinks( // optionally store to KV cache if (k_cur) { const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs(); - LLAMA_LOG_INFO("k_cur.shape = {%lld, %lld, %lld, %lld}\n", k_cur->ne[0], k_cur->ne[1], k_cur->ne[2], k_cur->ne[3]); + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 8966cdcf12..34cd49083b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7598,6 +7598,7 @@ struct llm_build_modern_bert : public llm_graph_context { for (int il = 0; il < n_layer; ++il) { + LLAMA_LOG_INFO("Setting layer %d\n", il); ggml_tensor * x = inpL; // pre attn LayerNorm @@ -7656,8 +7657,6 @@ struct llm_build_modern_bert : public llm_graph_context { K_work = ggml_get_rows(ctx0, Kcur, idx_2d); V_work = ggml_get_rows(ctx0, Vcur, idx_2d); - - ggml_tensor * pos_rows = ggml_get_rows(ctx0, inp_pos_global, idx_2d); @@ -7675,7 +7674,6 @@ struct llm_build_modern_bert : public llm_graph_context { // final pos_k to pass to rope pos_k = pos_rows; - LLAMA_LOG_INFO("pos_k final: ne[0]=%lld, ne[1]=%lld type=%d\n", pos_k->ne[0], pos_k->ne[1], pos_k->type); } if( !ggml_is_vector(pos_q) ) { @@ -7683,9 +7681,6 @@ struct llm_build_modern_bert : public llm_graph_context { pos_q = ggml_view_1d(ctx0, pos_q, n_el, 0); pos_q = ggml_cont(ctx0, pos_q); } - if( !ggml_is_vector(pos_q) ) { - } - // apply rope Qcur = ggml_rope_ext(ctx0, Qcur, pos_q, nullptr, @@ -7705,6 +7700,16 @@ struct llm_build_modern_bert : public llm_graph_context { // choseing mask, global vs swa ggml_tensor * kq_mask = is_global ? inp_attn->self_kq_mask : inp_attn->self_kq_mask_swa; + // flatten K/V back to full embedding dim + int64_t n_embd = n_embd_head * n_head_kv; + int64_t n_tokens = Kcur->ne[2]; + + ggml_tensor *K_2d = ggml_reshape_2d(ctx0, Kcur, n_embd, n_tokens); + + ggml_tensor *K_flat = ggml_view_3d(ctx0, K_2d, n_embd, 1, n_tokens, + K_2d->nb[0], K_2d->nb[1], 0); + K_flat = ggml_cont(ctx0, K_flat); + ggml_tensor * V_flat = ggml_reshape_2d(ctx0, Vcur, n_embd, n_tokens); ggml_tensor * attn_out = build_attn( @@ -7712,8 +7717,8 @@ struct llm_build_modern_bert : public llm_graph_context { model.layers[il].wo, model.layers[il].bo, Qcur, - K_work, - V_work, + K_flat, + V_flat, kq_mask, nullptr, 1.0f / sqrtf(float(n_embd_head)),