From a864fb1c14cf187831a1806c3b57e47c730afc1d Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 30 Dec 2025 10:13:12 +0100 Subject: [PATCH 1/5] model-conversion : use CONVERTED_MODEL for compare-embeddings (#18461) This commit updates the causal model verification script to use the CONVERTED_MODEL environment variable instead of using the MODEL_PATH (the original model path) as the basis for the converted model file name. The motivation for this that currently if the converted model file name differs from the original model directory/name the verification script will look for the wrong .bin file that was generating when running the converted model. This similar to the change made for the embeddings models script in Commit db81d5ec4b0a9cb19e98c4533731c9554eb025db ("model-conversion : use CONVERTED_EMBEDDING_MODEL for embedding_verify_logits (#18079)"), but we also verify the embeddings of for causal models as well. --- .../scripts/causal/compare-embeddings-logits.sh | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/model-conversion/scripts/causal/compare-embeddings-logits.sh b/examples/model-conversion/scripts/causal/compare-embeddings-logits.sh index c53c89d48a..2ae4dc7061 100755 --- a/examples/model-conversion/scripts/causal/compare-embeddings-logits.sh +++ b/examples/model-conversion/scripts/causal/compare-embeddings-logits.sh @@ -5,8 +5,11 @@ set -e MODEL_PATH="${1:-"$MODEL_PATH"}" MODEL_NAME="${2:-$(basename "$MODEL_PATH")}" +CONVERTED_MODEL_PATH="${1:-"$CONVERTED_MODEL"}" +CONVERTED_MODEL_NAME="${2:-$(basename "$CONVERTED_MODEL_PATH" ".gguf")}" + if [ -t 0 ]; then - CPP_EMBEDDINGS="data/llamacpp-${MODEL_NAME}-embeddings.bin" + CPP_EMBEDDINGS="data/llamacpp-${CONVERTED_MODEL_NAME}-embeddings.bin" else # Process piped JSON data and convert to binary (matching logits.cpp format) TEMP_FILE=$(mktemp /tmp/tmp.XXXXXX.binn) From d77d7c5c0654dc52b51f03941b12ae85d7227608 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 30 Dec 2025 17:40:46 +0800 Subject: [PATCH 2/5] CUDA: add log line when mxfp4 acceleration is used (#18483) * CUDA: add log line when mxfp4 acceleration is used * add in backend_get_features --- ggml/src/ggml-cuda/ggml-cuda.cu | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 40ffe92c57..55e1c20c96 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4785,6 +4785,16 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t features.push_back({ "FA_ALL_QUANTS", "1" }); #endif + { + const auto & info = ggml_cuda_info(); + for (int id = 0; id < info.device_count; ++id) { + if (blackwell_mma_available(info.devices[id].cc)) { + features.push_back({ "BLACKWELL_NATIVE_FP4", "1"}); + break; + } + } + } + #undef _STRINGIFY #undef STRINGIFY From 2d6c00a9b8bb4d72f4f43c6521ff1061088e2c2c Mon Sep 17 00:00:00 2001 From: Charles Xu Date: Tue, 30 Dec 2025 13:04:53 +0100 Subject: [PATCH 3/5] kleidiai: add and integrate SVE 256-bit vector-length kernel (#18458) * kleidiai: add and integrate SVE 256-bit vector-length kernel * updated for review comments --- ggml/src/CMakeLists.txt | 4 +- ggml/src/ggml-cpu/CMakeLists.txt | 14 ++- ggml/src/ggml-cpu/kleidiai/kernels.cpp | 111 ++++++++++++++++++------ ggml/src/ggml-cpu/kleidiai/kleidiai.cpp | 23 +++-- 4 files changed, 115 insertions(+), 37 deletions(-) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 25f25c4236..6192a87046 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -401,8 +401,8 @@ if (GGML_CPU_ALL_VARIANTS) ggml_add_cpu_backend_variant(android_armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC) ggml_add_cpu_backend_variant(android_armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC MATMUL_INT8) ggml_add_cpu_backend_variant(android_armv9.0_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE2) - ggml_add_cpu_backend_variant(android_armv9.2_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SME) - ggml_add_cpu_backend_variant(android_armv9.2_2 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SME) + ggml_add_cpu_backend_variant(android_armv9.2_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SME) + ggml_add_cpu_backend_variant(android_armv9.2_2 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SVE2 SME) elseif (APPLE) ggml_add_cpu_backend_variant(apple_m1 DOTPROD) ggml_add_cpu_backend_variant(apple_m2_m3 DOTPROD MATMUL_INT8) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 28fb7612e5..7622d0bf49 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -561,9 +561,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # Fetch KleidiAI sources: include(FetchContent) - set(KLEIDIAI_COMMIT_TAG "v1.14.0") + set(KLEIDIAI_COMMIT_TAG "v1.16.0") set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") - set(KLEIDIAI_ARCHIVE_MD5 "45e110675d93f99f82c23a1afcca76bc") + set(KLEIDIAI_ARCHIVE_MD5 "0a9e9008adb6031f9e8cf70dff4a3321") if (POLICY CMP0135) cmake_policy(SET CMP0135 NEW) @@ -615,6 +615,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) string(FIND "${ARCH_FLAGS_TEMP}" "+dotprod" DOTPROD_ENABLED) string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED) string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED) + string(FIND "${ARCH_FLAGS_TEMP}" "+sve" SVE_ENABLED) set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS_TEMP}) @@ -659,6 +660,15 @@ function(ggml_add_cpu_backend_variant_impl tag_name) set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2") endif() + if (NOT SVE_ENABLED MATCHES -1) + list(APPEND GGML_KLEIDIAI_SOURCES + ${KLEIDIAI_SRC}/kai/kai_common_sve_asm.S + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod_asm.S + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm_asm.S + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.c) + endif() + set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}") list(APPEND GGML_CPU_SOURCES ${GGML_KLEIDIAI_SOURCES}) endif() diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp index 55a00f008a..d114f2d49b 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -18,6 +18,8 @@ #include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h" #include "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h" #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.h" +#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.h" #include "kai_lhs_pack_bf16p2vlx2_f32_sme.h" #include "kai_lhs_quant_pack_qsi8d32p_f32.h" @@ -69,9 +71,9 @@ static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/, template static inline void kernel_run_float_fn10(size_t m, size_t n, size_t k, size_t /*bl*/, - const void* lhs, const void* rhs, void* dst, - size_t dst_stride_row, size_t dst_stride_col, - float clamp_min, float clamp_max) { + const void* lhs, const void* rhs, void* dst, + size_t dst_stride_row, size_t dst_stride_col, + float clamp_min, float clamp_max) { Fn(m, n, k, lhs, rhs, static_cast(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max); } @@ -152,8 +154,8 @@ static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t n template static inline void rhs_pack_scale_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/, - size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale, - void* rhs_packed, size_t extra_bytes, const void* params) { + size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale, + void* rhs_packed, size_t extra_bytes, const void* params) { Fn(num_groups, n, k, nr, kr, sr, static_cast(rhs), static_cast(bias), @@ -524,6 +526,61 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { }, #endif #else +#if defined(__ARM_FEATURE_SVE) + { + /* SVE i8mm GEMM */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, + }, + /* .gemm_lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, + }, + /* SVE dotprod GEMV */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, + }, + /* .gemv_lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, + }, + /* .rhs_info = */ { + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_size_ex = */ &rhs_ps_fn5, + /* .packed_stride_ex = */ &rhs_stride_fn4, + /* .pack_func_ex = */ &rhs_pack_fn12, + }, + /* .required_cpu = */ CPU_FEATURE_SVE | CPU_FEATURE_I8MM | CPU_FEATURE_DOTPROD, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q4_0, + /* .op_type = */ GGML_TYPE_F32, + }, +#endif #if defined(__ARM_FEATURE_MATMUL_INT8) { /* i8mm GEMM */ @@ -578,7 +635,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .rhs_type = */ GGML_TYPE_Q4_0, /* .op_type = */ GGML_TYPE_F32, }, -#endif +#endif // __ARM_FEATURE_MATMUL_INT8 #if defined(__ARM_FEATURE_DOTPROD) { /* DOTPROD GEMM */ @@ -811,26 +868,27 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c ggml_kleidiai_kernels * kernel = nullptr; if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) { -#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) - for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) { - if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu && - gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type && - gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type && - gemm_gemv_kernels[i].op_type == tensor->type) { - kernel = &gemm_gemv_kernels[i]; - break; - } - } - if (!kernel) { - for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) { - if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu && - gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type && - gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type && - gemm_gemv_kernels_q8[i].op_type == tensor->type) { - kernel = &gemm_gemv_kernels_q8[i]; - break; +#if defined(__ARM_FEATURE_SME) || \ + defined(__ARM_FEATURE_DOTPROD) || \ + defined(__ARM_FEATURE_MATMUL_INT8) || \ + defined(__ARM_FEATURE_SVE) + auto try_table = [&](auto & table) { + for (size_t i = 0; i < NELEMS(table) - 1; ++i) { + if ((cpu_features & table[i].required_cpu) == table[i].required_cpu && + table[i].lhs_type == tensor->src[1]->type && + table[i].rhs_type == tensor->src[0]->type && + table[i].op_type == tensor->type) { + kernel = &table[i]; + return true; } } + return false; + }; + + if (tensor->src[0]->type == GGML_TYPE_Q8_0) { + try_table(gemm_gemv_kernels_q8); + } else { + try_table(gemm_gemv_kernels); } #else GGML_UNUSED(gemm_gemv_kernels); @@ -845,7 +903,10 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) { ggml_kleidiai_kernels * kernels = nullptr; -#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) +#if defined(__ARM_FEATURE_SME) || \ + defined(__ARM_FEATURE_DOTPROD) || \ + defined(__ARM_FEATURE_MATMUL_INT8) || \ + defined(__ARM_FEATURE_SVE) for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) { if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) { kernels = &gemm_gemv_kernels[i]; diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 6f2a90fbda..ad23e73184 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -46,13 +46,20 @@ struct ggml_kleidiai_context { } static ctx = { CPU_FEATURE_NONE, NULL, NULL }; static const char* cpu_feature_to_string(cpu_feature f) { - switch (f) { - case CPU_FEATURE_NONE: return "NONE"; - case CPU_FEATURE_DOTPROD: return "DOTPROD"; - case CPU_FEATURE_I8MM: return "I8MM"; - case CPU_FEATURE_SVE: return "SVE"; - case CPU_FEATURE_SME: return "SME"; - default: return "UNKNOWN"; + if (f == CPU_FEATURE_NONE) { + return "NONE"; + } else if ((f & CPU_FEATURE_SME) == CPU_FEATURE_SME) { + return "SME"; + } else if ((f & CPU_FEATURE_SVE) == CPU_FEATURE_SVE) { + return "SVE"; + } + else if ((f & CPU_FEATURE_I8MM) == CPU_FEATURE_I8MM) { + return "I8MM"; + } else if ((f & CPU_FEATURE_DOTPROD) == CPU_FEATURE_DOTPROD) { + return "DOTPROD"; + } + else { + return "UNKNOWN"; } } @@ -68,7 +75,7 @@ static void init_kleidiai_context(void) { ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) | (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) | - (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE); + ((ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) ? CPU_FEATURE_SVE : CPU_FEATURE_NONE); if (env_var) { sme_enabled = atoi(env_var); From f14f4e421b2177fadcf9d15ebccb0492e5464d86 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 30 Dec 2025 06:11:13 -0600 Subject: [PATCH 4/5] server: fix files built redundantly (#18474) --- tools/server/CMakeLists.txt | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index ae1a497be6..a39b4c5b35 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -38,14 +38,6 @@ set(TARGET_SRCS server-http.h server-models.cpp server-models.h - server-task.cpp - server-task.h - server-queue.cpp - server-queue.h - server-common.cpp - server-common.h - server-context.cpp - server-context.h ) set(PUBLIC_ASSETS index.html.gz From c32fa21db8a631e9127e55f69a3d2bdaa9f71824 Mon Sep 17 00:00:00 2001 From: Jay Zenith <162098309+JayZenith@users.noreply.github.com> Date: Tue, 30 Dec 2025 06:27:49 -0800 Subject: [PATCH 5/5] sampling: reuse token data buffer in llama_sampler_sample (#18365) * sampling: reuse token data buffer in llama_sampler_sample * move cur buffer before timing section, after samplers * minor : fix build --------- Co-authored-by: Georgi Gerganov --- src/llama-sampling.cpp | 77 ++++++++++++++++++++++++------------------ src/llama-sampling.h | 3 ++ 2 files changed, 47 insertions(+), 33 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index d96f619ae1..f3891453e4 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -421,39 +421,6 @@ void llama_sampler_free(struct llama_sampler * smpl) { delete smpl; } -llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { - const auto * logits = llama_get_logits_ith(ctx, idx); - - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); - - const int n_vocab = llama_vocab_n_tokens(vocab); - - // TODO: do not allocate each time - std::vector cur; - cur.reserve(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); - } - - llama_token_data_array cur_p = { - /* .data = */ cur.data(), - /* .size = */ cur.size(), - /* .selected = */ -1, - /* .sorted = */ false, - }; - - llama_sampler_apply(smpl, &cur_p); - - GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size); - - auto token = cur_p.data[cur_p.selected].id; - - llama_sampler_accept(smpl, token); - - return token; -} - // sampler chain static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) { @@ -527,12 +494,56 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param /* .ctx = */ new llama_sampler_chain { /* .params = */ params, /* .samplers = */ {}, + /* .cur = */ {}, /* .t_sample_us = */ 0, /* .n_sample = */ 0, } ); } +llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { + const auto * logits = llama_get_logits_ith(ctx, idx); + + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_vocab = llama_vocab_n_tokens(vocab); + + // use pre-allocated buffer from chain if available, otherwise allocate locally + std::vector * cur_ptr; + std::vector cur_local; + + if (smpl->iface == &llama_sampler_chain_i) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + cur_ptr = &chain->cur; + } else { + cur_ptr = &cur_local; + } + + auto & cur = *cur_ptr; + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + + llama_token_data_array cur_p = { + /* .data = */ cur.data(), + /* .size = */ cur.size(), + /* .selected = */ -1, + /* .sorted = */ false, + }; + + llama_sampler_apply(smpl, &cur_p); + + GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size); + + auto token = cur_p.data[cur_p.selected].id; + + llama_sampler_accept(smpl, token); + + return token; +} + void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { auto * p = (llama_sampler_chain *) chain->ctx; p->samplers.push_back(smpl); diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 759dd7dcb7..1e3de4e2ec 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -16,6 +16,9 @@ struct llama_sampler_chain { std::vector samplers; + // pre-allocated buffer for llama_sampler_sample to avoid repeated allocations + std::vector cur; + // timing mutable int64_t t_sample_us;