From d13ed9f2f14ac08d8e0bdd769bc53f7a92fb4f34 Mon Sep 17 00:00:00 2001 From: Aristeidis Stathopoulos Date: Sat, 7 Mar 2026 11:39:43 +0200 Subject: [PATCH] llama: add cb_pre_alloc callback for pre-allocation backend reassignment Add a new llama_pre_alloc_callback that fires after graph construction but before memory allocation in llama_decode/llama_encode. This allows downstream consumers to call ggml_backend_sched_set_tensor_backend() to route specific ops (e.g. attention) to a different backend without modifying llama.cpp internals. Changes: - Add llama_pre_alloc_callback typedef to llama.h - Add cb_pre_alloc + cb_pre_alloc_user_data to llama_context_params and llama_cparams - Invoke callback in process_ubatch() between build_graph and alloc_graph - Add test that verifies callback invocation and backend reassignment Co-Authored-By: Claude Opus 4.6 --- include/llama.h | 11 +++++ src/llama-context.cpp | 9 ++++ src/llama-cparams.h | 3 ++ tests/CMakeLists.txt | 1 + tests/test-pre-alloc-callback.cpp | 71 +++++++++++++++++++++++++++++++ 5 files changed, 95 insertions(+) create mode 100644 tests/test-pre-alloc-callback.cpp diff --git a/include/llama.h b/include/llama.h index a84d56a885..2b1fdcbded 100644 --- a/include/llama.h +++ b/include/llama.h @@ -212,6 +212,12 @@ extern "C" { typedef bool (*llama_progress_callback)(float progress, void * user_data); + // called after graph build but before memory allocation in llama_decode/llama_encode + // use ggml_backend_sched_set_tensor_backend() to reassign ops to a different backend + // NOTE: not called when a previous graph is reused; assignments from the last invocation + // persist. set LLAMA_GRAPH_REUSE_DISABLE=1 for per-decode control. + typedef void (*llama_pre_alloc_callback)(ggml_backend_sched_t sched, struct ggml_cgraph * gf, void * user_data); + // Input data for llama_encode/llama_decode // A llama_batch object can contain input about one or many sequences // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens @@ -350,6 +356,11 @@ extern "C" { ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; + // called after graph build but before memory allocation + // allows reassigning tensor backends via ggml_backend_sched_set_tensor_backend() + llama_pre_alloc_callback cb_pre_alloc; + void * cb_pre_alloc_user_data; + enum ggml_type type_k; // data type for K cache [EXPERIMENTAL] enum ggml_type type_v; // data type for V cache [EXPERIMENTAL] diff --git a/src/llama-context.cpp b/src/llama-context.cpp index abaa5c0f8d..4460ff4db7 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -62,6 +62,9 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; + cparams.cb_pre_alloc = params.cb_pre_alloc; + cparams.cb_pre_alloc_user_data = params.cb_pre_alloc_user_data; + // Initialize backend samplers here so they are part of the sampling graph // before the reserve passes run later in this function. This avoids a later // re-reserve when graph nodes change. @@ -1147,6 +1150,10 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll return nullptr; } + if (cparams.cb_pre_alloc) { + cparams.cb_pre_alloc(sched.get(), gf, cparams.cb_pre_alloc_user_data); + } + if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); ret = GGML_STATUS_ALLOC_FAILED; @@ -2833,6 +2840,8 @@ llama_context_params llama_context_default_params() { /*.defrag_thold =*/ -1.0f, /*.cb_eval =*/ nullptr, /*.cb_eval_user_data =*/ nullptr, + /*.cb_pre_alloc =*/ nullptr, + /*.cb_pre_alloc_user_data =*/ nullptr, /*.type_k =*/ GGML_TYPE_F16, /*.type_v =*/ GGML_TYPE_F16, /*.abort_callback =*/ nullptr, diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 333922468c..301d906fb2 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -43,4 +43,7 @@ struct llama_cparams { ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; + + llama_pre_alloc_callback cb_pre_alloc; + void * cb_pre_alloc_user_data; }; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 46ab7a0cef..c97f925389 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -238,6 +238,7 @@ llama_build_and_test(test-backend-ops.cpp) llama_build_and_test(test-model-load-cancel.cpp LABEL "model") llama_build_and_test(test-autorelease.cpp LABEL "model") +llama_build_and_test(test-pre-alloc-callback.cpp LABEL "model") llama_build_and_test(test-backend-sampler.cpp LABEL "model") # Test for state restore with fragmented KV cache diff --git a/tests/test-pre-alloc-callback.cpp b/tests/test-pre-alloc-callback.cpp new file mode 100644 index 0000000000..97b215cef7 --- /dev/null +++ b/tests/test-pre-alloc-callback.cpp @@ -0,0 +1,71 @@ +#include +#include + +#include "llama.h" +#include "get-model.h" + +struct callback_state { + bool called; + bool reassign_ok; +}; + +static void pre_alloc_cb(ggml_backend_sched_t sched, struct ggml_cgraph * gf, void * user_data) { + auto * state = static_cast(user_data); + state->called = true; + + // reassign the first node to the last backend (CPU) and verify + int n_backends = ggml_backend_sched_get_n_backends(sched); + if (n_backends < 1 || ggml_graph_n_nodes(gf) <= 0) { + return; + } + + ggml_backend_t target = ggml_backend_sched_get_backend(sched, n_backends - 1); + struct ggml_tensor * node = ggml_graph_node(gf, 0); + ggml_backend_sched_set_tensor_backend(sched, node, target); + state->reassign_ok = (ggml_backend_sched_get_tensor_backend(sched, node) == target); +} + +int main(int argc, char ** argv) { + auto * model_path = get_model_or_exit(argc, argv); + + llama_backend_init(); + auto * model = llama_model_load_from_file(model_path, llama_model_default_params()); + if (!model) { + fprintf(stderr, "FAIL: could not load model\n"); + return 1; + } + + callback_state state = { false, false }; + + auto params = llama_context_default_params(); + params.n_ctx = 64; + params.n_batch = 1; + params.cb_pre_alloc = pre_alloc_cb; + params.cb_pre_alloc_user_data = &state; + + auto * ctx = llama_init_from_model(model, params); + if (!ctx) { + fprintf(stderr, "FAIL: could not create context\n"); + llama_model_free(model); + llama_backend_free(); + return 1; + } + + llama_token token = 0; + if (llama_decode(ctx, llama_batch_get_one(&token, 1)) != 0) { + fprintf(stderr, "FAIL: llama_decode failed\n"); + llama_free(ctx); + llama_model_free(model); + llama_backend_free(); + return 1; + } + + fprintf(stderr, "called=%d reassign_ok=%d\n", state.called, state.reassign_ok); + + int ret = (state.called && state.reassign_ok) ? 0 : 1; + + llama_free(ctx); + llama_model_free(model); + llama_backend_free(); + return ret; +}