diff --git a/common/arg.cpp b/common/arg.cpp index f2aec895ba..32e875972b 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1671,6 +1671,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sampling.grammar = json_schema_to_grammar(json::parse(schema)); } ).set_sparam()); + add_opt(common_arg( + {"--backend-sampling"}, + "enable backend sampling (default: disabled)", + [](common_params & params) { + params.sampling.backend_sampling = true; + } + ).set_sparam().set_env("LLAMA_ARG_BACKEND_SAMPLING")); add_opt(common_arg( {"--pooling"}, "{none,mean,cls,last,rank}", "pooling type for embeddings, use model default if unspecified", diff --git a/common/common.cpp b/common/common.cpp index 5a8cf52485..8e4490ba71 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1084,6 +1084,7 @@ struct common_init_result::impl { std::vector lora; std::vector samplers; + std::vector samplers_seq_config; }; common_init_result::common_init_result(common_params & params) : @@ -1141,10 +1142,19 @@ common_init_result::common_init_result(common_params & params) : // params.sampling.dry_penalty_last_n = llama_n_ctx(lctx); //} + // init the backend samplers as part of the context creation pimpl->samplers.resize(cparams.n_seq_max); + pimpl->samplers_seq_config.resize(cparams.n_seq_max); for (int i = 0; i < (int) cparams.n_seq_max; ++i) { pimpl->samplers[i].reset(common_sampler_init(model, params.sampling)); + pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) }; + } + + // TODO: temporarily gated behind a flag + if (params.sampling.backend_sampling) { + cparams.samplers = pimpl->samplers_seq_config.data(); + cparams.n_samplers = pimpl->samplers_seq_config.size(); } llama_context * lctx = llama_init_from_model(model, cparams); diff --git a/common/common.h b/common/common.h index d70744840f..2896dc388f 100644 --- a/common/common.h +++ b/common/common.h @@ -216,6 +216,8 @@ struct common_params_sampling { std::vector logit_bias; // logit biases to apply std::vector logit_bias_eog; // pre-calculated logit biases for EOG tokens + bool backend_sampling = false; + bool has_logit_bias() const { return !logit_bias.empty(); } diff --git a/common/llguidance.cpp b/common/llguidance.cpp index adce620e4d..d58f147a76 100644 --- a/common/llguidance.cpp +++ b/common/llguidance.cpp @@ -106,12 +106,16 @@ static void llama_sampler_llg_free(llama_sampler * smpl) { } static llama_sampler_i llama_sampler_llg_i = { - /* .name = */ llama_sampler_llg_name, - /* .accept = */ llama_sampler_llg_accept_impl, - /* .apply = */ llama_sampler_llg_apply, - /* .reset = */ llama_sampler_llg_reset, - /* .clone = */ llama_sampler_llg_clone, - /* .free = */ llama_sampler_llg_free, + /* .name = */ llama_sampler_llg_name, + /* .accept = */ llama_sampler_llg_accept_impl, + /* .apply = */ llama_sampler_llg_apply, + /* .reset = */ llama_sampler_llg_reset, + /* .clone = */ llama_sampler_llg_clone, + /* .free = */ llama_sampler_llg_free, + /* .backend_init = */ NULL, + /* .backend_accept = */ NULL, + /* .backend_apply = */ NULL, + /* .backend_set_input = */ NULL, }; static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len, diff --git a/common/sampling.cpp b/common/sampling.cpp index 6935d84e22..aefc596443 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -121,17 +121,34 @@ struct common_sampler { } void set_logits(struct llama_context * ctx, int idx) { - const auto * logits = llama_get_logits_ith(ctx, idx); + const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx); + const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx); + const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx); const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); const int n_vocab = llama_vocab_n_tokens(vocab); - cur.resize(n_vocab); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + if (sampled_probs) { + const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx); + cur.resize(sampled_probs_count); + for (uint32_t i = 0; i < sampled_probs_count; ++i) { + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]}; + } + } else if (sampled_logits) { + const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx); + cur.resize(sampled_logits_count); + for (uint32_t i = 0; i < sampled_logits_count; i++) { + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}; + } + } else { + const auto * logits = llama_get_logits_ith(ctx, idx); + GGML_ASSERT(logits != nullptr); + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } } cur_p = { cur.data(), cur.size(), -1, false }; @@ -421,6 +438,23 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co auto & chain = gsmpl->chain; auto & cur_p = gsmpl->cur_p; // initialized by set_logits + // Check if a backend sampler has already sampled a token in which case we + // return that token id directly. + { + id = llama_get_sampled_token_ith(ctx, idx); + + if (id != LLAMA_TOKEN_NULL) { + LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id); + + // TODO: simplify + gsmpl->cur.resize(1); + gsmpl->cur[0] = { id, 0.0f, 1.0f }; + cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true }; + + return id; + } + } + gsmpl->set_logits(ctx, idx); llama_sampler_apply(chain, &cur_p); diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 36a12d299f..6b134b4f6f 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -68,7 +68,7 @@ int main(int argc, char ** argv) { auto sparams = llama_sampler_chain_default_params(); sparams.no_perf = false; - std::vector samplers; + std::vector sampler_configs; for (int32_t i = 0; i < n_parallel; ++i) { llama_sampler * smpl = llama_sampler_chain_init(sparams); @@ -78,7 +78,13 @@ int main(int argc, char ** argv) { llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp)); llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed)); - samplers.push_back(smpl); + sampler_configs.push_back({ i, smpl }); + } + + // TODO: temporarily gated behind a flag + if (params.sampling.backend_sampling) { + ctx_params.samplers = sampler_configs.data(); + ctx_params.n_samplers = sampler_configs.size(); } llama_context * ctx = llama_init_from_model(model, ctx_params); @@ -180,7 +186,7 @@ int main(int argc, char ** argv) { continue; } - const llama_token new_token_id = llama_sampler_sample(samplers[i], ctx, i_batch[i]); + const llama_token new_token_id = llama_sampler_sample(sampler_configs[i].sampler, ctx, i_batch[i]); // is it an end of generation? -> mark the stream as finished if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) { @@ -236,15 +242,15 @@ int main(int argc, char ** argv) { __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); LOG("\n"); - llama_perf_sampler_print(samplers[0]); + llama_perf_sampler_print(sampler_configs[0].sampler); llama_perf_context_print(ctx); fprintf(stderr, "\n"); llama_batch_free(batch); - for (auto & sampler_config : samplers) { - llama_sampler_free(sampler_config); + for (auto & sampler_config : sampler_configs) { + llama_sampler_free(sampler_config.sampler); } llama_free(ctx); diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 67af1d8ccc..9e30ec4556 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -40,6 +40,20 @@ if (CUDAToolkit_FOUND) enable_language(CUDA) + # Remove once CCCL 3.2 has been released and bundled with CUDA Toolkit + if (GGML_CUDA_CUB_3DOT2) + include(FetchContent) + + FetchContent_Declare( + CCCL + GIT_REPOSITORY https://github.com/nvidia/cccl.git + GIT_TAG v3.2.0-rc1 + GIT_SHALLOW TRUE + ) + + FetchContent_MakeAvailable(CCCL) + endif() + file(GLOB GGML_HEADERS_CUDA "*.cuh") list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h") @@ -102,6 +116,9 @@ if (CUDAToolkit_FOUND) # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas) else () + if (GGML_CUDA_CUB_3DOT2) + target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL) + endif() if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1") target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) else() @@ -109,6 +126,9 @@ if (CUDAToolkit_FOUND) endif() endif() else() + if (GGML_CUDA_CUB_3DOT2) + target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL) + endif() target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas) endif() @@ -177,6 +197,10 @@ if (CUDAToolkit_FOUND) if (NOT MSVC) list(APPEND CUDA_CXX_FLAGS -Wno-pedantic) + else() + # CCCL 3.2 onwards will require a cpp-standard-compliant preprocessor for MSVC + # https://github.com/NVIDIA/cccl/pull/6827 + list(APPEND CUDA_CXX_FLAGS /Zc:preprocessor) endif() list(JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED) # pass host compiler flags as a single argument diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index da9652c3be..57c8a99a28 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -22,13 +22,13 @@ static __global__ void init_offsets(int * offsets, const int ncols, const int nr } #ifdef GGML_CUDA_USE_CUB -static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, - const float * x, - int * dst, - const int ncols, - const int nrows, - ggml_sort_order order, - cudaStream_t stream) { +void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, + const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream) { ggml_cuda_pool_alloc temp_indices_alloc(pool, ncols * nrows); ggml_cuda_pool_alloc temp_keys_alloc(pool, ncols * nrows); ggml_cuda_pool_alloc offsets_alloc(pool, nrows + 1); @@ -49,28 +49,49 @@ static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, size_t temp_storage_bytes = 0; if (order == GGML_SORT_ORDER_ASC) { - DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols * nrows, nrows, // num items, num segments - d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits - stream); + if (nrows == 1) { + DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream); + } else { + DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols * nrows, nrows, // num items, num segments + d_offsets, d_offsets + 1, stream); + } } else { - DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, - dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, - sizeof(float) * 8, stream); + if (nrows == 1) { + DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream); + } else { + DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, + dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream); + } } ggml_cuda_pool_alloc temp_storage_alloc(pool, temp_storage_bytes); void * d_temp_storage = temp_storage_alloc.get(); if (order == GGML_SORT_ORDER_ASC) { - DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, - ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8, - stream); + if (nrows == 1) { + DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream); + } else { + DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, + ncols * nrows, nrows, d_offsets, d_offsets + 1, stream); + } } else { - DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, - temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, - 0, sizeof(float) * 8, stream); + if (nrows == 1) { + DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream); + } else { + DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, + temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, + stream); + } } } #endif // GGML_CUDA_USE_CUB @@ -141,12 +162,12 @@ static int next_power_of_2(int x) { return n; } -static void argsort_f32_i32_cuda_bitonic(const float * x, - int * dst, - const int ncols, - const int nrows, - ggml_sort_order order, - cudaStream_t stream) { +void argsort_f32_i32_cuda_bitonic(const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream) { // bitonic sort requires ncols to be power of 2 const int ncols_pad = next_power_of_2(ncols); diff --git a/ggml/src/ggml-cuda/argsort.cuh b/ggml/src/ggml-cuda/argsort.cuh index 68a001547f..22b7306f20 100644 --- a/ggml/src/ggml-cuda/argsort.cuh +++ b/ggml/src/ggml-cuda/argsort.cuh @@ -1,3 +1,19 @@ #include "common.cuh" void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +#ifdef GGML_CUDA_USE_CUB +void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, + const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream); +#endif // GGML_CUDA_USE_CUB +void argsort_f32_i32_cuda_bitonic(const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 9fcb2f9fd2..bf60f27c07 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -915,15 +915,16 @@ struct ggml_cuda_device_info { int device_count; struct cuda_device_info { - int cc; // compute capability - int nsm; // number of streaming multiprocessors - size_t smpb; // max. shared memory per block - size_t smpbo; // max. shared memory per block (with opt-in) - bool integrated; // Device is integrated as opposed to discrete - bool vmm; // virtual memory support - size_t vmm_granularity; // granularity of virtual memory + int cc; // compute capability + int nsm; // number of streaming multiprocessors + size_t smpb; // max. shared memory per block + size_t smpbo; // max. shared memory per block (with opt-in) + bool integrated; // Device is integrated as opposed to discrete + bool vmm; // virtual memory support + size_t vmm_granularity; // granularity of virtual memory size_t total_vram; - int warp_size; // Number of threads in a dispatch + int warp_size; // Number of threads in a dispatch + bool supports_cooperative_launch; // whether cooperative launch is supported }; cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {}; diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index d2f2def8bd..1463bfa4f0 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -149,9 +149,34 @@ static __global__ void cumsum_kernel( } } +#ifdef GGML_CUDA_USE_CUB +template +static void cumsum_cub(ggml_cuda_pool & pool, + const T * src, + T * dst, + int64_t ne, + cudaStream_t stream) { + size_t tmp_size = 0; + + // Query how much temp storage CUDA UnBound (CUB) needs + cub::DeviceScan::InclusiveSum(nullptr, // d_temp_storage (null = just query size) + tmp_size, // reference to size (will be set by CUB) + src, // input pointer + dst, // output pointer + ne, // number of elements + stream // CUDA stream to use + ); + + ggml_cuda_pool_alloc tmp_alloc(pool, tmp_size); + + // Perform the inclusive scan + cub::DeviceScan::InclusiveSum((void *) tmp_alloc.get(), tmp_size, src, dst, ne, stream); +} +#endif // GGML_CUDA_USE_CUB + template static void cumsum_cuda( - const T * src, T * dst, + [[maybe_unused]] ggml_backend_cuda_context & ctx, const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, @@ -165,6 +190,15 @@ static void cumsum_cuda( if (is_contiguous) { use_cub = true; + const int64_t nrows = ne01 * ne02 * ne03; + // TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released + // Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004 + if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) { + for (int i=0; idata, (float *)dst->data, + ctx, (const float *)src0->data, (float *)dst->data, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ab0f6fe9ce..d70367bc0b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -19,6 +19,7 @@ #include "ggml-cuda/count-equal.cuh" #include "ggml-cuda/cpy.cuh" #include "ggml-cuda/cross-entropy-loss.cuh" +#include "ggml-cuda/cumsum.cuh" #include "ggml-cuda/diagmask.cuh" #include "ggml-cuda/diag.cuh" #include "ggml-cuda/fattn.cuh" @@ -44,6 +45,7 @@ #include "ggml-cuda/ssm-scan.cuh" #include "ggml-cuda/sum.cuh" #include "ggml-cuda/sumrows.cuh" +#include "ggml-cuda/top-k.cuh" #include "ggml-cuda/mean.cuh" #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/topk-moe.cuh" @@ -241,6 +243,14 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].nsm = prop.multiProcessorCount; info.devices[id].smpb = prop.sharedMemPerBlock; info.devices[id].warp_size = prop.warpSize; + +#ifndef GGML_USE_MUSA + int supports_coop_launch = 0; + CUDA_CHECK(cudaDeviceGetAttribute(&supports_coop_launch, cudaDevAttrCooperativeLaunch, id)); + info.devices[id].supports_cooperative_launch = !!supports_coop_launch; +#else + info.devices[id].supports_cooperative_launch = false; +#endif // !(GGML_USE_MUSA) #if defined(GGML_USE_HIP) info.devices[id].smpbo = prop.sharedMemPerBlock; @@ -2687,6 +2697,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SUM: ggml_cuda_op_sum(ctx, dst); break; + case GGML_OP_CUMSUM: + ggml_cuda_op_cumsum(ctx, dst); + break; case GGML_OP_SUM_ROWS: ggml_cuda_op_sum_rows(ctx, dst); break; @@ -2699,6 +2712,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SSM_SCAN: ggml_cuda_op_ssm_scan(ctx, dst); break; + case GGML_OP_TOP_K: + ggml_cuda_op_top_k(ctx, dst); + break; case GGML_OP_ARGSORT: ggml_cuda_op_argsort(ctx, dst); break; @@ -2708,9 +2724,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CROSS_ENTROPY_LOSS: ggml_cuda_cross_entropy_loss(ctx, dst); break; - case GGML_OP_CUMSUM: - ggml_cuda_op_cumsum(ctx, dst); - break; case GGML_OP_TRI: ggml_cuda_op_tri(ctx, dst); break; @@ -4600,6 +4613,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return true; case GGML_OP_SUM: return ggml_is_contiguous_rows(op->src[0]); + case GGML_OP_TOP_K: case GGML_OP_ARGSORT: #ifndef GGML_CUDA_USE_CUB return op->src[0]->ne[0] <= 1024; diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index eeacde0bdb..1ae84ebf63 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -1,6 +1,14 @@ #include "common.cuh" #include "ggml.h" #include "softmax.cuh" + +#ifdef GGML_USE_HIP +#include +#else +#include +#include +#endif // GGML_USE_HIP + #include #include @@ -160,6 +168,156 @@ static __global__ void soft_max_f32( dst[col] = vals[col] * inv_sum; } } + + +// TODO: This is a common pattern used across kernels that could be moved to common.cuh + templated +static __device__ float two_stage_warp_reduce_max(float val) { + val = warp_reduce_max(val); + if (blockDim.x > WARP_SIZE) { + assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); + __shared__ float local_vals[32]; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + local_vals[warp_id] = val; + } + __syncthreads(); + val = -INFINITY; + if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) { + val = local_vals[lane_id]; + } + return warp_reduce_max(val); + } else { + return val; + } +} + +static __device__ float two_stage_warp_reduce_sum(float val) { + val = warp_reduce_sum(val); + if (blockDim.x > WARP_SIZE) { + assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); + __shared__ float local_vals[32]; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + local_vals[warp_id] = val; + } + __syncthreads(); + val = 0.0f; + if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) { + val = local_vals[lane_id]; + } + return warp_reduce_sum(val); + } else { + return val; + } +} + +// TODO: Template to allow keeping ncols in registers if they fit +static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x, + float * __restrict__ dst, + float * __restrict__ tmp_maxs, + float * __restrict__ tmp_sums, + const soft_max_params p) { + namespace cg = cooperative_groups; + + const cg::grid_group g = cg::this_grid(); + + const int tid = threadIdx.x; + const int col_start = blockIdx.x * blockDim.x + tid; + const int n_elem_per_thread = 4; + + float local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY }; + float local_max = -INFINITY; + const int step_size = gridDim.x * blockDim.x; + + // Compute thread-local max + for (int col = col_start; col < p.ncols;) { +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY; + } +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + local_max = fmaxf(local_max, local_vals[i]); + } + col += step_size * n_elem_per_thread; + } + + // Compute CTA-level max + local_max = two_stage_warp_reduce_max(local_max); + + // Store CTA-level max to GMEM + if (tid == 0) { + tmp_maxs[blockIdx.x] = local_max; + } + g.sync(); + + // Compute compute global max from CTA-level maxs + assert(gridDim.x < blockDim.x); // currently we only support this case + if (tid < gridDim.x) { + local_max = tmp_maxs[tid]; + } else { + local_max = -INFINITY; + } + local_max = two_stage_warp_reduce_max(local_max); + + // Compute softmax dividends, accumulate divisor + float tmp_expf = 0.0f; + for (int col = col_start; col < p.ncols;) { +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY; + } +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + if (idx < p.ncols) { + const float tmp = expf(local_vals[i] - local_max); + tmp_expf += tmp; + dst[idx] = tmp; + } + } + col += step_size * n_elem_per_thread; + } + + // Reduce divisor within CTA + tmp_expf = two_stage_warp_reduce_sum(tmp_expf); + + // Store CTA-level sum to GMEM + if (tid == 0) { + tmp_sums[blockIdx.x] = tmp_expf; + } + g.sync(); + + // Compute global sum from CTA-level sums + if (tid < gridDim.x) { + tmp_expf = tmp_sums[tid]; + } else { + tmp_expf = 0.0f; + } + tmp_expf = two_stage_warp_reduce_sum(tmp_expf); + + // Divide dividend by global sum + store data + for (int col = col_start; col < p.ncols;) { +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + local_vals[i] = idx < p.ncols ? dst[idx] : -INFINITY; + } +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + if (idx < p.ncols) { + dst[idx] = local_vals[i] / tmp_expf; + } + } + col += step_size * n_elem_per_thread; + } +} + #ifdef __clang__ #pragma clang diagnostic pop #endif // __clang__ @@ -216,9 +374,31 @@ static void launch_soft_max_kernels(const float * x, const T * mask, const float soft_max_f32<<>>(x, mask, sinks, dst, p); } +__launch_bounds__(8*WARP_SIZE, 1) static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x, + float * __restrict__ dst, + float * __restrict__ tmp_maxs, + float * __restrict__ tmp_sums, + const soft_max_params p) +// We loop over all instead of parallelizing across gridDim.y as cooperative groups +// currently only support synchronizing the complete grid if not launched as a cluster group +// (which requires CC > 9.0) +// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#grid-synchronization +// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#class-cluster-group +{ + for (int rowx = 0; rowx < p.ne01 * p.ne02 * p.ne03; rowx++) { + soft_max_f32_parallelize_cols_single_row(x + int64_t(rowx) * p.ncols, dst + int64_t(rowx) * p.ncols, tmp_maxs, + tmp_sums, p); + } +} -template -static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) { +template +static void soft_max_f32_cuda(const float * x, + const T * mask, + const float * sinks, + float * dst, + const soft_max_params & params, + cudaStream_t stream, + [[maybe_unused]] ggml_backend_cuda_context & ctx) { int nth = WARP_SIZE; const int64_t ncols_x = params.ncols; @@ -236,8 +416,25 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const float * sin if (nbytes_shared <= smpbo) { launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared); } else { - const size_t nbytes_shared_low = WARP_SIZE*sizeof(float); - soft_max_f32<<>>(x, mask, sinks, dst, params); + // Parallelize across SMs for top-p/dist-sampling + // The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and + // Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution. + if (ggml_cuda_info().devices[id].supports_cooperative_launch && + ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr && + params.scale == 1.0f && params.max_bias == 0.0f) { + ggml_cuda_pool_alloc tmp_maxs_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float)); + ggml_cuda_pool_alloc tmp_sums_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float)); + + void * kernel_args[] = { (void *) &x, (void *) &dst, (void *) &tmp_maxs_alloc.ptr, + (void *) &tmp_sums_alloc.ptr, (void *) const_cast(¶ms) }; + CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols, + dim3(ggml_cuda_info().devices[id].nsm, 1, 1), + dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream)); + } else { + const size_t nbytes_shared_low = WARP_SIZE * sizeof(float); + soft_max_f32 + <<>>(x, mask, sinks, dst, params); + } } } @@ -315,9 +512,9 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { params.m1 = m1; if (use_f16) { - soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream); + soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx); } else { - soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream); + soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx); } } diff --git a/ggml/src/ggml-cuda/top-k.cu b/ggml/src/ggml-cuda/top-k.cu new file mode 100644 index 0000000000..7d66fec495 --- /dev/null +++ b/ggml/src/ggml-cuda/top-k.cu @@ -0,0 +1,110 @@ +#include "argsort.cuh" +#include "top-k.cuh" + +#ifdef GGML_CUDA_USE_CUB +# include +# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2) +# define CUB_TOP_K_AVAILABLE +using namespace cub; +# endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2 +#endif // GGML_CUDA_USE_CUB + +#ifdef CUB_TOP_K_AVAILABLE +static __global__ void init_indices(int * indices, const int ncols) { + const int col = blockIdx.x * blockDim.x + threadIdx.x; + + if (col < ncols) { + indices[col] = col; + } +} + +static void top_k_cub(ggml_cuda_pool & pool, + const float * src, + int * dst, + const int ncols, + const int k, + cudaStream_t stream) { + auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed, + cuda::execution::output_ordering::unsorted); + auto stream_env = cuda::stream_ref{ stream }; + auto env = cuda::std::execution::env{ stream_env, requirements }; + + ggml_cuda_pool_alloc temp_indices_alloc(pool, ncols); + ggml_cuda_pool_alloc temp_keys_alloc(pool, ncols); + + int * temp_indices = temp_indices_alloc.get(); + float * temp_keys = temp_keys_alloc.get(); + + static const int block_size = 256; + const dim3 grid_size((ncols + block_size - 1) / block_size, 1); + init_indices<<>>(temp_indices, ncols); + + CUDA_CHECK(cudaMemcpyAsync(temp_keys, src, ncols * sizeof(float), cudaMemcpyDeviceToDevice, stream)); + + size_t temp_storage_bytes = 0; + DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols, k, env); + + ggml_cuda_pool_alloc temp_storage_alloc(pool, temp_storage_bytes); + void * d_temp_storage = temp_storage_alloc.get(); + + DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols, k, env); +} + +#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE + +static int next_power_of_2(int x) { + int n = 1; + while (n < x) { + n *= 2; + } + return n; +} + +#endif // CUB_TOP_K_AVAILABLE + +void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *) src0->data; + int * dst_d = (int *) dst->data; + cudaStream_t stream = ctx.stream(); + + // are these asserts truly necessary? + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + const int64_t k = dst->ne[0]; + ggml_cuda_pool & pool = ctx.pool(); +#ifdef CUB_TOP_K_AVAILABLE + // TODO: Switch to `DeviceSegmentedTopK` for multi-row TopK once implemented + // https://github.com/NVIDIA/cccl/issues/6391 + // TODO: investigate if there exists a point where parallelized argsort is faster than sequential top-k + for (int i = 0; i < nrows; i++) { + top_k_cub(pool, src0_d + i * ncols, dst_d + i * k, ncols, k, stream); + } +#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE + // Fall back to argsort + copy + const int ncols_pad = next_power_of_2(ncols); + const size_t shared_mem = ncols_pad * sizeof(int); + const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb; + + ggml_cuda_pool_alloc temp_dst_alloc(pool, ncols * nrows); + int * tmp_dst = temp_dst_alloc.get(); + + if (shared_mem > max_shared_mem || ncols > 1024) { + argsort_f32_i32_cuda_cub(pool, src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream); + } else { + argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream); + } + CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows, + cudaMemcpyDeviceToDevice, stream)); +#else // GGML_CUDA_USE_CUB + ggml_cuda_pool_alloc temp_dst_alloc(pool, ncols * nrows); + int * tmp_dst = temp_dst_alloc.get(); + argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream); + CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows, + cudaMemcpyDeviceToDevice, stream)); +#endif +} diff --git a/ggml/src/ggml-cuda/top-k.cuh b/ggml/src/ggml-cuda/top-k.cuh new file mode 100644 index 0000000000..f4d8f61e5b --- /dev/null +++ b/ggml/src/ggml-cuda/top-k.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 951a88d567..016b04e5a0 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -45,9 +45,11 @@ #define cublasSgemm hipblasSgemm #define cublasStatus_t hipblasStatus_t #define cublasOperation_t hipblasOperation_t +#define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess +#define cudaDeviceGetAttribute hipDeviceGetAttribute #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceSynchronize hipDeviceSynchronize #define cudaError_t hipError_t @@ -70,6 +72,7 @@ #define cudaHostRegisterPortable hipHostRegisterPortable #define cudaHostRegisterReadOnly hipHostRegisterReadOnly #define cudaHostUnregister hipHostUnregister +#define cudaLaunchCooperativeKernel hipLaunchCooperativeKernel #define cudaLaunchHostFunc hipLaunchHostFunc #define cudaMalloc hipMalloc #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 221e67f96a..1abb8acfd4 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -61,6 +61,7 @@ #define cudaHostRegisterPortable musaHostRegisterPortable #define cudaHostRegisterReadOnly musaHostRegisterReadOnly #define cudaHostUnregister musaHostUnregister +#define cudaLaunchCooperativeKernel musaLaunchCooperativeKernel #define cudaLaunchHostFunc musaLaunchHostFunc #define cudaMalloc musaMalloc #define cudaMallocHost musaMallocHost diff --git a/include/llama.h b/include/llama.h index f862930099..5ba9d6cb32 100644 --- a/include/llama.h +++ b/include/llama.h @@ -316,6 +316,11 @@ extern "C" { bool no_alloc; // only load metadata and simulate memory allocations }; + struct llama_sampler_seq_config { + llama_seq_id seq_id; + struct llama_sampler * sampler; + }; + // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations // https://github.com/ggml-org/llama.cpp/pull/7544 struct llama_context_params { @@ -364,6 +369,11 @@ extern "C" { bool kv_unified; // use a unified buffer across the input sequences when computing the attention // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix // ref: https://github.com/ggml-org/llama.cpp/pull/14363 + + // backend sampler chain configuration (make sure the caller keeps the sampler chains alive) [EXPERIMENTAL] + // note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init) + struct llama_sampler_seq_config * samplers; + size_t n_samplers; }; // model quantization parameters @@ -983,6 +993,32 @@ extern "C" { // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); + // Get the backend sampled token for the ith token. + // Returns LLAMA_TOKEN_NULL if no token was sampled. + LLAMA_API llama_token llama_get_sampled_token_ith(struct llama_context * ctx, int32_t i); + + // Get the backend sampled probabilites for the ith token + // The index matches llama_get_sampled_token_ith(). + // Returns NULL if no probabilites were generated. + LLAMA_API float * llama_get_sampled_probs_ith(struct llama_context * ctx, int32_t i); + // + // Get the number of backend sampled probabilites for the ith token. + LLAMA_API uint32_t llama_get_sampled_probs_count_ith(struct llama_context * ctx, int32_t i); + + // Get the backend sampled logits for the ith token + // Returns NULL if no logits were sampled. + LLAMA_API float * llama_get_sampled_logits_ith(struct llama_context * ctx, int32_t i); + // + // Get the number of backend sampled logits for the ith token. + LLAMA_API uint32_t llama_get_sampled_logits_count_ith(struct llama_context * ctx, int32_t i); + + // Get the backend sampled candidates (token ids) for the ith token + // Returns NULL if no candidates were sampled. + LLAMA_API llama_token * llama_get_sampled_candidates_ith(struct llama_context * ctx, int32_t i); + // + // Get the number of backend sampled candidates for the ith token. + LLAMA_API uint32_t llama_get_sampled_candidates_count_ith(struct llama_context * ctx, int32_t i); + // // Vocab // @@ -1154,11 +1190,16 @@ extern "C" { // // llama_sampler_free(smpl); // - // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). - // typedef void * llama_sampler_context_t; + struct llama_sampler_data { + struct ggml_tensor * logits; + struct ggml_tensor * probs; + struct ggml_tensor * sampled; + struct ggml_tensor * candidates; + }; + // user code can implement the interface below in order to create custom llama_sampler struct llama_sampler_i { const char * (*name) (const struct llama_sampler * smpl); // can be NULL @@ -1168,17 +1209,40 @@ extern "C" { struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL - // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph - //void (*apply_ggml) (struct llama_sampler * smpl, ...); + // backend sampling interface: + + // return true if the backend supports all ops needed by the sampler + // note: call once per sampler + bool (*backend_init)(struct llama_sampler * smpl, ggml_backend_buffer_type_t buft); + + // call after .backend_accept() + void (*backend_accept)( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct ggml_tensor * selected_token); + + // call after .backend_init() + void (*backend_apply)( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data); + + // call before .backend_apply() + void (*backend_set_input)(struct llama_sampler * smpl); }; struct llama_sampler { - const struct llama_sampler_i * iface; - llama_sampler_context_t ctx; + struct llama_sampler_i * iface; + + llama_sampler_context_t ctx; }; + LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl); + // mirror of llama_sampler_i: - LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx); + LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_i * iface, llama_sampler_context_t ctx); LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); @@ -1194,7 +1258,15 @@ extern "C" { // important: takes ownership of the sampler object and will free it when llama_sampler_free is called LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl); - LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i); + + // return NULL if: + // - the sampler is NULL + // - the sampler is not a llama_sampler_chain + // - the index is out of bounds, unless i == -1 + // - if i == -1, returns the chain itself (can be used to check if the sampler is a chain) + LLAMA_API struct llama_sampler * llama_sampler_chain_get( struct llama_sampler * chain, int32_t i); + + // the total number of samplers in the chain LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain); // after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 386fab04ac..627ffca916 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -28,7 +28,8 @@ bool llama_batch_allocr::init( const llama_memory_i * memory, uint32_t n_embd, uint32_t n_seq_max, - bool output_all) { + bool output_all, + bool sampling) { clear(); batch = batch_inp; @@ -145,6 +146,24 @@ bool llama_batch_allocr::init( } } + if (sampling) { + std::vector seq_output_count(n_seq_max, 0); + + for (int32_t i = 0; i < batch.n_tokens; ++i) { + if (batch.logits[i] == 0) { + continue; + } + for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { + const llama_seq_id seq_id = batch.seq_id[i][s]; + seq_output_count[seq_id]++; + if (seq_output_count[seq_id] > 1) { + LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (%d)\n", __func__, seq_id); + return false; + } + } + } + } + // // compute stats // diff --git a/src/llama-batch.h b/src/llama-batch.h index 8e6fac0efa..05c03d018d 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -81,7 +81,8 @@ public: const llama_memory_i * memory, uint32_t n_embd, uint32_t n_seq_max, - bool output_all); + bool output_all, + bool sampling = false); const llama_batch & get_batch() const; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 8786d4ee3e..133124b4f5 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -60,6 +60,25 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; + // Initialize backend samplers here so they are part of the sampling graph + // before the reserve passes run later in this function. This avoids a later + // re-reserve when graph nodes change. + if (params.samplers != nullptr && params.n_samplers > 0) { + for (size_t i = 0; i < params.n_samplers; ++i) { + const auto & config = params.samplers[i]; + + if (llama_sampler_chain_get(config.sampler, -1) == nullptr) { + throw std::runtime_error("the backend samplers must be of type llama_sampler_chain"); + } + + if (set_sampler(config.seq_id, config.sampler)) { + const int n_samplers = llama_sampler_chain_n(config.sampler); + + LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers); + } + } + } + auto rope_scaling_type = params.rope_scaling_type; if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) { rope_scaling_type = hparams.rope_scaling_type_train; @@ -231,7 +250,10 @@ llama_context::llama_context( // graph outputs buffer { // resized during inference when a batch uses more outputs - if (output_reserve(params.n_seq_max) < params.n_seq_max) { + // Create a dummy batch for initialization. + llama_batch dummy_batch = {}; + dummy_batch.n_tokens = 0; + if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) { throw std::runtime_error("failed to reserve initial output buffer"); } @@ -456,6 +478,16 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); } } + + // Initialize the full vocabulary token ids for backend samplers. + { + const int n_vocab = model.vocab.n_tokens(); + + sampling.token_ids_full_vocab.resize(n_vocab); + for (int i = 0; i < n_vocab; ++i) { + sampling.token_ids_full_vocab[i] = i; + } + } } llama_context::~llama_context() { @@ -617,6 +649,35 @@ float * llama_context::get_logits() { return logits; } +int64_t llama_context::resolve_output_row(int32_t i) const { + int64_t j = -1; + + // support negative indices (last output row) + if (i < 0) { + j = n_outputs + i; + if (j < 0) { + throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); + } + } else if ((size_t) i >= output_ids.size()) { + throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); + } else { + // use output_ids to translate the batch token index into a row number + // that holds this token's data. + j = output_ids[i]; + } + + if (j < 0) { + // the batch token was not configured to output anything + throw std::runtime_error(format("batch.logits[%d] != true", i)); + } + + if (j >= n_outputs) { + throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); + } + + return j; +} + float * llama_context::get_logits_ith(int32_t i) { int64_t j = -1; @@ -663,6 +724,10 @@ float * llama_context::get_embeddings() { return embd; } +llama_token * llama_context::get_sampled_tokens() { + return sampling.sampled; +} + float * llama_context::get_embeddings_ith(int32_t i) { int64_t j = -1; @@ -712,6 +777,136 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +llama_token llama_context::get_sampled_token_ith(int32_t idx) { + output_reorder(); + + if (sampling.sampled == nullptr) { + return LLAMA_TOKEN_NULL; + } + + try { + const int64_t row = resolve_output_row(idx); + GGML_ASSERT(row < (int64_t) sampling.sampled_size); + return sampling.sampled[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what()); + return LLAMA_TOKEN_NULL; + } +} + +float * llama_context::get_sampled_probs_ith(int32_t idx) { + output_reorder(); + + if (sampling.probs == nullptr) { + return nullptr; + } + + try { + const int64_t row = resolve_output_row(idx); + if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) { + return nullptr; + } + return sampling.probs + row*model.vocab.n_tokens(); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what()); + return nullptr; + } +} + +float * llama_context::get_sampled_logits_ith(int32_t idx) { + output_reorder(); + + if (sampling.logits == nullptr) { + return nullptr; + } + + try { + const int64_t row = resolve_output_row(idx); + if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) { + return nullptr; + } + return sampling.logits + row*model.vocab.n_tokens(); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what()); + return nullptr; + } +} + +const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) { + output_reorder(); + + try { + const int64_t row = resolve_output_row(idx); + if (sampling.candidates != nullptr && + (size_t) row < sampling.candidates_count.size() && + sampling.candidates_count[row] > 0) { + return sampling.candidates + row*model.vocab.n_tokens(); + } + } catch (const std::exception & err) { + // fallback to full vocab list + } + + return sampling.token_ids_full_vocab.data(); +} + +size_t llama_context::get_sampled_candidates_count(int32_t idx) { + output_reorder(); + + if (sampling.candidates == nullptr) { + return 0; + } + + try { + const int64_t row = resolve_output_row(idx); + if ((size_t) row >= sampling.candidates_count.size()) { + return 0; + } + return sampling.candidates_count[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled candidates count id %d, reason: %s\n", __func__, idx, err.what()); + return 0; + } +} + +size_t llama_context::get_sampled_logits_count(int32_t idx) { + output_reorder(); + + if (sampling.logits == nullptr) { + return model.vocab.n_tokens(); + } + + try { + const int64_t row = resolve_output_row(idx); + if ((size_t) row >= sampling.logits_count.size()) { + return 0; + } + return sampling.logits_count[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled logits count id %d, reason: %s\n", __func__, idx, err.what()); + return 0; + } +} + +size_t llama_context::get_sampled_probs_count(int32_t idx) { + output_reorder(); + + if (sampling.probs == nullptr) { + return 0; + } + + try { + const int64_t row = resolve_output_row(idx); + if ((size_t) row >= sampling.probs_count.size()) { + return 0; + } + return sampling.probs_count[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled probs count id %d, reason: %s\n", __func__, idx, err.what()); + return 0; + } +} + + void llama_context::attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch) { @@ -768,6 +963,42 @@ void llama_context::set_warmup(bool value) { cparams.warmup = value; } +bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { + LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); + + const bool can_offload = + sampler && + sampler->iface->backend_init && + sampler->iface->backend_apply && + llama_sampler_chain_n(sampler) > 0; + + if (sampler && can_offload) { + ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output()); + auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output()); + if (host_buft) { + buft = host_buft; + } + + sampler->iface->backend_init(sampler, buft); + + sampling.samplers[seq_id] = sampler; + + return true; + } + + if (sampler && !can_offload) { + LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id); + + sampling.samplers.erase(seq_id); + + return false; + } + + sampling.samplers.erase(seq_id); + + return true; +} + void llama_context::set_adapter_lora( llama_adapter_lora * adapter, float scale) { @@ -908,7 +1139,7 @@ int llama_context::encode(const llama_batch & batch_inp) { n_queued_tokens += n_tokens; // reserve output buffer - if (output_reserve(n_tokens) < n_tokens) { + if (output_reserve(n_tokens, batch_inp) < n_tokens) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); return -2; }; @@ -1032,6 +1263,112 @@ int llama_context::encode(const llama_batch & batch_inp) { return 0; } +static std::map build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) { + std::map seq_to_row; + // how many output tokens we have seen so far for this ubatch. + uint32_t local = 0; + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + // skip tokens that are not output. + if (!ubatch.output[i]) { + continue; + } + + const llama_seq_id seq_id = ubatch.seq_id[i][0]; + // row_offset is the number of output tokens before this ubatch. + seq_to_row[seq_id] = row_offset + local; + ++local; + } + return seq_to_row; +} + +static void copy_tensor_async_ints( + const std::map & tensor_map, + llama_token * sampled, + size_t sampled_size, + const std::map & seq_to_row, + ggml_backend_sched_t sched) { + if (sampled == nullptr) { + return; + } + + for (const auto & [seq_id, tensor] : tensor_map) { + auto it = seq_to_row.find(seq_id); + if (it == seq_to_row.end()) { + continue; + } + + const uint32_t row = it->second; + GGML_ASSERT(row < sampled_size); + + GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy"); + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); + ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row])); + } +} + +static void copy_tensor_async_floats( + const std::map & tensor_map, + float * dst, + size_t stride, + std::vector & counts, + const std::map & seq_to_row, + ggml_backend_sched_t sched) { + if (dst == nullptr) { + return; + } + + for (const auto & [seq_id, tensor] : tensor_map) { + auto it = seq_to_row.find(seq_id); + if (it == seq_to_row.end()) { + continue; + } + + const uint32_t row = it->second; + GGML_ASSERT(row < counts.size()); + + GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy"); + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); + float * row_ptr = dst + (size_t) row * stride; + ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); + + // Update the actual number of logits/probabilities that were written for this row. + counts[row] = ggml_nelements(tensor); + } +} + +static void copy_tensor_async_candidates( + const std::map & tensor_map, + llama_token * dst, + size_t stride, + std::vector & counts, + const std::map & seq_to_row, + ggml_backend_sched_t sched) { + if (dst == nullptr) { + return; + } + + for (const auto & [seq_id, tensor] : tensor_map) { + auto it = seq_to_row.find(seq_id); + if (it == seq_to_row.end()) { + continue; + } + + const uint32_t row = it->second; + GGML_ASSERT(row < counts.size()); + + GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy"); + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); + llama_token * row_ptr = dst + (size_t) row * stride; + ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); + + // Update the actual number of candidates that were written. + counts[row] = ggml_nelements(tensor); + } +} + int llama_context::decode(const llama_batch & batch_inp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT @@ -1053,8 +1390,12 @@ int llama_context::decode(const llama_batch & batch_inp) { // when computing embeddings, all tokens are output const bool output_all = cparams.embeddings; + const bool has_samplers = !sampling.samplers.empty(); - if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) { + if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, + cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, + output_all, + has_samplers)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return -1; } @@ -1135,7 +1476,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } // reserve output buffer - if (output_reserve(n_outputs_all) < n_outputs_all) { + if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); return -2; }; @@ -1200,6 +1541,28 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} + // This flag indicates whether a backend sampler has actually sampled a specific + // token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings. + const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty(); + + if (has_samplers && has_sampled) { + const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); + const auto stride = n_vocab; + + // async copy the sampled tokens from the backend to the host. + copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get()); + + // async copy the sampled logits from the backend to the host. + copy_tensor_async_floats(res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get()); + + // async copy the sampled probablities from the backend to the host. + copy_tensor_async_floats(res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get()); + + // async copy the candidate token ids from the backend to the host. + // These are needed by CPU samplers to map probability/logit indices to vocab token ids. + copy_tensor_async_candidates(res->t_candidates, sampling.candidates, stride, sampling.candidates_count, seq_to_output_row, sched.get()); + } + auto * t_logits = res->get_logits(); auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; @@ -1208,7 +1571,10 @@ int llama_context::decode(const llama_batch & batch_inp) { } // extract logits - if (t_logits && n_outputs > 0) { + // For multipsequence batches that mix backend samplers and CPU sampler + // this is currently inefficient as we copy all logits even for the + // backend sampled tokens. + if (logits && t_logits && n_outputs > 0) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(logits != nullptr); @@ -1223,7 +1589,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } // extract embeddings - if (t_embd && n_outputs > 0) { + if (embd && t_embd && n_outputs > 0) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); @@ -1340,7 +1706,7 @@ int llama_context::decode(const llama_batch & batch_inp) { // output // -uint32_t llama_context::output_reserve(int32_t n_outputs) { +uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) { const auto & hparams = model.hparams; const auto & vocab = model.vocab; @@ -1359,8 +1725,51 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { has_embd = true; } - logits_size = has_logits ? n_vocab*n_outputs_max : 0; - embd_size = has_embd ? n_embd*n_outputs_max : 0; + // Check which sampling modes are needed by sequences in the current batch. + bool batch_has_sampling = false; + bool batch_needs_cpu_logits = false; + + if (batch.logits) { + for (int32_t i = 0; i < batch.n_tokens; i++) { + if (!batch.logits[i]) { + continue; + } + for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { + llama_seq_id seq_id = batch.seq_id[i][j]; + if (sampling.samplers.find(seq_id) != sampling.samplers.end()) { + batch_has_sampling = true; + } else { + batch_needs_cpu_logits = true; + } + } + } + } else { + // When batch.logits is nullptr (when loading state with a dummy batch), + // allocate CPU logits. + batch_needs_cpu_logits = true; + } + + size_t backend_float_count = 0; + size_t backend_token_count = 0; + + // Allocate CPU logits buffer only if needed by sequences in this batch + logits_size = (has_logits && batch_needs_cpu_logits) ? n_vocab*n_outputs_max : 0; + embd_size = has_embd ? n_embd*n_outputs_max : 0; + + if (!batch_has_sampling) { + sampling.logits_size = 0; + sampling.probs_size = 0; + sampling.sampled_size = 0; + sampling.candidates_size = 0; + } else { + sampling.logits_size = n_vocab*n_outputs_max; + sampling.probs_size = n_vocab*n_outputs_max; + sampling.sampled_size = n_outputs_max; + sampling.candidates_size = n_vocab*n_outputs_max; + + backend_float_count = sampling.logits_size + sampling.probs_size; + backend_token_count = sampling.sampled_size + sampling.candidates_size; + } if (output_ids.empty()) { // init, never resized afterwards @@ -1368,7 +1777,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { } const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; - const size_t new_size = (logits_size + embd_size) * sizeof(float); + const size_t new_size = (logits_size + embd_size + backend_float_count) * sizeof(float) + + backend_token_count * sizeof(llama_token); // alloc only when more than the current capacity is required // TODO: also consider shrinking the buffer @@ -1376,7 +1786,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { if (buf_output) { #ifndef NDEBUG // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark) - LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); + LLAMA_LOG_DEBUG("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); #endif synchronize(); buf_output = nullptr; @@ -1400,8 +1810,58 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); - logits = has_logits ? output_base : nullptr; - embd = has_embd ? output_base + logits_size : nullptr; + logits = nullptr; + embd = nullptr; + + // reset sampling pointers. + sampling.logits = nullptr; + sampling.probs = nullptr; + sampling.sampled = nullptr; + sampling.candidates = nullptr; + + size_t offset = 0; + uint8_t * base = (uint8_t *) output_base; + + logits = (has_logits && batch_needs_cpu_logits) ? output_base : nullptr; + offset += logits_size * sizeof(float); + + embd = has_embd ? (float *) (base + offset) : nullptr; + offset += embd_size * sizeof(float); + + if (batch_has_sampling) { + sampling.logits = (float *) (base + offset); + offset += sampling.logits_size * sizeof(float); + + sampling.probs = (float *) (base + offset); + offset += sampling.probs_size * sizeof(float); + + sampling.sampled = (llama_token *) (base + offset); + offset += sampling.sampled_size * sizeof(llama_token); + + sampling.candidates = (llama_token *) (base + offset); + offset += sampling.candidates_size * sizeof(llama_token); + + // The count vectors keep track of the actual number of logits/probs/candidates + // copied from the backend for each output row. + const size_t n_rows = (size_t) n_outputs_max; + if (sampling.outputs_capacity < n_rows) { + // The output size has increased, so resize and reset the count vectors. + sampling.outputs_capacity = n_rows; + + sampling.logits_count.assign(n_rows, 0); + sampling.probs_count.assign(n_rows, 0); + sampling.candidates_count.assign(n_rows, 0); + } else { + // The output size has not increased so just reset the counts to zero. + std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0); + std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0); + std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0); + } + + if (sampling.sampled) { + std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL); + } + } // set all ids as invalid (negative) std::fill(output_ids.begin(), output_ids.end(), -1); @@ -1430,6 +1890,40 @@ void llama_context::output_reorder() { std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); } } + + if (sampling.logits && sampling.logits_size > 0) { + for (uint64_t k = 0; k < n_vocab; ++k) { + std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]); + } + } + + if (sampling.probs && sampling.probs_size > 0) { + for (uint64_t k = 0; k < n_vocab; ++k) { + std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]); + } + } + + if (sampling.candidates && sampling.candidates_size > 0) { + for (uint64_t k = 0; k < n_vocab; ++k) { + std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]); + } + } + + if (sampling.sampled && sampling.sampled_size > 0) { + std::swap(sampling.sampled[i0], sampling.sampled[i1]); + } + + if (!sampling.logits_count.empty()) { + std::swap(sampling.logits_count[i0], sampling.logits_count[i1]); + } + + if (!sampling.probs_count.empty()) { + std::swap(sampling.probs_count[i0], sampling.probs_count[i1]); + } + + if (!sampling.candidates_count.empty()) { + std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]); + } } output_swaps.clear(); @@ -1476,6 +1970,15 @@ ggml_cgraph * llama_context::graph_reserve( llama_batch_allocr balloc(model.hparams.n_pos_per_embd()); llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs); + // set one output token per sequence in order to activate all backend samplers + std::vector seq_ids(n_seqs); + for (uint32_t i = 0; i < n_seqs; ++i) { + seq_ids[i] = i; + ubatch.n_seq_id[i] = 1; + ubatch.seq_id[i] = &seq_ids[i]; + ubatch.output[i] = true; + } + auto * res = gf_res_reserve.get(); const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); @@ -1506,7 +2009,7 @@ llm_graph_params llama_context::graph_params( llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx, - llm_graph_type gtype) const { + llm_graph_type gtype) const { return { /*.arch =*/ model.arch, /*.hparams =*/ model.hparams, @@ -1519,6 +2022,7 @@ llm_graph_params llama_context::graph_params( /*.loras =*/ &loras, /*.mctx =*/ mctx, /*.cross =*/ &cross, + /*.samplers =*/ sampling.samplers, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), /*.res =*/ res, @@ -2006,7 +2510,10 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { auto n_outputs = this->n_outputs; io.read_to(&n_outputs, sizeof(n_outputs)); - if (n_outputs > output_reserve(n_outputs)) { + // Create a dummy batch for state loading. + llama_batch dummy_batch = {}; + dummy_batch.n_tokens = 0; + if (n_outputs > output_reserve(n_outputs, dummy_batch)) { throw std::runtime_error("could not reserve outputs"); } @@ -2248,7 +2755,7 @@ void llama_context::opt_epoch_iter( } // reserve output buffer - if (output_reserve(n_outputs_all) < n_outputs_all) { + if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); GGML_ABORT("TODO: handle this error"); }; @@ -2393,6 +2900,8 @@ llama_context_params llama_context_default_params() { /*.op_offload =*/ true, /*.swa_full =*/ true, /*.kv_unified =*/ false, + /*.sampler =*/ nullptr, + /*.n_sampler =*/ 0, }; return result; @@ -2552,7 +3061,15 @@ float * llama_get_logits(llama_context * ctx) { float * llama_get_logits_ith(llama_context * ctx, int32_t i) { ctx->synchronize(); - return ctx->get_logits_ith(i); + float * res = nullptr; + + res = ctx->get_sampled_logits_ith(i); + + if (!res) { + res = ctx->get_logits_ith(i); + } + + return res; } float * llama_get_embeddings(llama_context * ctx) { @@ -2573,6 +3090,52 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } +bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { + return ctx->set_sampler(seq_id, smpl); +} + +llama_token llama_get_sampled_token_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_sampled_token_ith(i); +} + +float * llama_get_sampled_probs_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_sampled_probs_ith(i); +} + +float * llama_get_sampled_logits_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_sampled_logits_ith(i); +} + +llama_token * llama_get_sampled_candidates_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return const_cast(ctx->get_sampled_candidates_ith(i)); +} + +uint32_t llama_get_sampled_candidates_count_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return static_cast(ctx->get_sampled_candidates_count(i)); +} + +uint32_t llama_get_sampled_logits_count_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return static_cast(ctx->get_sampled_logits_count(i)); +} + +uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return static_cast(ctx->get_sampled_probs_count(i)); +} + // llama adapter API int32_t llama_set_adapter_lora( diff --git a/src/llama-context.h b/src/llama-context.h index c31101330e..602a55e4ce 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -70,6 +70,18 @@ struct llama_context { float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + llama_token * get_sampled_tokens(); + llama_token get_sampled_token_ith(int32_t idx); + + float * get_sampled_logits_ith(int32_t idx); + size_t get_sampled_logits_count(int32_t idx); + + float * get_sampled_probs_ith(int32_t idx); + size_t get_sampled_probs_count(int32_t idx); + + const llama_token * get_sampled_candidates_ith(int32_t idx); + size_t get_sampled_candidates_count(int32_t idx); + void attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch); @@ -192,9 +204,10 @@ private: // Make sure enough space is available for outputs. // Returns max number of outputs for which space was reserved. - uint32_t output_reserve(int32_t n_outputs); + uint32_t output_reserve(int32_t n_outputs, const llama_batch & batch); void output_reorder(); + int64_t resolve_output_row(int32_t i) const; // // graph @@ -213,6 +226,8 @@ public: ggml_cgraph * graph_reserve( uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false, size_t * sizes = nullptr); + bool set_sampler(llama_seq_id seq_id, llama_sampler * sampler); + private: llm_graph_params graph_params( llm_graph_result * res, @@ -247,6 +262,31 @@ private: size_t logits_size = 0; // capacity (of floats) for logits float * logits = nullptr; + struct sampling_info { + std::map samplers; + + float * logits = nullptr; + size_t logits_size = 0; + + llama_token * sampled = nullptr; + size_t sampled_size = 0; + + float * probs = nullptr; + size_t probs_size = 0; + + llama_token * candidates = nullptr; + size_t candidates_size = 0; + + size_t outputs_capacity = 0; + std::vector logits_count; + std::vector probs_count; + std::vector candidates_count; + + std::vector token_ids_full_vocab; + }; + + sampling_info sampling; + // embeddings output (2-dimensional array: [n_outputs][n_embd]) // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE size_t embd_size = 0; // capacity (of floats) for embeddings diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1d0d7197e1..ed757c27da 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -12,6 +12,7 @@ #include #include #include +#include void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { if (ubatch->token) { @@ -521,6 +522,43 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) { + // set the inputs only for the active samplers in the current ubatch + std::unordered_set active_samplers; + for (uint32_t i = 0; i < ubatch->n_tokens; i++) { + if (ubatch->output[i]) { + llama_seq_id seq_id = ubatch->seq_id[i][0]; + active_samplers.insert(seq_id); + } + } + + for (auto seq_id : active_samplers) { + if (samplers.find(seq_id) == samplers.end()) { + continue; + } + + auto & sampler = samplers[seq_id]; + + if (sampler->iface->backend_set_input) { + sampler->iface->backend_set_input(sampler); + } + } +} + +bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) { + if (samplers.size() != params.samplers.size()) { + return false; + } + + for (const auto & [seq_id, sampler] : params.samplers) { + if (samplers[seq_id] != sampler) { + return false; + } + } + + return true; +} + // // llm_graph_result // @@ -541,6 +579,10 @@ void llm_graph_result::reset() { t_logits = nullptr; t_embd = nullptr; t_embd_pooled = nullptr; + t_sampled.clear(); + t_sampled_probs.clear(); + t_sampled_logits.clear(); + t_candidates.clear(); params = {}; @@ -565,6 +607,38 @@ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) { } } +void llm_graph_result::set_outputs() { + if (t_logits != nullptr) { + ggml_set_output(t_logits); + } + if (t_embd != nullptr) { + ggml_set_output(t_embd); + } + if (t_embd_pooled != nullptr) { + ggml_set_output(t_embd_pooled); + } + for (auto & [seq_id, t] : t_sampled) { + if (t != nullptr) { + ggml_set_output(t); + } + } + for (auto & [seq_id, t] : t_sampled_probs) { + if (t != nullptr) { + ggml_set_output(t); + } + } + for (auto & [seq_id, t] : t_sampled_logits) { + if (t != nullptr) { + ggml_set_output(t); + } + } + for (auto & [seq_id, t] : t_candidates) { + if (t != nullptr) { + ggml_set_output(t); + } + } +} + bool llm_graph_result::can_reuse(const llm_graph_params & params) { if (!this->params.allow_reuse(params)) { if (debug > 1) { @@ -646,6 +720,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : loras (params.loras), mctx (params.mctx), cross (params.cross), + samplers (params.samplers), cb_func (params.cb), res (params.res), ctx0 (res->get_ctx()), @@ -1834,8 +1909,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); ggml_set_input(inp->self_kq_mask); + ggml_set_name(inp->self_kq_mask, "self_kq_mask"); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv"); } { @@ -1848,8 +1925,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); ggml_set_input(inp->self_kq_mask_swa); + ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa"); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv"); } return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp)); @@ -2086,6 +2165,86 @@ void llm_graph_context::build_pooling( ggml_build_forward_expand(gf, cur); } +void llm_graph_context::build_sampling() const { + if (samplers.empty() || !res->t_logits) { + return; + } + + auto inp_sampling = std::make_unique(samplers); + res->add_input(std::move(inp_sampling)); + + std::map seq_to_logit_row; + int32_t logit_row_idx = 0; + + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + if (ubatch.output[i]) { + llama_seq_id seq_id = ubatch.seq_id[i][0]; + seq_to_logit_row[seq_id] = logit_row_idx; + logit_row_idx++; + } + } + + // res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1) + GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor"); + + // add a dummy row of logits + // this trick makes the graph static, regardless of which samplers are activated + // this is important in order to minimize graph reallocations + ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0); + + for (const auto & [seq_id, sampler] : samplers) { + const auto it = seq_to_logit_row.find(seq_id); + + // inactive samplers alawys work on the first row + const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0; + + ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]); + ggml_format_name(logits_seq, "logits_seq_%d", seq_id); + + struct llama_sampler_data data = { + /*.logits =*/ logits_seq, + /*.probs =*/ nullptr, + /*.sampled =*/ nullptr, + /*.candidates =*/ nullptr, + }; + + assert(sampler->iface->backend_apply); + sampler->iface->backend_apply(sampler, ctx0, gf, &data); + + if (data.sampled != nullptr) { + res->t_sampled[seq_id] = data.sampled; + ggml_build_forward_expand(gf, data.sampled); + } + + if (data.probs != nullptr) { + res->t_sampled_probs[seq_id] = data.probs; + ggml_build_forward_expand(gf, data.probs); + } + + if (data.logits != nullptr) { + res->t_sampled_logits[seq_id] = data.logits; + ggml_build_forward_expand(gf, data.logits); + } + + if (data.candidates != nullptr) { + res->t_candidates[seq_id] = data.candidates; + ggml_build_forward_expand(gf, data.candidates); + } + } + + // TODO: Call llama_sampler_accept_ggml after all samplers have been applied. + /* + for (const auto & [seq_id, sampler] : samplers) { + if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) { + ggml_tensor * selected_token = it->second; + if (selected_token != nullptr) { + llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token); + } + } + } + */ +} + int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { // TODO move to hparams if a T5 variant appears that uses a different value const int64_t max_distance = 128; diff --git a/src/llama-graph.h b/src/llama-graph.h index 81ac329cc3..503ffd695a 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -10,6 +10,7 @@ #include #include #include +#include struct ggml_cgraph; struct ggml_context; @@ -396,6 +397,18 @@ public: const llama_memory_hybrid_context * mctx; }; +class llm_graph_input_sampling : public llm_graph_input_i { +public: + llm_graph_input_sampling(std::map samplers) : + samplers(std::move(samplers)) { } + virtual ~llm_graph_input_sampling() = default; + + void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + + std::map samplers; +}; + // // llm_graph_result // @@ -429,6 +442,23 @@ struct llm_graph_params { const llama_memory_context_i * mctx; const llama_cross * cross; + std::map samplers; + + static bool samplers_equal( + const std::map & lhs, + const std::map & rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (const auto & [seq_id, sampler] : lhs) { + auto it = rhs.find(seq_id); + if (it == rhs.end() || it->second != sampler) { + return false; + } + } + return true; + } + uint32_t n_outputs; llm_graph_cb cb; @@ -468,15 +498,36 @@ struct llm_graph_params { return false; } + if (n_outputs != other.n_outputs) { + return false; + } + + if (!samplers_equal(samplers, other.samplers)) { + return false; + } + + if (samplers.size() > 0) { + if (!ubatch.data || !other.ubatch.data) { + return false; + } + + // check that the outputs are the same for all samplers + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + if (ubatch.output[i] != other.ubatch.output[i] || + ubatch.seq_id[i][0] != other.ubatch.seq_id[i][0]) { + return false; + } + } + } + return cparams.embeddings == other.cparams.embeddings && cparams.causal_attn == other.cparams.causal_attn && - arch == other.arch && - gtype == other.gtype && - cvec == other.cvec && - loras == other.loras && - cross == other.cross && - n_outputs == other.n_outputs; + arch == other.arch && + gtype == other.gtype && + cvec == other.cvec && + loras == other.loras && + cross == other.cross; } }; @@ -499,6 +550,7 @@ public: void reset(); void set_inputs(const llama_ubatch * ubatch); + void set_outputs(); // try to update the existing graph result using the new graph parameters in order to reuse it // this can only be done if we determine that the resulting graph using the new graph parameters @@ -517,6 +569,11 @@ public: ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; + std::map t_sampled_logits; + std::map t_candidates; + std::map t_sampled; + std::map t_sampled_probs; + std::vector inputs; ggml_context_ptr ctx_compute; @@ -592,6 +649,8 @@ struct llm_graph_context { const llama_memory_context_i * mctx; const llama_cross * cross; + std::map samplers; + const llm_graph_cb & cb_func; llm_graph_result * res; @@ -832,6 +891,12 @@ struct llm_graph_context { ggml_tensor * cls_out, ggml_tensor * cls_out_b) const; + // + // sampling (backend sampling) + // + + void build_sampling() const; + // // dense (out) // diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c9a3c5dfa2..efe77d7324 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7642,12 +7642,17 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + // add backend sampling layers (if any) + llm->build_sampling(); + // if the gguf model was converted with --sentence-transformers-dense-modules // there will be two additional dense projection layers // dense linear projections are applied after pooling // TODO: move reranking logic here and generalize llm->build_dense_out(dense_2_out_layers, dense_3_out_layers); + llm->res->set_outputs(); + return llm->res->get_gf(); } diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 3f4a729bc3..15dafcf102 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -4,6 +4,8 @@ #include "llama-vocab.h" #include "llama-grammar.h" +#include "ggml-cpp.h" + #include #include #include @@ -346,7 +348,9 @@ static uint32_t get_rng_seed(uint32_t seed) { // llama_sampler API -struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) { +struct llama_sampler * llama_sampler_init( + struct llama_sampler_i * iface, + llama_sampler_context_t ctx) { return new llama_sampler { /* .iface = */ iface, /* .ctx = */ ctx, @@ -362,23 +366,39 @@ const char * llama_sampler_name(const struct llama_sampler * smpl) { } void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { + if (!smpl) { + return; + } + if (smpl->iface->accept) { smpl->iface->accept(smpl, token); } } void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) { + if (!smpl) { + return; + } + GGML_ASSERT(smpl->iface->apply); smpl->iface->apply(smpl, cur_p); } void llama_sampler_reset(struct llama_sampler * smpl) { + if (!smpl) { + return; + } + if (smpl->iface->reset) { smpl->iface->reset(smpl); } } struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) { + if (!smpl) { + return nullptr; + } + if (smpl->iface->clone) { return smpl->iface->clone(smpl); } @@ -406,7 +426,16 @@ void llama_sampler_free(struct llama_sampler * smpl) { } llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { - const auto * logits = llama_get_logits_ith(ctx, idx); + const llama_token sampled_token = llama_get_sampled_token_ith (ctx, idx); + const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx); + const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx); + const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx); + + // If a backend sampler has already sampled a token, return it. + if (sampled_token != LLAMA_TOKEN_NULL) { + LLAMA_LOG_DEBUG("%s: Backend sampler selected token for idx %d. Skipping CPU samplers\n", __func__, idx); + return sampled_token; + } const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); @@ -415,9 +444,26 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte // TODO: do not allocate each time std::vector cur; - cur.reserve(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + + if (sampled_probs) { + const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx); + cur.resize(sampled_probs_count); + for (uint32_t i = 0; i < sampled_probs_count; ++i) { + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]}; + } + } else if (sampled_logits) { + const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx); + cur.resize(sampled_logits_count); + for (llama_token i = 0; i < (int)sampled_logits_count; i++) { + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}; + } + } else { + const auto * logits = llama_get_logits_ith(ctx, idx); + GGML_ASSERT(logits != nullptr); + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } } llama_token_data_array cur_p = { @@ -438,6 +484,202 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte return token; } +// empty sampler + +struct llama_sampler_empty { + const char * name; +}; + +static struct llama_sampler * llama_sampler_init_empty(const char * name); + +static const char * llama_sampler_empty_name(const struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_empty *) smpl->ctx; + return ctx->name; +} + +static void llama_sampler_empty_accept(struct llama_sampler * smpl, llama_token token) { + GGML_UNUSED(smpl); + GGML_UNUSED(token); +} + +static void llama_sampler_empty_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + GGML_UNUSED(smpl); + GGML_UNUSED(cur_p); +} + +static void llama_sampler_empty_reset(struct llama_sampler * smpl) { + GGML_UNUSED(smpl); +} + +static struct llama_sampler * llama_sampler_empty_clone(const struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_empty *) smpl->ctx; + return llama_sampler_init_empty(ctx->name); +} + +static void llama_sampler_empty_free(struct llama_sampler * smpl) { + delete (llama_sampler_empty *) smpl->ctx; +} + +static bool llama_sampler_empty_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + GGML_UNUSED(smpl); + GGML_UNUSED(buft); + + return true; +} + +static void llama_sampler_empty_backend_accept( + struct llama_sampler * smpl, + ggml_context * ctx, + ggml_cgraph * gf, + struct ggml_tensor * selected_token) { + GGML_UNUSED(smpl); + GGML_UNUSED(ctx); + GGML_UNUSED(gf); + GGML_UNUSED(selected_token); +} + +static void llama_sampler_empty_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + GGML_UNUSED(smpl); + GGML_UNUSED(ctx); + GGML_UNUSED(gf); + GGML_UNUSED(data); +} + +static void llama_sampler_empty_backend_set_input(struct llama_sampler * smpl) { + GGML_UNUSED(smpl); +} + +static struct llama_sampler_i llama_sampler_empty_i = { + /* .name = */ llama_sampler_empty_name, + /* .accept = */ llama_sampler_empty_accept, + /* .apply = */ llama_sampler_empty_apply, + /* .reset = */ llama_sampler_empty_reset, + /* .clone = */ llama_sampler_empty_clone, + /* .free = */ llama_sampler_empty_free, + /* .backend_init = */ llama_sampler_empty_backend_init, + /* .backend_accept = */ llama_sampler_empty_backend_accept, + /* .backend_apply = */ llama_sampler_empty_backend_apply, + /* .backend_set_input = */ llama_sampler_empty_backend_set_input, +}; + +struct llama_sampler * llama_sampler_init_empty(const char * name) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_empty_i, + /* .ctx = */ new llama_sampler_empty { + /* .name = */ name, + } + ); +} + +// common backend sampler functionality +// +// +name : means that the sampler is support and will run on the backend +// -name : means that a ggml operator is not supported by the backend +// +struct llama_sampler_backend { + llama_sampler_backend(const char * name) : name(name), name_ext(name), is_init(false), support(false) {} + + const char * get_name() { + if (!is_init) { + return name.c_str(); + } + + if (support) { + name_ext = "+" + name; + } else { + name_ext = "-" + name; + } + + return name_ext.c_str(); + } + + void init(bool support) { + GGML_ASSERT(this->is_init == false); + + this->is_init = true; + this->support = support; + } + +private: + std::string name; + std::string name_ext; + + bool is_init; + bool support; +}; + +// check if all ggml ops used by the sampler are supported by the backend +static bool llama_sampler_backend_support( + llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * device = ggml_backend_buft_get_device(buft); + if (!device) { + // CPU backend always supported + return true; + } + + ggml_init_params params = { + /*.mem_size =*/ 128*ggml_tensor_overhead() + ggml_graph_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error(format("failed to create ggml context")); + } + + ggml_context * ctx = ctx_ptr.get(); + + const int64_t n = 1024*1024; + + llama_sampler_data data = { + /*.logits = */ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n), + /*.probs = */ nullptr, + /*.sampled = */ nullptr, + /*.candidates = */ ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n), + }; + + ggml_cgraph * gf = ggml_new_graph(ctx); + + smpl->iface->backend_apply(smpl, ctx, gf, &data); + + if (data.logits) { + ggml_build_forward_expand(gf, data.logits); + } + + if (data.probs) { + ggml_build_forward_expand(gf, data.probs); + } + + if (data.sampled) { + ggml_build_forward_expand(gf, data.sampled); + } + + if (data.candidates) { + ggml_build_forward_expand(gf, data.candidates); + } + + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + struct ggml_tensor * op = ggml_graph_node(gf, i); + + if (!ggml_backend_dev_supports_op(device, op)) { + LLAMA_LOG_WARN("%s: device '%s' does not have support for op %s needed for sampler '%s'\n", + __func__, ggml_backend_dev_name(device), ggml_op_name(op->op), smpl->iface->name(smpl)); + + return false; + } + } + + return true; +} + // sampler chain static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) { @@ -449,8 +691,8 @@ static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token time_meas tm(chain->t_sample_us, chain->params.no_perf); - for (auto * smpl : chain->samplers) { - llama_sampler_accept(smpl, token); + for (auto & smpl : chain->samplers) { + llama_sampler_accept(smpl.ptr, token); } chain->n_sample++; @@ -461,16 +703,28 @@ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_d time_meas tm(chain->t_sample_us, chain->params.no_perf); - for (auto * smpl : chain->samplers) { - llama_sampler_apply(smpl, cur_p); + bool is_backend = chain->is_init; + + for (auto & smpl : chain->samplers) { + if (is_backend && smpl.is_backend) { + continue; + } + + is_backend = false; + + if (smpl.ptr->iface->apply == nullptr) { + continue; + } + + llama_sampler_apply(smpl.ptr, cur_p); } } static void llama_sampler_chain_reset(struct llama_sampler * smpl) { auto * chain = (llama_sampler_chain *) smpl->ctx; - for (auto * smpl : chain->samplers) { - llama_sampler_reset(smpl); + for (auto & smpl : chain->samplers) { + llama_sampler_reset(smpl.ptr); } } @@ -479,8 +733,8 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl auto * result = llama_sampler_chain_init(chain_src->params); - for (auto * smpl : chain_src->samplers) { - llama_sampler_chain_add(result, llama_sampler_clone(smpl)); + for (const auto & smpl : chain_src->samplers) { + llama_sampler_chain_add(result, llama_sampler_clone(smpl.ptr)); } return result; @@ -489,20 +743,109 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl static void llama_sampler_chain_free(struct llama_sampler * smpl) { auto * chain = (llama_sampler_chain *) smpl->ctx; - for (auto * smpl : chain->samplers) { - llama_sampler_free(smpl); + for (auto & smpl : chain->samplers) { + llama_sampler_free(smpl.ptr); } delete chain; } +static bool llama_sampler_chain_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + GGML_ASSERT(chain->is_init == false && "llama_sampler_chain_backend_init() called twice"); + + chain->is_init = true; + + bool res = true; + + for (auto & smpl : chain->samplers) { + bool res_cur = true; + + // to be able to run a sampler on the backend, it has to: + // - have the .backend_init() API implemented + // - return true during .backend_init() + if (smpl.ptr->iface->backend_init) { + if (!smpl.ptr->iface->backend_init(smpl.ptr, buft)) { + res_cur = false; + } + } else { + res_cur = false; + } + + smpl.is_backend = res_cur; + + res = res && res_cur; + } + + return res; +} + +static void llama_sampler_chain_backend_accept( + struct llama_sampler * smpl, + ggml_context * ctx, + ggml_cgraph * gf, + struct ggml_tensor * selected_token) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto & smpl : chain->samplers) { + if (!smpl.is_backend) { + break; + } + + if (smpl.ptr->iface->backend_accept) { + smpl.ptr->iface->backend_accept(smpl.ptr, ctx, gf, selected_token); + } + } +} + +static void llama_sampler_chain_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + GGML_ASSERT(chain->is_init && "llama_sampler_chain_backend_init() not called"); + + for (auto & smpl : chain->samplers) { + if (!smpl.is_backend) { + break; + } + + if (smpl.ptr->iface->backend_apply) { + smpl.ptr->iface->backend_apply(smpl.ptr, ctx, gf, data); + } + } +} + +static void llama_sampler_chain_backend_set_input(struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto & smpl : chain->samplers) { + if (!smpl.is_backend) { + break; + } + + if (smpl.ptr->iface->backend_set_input) { + smpl.ptr->iface->backend_set_input(smpl.ptr); + } + } +} + static struct llama_sampler_i llama_sampler_chain_i = { - /* .name = */ llama_sampler_chain_name, - /* .accept = */ llama_sampler_chain_accept, - /* .apply = */ llama_sampler_chain_apply, - /* .reset = */ llama_sampler_chain_reset, - /* .clone = */ llama_sampler_chain_clone, - /* .free = */ llama_sampler_chain_free, + /* .name = */ llama_sampler_chain_name, + /* .accept = */ llama_sampler_chain_accept, + /* .apply = */ llama_sampler_chain_apply, + /* .reset = */ llama_sampler_chain_reset, + /* .clone = */ llama_sampler_chain_clone, + /* .free = */ llama_sampler_chain_free, + /* .backend_init = */ llama_sampler_chain_backend_init, + /* .backend_accept = */ llama_sampler_chain_backend_accept, + /* .backend_apply = */ llama_sampler_chain_backend_apply, + /* .backend_set_input = */ llama_sampler_chain_backend_set_input, }; struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { @@ -510,6 +853,7 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param /* .iface = */ &llama_sampler_chain_i, /* .ctx = */ new llama_sampler_chain { /* .params = */ params, + /* .is_init = */ false, /* .samplers = */ {}, /* .t_sample_us = */ 0, /* .n_sample = */ 0, @@ -519,17 +863,32 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { auto * p = (llama_sampler_chain *) chain->ctx; - p->samplers.push_back(smpl); + p->samplers.push_back({ + /* .is_backend = */ false, + /* .ptr = */ smpl, + }); } -struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) { +struct llama_sampler * llama_sampler_chain_get(struct llama_sampler * chain, int32_t i) { + if (chain == nullptr) { + return nullptr; + } + + if (chain->iface != &llama_sampler_chain_i) { + return nullptr; + } + + if (i == -1) { + return chain; + } + const auto * p = (const llama_sampler_chain *) chain->ctx; if (i < 0 || (size_t) i >= p->samplers.size()) { return nullptr; } - return p->samplers[i]; + return p->samplers[i].ptr; } struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) { @@ -539,7 +898,7 @@ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, return nullptr; } - auto * result = p->samplers[i]; + auto * result = p->samplers[i].ptr; p->samplers.erase(p->samplers.begin() + i); return result; @@ -557,8 +916,36 @@ int llama_sampler_chain_n(const struct llama_sampler * chain) { // greedy -static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) { - return "greedy"; +struct llama_sampler_greedy : public llama_sampler_backend { +}; + +static const char * llama_sampler_greedy_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_greedy *) smpl->ctx; + return sctx->get_name(); +} + +static void llama_sampler_greedy_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_greedy *) smpl->ctx; + GGML_UNUSED(ctx); +} + +static struct llama_sampler * llama_sampler_greedy_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_greedy *) smpl->ctx; + auto * result = llama_sampler_init_greedy(); + + // copy the state + { + auto * result_ctx = (llama_sampler_greedy *) result->ctx; + + GGML_UNUSED(ctx); + GGML_UNUSED(result_ctx); + } + + return result; +} + +static void llama_sampler_greedy_free(struct llama_sampler * smpl) { + delete (llama_sampler_greedy *) smpl->ctx; } static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { @@ -570,33 +957,72 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to } } +static bool llama_sampler_greedy_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_greedy *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_greedy_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + GGML_UNUSED(gf); + GGML_UNUSED(smpl); + + struct ggml_tensor * curl = ggml_argmax(ctx, data->logits); + ggml_set_name(curl, "greedy_argmax"); + + data->sampled = curl; +} + static struct llama_sampler_i llama_sampler_greedy_i = { - /* .name = */ llama_sampler_greedy_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_greedy_apply, - /* .reset = */ nullptr, - /* .clone = */ nullptr, - /* .free = */ nullptr, + /* .name = */ llama_sampler_greedy_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_greedy_apply, + /* .reset = */ llama_sampler_greedy_reset, + /* .clone = */ llama_sampler_greedy_clone, + /* .free = */ llama_sampler_greedy_free, + /* .backend_init = */ llama_sampler_greedy_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_greedy_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_greedy() { return llama_sampler_init( /* .iface = */ &llama_sampler_greedy_i, - /* .ctx = */ nullptr + /* .ctx = */ new llama_sampler_greedy { + ("greedy"), + } ); } // dist -struct llama_sampler_dist { +struct llama_sampler_dist : public llama_sampler_backend { const uint32_t seed; uint32_t seed_cur; std::mt19937 rng; + + // backend input + struct ggml_tensor * inp_uniform; + + ggml_context_ptr inp_ctx; + ggml_backend_buffer_ptr inp_buf; }; -static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) { - return "dist"; +static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_dist *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -671,6 +1097,12 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da #endif } +static void llama_sampler_dist_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_dist *) smpl->ctx; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} + static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_dist *) smpl->ctx; auto * result = llama_sampler_init_dist(ctx->seed); @@ -685,23 +1117,119 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample return result; } -static void llama_sampler_dist_reset(struct llama_sampler * smpl) { - auto * ctx = (llama_sampler_dist *) smpl->ctx; - ctx->seed_cur = get_rng_seed(ctx->seed); - ctx->rng.seed(ctx->seed_cur); -} - static void llama_sampler_dist_free(struct llama_sampler * smpl) { delete (llama_sampler_dist *) smpl->ctx; } +static bool llama_sampler_dist_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_dist *) smpl->ctx; + + // allocate inputs + { + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + sctx->inp_ctx.reset(ggml_init(params)); + + // Create the uniform random scalar input tensor. This will be set by + // llama_sampler_dist_backend_set_input after this graph is built. + sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1); + ggml_set_name (sctx->inp_uniform, "uniform"); + ggml_set_input(sctx->inp_uniform); + + // Allocate all tensors from our context to the backend + sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); + } + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + if (!res) { + sctx->inp_ctx.reset(nullptr); + sctx->inp_buf.reset(nullptr); + } + + return res; +} + +static void llama_sampler_dist_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + GGML_UNUSED(gf); + auto * sctx = (llama_sampler_dist *) smpl->ctx; + + struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits); + ggml_set_name(probs, "dist_probs"); + + struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs); + ggml_set_name(cumsum, "dist_cumsum"); + + // The uniform tensor has a random value and we subtract this tensor with + // the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub). + // Recall that each entry in cumsum is the cumulative probability up to that + // index so values stay negative while the cumulative total is below the + // random value, and become zero/positive once the threshold is crossed. + struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->inp_uniform); + ggml_set_name(diff, "dist_cumsum"); + + // The ggml_step function produces a tensor where entries are 1 if the + // corresponding entry in diff is > 0, and 0 otherwise. So all values up to + // the index where the cumulative probability exceeds the random value are 0, + // and all entries after that are 1. + struct ggml_tensor * mask = ggml_step(ctx, diff); + ggml_set_name(mask, "dist_mask"); + + // Taking the sum of the mask gives us the sum of elements after the threshold + // we are interested in. + struct ggml_tensor * idxf = ggml_sum(ctx, mask); + ggml_set_name(idxf, "dist_index_f32"); + + // Use ggml_scale_bias to scale the index value by -1 and then add the size + // of the mask to that value so we get the correct index ((-1 * idxf) + n). + struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32); + ggml_set_name(idx, "dist_index_i32"); + + // Map back to original vocab ids if a candidates tensor is available. + struct ggml_tensor * sampled_token = idx; + if (data->candidates != nullptr) { + struct ggml_tensor * candidates = ggml_reshape_2d(ctx, data->candidates, 1, ggml_nelements(data->candidates)); + + sampled_token = ggml_get_rows(ctx, candidates, idx); + ggml_set_name(sampled_token, "dist_sampled_token"); + } + + data->sampled = sampled_token; + data->probs = probs; +} + +static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_dist *) smpl->ctx; + GGML_ASSERT(sctx->inp_uniform != nullptr); + + std::uniform_real_distribution dist(0.0f, 1.0f); + const float rnd = dist(sctx->rng); + ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float)); +} + static struct llama_sampler_i llama_sampler_dist_i = { - /* .name = */ llama_sampler_dist_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_dist_apply, - /* .reset = */ llama_sampler_dist_reset, - /* .clone = */ llama_sampler_dist_clone, - /* .free = */ llama_sampler_dist_free, + /* .name = */ llama_sampler_dist_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_dist_apply, + /* .reset = */ llama_sampler_dist_reset, + /* .clone = */ llama_sampler_dist_clone, + /* .free = */ llama_sampler_dist_free, + /* .backend_init = */ llama_sampler_dist_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_dist_backend_apply, + /* .backend_set_input = */ llama_sampler_dist_backend_set_input, }; struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { @@ -709,21 +1237,26 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { return llama_sampler_init( /* .iface = */ &llama_sampler_dist_i, /* .ctx = */ new llama_sampler_dist { - /* .seed = */ seed, - /* .seed_cur = */ seed_cur, - /* .rng = */ std::mt19937(seed_cur), + ("dist"), + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .rng = */ std::mt19937(seed_cur), + /* .inp_uniform = */ nullptr, + /* .inp_ctx = */ nullptr, + /* .inp_buf = */ nullptr, } ); } // top-k -struct llama_sampler_top_k { +struct llama_sampler_top_k : public llama_sampler_backend { const int32_t k; }; -static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) { - return "top-k"; +static const char * llama_sampler_top_k_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_top_k *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -740,19 +1273,68 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) { delete (llama_sampler_top_k *) smpl->ctx; } +static bool llama_sampler_top_k_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_top_k *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_top_k_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * sctx = (llama_sampler_top_k *) smpl->ctx; + + struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, sctx->k); + ggml_set_name(top_k, "top_k"); + + if (data->candidates) { + data->candidates = ggml_get_rows(ctx, data->candidates, top_k); + ggml_set_name(data->candidates, "top_k_candidates"); + } else { + data->candidates = top_k; + } + + struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); + struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k); + ggml_set_name(top_k_rows, "top_k_rows"); + + data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k); + + GGML_UNUSED(gf); +} + static struct llama_sampler_i llama_sampler_top_k_i = { - /* .name = */ llama_sampler_top_k_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_k_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_k_clone, - /* .free = */ llama_sampler_top_k_free, + /* .name = */ llama_sampler_top_k_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_k_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_k_clone, + /* .free = */ llama_sampler_top_k_free, + /* .backend_init = */ llama_sampler_top_k_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_top_k_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_k(int32_t k) { + const bool is_empty = (k <= 0); + + if (is_empty) { + return llama_sampler_init_empty("?top-k"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_top_k_i, /* .ctx = */ new llama_sampler_top_k { + ("top-k"), /* .k = */ k, } ); @@ -760,15 +1342,16 @@ struct llama_sampler * llama_sampler_init_top_k(int32_t k) { // top-p -struct llama_sampler_top_p { +struct llama_sampler_top_p : public llama_sampler_backend { const float p; const size_t min_keep; std::vector buf_sort; }; -static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) { - return "top-p"; +static const char * llama_sampler_top_p_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_top_p *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -835,19 +1418,118 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) { delete (llama_sampler_top_p *) smpl->ctx; } +static bool llama_sampler_top_p_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_top_p *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_top_p_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * sctx = (llama_sampler_top_p *) smpl->ctx; + + auto ggml_sort = [ctx](struct ggml_tensor * a, struct ggml_tensor * b) { + GGML_ASSERT(ggml_nrows(a) == 1); + struct ggml_tensor * a_reshaped = ggml_reshape_2d(ctx, a, 1, a->ne[0]); + struct ggml_tensor * a_sorted = ggml_get_rows(ctx, a_reshaped, b); + return ggml_reshape_1d(ctx, a_sorted, a->ne[0]); + }; + + // Get the sorted logits in descending order. + struct ggml_tensor * sorted_idx = ggml_argsort(ctx, data->logits, GGML_SORT_ORDER_DESC); + ggml_set_name(sorted_idx, "top_p_sorted_idx"); + + // Do the sorting via reshape + get_rows + struct ggml_tensor * sorted_logits = ggml_sort(data->logits, sorted_idx); + ggml_set_name(sorted_logits, "top_p_sorted_logits"); + + struct ggml_tensor * softmax = ggml_soft_max(ctx, sorted_logits); + ggml_set_name(softmax, "top_p_softmax"); + + // If candidates are provided, sort them as well. Otherwise, set sorted indices as candidates. + if (data->candidates) { + data->candidates = ggml_sort(data->candidates, sorted_idx); + } else { + data->candidates = sorted_idx; + } + ggml_set_name(data->candidates, "top_p_candidates"); + + // Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM. + struct ggml_tensor * cdf = ggml_cumsum(ctx, softmax); + ggml_set_name(cdf, "top_p_cdf"); + + // Invert CDF and add top-p value so that ggml_step yields 1 for values we want to keep + struct ggml_tensor * cdf_scaled = ggml_scale_bias(ctx, cdf, -1.0f, sctx->p); + ggml_set_name(cdf_scaled, "top_p_cdf_scaled"); + + struct ggml_tensor * mask = ggml_step(ctx, cdf_scaled); + ggml_set_name(mask, "top_p_mask"); + + // Taking the sum of the mask gives us the sum of elements after the threshold + // we are interested in. + struct ggml_tensor * idxf = ggml_sum(ctx, mask); + ggml_set_name(idxf, "top_p_index_f32"); + + // prevent out-of-bounds access + idxf = ggml_clamp(ctx, idxf, 0.0f, mask->ne[0] - 1); + + // construct ones tensor to set the value in the mask + struct ggml_tensor * ones = ggml_scale_bias(ctx, idxf, 0.0f, 1.0f); + ggml_set_name(ones, "top_p_ones"); + + // Make top-p inclusive (i.e. return all values such that cum_sum/cdf >= p) + struct ggml_tensor * mask_reshaped = ggml_reshape_2d(ctx, mask, 1, mask->ne[0]); + + mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, idxf, GGML_TYPE_I32)); + mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]); + + // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes: + // top_p_bias = (mask * 1e9f) - 1e9f. + // So entries in the mask that we want to discard will become -1e9f, and + // others will be 0 (meaning that will not effect the logits). + const float large_val = 1e9f; + struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val); + ggml_set_name(top_p_bias, "top_p_bias"); + + data->logits = ggml_add(ctx, sorted_logits, top_p_bias); + ggml_set_name(data->logits, "top_p_logits"); + + GGML_UNUSED(gf); +} + static struct llama_sampler_i llama_sampler_top_p_i = { - /* .name = */ llama_sampler_top_p_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_p_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_p_clone, - /* .free = */ llama_sampler_top_p_free, + /* .name = */ llama_sampler_top_p_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_p_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_p_clone, + /* .free = */ llama_sampler_top_p_free, + /* .backend_init = */ llama_sampler_top_p_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_top_p_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { + const bool is_empty = p >= 1.0f; + + if (is_empty) { + return llama_sampler_init_empty("?top-p"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_top_p_i, /* .ctx = */ new llama_sampler_top_p { + ("top-p"), /* .p = */ p, /* .min_keep = */ min_keep, /* .buf_sort = */ {}, @@ -857,13 +1539,14 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { // min-p -struct llama_sampler_min_p { +struct llama_sampler_min_p : public llama_sampler_backend { const float p; const size_t min_keep; }; -static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) { - return "min-p"; +static const char * llama_sampler_min_p_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_min_p *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -929,19 +1612,85 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) { delete (llama_sampler_min_p *) smpl->ctx; } +static bool llama_sampler_min_p_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_min_p *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_min_p_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * sctx = (llama_sampler_min_p *) smpl->ctx; + + struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits); + ggml_set_name(max_idx, "max_idx"); + + struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); + ggml_set_name(logits_rows, "logits_rows"); + + struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx); + ggml_set_name(max_logit, "max_logit"); + + // Calculate the threshold value. + struct ggml_tensor * threshold = ggml_scale_bias(ctx, max_logit, 1.0f, logf(sctx->p)); + ggml_set_name(threshold, "min_p_threshold"); + + // Subtract the threshold from logits. + struct ggml_tensor * sub = ggml_sub(ctx, data->logits, threshold); + + // Create a mask where logits below the threshold are 0 (discard), + // and others are 1 (keep). + struct ggml_tensor * mask = ggml_step(ctx, sub); + ggml_set_name(mask, "min_p_mask"); + + // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes: + // min_p_bias = (mask * 1e9f) - 1e9f. + // So entries in the mask that we want to discard will become -1e9f, and + // others will be 0 (meaning that will not effect the logits). + const float large_val = 1e9f; + struct ggml_tensor * min_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val); + ggml_set_name(min_p_bias, "min_p_bias"); + + // Add the min_p bias to the logits. + data->logits = ggml_add(ctx, data->logits, min_p_bias); + ggml_set_name(data->logits, "min_p_logits"); + + GGML_UNUSED(gf); +} + static struct llama_sampler_i llama_sampler_min_p_i = { - /* .name = */ llama_sampler_min_p_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_min_p_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_min_p_clone, - /* .free = */ llama_sampler_min_p_free, + /* .name = */ llama_sampler_min_p_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_min_p_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_min_p_clone, + /* .free = */ llama_sampler_min_p_free, + /* .backend_init = */ llama_sampler_min_p_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_min_p_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { + const bool is_empty = (p <= 0.0f); + + if (is_empty) { + return llama_sampler_init_empty("?min-p"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_min_p_i, /* .ctx = */ new llama_sampler_min_p { + ("min-p"), /* .p = */ p, /* .min_keep = */ min_keep, } @@ -1029,15 +1778,25 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_typical_i = { - /* .name = */ llama_sampler_typical_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_typical_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_typical_clone, - /* .free = */ llama_sampler_typical_free, + /* .name = */ llama_sampler_typical_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_typical_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_typical_clone, + /* .free = */ llama_sampler_typical_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { + const bool is_empty = (p >= 1.0f); + + if (is_empty) { + return llama_sampler_init_empty("?typical"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_typical_i, /* .ctx = */ new llama_sampler_typical { @@ -1049,12 +1808,13 @@ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { // temp -struct llama_sampler_temp { +struct llama_sampler_temp : public llama_sampler_backend { const float temp; }; -static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) { - return "temp"; +static const char * llama_sampler_temp_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_temp *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -1072,19 +1832,79 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) { delete (llama_sampler_temp *) smpl->ctx; } +static void llama_sampler_backend_temp_sampling( + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data, + float temp) { + if (temp <= 0.0f) { + // Find the most probable token index. + struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits); + ggml_set_name(max_idx, "temp_max_idx"); + + if (data->candidates) { + data->candidates = ggml_get_rows(ctx, data->candidates, max_idx); + } else { + data->candidates = max_idx; + } + + struct ggml_tensor * logits = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); + + data->logits = ggml_get_rows(ctx, logits, max_idx); + + return; + } + + data->logits = ggml_scale(ctx, data->logits, 1.0f / temp); + + GGML_UNUSED(gf); +} + +static bool llama_sampler_temp_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_temp *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_temp_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * sctx = (llama_sampler_temp *) smpl->ctx; + llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp); +} + static struct llama_sampler_i llama_sampler_temp_i = { - /* .name = */ llama_sampler_temp_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_temp_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_temp_clone, - /* .free = */ llama_sampler_temp_free, + /* .name = */ llama_sampler_temp_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_temp_clone, + /* .free = */ llama_sampler_temp_free, + /* .backend_init = */ llama_sampler_temp_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_temp_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_temp(float temp) { + const bool is_empty = temp == 1.0f; + + if (is_empty) { + return llama_sampler_init_empty("?temp"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_temp_i, /* .ctx = */ new llama_sampler_temp { + ("temp"), /*.temp = */ temp, } ); @@ -1092,14 +1912,15 @@ struct llama_sampler * llama_sampler_init_temp(float temp) { // temp-ext -struct llama_sampler_temp_ext { +struct llama_sampler_temp_ext : public llama_sampler_backend { const float temp; const float delta; const float exponent; }; -static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) { - return "temp-ext"; +static const char * llama_sampler_temp_ext_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_temp_ext *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -1182,24 +2003,112 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) { delete (llama_sampler_temp_ext *) smpl->ctx; } +static bool llama_sampler_temp_ext_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_temp_ext *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_temp_ext_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * sctx = (llama_sampler_temp_ext *) smpl->ctx; + + // Revert to standard temperature scaling if delta or temp are non-positive. + if (sctx->delta <= 0.0f || sctx->temp <= 0.0f) { + llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp); + return; + } + + // Calculate min_temp, max_temp, and max_entropy. + const float min_temp = std::max(0.0f, sctx->temp - sctx->delta); + const float max_temp = sctx->temp + sctx->delta; + const float max_entropy = logf(data->logits->ne[0]); + + // Calculate the probabilities. + struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits); + ggml_set_name(probs, "temp_ext_softmax_probs"); + + // Clamp probabilities to avoid log(0) which would give -inf + struct ggml_tensor * probs_clamped = ggml_clamp(ctx, probs, 1e-10f, 1.0f); + ggml_set_name(probs_clamped, "temp_ext_probs_clamped"); + + // Calculate the entropy, entropy = -Σ(p * log(p)). + struct ggml_tensor * log_probs = ggml_log(ctx, probs_clamped); + struct ggml_tensor * p_log_p = ggml_mul(ctx, probs_clamped, log_probs); + struct ggml_tensor * sum_p_log_p = ggml_sum(ctx, p_log_p); + struct ggml_tensor * entropy = ggml_scale(ctx, sum_p_log_p, -1.0f); + ggml_set_name(log_probs, "temp_ext_log_probs"); + ggml_set_name(p_log_p, "temp_ext_p_log_p"); + ggml_set_name(sum_p_log_p, "temp_ext_sum_p_log_p"); + ggml_set_name(entropy, "temp_ext_entropy"); + + // Normalize the entropy, norm_entropy = entropy / max_entropy + struct ggml_tensor * norm_entropy = ggml_scale(ctx, entropy, 1.0f / max_entropy); + ggml_set_name(norm_entropy, "temp_ext_norm_entropy"); + + // Calculate the dynamic temperature: + // dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent); + // + // Calculate powf(normalized_entropy, exponent) as + // norm_entropy^exponent = exp(exponent * log(norm_entropy)) + struct ggml_tensor * log_norm_entropy = ggml_log(ctx, norm_entropy); + struct ggml_tensor * scaled_log = ggml_scale(ctx, log_norm_entropy, sctx->exponent); + struct ggml_tensor * pow_entropy = ggml_exp(ctx, scaled_log); + // With pow_entropy computed we can now compute dyn_temp, scaling by + // (max_temp - min_temp) and then adding min_temp. + struct ggml_tensor * dyn_temp = ggml_scale_bias(ctx, pow_entropy, max_temp - min_temp, min_temp); + ggml_set_name(log_norm_entropy, "temp_ext_log_norm_entropy"); + ggml_set_name(scaled_log, "temp_ext_scaled_log"); + ggml_set_name(pow_entropy, "temp_ext_pow_entropy"); + ggml_set_name(dyn_temp, "temp_ext_dyn_temp"); + + // Scale the logits by the dynamic temperature + struct ggml_tensor * scaled_logits = ggml_div(ctx, data->logits, dyn_temp); + ggml_set_name(scaled_logits, "temp_ext_scaled_logits"); + + data->logits = scaled_logits; +} + static struct llama_sampler_i llama_sampler_temp_ext_i = { - /* .name = */ llama_sampler_temp_ext_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_temp_ext_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_temp_ext_clone, - /* .free = */ llama_sampler_temp_ext_free, + /* .name = */ llama_sampler_temp_ext_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_ext_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_temp_ext_clone, + /* .free = */ llama_sampler_temp_ext_free, + /* .backend_init = */ llama_sampler_temp_ext_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_temp_ext_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { - return llama_sampler_init( + const bool is_empty = temp == 1.0f && delta <= 0.0f; + + if (is_empty) { + return llama_sampler_init_empty("?temp-ext"); + } + + auto * res = llama_sampler_init( /* .iface = */ &llama_sampler_temp_ext_i, /* .ctx = */ new llama_sampler_temp_ext { + ("temp-ext"), /* .temp = */ temp, /* .delta = */ delta, /* .exponent = */ exponent, } ); + + return res; } // xtc @@ -1277,16 +2186,27 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_xtc_i = { - /* .name = */ llama_sampler_xtc_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sample_xtc_apply, - /* .reset = */ llama_sampler_xtc_reset, - /* .clone = */ llama_sampler_xtc_clone, - /* .free = */ llama_sampler_xtc_free, + /* .name = */ llama_sampler_xtc_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sample_xtc_apply, + /* .reset = */ llama_sampler_xtc_reset, + /* .clone = */ llama_sampler_xtc_clone, + /* .free = */ llama_sampler_xtc_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { - auto seed_cur = get_rng_seed(seed); + const bool is_empty = (p <= 0.0f || t > 0.5f); + + if (is_empty) { + return llama_sampler_init_empty("?xtc"); + } + + const auto seed_cur = get_rng_seed(seed); + return llama_sampler_init( /* .iface = */ &llama_sampler_xtc_i, /* .ctx = */ new llama_sampler_xtc { @@ -1385,16 +2305,21 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_mirostat_i = { - /* .name = */ llama_sampler_mirostat_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_mirostat_apply, - /* .reset = */ llama_sampler_mirostat_reset, - /* .clone = */ llama_sampler_mirostat_clone, - /* .free = */ llama_sampler_mirostat_free, + /* .name = */ llama_sampler_mirostat_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_apply, + /* .reset = */ llama_sampler_mirostat_reset, + /* .clone = */ llama_sampler_mirostat_clone, + /* .free = */ llama_sampler_mirostat_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { - auto seed_cur = get_rng_seed(seed); + const auto seed_cur = get_rng_seed(seed); + return llama_sampler_init( /* .iface = */ &llama_sampler_mirostat_i, /* .ctx = */ new llama_sampler_mirostat { @@ -1484,12 +2409,16 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_mirostat_v2_i = { - /* .name = */ llama_sampler_mirostat_v2_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_mirostat_v2_apply, - /* .reset = */ llama_sampler_mirostat_v2_reset, - /* .clone = */ llama_sampler_mirostat_v2_clone, - /* .free = */ llama_sampler_mirostat_v2_free, + /* .name = */ llama_sampler_mirostat_v2_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_v2_apply, + /* .reset = */ llama_sampler_mirostat_v2_reset, + /* .clone = */ llama_sampler_mirostat_v2_clone, + /* .free = */ llama_sampler_mirostat_v2_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { @@ -1601,12 +2530,16 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_grammar_i = { - /* .name = */ llama_sampler_grammar_name, - /* .accept = */ llama_sampler_grammar_accept_impl, - /* .apply = */ llama_sampler_grammar_apply, - /* .reset = */ llama_sampler_grammar_reset, - /* .clone = */ llama_sampler_grammar_clone, - /* .free = */ llama_sampler_grammar_free, + /* .name = */ llama_sampler_grammar_name, + /* .accept = */ llama_sampler_grammar_accept_impl, + /* .apply = */ llama_sampler_grammar_apply, + /* .reset = */ llama_sampler_grammar_reset, + /* .clone = */ llama_sampler_grammar_clone, + /* .free = */ llama_sampler_grammar_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; static struct llama_sampler * llama_sampler_init_grammar_impl( @@ -1808,12 +2741,16 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_penalties_i = { - /* .name = */ llama_sampler_penalties_name, - /* .accept = */ llama_sampler_penalties_accept, - /* .apply = */ llama_sampler_penalties_apply, - /* .reset = */ llama_sampler_penalties_reset, - /* .clone = */ llama_sampler_penalties_clone, - /* .free = */ llama_sampler_penalties_free, + /* .name = */ llama_sampler_penalties_name, + /* .accept = */ llama_sampler_penalties_accept, + /* .apply = */ llama_sampler_penalties_apply, + /* .reset = */ llama_sampler_penalties_reset, + /* .clone = */ llama_sampler_penalties_clone, + /* .free = */ llama_sampler_penalties_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_penalties( @@ -1823,6 +2760,12 @@ struct llama_sampler * llama_sampler_init_penalties( float penalty_present) { penalty_last_n = std::max(penalty_last_n, 0); + const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)); + + if (is_empty) { + return llama_sampler_init_empty("?penalties"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_penalties_i, /* .ctx = */ new llama_sampler_penalties { @@ -1860,9 +2803,7 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t for (size_t i = 0; i < cur_p->size; ++i) { // Only count non-negative infinity values if (cur_p->data[i].logit != -INFINITY) { - if (cur_p->data[i].logit > max) { - max = cur_p->data[i].logit; - } + max = std::max(max, cur_p->data[i].logit); logits_sum += cur_p->data[i].logit; valid_count++; } @@ -1899,15 +2840,25 @@ static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_top_n_sigma_i = { - /* .name = */ llama_sampler_top_n_sigma_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_n_sigma_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_n_sigma_clone, - /* .free = */ llama_sampler_top_n_sigma_free, + /* .name = */ llama_sampler_top_n_sigma_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_n_sigma_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_n_sigma_clone, + /* .free = */ llama_sampler_top_n_sigma_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_n_sigma(float n) { + const bool is_empty = (n <= 0.0f); + + if (is_empty) { + return llama_sampler_init_empty("?top-n-sigma"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_top_n_sigma_i, /* .ctx = */ new llama_sampler_top_n_sigma { @@ -2229,12 +3180,16 @@ static void llama_sampler_dry_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_dry_i = { - /* .name = */ llama_sampler_dry_name, - /* .accept = */ llama_sampler_dry_accept, - /* .apply = */ llama_sampler_dry_apply, - /* .reset = */ llama_sampler_dry_reset, - /* .clone = */ llama_sampler_dry_clone, - /* .free = */ llama_sampler_dry_free, + /* .name = */ llama_sampler_dry_name, + /* .accept = */ llama_sampler_dry_accept, + /* .apply = */ llama_sampler_dry_apply, + /* .reset = */ llama_sampler_dry_reset, + /* .clone = */ llama_sampler_dry_clone, + /* .free = */ llama_sampler_dry_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { @@ -2245,6 +3200,10 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0); + if (!dry_enabled) { + return llama_sampler_init_empty("?dry"); + } + if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) { // Process sequence breakers for (size_t i = 0; i < num_breakers; ++i) { @@ -2315,16 +3274,23 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa // logit-bias -struct llama_sampler_logit_bias { +struct llama_sampler_logit_bias : public llama_sampler_backend { const int32_t n_vocab; const std::vector logit_bias; std::vector to_search; + + struct ggml_tensor * inp_logit_bias; + struct ggml_tensor * inp_logit_idxs; + + ggml_context_ptr inp_ctx; + ggml_backend_buffer_ptr inp_buf; }; -static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) { - return "logit-bias"; +static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_logit_bias *) smpl->ctx; + return ctx->get_name(); } static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -2369,25 +3335,121 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) { delete (llama_sampler_logit_bias *) smpl->ctx; } +static void llama_sampler_logit_bias_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + GGML_UNUSED(gf); + GGML_UNUSED(ctx); + + auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; + if (sctx->logit_bias.empty()) { + return; + } + + ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f); + + cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur)); + cur = ggml_set_rows(ctx, cur, sctx->inp_logit_bias, sctx->inp_logit_idxs); + cur = ggml_reshape_1d(ctx, cur, ggml_nelements(cur)); + + data->logits = ggml_add(ctx, data->logits, cur); +} + +static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; + if (sctx->logit_bias.empty()) { + return; + } + + GGML_ASSERT(sctx->inp_logit_bias != nullptr); + GGML_ASSERT(sctx->inp_logit_idxs != nullptr); + + const size_t n = sctx->logit_bias.size(); + + std::vector data_logit_bias(n, 0.0f); + std::vector data_logit_idxs(n, 0); + for (size_t i = 0; i < n; ++i) { + const auto & lb = sctx->logit_bias[i]; + GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab); + data_logit_bias[i] = lb.bias; + data_logit_idxs[i] = lb.token; + } + + ggml_backend_tensor_set(sctx->inp_logit_bias, data_logit_bias.data(), 0, ggml_nbytes(sctx->inp_logit_bias)); + ggml_backend_tensor_set(sctx->inp_logit_idxs, data_logit_idxs.data(), 0, ggml_nbytes(sctx->inp_logit_idxs)); +} + +static bool llama_sampler_logit_bias_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; + + sctx->init(true); + + if (sctx->logit_bias.empty()) { + return true; + } + + ggml_init_params params = { + /*.mem_size =*/ 2*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + sctx->inp_ctx.reset(ggml_init(params)); + + const size_t n = sctx->logit_bias.size(); + + sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n); + ggml_set_name(sctx->inp_logit_bias, "logit_bias"); + ggml_set_input(sctx->inp_logit_bias); + + sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n); + ggml_set_name(sctx->inp_logit_idxs, "logit_idxs"); + ggml_set_input(sctx->inp_logit_idxs); + + // Allocate all tensors from our context to the backend + sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); + + return true; +} + static struct llama_sampler_i llama_sampler_logit_bias_i = { - /* .name = */ llama_sampler_logit_bias_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_logit_bias_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_logit_bias_clone, - /* .free = */ llama_sampler_logit_bias_free, + /* .name = */ llama_sampler_logit_bias_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_logit_bias_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_logit_bias_clone, + /* .free = */ llama_sampler_logit_bias_free, + /* .backend_init = */ llama_sampler_logit_bias_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_logit_bias_backend_apply, + /* .backend_set_input = */ llama_sampler_logit_bias_backend_set_input, }; struct llama_sampler * llama_sampler_init_logit_bias( int32_t n_vocab, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { + const bool is_empty = n_logit_bias <= 0; + + if (is_empty) { + return llama_sampler_init_empty("?logit-bias"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_logit_bias_i, /* .ctx = */ new llama_sampler_logit_bias { - /* .n_vocab = */ n_vocab, - /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), - /* .to_search = */ {}, + ("logit-bias"), + /* .n_vocab = */ n_vocab, + /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), + /* .to_search = */ {}, + /* .inp_logit_bias = */ nullptr, + /* .inp_logit_idxs = */ nullptr, + /* .inp_ctx = */ nullptr, + /* .inp_buf = */ nullptr, } ); } @@ -2600,12 +3662,16 @@ static void llama_sampler_infill_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_infill_i = { - /* .name = */ llama_sampler_infill_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_infill_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_infill_clone, - /* .free = */ llama_sampler_infill_free, + /* .name = */ llama_sampler_infill_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_infill_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_infill_clone, + /* .free = */ llama_sampler_infill_free, + /* .backend_apply = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_set_input = */ nullptr, + /* .backend_init = */ nullptr, }; struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) { @@ -2637,7 +3703,7 @@ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) { if (smpl->iface == &llama_sampler_chain_i) { const auto * ctx = (const llama_sampler_chain *) smpl->ctx; for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) { - const uint32_t seed = llama_sampler_get_seed(*it); + const uint32_t seed = llama_sampler_get_seed(it->ptr); if (seed != LLAMA_DEFAULT_SEED) { return seed; } diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 759dd7dcb7..18cae29ece 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -14,7 +14,16 @@ struct llama_grammar; struct llama_sampler_chain { llama_sampler_chain_params params; - std::vector samplers; + // has .backend_init() been called? + bool is_init = false; + + struct info { + bool is_backend; + + llama_sampler * ptr; + }; + + std::vector samplers; // timing @@ -24,9 +33,9 @@ struct llama_sampler_chain { }; struct llama_sampler * llama_sampler_init_dry_testing( - int32_t context_size, - float dry_multiplier, - float dry_base, - int32_t dry_allowed_length, - int32_t dry_penalty_last_n, - const std::vector>& seq_breakers); + int32_t context_size, + float dry_multiplier, + float dry_base, + int32_t dry_allowed_length, + int32_t dry_penalty_last_n, + const std::vector> & seq_breakers); diff --git a/src/llama.cpp b/src/llama.cpp index c8b5febe70..567dc9aa4e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -708,7 +708,7 @@ bool llama_params_fit( struct llama_sampler_chain_params llama_sampler_chain_default_params() { struct llama_sampler_chain_params result = { - /*.no_perf =*/ true, + /*.no_perf =*/ true, }; return result; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index c3d9f9c324..ff4b7205aa 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -222,6 +222,17 @@ llama_build_and_test(test-backend-ops.cpp) llama_build_and_test(test-model-load-cancel.cpp LABEL "model") llama_build_and_test(test-autorelease.cpp LABEL "model") +llama_build_and_test(test-backend-sampler.cpp LABEL "model") +target_include_directories(test-backend-sampler PRIVATE ${PROJECT_SOURCE_DIR}/src) +llama_test(test-backend-sampler NAME test-backend-sampler-greedy ARGS --test greedy) +llama_test(test-backend-sampler NAME test-backend-sampler-temp ARGS --test temp) +llama_test(test-backend-sampler NAME test-backend-sampler-top_k ARGS --test top_k) +llama_test(test-backend-sampler NAME test-backend-sampler-dist ARGS --test dist) +llama_test(test-backend-sampler NAME test-backend-sampler-dist-and-cpu ARGS --test dist_and_cpu) +llama_test(test-backend-sampler NAME test-backend-sampler-logit-bias ARGS --test logit_bias) +llama_test(test-backend-sampler NAME test-backend-sampler-mul_seq ARGS --test multi_sequence) +llama_test(test-backend-sampler NAME test-backend-sampler-set-sampler ARGS --test set_sampler) + # Test for state restore with fragmented KV cache # Requires a model, uses same args pattern as test-thread-safety if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 416218b5b8..eb699661f7 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7613,6 +7613,9 @@ static std::vector> make_test_cases_eval() { exponent <<= 1; } #endif + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200000, 1, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200000, 4, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {643251, 3, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); for (bool mask : {false, true}) { for (bool sinks : {false, true}) { for (float max_bias : {0.0f, 8.0f}) { @@ -7666,7 +7669,6 @@ static std::vector> make_test_cases_eval() { } } } - for (bool fw : {true, false}) { // fw == forward bool all = true; @@ -7841,6 +7843,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2048, 5, 4, 3 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 201*1204, 1, 1, 1 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 312*1205, 1, 1, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 20481, 4, 1, 1 })); test_cases.emplace_back(new test_xielu()); @@ -8162,6 +8165,12 @@ static std::vector> make_test_cases_perf() { } } + for (int col: {8192, 16384, 32768, 65536, 131072, 262144, 524288}) { + for (int rows: {1, 4, 16}){ + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {col, rows, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + } + } + test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false)); test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true)); @@ -8206,6 +8215,8 @@ static std::vector> make_test_cases_perf() { } test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1})); + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 1, 1, 1})); + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 16, 1, 1})); test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2, 1, 1, 1}, 1)); for (auto k : {1, 10, 40, 400}) { @@ -8216,13 +8227,18 @@ static std::vector> make_test_cases_perf() { } } + for (auto nrows : {1, 4, 8, 16}) { + for (auto cols : {128, 1024, 4096, 8192, 16384, 32768, 65536, 131072, 200000, 2000000}) { + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, {cols, nrows, 1, 1})); + } + } + // Examples from granite-4.0-h-1b/ggml-model-Q8_0.gguf test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1})); // prefill test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1})); // generate test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1, 1)); // generate - return test_cases; } diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp new file mode 100644 index 0000000000..b3f202771a --- /dev/null +++ b/tests/test-backend-sampler.cpp @@ -0,0 +1,1281 @@ +#include "ggml.h" +#include "llama.h" +#include "get-model.h" +#include "common.h" + +#ifdef NDEBUG +#undef NDEBUG +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +struct test_model_context { + llama_model * model = nullptr; + llama_context * ctx = nullptr; + const llama_vocab * vocab = nullptr; + int n_vocab = 0; + + std::unordered_map seq_positions; + std::unordered_map last_batch_info; + + bool load_model(const char * model_path) { + if (model != nullptr) { + return true; + } + + llama_backend_init(); + + // force CPU backend since it always supports all ggml operations + ggml_backend_dev_t devs[2]; + devs[0] = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + devs[1] = nullptr; + + auto mparams = llama_model_default_params(); + mparams.devices = devs; + + model = llama_model_load_from_file(model_path, mparams); + if (model == nullptr) { + fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", model_path); + cleanup(); + return false; + } + vocab = llama_model_get_vocab(model); + n_vocab = llama_vocab_n_tokens(vocab); + fprintf(stderr, "Vocabulary size: %d\n", n_vocab); + + return true; + } + + bool setup(const char * model_path, std::vector & configs, int32_t n_seq_max = -1) { + if (model == nullptr) { + load_model(model_path); + } + + if (ctx != nullptr) { + return true; + } + + llama_context_params cparams = llama_context_default_params(); + cparams.n_ctx = 512; + cparams.n_batch = 512; + cparams.samplers = configs.data(); + cparams.n_samplers = configs.size(); + + // If n_seq_max is not specified, calculate it from configs + if (n_seq_max < 0) { + int32_t max_seq_id = 0; + for (const auto & config : configs) { + max_seq_id = std::max(config.seq_id, max_seq_id); + } + cparams.n_seq_max = max_seq_id + 1; + } else { + cparams.n_seq_max = n_seq_max; + } + + ctx = llama_init_from_model(model, cparams); + if (ctx == nullptr) { + fprintf(stderr, "Warning: failed to create context, skipping test\n"); + cleanup(); + return false; + } + llama_set_warmup(ctx, false); + + return true; + } + + bool decode(const std::map & prompts) { + if (ctx == nullptr || vocab == nullptr) { + fprintf(stderr, "Error: context not initialized, call setup() first\n"); + return false; + } + + last_batch_info.clear(); + llama_batch batch = llama_batch_init(512, 0, prompts.size()); + + int n_tokens_per_prompt = 0; + + for (const auto & [seq_id, prompt] : prompts) { + std::vector tokens; + tokens.push_back(llama_vocab_bos(vocab)); + + std::vector prompt_tokens(32); + int n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.length(), + prompt_tokens.data(), prompt_tokens.size(), + false, false); + //TODO: refactor this function to just handle a single prompt at a time + // to avoid this check and complexity. + if (n_tokens_per_prompt == 0) { + n_tokens_per_prompt = n_tokens; + } else { + if (n_tokens != n_tokens_per_prompt) { + fprintf(stderr, "Error: prompts must have the same number of tokens\n"); + llama_batch_free(batch); + return false; + } + n_tokens_per_prompt = n_tokens; + } + if (n_tokens < 0) { + fprintf(stderr, "Warning: tokenization failed for seq_id %d\n", seq_id); + llama_batch_free(batch); + return false; + } + + for (int i = 0; i < n_tokens; i++) { + tokens.push_back(prompt_tokens[i]); + } + + for (size_t i = 0; i < tokens.size(); i++) { + common_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1); + } + + seq_positions[seq_id] = tokens.size(); + } + + + printf("Batch contents:\n"); + printf("n_tokens: %d\n", batch.n_tokens); + for (int i = 0; i < batch.n_tokens; i++) { + printf("token[%d]: tok=%-5d, pos=%d, n_seq_id=%d, seq_ids=[", i, batch.token[i], batch.pos[i], batch.n_seq_id[i]); + + for (int j = 0; j < batch.n_seq_id[i]; j++) { + printf("%d%s", batch.seq_id[i][j], j < batch.n_seq_id[i]-1 ? ", " : ""); + } + printf("], logits=%d\n", batch.logits[i]); + } + + if (llama_decode(ctx, batch) != 0) { + fprintf(stderr, "Warning: llama_decode failed\n"); + llama_batch_free(batch); + return false; + } + + // Build mapping from seq id to batch token idx + for (int i = 0; i < batch.n_tokens; i++) { + if (batch.logits[i]) { + llama_seq_id seq_id = batch.seq_id[i][0]; + last_batch_info[seq_id] = i; + } + } + + llama_batch_free(batch); + return true; + } + + int32_t idx_for_seq(llama_seq_id seq_id) { + auto it = last_batch_info.find(seq_id); + if (it == last_batch_info.end()) { + fprintf(stderr, "Error: no batch index found for seq_id %d\n", seq_id); + return -1; + } + return it->second; + } + + bool decode_token(llama_token token, llama_seq_id seq_id = 0) { + if (ctx == nullptr) { + fprintf(stderr, "Error: context not initialized, call setup() first\n"); + return false; + } + + llama_batch batch = llama_batch_init(1, 0, 1); + int32_t pos = seq_positions[seq_id]; + common_batch_add(batch, token, pos, { seq_id }, true); + + if (llama_decode(ctx, batch) != 0) { + fprintf(stderr, "Warning: llama_decode failed for token %d in seq %d\n", token, seq_id); + llama_batch_free(batch); + return false; + } + + last_batch_info.clear(); + for (int i = 0; i < batch.n_tokens; i++) { + if (batch.logits[i]) { + llama_seq_id cur_seq = batch.seq_id[i][0]; + last_batch_info[cur_seq] = i; + } + } + + seq_positions[seq_id]++; + llama_batch_free(batch); + return true; + } + + bool decode_tokens(const std::map & seq_tokens) { + if (ctx == nullptr) { + fprintf(stderr, "Error: context not initialized, call setup() first\n"); + return false; + } + + llama_batch batch = llama_batch_init(seq_tokens.size(), 0, seq_tokens.size()); + + for (const auto & [seq_id, token] : seq_tokens) { + int32_t pos = seq_positions[seq_id]; + common_batch_add(batch, token, pos, { seq_id }, true); + } + + if (llama_decode(ctx, batch) != 0) { + fprintf(stderr, "Warning: llama_decode failed for batch tokens\n"); + llama_batch_free(batch); + return false; + } + + for (const auto & [seq_id, _] : seq_tokens) { + seq_positions[seq_id]++; + } + + last_batch_info.clear(); + for (int i = 0; i < batch.n_tokens; i++) { + if (batch.logits[i]) { + llama_seq_id cur_seq = batch.seq_id[i][0]; + last_batch_info[cur_seq] = i; + } + } + + llama_batch_free(batch); + return true; + } + + std::string token_to_piece(llama_token token, bool special) { + std::string piece; + piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' + const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); + if (n_chars < 0) { + piece.resize(-n_chars); + int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); + GGML_ASSERT(check == -n_chars); + } + else { + piece.resize(n_chars); + } + + return piece; + } + + void reset() { + if (ctx) { + llama_free(ctx); + ctx = nullptr; + } + seq_positions.clear(); + last_batch_info.clear(); + } + + void cleanup() { + if (ctx) { + llama_free(ctx); + } + if (model) { + llama_model_free(model); + } + + ctx = nullptr; + model = nullptr; + vocab = nullptr; + } + + ~test_model_context() { + cleanup(); + } +}; + +static void test_backend_greedy_sampling(const char * model_path) { + test_model_context test_ctx; + + const int seq_id = 0; + + struct llama_sampler_chain_params backend_sampler_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_sampler_params); + + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_greedy()); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Some"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); + printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + token = llama_get_sampled_token_ith(test_ctx.ctx, -1); + printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + for (int i = 0; i < 10; i++) { + int32_t loop_idx = test_ctx.idx_for_seq(seq_id); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, loop_idx); + printf("Generation step %d: token id:%d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); + if (!test_ctx.decode_token(token, 0)) { + GGML_ASSERT(false && "Failed to decode token"); + } + } + + llama_sampler_free(backend_sampler_chain); +} + +static void test_backend_top_k_sampling(const char * model_path) { + test_model_context test_ctx; + + const int seq_id = 0; + const int32_t k = 8; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_top_k(k)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx); + uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); + for (size_t i = 0; i < n_logits; ++i) { + printf("top_k logit[%zu] = %.6f\n", i, logits[i]); + } + + llama_token * candidates = llama_get_sampled_candidates_ith(test_ctx.ctx, batch_idx); + uint32_t n_candidates = llama_get_sampled_candidates_count_ith(test_ctx.ctx, batch_idx); + for (size_t i = 0; i < n_candidates; ++i) { + printf("top_k candidate[%zu] = %d : %s\n", i, candidates[i], + test_ctx.token_to_piece(candidates[i], false).c_str()); + } + + llama_sampler_free(backend_sampler_chain); + + // Sample using CPU sampler for verification that it is possible to do hybrid + // sampling, first top_k on the backend and then dist on the CPU. + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + GGML_ASSERT(chain->iface->backend_apply != nullptr); + + llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); + llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + printf("backend top-k hybrid sampling test PASSED\n"); + + llama_sampler_free(chain); +} + +static void test_backend_temp_sampling(const char * model_path) { + test_model_context test_ctx; + + { + const float temp_0 = 0.8f; + struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain_0 = llama_sampler_chain_init(backend_chain_params_0); + llama_sampler_chain_add(backend_sampler_chain_0, llama_sampler_init_temp(temp_0)); + + const float temp_1 = 0.1f; + struct llama_sampler_chain_params backend_chain_params_1 = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain_1 = llama_sampler_chain_init(backend_chain_params_1); + llama_sampler_chain_add(backend_sampler_chain_1, llama_sampler_init_temp(temp_1)); + + std::vector backend_sampler_configs = { + { 0, backend_sampler_chain_0 }, + { 1, backend_sampler_chain_1 } + }; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{0, "Some where over"}, {1, "Once upon a"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + llama_sampler_free(backend_sampler_chain_0); + llama_sampler_free(backend_sampler_chain_1); + + // Verfify sequence 0 + { + int32_t batch_idx = test_ctx.idx_for_seq(0); + int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); + GGML_ASSERT(n_logits == test_ctx.n_vocab); + + // Sample from sequence 0 using CPU sampler + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); + + llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("Sequence 0 sampled token id:%d, string: '%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + llama_sampler_free(chain); + } + + + // Verfify sequence 1 + { + int32_t batch_idx = test_ctx.idx_for_seq(1); + + // Sample from sequence 1 using CPU sampler + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); + + llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("Sequence 1 sampled token id:%d, string: '%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + llama_sampler_free(chain); + } + } + + // lambda to testing non-positive temperature values. + auto test_argmax_temp = [&](float temp) { + printf("\nTesting temperature = %.1f\n", temp); + + test_ctx.reset(); + + int seq_id = 0; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp(temp)); + + std::vector backend_sampler_configs = { + { seq_id, backend_sampler_chain }, + }; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Once"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); + GGML_ASSERT(n_logits == 1); + + llama_sampler_free(backend_sampler_chain); + }; + + test_argmax_temp(0.0f); + test_argmax_temp(-1.0f); + + printf("backend temp sampling test PASSED\n"); + +} + +static void test_backend_temp_ext_sampling(const char * model_path) { + test_model_context test_ctx; + + { + int seq_id = 0; + const float temp = 0.8f; + const float delta = 0.5f; + const float exponent = 1.5f; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp_ext(temp, delta, exponent)); + + std::vector backend_sampler_configs = { + { seq_id, backend_sampler_chain }, + }; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Once upon a"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + // Verify sequence 0 + { + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); + GGML_ASSERT(n_logits == test_ctx.n_vocab); + } + + llama_sampler_free(backend_sampler_chain); + } + + test_ctx.reset(); + + // lambda to testing non-positive temp/delta/exponent values. + auto test_argmax_temp = [&](float temp, float delta, float exponent) { + printf("\nTesting temperature = %.1f, delta = %1.f, exponent = %1.f\n", temp, delta, exponent); + + test_ctx.reset(); + + int seq_id = 0; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp_ext(temp, delta, exponent)); + + std::vector backend_sampler_configs = { + { seq_id, backend_sampler_chain }, + }; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Once"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); + + if (temp <= 0.0f && delta >= 0.0f) { + GGML_ASSERT(n_logits == 1); + } else { + GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab); + } + + llama_sampler_free(backend_sampler_chain); + }; + + test_argmax_temp(0.0f, 0.3f, 1.0f); // Greedy (temp=0) + test_argmax_temp(-1.0f, 0.3f, 2.0f); // Greedy (temp<0) + test_argmax_temp(0.8f, 0.0f, 2.0f); // Temperature scaling + + printf("backend temp_ext sampling test PASSED\n"); + +} + +static void test_backend_min_p_sampling(const char * model_path) { + test_model_context test_ctx; + + const int seq_id = 0; + const float p = 0.1; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_min_p(p, 0)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx); + uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); + + // Print the logits that are above the min-p threshold + std::vector filtered_logits; + for (size_t i = 0; i < n_logits; ++i) { + if (logits[i] > -1e9f) { + filtered_logits.push_back(logits[i]); + //printf("min_p logit[%zu] = %.6f\n", i, logits[i]); + } + } + GGML_ASSERT(filtered_logits.size() < (size_t) test_ctx.n_vocab); + + // Sample using CPU sampler for verification to inspect they are reasonable + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + llama_sampler_chain_add(chain, llama_sampler_init_dist(88)); + + llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("min-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + // Decode and sampler 10 more tokens + for (int i = 0; i < 10; i++) { + int32_t loop_idx = test_ctx.idx_for_seq(seq_id); + llama_token token = llama_sampler_sample(chain, test_ctx.ctx, loop_idx); + printf("min-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); + if (!test_ctx.decode_token(token, 0)) { + GGML_ASSERT(false && "Failed to decode token"); + } + } + + printf("min-p sampling test PASSED\n"); + + llama_sampler_free(backend_sampler_chain); + llama_sampler_free(chain); +} + +static void test_backend_top_p_sampling(const char * model_path) { + test_model_context test_ctx; + + const int seq_id = 0; + const float p = 0.9; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_top_p(p, 0)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + return; + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx); + uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); + + // Print the logits that are above the min-p threshold + std::vector filtered_logits; + for (size_t i = 0; i < n_logits; ++i) { + if (logits[i] > -1e9f) { + filtered_logits.push_back(logits[i]); + } + } + GGML_ASSERT(filtered_logits.size() < (size_t) test_ctx.n_vocab); + GGML_ASSERT(filtered_logits.size() > 0); + + // Sample using CPU sampler for verification to inspect they are reasonable + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + llama_sampler_chain_add(chain, llama_sampler_init_dist(88)); + + llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("top-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + // Decode and sampler 10 more tokens + for (int i = 0; i < 10; i++) { + int32_t loop_idx = test_ctx.idx_for_seq(seq_id); + llama_token token = llama_sampler_sample(chain, test_ctx.ctx, loop_idx); + printf("top-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); + test_ctx.decode_token(token, 0); + } + + printf("top-p sampling test PASSED\n"); + + llama_sampler_free(backend_sampler_chain); + llama_sampler_free(chain); +} + +static void test_backend_multi_sequence_sampling(const char * model_path) { + test_model_context test_ctx; + + struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); + struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0); + llama_sampler_chain_add(sampler_chain_0, llama_sampler_init_greedy()); + + struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params(); + struct llama_sampler * sampler_chain_1 = llama_sampler_chain_init(chain_params_1); + llama_sampler_chain_add(sampler_chain_1, llama_sampler_init_temp(0.8f)); + llama_sampler_chain_add(sampler_chain_1, llama_sampler_init_greedy()); + + std::vector backend_sampler_configs = { + { 0, sampler_chain_0 }, + { 1, sampler_chain_1 } + }; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + std::map prompts = { + {0, "Hello"}, + {1, "Some"} + }; + + if (!test_ctx.decode(prompts)) { + GGML_ASSERT(false && "Failed to decode token"); + } + + // Verfiy sequence 0 + { + int32_t batch_idx = test_ctx.idx_for_seq(0); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("Seq 0 sampled token id=%d, string='%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + } + + // Verify sequence 1 + { + int32_t batch_idx= test_ctx.idx_for_seq(1); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("Seq 1 sampled token id=%d, string='%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + } + + // Generate tokens for each sequence + printf("\nMulti-sequence generation:\n"); + for (int step = 0; step < 4; step++) { + std::map tokens; + + for (llama_seq_id seq_id : {0, 1}) { + int32_t idx = test_ctx.idx_for_seq(seq_id); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf(" Seq %d, step %d: token id=%d, string='%s'\n", seq_id, step, token, token_str.c_str()); + tokens[seq_id] = token; + } + + // Decode all tokens in a single batch + if (!test_ctx.decode_tokens(tokens)) { + GGML_ASSERT(false && "Failed to decode token"); + } + } + + llama_sampler_free(sampler_chain_0); + llama_sampler_free(sampler_chain_1); + + printf("backend multi-sequence sampling test PASSED\n"); +} + +static void test_backend_dist_sampling(const char * model_path) { + test_model_context test_ctx; + + const int seq_id = 189; + const int32_t seed = 88; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Some"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); + printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + //GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr); + + token = llama_get_sampled_token_ith(test_ctx.ctx, -1); + printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + llama_sampler_free(backend_sampler_chain); + + printf("backend dist sampling test PASSED\n"); +} + +static void test_backend_dist_sampling_and_cpu(const char * model_path) { + test_model_context test_ctx; + + const int seq_id = 0; + const int32_t seed = 88; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Some"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + // Sample using CPU sampler + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); + + llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); + llama_token cpu_token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + printf("dist & cpu sampled id:%d, string:'%s'\n", cpu_token, test_ctx.token_to_piece(cpu_token, false).c_str()); + GGML_ASSERT(backend_token == cpu_token); + + llama_sampler_free(backend_sampler_chain); + llama_sampler_free(chain); + + printf("backend dist & cpu sampling test PASSED\n"); +} + +static void test_backend_logit_bias_sampling(const char * model_path) { + test_model_context test_ctx; + + // Calling load_model to ensure vocab is loaded and can be accessed + if (!test_ctx.load_model(model_path)) { + return; + } + + const int seq_id = 0; + + // Create the logit biases vector. + std::vector logit_bias; + + // Get the token for the piece "World". + const std::string piece = "World"; + std::vector tokens(16); + llama_tokenize(test_ctx.vocab, piece.c_str(), piece.size(), tokens.data(), tokens.size(), false, false); + llama_token bias_token = tokens[0]; + logit_bias.push_back({ bias_token, +100.0f }); + printf("biasing token piece '%s' -> token id %d\n", piece.c_str(), bias_token); + + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_logit_bias( + llama_vocab_n_tokens(test_ctx.vocab), + logit_bias.size(), + logit_bias.data())); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(88)); + + std::vector backend_sampler_configs = { + { seq_id, backend_sampler_chain }, + }; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(seq_id)); + const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false); + printf("logit bias sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str()); + GGML_ASSERT(backend_token == bias_token); + + printf("backend logit bias sampling test PASSED\n"); + + llama_sampler_free(backend_sampler_chain); +} + +// This test verifies that it is possible to have two different backend sampler, +// one that uses the backend dist sampler, and another that uses CPU dist sampler. +static void test_backend_mixed_sampling(const char * model_path) { + test_model_context test_ctx; + + struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); + struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0); + llama_sampler_chain_add(sampler_chain_0, llama_sampler_init_dist(88)); + + int k = 40; + struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params(); + struct llama_sampler * sampler_chain_1 = llama_sampler_chain_init(chain_params_1); + llama_sampler_chain_add(sampler_chain_1, llama_sampler_init_top_k(k)); + + std::vector backend_sampler_configs = { + { 0, sampler_chain_0 }, + { 1, sampler_chain_1 } + }; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + std::map prompts = { + {0, "Hello"}, + {1, "Some"} + }; + + if (!test_ctx.decode(prompts)) { + GGML_ASSERT(false && "Failed to decode token"); + } + + // Verfiy sequence 0 that used the dist backend sampler. + { + int32_t batch_idx = test_ctx.idx_for_seq(0); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("sampled token id=%d, string='%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + //GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr); + //GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx) == 0); + } + + // Verfiy sequence 1 that used the top-k backend sampler. + { + int32_t batch_idx = test_ctx.idx_for_seq(1); + float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx); + GGML_ASSERT(logits != nullptr); + size_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); + GGML_ASSERT(n_logits == (size_t) k); + GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx, batch_idx) == LLAMA_TOKEN_NULL); + } + + llama_sampler_free(sampler_chain_0); + llama_sampler_free(sampler_chain_1); + + printf("backend mixed sampling test PASSED\n"); +} + +static void test_backend_set_sampler(const char * model_path) { + test_model_context test_ctx; + + const int32_t seed = 88; + const int seq_id = 0; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + // Sample using backend sampler configured above + llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); + const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false); + printf("dist sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str()); + + // Now clear the backend sampler for this sequence. + llama_set_sampler(test_ctx.ctx, seq_id, nullptr); + printf("Cleared backend sampler for seq_id %d\n", seq_id); + + // Sample using CPU sampler + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); + + std::map tokens = { { seq_id, backend_token}, }; + if (!test_ctx.decode_tokens(tokens)) { + GGML_ASSERT(false && "Failed to decode token"); + } + + // Should not have any sampled token or probs after clearing the backend sampler. + const int32_t idx = test_ctx.idx_for_seq(seq_id); + GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx, idx) == LLAMA_TOKEN_NULL); + GGML_ASSERT(llama_get_sampled_probs_ith(test_ctx.ctx, idx) == nullptr); + + // Sample the token using the CPU sampler chain. + llama_token token2 = llama_sampler_sample(chain, test_ctx.ctx, seq_id); + const std::string token2_str = test_ctx.token_to_piece(token2, false); + printf("CPU sampled token after clearing backend sampler: id=%d, string='%s'\n", token2, token2_str.c_str()); + std::map tokens2 = { { seq_id, token2}, }; + + // Set a new backend sampler for the sequence. + struct llama_sampler_chain_params new_backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * new_backend_sampler_chain = llama_sampler_chain_init(new_backend_chain_params); + llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_init_top_k(20)); + llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_init_dist(seed)); + llama_set_sampler(test_ctx.ctx, seq_id, new_backend_sampler_chain); + + if (!test_ctx.decode_tokens(tokens2)) { + GGML_ASSERT(false && "Failed to decode token"); + } + + llama_token new_backend_token = llama_get_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(seq_id)); + const std::string new_backend_token_str = test_ctx.token_to_piece(new_backend_token, false); + printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str()); + + llama_sampler_free(backend_sampler_chain); + llama_sampler_free(chain); + llama_sampler_free(new_backend_sampler_chain); + + printf("backend set sampler test PASSED\n"); +} + +static void test_backend_cpu_mixed_batch(const char * model_path) { + test_model_context test_ctx; + + // Sequence 0 uses backend sampling + struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); + struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0); + llama_sampler_chain_add(sampler_chain_0, llama_sampler_init_dist(88)); + + std::vector backend_sampler_configs = { + { 0, sampler_chain_0 }, + }; + + // We need 2 sequences: seq 0 with backend sampling, seq 1 with CPU sampling + if (!test_ctx.setup(model_path, backend_sampler_configs, 2)) { + return; + } + + std::map prompts = { + {0, "Hello"}, // Will use backend sampling + {1, "Some"} // Will use CPU sampling + }; + + if (!test_ctx.decode(prompts)) { + GGML_ASSERT(false && "Failed to decode token"); + } + + // Verify sequence 0 (backend sampled) + { + int32_t batch_idx = test_ctx.idx_for_seq(0); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("Seq 0 (backend) sampled token id=%d, string='%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + } + + // Verify sequence 1 (CPU sampled) + { + int32_t batch_idx = test_ctx.idx_for_seq(1); + + llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); + GGML_ASSERT(backend_token == LLAMA_TOKEN_NULL); + + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + llama_sampler_chain_add(chain, llama_sampler_init_greedy()); + + llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("Seq 1 (CPU) sampled token id=%d, string='%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + llama_sampler_free(chain); + } + + // Clear/remove the backend sampler, and sample again + { + // clear the backend sampler for seq 0 so that there are no backend + // samplers. + llama_set_sampler(test_ctx.ctx, 0, nullptr); + + // Create a CPU sampler and verify we can sampler from it. + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + llama_sampler_chain_add(chain, llama_sampler_init_greedy()); + + int32_t batch_idx = test_ctx.idx_for_seq(1); + llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + if (!test_ctx.decode_token(token, 1)) { + GGML_ASSERT(false && "Failed to decode token"); + } + + llama_sampler_free(chain); + } + + // Set a backend sampler so that we can verify that it can be reset + { + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * sampler_chain = llama_sampler_chain_init(chain_params); + llama_sampler_chain_add(sampler_chain, llama_sampler_init_dist(88)); + + llama_set_sampler(test_ctx.ctx, 0, sampler_chain); + + if (!test_ctx.decode_token(3834, 0)) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(0); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("re-added backend sampled token id=%d, string='%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + llama_sampler_free(sampler_chain); + } + + llama_sampler_free(sampler_chain_0); + + printf("backend-cpu mixed batch test PASSED\n"); +} + +static void test_backend_max_outputs(const char * model_path) { + test_model_context test_ctx; + + const int seq_id = 0; + const int32_t seed = 88; + llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + llama_batch batch = llama_batch_init(512, 0, 1); + std::string prompt = "Hello"; + + std::vector tokens; + tokens.push_back(llama_vocab_bos(test_ctx.vocab)); + + std::vector prompt_tokens(32); + int n_tokens = llama_tokenize(test_ctx.vocab, prompt.c_str(), prompt.length(), + prompt_tokens.data(), prompt_tokens.size(), + false, false); + for (int i = 0; i < n_tokens; i++) { + tokens.push_back(prompt_tokens[i]); + } + + for (size_t i = 0; i < tokens.size(); i++) { + // set all tokens as output to trigger error + common_batch_add(batch, tokens[i], i, { seq_id }, true); + } + + printf(">>> test_max_outputs expected error start:\n"); + const int ret = llama_decode(test_ctx.ctx, batch); + GGML_ASSERT(ret != 0 && "llama_decode should not succeed multiple outputs per sequence"); + printf("<<< test_max_outputs expected error end.\n"); + llama_batch_free(batch); + + llama_sampler_free(backend_sampler_chain); + printf("backend max outputs test PASSED\n"); +} + +struct backend_test_case { + const char * name; + void (*fn)(const char *); + bool enabled_by_default; +}; + +static const backend_test_case BACKEND_TESTS[] = { + { "greedy", test_backend_greedy_sampling, true }, + { "logit_bias", test_backend_logit_bias_sampling, true }, + { "temp", test_backend_temp_sampling, true }, + { "temp_ext", test_backend_temp_ext_sampling, true }, + { "top_k", test_backend_top_k_sampling, true }, + { "multi_sequence", test_backend_multi_sequence_sampling, true }, + { "dist", test_backend_dist_sampling, true }, + { "dist_and_cpu", test_backend_dist_sampling_and_cpu, true }, + { "set_sampler", test_backend_set_sampler, true }, + { "max_outputs", test_backend_max_outputs, true }, + { "mixed", test_backend_mixed_sampling, true }, + { "min_p", test_backend_min_p_sampling, true }, + { "cpu_mixed", test_backend_cpu_mixed_batch, true }, + { "top_p", test_backend_top_p_sampling, true }, +}; + +struct backend_cli_args { + const char * model = nullptr; + const char * test = nullptr; +}; + +static backend_cli_args parse_backend_cli(int argc, char ** argv) { + backend_cli_args out; + + for (int i = 1; i < argc; ++i) { + const char * arg = argv[i]; + + if (std::strcmp(arg, "--test") == 0) { + if (i + 1 >= argc) { + fprintf(stderr, "--test expects a value\n"); + exit(EXIT_FAILURE); + } + out.test = argv[++i]; + continue; + } + if (std::strncmp(arg, "--test=", 7) == 0) { + out.test = arg + 7; + continue; + } + if (std::strcmp(arg, "--model") == 0) { + if (i + 1 >= argc) { + fprintf(stderr, "--model expects a value\n"); + exit(EXIT_FAILURE); + } + out.model = argv[++i]; + continue; + } + if (std::strncmp(arg, "--model=", 8) == 0) { + out.model = arg + 8; + continue; + } + if (!out.model) { + out.model = arg; + continue; + } + + fprintf(stderr, "Unexpected argument: %s\n", arg); + exit(EXIT_FAILURE); + } + + return out; +} + +static std::vector collect_tests_to_run(const char * requested) { + std::vector selected; + + if (requested != nullptr) { + for (const auto & test : BACKEND_TESTS) { + if (std::strcmp(test.name, requested) == 0) { + selected.push_back(&test); + break; + } + } + if (selected.empty()) { + fprintf(stderr, "Unknown test '%s'. Available tests:\n", requested); + for (const auto & test : BACKEND_TESTS) { + fprintf(stderr, " %s\n", test.name); + } + exit(EXIT_FAILURE); + } + } else { + for (const auto & test : BACKEND_TESTS) { + if (test.enabled_by_default) { + selected.push_back(&test); + } + } + } + + if (selected.empty()) { + fprintf(stderr, "No backend sampling tests selected. Use --test= to pick one.\n"); + } + + return selected; +} + +static void run_tests(const std::vector & tests, const char * model_path) { + for (const auto * test : tests) { + fprintf(stderr, "\n=== %s ===\n", test->name); + test->fn(model_path); + } +} + + +int main(int argc, char *argv[] ) { + const backend_cli_args args = parse_backend_cli(argc, argv); + + std::array model_argv { argv[0], const_cast(args.model) }; + const int model_argc = args.model ? 2 : 1; + char * model_path = get_model_or_exit(model_argc, model_argv.data()); + + auto * file = fopen(model_path, "r"); + if (file == nullptr) { + fprintf(stderr, "no model at '%s' found\n", model_path); + return EXIT_FAILURE; + } + + fprintf(stderr, "using '%s'\n", model_path); + fclose(file); + + ggml_time_init(); + + const std::vector tests = collect_tests_to_run(args.test); + if (!tests.empty()) { + run_tests(tests, model_path); + } + + return 0; +} diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 2ff90e800a..49891e0b2e 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index ab6b3aa7ce..af21e3d45c 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -1397,16 +1397,21 @@ json format_response_rerank( std::vector get_token_probabilities(llama_context * ctx, int idx) { std::vector cur; + const auto * logits = llama_get_logits_ith(ctx, idx); + const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx); - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); + const int n_logits = llama_get_sampled_logits_count_ith(ctx, idx); - const int n_vocab = llama_vocab_n_tokens(vocab); - - cur.resize(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + cur.resize(n_logits); + if (sampled_ids) { + for (int i = 0; i < n_logits; i++) { + cur[i] = llama_token_data{sampled_ids[i], logits[i], 0.0f}; + } + } else { + for (llama_token token_id = 0; token_id < n_logits; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } } // sort tokens by logits diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 90898b5ec4..0faf5df159 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1056,6 +1056,25 @@ struct server_context_impl { return false; } + const bool need_logits = task.params.sampling.n_probs > 0; + + bool backend_sampling = true; + + backend_sampling &= task.params.sampling.backend_sampling; + + // TODO: speculative decoding requires multiple samples per batch - not supported yet + backend_sampling &= !(slot.ctx_dft && task.params.speculative.n_max > 0); + + // TODO: getting post/pre sampling logits is not yet supported with backend sampling + backend_sampling &= !need_logits; + + // TODO: tmp until backend sampling is fully implemented + if (backend_sampling) { + llama_set_sampler(ctx, slot.id, common_sampler_get(slot.smpl.get())); + } else { + llama_set_sampler(ctx, slot.id, nullptr); + } + SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str()); } diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 360826062b..337895a5ef 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -78,6 +78,7 @@ json task_params::to_json(bool only_metrics) const { {"speculative.p_min", speculative.p_min}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, + {"backend_sampling", sampling.backend_sampling}, {"lora", lora}, }; } @@ -136,6 +137,7 @@ json task_params::to_json(bool only_metrics) const { {"speculative.p_min", speculative.p_min}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, + {"backend_sampling", sampling.backend_sampling}, {"lora", lora}, }; } @@ -206,8 +208,12 @@ task_params server_task::params_from_json_cmpl( params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); + params.sampling.backend_sampling = json_value(data, "backend_sampling", defaults.sampling.backend_sampling); params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); + printf("params.sampling.backend_sampling = %d\n", params.sampling.backend_sampling); + printf("defaults.sampling.backend_sampling = %d\n", defaults.sampling.backend_sampling); + params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 64f3158b98..08b5265d48 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -13,16 +13,16 @@ def create_server(): @pytest.mark.parametrize( "model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template", [ - (None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", False, None), - (None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", True, None), - (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None), - (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None), - (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'), - (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", False, None), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", True, None), - (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None), - (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None), + (None, "Book", "Hey", 8, "But she couldn't|Some of her", 69, 8, "length", False, None), + (None, "Book", "Hey", 8, "But she couldn't|Some of her", 69, 8, "length", True, None), + (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.|Timmy", 77, 8, "length", False, None), + (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.|Timmy", 77, 8, "length", True, None), + (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.|Timmy", 77, 8, "length", True, 'chatml'), + (None, "Book", "What is the best book", 8, "^ blue|very teaful|very busy", 23, 8, "length", True, "This is not a chat template, it is"), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger|shake)+", 104, 128, "length", False, None), + ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger|shake)+", 104, 128, "length", True, None), + (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter|Some", 79, 8, "length", False, None), + (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter|Some", 79, 8, "length", True, None), ] ) def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template): @@ -54,8 +54,8 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte @pytest.mark.parametrize( "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason", [ - ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"), - ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length"), + ("Book", "What is the best book", 8, "(Suddenly)+|Timmy", 77, 8, "length"), + ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger|shake)+", 104, 128, "length"), ] ) def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): @@ -115,7 +115,7 @@ def test_chat_completion_with_openai_library(): assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b") assert res.choices[0].finish_reason == "length" assert res.choices[0].message.content is not None - assert match_regex("(Suddenly)+", res.choices[0].message.content) + assert match_regex("(Suddenly)+|Timmy", res.choices[0].message.content) def test_chat_template(): @@ -494,5 +494,5 @@ def test_chat_completions_multiple_choices(): assert len(res.body["choices"]) == 2 for choice in res.body["choices"]: assert "assistant" == choice["message"]["role"] - assert match_regex("Suddenly", choice["message"]["content"]) + assert match_regex("Suddenly|Timmy", choice["message"]["content"]) assert choice["finish_reason"] == "length" diff --git a/tools/server/tests/unit/test_completion.py b/tools/server/tests/unit/test_completion.py index ef1757db21..daaa6e5a90 100644 --- a/tools/server/tests/unit/test_completion.py +++ b/tools/server/tests/unit/test_completion.py @@ -17,7 +17,7 @@ def create_server(): server = ServerPreset.tinyllama2() @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [ - ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False), + ("I believe the meaning of life is", 8, "(going|bed)+|froze and every|froze and bri", 18, 8, False, False), ("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True), ]) def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool): @@ -42,7 +42,7 @@ def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, @pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [ - ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False), + ("I believe the meaning of life is", 8, "(going|bed)+|froze and every|froze and bri", 18, 8, False), ("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), ]) def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool): @@ -103,7 +103,7 @@ def test_completion_with_openai_library(): assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b") assert res.choices[0].finish_reason == "length" assert res.choices[0].text is not None - assert match_regex("(going|bed)+", res.choices[0].text) + assert match_regex("(going|bed)+|froze and every|froze and bri", res.choices[0].text) def test_completion_stream_with_openai_library(): @@ -122,7 +122,7 @@ def test_completion_stream_with_openai_library(): if choice.finish_reason is None: assert choice.text is not None output_text += choice.text - assert match_regex("(going|bed)+", output_text) + assert match_regex("(going|bed)+|froze and every|froze and bri", output_text) # Test case from https://github.com/ggml-org/llama.cpp/issues/13780 @@ -146,7 +146,7 @@ def test_completion_stream_with_openai_library_stops(): if choice.finish_reason is None: assert choice.text is not None output_text += choice.text - assert match_regex("Sure, here's one for[\\s\\S]*", output_text), f'Unexpected output: {output_text}' + assert match_regex("Sure, here's one for[\\s\\S]*|Sure thing..Why don't|Sure! Here's one for you:", output_text), f'Unexpected output: {output_text}' @pytest.mark.parametrize("n_slots", [1, 2]) @@ -511,8 +511,8 @@ def test_n_probs_post_sampling(): assert "token" in prob and type(prob["token"]) == str assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0 assert "bytes" in prob and type(prob["bytes"]) == list - # because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs - assert any(prob["prob"] == 1.0 for prob in tok["top_probs"]) + # at low temperature, one of the token has a very high probability + assert any(prob["prob"] >= 0.99 for prob in tok["top_probs"]) @pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)]) diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte index 4ec9b478fd..5a668aa300 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte @@ -185,6 +185,11 @@ key: 'samplers', label: 'Samplers', type: 'input' + }, + { + key: 'backend_sampling', + label: 'Backend sampling', + type: 'checkbox' } ] }, diff --git a/tools/server/webui/src/lib/constants/settings-config.ts b/tools/server/webui/src/lib/constants/settings-config.ts index f9584d01d7..cac48a557c 100644 --- a/tools/server/webui/src/lib/constants/settings-config.ts +++ b/tools/server/webui/src/lib/constants/settings-config.ts @@ -21,6 +21,7 @@ export const SETTING_CONFIG_DEFAULT: Record = autoMicOnEmpty: false, // make sure these default values are in sync with `common.h` samplers: 'top_k;typ_p;top_p;min_p;temperature', + backend_sampling: false, temperature: 0.8, dynatemp_range: 0.0, dynatemp_exponent: 1.0, @@ -57,6 +58,8 @@ export const SETTING_CONFIG_INFO: Record = { 'When copying a message with text attachments, combine them into a single plain text string instead of a special format that can be pasted back as attachments.', samplers: 'The order at which samplers are applied, in simplified way. Default is "top_k;typ_p;top_p;min_p;temperature": top_k->typ_p->top_p->min_p->temperature', + backend_sampling: + 'Enable backend-based samplers. When enabled, supported samplers run on the accelerator backend for faster sampling.', temperature: 'Controls the randomness of the generated text by affecting the probability distribution of the output tokens. Higher = more random, lower = more focused.', dynatemp_range: diff --git a/tools/server/webui/src/lib/services/chat.ts b/tools/server/webui/src/lib/services/chat.ts index c03b764419..fb98d2c995 100644 --- a/tools/server/webui/src/lib/services/chat.ts +++ b/tools/server/webui/src/lib/services/chat.ts @@ -86,6 +86,7 @@ export class ChatService { dry_penalty_last_n, // Other parameters samplers, + backend_sampling, custom, timings_per_token, // Config options @@ -158,6 +159,8 @@ export class ChatService { : samplers; } + if (backend_sampling !== undefined) requestBody.backend_sampling = backend_sampling; + if (timings_per_token !== undefined) requestBody.timings_per_token = timings_per_token; if (custom) { diff --git a/tools/server/webui/src/lib/stores/chat.svelte.ts b/tools/server/webui/src/lib/stores/chat.svelte.ts index 4f78840a57..0a551e890a 100644 --- a/tools/server/webui/src/lib/stores/chat.svelte.ts +++ b/tools/server/webui/src/lib/stores/chat.svelte.ts @@ -1401,6 +1401,8 @@ class ChatStore { if (hasValue(currentConfig.dry_penalty_last_n)) apiOptions.dry_penalty_last_n = Number(currentConfig.dry_penalty_last_n); if (currentConfig.samplers) apiOptions.samplers = currentConfig.samplers; + if (currentConfig.backend_sampling) + apiOptions.backend_sampling = currentConfig.backend_sampling; if (currentConfig.custom) apiOptions.custom = currentConfig.custom; return apiOptions; diff --git a/tools/server/webui/src/lib/types/api.d.ts b/tools/server/webui/src/lib/types/api.d.ts index 4bc92b57bc..26d2bcc0d8 100644 --- a/tools/server/webui/src/lib/types/api.d.ts +++ b/tools/server/webui/src/lib/types/api.d.ts @@ -149,6 +149,7 @@ export interface ApiLlamaCppServerProps { reasoning_in_content: boolean; thinking_forced_open: boolean; samplers: string[]; + backend_sampling: boolean; 'speculative.n_max': number; 'speculative.n_min': number; 'speculative.p_min': number; @@ -210,6 +211,7 @@ export interface ApiChatCompletionRequest { dry_penalty_last_n?: number; // Sampler configuration samplers?: string[]; + backend_sampling?: boolean; // Custom parameters (JSON string) custom?: Record; timings_per_token?: boolean; @@ -310,6 +312,7 @@ export interface ApiSlotData { reasoning_in_content: boolean; thinking_forced_open: boolean; samplers: string[]; + backend_sampling: boolean; 'speculative.n_max': number; 'speculative.n_min': number; 'speculative.p_min': number; diff --git a/tools/server/webui/src/lib/types/settings.d.ts b/tools/server/webui/src/lib/types/settings.d.ts index 40de98b708..ecd5802fb6 100644 --- a/tools/server/webui/src/lib/types/settings.d.ts +++ b/tools/server/webui/src/lib/types/settings.d.ts @@ -43,6 +43,7 @@ export interface SettingsChatServiceOptions { dry_penalty_last_n?: number; // Sampler configuration samplers?: string | string[]; + backend_sampling?: boolean; // Custom parameters custom?: string; timings_per_token?: boolean;