diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 72211db17b..0b2b05c419 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2173,13 +2173,6 @@ llm_graph_cb llama_context::graph_get_cb() const { ggml_set_name(cur, name); } - if (!cparams.offload_kqv) { - if (strcmp(name, "kqv_merged_cont") == 0) { - // all nodes between the KV store and the attention output are run on the CPU - ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu); - } - } - // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends // FIXME: fix in ggml_backend_sched const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 16d42c4ae3..b3198b7e3a 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1630,6 +1630,11 @@ ggml_tensor * llm_graph_context::build_attn_mha( hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f); cb(cur, LLAMA_TENSOR_NAME_FATTN, il); + if (!cparams.offload_kqv) { + // all nodes between the KV store and the attention output are run on the CPU + ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu); + } + ggml_flash_attn_ext_add_sinks(cur, sinks); ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);