diff --git a/examples/model-conversion/scripts/causal/compare-embeddings-logits.sh b/examples/model-conversion/scripts/causal/compare-embeddings-logits.sh index c53c89d48a..2ae4dc7061 100755 --- a/examples/model-conversion/scripts/causal/compare-embeddings-logits.sh +++ b/examples/model-conversion/scripts/causal/compare-embeddings-logits.sh @@ -5,8 +5,11 @@ set -e MODEL_PATH="${1:-"$MODEL_PATH"}" MODEL_NAME="${2:-$(basename "$MODEL_PATH")}" +CONVERTED_MODEL_PATH="${1:-"$CONVERTED_MODEL"}" +CONVERTED_MODEL_NAME="${2:-$(basename "$CONVERTED_MODEL_PATH" ".gguf")}" + if [ -t 0 ]; then - CPP_EMBEDDINGS="data/llamacpp-${MODEL_NAME}-embeddings.bin" + CPP_EMBEDDINGS="data/llamacpp-${CONVERTED_MODEL_NAME}-embeddings.bin" else # Process piped JSON data and convert to binary (matching logits.cpp format) TEMP_FILE=$(mktemp /tmp/tmp.XXXXXX.binn) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 25f25c4236..6192a87046 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -401,8 +401,8 @@ if (GGML_CPU_ALL_VARIANTS) ggml_add_cpu_backend_variant(android_armv8.2_2 DOTPROD FP16_VECTOR_ARITHMETIC) ggml_add_cpu_backend_variant(android_armv8.6_1 DOTPROD FP16_VECTOR_ARITHMETIC MATMUL_INT8) ggml_add_cpu_backend_variant(android_armv9.0_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE2) - ggml_add_cpu_backend_variant(android_armv9.2_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SME) - ggml_add_cpu_backend_variant(android_armv9.2_2 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SME) + ggml_add_cpu_backend_variant(android_armv9.2_1 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SME) + ggml_add_cpu_backend_variant(android_armv9.2_2 DOTPROD MATMUL_INT8 FP16_VECTOR_ARITHMETIC SVE SVE2 SME) elseif (APPLE) ggml_add_cpu_backend_variant(apple_m1 DOTPROD) ggml_add_cpu_backend_variant(apple_m2_m3 DOTPROD MATMUL_INT8) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 28fb7612e5..7622d0bf49 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -561,9 +561,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # Fetch KleidiAI sources: include(FetchContent) - set(KLEIDIAI_COMMIT_TAG "v1.14.0") + set(KLEIDIAI_COMMIT_TAG "v1.16.0") set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") - set(KLEIDIAI_ARCHIVE_MD5 "45e110675d93f99f82c23a1afcca76bc") + set(KLEIDIAI_ARCHIVE_MD5 "0a9e9008adb6031f9e8cf70dff4a3321") if (POLICY CMP0135) cmake_policy(SET CMP0135 NEW) @@ -615,6 +615,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) string(FIND "${ARCH_FLAGS_TEMP}" "+dotprod" DOTPROD_ENABLED) string(FIND "${ARCH_FLAGS_TEMP}" "+i8mm" I8MM_ENABLED) string(FIND "${ARCH_FLAGS_TEMP}" "+sme" SME_ENABLED) + string(FIND "${ARCH_FLAGS_TEMP}" "+sve" SVE_ENABLED) set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS_TEMP}) @@ -659,6 +660,15 @@ function(ggml_add_cpu_backend_variant_impl tag_name) set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2") endif() + if (NOT SVE_ENABLED MATCHES -1) + list(APPEND GGML_KLEIDIAI_SOURCES + ${KLEIDIAI_SRC}/kai/kai_common_sve_asm.S + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod_asm.S + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm_asm.S + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.c) + endif() + set_source_files_properties(${GGML_KLEIDIAI_SOURCES} PROPERTIES COMPILE_OPTIONS "${PRIVATE_ARCH_FLAGS}") list(APPEND GGML_CPU_SOURCES ${GGML_KLEIDIAI_SOURCES}) endif() diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp index 55a00f008a..d114f2d49b 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -18,6 +18,8 @@ #include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h" #include "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h" #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h" +#include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.h" +#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.h" #include "kai_lhs_pack_bf16p2vlx2_f32_sme.h" #include "kai_lhs_quant_pack_qsi8d32p_f32.h" @@ -69,9 +71,9 @@ static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/, template static inline void kernel_run_float_fn10(size_t m, size_t n, size_t k, size_t /*bl*/, - const void* lhs, const void* rhs, void* dst, - size_t dst_stride_row, size_t dst_stride_col, - float clamp_min, float clamp_max) { + const void* lhs, const void* rhs, void* dst, + size_t dst_stride_row, size_t dst_stride_col, + float clamp_min, float clamp_max) { Fn(m, n, k, lhs, rhs, static_cast(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max); } @@ -152,8 +154,8 @@ static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t n template static inline void rhs_pack_scale_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/, - size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale, - void* rhs_packed, size_t extra_bytes, const void* params) { + size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale, + void* rhs_packed, size_t extra_bytes, const void* params) { Fn(num_groups, n, k, nr, kr, sr, static_cast(rhs), static_cast(bias), @@ -524,6 +526,61 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { }, #endif #else +#if defined(__ARM_FEATURE_SVE) + { + /* SVE i8mm GEMM */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, + }, + /* .gemm_lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, + }, + /* SVE dotprod GEMV */ + /* .kern_info = */ { + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, + }, + /* .gemv_lhs_info = */ { + /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_float_fn10, + }, + /* .rhs_info = */ { + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .to_float = */ dequantize_row_qsi4c32pscalef16, + /* .packed_size_ex = */ &rhs_ps_fn5, + /* .packed_stride_ex = */ &rhs_stride_fn4, + /* .pack_func_ex = */ &rhs_pack_fn12, + }, + /* .required_cpu = */ CPU_FEATURE_SVE | CPU_FEATURE_I8MM | CPU_FEATURE_DOTPROD, + /* .lhs_type = */ GGML_TYPE_F32, + /* .rhs_type = */ GGML_TYPE_Q4_0, + /* .op_type = */ GGML_TYPE_F32, + }, +#endif #if defined(__ARM_FEATURE_MATMUL_INT8) { /* i8mm GEMM */ @@ -578,7 +635,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .rhs_type = */ GGML_TYPE_Q4_0, /* .op_type = */ GGML_TYPE_F32, }, -#endif +#endif // __ARM_FEATURE_MATMUL_INT8 #if defined(__ARM_FEATURE_DOTPROD) { /* DOTPROD GEMM */ @@ -811,26 +868,27 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c ggml_kleidiai_kernels * kernel = nullptr; if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) { -#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) - for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) { - if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu && - gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type && - gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type && - gemm_gemv_kernels[i].op_type == tensor->type) { - kernel = &gemm_gemv_kernels[i]; - break; - } - } - if (!kernel) { - for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8) - 1; ++i) { - if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu && - gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type && - gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type && - gemm_gemv_kernels_q8[i].op_type == tensor->type) { - kernel = &gemm_gemv_kernels_q8[i]; - break; +#if defined(__ARM_FEATURE_SME) || \ + defined(__ARM_FEATURE_DOTPROD) || \ + defined(__ARM_FEATURE_MATMUL_INT8) || \ + defined(__ARM_FEATURE_SVE) + auto try_table = [&](auto & table) { + for (size_t i = 0; i < NELEMS(table) - 1; ++i) { + if ((cpu_features & table[i].required_cpu) == table[i].required_cpu && + table[i].lhs_type == tensor->src[1]->type && + table[i].rhs_type == tensor->src[0]->type && + table[i].op_type == tensor->type) { + kernel = &table[i]; + return true; } } + return false; + }; + + if (tensor->src[0]->type == GGML_TYPE_Q8_0) { + try_table(gemm_gemv_kernels_q8); + } else { + try_table(gemm_gemv_kernels); } #else GGML_UNUSED(gemm_gemv_kernels); @@ -845,7 +903,10 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) { ggml_kleidiai_kernels * kernels = nullptr; -#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) +#if defined(__ARM_FEATURE_SME) || \ + defined(__ARM_FEATURE_DOTPROD) || \ + defined(__ARM_FEATURE_MATMUL_INT8) || \ + defined(__ARM_FEATURE_SVE) for (size_t i = 0; i < NELEMS(gemm_gemv_kernels) - 1; ++i) { if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) { kernels = &gemm_gemv_kernels[i]; diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 6f2a90fbda..ad23e73184 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -46,13 +46,20 @@ struct ggml_kleidiai_context { } static ctx = { CPU_FEATURE_NONE, NULL, NULL }; static const char* cpu_feature_to_string(cpu_feature f) { - switch (f) { - case CPU_FEATURE_NONE: return "NONE"; - case CPU_FEATURE_DOTPROD: return "DOTPROD"; - case CPU_FEATURE_I8MM: return "I8MM"; - case CPU_FEATURE_SVE: return "SVE"; - case CPU_FEATURE_SME: return "SME"; - default: return "UNKNOWN"; + if (f == CPU_FEATURE_NONE) { + return "NONE"; + } else if ((f & CPU_FEATURE_SME) == CPU_FEATURE_SME) { + return "SME"; + } else if ((f & CPU_FEATURE_SVE) == CPU_FEATURE_SVE) { + return "SVE"; + } + else if ((f & CPU_FEATURE_I8MM) == CPU_FEATURE_I8MM) { + return "I8MM"; + } else if ((f & CPU_FEATURE_DOTPROD) == CPU_FEATURE_DOTPROD) { + return "DOTPROD"; + } + else { + return "UNKNOWN"; } } @@ -68,7 +75,7 @@ static void init_kleidiai_context(void) { ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) | (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) | - (ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE); + ((ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) ? CPU_FEATURE_SVE : CPU_FEATURE_NONE); if (env_var) { sme_enabled = atoi(env_var); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index d3caf0ff75..931b021882 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4799,6 +4799,16 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t features.push_back({ "FA_ALL_QUANTS", "1" }); #endif + { + const auto & info = ggml_cuda_info(); + for (int id = 0; id < info.device_count; ++id) { + if (blackwell_mma_available(info.devices[id].cc)) { + features.push_back({ "BLACKWELL_NATIVE_FP4", "1"}); + break; + } + } + } + #undef _STRINGIFY #undef STRINGIFY diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 650a6f026b..042e40de20 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -425,65 +425,6 @@ void llama_sampler_free(struct llama_sampler * smpl) { delete smpl; } -llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { - const llama_token sampled_token = llama_get_sampled_token_ith (ctx, idx); - const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx); - const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx); - const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx); - - // If a backend sampler has already sampled a token, return it. - if (sampled_token != LLAMA_TOKEN_NULL) { - LLAMA_LOG_DEBUG("%s: Backend sampler selected token for idx %d. Skipping CPU samplers\n", __func__, idx); - return sampled_token; - } - - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); - - const int n_vocab = llama_vocab_n_tokens(vocab); - - // TODO: do not allocate each time - std::vector cur; - - if (sampled_probs) { - const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx); - cur.resize(sampled_probs_count); - for (uint32_t i = 0; i < sampled_probs_count; ++i) { - cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]}; - } - } else if (sampled_logits) { - const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx); - cur.resize(sampled_logits_count); - for (llama_token i = 0; i < (int)sampled_logits_count; i++) { - cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}; - } - } else { - const auto * logits = llama_get_logits_ith(ctx, idx); - GGML_ASSERT(logits != nullptr); - cur.resize(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; - } - } - - llama_token_data_array cur_p = { - /* .data = */ cur.data(), - /* .size = */ cur.size(), - /* .selected = */ -1, - /* .sorted = */ false, - }; - - llama_sampler_apply(smpl, &cur_p); - - GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size); - - auto token = cur_p.data[cur_p.selected].id; - - llama_sampler_accept(smpl, token); - - return token; -} - // empty sampler struct llama_sampler_empty { @@ -855,12 +796,83 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param /* .params = */ params, /* .is_init = */ false, /* .samplers = */ {}, + /* .cur = */ {}, /* .t_sample_us = */ 0, /* .n_sample = */ 0, } ); } +llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { + const llama_token sampled_token = llama_get_sampled_token_ith (ctx, idx); + const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx); + const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx); + const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx); + + // If a backend sampler has already sampled a token, return it. + if (sampled_token != LLAMA_TOKEN_NULL) { + LLAMA_LOG_DEBUG("%s: Backend sampler selected token for idx %d. Skipping CPU samplers\n", __func__, idx); + return sampled_token; + } + + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_vocab = llama_vocab_n_tokens(vocab); + + // use pre-allocated buffer from chain if available, otherwise allocate locally + std::vector * cur_ptr; + std::vector cur_local; + + if (smpl->iface == &llama_sampler_chain_i) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + cur_ptr = &chain->cur; + } else { + cur_ptr = &cur_local; + } + + auto & cur = *cur_ptr; + + if (sampled_probs) { + const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx); + cur.resize(sampled_probs_count); + for (uint32_t i = 0; i < sampled_probs_count; ++i) { + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]}; + } + } else if (sampled_logits) { + const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx); + cur.resize(sampled_logits_count); + for (llama_token i = 0; i < (int)sampled_logits_count; i++) { + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}; + } + } else { + const auto * logits = llama_get_logits_ith(ctx, idx); + GGML_ASSERT(logits != nullptr); + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + } + + llama_token_data_array cur_p = { + /* .data = */ cur.data(), + /* .size = */ cur.size(), + /* .selected = */ -1, + /* .sorted = */ false, + }; + + llama_sampler_apply(smpl, &cur_p); + + GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size); + + auto token = cur_p.data[cur_p.selected].id; + + llama_sampler_accept(smpl, token); + + return token; +} + + void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { auto * p = (llama_sampler_chain *) chain->ctx; p->samplers.push_back({ diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 18cae29ece..6a963c0bb7 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -25,6 +25,9 @@ struct llama_sampler_chain { std::vector samplers; + // pre-allocated buffer for llama_sampler_sample to avoid repeated allocations + std::vector cur; + // timing mutable int64_t t_sample_us; diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index ae1a497be6..a39b4c5b35 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -38,14 +38,6 @@ set(TARGET_SRCS server-http.h server-models.cpp server-models.h - server-task.cpp - server-task.h - server-queue.cpp - server-queue.h - server-common.cpp - server-common.h - server-context.cpp - server-context.h ) set(PUBLIC_ASSETS index.html.gz