diff --git a/src/llama-model.cpp b/src/llama-model.cpp index bd687036a9..d43d637672 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -8098,6 +8098,11 @@ struct llm_build_modern_bert : public llm_graph_context { Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + // re-add the layer input cur = ggml_add(ctx0, cur, inpL);