diff --git a/include/llama.h b/include/llama.h index c6e102abe5..27db6feb86 100644 --- a/include/llama.h +++ b/include/llama.h @@ -214,6 +214,12 @@ extern "C" { typedef bool (*llama_progress_callback)(float progress, void * user_data); + // called after graph build but before memory allocation in llama_decode/llama_encode + // use ggml_backend_sched_set_tensor_backend() to reassign graph nodes to a different backend + // NOTE: not called when a previous graph is reused; assignments from the last invocation + // persist. set LLAMA_GRAPH_REUSE_DISABLE=1 for per-decode control. + typedef void (*llama_pre_alloc_callback)(ggml_backend_sched_t sched, struct ggml_cgraph * gf, void * user_data); + // Input data for llama_encode/llama_decode // A llama_batch object can contain input about one or many sequences // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens @@ -352,6 +358,11 @@ extern "C" { ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; + // called after graph build but before memory allocation + // allows reassigning tensor backends via ggml_backend_sched_set_tensor_backend() + llama_pre_alloc_callback cb_pre_alloc; + void * cb_pre_alloc_user_data; + enum ggml_type type_k; // data type for K cache [EXPERIMENTAL] enum ggml_type type_v; // data type for V cache [EXPERIMENTAL] diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 1f7a52d789..8ed3b89d1d 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -63,6 +63,9 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; + cparams.cb_pre_alloc = params.cb_pre_alloc; + cparams.cb_pre_alloc_user_data = params.cb_pre_alloc_user_data; + // Initialize backend samplers here so they are part of the sampling graph // before the reserve passes run later in this function. This avoids a later // re-reserve when graph nodes change. @@ -1206,6 +1209,10 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll return nullptr; } + if (cparams.cb_pre_alloc) { + cparams.cb_pre_alloc(sched.get(), gf, cparams.cb_pre_alloc_user_data); + } + if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); ret = GGML_STATUS_ALLOC_FAILED; @@ -2893,6 +2900,8 @@ llama_context_params llama_context_default_params() { /*.defrag_thold =*/ -1.0f, /*.cb_eval =*/ nullptr, /*.cb_eval_user_data =*/ nullptr, + /*.cb_pre_alloc =*/ nullptr, + /*.cb_pre_alloc_user_data =*/ nullptr, /*.type_k =*/ GGML_TYPE_F16, /*.type_v =*/ GGML_TYPE_F16, /*.abort_callback =*/ nullptr, diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 9d35947413..9f3b3ed181 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -44,4 +44,7 @@ struct llama_cparams { ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; + + llama_pre_alloc_callback cb_pre_alloc; + void * cb_pre_alloc_user_data; }; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9582164b58..bc02285293 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -241,6 +241,7 @@ llama_build_and_test(test-backend-ops.cpp) llama_build_and_test(test-model-load-cancel.cpp LABEL "model") llama_build_and_test(test-autorelease.cpp LABEL "model") +llama_build_and_test(test-pre-alloc-callback.cpp LABEL "model") llama_build_and_test(test-backend-sampler.cpp LABEL "model") # Test for state restore with fragmented KV cache diff --git a/tests/test-pre-alloc-callback.cpp b/tests/test-pre-alloc-callback.cpp new file mode 100644 index 0000000000..0385fa84aa --- /dev/null +++ b/tests/test-pre-alloc-callback.cpp @@ -0,0 +1,88 @@ +#include + +#include "llama.h" +#include "get-model.h" + +struct callback_state { + bool called; + bool reassign_ok; +}; + +static void pre_alloc_cb(ggml_backend_sched_t sched, struct ggml_cgraph * gf, void * user_data) { + auto * state = static_cast(user_data); + state->called = true; + + // reassign the first node to a different backend and verify + int n_backends = ggml_backend_sched_get_n_backends(sched); + if (n_backends < 1 || ggml_graph_n_nodes(gf) <= 0) { + return; + } + + struct ggml_tensor * node = ggml_graph_node(gf, 0); + ggml_backend_t current = ggml_backend_sched_get_tensor_backend(sched, node); + ggml_backend_t target = current; + + for (int i = 0; i < n_backends; i++) { + ggml_backend_t candidate = ggml_backend_sched_get_backend(sched, i); + if (candidate != current) { + target = candidate; + break; + } + } + + if (target != current) { + ggml_backend_sched_set_tensor_backend(sched, node, target); + state->reassign_ok = (ggml_backend_sched_get_tensor_backend(sched, node) == target); + } else { + // only one backend available — can't test reassignment, just verify the callback was called + state->reassign_ok = true; + } +} + +int main(int argc, char ** argv) { + auto * model_path = get_model_or_exit(argc, argv); + + llama_backend_init(); + auto * model = llama_model_load_from_file(model_path, llama_model_default_params()); + if (!model) { + fprintf(stderr, "FAIL: could not load model\n"); + llama_backend_free(); + return 1; + } + + callback_state state = { false, false }; + + auto params = llama_context_default_params(); + params.n_ctx = 64; + params.n_batch = 1; + params.cb_pre_alloc = pre_alloc_cb; + params.cb_pre_alloc_user_data = &state; + + auto * ctx = llama_init_from_model(model, params); + if (!ctx) { + fprintf(stderr, "FAIL: could not create context\n"); + llama_model_free(model); + llama_backend_free(); + return 1; + } + + llama_token token = 0; + if (llama_decode(ctx, llama_batch_get_one(&token, 1)) != 0) { + fprintf(stderr, "FAIL: llama_decode failed\n"); + llama_free(ctx); + llama_model_free(model); + llama_backend_free(); + return 1; + } + + int ret = (state.called && state.reassign_ok) ? 0 : 1; + + if (ret != 0) { + fprintf(stderr, "FAIL: called=%d reassign_ok=%d\n", state.called, state.reassign_ok); + } + + llama_free(ctx); + llama_model_free(model); + llama_backend_free(); + return ret; +}