From 2e7ef98f18090b382611c135efc417200b23780b Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 28 Nov 2025 20:34:51 +0800 Subject: [PATCH 01/13] ggml-cuda: add stricter checking for fusion (#17568) * ggml-cuda: make conditions for fusion more explicit * ggml-cuda: remove size check as std::equal already does it --- ggml/src/ggml-cuda/ggml-cuda.cu | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 6463921a6e..a844a3d99a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3050,7 +3050,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list topk_moe_ops_delayed_softmax = ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true); - if (ops.size() == topk_moe_ops_with_norm.size() && + const auto is_equal = [](const std::initializer_list & list1, + const std::initializer_list & list2) { + return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end()); + }; + + if (is_equal(topk_moe_ops_with_norm, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * weights = cgraph->nodes[node_idx + 9]; @@ -3060,8 +3065,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } - if (ops.size() == topk_moe_ops.size() && - ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) { + if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * weights = cgraph->nodes[node_idx + 4]; if (ggml_cuda_should_use_topk_moe(softmax, weights)) { @@ -3069,7 +3073,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } - if (ops.size() == topk_moe_ops_delayed_softmax.size() && + if (is_equal(topk_moe_ops_delayed_softmax, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) { ggml_tensor * softmax = cgraph->nodes[node_idx + 4]; ggml_tensor * weights = cgraph->nodes[node_idx + 5]; @@ -3085,9 +3089,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU }; std::initializer_list mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU }; - if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) || - ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) { - + if ((is_equal(mul_mat_bias_glu_ops, ops) || is_equal(mul_mat_id_bias_glu_ops, ops)) && + ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 4 })) { const ggml_tensor * ffn_gate = cgraph->nodes[node_idx]; const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1]; const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2]; @@ -3099,9 +3102,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } - if (ops.size() == 3 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}) || - ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}))) { - + if ((is_equal(mul_mat_id_glu_ops, ops) || is_equal(mul_mat_glu_ops, ops)) && + ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) { const ggml_tensor * ffn_gate = cgraph->nodes[node_idx]; const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1]; const ggml_tensor * glu = cgraph->nodes[node_idx + 2]; @@ -3111,7 +3113,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } - if (ops.size() == 3 && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) { + std::initializer_list rope_set_rows_ops = { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }; + + if (is_equal(rope_set_rows_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) { const ggml_tensor * rope = cgraph->nodes[node_idx]; const ggml_tensor * view = cgraph->nodes[node_idx + 1]; const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2]; From c6f7a423c8c87748ef563a99d81c3b1b05cecff0 Mon Sep 17 00:00:00 2001 From: R0CKSTAR Date: Fri, 28 Nov 2025 21:08:29 +0800 Subject: [PATCH 02/13] [MUSA] enable fp16/fast_fp16/bf16_mma on PH1 (#17551) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [MUSA] enable fp16/fast_fp16/bf16_mma on PH1 Signed-off-by: Xiaodong Ye * Update ggml/src/ggml-cuda/fattn-vec.cuh Co-authored-by: Johannes Gäßler * Update ggml/src/ggml-cuda/fattn-vec.cuh Co-authored-by: Johannes Gäßler * Update ggml/src/ggml-cuda/fattn-tile.cuh Co-authored-by: Johannes Gäßler * Address review comments Signed-off-by: Xiaodong Ye --------- Signed-off-by: Xiaodong Ye Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/common.cuh | 20 ++++++++++++-------- ggml/src/ggml-cuda/cpy.cu | 5 ++++- ggml/src/ggml-cuda/fattn-tile.cuh | 2 +- ggml/src/ggml-cuda/fattn-vec.cuh | 4 ++-- ggml/src/ggml-cuda/mma.cuh | 4 ++-- 5 files changed, 21 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 81fb312409..0b10e5f6ae 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -84,12 +84,12 @@ #define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000 #define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000 -#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD +#define GGML_CUDA_CC_PH1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // MTT S5000 #define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD) #define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2) -#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG) -#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG) +#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_PH1) +#define GGML_CUDA_CC_IS_PH1(cc) (cc >= GGML_CUDA_CC_PH1) #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 # define GGML_CUDA_USE_CUB @@ -212,9 +212,9 @@ static const char * cu_get_error_str(CUresult err) { #define GGML_USE_VMM #endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) -#if defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL +#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL #define FP16_AVAILABLE -#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL +#endif // defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL #if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 #define FAST_FP16_AVAILABLE @@ -250,12 +250,14 @@ static const char * cu_get_error_str(CUresult err) { #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220) static bool fp16_available(const int cc) { - return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL; + return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL || + (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1); } static bool fast_fp16_available(const int cc) { return GGML_CUDA_CC_IS_AMD(cc) || - (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610); + (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610) || + (GGML_CUDA_CC_IS_MTHREADS(cc) && fp16_available(cc)); } // To be used for feature selection of external libraries, e.g. cuBLAS. @@ -272,7 +274,9 @@ static bool fp16_mma_hardware_available(const int cc) { } static bool bf16_mma_hardware_available(const int cc) { - return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3; + return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || + GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3 || + (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1); } static bool fp32_mma_hardware_available(const int cc) { diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 82367ad3fb..c4ceb4fc57 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -86,6 +86,9 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const } } } + + GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, + nb12, nb13); } static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { @@ -202,7 +205,7 @@ static void ggml_cpy_scalar_cuda( ne00n = ne00; ne01n = ne01; ne02n = ne02; - } else if (nb00 > nb02) { + } else { ne00n = ne00; ne01n = ne01*ne02; ne02n = 1; diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index c358aa1e87..3e58d64ff9 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -609,7 +609,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( float KQ_sum_add = 0.0f; #pragma unroll for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { - const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ? + const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < static_cast(k_VKQ_sup) ? expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f; KQ_sum_add += val; tmp[i0/(np*warp_size)][jc1] = val; diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 6e63e860ac..67aa67ecb9 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -155,7 +155,7 @@ static __global__ void flash_attn_ext_vec( for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE <= D/sizeof(int) || i < D/sizeof(int)) { + if (i0 + WARP_SIZE <= int(D/sizeof(int)) || i < int(D/sizeof(int))) { tmp_q_i32[i] = 0; } } @@ -272,7 +272,7 @@ static __global__ void flash_attn_ext_vec( KQ_max_new[j] = fmaxf(KQ_max_new[j], sum); - if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == i_KQ_0) { + if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) { KQ_reg[j] = sum; } } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index c0a9c2c08a..0ed42e87d3 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -889,8 +889,8 @@ namespace ggml_cuda_mma { : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) : "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7])); #else - tile<16, 8, float> * D16 = (tile<16, 8, float> *) &D; - tile<16, 8, half2> * A16 = (tile<16, 8, half2> *) &A; + tile <16, 8, float> * D16 = reinterpret_cast *>(&D); + const tile<16, 8, half2> * A16 = reinterpret_cast *>(&A); mma(D16[0], A16[0], B); mma(D16[1], A16[1], B); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE From e072b2052e9250395e4a28a28d37806342ac5db1 Mon Sep 17 00:00:00 2001 From: Diego Devesa Date: Fri, 28 Nov 2025 07:33:23 -0800 Subject: [PATCH 03/13] ggml : add GGML_SCHED_NO_REALLOC option to disable reallocations in ggml_backend_sched (#17276) * ggml : add GGML_SCHED_NO_REALLOC option to disable reallocations in ggml_backend_sched Enabled in ggml-ci for testing. * llama : update worst-case graph for unified cache * ci : disable op offload in some tests * fix spelling --------- Co-authored-by: Georgi Gerganov --- ci/run.sh | 16 ++++++++-------- examples/embedding/embedding.cpp | 7 ++++--- ggml/CMakeLists.txt | 1 + ggml/src/CMakeLists.txt | 4 ++++ ggml/src/ggml-alloc.c | 11 ++++++++--- ggml/src/ggml-backend.cpp | 12 +++++++++--- src/llama-context.cpp | 4 ++-- tests/CMakeLists.txt | 2 +- 8 files changed, 37 insertions(+), 20 deletions(-) diff --git a/ci/run.sh b/ci/run.sh index 3fec8e9110..1dd65adeaa 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -45,7 +45,7 @@ sd=`dirname $0` cd $sd/../ SRC=`pwd` -CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON" +CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON -DGGML_SCHED_NO_REALLOC=ON" if [ ! -z ${GG_BUILD_METAL} ]; then CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON" @@ -428,10 +428,10 @@ function gg_run_qwen3_0_6b { (time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -ngl 99 -c 1024 -b 512 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log - (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log function check_ppl { qnt="$1" @@ -523,8 +523,8 @@ function gg_run_embd_bge_small { ./bin/llama-quantize ${model_f16} ${model_q8_0} q8_0 - (time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log - (time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log + (time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-f16.log + (time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log set +e } @@ -564,7 +564,7 @@ function gg_run_rerank_tiny { model_f16="${path_models}/ggml-model-f16.gguf" # for this model, the SEP token is "" - (time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log + (time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --no-op-offload --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log # sample output # rerank score 0: 0.029 diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 9e3ab5905b..fe91b308cd 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -104,12 +104,16 @@ int main(int argc, char ** argv) { params.embedding = true; + // get max number of sequences per batch + const int n_seq_max = llama_max_parallel_sequences(); + // if the number of prompts that would be encoded is known in advance, it's more efficient to specify the // --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache // in order to support any number of prompts if (params.n_parallel == 1) { LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__); params.kv_unified = true; + params.n_parallel = n_seq_max; } // utilize the full context @@ -123,9 +127,6 @@ int main(int argc, char ** argv) { params.n_ubatch = params.n_batch; } - // get max number of sequences per batch - const int n_seq_max = llama_max_parallel_sequences(); - llama_backend_init(); llama_numa_init(params.numa); diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 0211255a76..9b10df00da 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -183,6 +183,7 @@ endif() # ggml core set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism") option(GGML_CPU "ggml: enable CPU backend" ON) +option(GGML_SCHED_NO_REALLOC "ggml: disallow reallocations in ggml-alloc (for debugging)" OFF) # 3rd party libs / backends option(GGML_ACCELERATE "ggml: enable Accelerate framework" ON) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index a4499509ec..a36f5b6647 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -221,6 +221,10 @@ if (GGML_BACKEND_DL) target_compile_definitions(ggml-base PUBLIC GGML_BACKEND_DL) endif() +if (GGML_SCHED_NO_REALLOC) + target_compile_definitions(ggml-base PUBLIC GGML_SCHED_NO_REALLOC) +endif() + add_library(ggml ggml-backend-reg.cpp) add_library(ggml::ggml ALIAS ggml) diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 91aff205f1..218222ece8 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -921,10 +921,15 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c } if (realloc) { #ifndef NDEBUG - size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0; - GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); + { + size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0; + if (cur_size > 0) { + GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", + __func__, ggml_backend_buft_name(galloc->bufts[i]), + cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); + } + } #endif - ggml_vbuffer_free(galloc->buffers[i]); galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); if (galloc->buffers[i] == NULL) { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index eeaf35c169..4cf377e7f3 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1395,14 +1395,20 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { // allocate graph if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { +#ifdef GGML_SCHED_NO_REALLOC + GGML_ABORT("%s: failed to allocate graph, but graph re-allocation is disabled by GGML_SCHED_NO_REALLOC\n", __func__); +#endif + +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed); +#endif + // the re-allocation may cause the split inputs to be moved to a different address // synchronize without ggml_backend_sched_synchronize to avoid changing cur_copy for (int i = 0; i < sched->n_backends; i++) { ggml_backend_synchronize(sched->backends[i]); } -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed); -#endif + ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids); if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { GGML_LOG_ERROR("%s: failed to allocate graph\n", __func__); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a58914429c..e04f0fc4f9 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -300,7 +300,7 @@ llama_context::llama_context( cross.v_embd.clear(); - const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; + const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); // avoid reserving graphs with zero outputs - assume one output per sequence @@ -543,7 +543,7 @@ bool llama_context::memory_update(bool optimize) { throw std::runtime_error("failed to initialize memory context"); } - const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; + const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d9cc5e933f..9361a113a1 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -196,7 +196,7 @@ if (NOT WIN32) llama_build_and_test(test-arg-parser.cpp) endif() -if (NOT LLAMA_SANITIZE_ADDRESS) +if (NOT LLAMA_SANITIZE_ADDRESS AND NOT GGML_SCHED_NO_REALLOC) # TODO: repair known memory leaks llama_build_and_test(test-opt.cpp) endif() From 3ce7a65c2f2529a8fc566b4aead53b088f7faec2 Mon Sep 17 00:00:00 2001 From: o7si <32285332+o7si@users.noreply.github.com> Date: Sat, 29 Nov 2025 02:14:00 +0800 Subject: [PATCH 04/13] server: fix: /metrics endpoint returning JSON-escaped Prometheus format (#17386) * fix: /metrics endpoint returning JSON-escaped Prometheus format * mod: remove string overload from ok() method --- tools/server/server.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 05bbe648c1..96b2df27f7 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2713,7 +2713,8 @@ public: res->headers["Process-Start-Time-Unix"] = std::to_string(res_task->t_start); res->content_type = "text/plain; version=0.0.4"; - res->ok(prometheus.str()); + res->status = 200; + res->data = prometheus.str(); return res; }; From 03914c7ef826caf0b6371a6d1de270cda102b542 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?DAN=E2=84=A2?= Date: Fri, 28 Nov 2025 13:29:36 -0500 Subject: [PATCH 05/13] common : move all common_chat_parse_* to chat-parser.cpp. (#17481) --- common/chat-parser.cpp | 968 +++++++++++++++++++++++++++++++++++++++++ common/chat.cpp | 952 ---------------------------------------- 2 files changed, 968 insertions(+), 952 deletions(-) diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index ff83102788..301f439a6f 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -13,6 +13,120 @@ using json = nlohmann::ordered_json; +static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, + const common_regex & prefix, + size_t rstrip_prefix = 0) { + static const std::vector> args_paths = { { "arguments" } }; + if (auto res = builder.try_find_regex(prefix)) { + builder.move_back(rstrip_prefix); + auto tool_calls = builder.consume_json_with_dumped_args(args_paths); + if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call array"); + } + } else { + builder.add_content(builder.consume_rest()); + } +} + +static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) { + std::string arguments; + if (builder.is_partial()) { + arguments = (json{ + { "code", code + builder.healing_marker() } + }) + .dump(); + auto idx = arguments.find(builder.healing_marker()); + if (idx != std::string::npos) { + arguments.resize(idx); + } + } else { + arguments = (json{ + { "code", code } + }) + .dump(); + } + return arguments; +} + +/** + * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. + * Aggregates the prefix, suffix and in-between text into the content. + */ +static void parse_json_tool_calls( + common_chat_msg_parser & builder, + const std::optional & block_open, + const std::optional & function_regex_start_only, + const std::optional & function_regex, + const common_regex & close_regex, + const std::optional & block_close, + bool allow_raw_python = false, + const std::function & get_function_name = + nullptr) { + auto parse_tool_calls = [&]() { + size_t from = std::string::npos; + auto first = true; + while (true) { + auto start_pos = builder.pos(); + auto res = function_regex_start_only && first ? builder.try_consume_regex(*function_regex_start_only) : + function_regex ? builder.try_find_regex(*function_regex, from) : + std::nullopt; + + if (res) { + std::string name; + if (get_function_name) { + name = get_function_name(*res); + } else { + GGML_ASSERT(res->groups.size() == 2); + name = builder.str(res->groups[1]); + } + first = false; + if (name.empty()) { + // get_function_name signalled us that we should skip this match and treat it as content. + from = res->groups[0].begin + 1; + continue; + } + from = std::string::npos; + + auto maybe_raw_python = name == "python" && allow_raw_python; + if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) { + if (auto arguments = builder.try_consume_json_with_dumped_args({ {} })) { + if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_regex(close_regex); + } + continue; + } + if (maybe_raw_python) { + auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); + if (!builder.add_tool_call(name, "", arguments)) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + return; + } + throw common_chat_msg_partial_exception("incomplete tool call"); + } else { + builder.move_to(start_pos); + } + break; + } + if (block_close) { + builder.consume_regex(*block_close); + } + builder.consume_spaces(); + builder.add_content(builder.consume_rest()); + }; + if (block_open) { + if (auto res = builder.try_find_regex(*block_open)) { + parse_tool_calls(); + } else { + builder.add_content(builder.consume_rest()); + } + } else { + parse_tool_calls(); + } +} + common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax) : input_(input), is_partial_(is_partial), syntax_(syntax) { @@ -532,3 +646,857 @@ std::optional common_chat_msg_parse void common_chat_msg_parser::clear_tools() { result_.tool_calls.clear(); } + +/** + * All common_chat_parse_* moved from chat.cpp to chat-parser.cpp below + * to reduce incremental compile time for parser changes. + */ +static void common_chat_parse_generic(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + static const std::vector> content_paths = { + {"response"}, + }; + static const std::vector> args_paths = { + {"tool_call", "arguments"}, + {"tool_calls", "arguments"}, + }; + auto data = builder.consume_json_with_dumped_args(args_paths, content_paths); + if (data.value.contains("tool_calls")) { + if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool calls"); + } + } else if (data.value.contains("tool_call")) { + if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } else if (data.value.contains("response")) { + const auto & response = data.value.at("response"); + builder.add_content(response.is_string() ? response.template get() : response.dump(2)); + if (data.is_partial) { + throw common_chat_msg_partial_exception("incomplete response"); + } + } else { + throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON"); + } +} + +static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex prefix(regex_escape("[TOOL_CALLS]")); + parse_prefixed_json_tool_call_array(builder, prefix); +} + +static void common_chat_parse_magistral(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("[THINK]", "[/THINK]"); + + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex prefix(regex_escape("[TOOL_CALLS]")); + parse_prefixed_json_tool_call_array(builder, prefix); +} + +static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>"); + + static const common_regex start_action_regex("<\\|START_ACTION\\|>"); + static const common_regex end_action_regex("<\\|END_ACTION\\|>"); + static const common_regex start_response_regex("<\\|START_RESPONSE\\|>"); + static const common_regex end_response_regex("<\\|END_RESPONSE\\|>"); + + if (auto res = builder.try_find_regex(start_action_regex)) { + // If we didn't extract thoughts, prelude includes them. + auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}}); + for (const auto & tool_call : tool_calls.value) { + std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; + std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : ""; + std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : ""; + if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } + if (tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_regex(end_action_regex); + } else if (auto res = builder.try_find_regex(start_response_regex)) { + if (!builder.try_find_regex(end_response_regex)) { + builder.add_content(builder.consume_rest()); + throw common_chat_msg_partial_exception(end_response_regex.str()); + } + } else { + builder.add_content(builder.consume_rest()); + } +} + +static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { + builder.try_parse_reasoning("", ""); + + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex function_regex( + "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: "); + static const common_regex close_regex("\\}\\s*"); + + static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\("); + static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*"); + + if (with_builtin_tools) { + static const common_regex builtin_call_regex("<\\|python_tag\\|>"); + if (auto res = builder.try_find_regex(builtin_call_regex)) { + auto fun_res = builder.consume_regex(function_name_regex); + auto function_name = builder.str(fun_res.groups[1]); + + common_healing_marker healing_marker; + json args = json::object(); + while (true) { + if (auto arg_res = builder.try_consume_regex(arg_name_regex)) { + auto arg_name = builder.str(arg_res->groups[1]); + auto partial = builder.consume_json(); + args[arg_name] = partial.json; + healing_marker.marker = partial.healing_marker.marker; + healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker; + builder.consume_spaces(); + if (!builder.try_consume_literal(",")) { + break; + } + } else { + break; + } + } + builder.consume_literal(")"); + builder.consume_spaces(); + + auto arguments = args.dump(); + if (!builder.add_tool_call(function_name, "", arguments)) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + return; + } + } + parse_json_tool_calls( + builder, + /* block_open= */ std::nullopt, + /* function_regex_start_only= */ function_regex, + /* function_regex= */ std::nullopt, + close_regex, + std::nullopt); + +} + +static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); + static const common_regex tool_calls_end("<|tool▁calls▁end|>"); + static const common_regex function_regex("(?:<|tool▁call▁begin|>)?function<|tool▁sep|>([^\n]+)\n```json\n"); + static const common_regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); + + parse_json_tool_calls( + builder, + /* block_open= */ tool_calls_begin, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + tool_calls_end); +} + +static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) { + static const common_regex function_regex("(?:<|tool▁call▁begin|>)?([^\\n<]+)(?:<|tool▁sep|>)"); + + static const common_regex close_regex("(?:[\\s]*)?<|tool▁call▁end|>"); + static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); + static const common_regex tool_calls_end("<|tool▁calls▁end|>"); + + if (!builder.syntax().parse_tool_calls) { + LOG_DBG("%s: not parse_tool_calls\n", __func__); + builder.add_content(builder.consume_rest()); + return; + } + + LOG_DBG("%s: parse_tool_calls\n", __func__); + + parse_json_tool_calls( + builder, + /* block_open= */ tool_calls_begin, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + tool_calls_end); +} + +static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) { + // DeepSeek V3.1 outputs reasoning content between "" and "" tags, followed by regular content + // First try to parse using the standard reasoning parsing method + LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str()); + + auto start_pos = builder.pos(); + auto found_end_think = builder.try_find_literal(""); + builder.move_to(start_pos); + + if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) { + LOG_DBG("%s: no end_think, not partial, adding content\n", __func__); + common_chat_parse_deepseek_v3_1_content(builder); + } else if (builder.try_parse_reasoning("", "")) { + // If reasoning was parsed successfully, the remaining content is regular content + LOG_DBG("%s: parsed reasoning, adding content\n", __func__); + // <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>NAME\n```json\nJSON\n```<|tool▁call▁end|><|tool▁calls▁end|> + common_chat_parse_deepseek_v3_1_content(builder); + } else { + if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) { + LOG_DBG("%s: reasoning_format none, adding content\n", __func__); + common_chat_parse_deepseek_v3_1_content(builder); + return; + } + // If no reasoning tags found, check if we should treat everything as reasoning + if (builder.syntax().thinking_forced_open) { + // If thinking is forced open but no tags found, treat everything as reasoning + LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__); + builder.add_reasoning_content(builder.consume_rest()); + } else { + LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__); + // <|tool▁call▁begin|>NAME<|tool▁sep|>JSON<|tool▁call▁end|> + common_chat_parse_deepseek_v3_1_content(builder); + } + } +} + +static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) { + static const xml_tool_call_format form { + /* form.scope_start = */ "", + /* form.tool_start = */ "", + /* form.key_start = */ "", + /* form.val_end = */ "", + /* form.tool_end = */ "", + /* form.scope_end = */ "", + }; + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) { + static const xml_tool_call_format form = ([]() { + xml_tool_call_format form {}; + form.scope_start = ""; + form.tool_start = "", ""); +} + +static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) { + static const xml_tool_call_format form = ([]() { + xml_tool_call_format form {}; + form.scope_start = "["; + form.tool_start = "{\"name\": \""; + form.tool_sep = "\", \"arguments\": {"; + form.key_start = "\""; + form.key_val_sep = "\": "; + form.val_end = ", "; + form.tool_end = "}, "; + form.scope_end = "]"; + form.raw_argval = false; + form.last_val_end = ""; + form.last_tool_end = "}"; + return form; + })(); + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) { + static const xml_tool_call_format form = ([]() { + xml_tool_call_format form {}; + form.scope_start = ""; + form.tool_start = "\n{\"name\": \""; + form.tool_sep = "\", \"arguments\": {"; + form.key_start = "\""; + form.key_val_sep = "\": "; + form.val_end = ", "; + form.tool_end = "}\n"; + form.scope_end = ""; + form.raw_argval = false; + form.last_val_end = ""; + return form; + })(); + builder.consume_reasoning_with_xml_tool_calls(form); +} + +static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) { + static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))"; + static const std::string recipient("(?: to=functions\\.([^<\\s]+))"); + + static const common_regex start_regex("<\\|start\\|>assistant"); + static const common_regex analysis_regex("<\\|channel\\|>analysis"); + static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?"); + static const common_regex preamble_regex("<\\|channel\\|>commentary"); + static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?"); + static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?"); + + auto consume_end = [&](bool include_end = false) { + if (auto res = builder.try_find_literal("<|end|>")) { + return res->prelude + (include_end ? builder.str(res->groups[0]) : ""); + } + return builder.consume_rest(); + }; + + auto handle_tool_call = [&](const std::string & name) { + if (auto args = builder.try_consume_json_with_dumped_args({{}})) { + if (builder.syntax().parse_tool_calls) { + if (!builder.add_tool_call(name, "", args->value) || args->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } else if (args->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } + }; + + auto regex_match = [](const common_regex & regex, const std::string & input) -> std::optional { + auto match = regex.search(input, 0, true); + if (match.type == COMMON_REGEX_MATCH_TYPE_FULL) { + return match; + } + return std::nullopt; + }; + + do { + auto header_start_pos = builder.pos(); + auto content_start = builder.try_find_literal("<|message|>"); + if (!content_start) { + throw common_chat_msg_partial_exception("incomplete header"); + } + + auto header = content_start->prelude; + + if (auto match = regex_match(tool_call1_regex, header)) { + auto group = match->groups[1]; + auto name = header.substr(group.begin, group.end - group.begin); + handle_tool_call(name); + continue; + } + + if (auto match = regex_match(tool_call2_regex, header)) { + auto group = match->groups[2]; + auto name = header.substr(group.begin, group.end - group.begin); + handle_tool_call(name); + continue; + } + + if (regex_match(analysis_regex, header)) { + builder.move_to(header_start_pos); + if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) { + builder.add_content(consume_end(true)); + } else { + builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>"); + } + continue; + } + + if(regex_match(final_regex, header) || regex_match(preamble_regex, header)) { + builder.add_content(consume_end()); + continue; + } + + // Possibly a malformed message, attempt to recover by rolling + // back to pick up the next <|start|> + LOG_DBG("%s: unknown header from message: %s\n", __func__, header.c_str()); + builder.move_to(header_start_pos); + } while (builder.try_find_regex(start_regex, std::string::npos, false)); + + auto remaining = builder.consume_rest(); + if (!remaining.empty()) { + LOG_DBG("%s: content after last message: %s\n", __func__, remaining.c_str()); + } +} + +static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) { + static const xml_tool_call_format form { + /* form.scope_start = */ "", + /* form.tool_start = */ "", + /* form.tool_sep = */ "", + /* form.key_start = */ "", + /* form.key_val_sep = */ "", + /* form.val_end = */ "", + /* form.tool_end = */ "", + /* form.scope_end = */ "", + /* form.key_val_sep2 = */ "", + }; + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + static const common_regex prefix(regex_escape(" functools[")); + parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1); +} + +static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) { + static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))"); + static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))"); + static const common_regex close_regex(R"(\s*)"); + + parse_json_tool_calls( + builder, + std::nullopt, + function_regex_start_only, + function_regex, + close_regex, + std::nullopt, + /* allow_raw_python= */ true, + /* get_function_name= */ [&](const auto & res) -> std::string { + auto at_start = res.groups[0].begin == 0; + auto name = builder.str(res.groups[1]); + if (!name.empty() && name.back() == '{') { + // Unconsume the opening brace '{' to ensure the JSON parsing goes well. + builder.move_back(1); + } + auto idx = name.find_last_not_of("\n{"); + name = name.substr(0, idx + 1); + if (at_start && name == "all") { + return ""; + } + return name; + }); +} + +static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + // This version of Functionary still supports the llama 3.1 tool call format for the python tool. + static const common_regex python_tag_regex(regex_escape("<|python_tag|>")); + + static const common_regex function_regex(R"()"); + static const common_regex close_regex(R"()"); + + parse_json_tool_calls( + builder, + /* block_open= */ std::nullopt, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + std::nullopt); + + if (auto res = builder.try_find_regex(python_tag_regex)) { + auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); + builder.add_tool_call("python", "", arguments); + return; + } +} + +static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex open_regex( + "(?:" + "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start) + "(" // match 2 (open_tag) + "" + "|" + "|" + "|" + "|" + "|" + "|" + "|" + ")?" + "(\\s*\\{\\s*\"name\")" // match 3 (named tool call) + ")" + "|]+)>" // match 4 (function name) + "|" // match 5 (function name again) + ); + + while (auto res = builder.try_find_regex(open_regex)) { + const auto & block_start = res->groups[1]; + std::string block_end = block_start.empty() ? "" : "```"; + + const auto & open_tag = res->groups[2]; + std::string close_tag; + + if (!res->groups[3].empty()) { + builder.move_to(res->groups[3].begin); + close_tag = open_tag.empty() ? "" : "value) || tool_call->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_spaces(); + builder.consume_literal(close_tag); + builder.consume_spaces(); + if (!block_end.empty()) { + builder.consume_literal(block_end); + builder.consume_spaces(); + } + } else { + throw common_chat_msg_partial_exception("failed to parse tool call"); + } + } else { + auto function_name = builder.str(res->groups[4]); + if (function_name.empty()) { + function_name = builder.str(res->groups[5]); + } + GGML_ASSERT(!function_name.empty()); + + close_tag = ""; + + if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { + if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_spaces(); + builder.consume_literal(close_tag); + builder.consume_spaces(); + if (!block_end.empty()) { + builder.consume_literal(block_end); + builder.consume_spaces(); + } + } + } + } + + builder.add_content(builder.consume_rest()); +} + +static void common_chat_parse_granite(common_chat_msg_parser & builder) { + // Parse thinking tags + static const common_regex start_think_regex(regex_escape("")); + static const common_regex end_think_regex(regex_escape("")); + // Granite models output partial tokens such as "<" and "groups[0].begin); + builder.try_find_regex(end_think_regex, std::string::npos, false); + // Restore position for try_parse_reasoning() + builder.move_to(res->groups[0].begin); + } + builder.try_parse_reasoning("", ""); + + // Parse response tags + static const common_regex start_response_regex(regex_escape("")); + static const common_regex end_response_regex(regex_escape("")); + // Granite models output partial tokens such as "<" and "")); + if (auto res = builder.try_find_regex(tool_call_regex)) { + builder.move_to(res->groups[0].end); + + // Expect JSON array of tool calls + if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) { + if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } + } else { + builder.add_content(builder.consume_rest()); + } +} + +static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) { + // Parse thinking tags + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // Look for tool calls + static const common_regex tool_call_regex(regex_escape("")); + if (auto res = builder.try_find_regex(tool_call_regex)) { + builder.move_to(res->groups[0].end); + + // Expect JSON array of tool calls + auto tool_calls_data = builder.consume_json(); + if (tool_calls_data.json.is_array()) { + if (!builder.try_consume_literal("")) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + builder.add_tool_calls(tool_calls_data.json); + } else { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } + builder.add_content(builder.consume_rest()); +} + +static void common_chat_parse_apertus(common_chat_msg_parser & builder) { + // Parse thinking tags + builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>"); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // Look for tool calls + static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>")); + if (auto res = builder.try_find_regex(tool_call_regex)) { + builder.move_to(res->groups[0].end); + + auto tool_calls_data = builder.consume_json(); + if (tool_calls_data.json.is_array()) { + builder.consume_spaces(); + if (!builder.try_consume_literal("<|tools_suffix|>")) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + for (const auto & value : tool_calls_data.json) { + if (value.is_object()) { + builder.add_tool_call_short_form(value); + } + } + } else { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } + builder.add_content(builder.consume_rest()); +} + + +static void common_chat_parse_lfm2(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|> + static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>")); + static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>")); + + // Loop through all tool calls + while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) { + builder.move_to(res->groups[0].end); + + // Parse JSON array format: [{"name": "...", "arguments": {...}}] + auto tool_calls_data = builder.consume_json(); + + // Consume end marker + builder.consume_spaces(); + if (!builder.try_consume_regex(tool_call_end_regex)) { + throw common_chat_msg_partial_exception("Expected <|tool_call_end|>"); + } + + // Process each tool call in the array + if (tool_calls_data.json.is_array()) { + for (const auto & tool_call : tool_calls_data.json) { + if (!tool_call.is_object()) { + throw common_chat_msg_partial_exception("Tool call must be an object"); + } + + if (!tool_call.contains("name")) { + throw common_chat_msg_partial_exception("Tool call missing 'name' field"); + } + + std::string function_name = tool_call.at("name"); + std::string arguments = "{}"; + + if (tool_call.contains("arguments")) { + if (tool_call.at("arguments").is_object()) { + arguments = tool_call.at("arguments").dump(); + } else if (tool_call.at("arguments").is_string()) { + arguments = tool_call.at("arguments"); + } + } + + if (!builder.add_tool_call(function_name, "", arguments)) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } + } else { + throw common_chat_msg_partial_exception("Expected JSON array for tool calls"); + } + + // Consume any trailing whitespace after this tool call + builder.consume_spaces(); + } + + // Consume any remaining content after all tool calls + auto remaining = builder.consume_rest(); + if (!string_strip(remaining).empty()) { + builder.add_content(remaining); + } +} + +static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { + static const xml_tool_call_format form { + /* form.scope_start = */ "", + /* form.tool_start = */ "", + /* form.key_start = */ "", + /* form.val_end = */ "", + /* form.tool_end = */ "", + /* form.scope_end = */ "", + }; + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static void common_chat_parse_content_only(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("", ""); + builder.add_content(builder.consume_rest()); +} + +static void common_chat_parse(common_chat_msg_parser & builder) { + LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str()); + + switch (builder.syntax().format) { + case COMMON_CHAT_FORMAT_CONTENT_ONLY: + common_chat_parse_content_only(builder); + break; + case COMMON_CHAT_FORMAT_GENERIC: + common_chat_parse_generic(builder); + break; + case COMMON_CHAT_FORMAT_MISTRAL_NEMO: + common_chat_parse_mistral_nemo(builder); + break; + case COMMON_CHAT_FORMAT_MAGISTRAL: + common_chat_parse_magistral(builder); + break; + case COMMON_CHAT_FORMAT_LLAMA_3_X: + common_chat_parse_llama_3_1(builder); + break; + case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: + common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true); + break; + case COMMON_CHAT_FORMAT_DEEPSEEK_R1: + common_chat_parse_deepseek_r1(builder); + break; + case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1: + common_chat_parse_deepseek_v3_1(builder); + break; + case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: + common_chat_parse_functionary_v3_2(builder); + break; + case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: + common_chat_parse_functionary_v3_1_llama_3_1(builder); + break; + case COMMON_CHAT_FORMAT_HERMES_2_PRO: + common_chat_parse_hermes_2_pro(builder); + break; + case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: + common_chat_parse_firefunction_v2(builder); + break; + case COMMON_CHAT_FORMAT_COMMAND_R7B: + common_chat_parse_command_r7b(builder); + break; + case COMMON_CHAT_FORMAT_GRANITE: + common_chat_parse_granite(builder); + break; + case COMMON_CHAT_FORMAT_GPT_OSS: + common_chat_parse_gpt_oss(builder); + break; + case COMMON_CHAT_FORMAT_SEED_OSS: + common_chat_parse_seed_oss(builder); + break; + case COMMON_CHAT_FORMAT_NEMOTRON_V2: + common_chat_parse_nemotron_v2(builder); + break; + case COMMON_CHAT_FORMAT_APERTUS: + common_chat_parse_apertus(builder); + break; + case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: + common_chat_parse_lfm2(builder); + break; + case COMMON_CHAT_FORMAT_MINIMAX_M2: + common_chat_parse_minimax_m2(builder); + break; + case COMMON_CHAT_FORMAT_GLM_4_5: + common_chat_parse_glm_4_5(builder); + break; + case COMMON_CHAT_FORMAT_KIMI_K2: + common_chat_parse_kimi_k2(builder); + break; + case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: + common_chat_parse_qwen3_coder_xml(builder); + break; + case COMMON_CHAT_FORMAT_APRIEL_1_5: + common_chat_parse_apriel_1_5(builder); + break; + case COMMON_CHAT_FORMAT_XIAOMI_MIMO: + common_chat_parse_xiaomi_mimo(builder); + break; + default: + throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); + } + builder.finish(); +} + +common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) { + common_chat_msg_parser builder(input, is_partial, syntax); + try { + common_chat_parse(builder); + } catch (const common_chat_msg_partial_exception & ex) { + LOG_DBG("Partial parse: %s\n", ex.what()); + if (!is_partial) { + builder.clear_tools(); + builder.move_to(0); + common_chat_parse_content_only(builder); + } + } + auto msg = builder.result(); + if (!is_partial) { + LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); + } + return msg; +} diff --git a/common/chat.cpp b/common/chat.cpp index 6fa05a6041..b4a0f985e2 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -678,114 +678,6 @@ common_reasoning_format common_reasoning_format_from_name(const std::string & fo throw std::runtime_error("Unknown reasoning format: " + format); } -static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) { - std::string arguments; - if (builder.is_partial()) { - arguments = (json {{"code", code + builder.healing_marker()}}).dump(); - auto idx = arguments.find(builder.healing_marker()); - if (idx != std::string::npos) { - arguments.resize(idx); - } - } else { - arguments = (json {{"code", code}}).dump(); - } - return arguments; -} - -/** - * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. - * Aggregates the prefix, suffix and in-between text into the content. - */ -static void parse_json_tool_calls( - common_chat_msg_parser & builder, - const std::optional & block_open, - const std::optional & function_regex_start_only, - const std::optional & function_regex, - const common_regex & close_regex, - const std::optional & block_close, - bool allow_raw_python = false, - const std::function & get_function_name = nullptr) { - - auto parse_tool_calls = [&]() { - size_t from = std::string::npos; - auto first = true; - while (true) { - auto start_pos = builder.pos(); - auto res = function_regex_start_only && first - ? builder.try_consume_regex(*function_regex_start_only) - : function_regex - ? builder.try_find_regex(*function_regex, from) - : std::nullopt; - - if (res) { - std::string name; - if (get_function_name) { - name = get_function_name(*res); - } else { - GGML_ASSERT(res->groups.size() == 2); - name = builder.str(res->groups[1]); - } - first = false; - if (name.empty()) { - // get_function_name signalled us that we should skip this match and treat it as content. - from = res->groups[0].begin + 1; - continue; - } - from = std::string::npos; - - auto maybe_raw_python = name == "python" && allow_raw_python; - if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) { - if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { - if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_regex(close_regex); - } - continue; - } - if (maybe_raw_python) { - auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); - if (!builder.add_tool_call(name, "", arguments)) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - return; - } - throw common_chat_msg_partial_exception("incomplete tool call"); - } else { - builder.move_to(start_pos); - } - break; - } - if (block_close) { - builder.consume_regex(*block_close); - } - builder.consume_spaces(); - builder.add_content(builder.consume_rest()); - }; - if (block_open) { - if (auto res = builder.try_find_regex(*block_open)) { - parse_tool_calls(); - } else { - builder.add_content(builder.consume_rest()); - } - } else { - parse_tool_calls(); - } -} - -static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, const common_regex & prefix, size_t rstrip_prefix = 0) { - static const std::vector> args_paths = {{"arguments"}}; - if (auto res = builder.try_find_regex(prefix)) { - builder.move_back(rstrip_prefix); - auto tool_calls = builder.consume_json_with_dumped_args(args_paths); - if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call array"); - } - } else { - builder.add_content(builder.consume_rest()); - } -} - static void foreach_function(const json & tools, const std::function & fn) { for (const auto & tool : tools) { if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) { @@ -918,37 +810,6 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp data.format = COMMON_CHAT_FORMAT_GENERIC; return data; } -static void common_chat_parse_generic(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - static const std::vector> content_paths = { - {"response"}, - }; - static const std::vector> args_paths = { - {"tool_call", "arguments"}, - {"tool_calls", "arguments"}, - }; - auto data = builder.consume_json_with_dumped_args(args_paths, content_paths); - if (data.value.contains("tool_calls")) { - if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool calls"); - } - } else if (data.value.contains("tool_call")) { - if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } else if (data.value.contains("response")) { - const auto & response = data.value.at("response"); - builder.add_content(response.is_string() ? response.template get() : response.dump(2)); - if (data.is_partial) { - throw common_chat_msg_partial_exception("incomplete response"); - } - } else { - throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON"); - } -} static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -1173,28 +1034,6 @@ static common_chat_params common_chat_params_init_magistral(const common_chat_te return data; } -static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex prefix(regex_escape("[TOOL_CALLS]")); - parse_prefixed_json_tool_call_array(builder, prefix); -} - -static void common_chat_parse_magistral(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("[THINK]", "[/THINK]"); - - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex prefix(regex_escape("[TOOL_CALLS]")); - parse_prefixed_json_tool_call_array(builder, prefix); -} - static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -1275,39 +1114,6 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ return data; } -static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>"); - - static const common_regex start_action_regex("<\\|START_ACTION\\|>"); - static const common_regex end_action_regex("<\\|END_ACTION\\|>"); - static const common_regex start_response_regex("<\\|START_RESPONSE\\|>"); - static const common_regex end_response_regex("<\\|END_RESPONSE\\|>"); - - if (auto res = builder.try_find_regex(start_action_regex)) { - // If we didn't extract thoughts, prelude includes them. - auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}}); - for (const auto & tool_call : tool_calls.value) { - std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; - std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : ""; - std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : ""; - if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - if (tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_regex(end_action_regex); - } else if (auto res = builder.try_find_regex(start_response_regex)) { - if (!builder.try_find_regex(end_response_regex)) { - builder.add_content(builder.consume_rest()); - throw common_chat_msg_partial_exception(end_response_regex.str()); - } - } else { - builder.add_content(builder.consume_rest()); - } -} - static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector & expected_properties) { if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) { throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties"); @@ -1536,63 +1342,6 @@ static common_chat_params common_chat_params_init_apertus(const common_chat_temp } return data; } -static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { - builder.try_parse_reasoning("", ""); - - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex function_regex( - "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: "); - static const common_regex close_regex("\\}\\s*"); - - static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\("); - static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*"); - - if (with_builtin_tools) { - static const common_regex builtin_call_regex("<\\|python_tag\\|>"); - if (auto res = builder.try_find_regex(builtin_call_regex)) { - auto fun_res = builder.consume_regex(function_name_regex); - auto function_name = builder.str(fun_res.groups[1]); - - common_healing_marker healing_marker; - json args = json::object(); - while (true) { - if (auto arg_res = builder.try_consume_regex(arg_name_regex)) { - auto arg_name = builder.str(arg_res->groups[1]); - auto partial = builder.consume_json(); - args[arg_name] = partial.json; - healing_marker.marker = partial.healing_marker.marker; - healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker; - builder.consume_spaces(); - if (!builder.try_consume_literal(",")) { - break; - } - } else { - break; - } - } - builder.consume_literal(")"); - builder.consume_spaces(); - - auto arguments = args.dump(); - if (!builder.add_tool_call(function_name, "", arguments)) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - return; - } - } - parse_json_tool_calls( - builder, - /* block_open= */ std::nullopt, - /* function_regex_start_only= */ function_regex, - /* function_regex= */ std::nullopt, - close_regex, - std::nullopt); - -} static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -1732,88 +1481,6 @@ static common_chat_params common_chat_params_init_deepseek_v3_1(const common_cha return data; } -static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("", ""); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); - static const common_regex tool_calls_end("<|tool▁calls▁end|>"); - static const common_regex function_regex("(?:<|tool▁call▁begin|>)?function<|tool▁sep|>([^\n]+)\n```json\n"); - static const common_regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); - - parse_json_tool_calls( - builder, - /* block_open= */ tool_calls_begin, - /* function_regex_start_only= */ std::nullopt, - function_regex, - close_regex, - tool_calls_end); -} - -static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) { - static const common_regex function_regex("(?:<|tool▁call▁begin|>)?([^\\n<]+)(?:<|tool▁sep|>)"); - - static const common_regex close_regex("(?:[\\s]*)?<|tool▁call▁end|>"); - static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); - static const common_regex tool_calls_end("<|tool▁calls▁end|>"); - - if (!builder.syntax().parse_tool_calls) { - LOG_DBG("%s: not parse_tool_calls\n", __func__); - builder.add_content(builder.consume_rest()); - return; - } - - LOG_DBG("%s: parse_tool_calls\n", __func__); - - parse_json_tool_calls( - builder, - /* block_open= */ tool_calls_begin, - /* function_regex_start_only= */ std::nullopt, - function_regex, - close_regex, - tool_calls_end); -} - -static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) { - // DeepSeek V3.1 outputs reasoning content between "" and "" tags, followed by regular content - // First try to parse using the standard reasoning parsing method - LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str()); - - auto start_pos = builder.pos(); - auto found_end_think = builder.try_find_literal(""); - builder.move_to(start_pos); - - if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) { - LOG_DBG("%s: no end_think, not partial, adding content\n", __func__); - common_chat_parse_deepseek_v3_1_content(builder); - } else if (builder.try_parse_reasoning("", "")) { - // If reasoning was parsed successfully, the remaining content is regular content - LOG_DBG("%s: parsed reasoning, adding content\n", __func__); - // <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>NAME\n```json\nJSON\n```<|tool▁call▁end|><|tool▁calls▁end|> - common_chat_parse_deepseek_v3_1_content(builder); - } else { - if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) { - LOG_DBG("%s: reasoning_format none, adding content\n", __func__); - common_chat_parse_deepseek_v3_1_content(builder); - return; - } - // If no reasoning tags found, check if we should treat everything as reasoning - if (builder.syntax().thinking_forced_open) { - // If thinking is forced open but no tags found, treat everything as reasoning - LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__); - builder.add_reasoning_content(builder.consume_rest()); - } else { - LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__); - // <|tool▁call▁begin|>NAME<|tool▁sep|>JSON<|tool▁call▁end|> - common_chat_parse_deepseek_v3_1_content(builder); - } - } -} - - static common_chat_params common_chat_params_init_minimax_m2(const common_chat_template & tmpl, const struct templates_params & params) { common_chat_params data; data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -1856,20 +1523,6 @@ static common_chat_params common_chat_params_init_minimax_m2(const common_chat_t return data; } -static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) { - static const xml_tool_call_format form { - /* form.scope_start = */ "", - /* form.tool_start = */ "", - /* form.key_start = */ "", - /* form.val_end = */ "", - /* form.tool_end = */ "", - /* form.scope_end = */ "", - }; - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_chat_template & tmpl, const struct templates_params & params) { common_chat_params data; data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -1902,23 +1555,6 @@ static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_c return data; } -static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = ""; - form.tool_start = "", ""); -} - static common_chat_params common_chat_params_init_apriel_1_5(const common_chat_template & tmpl, const struct templates_params & params) { common_chat_params data; data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -2016,25 +1634,6 @@ static common_chat_params common_chat_params_init_apriel_1_5(const common_chat_t return data; } -static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = "["; - form.tool_start = "{\"name\": \""; - form.tool_sep = "\", \"arguments\": {"; - form.key_start = "\""; - form.key_val_sep = "\": "; - form.val_end = ", "; - form.tool_end = "}, "; - form.scope_end = "]"; - form.raw_argval = false; - form.last_val_end = ""; - form.last_tool_end = "}"; - return form; - })(); - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - static common_chat_params common_chat_params_init_xiaomi_mimo(const common_chat_template & tmpl, const struct templates_params & params) { common_chat_params data; data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -2067,24 +1666,6 @@ static common_chat_params common_chat_params_init_xiaomi_mimo(const common_chat_ return data; } -static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = ""; - form.tool_start = "\n{\"name\": \""; - form.tool_sep = "\", \"arguments\": {"; - form.key_start = "\""; - form.key_val_sep = "\": "; - form.val_end = ", "; - form.tool_end = "}\n"; - form.scope_end = ""; - form.raw_argval = false; - form.last_val_end = ""; - return form; - })(); - builder.consume_reasoning_with_xml_tool_calls(form); -} - static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -2231,93 +1812,6 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp return data; } -static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) { - static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))"; - static const std::string recipient("(?: to=functions\\.([^<\\s]+))"); - - static const common_regex start_regex("<\\|start\\|>assistant"); - static const common_regex analysis_regex("<\\|channel\\|>analysis"); - static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?"); - static const common_regex preamble_regex("<\\|channel\\|>commentary"); - static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?"); - static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?"); - - auto consume_end = [&](bool include_end = false) { - if (auto res = builder.try_find_literal("<|end|>")) { - return res->prelude + (include_end ? builder.str(res->groups[0]) : ""); - } - return builder.consume_rest(); - }; - - auto handle_tool_call = [&](const std::string & name) { - if (auto args = builder.try_consume_json_with_dumped_args({{}})) { - if (builder.syntax().parse_tool_calls) { - if (!builder.add_tool_call(name, "", args->value) || args->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } else if (args->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - }; - - auto regex_match = [](const common_regex & regex, const std::string & input) -> std::optional { - auto match = regex.search(input, 0, true); - if (match.type == COMMON_REGEX_MATCH_TYPE_FULL) { - return match; - } - return std::nullopt; - }; - - do { - auto header_start_pos = builder.pos(); - auto content_start = builder.try_find_literal("<|message|>"); - if (!content_start) { - throw common_chat_msg_partial_exception("incomplete header"); - } - - auto header = content_start->prelude; - - if (auto match = regex_match(tool_call1_regex, header)) { - auto group = match->groups[1]; - auto name = header.substr(group.begin, group.end - group.begin); - handle_tool_call(name); - continue; - } - - if (auto match = regex_match(tool_call2_regex, header)) { - auto group = match->groups[2]; - auto name = header.substr(group.begin, group.end - group.begin); - handle_tool_call(name); - continue; - } - - if (regex_match(analysis_regex, header)) { - builder.move_to(header_start_pos); - if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) { - builder.add_content(consume_end(true)); - } else { - builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>"); - } - continue; - } - - if(regex_match(final_regex, header) || regex_match(preamble_regex, header)) { - builder.add_content(consume_end()); - continue; - } - - // Possibly a malformed message, attempt to recover by rolling - // back to pick up the next <|start|> - LOG_DBG("%s: unknown header from message: %s\n", __func__, header.c_str()); - builder.move_to(header_start_pos); - } while (builder.try_find_regex(start_regex, std::string::npos, false)); - - auto remaining = builder.consume_rest(); - if (!remaining.empty()) { - LOG_DBG("%s: content after last message: %s\n", __func__, remaining.c_str()); - } -} static common_chat_params common_chat_params_init_glm_4_5(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -2398,21 +1892,6 @@ static common_chat_params common_chat_params_init_glm_4_5(const common_chat_temp return data; } -static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) { - static const xml_tool_call_format form { - /* form.scope_start = */ "", - /* form.tool_start = */ "", - /* form.tool_sep = */ "", - /* form.key_start = */ "", - /* form.key_val_sep = */ "", - /* form.val_end = */ "", - /* form.tool_end = */ "", - /* form.scope_end = */ "", - /* form.key_val_sep2 = */ "", - }; - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { LOG_DBG("%s\n", __func__); common_chat_params data; @@ -2460,14 +1939,6 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c } return data; } -static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - static const common_regex prefix(regex_escape(" functools[")); - parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1); -} static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) { // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... @@ -2518,34 +1989,6 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ } return data; } -static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) { - static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))"); - static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))"); - static const common_regex close_regex(R"(\s*)"); - - parse_json_tool_calls( - builder, - std::nullopt, - function_regex_start_only, - function_regex, - close_regex, - std::nullopt, - /* allow_raw_python= */ true, - /* get_function_name= */ [&](const auto & res) -> std::string { - auto at_start = res.groups[0].begin == 0; - auto name = builder.str(res.groups[1]); - if (!name.empty() && name.back() == '{') { - // Unconsume the opening brace '{' to ensure the JSON parsing goes well. - builder.move_back(1); - } - auto idx = name.find_last_not_of("\n{"); - name = name.substr(0, idx + 1); - if (at_start && name == "all") { - return ""; - } - return name; - }); -} static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) { // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt @@ -2605,31 +2048,6 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con // TODO: if (has_raw_python) return data; } -static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - // This version of Functionary still supports the llama 3.1 tool call format for the python tool. - static const common_regex python_tag_regex(regex_escape("<|python_tag|>")); - - static const common_regex function_regex(R"()"); - static const common_regex close_regex(R"()"); - - parse_json_tool_calls( - builder, - /* block_open= */ std::nullopt, - /* function_regex_start_only= */ std::nullopt, - function_regex, - close_regex, - std::nullopt); - - if (auto res = builder.try_find_regex(python_tag_regex)) { - auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); - builder.add_tool_call("python", "", arguments); - return; - } -} static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -2746,83 +2164,6 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat return data; } -static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("", ""); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex open_regex( - "(?:" - "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start) - "(" // match 2 (open_tag) - "" - "|" - "|" - "|" - "|" - "|" - "|" - "|" - ")?" - "(\\s*\\{\\s*\"name\")" // match 3 (named tool call) - ")" - "|]+)>" // match 4 (function name) - "|" // match 5 (function name again) - ); - - while (auto res = builder.try_find_regex(open_regex)) { - const auto & block_start = res->groups[1]; - std::string block_end = block_start.empty() ? "" : "```"; - - const auto & open_tag = res->groups[2]; - std::string close_tag; - - if (!res->groups[3].empty()) { - builder.move_to(res->groups[3].begin); - close_tag = open_tag.empty() ? "" : "value) || tool_call->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_spaces(); - builder.consume_literal(close_tag); - builder.consume_spaces(); - if (!block_end.empty()) { - builder.consume_literal(block_end); - builder.consume_spaces(); - } - } else { - throw common_chat_msg_partial_exception("failed to parse tool call"); - } - } else { - auto function_name = builder.str(res->groups[4]); - if (function_name.empty()) { - function_name = builder.str(res->groups[5]); - } - GGML_ASSERT(!function_name.empty()); - - close_tag = ""; - - if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { - if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_spaces(); - builder.consume_literal(close_tag); - builder.consume_spaces(); - if (!block_end.empty()) { - builder.consume_literal(block_end); - builder.consume_spaces(); - } - } - } - } - - builder.add_content(builder.consume_rest()); -} static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -2905,190 +2246,6 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp return data; } -static void common_chat_parse_granite(common_chat_msg_parser & builder) { - // Parse thinking tags - static const common_regex start_think_regex(regex_escape("")); - static const common_regex end_think_regex(regex_escape("")); - // Granite models output partial tokens such as "<" and "groups[0].begin); - builder.try_find_regex(end_think_regex, std::string::npos, false); - // Restore position for try_parse_reasoning() - builder.move_to(res->groups[0].begin); - } - builder.try_parse_reasoning("", ""); - - // Parse response tags - static const common_regex start_response_regex(regex_escape("")); - static const common_regex end_response_regex(regex_escape("")); - // Granite models output partial tokens such as "<" and "")); - if (auto res = builder.try_find_regex(tool_call_regex)) { - builder.move_to(res->groups[0].end); - - // Expect JSON array of tool calls - if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) { - if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - } else { - builder.add_content(builder.consume_rest()); - } -} - -static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) { - // Parse thinking tags - builder.try_parse_reasoning("", ""); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - // Look for tool calls - static const common_regex tool_call_regex(regex_escape("")); - if (auto res = builder.try_find_regex(tool_call_regex)) { - builder.move_to(res->groups[0].end); - - // Expect JSON array of tool calls - auto tool_calls_data = builder.consume_json(); - if (tool_calls_data.json.is_array()) { - if (!builder.try_consume_literal("")) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - builder.add_tool_calls(tool_calls_data.json); - } else { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - } - builder.add_content(builder.consume_rest()); -} - -static void common_chat_parse_apertus(common_chat_msg_parser & builder) { - // Parse thinking tags - builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>"); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - // Look for tool calls - static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>")); - if (auto res = builder.try_find_regex(tool_call_regex)) { - builder.move_to(res->groups[0].end); - - auto tool_calls_data = builder.consume_json(); - if (tool_calls_data.json.is_array()) { - builder.consume_spaces(); - if (!builder.try_consume_literal("<|tools_suffix|>")) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - for (const auto & value : tool_calls_data.json) { - if (value.is_object()) { - builder.add_tool_call_short_form(value); - } - } - } else { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - } - builder.add_content(builder.consume_rest()); -} - - -static void common_chat_parse_lfm2(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - // LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|> - static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>")); - static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>")); - - // Loop through all tool calls - while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) { - builder.move_to(res->groups[0].end); - - // Parse JSON array format: [{"name": "...", "arguments": {...}}] - auto tool_calls_data = builder.consume_json(); - - // Consume end marker - builder.consume_spaces(); - if (!builder.try_consume_regex(tool_call_end_regex)) { - throw common_chat_msg_partial_exception("Expected <|tool_call_end|>"); - } - - // Process each tool call in the array - if (tool_calls_data.json.is_array()) { - for (const auto & tool_call : tool_calls_data.json) { - if (!tool_call.is_object()) { - throw common_chat_msg_partial_exception("Tool call must be an object"); - } - - if (!tool_call.contains("name")) { - throw common_chat_msg_partial_exception("Tool call missing 'name' field"); - } - - std::string function_name = tool_call.at("name"); - std::string arguments = "{}"; - - if (tool_call.contains("arguments")) { - if (tool_call.at("arguments").is_object()) { - arguments = tool_call.at("arguments").dump(); - } else if (tool_call.at("arguments").is_string()) { - arguments = tool_call.at("arguments"); - } - } - - if (!builder.add_tool_call(function_name, "", arguments)) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - } - } else { - throw common_chat_msg_partial_exception("Expected JSON array for tool calls"); - } - - // Consume any trailing whitespace after this tool call - builder.consume_spaces(); - } - - // Consume any remaining content after all tool calls - auto remaining = builder.consume_rest(); - if (!string_strip(remaining).empty()) { - builder.add_content(remaining); - } -} - -static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { - static const xml_tool_call_format form { - /* form.scope_start = */ "", - /* form.tool_start = */ "", - /* form.key_start = */ "", - /* form.val_end = */ "", - /* form.tool_end = */ "", - /* form.scope_end = */ "", - }; - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; data.prompt = apply(tmpl, inputs); @@ -3428,112 +2585,3 @@ common_chat_params common_chat_templates_apply( ? common_chat_templates_apply_jinja(tmpls, inputs) : common_chat_templates_apply_legacy(tmpls, inputs); } - -static void common_chat_parse_content_only(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("", ""); - builder.add_content(builder.consume_rest()); -} - -static void common_chat_parse(common_chat_msg_parser & builder) { - LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str()); - - switch (builder.syntax().format) { - case COMMON_CHAT_FORMAT_CONTENT_ONLY: - common_chat_parse_content_only(builder); - break; - case COMMON_CHAT_FORMAT_GENERIC: - common_chat_parse_generic(builder); - break; - case COMMON_CHAT_FORMAT_MISTRAL_NEMO: - common_chat_parse_mistral_nemo(builder); - break; - case COMMON_CHAT_FORMAT_MAGISTRAL: - common_chat_parse_magistral(builder); - break; - case COMMON_CHAT_FORMAT_LLAMA_3_X: - common_chat_parse_llama_3_1(builder); - break; - case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: - common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true); - break; - case COMMON_CHAT_FORMAT_DEEPSEEK_R1: - common_chat_parse_deepseek_r1(builder); - break; - case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1: - common_chat_parse_deepseek_v3_1(builder); - break; - case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: - common_chat_parse_functionary_v3_2(builder); - break; - case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: - common_chat_parse_functionary_v3_1_llama_3_1(builder); - break; - case COMMON_CHAT_FORMAT_HERMES_2_PRO: - common_chat_parse_hermes_2_pro(builder); - break; - case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: - common_chat_parse_firefunction_v2(builder); - break; - case COMMON_CHAT_FORMAT_COMMAND_R7B: - common_chat_parse_command_r7b(builder); - break; - case COMMON_CHAT_FORMAT_GRANITE: - common_chat_parse_granite(builder); - break; - case COMMON_CHAT_FORMAT_GPT_OSS: - common_chat_parse_gpt_oss(builder); - break; - case COMMON_CHAT_FORMAT_SEED_OSS: - common_chat_parse_seed_oss(builder); - break; - case COMMON_CHAT_FORMAT_NEMOTRON_V2: - common_chat_parse_nemotron_v2(builder); - break; - case COMMON_CHAT_FORMAT_APERTUS: - common_chat_parse_apertus(builder); - break; - case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: - common_chat_parse_lfm2(builder); - break; - case COMMON_CHAT_FORMAT_MINIMAX_M2: - common_chat_parse_minimax_m2(builder); - break; - case COMMON_CHAT_FORMAT_GLM_4_5: - common_chat_parse_glm_4_5(builder); - break; - case COMMON_CHAT_FORMAT_KIMI_K2: - common_chat_parse_kimi_k2(builder); - break; - case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: - common_chat_parse_qwen3_coder_xml(builder); - break; - case COMMON_CHAT_FORMAT_APRIEL_1_5: - common_chat_parse_apriel_1_5(builder); - break; - case COMMON_CHAT_FORMAT_XIAOMI_MIMO: - common_chat_parse_xiaomi_mimo(builder); - break; - default: - throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); - } - builder.finish(); -} - -common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) { - common_chat_msg_parser builder(input, is_partial, syntax); - try { - common_chat_parse(builder); - } catch (const common_chat_msg_partial_exception & ex) { - LOG_DBG("Partial parse: %s\n", ex.what()); - if (!is_partial) { - builder.clear_tools(); - builder.move_to(0); - common_chat_parse_content_only(builder); - } - } - auto msg = builder.result(); - if (!is_partial) { - LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); - } - return msg; -} From d82b7a7c1d73c0674698d9601b1bbb0200933f29 Mon Sep 17 00:00:00 2001 From: Aleksei Nikiforov <103434461+AlekseiNikiforovIBM@users.noreply.github.com> Date: Fri, 28 Nov 2025 20:53:01 +0100 Subject: [PATCH 06/13] gguf-py : fix passing non-native endian tensors (editor-gui and new-metadata) (#17553) gguf_new_metadata.py reads data from reader. Reader doesn't byteswap tensors to native endianness. But writer does expect tensors in native endianness to convert them into requested endianness. There are two ways to fix this: update reader and do conversion to native endianness and back, or skip converting endianness in writer in this particular USE-case. gguf_editor_gui.py doesn't allow editing or viewing tensor data. Let's go with skipping excessive byteswapping. If eventually capability to view or edit tensor data is added, tensor data should be instead byteswapped when reading it. --- gguf-py/gguf/gguf_writer.py | 18 ++++++++++++------ gguf-py/gguf/scripts/gguf_editor_gui.py | 2 +- gguf-py/gguf/scripts/gguf_new_metadata.py | 2 +- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 57ca2035fe..8ddd895cb7 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -371,10 +371,13 @@ class GGUFWriter: def add_tensor( self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, - raw_dtype: GGMLQuantizationType | None = None, + raw_dtype: GGMLQuantizationType | None = None, tensor_endianess: GGUFEndian | None = None ) -> None: - if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \ - (self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'): + # if tensor endianness is not passed, assume it's native to system + if tensor_endianess is None: + tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE + + if tensor_endianess != self.endianess: # Don't byteswap inplace since lazy copies cannot handle it tensor = tensor.byteswap(inplace=False) if self.use_temp_file and self.temp_file is None: @@ -397,13 +400,16 @@ class GGUFWriter: if pad != 0: fp.write(bytes([0] * pad)) - def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: + def write_tensor_data(self, tensor: np.ndarray[Any, Any], tensor_endianess: GGUFEndian | None = None) -> None: if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS: raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}') assert self.fout is not None - if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \ - (self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'): + # if tensor endianness is not passed, assume it's native to system + if tensor_endianess is None: + tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE + + if tensor_endianess != self.endianess: # Don't byteswap inplace since lazy copies cannot handle it tensor = tensor.byteswap(inplace=False) diff --git a/gguf-py/gguf/scripts/gguf_editor_gui.py b/gguf-py/gguf/scripts/gguf_editor_gui.py index 05f4db0f8c..293316afed 100755 --- a/gguf-py/gguf/scripts/gguf_editor_gui.py +++ b/gguf-py/gguf/scripts/gguf_editor_gui.py @@ -1552,7 +1552,7 @@ class GGUFEditorWindow(QMainWindow): # Add tensors (including data) for tensor in self.reader.tensors: - writer.add_tensor(tensor.name, tensor.data, raw_shape=tensor.data.shape, raw_dtype=tensor.tensor_type) + writer.add_tensor(tensor.name, tensor.data, raw_shape=tensor.data.shape, raw_dtype=tensor.tensor_type, tensor_endianess=self.reader.endianess) # Write header and metadata writer.open_output_file(Path(file_path)) diff --git a/gguf-py/gguf/scripts/gguf_new_metadata.py b/gguf-py/gguf/scripts/gguf_new_metadata.py index 2fa5800cf7..c67436bad4 100755 --- a/gguf-py/gguf/scripts/gguf_new_metadata.py +++ b/gguf-py/gguf/scripts/gguf_new_metadata.py @@ -94,7 +94,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new writer.write_ti_data_to_file() for tensor in reader.tensors: - writer.write_tensor_data(tensor.data) + writer.write_tensor_data(tensor.data, tensor_endianess=reader.endianess) bar.update(tensor.n_bytes) writer.close() From 59d8d4e96341eb54f362ac3d583ef522566e2a39 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 29 Nov 2025 01:39:57 -0600 Subject: [PATCH 07/13] vulkan: improve topk perf for large k, fix overflow in unit tests (#17582) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 +++- tests/test-backend-ops.cpp | 8 +++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 73562bc1be..f3aba8165b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -10239,7 +10239,9 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons // Prefer going as small as num_topk_pipelines - 3 for perf reasons. // But if K is larger, then we need a larger workgroup - uint32_t max_pipeline = num_topk_pipelines - 3; + uint32_t max_pipeline = num_topk_pipelines - 1; + uint32_t preferred_pipeline = std::max(num_topk_pipelines - 3, (uint32_t)log2f(float(k)) + 2); + max_pipeline = std::min(preferred_pipeline, max_pipeline); uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1; // require full subgroup min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 60bab47b9f..87a61aa122 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1446,14 +1446,14 @@ struct test_case { const uint64_t target_flops_cpu = 8ULL * GFLOP; const uint64_t target_flops_gpu = 100ULL * GFLOP; uint64_t target_flops = is_cpu ? target_flops_cpu : target_flops_gpu; - n_runs = std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1; + n_runs = (int)std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1; } else { // based on memory size const size_t GB = 1ULL << 30; const size_t target_size_cpu = 8 * GB; const size_t target_size_gpu = 32 * GB; size_t target_size = is_cpu ? target_size_cpu : target_size_gpu; - n_runs = std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1; + n_runs = (int)std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1; } // duplicate the op @@ -8043,7 +8043,9 @@ static std::vector> make_test_cases_perf() { } test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1})); - for (auto k : {1, 10, 40}) { + + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2, 1, 1, 1}, 1)); + for (auto k : {1, 10, 40, 400}) { for (auto nrows : {1, 16}) { for (auto cols : {k, 1000, 65000, 200000}) { test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {cols, nrows, 1, 1}, k)); From 47a268ea5000fc0f05fc1c5cd0062efebfe84b92 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sat, 29 Nov 2025 09:37:22 +0100 Subject: [PATCH 08/13] Vulkan: MMVQ Integer Dot K-Quant and MUL_MAT_ID support (#16900) * vulkan: split mul_mmq_funcs for mul_mat_vecq use * add mxfp4 mmvq * add q2_k mmvq * add q3_k mmvq * add q4_k and q5_k mmvq * add q6_k mmvq * handle 4x4 quants per mmvq thread * enable MUL_MAT_ID mmvq support * enable subgroup optimizations for mul_mat_vec_id shaders * device tuning * request prealloc_y sync after quantization * fix indentation * fix llvmpipe test failures * fix mul_mat_id mmvq condition * fix unused variable warning --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 244 ++++++++--- .../vulkan-shaders/dequant_funcs.glsl | 7 - .../vulkan-shaders/generic_binary_head.glsl | 7 + .../vulkan-shaders/generic_unary_head.glsl | 7 + .../vulkan-shaders/mul_mat_vec.comp | 1 + .../vulkan-shaders/mul_mat_vec_base.glsl | 2 - .../vulkan-shaders/mul_mat_vec_iface.glsl | 8 +- .../vulkan-shaders/mul_mat_vecq.comp | 62 ++- .../vulkan-shaders/mul_mat_vecq_funcs.glsl | 379 ++++++++++++++++++ .../ggml-vulkan/vulkan-shaders/mul_mmq.comp | 2 - .../vulkan-shaders/mul_mmq_funcs.glsl | 229 +++-------- .../vulkan-shaders/vulkan-shaders-gen.cpp | 22 +- 12 files changed, 682 insertions(+), 288 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index f3aba8165b..66dd0bfabd 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -613,9 +613,10 @@ struct vk_device_struct { vk_pipeline pipeline_dequant[GGML_TYPE_COUNT]; vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; - vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT]; vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + vk_pipeline pipeline_dequant_mul_mat_vec_id_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT]; vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio]; vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; @@ -1611,7 +1612,7 @@ class vk_perf_logger { } if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { const uint64_t m = node->src[0]->ne[1]; - const uint64_t n = node->ne[1]; + const uint64_t n = (node->op == GGML_OP_MUL_MAT) ? node->ne[1] : node->ne[2]; const uint64_t k = node->src[1]->ne[0]; const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3]; std::string name = ggml_op_name(node->op); @@ -3525,13 +3526,18 @@ static void ggml_vk_load_shaders(vk_device& device) { // the number of rows computed per shader depends on GPU model and quant uint32_t rm_stdq = 1; uint32_t rm_kq = 2; + uint32_t rm_stdq_int = 1; + uint32_t rm_kq_int = 1; if (device->vendor_id == VK_VENDOR_ID_AMD) { if (device->architecture == AMD_GCN) { rm_stdq = 2; rm_kq = 4; + rm_stdq_int = 4; } - } else if (device->vendor_id == VK_VENDOR_ID_INTEL) + } else if (device->vendor_id == VK_VENDOR_ID_INTEL) { rm_stdq = 2; + rm_stdq_int = 2; + } uint32_t rm_iq = 2 * rm_kq; const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN; @@ -3612,39 +3618,73 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size; const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_q8_1_f32", arr_dmmv_mxfp4_q8_1_f32_len[reduc], arr_dmmv_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_q8_1_f32", arr_dmmv_q2_k_q8_1_f32_len[reduc], arr_dmmv_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_q8_1_f32", arr_dmmv_q3_k_q8_1_f32_len[reduc], arr_dmmv_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); } #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT } + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", arr_dmmv_id_q4_1_f32_f32_len[reduc], arr_dmmv_id_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", arr_dmmv_id_q5_0_f32_f32_len[reduc], arr_dmmv_id_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", arr_dmmv_id_q5_1_f32_f32_len[reduc], arr_dmmv_id_q5_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", arr_dmmv_id_q8_0_f32_f32_len[reduc], arr_dmmv_id_q8_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", arr_dmmv_id_q2_k_f32_f32_len[reduc16], arr_dmmv_id_q2_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", arr_dmmv_id_q3_k_f32_f32_len[reduc16], arr_dmmv_id_q3_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", arr_dmmv_id_q4_k_f32_f32_len[reduc16], arr_dmmv_id_q4_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", arr_dmmv_id_q5_k_f32_f32_len[reduc16], arr_dmmv_id_q5_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", arr_dmmv_id_q6_k_f32_f32_len[reduc16], arr_dmmv_id_q6_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", arr_dmmv_id_iq1_s_f32_f32_len[reduc16], arr_dmmv_id_iq1_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", arr_dmmv_id_iq1_m_f32_f32_len[reduc16], arr_dmmv_id_iq1_m_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", arr_dmmv_id_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", arr_dmmv_id_iq2_xs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", arr_dmmv_id_iq2_s_f32_f32_len[reduc16], arr_dmmv_id_iq2_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", arr_dmmv_id_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq3_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", arr_dmmv_id_iq3_s_f32_f32_len[reduc16], arr_dmmv_id_iq3_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", arr_dmmv_id_iq4_xs_f32_f32_len[reduc16], arr_dmmv_id_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", arr_dmmv_id_iq4_nl_f32_f32_len[reduc16], arr_dmmv_id_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", arr_dmmv_id_mxfp4_f32_f32_len[reduc16], arr_dmmv_id_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size; + const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + } +#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT } - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); +#if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + GGML_UNUSED(rm_stdq_int); + GGML_UNUSED(rm_kq_int); +#endif // dequant shaders ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); @@ -5453,6 +5493,12 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_MXFP4: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: break; default: return nullptr; @@ -5592,9 +5638,28 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co } } -static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) { +static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t m, uint32_t k) { VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()"); - GGML_ASSERT(b_type == GGML_TYPE_F32); + GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_Q8_1); + + if (b_type == GGML_TYPE_Q8_1) { + switch (a_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_MXFP4: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + break; + default: + return nullptr; + } + } switch (a_type) { case GGML_TYPE_F32: @@ -5625,7 +5690,31 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context return nullptr; } - return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type]; + // heuristic to choose workgroup size + uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP; + if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) { + // Prefer larger workgroups when M is small, to spread the work out more + // and keep more SMs busy. + // q6_k seems to prefer small workgroup size even for "medium" values of M. + if (a_type == GGML_TYPE_Q6_K) { + if (m < 4096 && k >= 1024) { + dmmv_wg = DMMV_WG_SIZE_LARGE; + } + } else { + if (m <= 8192 && k >= 1024) { + dmmv_wg = DMMV_WG_SIZE_LARGE; + } + } + } + + if (b_type == GGML_TYPE_Q8_1) { + if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) { + dmmv_wg = DMMV_WG_SIZE_SUBGROUP; + } + return ctx->device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[dmmv_wg][a_type]; + } + + return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[dmmv_wg][a_type]; } static void * ggml_vk_host_malloc(vk_device& device, size_t size) { @@ -6817,20 +6906,35 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ return false; } + // General performance issue with q3_k and q6_k due to 2-byte alignment + if (src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) { + return false; + } + // MMVQ is generally good for batches if (n > 1) { return true; } + // Quantization overhead is not worth it for small k switch (device->vendor_id) { case VK_VENDOR_ID_NVIDIA: + if (k <= 4096) { + return false; + } + switch (src0_type) { + case GGML_TYPE_MXFP4: case GGML_TYPE_Q8_0: return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING; default: return true; } case VK_VENDOR_ID_AMD: + if (k < 2048) { + return false; + } + switch (src0_type) { case GGML_TYPE_Q8_0: return device->architecture == vk_device_architecture::AMD_GCN; @@ -6838,6 +6942,10 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ return true; } case VK_VENDOR_ID_INTEL: + if (k < 2048) { + return false; + } + switch (src0_type) { // From tests on A770 Linux, may need more tuning case GGML_TYPE_Q4_0: @@ -6851,7 +6959,6 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ } GGML_UNUSED(m); - GGML_UNUSED(k); } static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { @@ -7574,7 +7681,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& if (x_non_contig || qx_needs_dequant) { ctx->prealloc_x_need_sync = true; } - if (y_non_contig) { + if (y_non_contig || quantize_y) { ctx->prealloc_y_need_sync = true; } } @@ -7600,7 +7707,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte const uint64_t ne10 = src1->ne[0]; const uint64_t ne11 = src1->ne[1]; - // const uint64_t ne12 = src1->ne[2]; + const uint64_t ne12 = src1->ne[2]; // const uint64_t ne13 = src1->ne[3]; const uint64_t nei0 = ids->ne[0]; @@ -7617,19 +7724,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; - - const bool qx_needs_dequant = x_non_contig; - const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig; - - // Not implemented - GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - - const uint64_t x_ne = ggml_nelements(src0); - const uint64_t y_ne = ggml_nelements(src1); - - const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); - const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; - const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne12, ne10, src0->type); vk_pipeline to_fp16_vk_0 = nullptr; vk_pipeline to_fp16_vk_1 = nullptr; @@ -7641,11 +7736,38 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } - vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type); + + // Check for mmq first + vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, GGML_TYPE_Q8_1, ne20, ne00) : nullptr; + vk_pipeline to_q8_1 = nullptr; + + if (dmmv == nullptr) { + // Fall back to f16 dequant mul mat + dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type, ne20, ne00); + quantize_y = false; + } + + if (quantize_y) { + to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); + } + + const bool qx_needs_dequant = x_non_contig; + const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig); + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT GGML_ASSERT(dmmv != nullptr); + const uint64_t x_ne = ggml_nelements(src0); + const uint64_t y_ne = ggml_nelements(src1); + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; + const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : + (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); + { if ( (qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) || @@ -7656,7 +7778,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte ctx->prealloc_size_x = x_sz; ggml_vk_preallocate_buffers(ctx, subctx); } - if (qy_needs_dequant && ctx->prealloc_size_y < y_sz) { + if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) { ctx->prealloc_size_y = y_sz; ggml_vk_preallocate_buffers(ctx, subctx); } @@ -7668,6 +7790,9 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte if (qy_needs_dequant) { ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); } + if (quantize_y) { + ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); + } ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); } @@ -7683,7 +7808,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte } else { d_X = d_Qx; } - if (qy_needs_dequant) { + if (qy_needs_dequant || quantize_y) { d_Y = { ctx->prealloc_y, 0, ctx->prealloc_y->size }; } else { d_Y = d_Qy; @@ -7711,6 +7836,17 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte ctx->prealloc_y_last_tensor_used = src1; } } + if (quantize_y) { + if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne); + ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } + } uint32_t stride_batch_y = ne10*ne11; @@ -7772,7 +7908,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte if (x_non_contig) { ctx->prealloc_x_need_sync = true; } - if (y_non_contig) { + if (y_non_contig || quantize_y) { ctx->prealloc_y_need_sync = true; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 09676a623b..70ee542d96 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -4,13 +4,6 @@ #include "types.glsl" -#if defined(A_TYPE_PACKED16) -layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; -#endif -#if defined(A_TYPE_PACKED32) -layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; -#endif - #if defined(DATA_A_F32) vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl index c1ad517256..ba7909c4d3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl @@ -22,6 +22,13 @@ layout (push_constant) uniform parameter #if !RMS_NORM_ROPE_FUSION layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif + layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl index 8dc9d360d5..cc181fda87 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl @@ -18,6 +18,13 @@ layout (push_constant) uniform parameter } p; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif + layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; uint get_idx() { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index 9a03925cfd..b3c96576de 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -3,6 +3,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #include "mul_mat_vec_base.glsl" +#include "dequant_funcs.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl index e4651a683b..cfc8b0c7f4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl @@ -13,8 +13,6 @@ #include "mul_mat_vec_iface.glsl" -#include "dequant_funcs.glsl" - layout (push_constant) uniform parameter { uint ncols; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl index 14ab1fd74c..337dbd796a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl @@ -5,13 +5,15 @@ #define MAT_VEC_FUSION_FLAGS_SCALE0 0x4 #define MAT_VEC_FUSION_FLAGS_SCALE1 0x8 -#ifndef MMQ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; #if defined(A_TYPE_VEC4) layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; #endif -#else -layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; #endif layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp index 64293f6eca..15f005be3e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp @@ -10,60 +10,56 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4) #define K_PER_ITER 8 - -#include "mul_mmq_funcs.glsl" +#elif defined(DATA_A_QUANT_K) +#define K_PER_ITER 16 +#else +#error unimplemented +#endif uint a_offset, b_offset, d_offset; -int32_t cache_b_qs[2]; +int32_t cache_b_qs[K_PER_ITER / 4]; vec2 cache_b_ds; +#include "mul_mat_vecq_funcs.glsl" + void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { const uint col = i*BLOCK_SIZE + tid*K_PER_ITER; // Preload data_b block const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset; - const uint b_qs_idx = tid % 4; + const uint b_qs_idx = tid % (32 / K_PER_ITER); const uint b_block_idx_outer = b_block_idx / 4; const uint b_block_idx_inner = b_block_idx % 4; cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]); #if QUANT_R == 2 + // Assumes K_PER_ITER == 8 cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx]; cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4]; #else +#if K_PER_ITER == 8 cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2]; cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1]; +#elif K_PER_ITER == 16 + cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 ]; + cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1]; + cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2]; + cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3]; +#else +#error unimplemented +#endif #endif uint ibi = first_row*p.ncols; [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint a_block_idx = (ibi + col)/QUANT_K + a_offset; + const uint a_block_idx = (ibi + col)/QUANT_K_Q8_1 + a_offset; ibi += p.ncols; - int32_t q_sum = 0; -#if QUANT_R == 2 - const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx); - q_sum += dotPacked4x8EXT(data_a_qs.x, - cache_b_qs[0]); - q_sum += dotPacked4x8EXT(data_a_qs.y, - cache_b_qs[1]); -#else - int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2); - q_sum += dotPacked4x8EXT(data_a_qs, - cache_b_qs[0]); - data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1); - q_sum += dotPacked4x8EXT(data_a_qs, - cache_b_qs[1]); -#endif - -#if QUANT_AUXF == 1 - temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4); -#else - temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4); -#endif + temp[j][n] += mmvq_dot_product(a_block_idx, b_qs_idx); } } } @@ -72,7 +68,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint tid = gl_LocalInvocationID.x; get_offsets(a_offset, b_offset, d_offset); - a_offset /= QUANT_K; + a_offset /= QUANT_K_Q8_1; b_offset /= QUANT_K_Q8_1; FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; @@ -102,14 +98,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { unroll_count = 2; unrolled_iters = num_iters & ~(unroll_count - 1); -#if K_PER_ITER == 2 - if ((p.ncols & 1) != 0 && - unrolled_iters == num_iters && - unrolled_iters > 0) { - unrolled_iters -= unroll_count; - } -#endif - while (i < unrolled_iters) { // Manually partially unroll the loop [[unroll]] for (uint k = 0; k < unroll_count; ++k) { @@ -128,6 +116,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { void main() { const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + // do NUM_ROWS at a time, unless there aren't enough remaining rows if (first_row + NUM_ROWS <= p.stride_d) { compute_outputs(first_row, NUM_ROWS); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl new file mode 100644 index 0000000000..2389ea0b1e --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl @@ -0,0 +1,379 @@ +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require + +#include "types.glsl" + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +FLOAT_TYPE get_dm(uint ib) { + return FLOAT_TYPE(data_a[ib].d); +} +#endif + +#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) +FLOAT_TYPE_VEC2 get_dm(uint ib) { + return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); +} +#endif + +#if defined(DATA_A_MXFP4) +FLOAT_TYPE get_dm(uint ib) { + return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e)); +} +#endif + +#if defined(DATA_A_Q2_K) +FLOAT_TYPE_VEC2 get_dm(uint ib) { + const uint ib_k = ib / 8; + return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); +} +#endif + +// Each iqs value maps to a 32-bit integer +#if defined(DATA_A_Q4_0) +// 2-byte loads for Q4_0 blocks (18 bytes) +i32vec2 repack(uint ib, uint iqs) { + const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y)); +} +#endif + +#if defined(DATA_A_Q4_1) +// 4-byte loads for Q4_1 blocks (20 bytes) +i32vec2 repack(uint ib, uint iqs) { + const uint32_t vui = data_a_packed32[ib].qs[iqs]; + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); +} +#endif + +#if defined(DATA_A_Q5_0) +// 2-byte loads for Q5_0 blocks (22 bytes) +i32vec2 repack(uint ib, uint iqs) { + const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs)); + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + + const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + return i32vec2(v0, v1); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y)); +} +#endif + +#if defined(DATA_A_Q5_1) +// 4-byte loads for Q5_1 blocks (24 bytes) +i32vec2 repack(uint ib, uint iqs) { + const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + + const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + return i32vec2(v0, v1); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); +} +#endif + +#if defined(DATA_A_Q8_0) +// 2-byte loads for Q8_0 blocks (34 bytes) +int32_t repack(uint ib, uint iqs) { + return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1])); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(float(q_sum) * da * dsb.x); +} +#endif + +#if defined(DATA_A_MXFP4) +// 1-byte loads for mxfp4 blocks (17 bytes) +i32vec2 repack(uint ib, uint iqs) { + const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], + data_a[ib].qs[iqs * 4 + 1], + data_a[ib].qs[iqs * 4 + 2], + data_a[ib].qs[iqs * 4 + 3])); + + const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F); + const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F); + + return i32vec2(pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w])), + pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w]))); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(da * dsb.x * float(q_sum) * 0.5); +} +#endif + +#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4) +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t q_sum = 0; +#if QUANT_R == 2 + const i32vec2 data_a_qs = repack(ib_a, iqs); + q_sum += dotPacked4x8EXT(data_a_qs.x, + cache_b_qs[0]); + q_sum += dotPacked4x8EXT(data_a_qs.y, + cache_b_qs[1]); +#else + int32_t data_a_qs = repack(ib_a, iqs * 2); + q_sum += dotPacked4x8EXT(data_a_qs, + cache_b_qs[0]); + data_a_qs = repack(ib_a, iqs * 2 + 1); + q_sum += dotPacked4x8EXT(data_a_qs, + cache_b_qs[1]); +#endif + + // 2 quants per call => divide sums by 8/2 = 4 + return mul_q8_1(q_sum, get_dm(ib_a), cache_b_ds, 4); +} +#endif + +#if defined(DATA_A_Q2_K) +// 4-byte loads for Q2_K blocks (84 bytes) +i32vec4 repack4(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 32) / 8) * 2; + + return i32vec4((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303, + (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303, + (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303, + (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303); +} + +uint8_t get_scale(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + return data_a[ib_k].scales[iqs_k / 4]; +} + +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t sum_d = 0; + int32_t sum_m = 0; + + const i32vec4 qs_a = repack4(ib_a, iqs * 4); + const uint8_t scale = get_scale(ib_a, iqs * 4); + const vec2 dm = vec2(get_dm(ib_a)); + const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits. + + sum_d += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]) * (scale & 0xF); + sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[0]); + + sum_d += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]) * (scale & 0xF); + sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[1]); + + sum_d += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]) * (scale & 0xF); + sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[2]); + + sum_d += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]) * (scale & 0xF); + sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[3]); + + return FLOAT_TYPE(float(cache_b_ds.x) * (float(dm.x) * float(sum_d) - float(dm.y) * float(sum_m))); +} +#endif + +#if defined(DATA_A_Q3_K) +// 2-byte loads for Q3_K blocks (110 bytes) +i32vec4 repack4(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 32) / 8) * 2; + const uint hm_shift = iqs_k / 8; + + // bitwise OR to add 4 if hmask is set, subtract later + const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals20 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 4] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 4] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals21 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 5] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 5] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals30 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 6] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 6] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals31 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 7] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 7] >> hm_shift) & uint16_t(0x0101)) << 2)); + + return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y) - int8_t(4)), + pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y) - int8_t(4)), + pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y) - int8_t(4)), + pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y) - int8_t(4))); +} + +float get_d_scale(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + const uint is = iqs_k / 4; + + const int8_t scale = int8_t(((data_a[ib_k].scales[is % 8 ] >> (4 * (is / 8))) & 0x0F0F) | + (((data_a[ib_k].scales[8 + (is % 4)] >> (2 * (is / 4))) & 0x0303) << 4)); + return float(data_a[ib_k].d) * float(scale - 32); +} + +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t q_sum = 0; + + const i32vec4 qs_a = repack4(ib_a, iqs * 4); + const float d_scale = get_d_scale(ib_a, iqs * 4); + + q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]); + q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]); + q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]); + q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]); + + return FLOAT_TYPE(float(cache_b_ds.x) * d_scale * float(q_sum)); +} +#endif + +#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) +// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes) +i32vec4 repack4(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 16) / 8) * 4; + +#if defined(DATA_A_Q4_K) + const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F; + const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F; + const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F; + const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F; + + return i32vec4(vals0, vals1, vals2, vals3); +#else // defined(DATA_A_Q5_K) + const uint qh_idx = iqs; + const uint qh_shift = iqs_k / 8; + + return i32vec4(((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F) | + (((data_a_packed32[ib_k].qh[qh_idx ] >> qh_shift) & 0x01010101) << 4), + ((data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F) | + (((data_a_packed32[ib_k].qh[qh_idx + 1] >> qh_shift) & 0x01010101) << 4), + ((data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F) | + (((data_a_packed32[ib_k].qh[qh_idx + 2] >> qh_shift) & 0x01010101) << 4), + ((data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F) | + (((data_a_packed32[ib_k].qh[qh_idx + 3] >> qh_shift) & 0x01010101) << 4)); +#endif +} + +vec2 get_dm_scale(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + const uint is = iqs_k / 8; + u8vec2 scale_dm; + if (is < 4) { + scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F); + } else { + scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2), + (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2)); + } + + return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm); +} + +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t q_sum = 0; + + const i32vec4 qs_a = repack4(ib_a, iqs * 4); + const vec2 dm_scale = get_dm_scale(ib_a, iqs * 4); + + q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]); + q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]); + q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]); + q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]); + + return FLOAT_TYPE(float(cache_b_ds.x) * float(dm_scale.x) * float(q_sum) - float(dm_scale.y) * float(cache_b_ds.y / 2)); +} +#endif + +#if defined(DATA_A_Q6_K) +// 2-byte loads for Q6_K blocks (210 bytes) +i32vec4 repack4(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16; + const uint ql_shift = ((iqs_k % 32) / 16) * 4; + + const uint qh_idx = (iqs_k / 32) * 8 + iqs; + const uint qh_shift = ((iqs_k % 32) / 8) * 2; + + const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals10 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 2] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 2] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals11 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 3] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 3] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals20 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 4] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 4] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals21 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 5] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 5] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals30 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 6] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 6] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals31 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 7] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 7] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + + return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y)), + pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y)), + pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y)), + pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y))); +} + +float get_d_scale(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + return float(data_a[ib_k].d) * float(data_a[ib_k].scales[iqs_k / 4]); +} + +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t q_sum = 0; + + const i32vec4 qs_a = repack4(ib_a, iqs * 4); + const float d_scale = get_d_scale(ib_a, iqs * 4); + + q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]); + q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]); + q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]); + q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]); + + return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum)); +} +#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index 5266e523b9..dc8b3df47b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -78,8 +78,6 @@ layout (constant_id = 10) const uint WARP = 32; #define BK 32 -#define MMQ_SHMEM - #include "mul_mmq_shmem_types.glsl" #ifdef MUL_MAT_ID diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl index 4e3a561142..7f32dadf17 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -9,31 +9,6 @@ #if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) // 2-byte loads for Q4_0 blocks (18 bytes) // 4-byte loads for Q4_1 blocks (20 bytes) -i32vec2 repack(uint ib, uint iqs) { -#ifdef DATA_A_Q4_0 - const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], - data_a_packed16[ib].qs[iqs * 2 + 1]); - const uint32_t vui = pack32(quants); - return i32vec2( vui & 0x0F0F0F0F, - (vui >> 4) & 0x0F0F0F0F); -#else // DATA_A_Q4_1 - const uint32_t vui = data_a_packed32[ib].qs[iqs]; - return i32vec2( vui & 0x0F0F0F0F, - (vui >> 4) & 0x0F0F0F0F); -#endif -} - -#ifdef DATA_A_Q4_0 -ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y)); -} -#else // DATA_A_Q4_1 -ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); -} -#endif - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { #ifdef DATA_A_Q4_0 buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2], @@ -73,42 +48,17 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { q_sum += dotPacked4x8EXT(qs_a.y, qs_b1); } - return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); +#ifdef DATA_A_Q4_0 + return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 8.0 * float(cache_b.ds.y))); +#else // DATA_A_Q4_1 + return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y)); +#endif } -#endif // MMQ_SHMEM +#endif -#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) +#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) // 2-byte loads for Q5_0 blocks (22 bytes) // 4-byte loads for Q5_1 blocks (24 bytes) -i32vec2 repack(uint ib, uint iqs) { - const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], - data_a_packed16[ib].qs[iqs * 2 + 1]); - const uint32_t vui = pack32(quants); -#ifdef DATA_A_Q5_0 - const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs)); -#else // DATA_A_Q5_1 - const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); -#endif - const int32_t v0 = int32_t(vui & 0x0F0F0F0F) - | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) - - const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) - | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) - - return i32vec2(v0, v1); -} - -#ifdef DATA_A_Q5_0 -ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y)); -} -#else // DATA_A_Q5_1 -ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); -} -#endif - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { #ifdef DATA_A_Q5_0 buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2], @@ -154,23 +104,16 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { q_sum += dotPacked4x8EXT(qs_a1, qs_b1); } - return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); +#ifdef DATA_A_Q5_0 + return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 16.0 * float(cache_b.ds.y))); +#else // DATA_A_Q5_1 + return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y)); +#endif } -#endif // MMQ_SHMEM #endif #if defined(DATA_A_Q8_0) // 2-byte loads for Q8_0 blocks (34 bytes) -int32_t repack(uint ib, uint iqs) { - return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ], - data_a_packed16[ib].qs[iqs * 2 + 1])); -} - -ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(float(q_sum) * da * dsb.x); -} - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2], data_a_packed16[ib].qs[iqs * 2 + 1])); @@ -197,28 +140,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { q_sum += dotPacked4x8EXT(qs_a, qs_b); } - return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); + return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm) * float(cache_b.ds.x)); } -#endif // MMQ_SHMEM #endif #if defined(DATA_A_MXFP4) // 1-byte loads for mxfp4 blocks (17 bytes) -i32vec2 repack(uint ib, uint iqs) { - const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], - data_a[ib].qs[iqs * 4 + 1], - data_a[ib].qs[iqs * 4 + 2], - data_a[ib].qs[iqs * 4 + 3])); - - return i32vec2( quants & 0x0F0F0F0F, - (quants >> 4) & 0x0F0F0F0F); -} - -ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(da * dsb.x * float(q_sum)); -} - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], data_a[ib].qs[iqs * 4 + 1], @@ -252,37 +179,14 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); } - return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1); + return ACC_TYPE(float(cache_a[ib_a].d) * float(cache_b.ds.x) * float(q_sum)); } -#endif // MMQ_SHMEM #endif // For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide // iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants #if defined(DATA_A_Q2_K) // 4-byte loads for Q2_K blocks (84 bytes) -int32_t repack(uint ib, uint iqs) { - const uint ib_k = ib / 8; - const uint iqs_k = (ib % 8) * 8 + iqs; - - const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); - const uint qs_shift = ((iqs_k % 32) / 8) * 2; - - return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303); -} - -uint8_t get_scale(uint ib, uint iqs) { - const uint ib_k = ib / 8; - const uint iqs_k = (ib % 8) * 8 + iqs; - - return data_a[ib_k].scales[iqs_k / 4]; -} - -ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m))); -} - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint ib_k = ib / 8; const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ; @@ -326,14 +230,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]); } - return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1); + return ACC_TYPE(float(cache_b.ds.x) * (float(cache_a[ib_a].dm.x) * float(sum_d) - float(cache_a[ib_a].dm.y) * float(sum_m))); } -#endif // MMQ_SHMEM #endif #if defined(DATA_A_Q3_K) // 2-byte loads for Q3_K blocks (110 bytes) -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint ib_k = ib / 8; const uint hm_idx = iqs * QUANT_R_MMQ; @@ -394,18 +296,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { } result += float(cache_a[ib_a].d_scales[1]) * float(q_sum); - return ACC_TYPE(cache_b.ds.x * result); + return ACC_TYPE(float(cache_b.ds.x) * result); } -#endif // MMQ_SHMEM #endif #if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) // 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes) -ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y); -} - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint ib_k = ib / 8; const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ; @@ -427,7 +323,6 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { (((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4)); #endif - if (iqs == 0) { // Scale index const uint is = iqs_k / 8; @@ -464,49 +359,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); } - return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); -} -#endif // MMQ_SHMEM -#endif - -#ifdef MMQ_SHMEM -void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) { - if (is_in_bounds) { - const uint ib_outer = ib / 4; - const uint ib_inner = ib % 4; - - if (iqs == 0) { - buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); - } - - const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; - buf_b[buf_ib].qs[iqs * 4 ] = values.x; - buf_b[buf_ib].qs[iqs * 4 + 1] = values.y; - buf_b[buf_ib].qs[iqs * 4 + 2] = values.z; - buf_b[buf_ib].qs[iqs * 4 + 3] = values.w; - } else { - if (iqs == 0) { - buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f); - } - - buf_b[buf_ib].qs[iqs * 4 ] = 0; - buf_b[buf_ib].qs[iqs * 4 + 1] = 0; - buf_b[buf_ib].qs[iqs * 4 + 2] = 0; - buf_b[buf_ib].qs[iqs * 4 + 3] = 0; - } -} - -void block_b_to_registers(const uint ib) { - cache_b.ds = buf_b[ib].ds; - [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) { - cache_b.qs[iqs] = buf_b[ib].qs[iqs]; - } + return ACC_TYPE(float(cache_b.ds.x) * float(cache_a[ib_a].dm.x) * float(q_sum) - float(cache_a[ib_a].dm.y) * float(cache_b.ds.y)); } #endif #if defined(DATA_A_Q6_K) // 2-byte loads for Q6_K blocks (210 bytes) -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint ib_k = ib / 8; const uint iqs_k = (ib % 8) * 8 + iqs; @@ -558,32 +416,39 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { } result += float(cache_a[ib_a].d_scales[1]) * float(q_sum); - return ACC_TYPE(cache_b.ds.x * result); -} -#endif // MMQ_SHMEM -#endif - -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) -FLOAT_TYPE get_d(uint ib) { - return FLOAT_TYPE(data_a[ib].d); + return ACC_TYPE(float(cache_b.ds.x) * result); } #endif -#if defined(DATA_A_MXFP4) -FLOAT_TYPE get_d(uint ib) { - return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e)); -} -#endif +void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) { + if (is_in_bounds) { + const uint ib_outer = ib / 4; + const uint ib_inner = ib % 4; -#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) -FLOAT_TYPE_VEC2 get_dm(uint ib) { - return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); -} -#endif + if (iqs == 0) { + buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); + } -#if defined(DATA_A_Q2_K) -FLOAT_TYPE_VEC2 get_dm(uint ib) { - const uint ib_k = ib / 8; - return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); + const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; + buf_b[buf_ib].qs[iqs * 4 ] = values.x; + buf_b[buf_ib].qs[iqs * 4 + 1] = values.y; + buf_b[buf_ib].qs[iqs * 4 + 2] = values.z; + buf_b[buf_ib].qs[iqs * 4 + 3] = values.w; + } else { + if (iqs == 0) { + buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f); + } + + buf_b[buf_ib].qs[iqs * 4 ] = 0; + buf_b[buf_ib].qs[iqs * 4 + 1] = 0; + buf_b[buf_ib].qs[iqs * 4 + 2] = 0; + buf_b[buf_ib].qs[iqs * 4 + 3] = 0; + } +} + +void block_b_to_registers(const uint ib) { + cache_b.ds = buf_b[ib].ds; + [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) { + cache_b.qs[iqs] = buf_b[ib].qs[iqs]; + } } -#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 214a743b97..92bae088b2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -679,14 +679,20 @@ void process_shaders() { string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); // mul mat vec with integer dot product #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (is_legacy_quant(tname)) { + if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname)) { string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); } #endif @@ -1100,7 +1106,7 @@ void write_output_files() { for (const std::string& btype : btypes) { for (const auto& tname : type_names) { - if (btype == "q8_1" && !is_legacy_quant(tname)) { + if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname)) { continue; } hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n"; @@ -1109,6 +1115,16 @@ void write_output_files() { src << "const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n"; src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n"; } + + if (btype == "f16") { + continue; + } + hdr << "extern const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3];\n"; + hdr << "extern const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3];\n"; + if (basename(input_filepath) == "mul_mat_vec.comp") { + src << "const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n"; + src << "const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_len, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n"; + } } } From f698a79c6396771a33e903930c5b5348b87a8715 Mon Sep 17 00:00:00 2001 From: ixgbe <1113177880@qq.com> Date: Sat, 29 Nov 2025 20:56:31 +0800 Subject: [PATCH 09/13] ggml: replace hwcap with riscv_hwprobe for RVV detection (#17567) Signed-off-by: Wang Yang --- ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp b/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp index b181898818..43c757bd01 100644 --- a/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp @@ -1,20 +1,23 @@ #include "ggml-backend-impl.h" #if defined(__riscv) && __riscv_xlen == 64 -#include - -//https://github.com/torvalds/linux/blob/master/arch/riscv/include/uapi/asm/hwcap.h#L24 -#ifndef COMPAT_HWCAP_ISA_V -#define COMPAT_HWCAP_ISA_V (1 << ('V' - 'A')) -#endif +#include +#include +#include struct riscv64_features { bool has_rvv = false; riscv64_features() { - uint32_t hwcap = getauxval(AT_HWCAP); + struct riscv_hwprobe probe; + probe.key = RISCV_HWPROBE_KEY_IMA_EXT_0; + probe.value = 0; - has_rvv = !!(hwcap & COMPAT_HWCAP_ISA_V); + int ret = syscall(__NR_riscv_hwprobe, &probe, 1, 0, NULL, 0); + + if (0 == ret) { + has_rvv = !!(probe.value & RISCV_HWPROBE_IMA_V); + } } }; From 7d2add51d8e3759020d70f2ff3a76b5795ff67bc Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Sat, 29 Nov 2025 20:59:44 +0800 Subject: [PATCH 10/13] sycl : support to malloc memory on device more than 4GB, update the doc and script (#17566) Co-authored-by: Neo Zhang Jianyu --- docs/backend/SYCL.md | 13 +++++++++++++ examples/sycl/run-llama2.sh | 3 +++ examples/sycl/run-llama3.sh | 9 ++++++--- examples/sycl/win-run-llama2.bat | 2 ++ examples/sycl/win-run-llama3.bat | 4 +++- ggml/src/ggml-sycl/CMakeLists.txt | 8 ++++++-- ggml/src/ggml-sycl/cpy.cpp | 3 --- 7 files changed, 33 insertions(+), 9 deletions(-) diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index 92ab27066b..02a72a9d51 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -42,6 +42,9 @@ The following releases are verified and recommended: ## News +- 2025.11 + - Support malloc memory on device more than 4GB. + - 2025.2 - Optimize MUL_MAT Q4_0 on Intel GPU for all dGPUs and built-in GPUs since MTL. Increase the performance of LLM (llama-2-7b.Q4_0.gguf) 21%-87% on Intel GPUs (MTL, ARL-H, Arc, Flex, PVC). |GPU|Base tokens/s|Increased tokens/s|Percent| @@ -789,6 +792,8 @@ use 1 SYCL GPUs: [0] with Max compute units:512 | GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. | | GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. | | ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.
Recommended to use when --split-mode = layer | +| UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Support malloc device memory more than 4GB.| + ## Known Issues @@ -835,6 +840,14 @@ use 1 SYCL GPUs: [0] with Max compute units:512 | The default context is too big. It leads to excessive memory usage.|Set `-c 8192` or a smaller value.| | The model is too big and requires more memory than what is available.|Choose a smaller model or change to a smaller quantization, like Q5 -> Q4;
Alternatively, use more than one device to load model.| +- `ggml_backend_sycl_buffer_type_alloc_buffer: can't allocate 5000000000 Bytes of memory on device` + + You need to enable to support 4GB memory malloc by: + ``` + export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 + set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 + ``` + ### **GitHub contribution**: Please add the `SYCL :` prefix/tag in issues/PRs titles to help the SYCL contributors to check/address them without delay. diff --git a/examples/sycl/run-llama2.sh b/examples/sycl/run-llama2.sh index 37195008de..a018e45197 100755 --- a/examples/sycl/run-llama2.sh +++ b/examples/sycl/run-llama2.sh @@ -15,6 +15,9 @@ MODEL_FILE=models/llama-2-7b.Q4_0.gguf NGL=99 CONTEXT=4096 +#support malloc device memory more than 4GB. +export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 + if [ $# -gt 0 ]; then GGML_SYCL_DEVICE=$1 echo "use $GGML_SYCL_DEVICE as main GPU" diff --git a/examples/sycl/run-llama3.sh b/examples/sycl/run-llama3.sh index 8e21b017f4..4770255703 100755 --- a/examples/sycl/run-llama3.sh +++ b/examples/sycl/run-llama3.sh @@ -6,7 +6,7 @@ # If you want more control, DPC++ Allows selecting a specific device through the # following environment variable -#export ONEAPI_DEVICE_SELECTOR="level_zero:0" +export ONEAPI_DEVICE_SELECTOR="level_zero:0" source /opt/intel/oneapi/setvars.sh #export GGML_SYCL_DEBUG=1 @@ -18,11 +18,14 @@ MODEL_FILE=models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf NGL=99 # Layers offloaded to the GPU. If the device runs out of memory, reduce this value according to the model you are using. CONTEXT=4096 +#support malloc device memory more than 4GB. +export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 + if [ $# -gt 0 ]; then GGML_SYCL_DEVICE=$1 echo "Using $GGML_SYCL_DEVICE as the main GPU" - ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none + ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none else #use multiple GPUs with same max compute units - ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT} + ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} fi diff --git a/examples/sycl/win-run-llama2.bat b/examples/sycl/win-run-llama2.bat index d7564f4161..b654f88f62 100644 --- a/examples/sycl/win-run-llama2.bat +++ b/examples/sycl/win-run-llama2.bat @@ -5,5 +5,7 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:" @call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force +:: support malloc device memory more than 4GB. +set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 .\build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 99 -s 0 diff --git a/examples/sycl/win-run-llama3.bat b/examples/sycl/win-run-llama3.bat index 4b61aebee5..608b834f60 100644 --- a/examples/sycl/win-run-llama3.bat +++ b/examples/sycl/win-run-llama3.bat @@ -5,5 +5,7 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:" @call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force +:: support malloc device memory more than 4GB. +set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 -.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -e -ngl 99 +.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -s 0 -e -ngl 99 diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index efd78b912c..88f29221bb 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -91,7 +91,10 @@ if (GGML_SYCL_F16) add_compile_definitions(GGML_SYCL_F16) endif() -if (GGML_SYCL_TARGET STREQUAL "NVIDIA") +if (GGML_SYCL_TARGET STREQUAL "INTEL") + add_compile_definitions(GGML_SYCL_WARP_SIZE=16) + target_link_options(ggml-sycl PRIVATE -Xs -ze-intel-greater-than-4GB-buffer-required) +elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA") add_compile_definitions(GGML_SYCL_WARP_SIZE=32) elseif (GGML_SYCL_TARGET STREQUAL "AMD") # INFO: Allowed Sub_group_sizes are not consistent through all @@ -100,7 +103,8 @@ elseif (GGML_SYCL_TARGET STREQUAL "AMD") # Target archs tested working: gfx1030, gfx1031, (Only tested sub_group_size = 32) add_compile_definitions(GGML_SYCL_WARP_SIZE=32) else() - add_compile_definitions(GGML_SYCL_WARP_SIZE=16) + # default for other target + add_compile_definitions(GGML_SYCL_WARP_SIZE=32) endif() if (GGML_SYCL_GRAPH) diff --git a/ggml/src/ggml-sycl/cpy.cpp b/ggml/src/ggml-sycl/cpy.cpp index 1ec99b0a5d..96709554cf 100644 --- a/ggml/src/ggml-sycl/cpy.cpp +++ b/ggml/src/ggml-sycl/cpy.cpp @@ -515,9 +515,6 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); - GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); - GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); - GGML_TENSOR_BINARY_OP_LOCALS01; SYCL_CHECK(ggml_sycl_set_device(ctx.device)); From 0874693b449a847de2d052f2afb5d0cbe9409f92 Mon Sep 17 00:00:00 2001 From: Igor Smirnov Date: Sat, 29 Nov 2025 21:06:32 +0500 Subject: [PATCH 11/13] common : fix json schema with '\' in literals (#17307) * Fix json schema with '\' in literals * Add "literal string with escapes" test --- common/json-schema-to-grammar.cpp | 4 +-- examples/json_schema_to_grammar.py | 4 +-- tests/test-json-schema-to-grammar.cpp | 26 +++++++++++++++++++ .../public_legacy/json-schema-to-grammar.mjs | 4 +-- 4 files changed, 32 insertions(+), 6 deletions(-) diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index e64dc059f3..c8421e1e82 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -268,10 +268,10 @@ static bool is_reserved_name(const std::string & name) { } std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+"); -std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]"); +std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]"); std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]"); std::unordered_map GRAMMAR_LITERAL_ESCAPES = { - {'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"} + {'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}, {'\\', "\\\\"} }; std::unordered_set NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'}; diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index 26989157fe..886dd3d81e 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -231,9 +231,9 @@ DOT = '[^\\x0A\\x0D]' RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]) INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+') -GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') +GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\\]') GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]') -GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'} +GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\'} NON_LITERAL_SET = set('|.()[]{}*+?') ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?') diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index 8a55bc54ae..1e568219d2 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -1339,6 +1339,32 @@ static void test_all(const std::string & lang, std::function Date: Sun, 30 Nov 2025 01:43:29 +0800 Subject: [PATCH 12/13] server: explicitly set the function name in lambda (#17538) As [1] explained, the real debug message will be like: "res operator(): operator() : queue result stop" Set the name explicitly, the message is easy for debugging: "res operator(): recv : queue result stop" The left "operator()" is generated by 'RES_DBG() ... __func__' [1]: https://clang.llvm.org/extra/clang-tidy/checks/bugprone/lambda-function-name.html Signed-off-by: Haiyue Wang --- tools/server/server-queue.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/server/server-queue.cpp b/tools/server/server-queue.cpp index 5a74fd76ac..65c8a0a9ae 100644 --- a/tools/server/server-queue.cpp +++ b/tools/server/server-queue.cpp @@ -199,7 +199,7 @@ server_task_result_ptr server_response::recv(const std::unordered_set & id_ std::unique_lock lock(mutex_results); condition_results.wait(lock, [&]{ if (!running) { - RES_DBG("%s : queue result stop\n", __func__); + RES_DBG("%s : queue result stop\n", "recv"); std::terminate(); // we cannot return here since the caller is HTTP code } return !queue_results.empty(); From ab49f094d29057b2c4d2b0f1e822c9140df1b0e1 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Sat, 29 Nov 2025 22:04:44 +0100 Subject: [PATCH 13/13] server: move server-context to its own cpp|h (#17595) * git mv * add server-context.h * add server-context.h * clean up headers * cont : cleanup * also expose server_response_reader (to be used by CLI) * fix windows build * decouple server_routes and server_http --------- Co-authored-by: Georgi Gerganov --- tools/server/CMakeLists.txt | 2 + tools/server/server-context.cpp | 3619 ++++++++++++++++++++++++++++++ tools/server/server-context.h | 83 + tools/server/server-queue.cpp | 83 + tools/server/server-queue.h | 36 + tools/server/server.cpp | 3670 +------------------------------ 6 files changed, 3831 insertions(+), 3662 deletions(-) create mode 100644 tools/server/server-context.cpp create mode 100644 tools/server/server-context.h diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index 7fbca32016..d8623621f3 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -21,6 +21,8 @@ set(TARGET_SRCS server-queue.h server-common.cpp server-common.h + server-context.cpp + server-context.h ) set(PUBLIC_ASSETS index.html.gz diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp new file mode 100644 index 0000000000..2bf3924df9 --- /dev/null +++ b/tools/server/server-context.cpp @@ -0,0 +1,3619 @@ +#include "server-context.h" +#include "server-common.h" +#include "server-http.h" +#include "server-task.h" +#include "server-queue.h" + +#include "arg.h" +#include "common.h" +#include "llama.h" +#include "log.h" +#include "sampling.h" +#include "speculative.h" +#include "mtmd.h" +#include "mtmd-helper.h" + +#include +#include +#include +#include + +// fix problem with std::min and std::max +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include +#endif + +using json = nlohmann::ordered_json; + +constexpr int HTTP_POLLING_SECONDS = 1; + +// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 +enum slot_state { + SLOT_STATE_IDLE, + SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future + SLOT_STATE_PROCESSING_PROMPT, + SLOT_STATE_DONE_PROMPT, + SLOT_STATE_GENERATING, +}; + +enum server_state { + SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet + SERVER_STATE_READY, // Server is ready and model is loaded +}; + +static bool server_task_type_need_embd(server_task_type task_type) { + switch (task_type) { + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: + return true; + default: + return false; + } +} + +static bool server_task_type_need_logits(server_task_type task_type) { + switch (task_type) { + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + return true; + default: + return false; + } +} + +struct server_slot { + int id; + + llama_batch batch_spec = {}; + + // TODO: change to unique_ptrs for consistency: + llama_context * ctx = nullptr; + llama_context * ctx_dft = nullptr; + + // multimodal + mtmd_context * mctx = nullptr; + + common_speculative * spec = nullptr; + + std::unique_ptr task; + std::unique_ptr task_prev; // used for debugging + + // used to determine the slot that has been used the longest + int64_t t_last_used = -1; + + // generation props + int32_t n_ctx = 0; // context size per slot + int32_t n_keep = 0; + int32_t n_decoded = 0; + int32_t n_remaining = -1; + int32_t i_batch = -1; + + int32_t n_prompt_tokens_cache = 0; + int32_t n_prompt_tokens_processed = 0; + + size_t last_nl_pos = 0; + + std::string generated_text; + llama_tokens generated_tokens; + + common_chat_msg chat_msg; + + std::vector generated_token_probs; + + bool has_next_token = true; + bool has_new_line = false; + bool truncated = false; + + stop_type stop; + + std::string stopping_word; + + // state + slot_state state = SLOT_STATE_IDLE; + + server_prompt prompt; + + void prompt_save(server_prompt_cache & prompt_cache) const { + GGML_ASSERT(prompt.data.size() == 0); + + const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); + + SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n", + (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); + + auto * cur = prompt_cache.alloc(prompt, cur_size); + if (cur == nullptr) { + return; + } + + llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0); + } + + bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { + bool res = prompt_cache.load(prompt, tokens, ctx, id); + if (!res) { + SLT_WRN(*this, "%s", "failed to load prompt from cache\n"); + } + + return res; + } + + std::vector lora; + int32_t alora_invocation_start = -1; + + // sampling + json json_schema; + + struct common_sampler * smpl = nullptr; + + llama_token sampled; + + common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + std::vector generated_tool_call_ids; + + // stats + size_t n_sent_text = 0; // number of sent text character + + int64_t t_start_process_prompt; + int64_t t_start_generation; + + double t_prompt_processing; // ms + double t_token_generation; // ms + + std::function callback_on_release; + + // Speculative decoding stats + int32_t n_draft_total = 0; // Total draft tokens generated + int32_t n_draft_accepted = 0; // Draft tokens actually accepted + + void reset() { + SLT_DBG(*this, "%s", "\n"); + + n_prompt_tokens_cache = 0; + + last_nl_pos = 0; + generated_text = ""; + has_new_line = false; + truncated = false; + stop = STOP_TYPE_NONE; + stopping_word = ""; + n_sent_text = 0; + chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + + generated_tokens.clear(); + generated_token_probs.clear(); + chat_msg = {}; + json_schema = json(); + generated_tool_call_ids.clear(); + + // clear speculative decoding stats + n_draft_total = 0; + n_draft_accepted = 0; + + task.reset(); + task_prev.reset(); + + // clear alora start + alora_invocation_start = -1; + } + + bool need_embd() const { + GGML_ASSERT(task); + + return server_task_type_need_embd(task->type); + } + + bool need_logits() const { + GGML_ASSERT(task); + + return server_task_type_need_logits(task->type); + } + + // if the context does not have a memory module then all embeddings have to be computed within a single ubatch + // also we cannot split if the pooling would require any past tokens + bool can_split() const { + return + !need_embd() || + (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); + } + + bool can_batch_with(server_slot & other_slot) const { + GGML_ASSERT(task); + + return task->type == other_slot.task->type && are_lora_equal(lora, other_slot.lora); + } + + bool has_budget(const common_params & global_params) { + GGML_ASSERT(task); + + if (task->params.n_predict == -1 && global_params.n_predict == -1) { + return true; // limitless + } + + n_remaining = -1; + + if (task->params.n_predict != -1) { + n_remaining = task->params.n_predict - n_decoded; + } else if (global_params.n_predict != -1) { + n_remaining = global_params.n_predict - n_decoded; + } + + return n_remaining > 0; // no budget + } + + bool is_processing() const { + return state != SLOT_STATE_IDLE; + } + + bool can_speculate() const { + return ctx_dft; + } + + void add_token(const completion_token_output & token) { + if (!is_processing()) { + SLT_WRN(*this, "%s", "slot is not processing\n"); + return; + } + generated_token_probs.push_back(token); + } + + void release() { + if (is_processing()) { + GGML_ASSERT(task); + + SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated); + + t_last_used = ggml_time_us(); + t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; + state = SLOT_STATE_IDLE; + + task_prev = std::move(task); + task.reset(); + + callback_on_release(id); + } + } + + result_timings get_timings() const { + result_timings timings; + timings.cache_n = n_prompt_tokens_cache; + + timings.prompt_n = n_prompt_tokens_processed; + timings.prompt_ms = t_prompt_processing; + timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; + timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + + timings.predicted_n = n_decoded; + timings.predicted_ms = t_token_generation; + timings.predicted_per_token_ms = t_token_generation / n_decoded; + timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; + + // Add speculative metrics + if (n_draft_total > 0) { + timings.draft_n = n_draft_total; + timings.draft_n_accepted = n_draft_accepted; + } + + return timings; + } + + const common_chat_msg & update_chat_msg(std::vector & diffs) { + GGML_ASSERT(task); + + auto previous_msg = chat_msg; + SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); + auto new_msg = common_chat_parse( + generated_text, + /* is_partial= */ stop != STOP_TYPE_EOS, + task->params.oaicompat_chat_syntax); + if (!new_msg.empty()) { + new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id); + chat_msg = new_msg; + diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg); + } + return chat_msg; + } + + size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { + GGML_ASSERT(task); + + size_t stop_pos = std::string::npos; + + for (const std::string & word : task->params.antiprompt) { + size_t pos; + + if (is_full_stop) { + const size_t tmp = word.size() + last_token_size; + const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; + + pos = text.find(word, from_pos); + } else { + // otherwise, partial stop + pos = string_find_partial_stop(text, word); + } + + if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { + if (is_full_stop) { + stop = STOP_TYPE_WORD; + stopping_word = word; + has_next_token = false; + } + stop_pos = pos; + } + } + + return stop_pos; + } + + void print_timings() const { + const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; + const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + + const double t_gen = t_token_generation / n_decoded; + const double n_gen_second = 1e3 / t_token_generation * n_decoded; + + SLT_INF(*this, + "\n" + "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" + " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" + " total time = %10.2f ms / %5d tokens\n", + t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, + t_token_generation, n_decoded, t_gen, n_gen_second, + t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); + + if (n_draft_total > 0) { + const float draft_ratio = (float) n_draft_accepted / n_draft_total; + SLT_INF(*this, + "\n" + "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n", + draft_ratio, n_draft_accepted, n_draft_total + ); + } + } + + json to_json(bool only_metrics = false) const { + json res; + + res = { + {"id", id}, + {"n_ctx", n_ctx}, + {"speculative", can_speculate()}, + {"is_processing", is_processing()}, + }; + + const auto & ptask = task ? task : task_prev; + + if (ptask) { + res["id_task"] = ptask->id; + res["params"] = ptask->params.to_json(only_metrics); + res["next_token"] = { + { + {"has_next_token", has_next_token}, + {"has_new_line", has_new_line}, + {"n_remain", n_remaining}, + {"n_decoded", n_decoded}, + } + }; + + if (!only_metrics) { + res["prompt"] = ptask->tokens.detokenize(ctx, true); + res["generated"] = generated_text; + } + } + + return res; + } +}; + + + +// +// server_metrics +// + +struct server_metrics { + int64_t t_start = 0; + + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; + + uint64_t n_tokens_max = 0; + + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; + + void init() { + t_start = ggml_time_us(); + } + + void on_prompt_eval(const server_slot & slot) { + n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; + n_prompt_tokens_processed += slot.n_prompt_tokens_processed; + t_prompt_processing += slot.t_prompt_processing; + t_prompt_processing_total += slot.t_prompt_processing; + + n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); + } + + void on_prediction(const server_slot & slot) { + n_tokens_predicted_total += slot.n_decoded; + n_tokens_predicted += slot.n_decoded; + t_tokens_generation += slot.t_token_generation; + t_tokens_generation_total += slot.t_token_generation; + } + + void on_decoded(const std::vector & slots) { + n_decode_total++; + for (const auto & slot : slots) { + if (slot.is_processing()) { + n_busy_slots_total++; + } + n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); + } + } + + void reset_bucket() { + n_prompt_tokens_processed = 0; + t_prompt_processing = 0; + n_tokens_predicted = 0; + t_tokens_generation = 0; + } +}; + + +// +// server_context_impl (private implementation) +// + +struct server_context_impl { + common_params params_base; + + // note: keep these alive - they determine the lifetime of the model, context, etc. + common_init_result llama_init; + common_init_result llama_init_dft; + + llama_model * model = nullptr; + llama_context * ctx = nullptr; + + // multimodal + mtmd_context * mctx = nullptr; + + const llama_vocab * vocab = nullptr; + bool vocab_dft_compatible = true; + + llama_model * model_dft = nullptr; + + llama_context_params cparams_dft; + + llama_batch batch {}; + + bool add_bos_token = true; + + int32_t n_ctx; // total context for all clients / slots + + // slots / clients + std::vector slots; + + int slots_debug = 0; + + server_queue queue_tasks; + server_response queue_results; + + std::unique_ptr prompt_cache; + + server_metrics metrics; + + // Necessary similarity of prompt for slot selection + float slot_prompt_similarity = 0.0f; + + common_chat_templates_ptr chat_templates; + oaicompat_parser_options oai_parser_opt; + + ~server_context_impl() { + mtmd_free(mctx); + + // Clear any sampling context + for (server_slot & slot : slots) { + common_sampler_free(slot.smpl); + slot.smpl = nullptr; + + llama_free(slot.ctx_dft); + slot.ctx_dft = nullptr; + + common_speculative_free(slot.spec); + slot.spec = nullptr; + + llama_batch_free(slot.batch_spec); + } + + llama_batch_free(batch); + } + + // load the model and initialize llama_context + bool load_model(const common_params & params) { + SRV_INF("loading model '%s'\n", params.model.path.c_str()); + + params_base = params; + + llama_init = common_init_from_params(params_base); + + model = llama_init.model.get(); + ctx = llama_init.context.get(); + + if (model == nullptr) { + SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); + return false; + } + + vocab = llama_model_get_vocab(model); + + n_ctx = llama_n_ctx(ctx); + + add_bos_token = llama_vocab_get_add_bos(vocab); + + if (params_base.has_speculative()) { + SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); + + auto params_dft = params_base; + + params_dft.devices = params_base.speculative.devices; + params_dft.model = params_base.speculative.model; + params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx; + params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; + params_dft.n_parallel = 1; + params_dft.cache_type_k = params_base.speculative.cache_type_k; + params_dft.cache_type_v = params_base.speculative.cache_type_v; + + params_dft.cpuparams.n_threads = params_base.speculative.cpuparams.n_threads; + params_dft.cpuparams_batch.n_threads = params_base.speculative.cpuparams_batch.n_threads; + params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides; + + llama_init_dft = common_init_from_params(params_dft); + + model_dft = llama_init_dft.model.get(); + + if (model_dft == nullptr) { + SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str()); + return false; + } + + vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft.context.get()); + if (!vocab_dft_compatible) { + SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str()); + } + + const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); + + cparams_dft = common_context_params_to_llama(params_dft); + cparams_dft.n_batch = n_ctx_dft; + + // the context is not needed - we will create one for each slot + llama_init_dft.context.reset(); + } + + chat_templates = common_chat_templates_init(model, params_base.chat_template); + try { + common_chat_format_example(chat_templates.get(), params.use_jinja, params.default_template_kwargs); + } catch (const std::exception & e) { + SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what()); + SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); + chat_templates = common_chat_templates_init(model, "chatml"); + } + + std::string & mmproj_path = params_base.mmproj.path; + if (!mmproj_path.empty()) { + mtmd_helper_log_set(common_log_default_callback, nullptr); + + mtmd_context_params mparams = mtmd_context_params_default(); + mparams.use_gpu = params_base.mmproj_use_gpu; + mparams.print_timings = false; + mparams.n_threads = params_base.cpuparams.n_threads; + mparams.flash_attn_type = params_base.flash_attn_type; + mparams.image_min_tokens = params_base.image_min_tokens; + mparams.image_max_tokens = params_base.image_max_tokens; + mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); + if (mctx == nullptr) { + SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); + return false; + } + SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str()); + + if (params_base.ctx_shift) { + params_base.ctx_shift = false; + SRV_WRN("%s\n", "ctx_shift is not supported by multimodal, it will be disabled"); + } + + if (params_base.n_cache_reuse) { + params_base.n_cache_reuse = 0; + SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); + } + + if (params_base.has_speculative()) { + SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal"); + return false; + } + } + + if (!llama_memory_can_shift(llama_get_memory(ctx))) { + if (params_base.ctx_shift) { + params_base.ctx_shift = false; + SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled"); + } + + if (params_base.n_cache_reuse) { + params_base.n_cache_reuse = 0; + SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled"); + } + } + + return true; + } + + // initialize slots and server-related data + void init() { + // wiring up server queues + queue_tasks.on_new_task([this](server_task && task) { + process_single_task(std::move(task)); + }); + queue_tasks.on_update_slots([this]() { + update_slots(); + }); + + // Necessary similarity of prompt for slot selection + slot_prompt_similarity = params_base.slot_prompt_similarity; + + // setup slots + SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); + + const int n_ctx_train = llama_model_n_ctx_train(model); + + int n_ctx_slot = llama_n_ctx_seq(ctx); + if (n_ctx_slot > n_ctx_train) { + SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train); + n_ctx_slot = n_ctx_train; + } + + for (int i = 0; i < params_base.n_parallel; i++) { + server_slot slot; + + slot.id = i; + slot.ctx = ctx; + slot.n_ctx = n_ctx_slot; + slot.mctx = mctx; + slot.prompt.tokens.has_mtmd = mctx != nullptr; + + if (model_dft) { + slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); + + // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] + slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); + if (slot.ctx_dft == nullptr) { + SRV_ERR("%s", "failed to create draft context\n"); + return; + } + + slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft); + if (slot.spec == nullptr) { + SRV_ERR("%s", "failed to create speculator\n"); + return; + } + for (auto & pair : params_base.speculative.replacements) { + common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); + } + } + + SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx); + + slot.callback_on_release = [this](int) { + queue_tasks.pop_deferred_task(); + }; + + slot.reset(); + + slots.push_back(std::move(slot)); + } + + { + const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG"); + slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0; + + if (slots_debug) { + SRV_WRN("slots debug = %d\n", slots_debug); + } + } + + // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens + // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) + { + const int32_t n_batch = llama_n_batch(ctx); + batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); + } + + metrics.init(); + + if (params_base.cache_ram_mib != 0) { + if (params_base.cache_ram_mib < 0) { + SRV_WRN("prompt cache is enabled, size limit: %s\n", "no limit"); + } else { + SRV_WRN("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib); + } + SRV_WRN("%s", "use `--cache-ram 0` to disable the prompt cache\n"); + + prompt_cache = std::make_unique(params_base.cache_ram_mib, n_ctx); + } else { + SRV_WRN("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n"); + } + SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n"); + + // thinking is enabled if: + // 1. It's not explicitly disabled (reasoning_budget == 0) + // 2. The chat template supports it + const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get()); + SRV_INF("thinking = %d\n", enable_thinking); + + oai_parser_opt = { + /* use_jinja */ params_base.use_jinja, + /* prefill_assistant */ params_base.prefill_assistant, + /* reasoning_format */ params_base.reasoning_format, + /* chat_template_kwargs */ params_base.default_template_kwargs, + /* common_chat_templates */ chat_templates.get(), + /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, + /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, + /* enable_thinking */ enable_thinking, + }; + + // print sample chat example to make it clear which template is used + LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, + common_chat_templates_source(chat_templates.get()), + common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str()); + } + + server_slot * get_slot_by_id(int id) { + for (server_slot & slot : slots) { + if (slot.id == id) { + return &slot; + } + } + + return nullptr; + } + + server_slot * get_available_slot(const server_task & task) { + server_slot * ret = nullptr; + + bool update_cache = false; + + // find the slot that has at least n% prompt similarity + if (ret == nullptr && slot_prompt_similarity != 0.0f) { + float sim_best = 0; + + for (server_slot & slot : slots) { + // skip the slot if it is not available + if (slot.is_processing()) { + continue; + } + + const auto & tokens = slot.prompt.tokens; + + // skip the slot if it does not contains cached tokens + if (tokens.empty()) { + continue; + } + + // fraction of the Longest Common Prefix length with respect to the input prompt length + const float sim_cur = float(tokens.get_common_prefix(task.tokens)) / task.tokens.size(); + + // select the current slot if the criteria match + if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) { + sim_best = sim_cur; + + ret = &slot; + } + } + + if (ret != nullptr) { + const float f_keep = (sim_best*task.tokens.size()) / ret->prompt.tokens.size(); + + SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n", + sim_best, slot_prompt_similarity, f_keep); + + // if we are about to lose a large portion of the existing context - save it in the prompt cache + if (f_keep < 0.5f) { + update_cache = true; + } + } + } + + // find the slot that has been least recently used + if (ret == nullptr) { + int64_t t_last = -1; + + for (server_slot & slot : slots) { + // skip the slot if it is not available + if (slot.is_processing()) { + continue; + } + + // select the current slot if the criteria match + if (!ret || slot.t_last_used <= t_last) { + t_last = slot.t_last_used; + ret = &slot; + } + } + + if (ret != nullptr) { + SLT_INF(*ret, "selected slot by LRU, t_last = %" PRId64 "\n", t_last); + + update_cache = true; + } + } + + if (ret) { + const auto & tokens = ret->prompt.tokens; + + update_cache = update_cache && prompt_cache; + + // cache prompts only for completion tasks + update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION; + + // don't update the cache if the slot's context is empty + update_cache = update_cache && tokens.size() > 0; + + // TODO: mtmd does not support prompt cache + update_cache = update_cache && (ret->mctx == nullptr); + + if (update_cache) { + SRV_WRN("%s", "updating prompt cache\n"); + + const int64_t t_start = ggml_time_us(); + + ret->prompt_save(*prompt_cache); + + if (!ret->prompt_load(*prompt_cache, task.tokens)) { + clear_slot(*ret); + } + + prompt_cache->update(); + + SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); + } + } + + return ret; + } + + void clear_slot(server_slot & slot) const { + GGML_ASSERT(!slot.is_processing()); + + SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size()); + + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); + slot.prompt.tokens.clear(); + } + + // return true if at least one slot has been cleared + // TODO: improve logic + // - smarter decision which slot to clear (LRU or longest prompt?) + // - move slot to level 2 cache instead of removing? + // - instead of purging, try to store and resume later? + bool try_clear_idle_slots() { + bool res = false; + + if (!params_base.kv_unified) { + return res; + } + + for (auto & slot : slots) { + if (slot.is_processing()) { + continue; + } + + if (slot.prompt.n_tokens() > 0) { + SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size()); + + clear_slot(slot); + + res = true; + + // clear slots one by one + break; + } + } + + return res; + } + + bool launch_slot_with_task(server_slot & slot, server_task && task) { + slot.reset(); + + if (!are_lora_equal(task.params.lora, slot.lora)) { + // if lora has changed, check to see if the cache should be cleared + if (lora_should_clear_cache(slot.lora, task.params.lora)) { + SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), task.params.lora.size()); + slot.prompt.tokens.clear(); + } else { + SLT_INF(slot, "keeping cache for alora. %zu target loras\n", task.params.lora.size()); + } + slot.lora = task.params.lora; + } + + // if using alora, make sure it's only a single one requested and active + size_t alora_invocation_start = task.tokens.size(); + if (lora_all_alora(slot.lora)) { + const auto & enabled_ids = lora_get_enabled_ids(slot.lora); + // TODO: This will error out if a user requests two aloras, but only + // provides the activation string for one. We could, instead search + // for all requested alora activation strings and then either keep + // only the last one, or reject if multiple are found. + if (enabled_ids.size() != 1) { + send_error(task, "Cannot run multiple aLoRAs in a single request", ERROR_TYPE_INVALID_REQUEST); + return false; + } + const auto & lora = slot.lora[enabled_ids[0]].ptr; + + // get the pointer and count for the invocation tokens + const uint64_t n_invocation_tokens = llama_adapter_get_alora_n_invocation_tokens(lora); + const llama_token * invocation_tokens = llama_adapter_get_alora_invocation_tokens (lora); + + // scan backwards through the prompt tokens to find the last + // occurrence of the invocation sequence + int match_idx = static_cast(n_invocation_tokens) - 1; + for (int i = task.tokens.size() - 1; i >= 0; --i) { + // the token in this position matches the next token to find in + // the invocation sequence + if (task.tokens[i] == invocation_tokens[match_idx]) { + // if it's a full match, we've found the start + if (match_idx == 0) { + alora_invocation_start = i; + break; + } + // otherwise, check the next token in the sequence + --match_idx; + } else { + // no match in this position, so start looking over again + match_idx = static_cast(n_invocation_tokens) - 1; + } + } + + // if the activation string is not found, disable the alora + if (alora_invocation_start == task.tokens.size()) { + SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]); + slot.lora[enabled_ids[0]].scale = 0.0f; + } else { + SLT_DBG(slot, "alora %zu activated starting at %zu\n", enabled_ids[0], alora_invocation_start); + slot.alora_invocation_start = alora_invocation_start; + } + } + + if (!task.tokens.validate(ctx)) { + send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); + return false; + } + + SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); + + // initialize samplers + { + if (slot.smpl != nullptr) { + common_sampler_free(slot.smpl); + } + + slot.smpl = common_sampler_init(model, task.params.sampling); + if (slot.smpl == nullptr) { + // for now, the only error that may happen here is invalid grammar + send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); + return false; + } + + SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str()); + } + + // initialize draft batch + // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] + if (slot.ctx_dft) { + llama_batch_free(slot.batch_spec); + + slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1); + } + + slot.task = std::make_unique(std::move(task)); + + slot.state = SLOT_STATE_STARTED; + + SLT_INF(slot, "%s", "processing task\n"); + + return true; + } + + bool process_token(completion_token_output & result, server_slot & slot) { + // remember which tokens were sampled - used for repetition penalties during sampling + const std::string token_str = result.text_to_send; + slot.sampled = result.tok; + + slot.generated_text += token_str; + if (slot.task->params.return_tokens) { + slot.generated_tokens.push_back(result.tok); + } + slot.has_next_token = true; + + // check if there is incomplete UTF-8 character at the end + bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); + + // search stop word and delete it + if (!incomplete) { + size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); + + const std::string str_test = slot.generated_text.substr(pos); + bool send_text = true; + + size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); + if (stop_pos != std::string::npos) { + slot.generated_text.erase( + slot.generated_text.begin() + pos + stop_pos, + slot.generated_text.end()); + pos = std::min(slot.n_sent_text, slot.generated_text.size()); + } else if (slot.has_next_token && !llama_vocab_is_eog(vocab, result.tok) ) { + stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); + send_text = stop_pos == std::string::npos; + } + + // check if there is any token to predict + if (send_text) { + // no send the stop word in the response + result.text_to_send = slot.generated_text.substr(pos, std::string::npos); + slot.n_sent_text += result.text_to_send.size(); + // add the token to slot queue and cache + } else { + result.text_to_send = ""; + } + + slot.add_token(result); + if (slot.task->params.stream) { + send_partial_response(slot, result, false); + } + } + + if (incomplete) { + slot.has_next_token = true; + } + + // if context shifting is disabled, make sure that we don't run out of context + if (!params_base.ctx_shift && slot.prompt.n_tokens() + 1 >= slot.n_ctx) { + slot.truncated = true; + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped due to running out of context capacity, prompt.n_tokens() = %d, task.n_tokens = %d, n_decoded = %d, n_ctx = %d\n", + slot.prompt.n_tokens(), slot.task->n_tokens(), slot.n_decoded, slot.n_ctx); + } + + // check the limits + if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.task->params.n_predict); + } + + if (slot.has_new_line) { + // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent + if (slot.task->params.n_indent > 0) { + // check the current indentation + // TODO: improve by not doing it more than once for each new line + if (slot.last_nl_pos > 0) { + size_t pos = slot.last_nl_pos; + + int n_indent = 0; + while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) { + n_indent++; + pos++; + } + + if (pos < slot.generated_text.size() && n_indent < slot.task->params.n_indent) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + // cut the last line + slot.generated_text.erase(pos, std::string::npos); + + SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent); + } + } + + // find the next new line + { + const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos); + + if (pos != std::string::npos) { + slot.last_nl_pos = pos + 1; + } + } + } + } + + // check if there is a new line in the generated text + if (result.text_to_send.find('\n') != std::string::npos) { + slot.has_new_line = true; + + // if we have seen a new line, we stop after a certain time limit, but only upon another new line + if (slot.task->params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.task->params.t_max_predict_ms)) { + slot.stop = STOP_TYPE_LIMIT; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.task->params.t_max_predict_ms); + } + } + + if (llama_vocab_is_eog(vocab, result.tok)) { + slot.stop = STOP_TYPE_EOS; + slot.has_next_token = false; + + SLT_DBG(slot, "%s", "stopped by EOS\n"); + } + + SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); + + return slot.has_next_token; // continue + } + + void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { + size_t n_probs = slot.task->params.sampling.n_probs; + size_t n_vocab = llama_vocab_n_tokens(vocab); + + if (post_sampling) { + const auto * cur_p = common_sampler_get_candidates(slot.smpl, true); + const size_t max_probs = cur_p->size; + + // set probability for sampled token + for (size_t i = 0; i < max_probs; i++) { + if (cur_p->data[i].id == result.tok) { + result.prob = cur_p->data[i].p; + break; + } + } + + // set probability for top n_probs tokens + result.probs.reserve(max_probs); + for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { + result.probs.push_back({ + cur_p->data[i].id, + common_token_to_piece(ctx, cur_p->data[i].id, special), + cur_p->data[i].p + }); + } + } else { + // TODO: optimize this with min-p optimization + std::vector cur = get_token_probabilities(ctx, idx); + + // set probability for sampled token + for (size_t i = 0; i < n_vocab; i++) { + // set probability for sampled token + if (cur[i].id == result.tok) { + result.prob = cur[i].p; + break; + } + } + + // set probability for top n_probs tokens + result.probs.reserve(n_probs); + for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { + result.probs.push_back({ + cur[i].id, + common_token_to_piece(ctx, cur[i].id, special), + cur[i].p + }); + } + } + } + + void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + send_error(task.id, error, type); + } + + void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + send_error(slot.task->id, error, type, slot.task->n_tokens(), slot.n_ctx); + } + + void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) { + SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str()); + + if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) { + GGML_ASSERT(n_ctx > 0 && n_prompt_tokens > 0); + } + + auto res = std::make_unique(); + res->id = id_task; + res->err_type = type; + res->err_msg = error; + res->n_prompt_tokens = n_prompt_tokens; + res->n_ctx = n_ctx; + + queue_results.send(std::move(res)); + } + + // if multimodal is enabled, send an error and return false + bool check_no_mtmd(const int id_task) { + if (mctx) { + send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); + return false; + } + return true; + } + + void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) { + auto res = std::make_unique(); + + res->id = slot.task->id; + res->index = slot.task->index; + + if (is_progress) { + res->is_progress = true; + res->progress.total = slot.task->n_tokens(); + res->progress.cache = slot.n_prompt_tokens_cache; + res->progress.processed = slot.prompt.tokens.size(); + res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt) / 1000; + } else { + res->content = tkn.text_to_send; + res->tokens = { tkn.tok }; + + slot.update_chat_msg(res->oaicompat_msg_diffs); + } + + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.task->n_tokens(); + res->post_sampling_probs = slot.task->params.post_sampling_probs; + + res->verbose = slot.task->params.verbose; + res->res_type = slot.task->params.res_type; + res->oaicompat_model = slot.task->params.oaicompat_model; + res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; + + // populate res.probs_output + if (slot.task->params.sampling.n_probs > 0) { + res->prob_output = tkn; // copy the token probs + } + + // populate timings if this is final response or timings_per_token is enabled + if (slot.stop != STOP_TYPE_NONE || slot.task->params.timings_per_token) { + res->timings = slot.get_timings(); + } + + queue_results.send(std::move(res)); + } + + void send_final_response(server_slot & slot) { + auto res = std::make_unique(); + + res->id = slot.task->id; + res->id_slot = slot.id; + + res->index = slot.task->index; + res->content = slot.generated_text; + res->tokens = std::move(slot.generated_tokens); + res->timings = slot.get_timings(); + res->prompt = slot.task->tokens.detokenize(ctx, true); + res->response_fields = std::move(slot.task->params.response_fields); + + res->truncated = slot.truncated; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.task->n_tokens(); + res->n_tokens_cached = slot.prompt.n_tokens(); + res->has_new_line = slot.has_new_line; + res->stopping_word = slot.stopping_word; + res->stop = slot.stop; + res->post_sampling_probs = slot.task->params.post_sampling_probs; + + res->verbose = slot.task->params.verbose; + res->stream = slot.task->params.stream; + res->include_usage = slot.task->params.include_usage; + res->res_type = slot.task->params.res_type; + res->oaicompat_model = slot.task->params.oaicompat_model; + res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; + res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); + + // populate res.probs_output + if (slot.task->params.sampling.n_probs > 0) { + if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) { + const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); + + size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); + res->probs_output = std::vector( + slot.generated_token_probs.begin(), + slot.generated_token_probs.end() - safe_offset); + } else { + res->probs_output = std::vector( + slot.generated_token_probs.begin(), + slot.generated_token_probs.end()); + } + } + + res->generation_params = slot.task->params; // copy the parameters + + queue_results.send(std::move(res)); + } + + void send_embedding(const server_slot & slot, const llama_batch & batch) { + auto res = std::make_unique(); + res->id = slot.task->id; + res->index = slot.task->index; + res->n_tokens = slot.task->n_tokens(); + res->res_type = slot.task->params.res_type; + + const int n_embd = llama_model_n_embd(model); + + std::vector embd_res(n_embd, 0.0f); + + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } + + const float * embd = nullptr; + if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) { + embd = llama_get_embeddings_ith(ctx, i); + } else { + embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + } + + if (embd == nullptr) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + + res->embedding.push_back(std::vector(n_embd, 0.0f)); + continue; + } + + // normalize only when there is pooling + if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { + common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize); + res->embedding.push_back(embd_res); + break; + } + + res->embedding.emplace_back(embd, embd + n_embd); + } + + SLT_DBG(slot, "%s", "sending embeddings\n"); + + queue_results.send(std::move(res)); + } + + void send_rerank(const server_slot & slot, const llama_batch & batch) { + auto res = std::make_unique(); + res->id = slot.task->id; + res->index = slot.task->index; + res->n_tokens = slot.task->n_tokens(); + + for (int i = 0; i < batch.n_tokens; ++i) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + continue; + } + + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + } + + if (embd == NULL) { + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + + res->score = -1e6; + continue; + } + + res->score = embd[0]; + } + + SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score); + + queue_results.send(std::move(res)); + } + + // + // Functions to process the task + // + + void process_single_task(server_task && task) { + switch (task.type) { + case SERVER_TASK_TYPE_COMPLETION: + case SERVER_TASK_TYPE_INFILL: + case SERVER_TASK_TYPE_EMBEDDING: + case SERVER_TASK_TYPE_RERANK: + { + const int id_slot = task.id_slot; + + server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); + + if (slot == nullptr) { + // if no slot is available, we defer this task for processing later + SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + if (!launch_slot_with_task(*slot, std::move(task))) { + SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); + break; + } + } break; + case SERVER_TASK_TYPE_CANCEL: + { + // release slot linked with the task id + for (auto & slot : slots) { + if (slot.task && slot.task->id == task.id_target) { + slot.release(); + break; + } + } + } break; + case SERVER_TASK_TYPE_NEXT_RESPONSE: + { + // do nothing + } break; + case SERVER_TASK_TYPE_METRICS: + { + json slots_data = json::array(); + + int n_idle_slots = 0; + int n_processing_slots = 0; + + for (server_slot & slot : slots) { + json slot_data = slot.to_json(slots_debug == 0); + + if (slot.is_processing()) { + n_processing_slots++; + } else { + n_idle_slots++; + } + + slots_data.push_back(slot_data); + } + SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots); + + auto res = std::make_unique(); + res->id = task.id; + res->slots_data = std::move(slots_data); + res->n_idle_slots = n_idle_slots; + res->n_processing_slots = n_processing_slots; + res->n_tasks_deferred = queue_tasks.queue_tasks_deferred_size(); + res->t_start = metrics.t_start; + + res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; + res->t_prompt_processing_total = metrics.t_prompt_processing_total; + res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; + res->t_tokens_generation_total = metrics.t_tokens_generation_total; + + res->n_tokens_max = metrics.n_tokens_max; + + res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; + res->t_prompt_processing = metrics.t_prompt_processing; + res->n_tokens_predicted = metrics.n_tokens_predicted; + res->t_tokens_generation = metrics.t_tokens_generation; + + res->n_decode_total = metrics.n_decode_total; + res->n_busy_slots_total = metrics.n_busy_slots_total; + + if (task.metrics_reset_bucket) { + metrics.reset_bucket(); + } + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_SAVE: + { + if (!check_no_mtmd(task.id)) { + break; + } + + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + const size_t token_count = slot->prompt.tokens.size(); + const int64_t t_start = ggml_time_us(); + + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; + + const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens(); + const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); + + const int64_t t_end = ggml_time_us(); + const double t_save_ms = (t_end - t_start) / 1000.0; + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = true; + res->n_tokens = token_count; + res->n_bytes = nwrite; + res->t_ms = t_save_ms; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_RESTORE: + { + if (!check_no_mtmd(task.id)) break; + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + const int64_t t_start = ggml_time_us(); + + std::string filename = task.slot_action.filename; + std::string filepath = task.slot_action.filepath; + + llama_tokens tokens; + tokens.resize(slot->n_ctx); + size_t token_count = 0; + size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); + if (nread == 0) { + slot->prompt.tokens.clear(); // KV may already been invalidated? + send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); + break; + } + tokens.resize(token_count); + slot->prompt.tokens.clear(); + slot->prompt.tokens.insert(tokens); + + const int64_t t_end = ggml_time_us(); + const double t_restore_ms = (t_end - t_start) / 1000.0; + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->filename = filename; + res->is_save = false; + res->n_tokens = token_count; + res->n_bytes = nread; + res->t_ms = t_restore_ms; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SLOT_ERASE: + { + if (!check_no_mtmd(task.id)) { + break; + } + int id_slot = task.slot_action.slot_id; + server_slot * slot = get_slot_by_id(id_slot); + if (slot == nullptr) { + send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); + break; + } + if (slot->is_processing()) { + // if requested slot is unavailable, we defer this task for processing later + SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); + queue_tasks.defer(std::move(task)); + break; + } + + // Erase token cache + const size_t n_erased = slot->prompt.tokens.size(); + + clear_slot(*slot); + + auto res = std::make_unique(); + res->id = task.id; + res->id_slot = id_slot; + res->n_erased = n_erased; + queue_results.send(std::move(res)); + } break; + case SERVER_TASK_TYPE_SET_LORA: + { + params_base.lora_adapters = std::move(task.set_lora); + auto res = std::make_unique(); + res->id = task.id; + queue_results.send(std::move(res)); + } break; + + } + } + + void update_slots() { + // check if all slots are idle + { + bool all_idle = true; + + for (auto & slot : slots) { + if (slot.is_processing()) { + all_idle = false; + break; + } + } + + if (all_idle) { + SRV_INF("%s", "all slots are idle\n"); + + return; + } + } + + { + SRV_DBG("%s", "posting NEXT_RESPONSE\n"); + + server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); + task.id = queue_tasks.get_new_id(); + queue_tasks.post(std::move(task)); + } + + // apply context-shift if needed + // TODO: simplify and improve + for (server_slot & slot : slots) { + if (slot.state == SLOT_STATE_GENERATING && slot.prompt.n_tokens() + 1 >= slot.n_ctx) { + if (!params_base.ctx_shift) { + // this check is redundant (for good) + // we should never get here, because generation should already stopped in process_token() + send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); + slot.release(); + continue; + } + + if (mctx) { + // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded + // we don't support ctx_shift because an image chunk may contains multiple tokens + GGML_ABORT("not supported by multimodal"); + } + + // Shift context + int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep; + + if (add_bos_token) { + n_keep += 1; + } + + n_keep = std::min(slot.n_ctx - 4, n_keep); + + const int n_left = slot.prompt.n_tokens() - n_keep; + const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2); + + SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); + + llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard); + llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); + + // add generated tokens to cache + // ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481 + { + GGML_ASSERT(!slot.prompt.tokens.has_mtmd); + + llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy + for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { + new_tokens[i - n_discard] = new_tokens[i]; + } + + new_tokens.resize(slot.prompt.tokens.size() - n_discard); + + slot.prompt.tokens.clear(); + slot.prompt.tokens.insert(new_tokens); + } + + slot.truncated = true; + } + } + + // start populating the batch for this iteration + common_batch_clear(batch); + + // track if given slot can be batched with slots already in the batch + server_slot * slot_batched = nullptr; + + auto accept_special_token = [&](server_slot & slot, llama_token token) { + return params_base.special || + slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); + }; + + // first, add sampled tokens from any ongoing sequences + for (auto & slot : slots) { + if (slot.state != SLOT_STATE_GENERATING) { + continue; + } + + // check if we can batch this slot with the previous one + if (!slot_batched) { + slot_batched = &slot; + } else if (!slot_batched->can_batch_with(slot)) { + continue; + } + + slot.i_batch = batch.n_tokens; + + common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); + + slot.prompt.tokens.push_back(slot.sampled); + + SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n", + slot.n_ctx, slot.prompt.n_tokens(), slot.truncated); + } + + // process in chunks of params.n_batch + int32_t n_batch = llama_n_batch(ctx); + int32_t n_ubatch = llama_n_ubatch(ctx); + + float alora_scale = -1.0f; + size_t alora_disabled_id = 0; + + // next, batch any pending prompts without exceeding n_batch + if (params_base.cont_batching || batch.n_tokens == 0) { + for (auto & slot : slots) { + if (!slot.is_processing()) { + continue; + } + + // check if we can batch this slot with the previous one + if (slot_batched && !slot_batched->can_batch_with(slot)) { + continue; + } + + // this slot still has a prompt to be processed + if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { + const auto & input_tokens = slot.task->tokens; + + // TODO: maybe move branch to outside of this loop in the future + if (slot.state == SLOT_STATE_STARTED) { + slot.t_start_process_prompt = ggml_time_us(); + slot.t_start_generation = 0; + + slot.state = SLOT_STATE_PROCESSING_PROMPT; + + SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n", + slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens()); + + // print prompt tokens (for debugging) + /*if (1) { + // first 16 tokens (avoid flooding logs) + for (int i = 0; i < std::min(16, input_tokens.size()); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); + } + } else { + // all + for (int i = 0; i < (int) input_tokens.size(); i++) { + SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); + } + }*/ + + // keep track how many tokens we can reuse from the previous state + int n_past = 0; + + // empty prompt passed -> release the slot and send empty response + if (input_tokens.empty()) { + SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); + + slot.print_timings(); + send_final_response(slot); + slot.release(); + + continue; + } + + // TODO: support memory-less logits computation + if (slot.need_logits() && !llama_get_memory(ctx)) { + send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); + slot.release(); + continue; + } + + if (!slot.can_split()) { + if (slot.task->n_tokens() > n_ubatch) { + send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); + slot.release(); + continue; + } + + if (slot.task->n_tokens() > slot.n_ctx) { + send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE); + slot.release(); + continue; + } + } else { + if (slot.task->n_tokens() >= slot.n_ctx) { + send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE); + slot.release(); + continue; + } + + if (slot.task->params.cache_prompt) { + // reuse any previously computed tokens that are common with the new prompt + n_past = slot.prompt.tokens.get_common_prefix(input_tokens); + + // if there is an alora invoked, don't cache after the invocation start + if (slot.alora_invocation_start > 0) { + SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start); + n_past = std::min(n_past, slot.alora_invocation_start - 1); + } + + // reuse chunks from the cached prompt by shifting their KV cache in the new position + if (params_base.n_cache_reuse > 0) { + GGML_ASSERT(!slot.prompt.tokens.has_mtmd); + + size_t head_c = n_past; // cache + size_t head_p = n_past; // current prompt + + if (mctx) { + // we should never reach this + GGML_ABORT("not supported by multimodal"); + } + + SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past); + + while (head_c < slot.prompt.tokens.size() && + head_p < input_tokens.size()) { + + size_t n_match = 0; + while (head_c + n_match < slot.prompt.tokens.size() && + head_p + n_match < input_tokens.size() && + slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) { + + n_match++; + } + + if (n_match >= (size_t) params_base.n_cache_reuse) { + SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); + //for (size_t i = head_p; i < head_p + n_match; i++) { + // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); + //} + + const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; + + llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c); + llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift); + + for (size_t i = 0; i < n_match; i++) { + slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]); + n_past++; + } + + head_c += n_match; + head_p += n_match; + } else { + head_c += 1; + } + } + + SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past); + } + } else { + // if we don't cache the prompt, we have to remove all previous tokens + n_past = 0; + } + + // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1 + const auto n_swa = std::max(1, llama_model_n_swa(model)); + + // the largest pos_min required for a checkpoint to be useful + const auto pos_min_thold = std::max(0, n_past - n_swa); + + // note: disallow with mtmd contexts for now + // https://github.com/ggml-org/llama.cpp/issues/17043 + if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) { + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); + if (pos_min == -1) { + SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); + GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); + } + + // when the prompt prefix does not match, print the tokens around the mismatch + // this is useful for debugging prompt caching + if (slots_debug) { + const int np0 = std::max(n_past - 4, 0); + const int np1 = std::min(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size())); + + std::stringstream ss0; + std::stringstream ss1; + + std::stringstream st0; + std::stringstream st1; + + ss0 << "old: ... "; + ss1 << "new: ... "; + + for (int i = np0; i < np1; i++) { + if (i == n_past) { + ss0 << " | "; + ss1 << " | "; + } + + { + const auto token = slot.prompt.tokens[i]; + const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; + ss0 << piece; + st0 << std::setw(8) << token; + } + + { + const auto token = slot.task->tokens[i]; + const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; + ss1 << piece; + st1 << std::setw(8) << token; + } + } + + SLT_WRN(slot, "%s\n", ss0.str().c_str()); + SLT_WRN(slot, "%s\n", ss1.str().c_str()); + + SLT_WRN(slot, "%s\n", st0.str().c_str()); + SLT_WRN(slot, "%s\n", st1.str().c_str()); + } + + if (pos_min > pos_min_thold) { + // TODO: support can be added in the future when corresponding vision models get released + GGML_ASSERT(!slot.prompt.tokens.has_mtmd); + + SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); + + // search for a context checkpoint + const auto it = std::find_if( + slot.prompt.checkpoints.rbegin(), + slot.prompt.checkpoints.rend(), + [&](const auto & cur) { + // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] + return cur.pos_min < pos_min_thold; + } + ); + + bool do_reset = it == slot.prompt.checkpoints.rend(); + + if (!do_reset) { + // restore the context checkpoint + const size_t checkpoint_size = it->data.size(); + const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + if (n != checkpoint_size) { + SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); + do_reset = true; + //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); + } else { + n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max)); + SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); + } + } + + if (do_reset) { + SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n", + "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); + n_past = 0; + } + } + } + + { + // erase any checkpoints with pos_min > pos_min_thold + for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { + const auto & cur = *it; + if (cur.pos_min > pos_min_thold) { + SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024); + it = slot.prompt.checkpoints.erase(it); + } else { + ++it; + } + } + } + } + + // [TAG_PROMPT_LOGITS] + if (n_past == slot.task->n_tokens() && n_past > 0) { + SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens()); + n_past--; + SLT_WRN(slot, "n_past was set to %d\n", n_past); + } + + slot.n_prompt_tokens_cache = n_past; + slot.n_prompt_tokens_processed = 0; + + slot.prompt.tokens.keep_first(n_past); + } + + if (!slot.can_split()) { + // cannot fit the prompt in the current batch - will try next iter + if (batch.n_tokens + slot.task->n_tokens() > n_batch) { + continue; + } + } + + // truncate any tokens that are beyond n_past for this slot + const llama_pos p0 = slot.prompt.tokens.pos_next(); + + SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); + + if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) { + SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); + + clear_slot(slot); + + // there is no common part left + slot.n_prompt_tokens_cache = 0; + } + + // check if we should process the image + if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { + // process the image + size_t n_tokens_out = 0; + int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); + if (res != 0) { + SLT_ERR(slot, "failed to process image, res = %d\n", res); + send_error(slot, "failed to process image", ERROR_TYPE_SERVER); + slot.release(); + continue; + } + + slot.n_prompt_tokens_processed += n_tokens_out; + + // add the image chunk to cache + { + const auto & chunk = input_tokens.find_chunk(slot.prompt.n_tokens()); + slot.prompt.tokens.push_back(chunk.get()); // copy + } + } + + // If using an alora, there may be uncached tokens that come + // before the invocation sequence. When this happens, the + // tokens before the invocation sequence need to be + // processed without the adapter in a separate batch, then + // the adapter needs to be enabled for the remaining tokens. + if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) { + SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start); + const auto & enabled_loras = lora_get_enabled_ids(slot.lora); + GGML_ASSERT(enabled_loras.size() == 1); + alora_scale = slot.lora[enabled_loras[0]].scale; + slot.lora[enabled_loras[0]].scale = 0.0f; + alora_disabled_id = enabled_loras[0]; + } + + bool do_checkpoint = params_base.n_ctx_checkpoints > 0; + + // make checkpoints only for completion tasks + do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION; + + // make a checkpoint of the parts of the memory that cannot be rolled back. + // checkpoints are created only if: + // - the model uses SWA and we are not using `swa_full` + // - the model architecture is marked as recurrent or hybrid + // + // TODO: try to make this conditional on the context or the memory module, instead of the model type + do_checkpoint = do_checkpoint && ( + llama_model_is_recurrent(model) || + llama_model_is_hybrid(model) || + (llama_model_n_swa(model) > 0 && !params_base.swa_full) + ); + + // add prompt tokens for processing in the current batch + while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) { + // get next token to process + llama_token cur_tok = input_tokens[slot.prompt.n_tokens()]; + if (cur_tok == LLAMA_TOKEN_NULL) { + break; // end of text chunk + } + + // if this is an alora request with pre-invocation + // tokens that are not cached, we need to stop filling + // this batch at those pre-invocation tokens. + if (alora_scale > 0 && slot.prompt.n_tokens() == slot.alora_invocation_start - 1) { + SLT_DBG(slot, "stop prompt batch filling at (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start); + break; + } + + // embedding requires all tokens in the batch to be output + common_batch_add(batch, + cur_tok, + slot.prompt.tokens.pos_next(), + { slot.id }, + slot.need_embd()); + slot.prompt.tokens.push_back(cur_tok); + + slot.n_prompt_tokens_processed++; + + // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. + if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) { + break; + } + } + + // SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str()); + + SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens()); + + // entire prompt has been processed + if (slot.prompt.n_tokens() == slot.task->n_tokens()) { + slot.state = SLOT_STATE_DONE_PROMPT; + + GGML_ASSERT(batch.n_tokens > 0); + + common_sampler_reset(slot.smpl); + + // Process all prompt tokens through sampler system + for (int i = 0; i < slot.task->n_tokens(); ++i) { + llama_token id = input_tokens[i]; + if (id != LLAMA_TOKEN_NULL) { + common_sampler_accept(slot.smpl, id, false); + } + } + + // extract the logits only for the last token + batch.logits[batch.n_tokens - 1] = true; + + slot.n_decoded = 0; + slot.i_batch = batch.n_tokens - 1; + + SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens); + + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); + const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); + + // no need for empty or small checkpoints + do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); + + // no need to create checkpoints that are too close together + do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); + + if (do_checkpoint) { + while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { + // make room for the new checkpoint, if needed + const auto & cur = slot.prompt.checkpoints.front(); + + SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + + slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); + } + + const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ + /*.pos_min = */ pos_min, + /*.pos_max = */ pos_max, + /*.data = */ std::vector(checkpoint_size), + }); + + llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", + (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + } + } + } + + if (!slot_batched) { + slot_batched = &slot; + } + + if (batch.n_tokens >= n_batch) { + break; + } + } + } + + if (batch.n_tokens == 0) { + SRV_WRN("%s", "no tokens to decode\n"); + return; + } + + SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); + + if (slot_batched) { + // apply lora, only need to do it once per batch + common_set_adapter_lora(ctx, slot_batched->lora); + + // if the lora is temporarily disabled for an alora, re-enable it + // for next time + if (alora_scale > 0.0f) { + SRV_DBG("re-enabling alora with scale %f\n", alora_scale); + slot_batched->lora[alora_disabled_id].scale = alora_scale; + } + + llama_set_embeddings(ctx, slot_batched->need_embd()); + } + + int32_t i_next = 0; + + // process the created batch of tokens + for (int32_t i = 0; i < batch.n_tokens; i = i_next) { + const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); + + llama_batch batch_view = { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, + }; + + const int ret = llama_decode(ctx, batch_view); + + metrics.on_decoded(slots); + + if (ret != 0) { + { + std::string err; + + if (n_batch == 1 && ret == 1) { + // TODO: try to terminate only the largest active slot/sequence and continue with the rest + // need to remove the tokens from the current batch too + err = "Context size has been exceeded."; + } + + if (ret == -1) { + err = "Invalid input batch."; + } + + if (ret < -1) { + // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max() + err = "Compute error."; + } + + // TODO: handle ret == 2 (abort) when we start aborting + + if (!err.empty()) { + SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); + + for (auto & slot : slots) { + if (slot.is_processing()) { + send_error(slot, err); + slot.release(); + + // note: it's complicated to keep track of how much of the current batch has been + // processed before the error occurred, so we simply clear the entire context + clear_slot(slot); + } + } + + break; + } + } + + // retry with half the batch size to try to find a free slot in the KV cache + if (!try_clear_idle_slots()) { + n_batch /= 2; + } + + SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); + + continue; // continue loop of n_batch + } + + // move the head of the batch forward with the number of tokens we just processed + i_next = i + n_tokens; + + // on successful decode, restore the original batch size + n_batch = llama_n_batch(ctx); + + for (auto & slot : slots) { + // optionally send prompt processing progress + if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { + if (slot.task->params.stream && slot.task->params.return_progress) { + send_partial_response(slot, {}, true); + } + } + + if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { + continue; // continue loop of slots + } + + if (slot.state == SLOT_STATE_DONE_PROMPT) { + if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) { + // prompt evaluated for embedding + send_embedding(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + + if (slot.task->type == SERVER_TASK_TYPE_RERANK) { + send_rerank(slot, batch_view); + slot.release(); + slot.i_batch = -1; + continue; // continue loop of slots + } + + // prompt evaluated for next-token prediction + slot.state = SLOT_STATE_GENERATING; + } else if (slot.state != SLOT_STATE_GENERATING) { + continue; // continue loop of slots + } + + const int tok_idx = slot.i_batch - i; + + llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); + + slot.i_batch = -1; + + common_sampler_accept(slot.smpl, id, true); + + slot.n_decoded += 1; + + const int64_t t_current = ggml_time_us(); + + if (slot.n_decoded == 1) { + slot.t_start_generation = t_current; + slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; + metrics.on_prompt_eval(slot); + } + + slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; + + completion_token_output result; + result.tok = id; + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs + + if (slot.task->params.sampling.n_probs > 0) { + populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx); + } + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + slot.release(); + + continue; + } + } + + // do speculative decoding + // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] + // perform the speculative drafting for all sequences at the same time in a single batch + for (auto & slot : slots) { + if (!slot.is_processing() || !slot.can_speculate()) { + continue; + } + + if (slot.state != SLOT_STATE_GENERATING) { + continue; + } + + if (mctx) { + // we should never reach this, as speculative is automatically disabled if mmproj is loaded + GGML_ABORT("not supported by multimodal"); + } + + // determine the max draft that fits the current slot state + int n_draft_max = slot.task->params.speculative.n_max; + + // note: slot.prompt is not yet expanded with the `id` token sampled above + // also, need to leave space for 1 extra token to allow context shifts + n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2); + + if (slot.n_remaining > 0) { + n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); + } + + SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); + + if (n_draft_max < slot.task->params.speculative.n_min) { + SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min); + + continue; + } + + llama_token id = slot.sampled; + + struct common_speculative_params params_spec; + params_spec.n_draft = n_draft_max; + params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max; + params_spec.p_min = slot.task->params.speculative.p_min; + + const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); + llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); + + // ignore small drafts + if (slot.task->params.speculative.n_min > (int) draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); + + continue; + } + + // keep track of total number of drafted tokens tested + slot.n_draft_total += draft.size(); + + // construct the speculation batch + common_batch_clear(slot.batch_spec); + common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true); + + for (size_t i = 0; i < draft.size(); ++i) { + common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true); + } + + SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); + + llama_decode(ctx, slot.batch_spec); + + // the accepted tokens from the speculation + const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); + + slot.n_decoded += ids.size(); + + // update how many tokens out of those tested were accepted + slot.n_draft_accepted += ids.size() - 1; + + slot.prompt.tokens.push_back(id); + slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); + + llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1); + + for (size_t i = 0; i < ids.size(); ++i) { + completion_token_output result; + + result.tok = ids[i]; + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // set later + + // TODO: set result.probs + + if (!process_token(result, slot)) { + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + slot.release(); + + break; + } + } + + SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens()); + } + } + + SRV_DBG("%s", "run slots completed\n"); + } + + json model_meta() const { + return json { + {"vocab_type", llama_vocab_type (vocab)}, + {"n_vocab", llama_vocab_n_tokens (vocab)}, + {"n_ctx_train", llama_model_n_ctx_train(model)}, + {"n_embd", llama_model_n_embd (model)}, + {"n_params", llama_model_n_params (model)}, + {"size", llama_model_size (model)}, + }; + } + + int get_slot_n_ctx() { + return slots.back().n_ctx; + } +}; + +// +// server_context (public API) +// + +server_context::server_context() : impl(new server_context_impl()) {} +server_context::~server_context() = default; + +void server_context::init() { + impl->init(); +} + +bool server_context::load_model(const common_params & params) { + return impl->load_model(params); +} + +void server_context::start_loop() { + impl->queue_tasks.start_loop(); +} + +void server_context::terminate() { + impl->queue_tasks.terminate(); +} + +llama_context * server_context::get_llama_context() const { + return impl->ctx; +} + +std::pair server_context::get_queues() { + return { impl->queue_tasks, impl->queue_results }; +} + + + +// generator-like API for HTTP response generation +struct server_res_generator : server_http_res { + server_response_reader rd; + server_res_generator(server_context_impl & ctx_server) + : rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS) {} + void ok(const json & response_data) { + status = 200; + data = safe_json_to_str(response_data); + } + void error(const json & error_data) { + status = json_value(error_data, "code", 500); + data = safe_json_to_str({{ "error", error_data }}); + } +}; + + + +// +// server_routes +// + +static std::unique_ptr handle_completions_impl( + server_context_impl & ctx_server, + server_task_type type, + const json & data, + const std::vector & files, + const std::function & should_stop, + task_response_type res_type) { + GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); + + auto res = std::make_unique(ctx_server); + auto completion_id = gen_chatcmplid(); + auto & rd = res->rd; + + try { + std::vector tasks; + + const auto & prompt = data.at("prompt"); + // TODO: this log can become very long, put it behind a flag or think about a more compact format + //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); + + // process prompt + std::vector inputs; + + if (res_type != TASK_RESPONSE_TYPE_NONE && ctx_server.mctx != nullptr) { + // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below. + inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); + } else { + // Everything else, including multimodal completions. + inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); + } + tasks.reserve(inputs.size()); + for (size_t i = 0; i < inputs.size(); i++) { + server_task task = server_task(type); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + + task.tokens = std::move(inputs[i]); + task.params = server_task::params_from_json_cmpl( + ctx_server.ctx, + ctx_server.params_base, + data); + task.id_slot = json_value(data, "id_slot", -1); + + // OAI-compat + task.params.res_type = res_type; + task.params.oaicompat_cmpl_id = completion_id; + // oaicompat_model is already populated by params_from_json_cmpl + + tasks.push_back(std::move(task)); + } + + rd.post_tasks(std::move(tasks)); + } catch (const std::exception & e) { + res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + bool stream = json_value(data, "stream", false); + + if (!stream) { + // non-stream, wait for the results + auto all_results = rd.wait_for_all(should_stop); + if (all_results.is_terminated) { + return res; // connection is closed + } else if (all_results.error) { + res->error(all_results.error->to_json()); + return res; + } else { + json arr = json::array(); + for (auto & res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + arr.push_back(res->to_json()); + } + // if single request, return single object instead of array + res->ok(arr.size() == 1 ? arr[0] : arr); + } + + } else { + // in streaming mode, the first error must be treated as non-stream response + // this is to match the OAI API behavior + // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 + server_task_result_ptr first_result = rd.next(should_stop); + if (first_result == nullptr) { + return res; // connection is closed + } else if (first_result->is_error()) { + res->error(first_result->to_json()); + return res; + } else { + GGML_ASSERT( + dynamic_cast(first_result.get()) != nullptr + || dynamic_cast(first_result.get()) != nullptr + ); + } + + // next responses are streamed + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + res->data = format_anthropic_sse(first_result->to_json()); + } else { + res->data = format_oai_sse(first_result->to_json()); // to be sent immediately + } + res->status = 200; + res->content_type = "text/event-stream"; + res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool { + if (should_stop()) { + SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); + return false; // should_stop condition met + } + + if (!res_this->data.empty()) { + // flush the first chunk + output = std::move(res_this->data); + res_this->data.clear(); + return true; + } + + server_response_reader & rd = res_this->rd; + + // check if there is more data + if (!rd.has_next()) { + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + // Anthropic doesn't send [DONE], message_stop was already sent + output = ""; + } else if (res_type != TASK_RESPONSE_TYPE_NONE) { + output = "data: [DONE]\n\n"; + } else { + output = ""; + } + SRV_DBG("%s", "all results received, terminating stream\n"); + return false; // no more data, terminate + } + + // receive subsequent results + auto result = rd.next(should_stop); + if (result == nullptr) { + SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); + return false; // should_stop condition met + } + + // send the results + json res_json = result->to_json(); + if (result->is_error()) { + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + output = format_anthropic_sse({ + {"event", "error"}, + {"data", res_json}, + }); + } else { + output = format_oai_sse(json {{ "error", res_json }}); + } + SRV_DBG("%s", "error received during streaming, terminating stream\n"); + return false; // terminate on error + } else { + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + output = format_anthropic_sse(res_json); + } else { + output = format_oai_sse(res_json); + } + } + + // has next data, continue + return true; + }; + } + + return res; +} + +void server_routes::init_routes() { + this->get_health = [this](const server_http_req &) { + // error and loading states are handled by middleware + auto res = std::make_unique(ctx_server); + res->ok({{"status", "ok"}}); + return res; + }; + + this->get_metrics = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + if (!params.endpoint_metrics) { + res->error(format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + // request slots data using task queue + // TODO: use server_response_reader + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_METRICS); + task.id = task_id; + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task), true); // high-priority task + } + + // get the result + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + // TODO: get rid of this dynamic_cast + auto res_task = dynamic_cast(result.get()); + GGML_ASSERT(res_task != nullptr); + + // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names + json all_metrics_def = json { + {"counter", {{ + {"name", "prompt_tokens_total"}, + {"help", "Number of prompt tokens processed."}, + {"value", (uint64_t) res_task->n_prompt_tokens_processed_total} + }, { + {"name", "prompt_seconds_total"}, + {"help", "Prompt process time"}, + {"value", (uint64_t) res_task->t_prompt_processing_total / 1.e3} + }, { + {"name", "tokens_predicted_total"}, + {"help", "Number of generation tokens processed."}, + {"value", (uint64_t) res_task->n_tokens_predicted_total} + }, { + {"name", "tokens_predicted_seconds_total"}, + {"help", "Predict process time"}, + {"value", (uint64_t) res_task->t_tokens_generation_total / 1.e3} + }, { + {"name", "n_decode_total"}, + {"help", "Total number of llama_decode() calls"}, + {"value", res_task->n_decode_total} + }, { + {"name", "n_tokens_max"}, + {"help", "Largest observed n_tokens."}, + {"value", res_task->n_tokens_max} + }, { + {"name", "n_busy_slots_per_decode"}, + {"help", "Average number of busy slots per llama_decode() call"}, + {"value", (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)} + }}}, + {"gauge", {{ + {"name", "prompt_tokens_seconds"}, + {"help", "Average prompt throughput in tokens/s."}, + {"value", res_task->n_prompt_tokens_processed ? 1.e3 / res_task->t_prompt_processing * res_task->n_prompt_tokens_processed : 0.} + },{ + {"name", "predicted_tokens_seconds"}, + {"help", "Average generation throughput in tokens/s."}, + {"value", res_task->n_tokens_predicted ? 1.e3 / res_task->t_tokens_generation * res_task->n_tokens_predicted : 0.} + },{ + {"name", "requests_processing"}, + {"help", "Number of requests processing."}, + {"value", (uint64_t) res_task->n_processing_slots} + },{ + {"name", "requests_deferred"}, + {"help", "Number of requests deferred."}, + {"value", (uint64_t) res_task->n_tasks_deferred} + }}} + }; + + std::stringstream prometheus; + + for (const auto & el : all_metrics_def.items()) { + const auto & type = el.key(); + const auto & metrics_def = el.value(); + + for (const auto & metric_def : metrics_def) { + const std::string name = metric_def.at("name"); + const std::string help = metric_def.at("help"); + + auto value = json_value(metric_def, "value", 0.); + prometheus << "# HELP llamacpp:" << name << " " << help << "\n" + << "# TYPE llamacpp:" << name << " " << type << "\n" + << "llamacpp:" << name << " " << value << "\n"; + } + } + + res->headers["Process-Start-Time-Unix"] = std::to_string(res_task->t_start); + res->content_type = "text/plain; version=0.0.4"; + res->status = 200; + res->data = prometheus.str(); + return res; + }; + + this->get_slots = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + if (!params.endpoint_slots) { + res->error(format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + // request slots data using task queue + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_METRICS); + task.id = task_id; + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task), true); // high-priority task + } + + // get the result + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + // TODO: get rid of this dynamic_cast + auto res_task = dynamic_cast(result.get()); + GGML_ASSERT(res_task != nullptr); + + // optionally return "fail_on_no_slot" error + if (!req.get_param("fail_on_no_slot").empty()) { + if (res_task->n_idle_slots == 0) { + res->error(format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); + return res; + } + } + + res->ok(res_task->slots_data); + return res; + }; + + this->post_slots = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + if (params.slot_save_path.empty()) { + res->error(format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + std::string id_slot_str = req.get_param("id_slot"); + int id_slot; + + try { + id_slot = std::stoi(id_slot_str); + } catch (const std::exception &) { + res->error(format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + std::string action = req.get_param("action"); + + if (action == "save") { + return handle_slots_save(req, id_slot); + } else if (action == "restore") { + return handle_slots_restore(req, id_slot); + } else if (action == "erase") { + return handle_slots_erase(req, id_slot); + } else { + res->error(format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + }; + + this->get_props = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + json default_generation_settings_for_props; + + { + task_params params; + + params.sampling = ctx_server.params_base.sampling; + + default_generation_settings_for_props = json { + {"params", params.to_json(true)}, + {"n_ctx", ctx_server.get_slot_n_ctx()}, + }; + } + + // this endpoint is publicly available, please only return what is safe to be exposed + json data = { + { "default_generation_settings", default_generation_settings_for_props }, + { "total_slots", ctx_server.params_base.n_parallel }, + { "model_alias", ctx_server.params_base.model_alias }, + { "model_path", ctx_server.params_base.model.path }, + { "modalities", json { + {"vision", ctx_server.oai_parser_opt.allow_image}, + {"audio", ctx_server.oai_parser_opt.allow_audio}, + } }, + { "endpoint_slots", params.endpoint_slots }, + { "endpoint_props", params.endpoint_props }, + { "endpoint_metrics", params.endpoint_metrics }, + { "webui", params.webui }, + { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) }, + { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)}, + { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)}, + { "build_info", build_info }, + }; + if (ctx_server.params_base.use_jinja) { + if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) { + data["chat_template_tool_use"] = tool_use_src; + } + } + + res->ok(data); + return res; + }; + + this->post_props = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + if (!params.endpoint_props) { + res->error(format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + // update any props here + + res->ok({{ "success", true }}); + return res; + }; + + this->get_api_show = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + bool has_mtmd = ctx_server.mctx != nullptr; + json data = { + { + "template", common_chat_templates_source(ctx_server.chat_templates.get()), + }, + { + "model_info", { + { "llama.context_length", ctx_server.get_slot_n_ctx() }, + } + }, + {"modelfile", ""}, + {"parameters", ""}, + {"template", common_chat_templates_source(ctx_server.chat_templates.get())}, + {"details", { + {"parent_model", ""}, + {"format", "gguf"}, + {"family", ""}, + {"families", {""}}, + {"parameter_size", ""}, + {"quantization_level", ""} + }}, + {"model_info", ""}, + {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})} + }; + + res->ok(data); + return res; + }; + + this->post_infill = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + // check model compatibility + std::string err; + if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + err += "prefix token is missing. "; + } + if (llama_vocab_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + err += "suffix token is missing. "; + } + if (llama_vocab_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + err += "middle token is missing. "; + } + if (!err.empty()) { + res->error(format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + // validate input + json data = json::parse(req.body); + if (data.contains("prompt") && !data.at("prompt").is_string()) { + // prompt is optional + res->error(format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + } + + if (!data.contains("input_prefix")) { + res->error(format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST)); + } + + if (!data.contains("input_suffix")) { + res->error(format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST)); + } + + if (data.contains("input_extra") && !data.at("input_extra").is_array()) { + // input_extra is optional + res->error(format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + json input_extra = json_value(data, "input_extra", json::array()); + for (const auto & chunk : input_extra) { + // { "text": string, "filename": string } + if (!chunk.contains("text") || !chunk.at("text").is_string()) { + res->error(format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + // filename is optional + if (chunk.contains("filename") && !chunk.at("filename").is_string()) { + res->error(format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + } + data["input_extra"] = input_extra; // default to empty array if it's not exist + + std::string prompt = json_value(data, "prompt", std::string()); + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true); + SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); + data["prompt"] = format_prompt_infill( + ctx_server.vocab, + data.at("input_prefix"), + data.at("input_suffix"), + data.at("input_extra"), + ctx_server.params_base.n_batch, + ctx_server.params_base.n_predict, + ctx_server.get_slot_n_ctx(), + ctx_server.params_base.spm_infill, + tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal. + ); + + std::vector files; // dummy + return handle_completions_impl( + ctx_server, + SERVER_TASK_TYPE_INFILL, + data, + files, + req.should_stop, + TASK_RESPONSE_TYPE_NONE); // infill is not OAI compatible + }; + + this->post_completions = [this](const server_http_req & req) { + std::vector files; // dummy + const json body = json::parse(req.body); + return handle_completions_impl( + ctx_server, + SERVER_TASK_TYPE_COMPLETION, + body, + files, + req.should_stop, + TASK_RESPONSE_TYPE_NONE); + }; + + this->post_completions_oai = [this](const server_http_req & req) { + std::vector files; // dummy + const json body = json::parse(req.body); + return handle_completions_impl( + ctx_server, + SERVER_TASK_TYPE_COMPLETION, + body, + files, + req.should_stop, + TASK_RESPONSE_TYPE_OAI_CMPL); + }; + + this->post_chat_completions = [this](const server_http_req & req) { + std::vector files; + json body = json::parse(req.body); + json body_parsed = oaicompat_chat_params_parse( + body, + ctx_server.oai_parser_opt, + files); + return handle_completions_impl( + ctx_server, + SERVER_TASK_TYPE_COMPLETION, + body_parsed, + files, + req.should_stop, + TASK_RESPONSE_TYPE_OAI_CHAT); + }; + + this->post_anthropic_messages = [this](const server_http_req & req) { + std::vector files; + json body = convert_anthropic_to_oai(json::parse(req.body)); + json body_parsed = oaicompat_chat_params_parse( + body, + ctx_server.oai_parser_opt, + files); + return handle_completions_impl( + ctx_server, + SERVER_TASK_TYPE_COMPLETION, + body_parsed, + files, + req.should_stop, + TASK_RESPONSE_TYPE_ANTHROPIC); + }; + + this->post_anthropic_count_tokens = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + std::vector files; + json body = convert_anthropic_to_oai(json::parse(req.body)); + json body_parsed = oaicompat_chat_params_parse( + body, + ctx_server.oai_parser_opt, + files); + + json prompt = body_parsed.at("prompt"); + llama_tokens tokens = tokenize_mixed(ctx_server.vocab, prompt, true, true); + + res->ok({{"input_tokens", static_cast(tokens.size())}}); + return res; + }; + + // same with handle_chat_completions, but without inference part + this->post_apply_template = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + std::vector files; // dummy, unused + json body = json::parse(req.body); + json data = oaicompat_chat_params_parse( + body, + ctx_server.oai_parser_opt, + files); + res->ok({{ "prompt", std::move(data.at("prompt")) }}); + return res; + }; + + this->get_models = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + json model_meta = nullptr; + if (is_ready()) { + model_meta = ctx_server.model_meta(); + } + bool has_mtmd = ctx_server.mctx != nullptr; + json models = { + {"models", { + { + {"name", params.model_alias.empty() ? params.model.path : params.model_alias}, + {"model", params.model_alias.empty() ? params.model.path : params.model_alias}, + {"modified_at", ""}, + {"size", ""}, + {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash + {"type", "model"}, + {"description", ""}, + {"tags", {""}}, + {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}, + {"parameters", ""}, + {"details", { + {"parent_model", ""}, + {"format", "gguf"}, + {"family", ""}, + {"families", {""}}, + {"parameter_size", ""}, + {"quantization_level", ""} + }} + } + }}, + {"object", "list"}, + {"data", { + { + {"id", params.model_alias.empty() ? params.model.path : params.model_alias}, + {"object", "model"}, + {"created", std::time(0)}, + {"owned_by", "llamacpp"}, + {"meta", model_meta}, + }, + }} + }; + + res->ok(models); + return res; + }; + + this->post_tokenize = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + const json body = json::parse(req.body); + json tokens_response = json::array(); + if (body.count("content") != 0) { + const bool add_special = json_value(body, "add_special", false); + const bool parse_special = json_value(body, "parse_special", true); + const bool with_pieces = json_value(body, "with_pieces", false); + + llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, parse_special); + + if (with_pieces) { + for (const auto& token : tokens) { + std::string piece = common_token_to_piece(ctx_server.ctx, token); + json piece_json; + + // Check if the piece is valid UTF-8 + if (is_valid_utf8(piece)) { + piece_json = piece; + } else { + // If not valid UTF-8, store as array of byte values + piece_json = json::array(); + for (unsigned char c : piece) { + piece_json.push_back(static_cast(c)); + } + } + + tokens_response.push_back({ + {"id", token}, + {"piece", piece_json} + }); + } + } else { + tokens_response = tokens; + } + } + + res->ok(json{{"tokens", std::move(tokens_response)}}); + return res; + }; + + this->post_detokenize = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + const json body = json::parse(req.body); + + std::string content; + if (body.count("tokens") != 0) { + const llama_tokens tokens = body.at("tokens"); + content = tokens_to_str(ctx_server.ctx, tokens); + } + + res->ok(json{{"content", std::move(content)}}); + return res; + }; + + this->post_embeddings = [this](const server_http_req & req) { + return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_NONE); + }; + + this->post_embeddings_oai = [this](const server_http_req & req) { + return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_OAI_EMBD); + }; + + this->post_rerank = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) { + res->error(format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + const json body = json::parse(req.body); + + // if true, use TEI API format, otherwise use Jina API format + // Jina: https://jina.ai/reranker/ + // TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank + bool is_tei_format = body.contains("texts"); + + json query; + if (body.count("query") == 1) { + query = body.at("query"); + if (!query.is_string()) { + res->error(format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + } else { + res->error(format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + std::vector documents = json_value(body, "documents", + json_value(body, "texts", std::vector())); + if (documents.empty()) { + res->error(format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + int top_n = json_value(body, "top_n", (int)documents.size()); + + // create and queue the task + json responses = json::array(); + server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS); + { + std::vector tasks; + tasks.reserve(documents.size()); + for (size_t i = 0; i < documents.size(); i++) { + auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); + server_task task = server_task(SERVER_TASK_TYPE_RERANK); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.tokens = std::move(tmp); + tasks.push_back(std::move(task)); + } + rd.post_tasks(std::move(tasks)); + } + + // wait for the results + auto all_results = rd.wait_for_all(req.should_stop); + + // collect results + if (all_results.is_terminated) { + return res; // connection is closed + } else if (all_results.error) { + res->error(all_results.error->to_json()); + return res; + } else { + for (auto & res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); + } + } + + // write JSON response + json root = format_response_rerank( + body, + responses, + is_tei_format, + documents, + top_n); + + res->ok(root); + return res; + }; + + this->get_lora_adapters = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + json result = json::array(); + const auto & loras = ctx_server.params_base.lora_adapters; + for (size_t i = 0; i < loras.size(); ++i) { + auto & lora = loras[i]; + json entry = { + {"id", i}, + {"path", lora.path}, + {"scale", lora.scale}, + {"task_name", lora.task_name}, + {"prompt_prefix", lora.prompt_prefix}, + }; + std::string alora_invocation_string = ""; + const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr); + std::vector alora_invocation_tokens; + if (n_alora_tokens) { + const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr); + for (uint64_t i = 0; i < n_alora_tokens; ++i) { + alora_invocation_string += common_token_to_piece(ctx_server.ctx, alora_tokens[i]); + alora_invocation_tokens.push_back(alora_tokens[i]); + } + entry["alora_invocation_string"] = alora_invocation_string; + entry["alora_invocation_tokens"] = alora_invocation_tokens; + } + result.push_back(std::move(entry)); + } + res->ok(result); + return res; + }; + + this->post_lora_adapters = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + const json body = json::parse(req.body); + if (!body.is_array()) { + res->error(format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SET_LORA); + task.id = task_id; + task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body); + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); + } + + // get the result + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res->ok(result->to_json()); + return res; + }; +} + +std::unique_ptr server_routes::handle_slots_save(const server_http_req & req, int id_slot) { + auto res = std::make_unique(ctx_server); + const json request_data = json::parse(req.body); + std::string filename = request_data.at("filename"); + if (!fs_validate_filename(filename)) { + res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + std::string filepath = params.slot_save_path + filename; + + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_SAVE); + task.id = task_id; + task.slot_action.slot_id = id_slot; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; + + // TODO: use server_response_reader + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); + } + + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + res->ok(result->to_json()); + return res; +} + +std::unique_ptr server_routes::handle_slots_restore(const server_http_req & req, int id_slot) { + auto res = std::make_unique(ctx_server); + const json request_data = json::parse(req.body); + std::string filename = request_data.at("filename"); + if (!fs_validate_filename(filename)) { + res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + std::string filepath = params.slot_save_path + filename; + + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); + task.id = task_id; + task.slot_action.slot_id = id_slot; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; + + // TODO: use server_response_reader + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); + } + + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res->ok(result->to_json()); + return res; +} + +std::unique_ptr server_routes::handle_slots_erase(const server_http_req &, int id_slot) { + auto res = std::make_unique(ctx_server); + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_ERASE); + task.id = task_id; + task.slot_action.slot_id = id_slot; + + // TODO: use server_response_reader + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); + } + + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res->ok(result->to_json()); + return res; +} + +std::unique_ptr server_routes::handle_embeddings_impl(const server_http_req & req, task_response_type res_type) { + auto res = std::make_unique(ctx_server); + if (!ctx_server.params_base.embedding) { + res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + if (res_type != TASK_RESPONSE_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + const json body = json::parse(req.body); + + // for the shape of input/content, see tokenize_input_prompts() + json prompt; + if (body.count("input") != 0) { + prompt = body.at("input"); + } else if (body.contains("content")) { + res_type = TASK_RESPONSE_TYPE_NONE; // "content" field is not OAI compatible + prompt = body.at("content"); + } else { + res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + bool use_base64 = false; + if (body.count("encoding_format") != 0) { + const std::string& format = body.at("encoding_format"); + if (format == "base64") { + use_base64 = true; + } else if (format != "float") { + res->error(format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + } + + auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); + for (const auto & tokens : tokenized_prompts) { + // this check is necessary for models that do not add BOS token to the input + if (tokens.empty()) { + res->error(format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + } + + int embd_normalize = 2; // default to Euclidean/L2 norm + if (body.count("embd_normalize") != 0) { + embd_normalize = body.at("embd_normalize"); + if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx)); + } + } + + // create and queue the task + json responses = json::array(); + server_response_reader rd({ctx_server.queue_tasks, ctx_server.queue_results}, HTTP_POLLING_SECONDS); + { + std::vector tasks; + for (size_t i = 0; i < tokenized_prompts.size(); i++) { + server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); + + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.tokens = std::move(tokenized_prompts[i]); + + // OAI-compat + task.params.res_type = res_type; + task.params.embd_normalize = embd_normalize; + + tasks.push_back(std::move(task)); + } + rd.post_tasks(std::move(tasks)); + } + + // wait for the results + auto all_results = rd.wait_for_all(req.should_stop); + + // collect results + if (all_results.is_terminated) { + return res; // connection is closed + } else if (all_results.error) { + res->error(all_results.error->to_json()); + return res; + } else { + for (auto & res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); + } + } + + // write JSON response + json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD + ? format_embeddings_response_oaicompat(body, responses, use_base64) + : json(responses); + res->ok(root); + return res; +} diff --git a/tools/server/server-context.h b/tools/server/server-context.h new file mode 100644 index 0000000000..05b4afaeeb --- /dev/null +++ b/tools/server/server-context.h @@ -0,0 +1,83 @@ +#include "server-http.h" +#include "server-task.h" +#include "server-queue.h" + +#include + +#include +#include + +struct server_context_impl; // private implementation + +struct server_context { + std::unique_ptr impl; + + server_context(); + ~server_context(); + + // initialize slots and server-related data + void init(); + + // load the model and initialize llama_context + // returns true on success + bool load_model(const common_params & params); + + // this function will block main thread until termination + void start_loop(); + + // terminate main loop (will unblock start_loop) + void terminate(); + + // get the underlaying llama_context + llama_context * get_llama_context() const; + + // get the underlaying queue_tasks and queue_results + // used by CLI application + std::pair get_queues(); +}; + + +// forward declarations +struct server_res_generator; + +struct server_routes { + server_routes(const common_params & params, server_context & ctx_server, std::function is_ready = []() { return true; }) + : params(params), ctx_server(*ctx_server.impl), is_ready(is_ready) { + init_routes(); + } + + void init_routes(); + // handlers using lambda function, so that they can capture `this` without `std::bind` + server_http_context::handler_t get_health; + server_http_context::handler_t get_metrics; + server_http_context::handler_t get_slots; + server_http_context::handler_t post_slots; + server_http_context::handler_t get_props; + server_http_context::handler_t post_props; + server_http_context::handler_t get_api_show; + server_http_context::handler_t post_infill; + server_http_context::handler_t post_completions; + server_http_context::handler_t post_completions_oai; + server_http_context::handler_t post_chat_completions; + server_http_context::handler_t post_anthropic_messages; + server_http_context::handler_t post_anthropic_count_tokens; + server_http_context::handler_t post_apply_template; + server_http_context::handler_t get_models; + server_http_context::handler_t post_tokenize; + server_http_context::handler_t post_detokenize; + server_http_context::handler_t post_embeddings; + server_http_context::handler_t post_embeddings_oai; + server_http_context::handler_t post_rerank; + server_http_context::handler_t get_lora_adapters; + server_http_context::handler_t post_lora_adapters; +private: + // TODO: move these outside of server_routes? + std::unique_ptr handle_slots_save(const server_http_req & req, int id_slot); + std::unique_ptr handle_slots_restore(const server_http_req & req, int id_slot); + std::unique_ptr handle_slots_erase(const server_http_req &, int id_slot); + std::unique_ptr handle_embeddings_impl(const server_http_req & req, task_response_type res_type); + + const common_params & params; + server_context_impl & ctx_server; + std::function is_ready; +}; diff --git a/tools/server/server-queue.cpp b/tools/server/server-queue.cpp index 65c8a0a9ae..38a4858522 100644 --- a/tools/server/server-queue.cpp +++ b/tools/server/server-queue.cpp @@ -266,3 +266,86 @@ void server_response::terminate() { running = false; condition_results.notify_all(); } + +// +// server_response_reader +// + +void server_response_reader::post_tasks(std::vector && tasks) { + id_tasks = server_task::get_list_id(tasks); + queue_results.add_waiting_tasks(tasks); + queue_tasks.post(std::move(tasks)); +} + +bool server_response_reader::has_next() const { + return !cancelled && received_count < id_tasks.size(); +} + +// return nullptr if should_stop() is true before receiving a result +// note: if one error is received, it will stop further processing and return error result +server_task_result_ptr server_response_reader::next(const std::function & should_stop) { + while (true) { + server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, polling_interval_seconds); + if (result == nullptr) { + // timeout, check stop condition + if (should_stop()) { + SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n"); + return nullptr; + } + } else { + if (result->is_error()) { + stop(); // cancel remaining tasks + SRV_DBG("%s", "received error result, stopping further processing\n"); + return result; + } + if (result->is_stop()) { + received_count++; + } + return result; + } + } + + // should not reach here +} + +server_response_reader::batch_response server_response_reader::wait_for_all(const std::function & should_stop) { + batch_response batch_res; + batch_res.results.resize(id_tasks.size()); + while (has_next()) { + auto res = next(should_stop); + if (res == nullptr) { + batch_res.is_terminated = true; + return batch_res; + } + if (res->is_error()) { + batch_res.error = std::move(res); + return batch_res; + } + const size_t idx = res->get_index(); + GGML_ASSERT(idx < batch_res.results.size() && "index out of range"); + GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received"); + batch_res.results[idx] = std::move(res); + } + return batch_res; +} + +void server_response_reader::stop() { + queue_results.remove_waiting_task_ids(id_tasks); + if (has_next() && !cancelled) { + // if tasks is not finished yet, cancel them + cancelled = true; + std::vector cancel_tasks; + cancel_tasks.reserve(id_tasks.size()); + for (const auto & id_task : id_tasks) { + SRV_WRN("cancel task, id_task = %d\n", id_task); + server_task task(SERVER_TASK_TYPE_CANCEL); + task.id_target = id_task; + queue_results.remove_waiting_task_id(id_task); + cancel_tasks.push_back(std::move(task)); + } + // push to beginning of the queue, so it has highest priority + queue_tasks.post(std::move(cancel_tasks), true); + } else { + SRV_DBG("%s", "all tasks already finished, no need to cancel\n"); + } +} diff --git a/tools/server/server-queue.h b/tools/server/server-queue.h index 47ef58425e..209d2017c7 100644 --- a/tools/server/server-queue.h +++ b/tools/server/server-queue.h @@ -108,3 +108,39 @@ public: // terminate the waiting loop void terminate(); }; + +// utility class to make working with server_queue and server_response easier +// it provides a generator-like API for server responses +// support pooling connection state and aggregating multiple results +struct server_response_reader { + std::unordered_set id_tasks; + server_queue & queue_tasks; + server_response & queue_results; + size_t received_count = 0; + bool cancelled = false; + int polling_interval_seconds; + + // should_stop function will be called each polling_interval_seconds + server_response_reader(std::pair server_queues, int polling_interval_seconds) + : queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {} + ~server_response_reader() { + stop(); + } + + void post_tasks(std::vector && tasks); + bool has_next() const; + + // return nullptr if should_stop() is true before receiving a result + // note: if one error is received, it will stop further processing and return error result + server_task_result_ptr next(const std::function & should_stop); + + struct batch_response { + bool is_terminated = false; // if true, indicates that processing was stopped before all results were received + std::vector results; + server_task_result_ptr error; // nullptr if no error + }; + // aggregate multiple results + batch_response wait_for_all(const std::function & should_stop); + + void stop(); +}; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 96b2df27f7..5256790db2 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1,3662 +1,19 @@ -#include "server-common.h" +#include "server-context.h" #include "server-http.h" -#include "server-task.h" -#include "server-queue.h" #include "arg.h" #include "common.h" #include "llama.h" #include "log.h" -#include "sampling.h" -#include "speculative.h" -#include "mtmd.h" -#include "mtmd-helper.h" #include -#include -#include -#include #include -#include -#include +#include // for std::thread::hardware_concurrency -// fix problem with std::min and std::max #if defined(_WIN32) -#define WIN32_LEAN_AND_MEAN -#ifndef NOMINMAX -# define NOMINMAX -#endif #include #endif -using json = nlohmann::ordered_json; - -constexpr int HTTP_POLLING_SECONDS = 1; - -// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 -enum slot_state { - SLOT_STATE_IDLE, - SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future - SLOT_STATE_PROCESSING_PROMPT, - SLOT_STATE_DONE_PROMPT, - SLOT_STATE_GENERATING, -}; - -enum server_state { - SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet - SERVER_STATE_READY, // Server is ready and model is loaded -}; - -static bool server_task_type_need_embd(server_task_type task_type) { - switch (task_type) { - case SERVER_TASK_TYPE_EMBEDDING: - case SERVER_TASK_TYPE_RERANK: - return true; - default: - return false; - } -} - -static bool server_task_type_need_logits(server_task_type task_type) { - switch (task_type) { - case SERVER_TASK_TYPE_COMPLETION: - case SERVER_TASK_TYPE_INFILL: - return true; - default: - return false; - } -} - -struct server_slot { - int id; - - llama_batch batch_spec = {}; - - // TODO: change to unique_ptrs for consistency: - llama_context * ctx = nullptr; - llama_context * ctx_dft = nullptr; - - // multimodal - mtmd_context * mctx = nullptr; - - common_speculative * spec = nullptr; - - std::unique_ptr task; - std::unique_ptr task_prev; // used for debugging - - // used to determine the slot that has been used the longest - int64_t t_last_used = -1; - - // generation props - int32_t n_ctx = 0; // context size per slot - int32_t n_keep = 0; - int32_t n_decoded = 0; - int32_t n_remaining = -1; - int32_t i_batch = -1; - - int32_t n_prompt_tokens_cache = 0; - int32_t n_prompt_tokens_processed = 0; - - size_t last_nl_pos = 0; - - std::string generated_text; - llama_tokens generated_tokens; - - common_chat_msg chat_msg; - - std::vector generated_token_probs; - - bool has_next_token = true; - bool has_new_line = false; - bool truncated = false; - - stop_type stop; - - std::string stopping_word; - - // state - slot_state state = SLOT_STATE_IDLE; - - server_prompt prompt; - - void prompt_save(server_prompt_cache & prompt_cache) const { - GGML_ASSERT(prompt.data.size() == 0); - - const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); - - SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n", - (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); - - auto * cur = prompt_cache.alloc(prompt, cur_size); - if (cur == nullptr) { - return; - } - - llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0); - } - - bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { - bool res = prompt_cache.load(prompt, tokens, ctx, id); - if (!res) { - SLT_WRN(*this, "%s", "failed to load prompt from cache\n"); - } - - return res; - } - - std::vector lora; - int32_t alora_invocation_start = -1; - - // sampling - json json_schema; - - struct common_sampler * smpl = nullptr; - - llama_token sampled; - - common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - std::vector generated_tool_call_ids; - - // stats - size_t n_sent_text = 0; // number of sent text character - - int64_t t_start_process_prompt; - int64_t t_start_generation; - - double t_prompt_processing; // ms - double t_token_generation; // ms - - std::function callback_on_release; - - // Speculative decoding stats - int32_t n_draft_total = 0; // Total draft tokens generated - int32_t n_draft_accepted = 0; // Draft tokens actually accepted - - void reset() { - SLT_DBG(*this, "%s", "\n"); - - n_prompt_tokens_cache = 0; - - last_nl_pos = 0; - generated_text = ""; - has_new_line = false; - truncated = false; - stop = STOP_TYPE_NONE; - stopping_word = ""; - n_sent_text = 0; - chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - - generated_tokens.clear(); - generated_token_probs.clear(); - chat_msg = {}; - json_schema = json(); - generated_tool_call_ids.clear(); - - // clear speculative decoding stats - n_draft_total = 0; - n_draft_accepted = 0; - - task.reset(); - task_prev.reset(); - - // clear alora start - alora_invocation_start = -1; - } - - bool need_embd() const { - GGML_ASSERT(task); - - return server_task_type_need_embd(task->type); - } - - bool need_logits() const { - GGML_ASSERT(task); - - return server_task_type_need_logits(task->type); - } - - // if the context does not have a memory module then all embeddings have to be computed within a single ubatch - // also we cannot split if the pooling would require any past tokens - bool can_split() const { - return - !need_embd() || - (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); - } - - bool can_batch_with(server_slot & other_slot) const { - GGML_ASSERT(task); - - return task->type == other_slot.task->type && are_lora_equal(lora, other_slot.lora); - } - - bool has_budget(const common_params & global_params) { - GGML_ASSERT(task); - - if (task->params.n_predict == -1 && global_params.n_predict == -1) { - return true; // limitless - } - - n_remaining = -1; - - if (task->params.n_predict != -1) { - n_remaining = task->params.n_predict - n_decoded; - } else if (global_params.n_predict != -1) { - n_remaining = global_params.n_predict - n_decoded; - } - - return n_remaining > 0; // no budget - } - - bool is_processing() const { - return state != SLOT_STATE_IDLE; - } - - bool can_speculate() const { - return ctx_dft; - } - - void add_token(const completion_token_output & token) { - if (!is_processing()) { - SLT_WRN(*this, "%s", "slot is not processing\n"); - return; - } - generated_token_probs.push_back(token); - } - - void release() { - if (is_processing()) { - GGML_ASSERT(task); - - SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated); - - t_last_used = ggml_time_us(); - t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; - state = SLOT_STATE_IDLE; - - task_prev = std::move(task); - task.reset(); - - callback_on_release(id); - } - } - - result_timings get_timings() const { - result_timings timings; - timings.cache_n = n_prompt_tokens_cache; - - timings.prompt_n = n_prompt_tokens_processed; - timings.prompt_ms = t_prompt_processing; - timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; - timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - - timings.predicted_n = n_decoded; - timings.predicted_ms = t_token_generation; - timings.predicted_per_token_ms = t_token_generation / n_decoded; - timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; - - // Add speculative metrics - if (n_draft_total > 0) { - timings.draft_n = n_draft_total; - timings.draft_n_accepted = n_draft_accepted; - } - - return timings; - } - - const common_chat_msg & update_chat_msg(std::vector & diffs) { - GGML_ASSERT(task); - - auto previous_msg = chat_msg; - SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); - auto new_msg = common_chat_parse( - generated_text, - /* is_partial= */ stop != STOP_TYPE_EOS, - task->params.oaicompat_chat_syntax); - if (!new_msg.empty()) { - new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id); - chat_msg = new_msg; - diffs = common_chat_msg_diff::compute_diffs(previous_msg, new_msg.empty() ? previous_msg : new_msg); - } - return chat_msg; - } - - size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { - GGML_ASSERT(task); - - size_t stop_pos = std::string::npos; - - for (const std::string & word : task->params.antiprompt) { - size_t pos; - - if (is_full_stop) { - const size_t tmp = word.size() + last_token_size; - const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; - - pos = text.find(word, from_pos); - } else { - // otherwise, partial stop - pos = string_find_partial_stop(text, word); - } - - if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { - if (is_full_stop) { - stop = STOP_TYPE_WORD; - stopping_word = word; - has_next_token = false; - } - stop_pos = pos; - } - } - - return stop_pos; - } - - void print_timings() const { - const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; - const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - - const double t_gen = t_token_generation / n_decoded; - const double n_gen_second = 1e3 / t_token_generation * n_decoded; - - SLT_INF(*this, - "\n" - "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" - " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" - " total time = %10.2f ms / %5d tokens\n", - t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, - t_token_generation, n_decoded, t_gen, n_gen_second, - t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); - - if (n_draft_total > 0) { - const float draft_ratio = (float) n_draft_accepted / n_draft_total; - SLT_INF(*this, - "\n" - "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n", - draft_ratio, n_draft_accepted, n_draft_total - ); - } - } - - json to_json(bool only_metrics = false) const { - json res; - - res = { - {"id", id}, - {"n_ctx", n_ctx}, - {"speculative", can_speculate()}, - {"is_processing", is_processing()}, - }; - - const auto & ptask = task ? task : task_prev; - - if (ptask) { - res["id_task"] = ptask->id; - res["params"] = ptask->params.to_json(only_metrics); - res["next_token"] = { - { - {"has_next_token", has_next_token}, - {"has_new_line", has_new_line}, - {"n_remain", n_remaining}, - {"n_decoded", n_decoded}, - } - }; - - if (!only_metrics) { - res["prompt"] = ptask->tokens.detokenize(ctx, true); - res["generated"] = generated_text; - } - } - - return res; - } -}; - -struct server_metrics { - int64_t t_start = 0; - - uint64_t n_prompt_tokens_processed_total = 0; - uint64_t t_prompt_processing_total = 0; - uint64_t n_tokens_predicted_total = 0; - uint64_t t_tokens_generation_total = 0; - - uint64_t n_tokens_max = 0; - - uint64_t n_prompt_tokens_processed = 0; - uint64_t t_prompt_processing = 0; - - uint64_t n_tokens_predicted = 0; - uint64_t t_tokens_generation = 0; - - uint64_t n_decode_total = 0; - uint64_t n_busy_slots_total = 0; - - void init() { - t_start = ggml_time_us(); - } - - void on_prompt_eval(const server_slot & slot) { - n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; - n_prompt_tokens_processed += slot.n_prompt_tokens_processed; - t_prompt_processing += slot.t_prompt_processing; - t_prompt_processing_total += slot.t_prompt_processing; - - n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); - } - - void on_prediction(const server_slot & slot) { - n_tokens_predicted_total += slot.n_decoded; - n_tokens_predicted += slot.n_decoded; - t_tokens_generation += slot.t_token_generation; - t_tokens_generation_total += slot.t_token_generation; - } - - void on_decoded(const std::vector & slots) { - n_decode_total++; - for (const auto & slot : slots) { - if (slot.is_processing()) { - n_busy_slots_total++; - } - n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); - } - } - - void reset_bucket() { - n_prompt_tokens_processed = 0; - t_prompt_processing = 0; - n_tokens_predicted = 0; - t_tokens_generation = 0; - } -}; - -struct server_context { - common_params params_base; - - // note: keep these alive - they determine the lifetime of the model, context, etc. - common_init_result llama_init; - common_init_result llama_init_dft; - - llama_model * model = nullptr; - llama_context * ctx = nullptr; - - // multimodal - mtmd_context * mctx = nullptr; - - const llama_vocab * vocab = nullptr; - bool vocab_dft_compatible = true; - - llama_model * model_dft = nullptr; - - llama_context_params cparams_dft; - - llama_batch batch {}; - - bool add_bos_token = true; - - int32_t n_ctx; // total context for all clients / slots - - // slots / clients - std::vector slots; - - int slots_debug = 0; - - server_queue queue_tasks; - server_response queue_results; - - std::unique_ptr prompt_cache; - - server_metrics metrics; - - // Necessary similarity of prompt for slot selection - float slot_prompt_similarity = 0.0f; - - common_chat_templates_ptr chat_templates; - oaicompat_parser_options oai_parser_opt; - - ~server_context() { - mtmd_free(mctx); - - // Clear any sampling context - for (server_slot & slot : slots) { - common_sampler_free(slot.smpl); - slot.smpl = nullptr; - - llama_free(slot.ctx_dft); - slot.ctx_dft = nullptr; - - common_speculative_free(slot.spec); - slot.spec = nullptr; - - llama_batch_free(slot.batch_spec); - } - - llama_batch_free(batch); - } - - // load the model and initialize llama_context - bool load_model(const common_params & params) { - SRV_INF("loading model '%s'\n", params.model.path.c_str()); - - params_base = params; - - llama_init = common_init_from_params(params_base); - - model = llama_init.model.get(); - ctx = llama_init.context.get(); - - if (model == nullptr) { - SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); - return false; - } - - vocab = llama_model_get_vocab(model); - - n_ctx = llama_n_ctx(ctx); - - add_bos_token = llama_vocab_get_add_bos(vocab); - - if (params_base.has_speculative()) { - SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); - - auto params_dft = params_base; - - params_dft.devices = params_base.speculative.devices; - params_dft.model = params_base.speculative.model; - params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx; - params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; - params_dft.n_parallel = 1; - params_dft.cache_type_k = params_base.speculative.cache_type_k; - params_dft.cache_type_v = params_base.speculative.cache_type_v; - - params_dft.cpuparams.n_threads = params_base.speculative.cpuparams.n_threads; - params_dft.cpuparams_batch.n_threads = params_base.speculative.cpuparams_batch.n_threads; - params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides; - - llama_init_dft = common_init_from_params(params_dft); - - model_dft = llama_init_dft.model.get(); - - if (model_dft == nullptr) { - SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str()); - return false; - } - - vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft.context.get()); - if (!vocab_dft_compatible) { - SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str()); - } - - const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); - - cparams_dft = common_context_params_to_llama(params_dft); - cparams_dft.n_batch = n_ctx_dft; - - // the context is not needed - we will create one for each slot - llama_init_dft.context.reset(); - } - - chat_templates = common_chat_templates_init(model, params_base.chat_template); - try { - common_chat_format_example(chat_templates.get(), params.use_jinja, params.default_template_kwargs); - } catch (const std::exception & e) { - SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what()); - SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); - chat_templates = common_chat_templates_init(model, "chatml"); - } - - std::string & mmproj_path = params_base.mmproj.path; - if (!mmproj_path.empty()) { - mtmd_helper_log_set(common_log_default_callback, nullptr); - - mtmd_context_params mparams = mtmd_context_params_default(); - mparams.use_gpu = params_base.mmproj_use_gpu; - mparams.print_timings = false; - mparams.n_threads = params_base.cpuparams.n_threads; - mparams.flash_attn_type = params_base.flash_attn_type; - mparams.image_min_tokens = params_base.image_min_tokens; - mparams.image_max_tokens = params_base.image_max_tokens; - mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); - if (mctx == nullptr) { - SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); - return false; - } - SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str()); - - if (params_base.ctx_shift) { - params_base.ctx_shift = false; - SRV_WRN("%s\n", "ctx_shift is not supported by multimodal, it will be disabled"); - } - - if (params_base.n_cache_reuse) { - params_base.n_cache_reuse = 0; - SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); - } - - if (params_base.has_speculative()) { - SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal"); - return false; - } - } - - if (!llama_memory_can_shift(llama_get_memory(ctx))) { - if (params_base.ctx_shift) { - params_base.ctx_shift = false; - SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled"); - } - - if (params_base.n_cache_reuse) { - params_base.n_cache_reuse = 0; - SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled"); - } - } - - return true; - } - - // initialize slots and server-related data - void init() { - SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); - - const int n_ctx_train = llama_model_n_ctx_train(model); - - int n_ctx_slot = llama_n_ctx_seq(ctx); - if (n_ctx_slot > n_ctx_train) { - SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train); - n_ctx_slot = n_ctx_train; - } - - for (int i = 0; i < params_base.n_parallel; i++) { - server_slot slot; - - slot.id = i; - slot.ctx = ctx; - slot.n_ctx = n_ctx_slot; - slot.mctx = mctx; - slot.prompt.tokens.has_mtmd = mctx != nullptr; - - if (model_dft) { - slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); - - // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] - slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); - if (slot.ctx_dft == nullptr) { - SRV_ERR("%s", "failed to create draft context\n"); - return; - } - - slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft); - if (slot.spec == nullptr) { - SRV_ERR("%s", "failed to create speculator\n"); - return; - } - for (auto & pair : params_base.speculative.replacements) { - common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); - } - } - - SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx); - - slot.callback_on_release = [this](int) { - queue_tasks.pop_deferred_task(); - }; - - slot.reset(); - - slots.push_back(std::move(slot)); - } - - { - const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG"); - slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0; - - if (slots_debug) { - SRV_WRN("slots debug = %d\n", slots_debug); - } - } - - // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens - // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) - { - const int32_t n_batch = llama_n_batch(ctx); - batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); - } - - metrics.init(); - - if (params_base.cache_ram_mib != 0) { - if (params_base.cache_ram_mib < 0) { - SRV_WRN("prompt cache is enabled, size limit: %s\n", "no limit"); - } else { - SRV_WRN("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib); - } - SRV_WRN("%s", "use `--cache-ram 0` to disable the prompt cache\n"); - - prompt_cache = std::make_unique(params_base.cache_ram_mib, n_ctx); - } else { - SRV_WRN("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n"); - } - SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n"); - - // thinking is enabled if: - // 1. It's not explicitly disabled (reasoning_budget == 0) - // 2. The chat template supports it - const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get()); - SRV_INF("thinking = %d\n", enable_thinking); - - oai_parser_opt = { - /* use_jinja */ params_base.use_jinja, - /* prefill_assistant */ params_base.prefill_assistant, - /* reasoning_format */ params_base.reasoning_format, - /* chat_template_kwargs */ params_base.default_template_kwargs, - /* common_chat_templates */ chat_templates.get(), - /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, - /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, - /* enable_thinking */ enable_thinking, - }; - - // print sample chat example to make it clear which template is used - LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, - common_chat_templates_source(chat_templates.get()), - common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str()); - } - - server_slot * get_slot_by_id(int id) { - for (server_slot & slot : slots) { - if (slot.id == id) { - return &slot; - } - } - - return nullptr; - } - - server_slot * get_available_slot(const server_task & task) { - server_slot * ret = nullptr; - - bool update_cache = false; - - // find the slot that has at least n% prompt similarity - if (ret == nullptr && slot_prompt_similarity != 0.0f) { - float sim_best = 0; - - for (server_slot & slot : slots) { - // skip the slot if it is not available - if (slot.is_processing()) { - continue; - } - - const auto & tokens = slot.prompt.tokens; - - // skip the slot if it does not contains cached tokens - if (tokens.empty()) { - continue; - } - - // fraction of the Longest Common Prefix length with respect to the input prompt length - const float sim_cur = float(tokens.get_common_prefix(task.tokens)) / task.tokens.size(); - - // select the current slot if the criteria match - if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) { - sim_best = sim_cur; - - ret = &slot; - } - } - - if (ret != nullptr) { - const float f_keep = (sim_best*task.tokens.size()) / ret->prompt.tokens.size(); - - SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n", - sim_best, slot_prompt_similarity, f_keep); - - // if we are about to lose a large portion of the existing context - save it in the prompt cache - if (f_keep < 0.5f) { - update_cache = true; - } - } - } - - // find the slot that has been least recently used - if (ret == nullptr) { - int64_t t_last = -1; - - for (server_slot & slot : slots) { - // skip the slot if it is not available - if (slot.is_processing()) { - continue; - } - - // select the current slot if the criteria match - if (!ret || slot.t_last_used <= t_last) { - t_last = slot.t_last_used; - ret = &slot; - } - } - - if (ret != nullptr) { - SLT_INF(*ret, "selected slot by LRU, t_last = %" PRId64 "\n", t_last); - - update_cache = true; - } - } - - if (ret) { - const auto & tokens = ret->prompt.tokens; - - update_cache = update_cache && prompt_cache; - - // cache prompts only for completion tasks - update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION; - - // don't update the cache if the slot's context is empty - update_cache = update_cache && tokens.size() > 0; - - // TODO: mtmd does not support prompt cache - update_cache = update_cache && (ret->mctx == nullptr); - - if (update_cache) { - SRV_WRN("%s", "updating prompt cache\n"); - - const int64_t t_start = ggml_time_us(); - - ret->prompt_save(*prompt_cache); - - if (!ret->prompt_load(*prompt_cache, task.tokens)) { - clear_slot(*ret); - } - - prompt_cache->update(); - - SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); - } - } - - return ret; - } - - void clear_slot(server_slot & slot) const { - GGML_ASSERT(!slot.is_processing()); - - SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size()); - - llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); - slot.prompt.tokens.clear(); - } - - // return true if at least one slot has been cleared - // TODO: improve logic - // - smarter decision which slot to clear (LRU or longest prompt?) - // - move slot to level 2 cache instead of removing? - // - instead of purging, try to store and resume later? - bool try_clear_idle_slots() { - bool res = false; - - if (!params_base.kv_unified) { - return res; - } - - for (auto & slot : slots) { - if (slot.is_processing()) { - continue; - } - - if (slot.prompt.n_tokens() > 0) { - SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size()); - - clear_slot(slot); - - res = true; - - // clear slots one by one - break; - } - } - - return res; - } - - bool launch_slot_with_task(server_slot & slot, server_task && task) { - slot.reset(); - - if (!are_lora_equal(task.params.lora, slot.lora)) { - // if lora has changed, check to see if the cache should be cleared - if (lora_should_clear_cache(slot.lora, task.params.lora)) { - SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), task.params.lora.size()); - slot.prompt.tokens.clear(); - } else { - SLT_INF(slot, "keeping cache for alora. %zu target loras\n", task.params.lora.size()); - } - slot.lora = task.params.lora; - } - - // if using alora, make sure it's only a single one requested and active - size_t alora_invocation_start = task.tokens.size(); - if (lora_all_alora(slot.lora)) { - const auto & enabled_ids = lora_get_enabled_ids(slot.lora); - // TODO: This will error out if a user requests two aloras, but only - // provides the activation string for one. We could, instead search - // for all requested alora activation strings and then either keep - // only the last one, or reject if multiple are found. - if (enabled_ids.size() != 1) { - send_error(task, "Cannot run multiple aLoRAs in a single request", ERROR_TYPE_INVALID_REQUEST); - return false; - } - const auto & lora = slot.lora[enabled_ids[0]].ptr; - - // get the pointer and count for the invocation tokens - const uint64_t n_invocation_tokens = llama_adapter_get_alora_n_invocation_tokens(lora); - const llama_token * invocation_tokens = llama_adapter_get_alora_invocation_tokens (lora); - - // scan backwards through the prompt tokens to find the last - // occurrence of the invocation sequence - int match_idx = static_cast(n_invocation_tokens) - 1; - for (int i = task.tokens.size() - 1; i >= 0; --i) { - // the token in this position matches the next token to find in - // the invocation sequence - if (task.tokens[i] == invocation_tokens[match_idx]) { - // if it's a full match, we've found the start - if (match_idx == 0) { - alora_invocation_start = i; - break; - } - // otherwise, check the next token in the sequence - --match_idx; - } else { - // no match in this position, so start looking over again - match_idx = static_cast(n_invocation_tokens) - 1; - } - } - - // if the activation string is not found, disable the alora - if (alora_invocation_start == task.tokens.size()) { - SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]); - slot.lora[enabled_ids[0]].scale = 0.0f; - } else { - SLT_DBG(slot, "alora %zu activated starting at %zu\n", enabled_ids[0], alora_invocation_start); - slot.alora_invocation_start = alora_invocation_start; - } - } - - if (!task.tokens.validate(ctx)) { - send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); - return false; - } - - SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); - - // initialize samplers - { - if (slot.smpl != nullptr) { - common_sampler_free(slot.smpl); - } - - slot.smpl = common_sampler_init(model, task.params.sampling); - if (slot.smpl == nullptr) { - // for now, the only error that may happen here is invalid grammar - send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); - return false; - } - - SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str()); - } - - // initialize draft batch - // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] - if (slot.ctx_dft) { - llama_batch_free(slot.batch_spec); - - slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1); - } - - slot.task = std::make_unique(std::move(task)); - - slot.state = SLOT_STATE_STARTED; - - SLT_INF(slot, "%s", "processing task\n"); - - return true; - } - - bool process_token(completion_token_output & result, server_slot & slot) { - // remember which tokens were sampled - used for repetition penalties during sampling - const std::string token_str = result.text_to_send; - slot.sampled = result.tok; - - slot.generated_text += token_str; - if (slot.task->params.return_tokens) { - slot.generated_tokens.push_back(result.tok); - } - slot.has_next_token = true; - - // check if there is incomplete UTF-8 character at the end - bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); - - // search stop word and delete it - if (!incomplete) { - size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); - - const std::string str_test = slot.generated_text.substr(pos); - bool send_text = true; - - size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); - if (stop_pos != std::string::npos) { - slot.generated_text.erase( - slot.generated_text.begin() + pos + stop_pos, - slot.generated_text.end()); - pos = std::min(slot.n_sent_text, slot.generated_text.size()); - } else if (slot.has_next_token && !llama_vocab_is_eog(vocab, result.tok) ) { - stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); - send_text = stop_pos == std::string::npos; - } - - // check if there is any token to predict - if (send_text) { - // no send the stop word in the response - result.text_to_send = slot.generated_text.substr(pos, std::string::npos); - slot.n_sent_text += result.text_to_send.size(); - // add the token to slot queue and cache - } else { - result.text_to_send = ""; - } - - slot.add_token(result); - if (slot.task->params.stream) { - send_partial_response(slot, result, false); - } - } - - if (incomplete) { - slot.has_next_token = true; - } - - // if context shifting is disabled, make sure that we don't run out of context - if (!params_base.ctx_shift && slot.prompt.n_tokens() + 1 >= slot.n_ctx) { - slot.truncated = true; - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped due to running out of context capacity, prompt.n_tokens() = %d, task.n_tokens = %d, n_decoded = %d, n_ctx = %d\n", - slot.prompt.n_tokens(), slot.task->n_tokens(), slot.n_decoded, slot.n_ctx); - } - - // check the limits - if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.task->params.n_predict); - } - - if (slot.has_new_line) { - // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent - if (slot.task->params.n_indent > 0) { - // check the current indentation - // TODO: improve by not doing it more than once for each new line - if (slot.last_nl_pos > 0) { - size_t pos = slot.last_nl_pos; - - int n_indent = 0; - while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) { - n_indent++; - pos++; - } - - if (pos < slot.generated_text.size() && n_indent < slot.task->params.n_indent) { - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - // cut the last line - slot.generated_text.erase(pos, std::string::npos); - - SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent); - } - } - - // find the next new line - { - const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos); - - if (pos != std::string::npos) { - slot.last_nl_pos = pos + 1; - } - } - } - } - - // check if there is a new line in the generated text - if (result.text_to_send.find('\n') != std::string::npos) { - slot.has_new_line = true; - - // if we have seen a new line, we stop after a certain time limit, but only upon another new line - if (slot.task->params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.task->params.t_max_predict_ms)) { - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.task->params.t_max_predict_ms); - } - } - - if (llama_vocab_is_eog(vocab, result.tok)) { - slot.stop = STOP_TYPE_EOS; - slot.has_next_token = false; - - SLT_DBG(slot, "%s", "stopped by EOS\n"); - } - - SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); - - return slot.has_next_token; // continue - } - - void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { - size_t n_probs = slot.task->params.sampling.n_probs; - size_t n_vocab = llama_vocab_n_tokens(vocab); - - if (post_sampling) { - const auto * cur_p = common_sampler_get_candidates(slot.smpl, true); - const size_t max_probs = cur_p->size; - - // set probability for sampled token - for (size_t i = 0; i < max_probs; i++) { - if (cur_p->data[i].id == result.tok) { - result.prob = cur_p->data[i].p; - break; - } - } - - // set probability for top n_probs tokens - result.probs.reserve(max_probs); - for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { - result.probs.push_back({ - cur_p->data[i].id, - common_token_to_piece(ctx, cur_p->data[i].id, special), - cur_p->data[i].p - }); - } - } else { - // TODO: optimize this with min-p optimization - std::vector cur = get_token_probabilities(ctx, idx); - - // set probability for sampled token - for (size_t i = 0; i < n_vocab; i++) { - // set probability for sampled token - if (cur[i].id == result.tok) { - result.prob = cur[i].p; - break; - } - } - - // set probability for top n_probs tokens - result.probs.reserve(n_probs); - for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { - result.probs.push_back({ - cur[i].id, - common_token_to_piece(ctx, cur[i].id, special), - cur[i].p - }); - } - } - } - - void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(task.id, error, type); - } - - void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(slot.task->id, error, type, slot.task->n_tokens(), slot.n_ctx); - } - - void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) { - SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str()); - - if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) { - GGML_ASSERT(n_ctx > 0 && n_prompt_tokens > 0); - } - - auto res = std::make_unique(); - res->id = id_task; - res->err_type = type; - res->err_msg = error; - res->n_prompt_tokens = n_prompt_tokens; - res->n_ctx = n_ctx; - - queue_results.send(std::move(res)); - } - - // if multimodal is enabled, send an error and return false - bool check_no_mtmd(const int id_task) { - if (mctx) { - send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); - return false; - } - return true; - } - - void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) { - auto res = std::make_unique(); - - res->id = slot.task->id; - res->index = slot.task->index; - - if (is_progress) { - res->is_progress = true; - res->progress.total = slot.task->n_tokens(); - res->progress.cache = slot.n_prompt_tokens_cache; - res->progress.processed = slot.prompt.tokens.size(); - res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt) / 1000; - } else { - res->content = tkn.text_to_send; - res->tokens = { tkn.tok }; - - slot.update_chat_msg(res->oaicompat_msg_diffs); - } - - res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.task->n_tokens(); - res->post_sampling_probs = slot.task->params.post_sampling_probs; - - res->verbose = slot.task->params.verbose; - res->res_type = slot.task->params.res_type; - res->oaicompat_model = slot.task->params.oaicompat_model; - res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; - - // populate res.probs_output - if (slot.task->params.sampling.n_probs > 0) { - res->prob_output = tkn; // copy the token probs - } - - // populate timings if this is final response or timings_per_token is enabled - if (slot.stop != STOP_TYPE_NONE || slot.task->params.timings_per_token) { - res->timings = slot.get_timings(); - } - - queue_results.send(std::move(res)); - } - - void send_final_response(server_slot & slot) { - auto res = std::make_unique(); - - res->id = slot.task->id; - res->id_slot = slot.id; - - res->index = slot.task->index; - res->content = slot.generated_text; - res->tokens = std::move(slot.generated_tokens); - res->timings = slot.get_timings(); - res->prompt = slot.task->tokens.detokenize(ctx, true); - res->response_fields = std::move(slot.task->params.response_fields); - - res->truncated = slot.truncated; - res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.task->n_tokens(); - res->n_tokens_cached = slot.prompt.n_tokens(); - res->has_new_line = slot.has_new_line; - res->stopping_word = slot.stopping_word; - res->stop = slot.stop; - res->post_sampling_probs = slot.task->params.post_sampling_probs; - - res->verbose = slot.task->params.verbose; - res->stream = slot.task->params.stream; - res->include_usage = slot.task->params.include_usage; - res->res_type = slot.task->params.res_type; - res->oaicompat_model = slot.task->params.oaicompat_model; - res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; - res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); - - // populate res.probs_output - if (slot.task->params.sampling.n_probs > 0) { - if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) { - const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); - - size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); - res->probs_output = std::vector( - slot.generated_token_probs.begin(), - slot.generated_token_probs.end() - safe_offset); - } else { - res->probs_output = std::vector( - slot.generated_token_probs.begin(), - slot.generated_token_probs.end()); - } - } - - res->generation_params = slot.task->params; // copy the parameters - - queue_results.send(std::move(res)); - } - - void send_embedding(const server_slot & slot, const llama_batch & batch) { - auto res = std::make_unique(); - res->id = slot.task->id; - res->index = slot.task->index; - res->n_tokens = slot.task->n_tokens(); - res->res_type = slot.task->params.res_type; - - const int n_embd = llama_model_n_embd(model); - - std::vector embd_res(n_embd, 0.0f); - - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { - continue; - } - - const float * embd = nullptr; - if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) { - embd = llama_get_embeddings_ith(ctx, i); - } else { - embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - } - - if (embd == nullptr) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); - - res->embedding.push_back(std::vector(n_embd, 0.0f)); - continue; - } - - // normalize only when there is pooling - if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { - common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize); - res->embedding.push_back(embd_res); - break; - } - - res->embedding.emplace_back(embd, embd + n_embd); - } - - SLT_DBG(slot, "%s", "sending embeddings\n"); - - queue_results.send(std::move(res)); - } - - void send_rerank(const server_slot & slot, const llama_batch & batch) { - auto res = std::make_unique(); - res->id = slot.task->id; - res->index = slot.task->index; - res->n_tokens = slot.task->n_tokens(); - - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { - continue; - } - - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); - } - - if (embd == NULL) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); - - res->score = -1e6; - continue; - } - - res->score = embd[0]; - } - - SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score); - - queue_results.send(std::move(res)); - } - - // - // Functions to process the task - // - - void process_single_task(server_task && task) { - switch (task.type) { - case SERVER_TASK_TYPE_COMPLETION: - case SERVER_TASK_TYPE_INFILL: - case SERVER_TASK_TYPE_EMBEDDING: - case SERVER_TASK_TYPE_RERANK: - { - const int id_slot = task.id_slot; - - server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); - - if (slot == nullptr) { - // if no slot is available, we defer this task for processing later - SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - if (!launch_slot_with_task(*slot, std::move(task))) { - SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); - break; - } - } break; - case SERVER_TASK_TYPE_CANCEL: - { - // release slot linked with the task id - for (auto & slot : slots) { - if (slot.task && slot.task->id == task.id_target) { - slot.release(); - break; - } - } - } break; - case SERVER_TASK_TYPE_NEXT_RESPONSE: - { - // do nothing - } break; - case SERVER_TASK_TYPE_METRICS: - { - json slots_data = json::array(); - - int n_idle_slots = 0; - int n_processing_slots = 0; - - for (server_slot & slot : slots) { - json slot_data = slot.to_json(slots_debug == 0); - - if (slot.is_processing()) { - n_processing_slots++; - } else { - n_idle_slots++; - } - - slots_data.push_back(slot_data); - } - SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots); - - auto res = std::make_unique(); - res->id = task.id; - res->slots_data = std::move(slots_data); - res->n_idle_slots = n_idle_slots; - res->n_processing_slots = n_processing_slots; - res->n_tasks_deferred = queue_tasks.queue_tasks_deferred_size(); - res->t_start = metrics.t_start; - - res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; - res->t_prompt_processing_total = metrics.t_prompt_processing_total; - res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; - res->t_tokens_generation_total = metrics.t_tokens_generation_total; - - res->n_tokens_max = metrics.n_tokens_max; - - res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; - res->t_prompt_processing = metrics.t_prompt_processing; - res->n_tokens_predicted = metrics.n_tokens_predicted; - res->t_tokens_generation = metrics.t_tokens_generation; - - res->n_decode_total = metrics.n_decode_total; - res->n_busy_slots_total = metrics.n_busy_slots_total; - - if (task.metrics_reset_bucket) { - metrics.reset_bucket(); - } - queue_results.send(std::move(res)); - } break; - case SERVER_TASK_TYPE_SLOT_SAVE: - { - if (!check_no_mtmd(task.id)) { - break; - } - - int id_slot = task.slot_action.slot_id; - server_slot * slot = get_slot_by_id(id_slot); - if (slot == nullptr) { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } - if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - const size_t token_count = slot->prompt.tokens.size(); - const int64_t t_start = ggml_time_us(); - - std::string filename = task.slot_action.filename; - std::string filepath = task.slot_action.filepath; - - const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens(); - const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); - - const int64_t t_end = ggml_time_us(); - const double t_save_ms = (t_end - t_start) / 1000.0; - - auto res = std::make_unique(); - res->id = task.id; - res->id_slot = id_slot; - res->filename = filename; - res->is_save = true; - res->n_tokens = token_count; - res->n_bytes = nwrite; - res->t_ms = t_save_ms; - queue_results.send(std::move(res)); - } break; - case SERVER_TASK_TYPE_SLOT_RESTORE: - { - if (!check_no_mtmd(task.id)) break; - int id_slot = task.slot_action.slot_id; - server_slot * slot = get_slot_by_id(id_slot); - if (slot == nullptr) { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } - if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - const int64_t t_start = ggml_time_us(); - - std::string filename = task.slot_action.filename; - std::string filepath = task.slot_action.filepath; - - llama_tokens tokens; - tokens.resize(slot->n_ctx); - size_t token_count = 0; - size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); - if (nread == 0) { - slot->prompt.tokens.clear(); // KV may already been invalidated? - send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); - break; - } - tokens.resize(token_count); - slot->prompt.tokens.clear(); - slot->prompt.tokens.insert(tokens); - - const int64_t t_end = ggml_time_us(); - const double t_restore_ms = (t_end - t_start) / 1000.0; - - auto res = std::make_unique(); - res->id = task.id; - res->id_slot = id_slot; - res->filename = filename; - res->is_save = false; - res->n_tokens = token_count; - res->n_bytes = nread; - res->t_ms = t_restore_ms; - queue_results.send(std::move(res)); - } break; - case SERVER_TASK_TYPE_SLOT_ERASE: - { - if (!check_no_mtmd(task.id)) { - break; - } - int id_slot = task.slot_action.slot_id; - server_slot * slot = get_slot_by_id(id_slot); - if (slot == nullptr) { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } - if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - // Erase token cache - const size_t n_erased = slot->prompt.tokens.size(); - - clear_slot(*slot); - - auto res = std::make_unique(); - res->id = task.id; - res->id_slot = id_slot; - res->n_erased = n_erased; - queue_results.send(std::move(res)); - } break; - case SERVER_TASK_TYPE_SET_LORA: - { - params_base.lora_adapters = std::move(task.set_lora); - auto res = std::make_unique(); - res->id = task.id; - queue_results.send(std::move(res)); - } break; - - } - } - - void update_slots() { - // check if all slots are idle - { - bool all_idle = true; - - for (auto & slot : slots) { - if (slot.is_processing()) { - all_idle = false; - break; - } - } - - if (all_idle) { - SRV_INF("%s", "all slots are idle\n"); - - return; - } - } - - { - SRV_DBG("%s", "posting NEXT_RESPONSE\n"); - - server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); - task.id = queue_tasks.get_new_id(); - queue_tasks.post(std::move(task)); - } - - // apply context-shift if needed - // TODO: simplify and improve - for (server_slot & slot : slots) { - if (slot.state == SLOT_STATE_GENERATING && slot.prompt.n_tokens() + 1 >= slot.n_ctx) { - if (!params_base.ctx_shift) { - // this check is redundant (for good) - // we should never get here, because generation should already stopped in process_token() - send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); - slot.release(); - continue; - } - - if (mctx) { - // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded - // we don't support ctx_shift because an image chunk may contains multiple tokens - GGML_ABORT("not supported by multimodal"); - } - - // Shift context - int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep; - - if (add_bos_token) { - n_keep += 1; - } - - n_keep = std::min(slot.n_ctx - 4, n_keep); - - const int n_left = slot.prompt.n_tokens() - n_keep; - const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2); - - SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); - - llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard); - llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); - - // add generated tokens to cache - // ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481 - { - GGML_ASSERT(!slot.prompt.tokens.has_mtmd); - - llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy - for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { - new_tokens[i - n_discard] = new_tokens[i]; - } - - new_tokens.resize(slot.prompt.tokens.size() - n_discard); - - slot.prompt.tokens.clear(); - slot.prompt.tokens.insert(new_tokens); - } - - slot.truncated = true; - } - } - - // start populating the batch for this iteration - common_batch_clear(batch); - - // track if given slot can be batched with slots already in the batch - server_slot * slot_batched = nullptr; - - auto accept_special_token = [&](server_slot & slot, llama_token token) { - return params_base.special || - slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); - }; - - // first, add sampled tokens from any ongoing sequences - for (auto & slot : slots) { - if (slot.state != SLOT_STATE_GENERATING) { - continue; - } - - // check if we can batch this slot with the previous one - if (!slot_batched) { - slot_batched = &slot; - } else if (!slot_batched->can_batch_with(slot)) { - continue; - } - - slot.i_batch = batch.n_tokens; - - common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); - - slot.prompt.tokens.push_back(slot.sampled); - - SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n", - slot.n_ctx, slot.prompt.n_tokens(), slot.truncated); - } - - // process in chunks of params.n_batch - int32_t n_batch = llama_n_batch(ctx); - int32_t n_ubatch = llama_n_ubatch(ctx); - - float alora_scale = -1.0f; - size_t alora_disabled_id = 0; - - // next, batch any pending prompts without exceeding n_batch - if (params_base.cont_batching || batch.n_tokens == 0) { - for (auto & slot : slots) { - if (!slot.is_processing()) { - continue; - } - - // check if we can batch this slot with the previous one - if (slot_batched && !slot_batched->can_batch_with(slot)) { - continue; - } - - // this slot still has a prompt to be processed - if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { - const auto & input_tokens = slot.task->tokens; - - // TODO: maybe move branch to outside of this loop in the future - if (slot.state == SLOT_STATE_STARTED) { - slot.t_start_process_prompt = ggml_time_us(); - slot.t_start_generation = 0; - - slot.state = SLOT_STATE_PROCESSING_PROMPT; - - SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n", - slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens()); - - // print prompt tokens (for debugging) - /*if (1) { - // first 16 tokens (avoid flooding logs) - for (int i = 0; i < std::min(16, input_tokens.size()); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); - } - } else { - // all - for (int i = 0; i < (int) input_tokens.size(); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); - } - }*/ - - // keep track how many tokens we can reuse from the previous state - int n_past = 0; - - // empty prompt passed -> release the slot and send empty response - if (input_tokens.empty()) { - SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); - - slot.print_timings(); - send_final_response(slot); - slot.release(); - - continue; - } - - // TODO: support memory-less logits computation - if (slot.need_logits() && !llama_get_memory(ctx)) { - send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); - slot.release(); - continue; - } - - if (!slot.can_split()) { - if (slot.task->n_tokens() > n_ubatch) { - send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); - slot.release(); - continue; - } - - if (slot.task->n_tokens() > slot.n_ctx) { - send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE); - slot.release(); - continue; - } - } else { - if (slot.task->n_tokens() >= slot.n_ctx) { - send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE); - slot.release(); - continue; - } - - if (slot.task->params.cache_prompt) { - // reuse any previously computed tokens that are common with the new prompt - n_past = slot.prompt.tokens.get_common_prefix(input_tokens); - - // if there is an alora invoked, don't cache after the invocation start - if (slot.alora_invocation_start > 0) { - SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start); - n_past = std::min(n_past, slot.alora_invocation_start - 1); - } - - // reuse chunks from the cached prompt by shifting their KV cache in the new position - if (params_base.n_cache_reuse > 0) { - GGML_ASSERT(!slot.prompt.tokens.has_mtmd); - - size_t head_c = n_past; // cache - size_t head_p = n_past; // current prompt - - if (mctx) { - // we should never reach this - GGML_ABORT("not supported by multimodal"); - } - - SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past); - - while (head_c < slot.prompt.tokens.size() && - head_p < input_tokens.size()) { - - size_t n_match = 0; - while (head_c + n_match < slot.prompt.tokens.size() && - head_p + n_match < input_tokens.size() && - slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) { - - n_match++; - } - - if (n_match >= (size_t) params_base.n_cache_reuse) { - SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); - //for (size_t i = head_p; i < head_p + n_match; i++) { - // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); - //} - - const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; - - llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c); - llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift); - - for (size_t i = 0; i < n_match; i++) { - slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]); - n_past++; - } - - head_c += n_match; - head_p += n_match; - } else { - head_c += 1; - } - } - - SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past); - } - } else { - // if we don't cache the prompt, we have to remove all previous tokens - n_past = 0; - } - - // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1 - const auto n_swa = std::max(1, llama_model_n_swa(model)); - - // the largest pos_min required for a checkpoint to be useful - const auto pos_min_thold = std::max(0, n_past - n_swa); - - // note: disallow with mtmd contexts for now - // https://github.com/ggml-org/llama.cpp/issues/17043 - if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) { - const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); - if (pos_min == -1) { - SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); - GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); - } - - // when the prompt prefix does not match, print the tokens around the mismatch - // this is useful for debugging prompt caching - if (slots_debug) { - const int np0 = std::max(n_past - 4, 0); - const int np1 = std::min(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size())); - - std::stringstream ss0; - std::stringstream ss1; - - std::stringstream st0; - std::stringstream st1; - - ss0 << "old: ... "; - ss1 << "new: ... "; - - for (int i = np0; i < np1; i++) { - if (i == n_past) { - ss0 << " | "; - ss1 << " | "; - } - - { - const auto token = slot.prompt.tokens[i]; - const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; - ss0 << piece; - st0 << std::setw(8) << token; - } - - { - const auto token = slot.task->tokens[i]; - const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; - ss1 << piece; - st1 << std::setw(8) << token; - } - } - - SLT_WRN(slot, "%s\n", ss0.str().c_str()); - SLT_WRN(slot, "%s\n", ss1.str().c_str()); - - SLT_WRN(slot, "%s\n", st0.str().c_str()); - SLT_WRN(slot, "%s\n", st1.str().c_str()); - } - - if (pos_min > pos_min_thold) { - // TODO: support can be added in the future when corresponding vision models get released - GGML_ASSERT(!slot.prompt.tokens.has_mtmd); - - SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); - - // search for a context checkpoint - const auto it = std::find_if( - slot.prompt.checkpoints.rbegin(), - slot.prompt.checkpoints.rend(), - [&](const auto & cur) { - // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] - return cur.pos_min < pos_min_thold; - } - ); - - bool do_reset = it == slot.prompt.checkpoints.rend(); - - if (!do_reset) { - // restore the context checkpoint - const size_t checkpoint_size = it->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - if (n != checkpoint_size) { - SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); - do_reset = true; - //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); - } else { - n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max)); - SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); - } - } - - if (do_reset) { - SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n", - "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); - n_past = 0; - } - } - } - - { - // erase any checkpoints with pos_min > pos_min_thold - for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { - const auto & cur = *it; - if (cur.pos_min > pos_min_thold) { - SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024); - it = slot.prompt.checkpoints.erase(it); - } else { - ++it; - } - } - } - } - - // [TAG_PROMPT_LOGITS] - if (n_past == slot.task->n_tokens() && n_past > 0) { - SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens()); - n_past--; - SLT_WRN(slot, "n_past was set to %d\n", n_past); - } - - slot.n_prompt_tokens_cache = n_past; - slot.n_prompt_tokens_processed = 0; - - slot.prompt.tokens.keep_first(n_past); - } - - if (!slot.can_split()) { - // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.task->n_tokens() > n_batch) { - continue; - } - } - - // truncate any tokens that are beyond n_past for this slot - const llama_pos p0 = slot.prompt.tokens.pos_next(); - - SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); - - if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) { - SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); - - clear_slot(slot); - - // there is no common part left - slot.n_prompt_tokens_cache = 0; - } - - // check if we should process the image - if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { - // process the image - size_t n_tokens_out = 0; - int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); - if (res != 0) { - SLT_ERR(slot, "failed to process image, res = %d\n", res); - send_error(slot, "failed to process image", ERROR_TYPE_SERVER); - slot.release(); - continue; - } - - slot.n_prompt_tokens_processed += n_tokens_out; - - // add the image chunk to cache - { - const auto & chunk = input_tokens.find_chunk(slot.prompt.n_tokens()); - slot.prompt.tokens.push_back(chunk.get()); // copy - } - } - - // If using an alora, there may be uncached tokens that come - // before the invocation sequence. When this happens, the - // tokens before the invocation sequence need to be - // processed without the adapter in a separate batch, then - // the adapter needs to be enabled for the remaining tokens. - if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) { - SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start); - const auto & enabled_loras = lora_get_enabled_ids(slot.lora); - GGML_ASSERT(enabled_loras.size() == 1); - alora_scale = slot.lora[enabled_loras[0]].scale; - slot.lora[enabled_loras[0]].scale = 0.0f; - alora_disabled_id = enabled_loras[0]; - } - - bool do_checkpoint = params_base.n_ctx_checkpoints > 0; - - // make checkpoints only for completion tasks - do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION; - - // make a checkpoint of the parts of the memory that cannot be rolled back. - // checkpoints are created only if: - // - the model uses SWA and we are not using `swa_full` - // - the model architecture is marked as recurrent or hybrid - // - // TODO: try to make this conditional on the context or the memory module, instead of the model type - do_checkpoint = do_checkpoint && ( - llama_model_is_recurrent(model) || - llama_model_is_hybrid(model) || - (llama_model_n_swa(model) > 0 && !params_base.swa_full) - ); - - // add prompt tokens for processing in the current batch - while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) { - // get next token to process - llama_token cur_tok = input_tokens[slot.prompt.n_tokens()]; - if (cur_tok == LLAMA_TOKEN_NULL) { - break; // end of text chunk - } - - // if this is an alora request with pre-invocation - // tokens that are not cached, we need to stop filling - // this batch at those pre-invocation tokens. - if (alora_scale > 0 && slot.prompt.n_tokens() == slot.alora_invocation_start - 1) { - SLT_DBG(slot, "stop prompt batch filling at (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start); - break; - } - - // embedding requires all tokens in the batch to be output - common_batch_add(batch, - cur_tok, - slot.prompt.tokens.pos_next(), - { slot.id }, - slot.need_embd()); - slot.prompt.tokens.push_back(cur_tok); - - slot.n_prompt_tokens_processed++; - - // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. - if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) { - break; - } - } - - // SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str()); - - SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens()); - - // entire prompt has been processed - if (slot.prompt.n_tokens() == slot.task->n_tokens()) { - slot.state = SLOT_STATE_DONE_PROMPT; - - GGML_ASSERT(batch.n_tokens > 0); - - common_sampler_reset(slot.smpl); - - // Process all prompt tokens through sampler system - for (int i = 0; i < slot.task->n_tokens(); ++i) { - llama_token id = input_tokens[i]; - if (id != LLAMA_TOKEN_NULL) { - common_sampler_accept(slot.smpl, id, false); - } - } - - // extract the logits only for the last token - batch.logits[batch.n_tokens - 1] = true; - - slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; - - SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens); - - const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); - const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); - - // no need for empty or small checkpoints - do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); - - // no need to create checkpoints that are too close together - do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); - - if (do_checkpoint) { - while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { - // make room for the new checkpoint, if needed - const auto & cur = slot.prompt.checkpoints.front(); - - SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); - - slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); - } - - const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ - /*.pos_min = */ pos_min, - /*.pos_max = */ pos_max, - /*.data = */ std::vector(checkpoint_size), - }); - - llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); - } - } - } - - if (!slot_batched) { - slot_batched = &slot; - } - - if (batch.n_tokens >= n_batch) { - break; - } - } - } - - if (batch.n_tokens == 0) { - SRV_WRN("%s", "no tokens to decode\n"); - return; - } - - SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); - - if (slot_batched) { - // apply lora, only need to do it once per batch - common_set_adapter_lora(ctx, slot_batched->lora); - - // if the lora is temporarily disabled for an alora, re-enable it - // for next time - if (alora_scale > 0.0f) { - SRV_DBG("re-enabling alora with scale %f\n", alora_scale); - slot_batched->lora[alora_disabled_id].scale = alora_scale; - } - - llama_set_embeddings(ctx, slot_batched->need_embd()); - } - - int32_t i_next = 0; - - // process the created batch of tokens - for (int32_t i = 0; i < batch.n_tokens; i = i_next) { - const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); - - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - }; - - const int ret = llama_decode(ctx, batch_view); - - metrics.on_decoded(slots); - - if (ret != 0) { - { - std::string err; - - if (n_batch == 1 && ret == 1) { - // TODO: try to terminate only the largest active slot/sequence and continue with the rest - // need to remove the tokens from the current batch too - err = "Context size has been exceeded."; - } - - if (ret == -1) { - err = "Invalid input batch."; - } - - if (ret < -1) { - // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max() - err = "Compute error."; - } - - // TODO: handle ret == 2 (abort) when we start aborting - - if (!err.empty()) { - SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); - - for (auto & slot : slots) { - if (slot.is_processing()) { - send_error(slot, err); - slot.release(); - - // note: it's complicated to keep track of how much of the current batch has been - // processed before the error occurred, so we simply clear the entire context - clear_slot(slot); - } - } - - break; - } - } - - // retry with half the batch size to try to find a free slot in the KV cache - if (!try_clear_idle_slots()) { - n_batch /= 2; - } - - SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); - - continue; // continue loop of n_batch - } - - // move the head of the batch forward with the number of tokens we just processed - i_next = i + n_tokens; - - // on successful decode, restore the original batch size - n_batch = llama_n_batch(ctx); - - for (auto & slot : slots) { - // optionally send prompt processing progress - if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task->params.stream && slot.task->params.return_progress) { - send_partial_response(slot, {}, true); - } - } - - if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { - continue; // continue loop of slots - } - - if (slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) { - // prompt evaluated for embedding - send_embedding(slot, batch_view); - slot.release(); - slot.i_batch = -1; - continue; // continue loop of slots - } - - if (slot.task->type == SERVER_TASK_TYPE_RERANK) { - send_rerank(slot, batch_view); - slot.release(); - slot.i_batch = -1; - continue; // continue loop of slots - } - - // prompt evaluated for next-token prediction - slot.state = SLOT_STATE_GENERATING; - } else if (slot.state != SLOT_STATE_GENERATING) { - continue; // continue loop of slots - } - - const int tok_idx = slot.i_batch - i; - - llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); - - slot.i_batch = -1; - - common_sampler_accept(slot.smpl, id, true); - - slot.n_decoded += 1; - - const int64_t t_current = ggml_time_us(); - - if (slot.n_decoded == 1) { - slot.t_start_generation = t_current; - slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; - metrics.on_prompt_eval(slot); - } - - slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; - - completion_token_output result; - result.tok = id; - result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); - result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs - - if (slot.task->params.sampling.n_probs > 0) { - populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx); - } - - if (!process_token(result, slot)) { - // release slot because of stop condition - slot.print_timings(); - send_final_response(slot); - metrics.on_prediction(slot); - slot.release(); - - continue; - } - } - - // do speculative decoding - // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] - // perform the speculative drafting for all sequences at the same time in a single batch - for (auto & slot : slots) { - if (!slot.is_processing() || !slot.can_speculate()) { - continue; - } - - if (slot.state != SLOT_STATE_GENERATING) { - continue; - } - - if (mctx) { - // we should never reach this, as speculative is automatically disabled if mmproj is loaded - GGML_ABORT("not supported by multimodal"); - } - - // determine the max draft that fits the current slot state - int n_draft_max = slot.task->params.speculative.n_max; - - // note: slot.prompt is not yet expanded with the `id` token sampled above - // also, need to leave space for 1 extra token to allow context shifts - n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.prompt.n_tokens() - 2); - - if (slot.n_remaining > 0) { - n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); - } - - SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); - - if (n_draft_max < slot.task->params.speculative.n_min) { - SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.task->params.speculative.n_min); - - continue; - } - - llama_token id = slot.sampled; - - struct common_speculative_params params_spec; - params_spec.n_draft = n_draft_max; - params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max; - params_spec.p_min = slot.task->params.speculative.p_min; - - const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, id); - - // ignore small drafts - if (slot.task->params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); - - continue; - } - - // keep track of total number of drafted tokens tested - slot.n_draft_total += draft.size(); - - // construct the speculation batch - common_batch_clear(slot.batch_spec); - common_batch_add (slot.batch_spec, id, slot.prompt.tokens.pos_next(), { slot.id }, true); - - for (size_t i = 0; i < draft.size(); ++i) { - common_batch_add(slot.batch_spec, draft[i], slot.prompt.tokens.pos_next() + 1 + i, { slot.id }, true); - } - - SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); - - llama_decode(ctx, slot.batch_spec); - - // the accepted tokens from the speculation - const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); - - slot.n_decoded += ids.size(); - - // update how many tokens out of those tested were accepted - slot.n_draft_accepted += ids.size() - 1; - - slot.prompt.tokens.push_back(id); - slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); - - llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1); - - for (size_t i = 0; i < ids.size(); ++i) { - completion_token_output result; - - result.tok = ids[i]; - result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); - result.prob = 1.0f; // set later - - // TODO: set result.probs - - if (!process_token(result, slot)) { - slot.print_timings(); - send_final_response(slot); - metrics.on_prediction(slot); - slot.release(); - - break; - } - } - - SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.prompt.n_tokens()); - } - } - - SRV_DBG("%s", "run slots completed\n"); - } - - json model_meta() const { - return json { - {"vocab_type", llama_vocab_type (vocab)}, - {"n_vocab", llama_vocab_n_tokens (vocab)}, - {"n_ctx_train", llama_model_n_ctx_train(model)}, - {"n_embd", llama_model_n_embd (model)}, - {"n_params", llama_model_n_params (model)}, - {"size", llama_model_size (model)}, - }; - } -}; - - -// generator-like API for server responses, support pooling connection state and aggregating results -struct server_response_reader { - std::unordered_set id_tasks; - server_context & ctx_server; - size_t received_count = 0; - bool cancelled = false; - - server_response_reader(server_context & ctx_server) : ctx_server(ctx_server) {} - ~server_response_reader() { - stop(); - } - - void post_tasks(std::vector && tasks) { - id_tasks = server_task::get_list_id(tasks); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(std::move(tasks)); - } - - bool has_next() const { - return !cancelled && received_count < id_tasks.size(); - } - - // return nullptr if should_stop() is true before receiving a result - // note: if one error is received, it will stop further processing and return error result - server_task_result_ptr next(const std::function & should_stop) { - while (true) { - server_task_result_ptr result = ctx_server.queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); - if (result == nullptr) { - // timeout, check stop condition - if (should_stop()) { - SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n"); - return nullptr; - } - } else { - if (result->is_error()) { - stop(); // cancel remaining tasks - SRV_DBG("%s", "received error result, stopping further processing\n"); - return result; - } - if (result->is_stop()) { - received_count++; - } - return result; - } - } - - // should not reach here - } - - struct batch_response { - bool is_terminated = false; // if true, indicates that processing was stopped before all results were received - std::vector results; - server_task_result_ptr error; // nullptr if no error - }; - - batch_response wait_for_all(const std::function & should_stop) { - batch_response batch_res; - batch_res.results.resize(id_tasks.size()); - while (has_next()) { - auto res = next(should_stop); - if (res == nullptr) { - batch_res.is_terminated = true; - return batch_res; - } - if (res->is_error()) { - batch_res.error = std::move(res); - return batch_res; - } - const size_t idx = res->get_index(); - GGML_ASSERT(idx < batch_res.results.size() && "index out of range"); - GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received"); - batch_res.results[idx] = std::move(res); - } - return batch_res; - } - - void stop() { - ctx_server.queue_results.remove_waiting_task_ids(id_tasks); - if (has_next() && !cancelled) { - // if tasks is not finished yet, cancel them - cancelled = true; - std::vector cancel_tasks; - cancel_tasks.reserve(id_tasks.size()); - for (const auto & id_task : id_tasks) { - SRV_WRN("cancel task, id_task = %d\n", id_task); - server_task task(SERVER_TASK_TYPE_CANCEL); - task.id_target = id_task; - ctx_server.queue_results.remove_waiting_task_id(id_task); - cancel_tasks.push_back(std::move(task)); - } - // push to beginning of the queue, so it has highest priority - ctx_server.queue_tasks.post(std::move(cancel_tasks), true); - } else { - SRV_DBG("%s", "all tasks already finished, no need to cancel\n"); - } - } -}; - -// generator-like API for HTTP response generation -struct server_res_generator : server_http_res { - server_response_reader rd; - server_res_generator(server_context & ctx_server_) : rd(ctx_server_) {} - void ok(const json & response_data) { - status = 200; - data = safe_json_to_str(response_data); - } - void error(const json & error_data) { - status = json_value(error_data, "code", 500); - data = safe_json_to_str({{ "error", error_data }}); - } -}; - -struct server_routes { - const common_params & params; - server_context & ctx_server; - server_http_context & ctx_http; // for reading is_ready - server_routes(const common_params & params, server_context & ctx_server, server_http_context & ctx_http) - : params(params), ctx_server(ctx_server), ctx_http(ctx_http) {} - -public: - // handlers using lambda function, so that they can capture `this` without `std::bind` - - server_http_context::handler_t get_health = [this](const server_http_req &) { - // error and loading states are handled by middleware - auto res = std::make_unique(ctx_server); - res->ok({{"status", "ok"}}); - return res; - }; - - server_http_context::handler_t get_metrics = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - if (!params.endpoint_metrics) { - res->error(format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - // request slots data using task queue - // TODO: use server_response_reader - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_METRICS); - task.id = task_id; - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task), true); // high-priority task - } - - // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - // TODO: get rid of this dynamic_cast - auto res_task = dynamic_cast(result.get()); - GGML_ASSERT(res_task != nullptr); - - // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names - json all_metrics_def = json { - {"counter", {{ - {"name", "prompt_tokens_total"}, - {"help", "Number of prompt tokens processed."}, - {"value", (uint64_t) res_task->n_prompt_tokens_processed_total} - }, { - {"name", "prompt_seconds_total"}, - {"help", "Prompt process time"}, - {"value", (uint64_t) res_task->t_prompt_processing_total / 1.e3} - }, { - {"name", "tokens_predicted_total"}, - {"help", "Number of generation tokens processed."}, - {"value", (uint64_t) res_task->n_tokens_predicted_total} - }, { - {"name", "tokens_predicted_seconds_total"}, - {"help", "Predict process time"}, - {"value", (uint64_t) res_task->t_tokens_generation_total / 1.e3} - }, { - {"name", "n_decode_total"}, - {"help", "Total number of llama_decode() calls"}, - {"value", res_task->n_decode_total} - }, { - {"name", "n_tokens_max"}, - {"help", "Largest observed n_tokens."}, - {"value", res_task->n_tokens_max} - }, { - {"name", "n_busy_slots_per_decode"}, - {"help", "Average number of busy slots per llama_decode() call"}, - {"value", (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)} - }}}, - {"gauge", {{ - {"name", "prompt_tokens_seconds"}, - {"help", "Average prompt throughput in tokens/s."}, - {"value", res_task->n_prompt_tokens_processed ? 1.e3 / res_task->t_prompt_processing * res_task->n_prompt_tokens_processed : 0.} - },{ - {"name", "predicted_tokens_seconds"}, - {"help", "Average generation throughput in tokens/s."}, - {"value", res_task->n_tokens_predicted ? 1.e3 / res_task->t_tokens_generation * res_task->n_tokens_predicted : 0.} - },{ - {"name", "requests_processing"}, - {"help", "Number of requests processing."}, - {"value", (uint64_t) res_task->n_processing_slots} - },{ - {"name", "requests_deferred"}, - {"help", "Number of requests deferred."}, - {"value", (uint64_t) res_task->n_tasks_deferred} - }}} - }; - - std::stringstream prometheus; - - for (const auto & el : all_metrics_def.items()) { - const auto & type = el.key(); - const auto & metrics_def = el.value(); - - for (const auto & metric_def : metrics_def) { - const std::string name = metric_def.at("name"); - const std::string help = metric_def.at("help"); - - auto value = json_value(metric_def, "value", 0.); - prometheus << "# HELP llamacpp:" << name << " " << help << "\n" - << "# TYPE llamacpp:" << name << " " << type << "\n" - << "llamacpp:" << name << " " << value << "\n"; - } - } - - res->headers["Process-Start-Time-Unix"] = std::to_string(res_task->t_start); - res->content_type = "text/plain; version=0.0.4"; - res->status = 200; - res->data = prometheus.str(); - return res; - }; - - server_http_context::handler_t get_slots = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - if (!params.endpoint_slots) { - res->error(format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - // request slots data using task queue - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_METRICS); - task.id = task_id; - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task), true); // high-priority task - } - - // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - // TODO: get rid of this dynamic_cast - auto res_task = dynamic_cast(result.get()); - GGML_ASSERT(res_task != nullptr); - - // optionally return "fail_on_no_slot" error - if (!req.get_param("fail_on_no_slot").empty()) { - if (res_task->n_idle_slots == 0) { - res->error(format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); - return res; - } - } - - res->ok(res_task->slots_data); - return res; - }; - - server_http_context::handler_t post_slots = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - if (params.slot_save_path.empty()) { - res->error(format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - std::string id_slot_str = req.get_param("id_slot"); - int id_slot; - - try { - id_slot = std::stoi(id_slot_str); - } catch (const std::exception &) { - res->error(format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - std::string action = req.get_param("action"); - - if (action == "save") { - return handle_slots_save(req, id_slot); - } else if (action == "restore") { - return handle_slots_restore(req, id_slot); - } else if (action == "erase") { - return handle_slots_erase(req, id_slot); - } else { - res->error(format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - }; - - server_http_context::handler_t get_props = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - json default_generation_settings_for_props; - - { - task_params params; - - params.sampling = ctx_server.params_base.sampling; - - default_generation_settings_for_props = json { - {"params", params.to_json(true)}, - {"n_ctx", ctx_server.slots[0].n_ctx}, - }; - } - - // this endpoint is publicly available, please only return what is safe to be exposed - json data = { - { "default_generation_settings", default_generation_settings_for_props }, - { "total_slots", ctx_server.params_base.n_parallel }, - { "model_alias", ctx_server.params_base.model_alias }, - { "model_path", ctx_server.params_base.model.path }, - { "modalities", json { - {"vision", ctx_server.oai_parser_opt.allow_image}, - {"audio", ctx_server.oai_parser_opt.allow_audio}, - } }, - { "endpoint_slots", params.endpoint_slots }, - { "endpoint_props", params.endpoint_props }, - { "endpoint_metrics", params.endpoint_metrics }, - { "webui", params.webui }, - { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) }, - { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)}, - { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)}, - { "build_info", build_info }, - }; - if (ctx_server.params_base.use_jinja) { - if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) { - data["chat_template_tool_use"] = tool_use_src; - } - } - - res->ok(data); - return res; - }; - - server_http_context::handler_t post_props = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - if (!params.endpoint_props) { - res->error(format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - // update any props here - - res->ok({{ "success", true }}); - return res; - }; - - server_http_context::handler_t get_api_show = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - bool has_mtmd = ctx_server.mctx != nullptr; - json data = { - { - "template", common_chat_templates_source(ctx_server.chat_templates.get()), - }, - { - "model_info", { - { "llama.context_length", ctx_server.slots.back().n_ctx, }, - } - }, - {"modelfile", ""}, - {"parameters", ""}, - {"template", common_chat_templates_source(ctx_server.chat_templates.get())}, - {"details", { - {"parent_model", ""}, - {"format", "gguf"}, - {"family", ""}, - {"families", {""}}, - {"parameter_size", ""}, - {"quantization_level", ""} - }}, - {"model_info", ""}, - {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})} - }; - - res->ok(data); - return res; - }; - - server_http_context::handler_t post_infill = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - // check model compatibility - std::string err; - if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) { - err += "prefix token is missing. "; - } - if (llama_vocab_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) { - err += "suffix token is missing. "; - } - if (llama_vocab_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) { - err += "middle token is missing. "; - } - if (!err.empty()) { - res->error(format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - // validate input - json data = json::parse(req.body); - if (data.contains("prompt") && !data.at("prompt").is_string()) { - // prompt is optional - res->error(format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST)); - } - - if (!data.contains("input_prefix")) { - res->error(format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST)); - } - - if (!data.contains("input_suffix")) { - res->error(format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST)); - } - - if (data.contains("input_extra") && !data.at("input_extra").is_array()) { - // input_extra is optional - res->error(format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - json input_extra = json_value(data, "input_extra", json::array()); - for (const auto & chunk : input_extra) { - // { "text": string, "filename": string } - if (!chunk.contains("text") || !chunk.at("text").is_string()) { - res->error(format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - // filename is optional - if (chunk.contains("filename") && !chunk.at("filename").is_string()) { - res->error(format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - } - data["input_extra"] = input_extra; // default to empty array if it's not exist - - std::string prompt = json_value(data, "prompt", std::string()); - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true); - SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); - data["prompt"] = format_prompt_infill( - ctx_server.vocab, - data.at("input_prefix"), - data.at("input_suffix"), - data.at("input_extra"), - ctx_server.params_base.n_batch, - ctx_server.params_base.n_predict, - ctx_server.slots[0].n_ctx, // TODO: there should be a better way - ctx_server.params_base.spm_infill, - tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal. - ); - - std::vector files; // dummy - return handle_completions_impl( - SERVER_TASK_TYPE_INFILL, - data, - files, - req.should_stop, - TASK_RESPONSE_TYPE_NONE); // infill is not OAI compatible - }; - - server_http_context::handler_t post_completions = [this](const server_http_req & req) { - std::vector files; // dummy - const json body = json::parse(req.body); - return handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - body, - files, - req.should_stop, - TASK_RESPONSE_TYPE_NONE); - }; - - server_http_context::handler_t post_completions_oai = [this](const server_http_req & req) { - std::vector files; // dummy - const json body = json::parse(req.body); - return handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - body, - files, - req.should_stop, - TASK_RESPONSE_TYPE_OAI_CMPL); - }; - - server_http_context::handler_t post_chat_completions = [this](const server_http_req & req) { - std::vector files; - json body = json::parse(req.body); - json body_parsed = oaicompat_chat_params_parse( - body, - ctx_server.oai_parser_opt, - files); - return handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - body_parsed, - files, - req.should_stop, - TASK_RESPONSE_TYPE_OAI_CHAT); - }; - - server_http_context::handler_t post_anthropic_messages = [this](const server_http_req & req) { - std::vector files; - json body = convert_anthropic_to_oai(json::parse(req.body)); - json body_parsed = oaicompat_chat_params_parse( - body, - ctx_server.oai_parser_opt, - files); - return handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - body_parsed, - files, - req.should_stop, - TASK_RESPONSE_TYPE_ANTHROPIC); - }; - - server_http_context::handler_t post_anthropic_count_tokens = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - std::vector files; - json body = convert_anthropic_to_oai(json::parse(req.body)); - json body_parsed = oaicompat_chat_params_parse( - body, - ctx_server.oai_parser_opt, - files); - - json prompt = body_parsed.at("prompt"); - llama_tokens tokens = tokenize_mixed(ctx_server.vocab, prompt, true, true); - - res->ok({{"input_tokens", static_cast(tokens.size())}}); - return res; - }; - - // same with handle_chat_completions, but without inference part - server_http_context::handler_t post_apply_template = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - std::vector files; // dummy, unused - json body = json::parse(req.body); - json data = oaicompat_chat_params_parse( - body, - ctx_server.oai_parser_opt, - files); - res->ok({{ "prompt", std::move(data.at("prompt")) }}); - return res; - }; - - server_http_context::handler_t get_models = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - bool is_model_ready = ctx_http.is_ready.load(); - json model_meta = nullptr; - if (is_model_ready) { - model_meta = ctx_server.model_meta(); - } - bool has_mtmd = ctx_server.mctx != nullptr; - json models = { - {"models", { - { - {"name", params.model_alias.empty() ? params.model.path : params.model_alias}, - {"model", params.model_alias.empty() ? params.model.path : params.model_alias}, - {"modified_at", ""}, - {"size", ""}, - {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash - {"type", "model"}, - {"description", ""}, - {"tags", {""}}, - {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}, - {"parameters", ""}, - {"details", { - {"parent_model", ""}, - {"format", "gguf"}, - {"family", ""}, - {"families", {""}}, - {"parameter_size", ""}, - {"quantization_level", ""} - }} - } - }}, - {"object", "list"}, - {"data", { - { - {"id", params.model_alias.empty() ? params.model.path : params.model_alias}, - {"object", "model"}, - {"created", std::time(0)}, - {"owned_by", "llamacpp"}, - {"meta", model_meta}, - }, - }} - }; - - res->ok(models); - return res; - }; - - server_http_context::handler_t post_tokenize = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - const json body = json::parse(req.body); - json tokens_response = json::array(); - if (body.count("content") != 0) { - const bool add_special = json_value(body, "add_special", false); - const bool parse_special = json_value(body, "parse_special", true); - const bool with_pieces = json_value(body, "with_pieces", false); - - llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, parse_special); - - if (with_pieces) { - for (const auto& token : tokens) { - std::string piece = common_token_to_piece(ctx_server.ctx, token); - json piece_json; - - // Check if the piece is valid UTF-8 - if (is_valid_utf8(piece)) { - piece_json = piece; - } else { - // If not valid UTF-8, store as array of byte values - piece_json = json::array(); - for (unsigned char c : piece) { - piece_json.push_back(static_cast(c)); - } - } - - tokens_response.push_back({ - {"id", token}, - {"piece", piece_json} - }); - } - } else { - tokens_response = tokens; - } - } - - res->ok(json{{"tokens", std::move(tokens_response)}}); - return res; - }; - - server_http_context::handler_t post_detokenize = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - const json body = json::parse(req.body); - - std::string content; - if (body.count("tokens") != 0) { - const llama_tokens tokens = body.at("tokens"); - content = tokens_to_str(ctx_server.ctx, tokens); - } - - res->ok(json{{"content", std::move(content)}}); - return res; - }; - - server_http_context::handler_t post_embeddings = [this](const server_http_req & req) { - return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_NONE); - }; - - server_http_context::handler_t post_embeddings_oai = [this](const server_http_req & req) { - return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_OAI_EMBD); - }; - - server_http_context::handler_t post_rerank = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) { - res->error(format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - const json body = json::parse(req.body); - - // if true, use TEI API format, otherwise use Jina API format - // Jina: https://jina.ai/reranker/ - // TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank - bool is_tei_format = body.contains("texts"); - - json query; - if (body.count("query") == 1) { - query = body.at("query"); - if (!query.is_string()) { - res->error(format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - } else { - res->error(format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - std::vector documents = json_value(body, "documents", - json_value(body, "texts", std::vector())); - if (documents.empty()) { - res->error(format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - int top_n = json_value(body, "top_n", (int)documents.size()); - - // create and queue the task - json responses = json::array(); - server_response_reader rd(ctx_server); - { - std::vector tasks; - tasks.reserve(documents.size()); - for (size_t i = 0; i < documents.size(); i++) { - auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); - server_task task = server_task(SERVER_TASK_TYPE_RERANK); - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.tokens = std::move(tmp); - tasks.push_back(std::move(task)); - } - rd.post_tasks(std::move(tasks)); - } - - // wait for the results - auto all_results = rd.wait_for_all(req.should_stop); - - // collect results - if (all_results.is_terminated) { - return res; // connection is closed - } else if (all_results.error) { - res->error(all_results.error->to_json()); - return res; - } else { - for (auto & res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - responses.push_back(res->to_json()); - } - } - - // write JSON response - json root = format_response_rerank( - body, - responses, - is_tei_format, - documents, - top_n); - - res->ok(root); - return res; - }; - - server_http_context::handler_t get_lora_adapters = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - json result = json::array(); - const auto & loras = ctx_server.params_base.lora_adapters; - for (size_t i = 0; i < loras.size(); ++i) { - auto & lora = loras[i]; - json entry = { - {"id", i}, - {"path", lora.path}, - {"scale", lora.scale}, - {"task_name", lora.task_name}, - {"prompt_prefix", lora.prompt_prefix}, - }; - std::string alora_invocation_string = ""; - const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr); - std::vector alora_invocation_tokens; - if (n_alora_tokens) { - const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr); - for (uint64_t i = 0; i < n_alora_tokens; ++i) { - alora_invocation_string += common_token_to_piece(ctx_server.ctx, alora_tokens[i]); - alora_invocation_tokens.push_back(alora_tokens[i]); - } - entry["alora_invocation_string"] = alora_invocation_string; - entry["alora_invocation_tokens"] = alora_invocation_tokens; - } - result.push_back(std::move(entry)); - } - res->ok(result); - return res; - }; - - server_http_context::handler_t post_lora_adapters = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - const json body = json::parse(req.body); - if (!body.is_array()) { - res->error(format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SET_LORA); - task.id = task_id; - task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body); - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res->ok(result->to_json()); - return res; - }; - -private: - std::unique_ptr handle_completions_impl( - server_task_type type, - const json & data, - const std::vector & files, - const std::function & should_stop, - task_response_type res_type) { - GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); - - auto res = std::make_unique(ctx_server); - auto completion_id = gen_chatcmplid(); - auto & rd = res->rd; - - try { - std::vector tasks; - - const auto & prompt = data.at("prompt"); - // TODO: this log can become very long, put it behind a flag or think about a more compact format - //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); - - // process prompt - std::vector inputs; - - if (res_type != TASK_RESPONSE_TYPE_NONE && ctx_server.mctx != nullptr) { - // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below. - inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); - } else { - // Everything else, including multimodal completions. - inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); - } - tasks.reserve(inputs.size()); - for (size_t i = 0; i < inputs.size(); i++) { - server_task task = server_task(type); - - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - - task.tokens = std::move(inputs[i]); - task.params = server_task::params_from_json_cmpl( - ctx_server.ctx, - ctx_server.params_base, - data); - task.id_slot = json_value(data, "id_slot", -1); - - // OAI-compat - task.params.res_type = res_type; - task.params.oaicompat_cmpl_id = completion_id; - // oaicompat_model is already populated by params_from_json_cmpl - - tasks.push_back(std::move(task)); - } - - rd.post_tasks(std::move(tasks)); - } catch (const std::exception & e) { - res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - bool stream = json_value(data, "stream", false); - - if (!stream) { - // non-stream, wait for the results - auto all_results = rd.wait_for_all(should_stop); - if (all_results.is_terminated) { - return res; // connection is closed - } else if (all_results.error) { - res->error(all_results.error->to_json()); - return res; - } else { - json arr = json::array(); - for (auto & res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - arr.push_back(res->to_json()); - } - // if single request, return single object instead of array - res->ok(arr.size() == 1 ? arr[0] : arr); - } - - } else { - // in streaming mode, the first error must be treated as non-stream response - // this is to match the OAI API behavior - // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 - server_task_result_ptr first_result = rd.next(should_stop); - if (first_result == nullptr) { - return res; // connection is closed - } else if (first_result->is_error()) { - res->error(first_result->to_json()); - return res; - } else { - GGML_ASSERT( - dynamic_cast(first_result.get()) != nullptr - || dynamic_cast(first_result.get()) != nullptr - ); - } - - // next responses are streamed - if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { - res->data = format_anthropic_sse(first_result->to_json()); - } else { - res->data = format_oai_sse(first_result->to_json()); // to be sent immediately - } - res->status = 200; - res->content_type = "text/event-stream"; - res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool { - if (should_stop()) { - SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); - return false; // should_stop condition met - } - - if (!res_this->data.empty()) { - // flush the first chunk - output = std::move(res_this->data); - res_this->data.clear(); - return true; - } - - server_response_reader & rd = res_this->rd; - - // check if there is more data - if (!rd.has_next()) { - if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { - // Anthropic doesn't send [DONE], message_stop was already sent - output = ""; - } else if (res_type != TASK_RESPONSE_TYPE_NONE) { - output = "data: [DONE]\n\n"; - } else { - output = ""; - } - SRV_DBG("%s", "all results received, terminating stream\n"); - return false; // no more data, terminate - } - - // receive subsequent results - auto result = rd.next(should_stop); - if (result == nullptr) { - SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); - return false; // should_stop condition met - } - - // send the results - json res_json = result->to_json(); - if (result->is_error()) { - if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { - output = format_anthropic_sse({ - {"event", "error"}, - {"data", res_json}, - }); - } else { - output = format_oai_sse(json {{ "error", res_json }}); - } - SRV_DBG("%s", "error received during streaming, terminating stream\n"); - return false; // terminate on error - } else { - GGML_ASSERT( - dynamic_cast(result.get()) != nullptr - || dynamic_cast(result.get()) != nullptr - ); - if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { - output = format_anthropic_sse(res_json); - } else { - output = format_oai_sse(res_json); - } - } - - // has next data, continue - return true; - }; - } - - return res; - } - - std::unique_ptr handle_slots_save(const server_http_req & req, int id_slot) { - auto res = std::make_unique(ctx_server); - const json request_data = json::parse(req.body); - std::string filename = request_data.at("filename"); - if (!fs_validate_filename(filename)) { - res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - std::string filepath = params.slot_save_path + filename; - - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SLOT_SAVE); - task.id = task_id; - task.slot_action.slot_id = id_slot; - task.slot_action.filename = filename; - task.slot_action.filepath = filepath; - - // TODO: use server_response_reader - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - res->ok(result->to_json()); - return res; - } - - std::unique_ptr handle_slots_restore(const server_http_req & req, int id_slot) { - auto res = std::make_unique(ctx_server); - const json request_data = json::parse(req.body); - std::string filename = request_data.at("filename"); - if (!fs_validate_filename(filename)) { - res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - std::string filepath = params.slot_save_path + filename; - - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); - task.id = task_id; - task.slot_action.slot_id = id_slot; - task.slot_action.filename = filename; - task.slot_action.filepath = filepath; - - // TODO: use server_response_reader - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res->ok(result->to_json()); - return res; - } - - std::unique_ptr handle_slots_erase(const server_http_req &, int id_slot) { - auto res = std::make_unique(ctx_server); - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SLOT_ERASE); - task.id = task_id; - task.slot_action.slot_id = id_slot; - - // TODO: use server_response_reader - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res->ok(result->to_json()); - return res; - } - - std::unique_ptr handle_embeddings_impl(const server_http_req & req, task_response_type res_type) { - auto res = std::make_unique(ctx_server); - if (!ctx_server.params_base.embedding) { - res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - if (res_type != TASK_RESPONSE_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { - res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - const json body = json::parse(req.body); - - // for the shape of input/content, see tokenize_input_prompts() - json prompt; - if (body.count("input") != 0) { - prompt = body.at("input"); - } else if (body.contains("content")) { - res_type = TASK_RESPONSE_TYPE_NONE; // "content" field is not OAI compatible - prompt = body.at("content"); - } else { - res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - bool use_base64 = false; - if (body.count("encoding_format") != 0) { - const std::string& format = body.at("encoding_format"); - if (format == "base64") { - use_base64 = true; - } else if (format != "float") { - res->error(format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - } - - auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); - for (const auto & tokens : tokenized_prompts) { - // this check is necessary for models that do not add BOS token to the input - if (tokens.empty()) { - res->error(format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - } - - int embd_normalize = 2; // default to Euclidean/L2 norm - if (body.count("embd_normalize") != 0) { - embd_normalize = body.at("embd_normalize"); - if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { - SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx)); - } - } - - // create and queue the task - json responses = json::array(); - server_response_reader rd(ctx_server); - { - std::vector tasks; - for (size_t i = 0; i < tokenized_prompts.size(); i++) { - server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); - - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.tokens = std::move(tokenized_prompts[i]); - - // OAI-compat - task.params.res_type = res_type; - task.params.embd_normalize = embd_normalize; - - tasks.push_back(std::move(task)); - } - rd.post_tasks(std::move(tasks)); - } - - // wait for the results - auto all_results = rd.wait_for_all(req.should_stop); - - // collect results - if (all_results.is_terminated) { - return res; // connection is closed - } else if (all_results.error) { - res->error(all_results.error->to_json()); - return res; - } else { - for (auto & res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - responses.push_back(res->to_json()); - } - } - - // write JSON response - json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD - ? format_embeddings_response_oaicompat(body, responses, use_base64) - : json(responses); - res->ok(root); - return res; - } -}; - static std::function shutdown_handler; static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; @@ -3723,9 +80,6 @@ int main(int argc, char ** argv) { // struct that contains llama context and inference server_context ctx_server; - // Necessary similarity of prompt for slot selection - ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; - llama_backend_init(); llama_numa_init(params.numa); @@ -3745,7 +99,7 @@ int main(int argc, char ** argv) { // // register API routes - server_routes routes(params, ctx_server, ctx_http); + server_routes routes(params, ctx_server, [&ctx_http]() { return ctx_http.is_ready.load(); }); ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) @@ -3790,7 +144,7 @@ int main(int argc, char ** argv) { auto clean_up = [&ctx_http, &ctx_server]() { SRV_INF("%s: cleaning up before exit...\n", __func__); ctx_http.stop(); - ctx_server.queue_results.terminate(); + ctx_server.terminate(); llama_backend_free(); }; @@ -3818,17 +172,9 @@ int main(int argc, char ** argv) { LOG_INF("%s: model loaded\n", __func__); - ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) { - ctx_server.process_single_task(std::move(task)); - }); - - ctx_server.queue_tasks.on_update_slots([&ctx_server]() { - ctx_server.update_slots(); - }); - shutdown_handler = [&](int) { // this will unblock start_loop() - ctx_server.queue_tasks.terminate(); + ctx_server.terminate(); }; // TODO: refactor in common/console @@ -3848,14 +194,14 @@ int main(int argc, char ** argv) { LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); LOG_INF("%s: starting the main loop...\n", __func__); - // this call blocks the main thread until queue_tasks.terminate() is called - ctx_server.queue_tasks.start_loop(); + // this call blocks the main thread until ctx_server.terminate() is called + ctx_server.start_loop(); clean_up(); if (ctx_http.thread.joinable()) { ctx_http.thread.join(); } - llama_memory_breakdown_print(ctx_server.ctx); + llama_memory_breakdown_print(ctx_server.get_llama_context()); return 0; }