From 4df6e859e92dac52536f735785bb9e0a3bc63e2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Wed, 10 Dec 2025 16:16:20 +0100 Subject: [PATCH 01/11] cuda : add missing support check for xielu (#17895) --- ggml/src/ggml-cuda/ggml-cuda.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 279679a4ea..8d17bc669a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4313,6 +4313,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_EXPM1: case GGML_UNARY_OP_SOFTPLUS: case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_XIELU: case GGML_UNARY_OP_FLOOR: case GGML_UNARY_OP_CEIL: case GGML_UNARY_OP_ROUND: From 4dff236a522bd0ed949331d6cb1ee2a1b3615c35 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 10 Dec 2025 20:53:16 +0200 Subject: [PATCH 02/11] ggml : remove GGML_KQ_MASK_PAD constant (#17910) * ggml : remove GGML_KQ_MASK_PAD constant * cont : remove comment --- ggml/include/ggml.h | 12 +++++------- ggml/src/ggml.c | 2 -- src/llama-context.cpp | 8 -------- src/llama-graph.cpp | 20 ++++++++++---------- src/llama-kv-cache.cpp | 5 ++--- tests/test-backend-ops.cpp | 2 +- tools/mtmd/clip.cpp | 6 +----- 7 files changed, 19 insertions(+), 36 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 6bc762c069..686da3dbd1 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2305,13 +2305,11 @@ extern "C" { float stop, float step); -#define GGML_KQ_MASK_PAD 1 - - // q: [n_embd_k, n_batch, n_head, ne3 ] - // k: [n_embd_k, n_kv, n_head_kv, ne3 ] - // v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !! - // mask: [n_kv, n_batch_pad, ne32, ne33] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! - // res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !! + // q: [n_embd_k, n_batch, n_head, ne3 ] + // k: [n_embd_k, n_kv, n_head_kv, ne3 ] + // v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !! + // mask: [n_kv, n_batch, ne32, ne33] + // res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !! // // broadcast: // n_head % n_head_kv == 0 diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 530ff7b953..f0913cd359 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5260,8 +5260,6 @@ struct ggml_tensor * ggml_flash_attn_ext( if (mask) { GGML_ASSERT(ggml_is_contiguous(mask)); - GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && - "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); GGML_ASSERT(q->ne[2] % mask->ne[2] == 0); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 4171400713..2692297dca 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -93,14 +93,6 @@ llama_context::llama_context( // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; - // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask - // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) - // ref: https://github.com/ggerganov/llama.cpp/pull/5021 - // TODO: this padding is not needed for the cache-less context so we should probably move it to llama_memory - if (cparams.n_batch < GGML_KQ_MASK_PAD) { - LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); - cparams.n_batch = GGML_KQ_MASK_PAD; - } cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); cparams.op_offload = params.op_offload; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 43620df780..6cf9a883a6 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -385,7 +385,7 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there res &= self_kq_mask->ne[0] == mctx->get_n_kv(); - res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); + res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; return res; } @@ -416,10 +416,10 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv(); - res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); + res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv(); - res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); + res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens; return res; } @@ -452,7 +452,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } } - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int i = n_tokens; i < n_tokens; ++i) { for (int j = 0; j < n_enc; ++j) { data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY; } @@ -1470,13 +1470,13 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con auto inp = std::make_unique(hparams, cparams); // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1); ggml_set_input(inp->self_kq_mask_swa); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; @@ -1558,7 +1558,7 @@ static std::unique_ptr build_attn_inp_kv_impl( inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1701,7 +1701,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const { const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; - inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1); ggml_set_input(inp->cross_kq_mask); inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask; @@ -1767,7 +1767,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1781,7 +1781,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); + inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); ggml_set_input(inp->self_kq_mask_swa); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index e26385a1fe..3e02bd6297 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1232,8 +1232,7 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u GGML_ASSERT(n_tokens%n_stream == 0); // n_tps == n_tokens_per_stream - const int64_t n_tps = n_tokens/n_stream; - const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD); + const int64_t n_tps = n_tokens/n_stream; std::fill(data, data + ggml_nelements(dst), -INFINITY); @@ -1266,7 +1265,7 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0; const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0; - const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii); + const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii); for (uint32_t j = 0; j < n_kv; ++j) { if (cells.is_empty(j)) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a6f266601f..7be1f66038 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5875,7 +5875,7 @@ struct test_flash_attn_ext : public test_case { ggml_tensor * m = nullptr; if (mask) { - m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, nr23[1]); + m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, nb, 1, nr23[1]); ggml_set_name(m, "m"); } diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 3ed08a0fec..e5f7117dbf 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -775,10 +775,6 @@ struct clip_graph { // if flash attn is used, we need to pad the mask and cast to f16 if (ctx->flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { - int n_pad = GGML_PAD(window_mask->ne[1], GGML_KQ_MASK_PAD) - window_mask->ne[1]; - if (n_pad > 0) { - window_mask = ggml_pad(ctx0, window_mask, 0, n_pad, 0, 0); - } window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16); } @@ -791,7 +787,7 @@ struct clip_graph { // loop over layers for (int il = 0; il < n_layer; il++) { - auto & layer = model.layers[il]; + const auto & layer = model.layers[il]; const bool full_attn = use_window_attn ? (il + 1) % n_wa_pattern == 0 : true; ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states From e1f4921980444b30d3b4fa1c6e18cdecd85b0690 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Wed, 10 Dec 2025 12:32:23 -0800 Subject: [PATCH 03/11] Fix race conditions in threadpool when dealing with dynamic/frequent n_threads changes (#17748) * tests: update barrier test to check for race condition in active threads * cpu: combine n_graph and n_threads into a single atomic update * tests: add multi-graph test for test_barrier --- ggml/src/ggml-cpu/ggml-cpu.c | 73 +++++++-------- tests/test-barrier.cpp | 170 ++++++++++++++++++++++++++++++++--- 2 files changed, 190 insertions(+), 53 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index b468b115a1..c47511adcb 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -187,6 +187,9 @@ typedef void * thread_ret_t; typedef pthread_t ggml_thread_t; +#define GGML_THREADPOOL_N_THREADS_MASK (0xffffU) +#define GGML_THREADPOOL_N_THREADS_BITS (16) + #if defined(__APPLE__) #include #include @@ -449,7 +452,7 @@ struct ggml_threadpool { struct ggml_cplan * cplan; // synchronization primitives - atomic_int n_graph; // incremented when there is work to be done (i.e each graph) + atomic_int n_graph; // updated when there is work to be done (i.e each graph) holds graph and active thread counts. atomic_int GGML_CACHE_ALIGN n_barrier; atomic_int GGML_CACHE_ALIGN n_barrier_passed; atomic_int GGML_CACHE_ALIGN current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads. @@ -457,12 +460,10 @@ struct ggml_threadpool { // these are atomic as an annotation for thread-sanitizer atomic_bool stop; // Used for stopping the threadpool altogether atomic_bool pause; // Used for pausing the threadpool or individual threads - atomic_int abort; // Used for aborting processing of a graph + atomic_int abort; // Used for aborting processing of a graph struct ggml_compute_state * workers; // per thread state - int n_threads_max; // number of threads in the pool - atomic_int n_threads_cur; // number of threads used in the current graph - + int n_threads; // Number of threads in the pool int32_t prio; // Scheduling priority uint32_t poll; // Polling level (0 - no polling) @@ -539,7 +540,7 @@ struct ggml_state { static struct ggml_state g_state = {0}; void ggml_barrier(struct ggml_threadpool * tp) { - int n_threads = atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed); + int n_threads = atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK; if (n_threads == 1) { return; } @@ -556,7 +557,7 @@ void ggml_barrier(struct ggml_threadpool * tp) { // last thread atomic_store_explicit(&tp->n_barrier, 0, memory_order_relaxed); - // exit barrier (fill seq-cst fence) + // exit barrier (full seq-cst fence) atomic_fetch_add_explicit(&tp->n_barrier_passed, 1, memory_order_seq_cst); return; } @@ -2628,7 +2629,7 @@ static void ggml_thread_cpumask_next(const bool * global_mask, bool * local_mask void ggml_threadpool_free(struct ggml_threadpool* threadpool) { if (!threadpool) return; - const int n_threads = threadpool->n_threads_max; + const int n_threads = threadpool->n_threads; #ifndef GGML_USE_OPENMP struct ggml_compute_state* workers = threadpool->workers; @@ -2704,7 +2705,7 @@ struct ggml_cplan ggml_graph_plan( //GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads); } if (n_threads <= 0) { - n_threads = threadpool ? threadpool->n_threads_max : GGML_DEFAULT_N_THREADS; + n_threads = threadpool ? threadpool->n_threads : GGML_DEFAULT_N_THREADS; } #if defined(__EMSCRIPTEN__) && !defined(__EMSCRIPTEN_PTHREADS__) @@ -2912,12 +2913,14 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { struct ggml_compute_params params = { /*.ith =*/ state->ith, - /*.nth =*/ atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed), + /*.nth =*/ atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK, /*.wsize =*/ cplan->work_size, /*.wdata =*/ cplan->work_data, /*.threadpool=*/ tp, }; + GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d \n", state->ith, cplan, state->last_graph); + for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) { struct ggml_tensor * node = cgraph->nodes[node_n]; @@ -2939,6 +2942,8 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { } } + GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d \n", state->ith, cplan, state->last_graph); + ggml_barrier(state->threadpool); return 0; @@ -2946,27 +2951,23 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { #ifndef GGML_USE_OPENMP -// check if thread is active -static inline bool ggml_graph_compute_thread_active(struct ggml_compute_state * state) { - struct ggml_threadpool * threadpool = state->threadpool; - int n_threads = atomic_load_explicit(&threadpool->n_threads_cur, memory_order_relaxed); - return (state->ith < n_threads); -} - // check if thread is ready to proceed (exit from polling or sleeping) +// returns true if loops should exit, sets state->pending to indicate new work static inline bool ggml_graph_compute_thread_ready(struct ggml_compute_state * state) { struct ggml_threadpool * threadpool = state->threadpool; if (state->pending || threadpool->stop || threadpool->pause) { return true; } // check for new graph/work - int new_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed); - if (new_graph != state->last_graph) { - state->pending = ggml_graph_compute_thread_active(state); - state->last_graph = new_graph; + int n_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed); + int n_threads = n_graph & GGML_THREADPOOL_N_THREADS_MASK; + if (n_graph != state->last_graph) { + state->pending = (state->ith < n_threads); + state->last_graph = n_graph; + return true; } - return state->pending; + return false; } // sync thread state after polling @@ -2983,11 +2984,6 @@ static inline void ggml_graph_compute_thread_sync(struct ggml_compute_state * st static inline bool ggml_graph_compute_poll_for_work(struct ggml_compute_state * state) { struct ggml_threadpool * threadpool = state->threadpool; - // Skip polling for unused threads - if (!ggml_graph_compute_thread_active(state)) { - return state->pending; - } - // This seems to make 0 ... 100 a decent range for polling level across modern processors. // Perhaps, we can adjust it dynamically based on load and things. const uint64_t n_rounds = 1024UL * 128 * threadpool->poll; @@ -3049,7 +3045,6 @@ static thread_ret_t ggml_graph_compute_secondary_thread(void* data) { ggml_graph_compute_check_for_work(state); if (state->pending) { state->pending = false; - ggml_graph_compute_thread(state); } } @@ -3064,14 +3059,15 @@ static void ggml_graph_compute_kickoff(struct ggml_threadpool * threadpool, int ggml_mutex_lock(&threadpool->mutex); - GGML_PRINT_DEBUG("threadpool: n_threads_cur %d n_threads %d\n", threadpool->n_threads_cur, n_threads); + // Update the number of active threads and the graph count + int n_graph = atomic_load_explicit(&threadpool->n_graph, memory_order_relaxed) >> GGML_THREADPOOL_N_THREADS_BITS; + n_graph = ((n_graph + 1) << GGML_THREADPOOL_N_THREADS_BITS) | (n_threads & GGML_THREADPOOL_N_THREADS_MASK); - // Update the number of active threads - atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed); + GGML_PRINT_DEBUG("compute-kickoff: n_threads %d n_graph %d\n", n_threads, n_graph); // Indicate the graph is ready to be processed // We need the full seq-cst fence here because of the polling threads (used in thread_sync) - atomic_fetch_add_explicit(&threadpool->n_graph, 1, memory_order_seq_cst); + atomic_store_explicit(&threadpool->n_graph, n_graph, memory_order_seq_cst); if (threadpool->pause) { // Update main thread prio and affinity to match the threadpool settings @@ -3109,8 +3105,7 @@ static struct ggml_threadpool * ggml_threadpool_new_impl( threadpool->pause = tpp->paused; threadpool->abort = -1; threadpool->workers = NULL; - threadpool->n_threads_max = tpp->n_threads; - threadpool->n_threads_cur = tpp->n_threads; + threadpool->n_threads = tpp->n_threads; threadpool->poll = tpp->poll; threadpool->prio = tpp->prio; threadpool->ec = GGML_STATUS_SUCCESS; @@ -3205,7 +3200,7 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl { // update the number of threads from the actual number of threads that we got from OpenMP n_threads = omp_get_num_threads(); - atomic_store_explicit(&threadpool->n_threads_cur, n_threads, memory_order_relaxed); + atomic_store_explicit(&threadpool->n_graph, n_threads, memory_order_relaxed); } // Apply thread CPU mask and priority @@ -3218,13 +3213,13 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl ggml_graph_compute_thread(&threadpool->workers[ith]); } } else { - atomic_store_explicit(&threadpool->n_threads_cur, 1, memory_order_relaxed); + atomic_store_explicit(&threadpool->n_graph, 1, memory_order_relaxed); ggml_graph_compute_thread(&threadpool->workers[0]); } #else - if (n_threads > threadpool->n_threads_max) { - GGML_LOG_WARN("cplan requested more threads (%d) than available (%d)\n", n_threads, threadpool->n_threads_max); - n_threads = threadpool->n_threads_max; + if (n_threads > threadpool->n_threads) { + GGML_LOG_WARN("cplan requested more threads (%d) than available (%d)\n", n_threads, threadpool->n_threads); + n_threads = threadpool->n_threads; } // Kick all threads to start the new graph diff --git a/tests/test-barrier.cpp b/tests/test-barrier.cpp index 04c27761dc..61f73adfd2 100644 --- a/tests/test-barrier.cpp +++ b/tests/test-barrier.cpp @@ -11,19 +11,7 @@ #define MAX_NARGS 2 -int main(int argc, char *argv[]) { - - int n_threads = std::max(1, std::min(4, (int) std::thread::hardware_concurrency())); - int n_rounds = 100; - - if (argc > 1) { - n_threads = std::atoi(argv[1]); - } - - if (argc > 2) { - n_rounds = std::atoi(argv[2]); - } - +static void test_barrier(int n_threads, int n_rounds) { struct ggml_init_params params = { /* .mem_size = */ 1024*1024*1024, /* .mem_buffer = */ NULL, @@ -56,7 +44,7 @@ int main(int argc, char *argv[]) { exit(1); } - // Create compute plan + // The test runs with constant number of threads struct ggml_cplan cplan = ggml_graph_plan(gf, n_threads, threadpool); std::vector work_data(cplan.work_size); @@ -89,6 +77,160 @@ int main(int argc, char *argv[]) { ggml_threadpool_free(threadpool); ggml_free(ctx); +} + +static void test_active(int n_threads, int n_rounds) { + struct ggml_init_params params = { + /* .mem_size = */ 1024*1024*1024, + /* .mem_buffer = */ NULL, + /* .no_alloc = */ false, + }; + + struct ggml_context * ctx = ggml_init(params); + + // Create graph + struct ggml_cgraph * gf = ggml_new_graph(ctx); + + // Small graph with, parallel ops with barriers + struct ggml_tensor * out = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 64); + for (int i = 0; i < 2; i++) { + struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 64, 128); + out = ggml_mul_mat(ctx, a, out); + + struct ggml_tensor * d = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 128, 64); + out = ggml_mul_mat(ctx, d, out); + } + + ggml_build_forward_expand(gf, out); + int n_nodes = ggml_graph_n_nodes(gf); + + // Create threadpool + struct ggml_threadpool_params tpp = ggml_threadpool_params_default(n_threads); + struct ggml_threadpool* threadpool = ggml_threadpool_new(&tpp); + if (!threadpool) { + fprintf(stderr, "threadpool create failed : n_threads %d\n", n_threads); + exit(1); + } + + std::cerr << "graph-compute with" + << "\n n_threads: " << n_threads + << "\n n_nodes: " << n_nodes + << "\n n_rounds: " << n_rounds + << "\n"; + // ggml_graph_print(gf); + + // In this test we keep changing the number of threads every 4th iteration + // to test for race conditions in that path + + for (int i=0; i < n_rounds; i++) { + struct ggml_cplan cplan = ggml_graph_plan(gf, (i % 4) == 0 ? 1 : n_threads, threadpool); + + std::vector work_data(cplan.work_size); + cplan.work_data = work_data.data(); + + ggml_graph_compute(gf, &cplan); + } + + ggml_threadpool_free(threadpool); + ggml_free(ctx); +} + +static void test_multi_graph(int n_threads, int n_rounds) { + struct ggml_init_params params = { + /* .mem_size = */ 1024*1024*1024, + /* .mem_buffer = */ NULL, + /* .no_alloc = */ false, + }; + + struct ggml_context * ctx = ggml_init(params); + + // Create graphs + struct ggml_cgraph * gf0 = ggml_new_graph(ctx); + { + // Small graph with parallel ops with barriers + struct ggml_tensor * out = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 64); + for (int i = 0; i < 2; i++) { + struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 64, 128); + out = ggml_mul_mat(ctx, a, out); + + struct ggml_tensor * d = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 128, 64); + out = ggml_mul_mat(ctx, d, out); + } + + ggml_build_forward_expand(gf0, out); + } + + struct ggml_cgraph * gf1 = ggml_new_graph(ctx); + { + // Small graph with parallel ops with barriers + // Use larger tensors to make sure work_data size is larger than gf0 + struct ggml_tensor * out = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 256); + for (int i = 0; i < 4; i++) { + struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 256, 128); + out = ggml_mul_mat(ctx, a, out); + + struct ggml_tensor * d = ggml_new_tensor_2d(ctx, GGML_TYPE_Q4_0, 128, 256); + out = ggml_mul_mat(ctx, d, out); + } + + ggml_build_forward_expand(gf1, out); + } + + + // Create threadpool + struct ggml_threadpool_params tpp = ggml_threadpool_params_default(n_threads); + struct ggml_threadpool* threadpool = ggml_threadpool_new(&tpp); + if (!threadpool) { + fprintf(stderr, "threadpool create failed : n_threads %d\n", n_threads); + exit(1); + } + + std::cerr << "graph-compute with" + << "\n gf0 n_nodes: " << ggml_graph_n_nodes(gf0) + << "\n gf1 n_nodes: " << ggml_graph_n_nodes(gf1) + << "\n n_threads: " << n_threads + << "\n n_rounds: " << n_rounds + << "\n"; + + // In this test we keep changing the number of threads every 4th iteration + // and we compute two graphs back to back to test graph frequent graph switching + + for (int i=0; i < n_rounds; i++) { + struct ggml_cplan cplan0 = ggml_graph_plan(gf0, (i % 4) == 0 ? 1 : n_threads, threadpool); + std::vector work_data0(cplan0.work_size); + cplan0.work_data = work_data0.data(); + + struct ggml_cplan cplan1 = ggml_graph_plan(gf1, (i % 4) == 0 ? 1 : n_threads, threadpool); + std::vector work_data1(cplan1.work_size); + cplan1.work_data = work_data1.data(); + + ggml_graph_compute(gf0, &cplan0); + ggml_graph_compute(gf1, &cplan1); + } + + ggml_threadpool_free(threadpool); + ggml_free(ctx); +} + + +int main(int argc, char *argv[]) { + + int n_threads = std::max(1, std::min(4, (int) std::thread::hardware_concurrency())); + int n_rounds = 100; + + if (argc > 1) { + n_threads = std::atoi(argv[1]); + } + + if (argc > 2) { + n_rounds = std::atoi(argv[2]); + } + + test_barrier(n_threads, n_rounds); + + test_active(n_threads, n_rounds * 100); + + test_multi_graph(n_threads, n_rounds * 10); return 0; } From f32ca51bfeb74b8d1b735c2744c0e4a6224f6b7c Mon Sep 17 00:00:00 2001 From: Pascal Date: Wed, 10 Dec 2025 22:18:21 +0100 Subject: [PATCH 04/11] server: add presets (config) when using multiple models (#17859) * llama-server: recursive GGUF loading Replace flat directory scan with recursive traversal using std::filesystem::recursive_directory_iterator. Support for nested vendor/model layouts (e.g. vendor/model/*.gguf). Model name now reflects the relative path within --models-dir instead of just the filename. Aggregate files by parent directory via std::map before constructing local_model * server : router config POC (INI-based per-model settings) * server: address review feedback from @aldehir and @ngxson PEG parser usage improvements: - Simplify parser instantiation (remove arena indirection) - Optimize grammar usage (ws instead of zero_or_more, remove optional wrapping) - Fix last line without newline bug (+ operator instead of <<) - Remove redundant end position check Feature scope: - Remove auto-reload feature (will be separate PR per @ngxson) - Keep config.ini auto-creation and template generation - Preserve per-model customization logic Co-authored-by: aldehir Co-authored-by: ngxson * server: adopt aldehir's line-oriented PEG parser Complete rewrite of INI parser grammar and visitor: - Use p.chars(), p.negate(), p.any() instead of p.until() - Support end-of-line comments (key=value # comment) - Handle EOF without trailing newline correctly - Strict identifier validation ([a-zA-Z_][a-zA-Z0-9_.-]*) - Simplified visitor (no pending state, no trim needed) - Grammar handles whitespace natively via eol rule Business validation preserved: - Reject section names starting with LLAMA_ARG_* - Accept only keys starting with LLAMA_ARG_* - Require explicit section before key-value pairs Co-authored-by: aldehir * server: fix CLI/env duplication in child processes Children now receive minimal CLI args (executable, model, port, alias) instead of inheriting all router args. Global settings pass through LLAMA_ARG_* environment variables only, eliminating duplicate config warnings. Fixes: Router args like -ngl, -fa were passed both via CLI and env, causing 'will be overwritten' warnings on every child spawn * add common/preset.cpp * fix compile * cont * allow custom-path models * add falsey check * server: fix router model discovery and child process spawning - Sanitize model names: replace / and \ with _ for display - Recursive directory scan with relative path storage - Convert relative paths to absolute when spawning children - Filter router control args from child processes - Refresh args after port assignment for correct port value - Fallback preset lookup for compatibility - Fix missing argv[0]: store server binary path before base_args parsing * Revert "server: fix router model discovery and child process spawning" This reverts commit e3832b42eeea7fcb108995966c7584479f745857. * clarify about "no-" prefix * correct render_args() to include binary path * also remove arg LLAMA_ARG_MODELS_PRESET for child * add co-author for ini parser code Co-authored-by: aldehir * also set LLAMA_ARG_HOST * add CHILD_ADDR * Remove dead code --------- Co-authored-by: aldehir Co-authored-by: ngxson Co-authored-by: Xuan Son Nguyen Co-authored-by: aldehir --- common/CMakeLists.txt | 2 + common/arg.cpp | 72 ++++++++- common/arg.h | 32 +++- common/common.h | 7 +- common/preset.cpp | 180 ++++++++++++++++++++++ common/preset.h | 32 ++++ tools/server/CMakeLists.txt | 8 + tools/server/README.md | 50 ++++++ tools/server/server-models.cpp | 267 ++++++++++++++++++++++----------- tools/server/server-models.h | 27 +++- 10 files changed, 580 insertions(+), 97 deletions(-) create mode 100644 common/preset.cpp create mode 100644 common/preset.h diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 377b26846b..0182767c2b 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -73,6 +73,8 @@ add_library(${TARGET} STATIC ngram-cache.h peg-parser.cpp peg-parser.h + preset.cpp + preset.h regex-partial.cpp regex-partial.h sampling.cpp diff --git a/common/arg.cpp b/common/arg.cpp index 210ef8d621..b333f45c96 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -47,6 +47,7 @@ #define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 using json = nlohmann::ordered_json; +using namespace common_arg_utils; static std::initializer_list mmproj_examples = { LLAMA_EXAMPLE_MTMD, @@ -64,6 +65,15 @@ static std::string read_file(const std::string & fname) { return content; } +static const std::vector & get_common_arg_defs() { + static const std::vector options = [] { + common_params params; + auto ctx = common_params_parser_init(params, LLAMA_EXAMPLE_SERVER, nullptr); + return ctx.options; + }(); + return options; +} + common_arg & common_arg::set_examples(std::initializer_list examples) { this->examples = examples; return *this; @@ -134,7 +144,7 @@ static std::vector break_str_into_lines(std::string input, size_t m return result; } -std::string common_arg::to_string() { +std::string common_arg::to_string() const { // params for printing to console const static int n_leading_spaces = 40; const static int n_char_per_line_help = 70; // TODO: detect this based on current console @@ -647,6 +657,53 @@ static void add_rpc_devices(const std::string & servers) { } } +bool common_params_parse(int argc, char ** argv, llama_example ex, std::map & out_map) { + common_params dummy_params; + common_params_context ctx_arg = common_params_parser_init(dummy_params, ex, nullptr); + + std::unordered_map arg_to_options; + for (auto & opt : ctx_arg.options) { + for (const auto & arg : opt.args) { + arg_to_options[arg] = &opt; + } + } + + // TODO @ngxson : find a way to deduplicate this code + + // handle command line arguments + auto check_arg = [&](int i) { + if (i+1 >= argc) { + throw std::invalid_argument("expected value for argument"); + } + }; + + for (int i = 1; i < argc; i++) { + const std::string arg_prefix = "--"; + + std::string arg = argv[i]; + if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { + std::replace(arg.begin(), arg.end(), '_', '-'); + } + if (arg_to_options.find(arg) == arg_to_options.end()) { + throw std::invalid_argument(string_format("error: invalid argument: %s", arg.c_str())); + } + auto opt = *arg_to_options[arg]; + std::string val; + if (opt.value_hint != nullptr) { + // arg with single value + check_arg(i); + val = argv[++i]; + } + if (opt.value_hint_2 != nullptr) { + // TODO: support arg with 2 values + throw std::invalid_argument("error: argument with 2 values is not yet supported\n"); + } + out_map[opt] = val; + } + + return true; +} + bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) { auto ctx_arg = common_params_parser_init(params, ex, print_usage); const common_params params_org = ctx_arg.params; // the example can modify the default params @@ -692,15 +749,15 @@ static std::string list_builtin_chat_templates() { return msg.str(); } -static bool is_truthy(const std::string & value) { +bool common_arg_utils::is_truthy(const std::string & value) { return value == "on" || value == "enabled" || value == "1"; } -static bool is_falsey(const std::string & value) { +bool common_arg_utils::is_falsey(const std::string & value) { return value == "off" || value == "disabled" || value == "0"; } -static bool is_autoy(const std::string & value) { +bool common_arg_utils::is_autoy(const std::string & value) { return value == "auto" || value == "-1"; } @@ -2543,6 +2600,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.models_dir = value; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODELS_DIR")); + add_opt(common_arg( + {"--models-preset"}, "PATH", + "path to INI file containing model presets for the router server (default: disabled)", + [](common_params & params, const std::string & value) { + params.models_preset = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODELS_PRESET")); add_opt(common_arg( {"--models-max"}, "N", string_format("for router server, maximum number of models to load simultaneously (default: %d, 0 = unlimited)", params.models_max), diff --git a/common/arg.h b/common/arg.h index 7ab7e2cea4..219c115e63 100644 --- a/common/arg.h +++ b/common/arg.h @@ -3,8 +3,10 @@ #include "common.h" #include +#include #include #include +#include // // CLI argument parsing @@ -24,6 +26,8 @@ struct common_arg { void (*handler_str_str)(common_params & params, const std::string &, const std::string &) = nullptr; void (*handler_int) (common_params & params, int) = nullptr; + common_arg() = default; + common_arg( const std::initializer_list & args, const char * value_hint, @@ -61,9 +65,29 @@ struct common_arg { bool is_exclude(enum llama_example ex); bool get_value_from_env(std::string & output) const; bool has_value_from_env() const; - std::string to_string(); + std::string to_string() const; + + // for using as key in std::map + bool operator<(const common_arg& other) const { + if (args.empty() || other.args.empty()) { + return false; + } + return strcmp(args[0], other.args[0]) < 0; + } + bool operator==(const common_arg& other) const { + if (args.empty() || other.args.empty()) { + return false; + } + return strcmp(args[0], other.args[0]) == 0; + } }; +namespace common_arg_utils { + bool is_truthy(const std::string & value); + bool is_falsey(const std::string & value); + bool is_autoy(const std::string & value); +} + struct common_params_context { enum llama_example ex = LLAMA_EXAMPLE_COMMON; common_params & params; @@ -76,7 +100,11 @@ struct common_params_context { // if one argument has invalid value, it will automatically display usage of the specific argument (and not the full usage message) bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); -// function to be used by test-arg-parser +// parse input arguments from CLI into a map +// TODO: support repeated args in the future +bool common_params_parse(int argc, char ** argv, llama_example ex, std::map & out_map); + +// initialize argument parser context - used by test-arg-parser and preset common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); struct common_remote_params { diff --git a/common/common.h b/common/common.h index ad79f5b425..6119adcc0f 100644 --- a/common/common.h +++ b/common/common.h @@ -484,9 +484,10 @@ struct common_params { bool endpoint_metrics = false; // router server configs - std::string models_dir = ""; // directory containing models for the router server - int models_max = 4; // maximum number of models to load simultaneously - bool models_autoload = true; // automatically load models when requested via the router server + std::string models_dir = ""; // directory containing models for the router server + std::string models_preset = ""; // directory containing model presets for the router server + int models_max = 4; // maximum number of models to load simultaneously + bool models_autoload = true; // automatically load models when requested via the router server bool log_json = false; diff --git a/common/preset.cpp b/common/preset.cpp new file mode 100644 index 0000000000..09ac171b72 --- /dev/null +++ b/common/preset.cpp @@ -0,0 +1,180 @@ +#include "arg.h" +#include "preset.h" +#include "peg-parser.h" +#include "log.h" + +#include +#include +#include + +static std::string rm_leading_dashes(const std::string & str) { + size_t pos = 0; + while (pos < str.size() && str[pos] == '-') { + ++pos; + } + return str.substr(pos); +} + +std::vector common_preset::to_args() const { + std::vector args; + + for (const auto & [opt, value] : options) { + args.push_back(opt.args.back()); // use the last arg as the main arg + if (opt.value_hint == nullptr && opt.value_hint_2 == nullptr) { + // flag option, no value + if (common_arg_utils::is_falsey(value)) { + // skip the flag + args.pop_back(); + } + } + if (opt.value_hint != nullptr) { + // single value + args.push_back(value); + } + if (opt.value_hint != nullptr && opt.value_hint_2 != nullptr) { + throw std::runtime_error(string_format( + "common_preset::to_args(): option '%s' has two values, which is not supported yet", + opt.args.back() + )); + } + } + + return args; +} + +std::string common_preset::to_ini() const { + std::ostringstream ss; + + ss << "[" << name << "]\n"; + for (const auto & [opt, value] : options) { + auto espaced_value = value; + string_replace_all(espaced_value, "\n", "\\\n"); + ss << rm_leading_dashes(opt.args.back()) << " = "; + ss << espaced_value << "\n"; + } + ss << "\n"; + + return ss.str(); +} + +static std::map> parse_ini_from_file(const std::string & path) { + std::map> parsed; + + if (!std::filesystem::exists(path)) { + throw std::runtime_error("preset file does not exist: " + path); + } + + std::ifstream file(path); + if (!file.good()) { + throw std::runtime_error("failed to open server preset file: " + path); + } + + std::string contents((std::istreambuf_iterator(file)), std::istreambuf_iterator()); + + static const auto parser = build_peg_parser([](auto & p) { + // newline ::= "\r\n" / "\n" / "\r" + auto newline = p.rule("newline", p.literal("\r\n") | p.literal("\n") | p.literal("\r")); + + // ws ::= [ \t]* + auto ws = p.rule("ws", p.chars("[ \t]", 0, -1)); + + // comment ::= [;#] (!newline .)* + auto comment = p.rule("comment", p.chars("[;#]", 1, 1) + p.zero_or_more(p.negate(newline) + p.any())); + + // eol ::= ws comment? (newline / EOF) + auto eol = p.rule("eol", ws + p.optional(comment) + (newline | p.end())); + + // ident ::= [a-zA-Z_] [a-zA-Z0-9_.-]* + auto ident = p.rule("ident", p.chars("[a-zA-Z_]", 1, 1) + p.chars("[a-zA-Z0-9_.-]", 0, -1)); + + // value ::= (!eol-start .)* + auto eol_start = p.rule("eol-start", ws + (p.chars("[;#]", 1, 1) | newline | p.end())); + auto value = p.rule("value", p.zero_or_more(p.negate(eol_start) + p.any())); + + // header-line ::= "[" ws ident ws "]" eol + auto header_line = p.rule("header-line", "[" + ws + p.tag("section-name", p.chars("[^]]")) + ws + "]" + eol); + + // kv-line ::= ident ws "=" ws value eol + auto kv_line = p.rule("kv-line", p.tag("key", ident) + ws + "=" + ws + p.tag("value", value) + eol); + + // comment-line ::= ws comment (newline / EOF) + auto comment_line = p.rule("comment-line", ws + comment + (newline | p.end())); + + // blank-line ::= ws (newline / EOF) + auto blank_line = p.rule("blank-line", ws + (newline | p.end())); + + // line ::= header-line / kv-line / comment-line / blank-line + auto line = p.rule("line", header_line | kv_line | comment_line | blank_line); + + // ini ::= line* EOF + auto ini = p.rule("ini", p.zero_or_more(line) + p.end()); + + return ini; + }); + + common_peg_parse_context ctx(contents); + const auto result = parser.parse(ctx); + if (!result.success()) { + throw std::runtime_error("failed to parse server config file: " + path); + } + + std::string current_section = COMMON_PRESET_DEFAULT_NAME; + std::string current_key; + + ctx.ast.visit(result, [&](const auto & node) { + if (node.tag == "section-name") { + const std::string section = std::string(node.text); + current_section = section; + parsed[current_section] = {}; + } else if (node.tag == "key") { + const std::string key = std::string(node.text); + current_key = key; + } else if (node.tag == "value" && !current_key.empty() && !current_section.empty()) { + parsed[current_section][current_key] = std::string(node.text); + current_key.clear(); + } + }); + + return parsed; +} + +static std::map get_map_key_opt(common_params_context & ctx_params) { + std::map mapping; + for (const auto & opt : ctx_params.options) { + if (opt.env != nullptr) { + mapping[opt.env] = opt; + } + for (const auto & arg : opt.args) { + mapping[rm_leading_dashes(arg)] = opt; + } + } + return mapping; +} + +common_presets common_presets_load(const std::string & path, common_params_context & ctx_params) { + common_presets out; + auto key_to_opt = get_map_key_opt(ctx_params); + auto ini_data = parse_ini_from_file(path); + + for (auto section : ini_data) { + common_preset preset; + if (section.first.empty()) { + preset.name = COMMON_PRESET_DEFAULT_NAME; + } else { + preset.name = section.first; + } + LOG_DBG("loading preset: %s\n", preset.name.c_str()); + for (const auto & [key, value] : section.second) { + LOG_DBG("option: %s = %s\n", key.c_str(), value.c_str()); + if (key_to_opt.find(key) != key_to_opt.end()) { + preset.options[key_to_opt[key]] = value; + LOG_DBG("accepted option: %s = %s\n", key.c_str(), value.c_str()); + } else { + // TODO: maybe warn about unknown key? + } + } + out[preset.name] = preset; + } + + return out; +} diff --git a/common/preset.h b/common/preset.h new file mode 100644 index 0000000000..dceb849eb8 --- /dev/null +++ b/common/preset.h @@ -0,0 +1,32 @@ +#pragma once + +#include "common.h" +#include "arg.h" + +#include +#include +#include + +// +// INI preset parser and writer +// + +constexpr const char * COMMON_PRESET_DEFAULT_NAME = "default"; + +struct common_preset { + std::string name; + // TODO: support repeated args in the future + std::map options; + + // convert preset to CLI argument list + std::vector to_args() const; + + // convert preset to INI format string + std::string to_ini() const; + + // TODO: maybe implement to_env() if needed +}; + +// interface for multiple presets in one file +using common_presets = std::map; +common_presets common_presets_load(const std::string & path, common_params_context & ctx_params); diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index a39b4c5b35..ae1a497be6 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -38,6 +38,14 @@ set(TARGET_SRCS server-http.h server-models.cpp server-models.h + server-task.cpp + server-task.h + server-queue.cpp + server-queue.h + server-common.cpp + server-common.h + server-context.cpp + server-context.h ) set(PUBLIC_ASSETS index.html.gz diff --git a/tools/server/README.md b/tools/server/README.md index f98fb44c7b..d6b9b87dcf 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -1369,6 +1369,11 @@ llama-server ### Model sources +There are 3 possible sources for model files: +1. Cached models (controlled by the `LLAMA_CACHE` environment variable) +2. Custom model directory (set via the `--models-dir` argument) +3. Custom preset (set via the `--models-preset` argument) + By default, the router looks for models in the cache. You can add Hugging Face models to the cache with: ```sh @@ -1413,6 +1418,51 @@ llama-server -ctx 8192 -n 1024 -np 2 Note: model instances inherit both command line arguments and environment variables from the router server. +Alternatively, you can also add GGUF based preset (see next section) + +### Model presets + +Model presets allow advanced users to define custom configurations using an `.ini` file: + +```sh +llama-server --models-preset ./my-models.ini +``` + +Each section in the file defines a new preset. Keys within a section correspond to command-line arguments (without leading dashes). For example, the argument `--n-gpu-layer 123` is written as `n-gpu-layer = 123`. + +Short argument forms (e.g., `c`, `ngl`) and environment variable names (e.g., `LLAMA_ARG_N_GPU_LAYERS`) are also supported as keys. + +Example: + +```ini +version = 1 + +; If the key corresponds to an existing model on the server, +; this will be used as the default config for that model +[ggml-org/MY-MODEL-GGUF:Q8_0] +; string value +chat-template = chatml +; numeric value +n-gpu-layer = 123 +; flag value (for certain flags, you need to use the "no-" prefix for negation) +jinja = true +; shorthand argument (for example, context size) +c = 4096 +; environment variable name +LLAMA_ARG_CACHE_RAM = 0 +; file paths are relative to server's CWD +model-draft = ./my-models/draft.gguf +; but it's RECOMMENDED to use absolute path +model-draft = /Users/abc/my-models/draft.gguf + +; If the key does NOT correspond to an existing model, +; you need to specify at least the model path +[custom_model] +model = /Users/abc/my-awesome-model-Q4_K_M.gguf +``` + +Note: some arguments are controlled by router (e.g., host, port, API key, HF repo, model alias). They will be removed or overwritten upload loading. + ### Routing requests Requests are routed according to the requested model name. diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 6f88e93c4b..6c618a673c 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -1,6 +1,7 @@ #include "server-common.h" #include "server-models.h" +#include "preset.h" #include "download.h" #include // TODO: remove this once we use HTTP client from download.h @@ -33,6 +34,10 @@ #define CMD_EXIT "exit" +// address for child process, this is needed because router may run on 0.0.0.0 +// ref: https://github.com/ggml-org/llama.cpp/issues/17862 +#define CHILD_ADDR "127.0.0.1" + static std::filesystem::path get_server_exec_path() { #if defined(_WIN32) wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths @@ -132,6 +137,93 @@ static std::vector list_local_models(const std::string & dir) { return models; } +// +// server_presets +// + + +server_presets::server_presets(int argc, char ** argv, common_params & base_params, const std::string & presets_path) + : ctx_params(common_params_parser_init(base_params, LLAMA_EXAMPLE_SERVER)) { + if (!presets_path.empty()) { + presets = common_presets_load(presets_path, ctx_params); + SRV_INF("Loaded %zu presets from %s\n", presets.size(), presets_path.c_str()); + } + + // populate reserved args (will be appended by the router) + for (auto & opt : ctx_params.options) { + if (opt.env == nullptr) { + continue; + } + std::string env = opt.env; + if (env == "LLAMA_ARG_PORT" || + env == "LLAMA_ARG_HOST" || + env == "LLAMA_ARG_ALIAS" || + env == "LLAMA_ARG_API_KEY" || + env == "LLAMA_ARG_MODELS_DIR" || + env == "LLAMA_ARG_MODELS_MAX" || + env == "LLAMA_ARG_MODELS_PRESET" || + env == "LLAMA_ARG_MODEL" || + env == "LLAMA_ARG_MMPROJ" || + env == "LLAMA_ARG_HF_REPO" || + env == "LLAMA_ARG_NO_MODELS_AUTOLOAD") { + control_args[env] = opt; + } + } + + // read base args from router's argv + common_params_parse(argc, argv, LLAMA_EXAMPLE_SERVER, base_args); + + // remove any router-controlled args from base_args + for (const auto & cargs : control_args) { + auto it = base_args.find(cargs.second); + if (it != base_args.end()) { + base_args.erase(it); + } + } +} + +common_preset server_presets::get_preset(const std::string & name) { + auto it = presets.find(name); + if (it != presets.end()) { + return it->second; + } + return common_preset(); +} + +void server_presets::render_args(server_model_meta & meta) { + common_preset preset = meta.preset; // copy + // merging 3 kinds of args: + // 1. model-specific args (from preset) + // force removing control args if any + for (auto & cargs : control_args) { + if (preset.options.find(cargs.second) != preset.options.end()) { + SRV_WRN("Preset '%s' contains reserved arg '%s', removing it\n", preset.name.c_str(), cargs.second.args[0]); + preset.options.erase(cargs.second); + } + } + // 2. base args (from router) + // inherit from base args + for (const auto & [arg, value] : base_args) { + preset.options[arg] = value; + } + // 3. control args (from router) + // set control values + preset.options[control_args["LLAMA_ARG_HOST"]] = CHILD_ADDR; + preset.options[control_args["LLAMA_ARG_PORT"]] = std::to_string(meta.port); + preset.options[control_args["LLAMA_ARG_ALIAS"]] = meta.name; + if (meta.in_cache) { + preset.options[control_args["LLAMA_ARG_HF_REPO"]] = meta.name; + } else { + preset.options[control_args["LLAMA_ARG_MODEL"]] = meta.path; + if (!meta.path_mmproj.empty()) { + preset.options[control_args["LLAMA_ARG_MMPROJ"]] = meta.path_mmproj; + } + } + meta.args = preset.to_args(); + // add back the binary path at the front + meta.args.insert(meta.args.begin(), get_server_exec_path().string()); +} + // // server_models // @@ -140,7 +232,7 @@ server_models::server_models( const common_params & params, int argc, char ** argv, - char ** envp) : base_params(params) { + char ** envp) : base_params(params), presets(argc, argv, base_params, params.models_preset) { for (int i = 0; i < argc; i++) { base_args.push_back(std::string(argv[i])); } @@ -155,11 +247,58 @@ server_models::server_models( LOG_WRN("failed to get server executable path: %s\n", e.what()); LOG_WRN("using original argv[0] as fallback: %s\n", base_args[0].c_str()); } - // TODO: allow refreshing cached model list - // add cached models + load_models(); +} + +void server_models::add_model(server_model_meta && meta) { + if (mapping.find(meta.name) != mapping.end()) { + throw std::runtime_error(string_format("model '%s' appears multiple times", meta.name.c_str())); + } + presets.render_args(meta); // populate meta.args + std::string name = meta.name; + mapping[name] = instance_t{ + /* subproc */ std::make_shared(), + /* th */ std::thread(), + /* meta */ std::move(meta) + }; +} + +static std::vector list_custom_path_models(server_presets & presets) { + // detect any custom-path models in presets + std::vector custom_models; + for (auto & [model_name, preset] : presets.presets) { + local_model model; + model.name = model_name; + std::vector to_erase; + for (auto & [arg, value] : preset.options) { + std::string env(arg.env ? arg.env : ""); + if (env == "LLAMA_ARG_MODEL") { + model.path = value; + to_erase.push_back(arg); + } + if (env == "LLAMA_ARG_MMPROJ") { + model.path_mmproj = value; + to_erase.push_back(arg); + } + } + for (auto & arg : to_erase) { + preset.options.erase(arg); + } + if (!model.name.empty() && !model.path.empty()) { + custom_models.push_back(model); + } + } + return custom_models; +} + +// TODO: allow refreshing cached model list +void server_models::load_models() { + // loading models from 3 sources: + // 1. cached models auto cached_models = common_list_cached_models(); for (const auto & model : cached_models) { server_model_meta meta{ + /* preset */ presets.get_preset(model.to_string()), /* name */ model.to_string(), /* path */ model.manifest_path, /* path_mmproj */ "", // auto-detected when loading @@ -170,21 +309,18 @@ server_models::server_models( /* args */ std::vector(), /* exit_code */ 0 }; - mapping[meta.name] = instance_t{ - /* subproc */ std::make_shared(), - /* th */ std::thread(), - /* meta */ meta - }; + add_model(std::move(meta)); } - // add local models specificed via --models-dir - if (!params.models_dir.empty()) { - auto local_models = list_local_models(params.models_dir); + // 2. local models specificed via --models-dir + if (!base_params.models_dir.empty()) { + auto local_models = list_local_models(base_params.models_dir); for (const auto & model : local_models) { if (mapping.find(model.name) != mapping.end()) { // already exists in cached models, skip continue; } server_model_meta meta{ + /* preset */ presets.get_preset(model.name), /* name */ model.name, /* path */ model.path, /* path_mmproj */ model.path_mmproj, @@ -195,13 +331,31 @@ server_models::server_models( /* args */ std::vector(), /* exit_code */ 0 }; - mapping[meta.name] = instance_t{ - /* subproc */ std::make_shared(), - /* th */ std::thread(), - /* meta */ meta - }; + add_model(std::move(meta)); } } + // 3. custom-path models specified in presets + auto custom_models = list_custom_path_models(presets); + for (const auto & model : custom_models) { + server_model_meta meta{ + /* preset */ presets.get_preset(model.name), + /* name */ model.name, + /* path */ model.path, + /* path_mmproj */ model.path_mmproj, + /* in_cache */ false, + /* port */ 0, + /* status */ SERVER_MODEL_STATUS_UNLOADED, + /* last_used */ 0, + /* args */ std::vector(), + /* exit_code */ 0 + }; + add_model(std::move(meta)); + } + // log available models + SRV_INF("Available models (%zu) (*: custom preset)\n", mapping.size()); + for (const auto & [name, inst] : mapping) { + SRV_INF(" %c %s\n", inst.meta.preset.name.empty() ? ' ' : '*', name.c_str()); + } } void server_models::update_meta(const std::string & name, const server_model_meta & meta) { @@ -335,19 +489,7 @@ void server_models::unload_lru() { } } -static void add_or_replace_arg(std::vector & args, const std::string & key, const std::string & value) { - for (size_t i = 0; i < args.size(); i++) { - if (args[i] == key && i + 1 < args.size()) { - args[i + 1] = value; - return; - } - } - // not found, append - args.push_back(key); - args.push_back(value); -} - -void server_models::load(const std::string & name, bool auto_load) { +void server_models::load(const std::string & name) { if (!has_model(name)) { throw std::runtime_error("model name=" + name + " is not found"); } @@ -376,26 +518,10 @@ void server_models::load(const std::string & name, bool auto_load) { { SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port); - std::vector child_args; - if (auto_load && !meta.args.empty()) { - child_args = meta.args; // copy previous args - } else { - child_args = base_args; // copy - if (inst.meta.in_cache) { - add_or_replace_arg(child_args, "-hf", inst.meta.name); - } else { - add_or_replace_arg(child_args, "-m", inst.meta.path); - if (!inst.meta.path_mmproj.empty()) { - add_or_replace_arg(child_args, "--mmproj", inst.meta.path_mmproj); - } - } - } + presets.render_args(inst.meta); // update meta.args - // set model args - add_or_replace_arg(child_args, "--port", std::to_string(inst.meta.port)); - add_or_replace_arg(child_args, "--alias", inst.meta.name); - - std::vector child_env = base_env; // copy + std::vector child_args = inst.meta.args; // copy + std::vector child_env = base_env; // copy child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port)); SRV_INF("%s", "spawning server instance with args:\n"); @@ -541,7 +667,7 @@ bool server_models::ensure_model_loaded(const std::string & name) { } if (meta->status == SERVER_MODEL_STATUS_UNLOADED) { SRV_INF("model name=%s is not loaded, loading...\n", name.c_str()); - load(name, true); + load(name); } SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str()); @@ -571,7 +697,7 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port); auto proxy = std::make_unique( method, - base_params.hostname, + CHILD_ADDR, meta->port, req.path, req.headers, @@ -724,38 +850,6 @@ void server_models_routes::init_routes() { return models.proxy_request(req, method, name, true); // update last usage for POST request only }; - this->get_router_models = [this](const server_http_req &) { - auto res = std::make_unique(); - json models_json = json::array(); - auto all_models = models.get_all_meta(); - std::time_t t = std::time(0); - for (const auto & meta : all_models) { - json status { - {"value", server_model_status_to_string(meta.status)}, - {"args", meta.args}, - }; - if (meta.is_failed()) { - status["exit_code"] = meta.exit_code; - status["failed"] = true; - } - models_json.push_back(json { - {"id", meta.name}, - {"object", "model"}, // for OAI-compat - {"owned_by", "llamacpp"}, // for OAI-compat - {"created", t}, // for OAI-compat - {"in_cache", meta.in_cache}, - {"path", meta.path}, - {"status", status}, - // TODO: add other fields, may require reading GGUF metadata - }); - } - res_ok(res, { - {"data", models_json}, - {"object", "list"}, - }); - return res; - }; - this->post_router_models_load = [this](const server_http_req & req) { auto res = std::make_unique(); json body = json::parse(req.body); @@ -769,7 +863,7 @@ void server_models_routes::init_routes() { res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST)); return res; } - models.load(name, false); + models.load(name); res_ok(res, {{"success", true}}); return res; }; @@ -793,9 +887,12 @@ void server_models_routes::init_routes() { std::time_t t = std::time(0); for (const auto & meta : all_models) { json status { - {"value", server_model_status_to_string(meta.status)}, - {"args", meta.args}, + {"value", server_model_status_to_string(meta.status)}, + {"args", meta.args}, }; + if (!meta.preset.name.empty()) { + status["preset"] = meta.preset.to_ini(); + } if (meta.is_failed()) { status["exit_code"] = meta.exit_code; status["failed"] = true; diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 526e7488dc..9cdbbad9b6 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -1,6 +1,7 @@ #pragma once #include "common.h" +#include "preset.h" #include "server-http.h" #include @@ -47,6 +48,7 @@ static std::string server_model_status_to_string(server_model_status status) { } struct server_model_meta { + common_preset preset; std::string name; std::string path; std::string path_mmproj; // only available if in_cache=false @@ -54,7 +56,7 @@ struct server_model_meta { int port = 0; server_model_status status = SERVER_MODEL_STATUS_UNLOADED; int64_t last_used = 0; // for LRU unloading - std::vector args; // additional args passed to the model instance (used for debugging) + std::vector args; // args passed to the model instance, will be populated by render_args() int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED) bool is_active() const { @@ -66,6 +68,19 @@ struct server_model_meta { } }; +// the server_presets struct holds the presets read from presets.ini +// as well as base args from the router server +struct server_presets { + common_presets presets; + common_params_context ctx_params; + std::map base_args; + std::map control_args; // args reserved for server control + + server_presets(int argc, char ** argv, common_params & base_params, const std::string & models_dir); + common_preset get_preset(const std::string & name); + void render_args(server_model_meta & meta); +}; + struct subprocess_s; struct server_models { @@ -85,14 +100,21 @@ private: std::vector base_args; std::vector base_env; + server_presets presets; + void update_meta(const std::string & name, const server_model_meta & meta); // unload least recently used models if the limit is reached void unload_lru(); + // not thread-safe, caller must hold mutex + void add_model(server_model_meta && meta); + public: server_models(const common_params & params, int argc, char ** argv, char ** envp); + void load_models(); + // check if a model instance exists bool has_model(const std::string & name); @@ -102,8 +124,7 @@ public: // return a copy of all model metadata std::vector get_all_meta(); - // if auto_load is true, load the model with previous args if any - void load(const std::string & name, bool auto_load); + void load(const std::string & name); void unload(const std::string & name); void unload_all(); From 34a6d86982b54314516fd40ef5110525247528b8 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Wed, 10 Dec 2025 22:19:42 +0100 Subject: [PATCH 05/11] cli: enable jinja by default (#17911) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * cli: enable jinja by default * Update common/arg.cpp Co-authored-by: Sigbjørn Skjæret --------- Co-authored-by: Sigbjørn Skjæret --- common/arg.cpp | 10 ++-------- common/common.h | 2 +- tools/completion/completion.cpp | 4 ++++ tools/mtmd/mtmd-cli.cpp | 1 + 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index b333f45c96..a31dcbc689 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -762,12 +762,6 @@ bool common_arg_utils::is_autoy(const std::string & value) { } common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) { - // default values specific to example - // note: we place it here instead of inside server.cpp to allow llama-gen-docs to pick it up - if (ex == LLAMA_EXAMPLE_SERVER) { - params.use_jinja = true; - } - params.use_color = tty_can_use_colors(); // load dynamic backends @@ -2623,14 +2617,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_MODELS_AUTOLOAD")); add_opt(common_arg( {"--jinja"}, - string_format("use jinja template for chat (default: %s)\n", params.use_jinja ? "enabled" : "disabled"), + string_format("use jinja template for chat (default: %s)", params.use_jinja ? "enabled" : "disabled"), [](common_params & params) { params.use_jinja = true; } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_JINJA")); add_opt(common_arg( {"--no-jinja"}, - string_format("disable jinja template for chat (default: %s)\n", params.use_jinja ? "enabled" : "disabled"), + string_format("disable jinja template for chat (default: %s)", params.use_jinja ? "disabled" : "enabled"), [](common_params & params) { params.use_jinja = false; } diff --git a/common/common.h b/common/common.h index 6119adcc0f..2fd83f0cf9 100644 --- a/common/common.h +++ b/common/common.h @@ -464,7 +464,7 @@ struct common_params { std::string public_path = ""; // NOLINT std::string api_prefix = ""; // NOLINT std::string chat_template = ""; // NOLINT - bool use_jinja = false; // NOLINT + bool use_jinja = true; // NOLINT bool enable_chat_template = true; common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; int reasoning_budget = -1; diff --git a/tools/completion/completion.cpp b/tools/completion/completion.cpp index 79581eacb5..cb2641ae0a 100644 --- a/tools/completion/completion.cpp +++ b/tools/completion/completion.cpp @@ -86,6 +86,10 @@ static void sigint_handler(int signo) { int main(int argc, char ** argv) { common_params params; g_params = ¶ms; + + // disable jinja by default + params.use_jinja = false; + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMPLETION, print_usage)) { return 1; } diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index a75af406cd..ab7203d170 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -270,6 +270,7 @@ int main(int argc, char ** argv) { ggml_time_init(); common_params params; + params.use_jinja = false; // disable jinja by default params.sampling.temp = 0.2; // lower temp by default for better quality if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_MTMD, show_additional_info)) { From c6b2c9310cc53e5cd4f65bbb8d0cb498caf2ed1e Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Wed, 10 Dec 2025 22:20:06 +0100 Subject: [PATCH 06/11] mtmd: some small clean up (#17909) * clip: add support for fused qkv in build_vit * use bulid_ffn whenever possible * fix internvl * mtmd-cli: move image to beginning * test script: support custom args --- tools/mtmd/clip.cpp | 152 ++++++++++++++++++++++++---------------- tools/mtmd/mtmd-cli.cpp | 4 +- tools/mtmd/tests.sh | 51 +++++++++----- 3 files changed, 126 insertions(+), 81 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index e5f7117dbf..7360e8e09d 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -595,11 +595,12 @@ struct clip_graph { cur = ggml_mul(ctx0, cur, model.mm_input_norm_w); cur = ggml_add(ctx0, cur, model.mm_input_norm_b); - cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); - cur = ggml_add(ctx0, cur, model.mm_1_b); - cur = ggml_gelu(ctx0, cur); - cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); - cur = ggml_add(ctx0, cur, model.mm_2_b); + cur = build_ffn(cur, + model.mm_1_w, model.mm_1_b, + nullptr, nullptr, + model.mm_2_w, model.mm_2_b, + FFN_GELU, + -1); } else if (ctx->proj_type() == PROJECTOR_TYPE_JANUS_PRO) { cur = build_ffn(cur, @@ -667,16 +668,12 @@ struct clip_graph { // LlavaMultiModalProjector (always using GELU activation) { - cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); - if (model.mm_1_b) { - cur = ggml_add(ctx0, cur, model.mm_1_b); - } - - cur = ggml_gelu(ctx0, cur); - cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); - if (model.mm_2_b) { - cur = ggml_add(ctx0, cur, model.mm_2_b); - } + cur = build_ffn(cur, + model.mm_1_w, model.mm_1_b, + nullptr, nullptr, + model.mm_2_w, model.mm_2_b, + FFN_GELU, + -1); } // arrangement of the [IMG_BREAK] token @@ -866,16 +863,12 @@ struct clip_graph { // multimodal projection ggml_tensor * embeddings = inpL; embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size); - - embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); - embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); - - // GELU activation - embeddings = ggml_gelu(ctx0, embeddings); - - // Second linear layer - embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings); - embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); + embeddings = build_ffn(embeddings, + model.mm_0_w, model.mm_0_b, + nullptr, nullptr, + model.mm_1_w, model.mm_1_b, + FFN_GELU, + -1); if (use_window_attn) { window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4); @@ -1253,11 +1246,12 @@ struct clip_graph { // projector LayerNorm uses pytorch's default eps = 1e-5 // ref: https://huggingface.co/OpenGVLab/InternVL3-8B-Instruct/blob/a34d3e4e129a5856abfd6aa6de79776484caa14e/modeling_internvl_chat.py#L79 cur = build_norm(cur, model.mm_0_w, model.mm_0_b, NORM_TYPE_NORMAL, 1e-5, -1); - cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); - cur = ggml_add(ctx0, cur, model.mm_1_b); - cur = ggml_gelu(ctx0, cur); - cur = ggml_mul_mat(ctx0, model.mm_3_w, cur); - cur = ggml_add(ctx0, cur, model.mm_3_b); + cur = build_ffn(cur, + model.mm_1_w, model.mm_1_b, + nullptr, nullptr, + model.mm_3_w, model.mm_3_b, + FFN_GELU, + -1); } // build the graph @@ -1408,11 +1402,12 @@ struct clip_graph { cb(cur, "proj_inp_normed", -1); // projection mlp - cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); - cur = ggml_add(ctx0, cur, model.mm_1_b); - cur = ggml_gelu(ctx0, cur); - cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); - cur = ggml_add(ctx0, cur, model.mm_2_b); + cur = build_ffn(cur, + model.mm_1_w, model.mm_1_b, + nullptr, nullptr, + model.mm_2_w, model.mm_2_b, + FFN_GELU, + -1); cb(cur, "proj_out", -1); } @@ -1883,9 +1878,12 @@ struct clip_graph { } else if (ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL) { // projector - cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); - cur = ggml_gelu_erf(ctx0, cur); - cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); + cur = build_ffn(cur, + model.mm_1_w, model.mm_1_b, + nullptr, nullptr, + model.mm_2_w, model.mm_2_b, + FFN_GELU_ERF, + -1); } else { GGML_ABORT("%s: unknown projector type", __func__); @@ -2070,34 +2068,66 @@ private: // self-attention { - ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur); - if (layer.q_b) { - Qcur = ggml_add(ctx0, Qcur, layer.q_b); - } + ggml_tensor * Qcur = nullptr; + ggml_tensor * Kcur = nullptr; + ggml_tensor * Vcur = nullptr; + if (layer.qkv_w != nullptr) { + // fused qkv + cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); + if (layer.qkv_b != nullptr) { + cur = ggml_add(ctx0, cur, layer.qkv_b); + } - ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur); - if (layer.k_b) { - Kcur = ggml_add(ctx0, Kcur, layer.k_b); - } + Qcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* offset */ 0); - ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur); - if (layer.v_b) { - Vcur = ggml_add(ctx0, Vcur, layer.v_b); - } + Kcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* offset */ ggml_row_size(cur->type, n_embd)); - if (layer.q_norm) { - Qcur = build_norm(Qcur, layer.q_norm, NULL, norm_t, eps, il); - cb(Qcur, "Qcur_norm", il); - } + Vcur = ggml_view_3d(ctx0, cur, d_head, n_head, n_pos, + /* nb1 */ ggml_row_size(cur->type, d_head), + /* nb2 */ cur->nb[1], + /* offset */ ggml_row_size(cur->type, 2 * n_embd)); - if (layer.k_norm) { - Kcur = build_norm(Kcur, layer.k_norm, NULL, norm_t, eps, il); - cb(Kcur, "Kcur_norm", il); - } + // TODO: q/k norm requires row size == n_embd, while here it's d_head + // we can add support in the future if needed + GGML_ASSERT(layer.q_norm == nullptr && layer.k_norm == nullptr); - Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos); - Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos); - Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos); + } else { + // separate q, k, v + Qcur = ggml_mul_mat(ctx0, layer.q_w, cur); + if (layer.q_b) { + Qcur = ggml_add(ctx0, Qcur, layer.q_b); + } + + Kcur = ggml_mul_mat(ctx0, layer.k_w, cur); + if (layer.k_b) { + Kcur = ggml_add(ctx0, Kcur, layer.k_b); + } + + Vcur = ggml_mul_mat(ctx0, layer.v_w, cur); + if (layer.v_b) { + Vcur = ggml_add(ctx0, Vcur, layer.v_b); + } + + if (layer.q_norm) { + Qcur = build_norm(Qcur, layer.q_norm, NULL, norm_t, eps, il); + cb(Qcur, "Qcur_norm", il); + } + + if (layer.k_norm) { + Kcur = build_norm(Kcur, layer.k_norm, NULL, norm_t, eps, il); + cb(Kcur, "Kcur_norm", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos); + Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos); + Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos); + } cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index ab7203d170..25d24603db 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -318,7 +318,9 @@ int main(int argc, char ** argv) { g_is_generating = true; if (params.prompt.find(mtmd_default_marker()) == std::string::npos) { for (size_t i = 0; i < params.image.size(); i++) { - params.prompt += mtmd_default_marker(); + // most models require the marker before each image + // ref: https://github.com/ggml-org/llama.cpp/pull/17616 + params.prompt = mtmd_default_marker() + params.prompt; } } common_chat_msg msg; diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index 472f7d821c..82b486ec93 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -32,23 +32,32 @@ fi arr_prefix=() arr_hf=() -arr_tmpl=() # chat template +arr_extra_args=() arr_file=() add_test_vision() { local hf=$1 - local tmpl=${2:-""} # default to empty string if not provided + shift + local extra_args="" + if [ $# -gt 0 ]; then + extra_args=$(printf " %q" "$@") + fi arr_prefix+=("[vision]") arr_hf+=("$hf") - arr_tmpl+=("$tmpl") + arr_extra_args+=("$extra_args") arr_file+=("test-1.jpeg") } add_test_audio() { local hf=$1 + shift + local extra_args="" + if [ $# -gt 0 ]; then + extra_args=$(printf " %q" "$@") + fi arr_prefix+=("[audio] ") arr_hf+=("$hf") - arr_tmpl+=("") # no need for chat tmpl + arr_extra_args+=("$extra_args") arr_file+=("test-2.mp3") } @@ -56,9 +65,9 @@ add_test_vision "ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0" add_test_vision "ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M" add_test_vision "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0" add_test_vision "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M" -add_test_vision "THUDM/glm-edge-v-5b-gguf:Q4_K_M" -add_test_vision "second-state/Llava-v1.5-7B-GGUF:Q2_K" "vicuna" -add_test_vision "cjpais/llava-1.6-mistral-7b-gguf:Q3_K_M" "vicuna" +add_test_vision "THUDM/glm-edge-v-5b-gguf:Q4_K_M" -p "name of the newspaper?<__media__>" +add_test_vision "second-state/Llava-v1.5-7B-GGUF:Q2_K" --chat-template vicuna +add_test_vision "cjpais/llava-1.6-mistral-7b-gguf:Q3_K_M" --chat-template vicuna add_test_vision "ibm-research/granite-vision-3.2-2b-GGUF:Q4_K_M" add_test_vision "second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K" # model from openbmb is corrupted add_test_vision "openbmb/MiniCPM-V-2_6-gguf:Q2_K" @@ -79,7 +88,7 @@ add_test_audio "ggml-org/Voxtral-Mini-3B-2507-GGUF:Q4_K_M" # to test the big models, run: ./tests.sh big if [ "$RUN_BIG_TESTS" = true ]; then add_test_vision "ggml-org/pixtral-12b-GGUF:Q4_K_M" - add_test_vision "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" "mistral-v7" + add_test_vision "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" --chat-template mistral-v7 add_test_vision "ggml-org/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M" add_test_vision "ggml-org/Qwen2-VL-7B-Instruct-GGUF:Q4_K_M" add_test_vision "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M" @@ -89,7 +98,7 @@ if [ "$RUN_BIG_TESTS" = true ]; then add_test_vision "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M" add_test_vision "ggml-org/Qwen2.5-Omni-7B-GGUF:Q4_K_M" # add_test_vision "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra - add_test_vision "ggml-org/Kimi-VL-A3B-Thinking-2506-GGUF:Q4_K_M" + # add_test_vision "ggml-org/Kimi-VL-A3B-Thinking-2506-GGUF:Q4_K_M" # not always working add_test_audio "ggml-org/ultravox-v0_5-llama-3_1-8b-GGUF:Q4_K_M" add_test_audio "ggml-org/Qwen2.5-Omni-7B-GGUF:Q4_K_M" @@ -122,21 +131,25 @@ for i in "${!arr_hf[@]}"; do bin="llama-mtmd-cli" prefix="${arr_prefix[$i]}" hf="${arr_hf[$i]}" - tmpl="${arr_tmpl[$i]}" + extra_args="${arr_extra_args[$i]}" inp_file="${arr_file[$i]}" echo "Running test with binary: $bin and HF model: $hf" echo "" echo "" - output=$(\ - "$PROJ_ROOT/build/bin/$bin" \ - -hf "$hf" \ - --image $SCRIPT_DIR/$inp_file \ - -p "what is the publisher name of the newspaper?" \ + cmd="$(printf %q "$PROJ_ROOT/build/bin/$bin") \ + -hf $(printf %q "$hf") \ + --image $(printf %q "$SCRIPT_DIR/$inp_file") \ --temp 0 -n 128 \ - ${tmpl:+--chat-template "$tmpl"} \ - 2>&1 | tee /dev/tty) + ${extra_args}" + + # if extra_args does not contain -p, we add a default prompt + if ! [[ "$extra_args" =~ "-p" ]]; then + cmd+=" -p \"what is the publisher name of the newspaper?\"" + fi + + output=$(eval "$cmd" 2>&1 | tee /dev/tty) echo "$output" > $SCRIPT_DIR/output/$bin-$(echo "$hf" | tr '/' '-').log @@ -144,9 +157,9 @@ for i in "${!arr_hf[@]}"; do if echo "$output" | grep -iq "new york" \ || (echo "$output" | grep -iq "men" && echo "$output" | grep -iq "walk") then - result="$prefix \033[32mOK\033[0m: $bin $hf" + result="$prefix \033[32mOK\033[0m: $hf" else - result="$prefix \033[31mFAIL\033[0m: $bin $hf" + result="$prefix \033[31mFAIL\033[0m: $hf" fi echo -e "$result" arr_res+=("$result") From 45e350e3d3a663714d7d3d2397e14c2904534338 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Wed, 10 Dec 2025 23:24:31 +0100 Subject: [PATCH 07/11] ci: fix riscv64-native build (#17916) --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 182d433b1b..383427f36f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1770,7 +1770,7 @@ jobs: echo "Fetch llama2c model" wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories260K/stories260K.bin ./bin/llama-convert-llama2c-to-ggml --copy-vocab-from-model ./tok512.bin --llama2c-model stories260K.bin --llama2c-output-model stories260K.gguf - ./bin/llama-cli -m stories260K.gguf -p "One day, Lily met a Shoggoth" -n 500 -c 256 + ./bin/llama-completion -m stories260K.gguf -p "One day, Lily met a Shoggoth" -n 500 -c 256 ubuntu-cmake-sanitizer-riscv64-native: runs-on: RISCV64 From 34ce48d97a8cd5497ee418224da5bf422ed96673 Mon Sep 17 00:00:00 2001 From: nullname Date: Thu, 11 Dec 2025 06:45:43 +0800 Subject: [PATCH 08/11] ggml-hexagon: fix `rope` failure at `test-backend-ops` (#17565) * fix test failure * fix: correct scaling calculations in rope_cache_init * fix: optimize element copying in rope_hex_f32 using memcpy * fix: optimize loop boundaries in rope_hex_f32 for better performance * feat: add profiling macros for performance measurement in operations --- ggml/src/ggml-hexagon/htp/rope-ops.c | 78 +++++++++++++--------------- 1 file changed, 37 insertions(+), 41 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index 00419bcba6..a4399704fc 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -73,15 +73,15 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { return (1 - MIN(1, MAX(0, y))); } -static void rope_cache_init(const float theta_base, - float freq_scale, - const float * freq_factors, - float * corr_dims, - uint32_t ne0, - float ext_factor, - float mscale, - float * cache, - float theta_scale) { +static void rope_cache_init(const float theta_base, + const float freq_scale, + const float * freq_factors, + float * corr_dims, + const uint32_t ne0, + const float ext_factor, + const float mscale, + float * cache, + const float theta_scale) { // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py float theta = theta_base; @@ -92,18 +92,19 @@ static void rope_cache_init(const float theta_base, // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; - float theta2 = theta_interp; + float theta_final = theta_interp; + float mscale_final = mscale; if (ext_factor != 0.0f) { float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; - theta2 = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; // Get n-d magnitude scaling corrected for interpolation - mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); + mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale); } - cache[i0 + 0] = cosf(theta2) * mscale; - cache[i0 + 1] = sinf(theta2) * mscale; + cache[i0 + 0] = cosf(theta_final) * mscale_final; + cache[i0 + 1] = sinf(theta_final) * mscale_final; theta *= theta_scale; } @@ -151,9 +152,9 @@ static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context } static void hvx_calc_rope_neox_f32(const float * restrict src0, - float * restrict dst, - const int num_elems, - const float * restrict theta_cache) { + float * restrict dst, + const int num_elems, + const float * restrict theta_cache) { // for (int i = 0; i < num_elems; i += 2) { //const float cos_theta = theta_cache[i + 0]; //const float sin_theta = theta_cache[i + 1]; @@ -192,7 +193,7 @@ static void hvx_calc_rope_neox_f32(const float * restrict src0, HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4); + *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4); *(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5); src0_curr += VLEN; @@ -259,7 +260,7 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx, const uint32_t ir1, int nth, int ith, - int opt_path) { + const int opt_path) { struct htp_ops_context * octx = rope_ctx->octx; const struct htp_tensor * src0 = &octx->src0; @@ -267,8 +268,8 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx, const struct htp_tensor * src2 = &octx->src2; struct htp_tensor * dst = &octx->dst; - const int32_t mode = rope_ctx->mode; - const bool is_neox = mode & HTP_ROPE_TYPE_NEOX; + const int32_t mode = rope_ctx->mode; + const bool is_neox = mode & HTP_ROPE_TYPE_NEOX; htp_rope_preamble; @@ -281,8 +282,9 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx, freq_factors = (const float *) src2->data; } - int ir = 0; - + const uint32_t i1_end = MIN(ir1, ne1); + const int32_t half_dims = rope_ctx->n_dims / 2; + const size_t remain_bytes = (ne0 - rope_ctx->n_dims) * sizeof(float); for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len const int32_t p = pos[i2]; @@ -290,14 +292,7 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx, rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor, rope_ctx->attn_factor, wp0, rope_ctx->theta_scale); - for (uint32_t i1 = 0; i1 < ne1; i1++) { // attn-heads - if (ir++ < ir0) { - continue; - } - if (ir > ir1) { - break; - } - + for (uint32_t i1 = ir0; i1 < i1_end; i1++) { // attn-heads const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01); float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1); @@ -310,6 +305,9 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx, } else { hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); } + + src_loc += rope_ctx->n_dims; + dst_data_loc += rope_ctx->n_dims; } else { for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) { const float cos_theta = wp0[i0 + 0]; @@ -317,10 +315,10 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx, if (is_neox) { const float x0 = src_loc[0]; - const float x1 = src_loc[rope_ctx->n_dims/2]; + const float x1 = src_loc[half_dims]; - dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; - dst_data_loc[rope_ctx->n_dims/2] = x0 * sin_theta + x1 * cos_theta; + dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; + dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta; src_loc += 1; dst_data_loc += 1; @@ -335,15 +333,13 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx, dst_data_loc += 2; } } + + src_loc += (is_neox ? half_dims : 0); + dst_data_loc += (is_neox ? half_dims : 0); } - for (uint32_t i0 = rope_ctx->n_dims; i0 < ne0; i0 += 2) { - dst_data_loc[0] = src_loc[0]; - dst_data_loc[1] = src_loc[1]; - - src_loc += 2; - dst_data_loc += 2; - } + // TODO: use simd to speed up the remaining elements copy + memcpy(dst_data_loc, src_loc, remain_bytes); } } } From e4ae38331702aeb43b6ecc3f912d626171c9862a Mon Sep 17 00:00:00 2001 From: Yuichiro Utsumi <81412151+utsumi-fj@users.noreply.github.com> Date: Thu, 11 Dec 2025 18:12:07 +0900 Subject: [PATCH 09/11] docs: use port 8080 in Docker examples (#17903) --- docs/docker.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/docker.md b/docs/docker.md index 98502a0c50..b9e5015396 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -56,7 +56,7 @@ docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:light -m /model or with a server image: ```bash -docker run -v /path/to/models:/models -p 8000:8000 ghcr.io/ggml-org/llama.cpp:server -m /models/7B/ggml-model-q4_0.gguf --port 8000 --host 0.0.0.0 -n 512 +docker run -v /path/to/models:/models -p 8080:8080 ghcr.io/ggml-org/llama.cpp:server -m /models/7B/ggml-model-q4_0.gguf --port 8080 --host 0.0.0.0 -n 512 ``` ## Docker With CUDA @@ -91,7 +91,7 @@ After building locally, Usage is similar to the non-CUDA examples, but you'll ne ```bash docker run --gpus all -v /path/to/models:/models local/llama.cpp:full-cuda --run -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1 docker run --gpus all -v /path/to/models:/models local/llama.cpp:light-cuda -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1 -docker run --gpus all -v /path/to/models:/models local/llama.cpp:server-cuda -m /models/7B/ggml-model-q4_0.gguf --port 8000 --host 0.0.0.0 -n 512 --n-gpu-layers 1 +docker run --gpus all -v /path/to/models:/models local/llama.cpp:server-cuda -m /models/7B/ggml-model-q4_0.gguf --port 8080 --host 0.0.0.0 -n 512 --n-gpu-layers 1 ``` ## Docker With MUSA @@ -125,5 +125,5 @@ After building locally, Usage is similar to the non-MUSA examples, but you'll ne ```bash docker run -v /path/to/models:/models local/llama.cpp:full-musa --run -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1 docker run -v /path/to/models:/models local/llama.cpp:light-musa -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1 -docker run -v /path/to/models:/models local/llama.cpp:server-musa -m /models/7B/ggml-model-q4_0.gguf --port 8000 --host 0.0.0.0 -n 512 --n-gpu-layers 1 +docker run -v /path/to/models:/models local/llama.cpp:server-musa -m /models/7B/ggml-model-q4_0.gguf --port 8080 --host 0.0.0.0 -n 512 --n-gpu-layers 1 ``` From d9f8f60618a1df2797cb7df4ad1272f71d6bd7b2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 11 Dec 2025 14:29:47 +0200 Subject: [PATCH 10/11] batch : fix sequence id ownership (#17915) * batch : fix sequence id ownage * cont : reduce allocations --- src/llama-batch.cpp | 14 ++++++++++++-- src/llama-batch.h | 6 ++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 86a1a4ba18..386fab04ac 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -695,6 +695,8 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u udata->seq_idx .resize(LLAMA_MAX_SEQ, -1); udata->output .resize(n_tokens); + udata->seq_id_data.reserve(n_tokens); + seq_set_t seq_set_unq; for (size_t i = 0; i < idxs.size(); ++i) { @@ -716,11 +718,13 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u } udata->n_seq_id[i] = batch.n_seq_id[idxs[i]]; - udata->seq_id[i] = batch.seq_id[idxs[i]]; udata->output[i] = batch.logits[idxs[i]]; for (int s = 0; s < udata->n_seq_id[i]; ++s) { - seq_set_unq.set(udata->seq_id[i][s]); + const llama_seq_id seq_id = batch.seq_id[idxs[i]][s]; + + udata->seq_id_data.push_back(seq_id); + seq_set_unq.set(seq_id); } if (udata->output[i]) { @@ -728,6 +732,12 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u } } + llama_seq_id * seq_id_ptr = udata->seq_id_data.data(); + for (size_t i = 0; i < idxs.size(); ++i) { + udata->seq_id[i] = seq_id_ptr; + seq_id_ptr += udata->n_seq_id[i]; + } + for (uint32_t s = 0; s < n_seq_max; ++s) { if (seq_set_unq.test(s)) { udata->seq_idx[s] = udata->seq_id_unq.size(); diff --git a/src/llama-batch.h b/src/llama-batch.h index 209cf3699d..8e6fac0efa 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -56,13 +56,15 @@ struct llama_ubatch { std::vector embd; std::vector pos; std::vector n_seq_id; - std::vector seq_id; + std::vector seq_id; // these point into the seq_id_data below std::vector seq_id_unq; std::vector seq_idx; std::vector output; + + std::vector seq_id_data; }; - // the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data + // the llama_ubatch pointers above point to this data if set. otherwise - point to external non-owning data std::shared_ptr data; }; From c6f6e4f96a7f7bce49f5c21d19ee69fb8b72f84d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 11 Dec 2025 14:30:10 +0200 Subject: [PATCH 11/11] ggml-alloc : fix reuse-parent logic for misaligned sizes (#17884) --- ggml/src/ggml-alloc.c | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index a5995fdc2c..ec16cbda9f 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -312,16 +312,9 @@ static struct buffer_address ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * al } // this is a very naive implementation, but for our case the number of free blocks should be very small -static void ggml_dyn_tallocr_free_tensor(struct ggml_dyn_tallocr * alloc, struct buffer_address addr, size_t size, const struct ggml_tensor * tensor) { +static void ggml_dyn_tallocr_free_bytes(struct ggml_dyn_tallocr * alloc, struct buffer_address addr, size_t size) { size = aligned_offset(NULL, size, alloc->alignment); - AT_PRINTF("%s: freeing %s at {chunk=%d, offset=%zu} (%zu bytes) - n_free_blocks = %d\n", - __func__, tensor->name, addr.chunk, addr.offset, size, alloc->chunks[addr.chunk]->n_free_blocks); - -#ifdef GGML_ALLOCATOR_DEBUG - remove_allocated_tensor(alloc, addr, tensor); -#endif - struct tallocr_chunk * chunk = alloc->chunks[addr.chunk]; // see if we can merge with an existing block @@ -357,8 +350,6 @@ static void ggml_dyn_tallocr_free_tensor(struct ggml_dyn_tallocr * alloc, struct } // otherwise, add a new block ggml_dyn_tallocr_insert_block(chunk, addr.offset, size); - - GGML_UNUSED(tensor); } static void ggml_dyn_tallocr_reset(struct ggml_dyn_tallocr * alloc) { @@ -616,13 +607,17 @@ static void ggml_gallocr_free_extra_space(ggml_gallocr_t galloc, struct ggml_ten GGML_ASSERT(parent_size >= node_size); + // note: we want after the freeing the chunks to continue to be aligned + struct ggml_dyn_tallocr * p_alloc = galloc->buf_tallocs[p_hn->buffer_id]; + parent_size = aligned_offset(NULL, parent_size, p_alloc->alignment); + node_size = aligned_offset(NULL, node_size, p_alloc->alignment); + if (parent_size > node_size) { - struct ggml_dyn_tallocr * p_alloc = galloc->buf_tallocs[p_hn->buffer_id]; struct buffer_address p_addr = p_hn->addr; p_addr.offset += node_size; size_t extra_size = parent_size - node_size; AT_PRINTF("freeing extra %zu bytes from parent %s for %s\n", extra_size, parent->name, node->name); - ggml_dyn_tallocr_free_tensor(p_alloc, p_addr, extra_size, parent); + ggml_dyn_tallocr_free_bytes(p_alloc, p_addr, extra_size); } } @@ -706,7 +701,14 @@ static void ggml_gallocr_free_node(ggml_gallocr_t galloc, struct ggml_tensor * n struct ggml_dyn_tallocr * alloc = galloc->buf_tallocs[buffer_id]; ggml_backend_buffer_type_t buft = galloc->bufts[buffer_id]; size_t size = ggml_backend_buft_get_alloc_size(buft, node); - ggml_dyn_tallocr_free_tensor(alloc, hn->addr, size, node); + + AT_PRINTF("%s: freeing %s at {chunk=%d, offset=%zu} (%zu bytes) - n_free_blocks = %d\n", + __func__, node->name, hn->addr.chunk, hn->addr.offset, size, alloc->chunks[hn->addr.chunk]->n_free_blocks); +#ifdef GGML_ALLOCATOR_DEBUG + remove_allocated_tensor(alloc, hn->addr, node); +#endif + + ggml_dyn_tallocr_free_bytes(alloc, hn->addr, size); hn->allocated = false; }