From 4c3d5422ad23a9ddb16803094fdc2cda79ee9151 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 31 Dec 2025 16:59:42 +0200 Subject: [PATCH] minor : add comments + some cleanup --- common/arg.cpp | 2 +- include/llama.h | 27 ++++++++++++++++----------- tests/CMakeLists.txt | 7 +++---- tests/test-backend-ops.cpp | 19 ++++++++++--------- 4 files changed, 30 insertions(+), 25 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index 3caee0da69..b52b3e70b7 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1697,7 +1697,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_sparam()); add_opt(common_arg( {"-bs", "--backend-sampling"}, - "enable backend sampling (default: disabled)", + "enable backend sampling (experimental) (default: disabled)", [](common_params & params) { params.sampling.backend_sampling = true; } diff --git a/include/llama.h b/include/llama.h index b5f6657bc9..a5a065ed84 100644 --- a/include/llama.h +++ b/include/llama.h @@ -370,7 +370,8 @@ extern "C" { // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix // ref: https://github.com/ggml-org/llama.cpp/pull/14363 - // backend sampler chain configuration (make sure the caller keeps the sampler chains alive) [EXPERIMENTAL] + // [EXPERIMENTAL] + // backend sampler chain configuration (make sure the caller keeps the sampler chains alive) // note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init) struct llama_sampler_seq_config * samplers; size_t n_samplers; @@ -1000,6 +1001,11 @@ extern "C" { // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); + // + // backend sampling API [EXPERIMENTAL] + // note: use only if the llama_context was created with at least one llama_sampler_seq_config + // + // Get the backend sampled token for the ith token. // Returns LLAMA_TOKEN_NULL if no token was sampled. LLAMA_API llama_token llama_get_sampled_token_ith(struct llama_context * ctx, int32_t i); @@ -1007,24 +1013,18 @@ extern "C" { // Get the backend sampled probabilites for the ith token // The index matches llama_get_sampled_token_ith(). // Returns NULL if no probabilites were generated. - LLAMA_API float * llama_get_sampled_probs_ith(struct llama_context * ctx, int32_t i); - // - // Get the number of backend sampled probabilites for the ith token. + LLAMA_API float * llama_get_sampled_probs_ith (struct llama_context * ctx, int32_t i); LLAMA_API uint32_t llama_get_sampled_probs_count_ith(struct llama_context * ctx, int32_t i); // Get the backend sampled logits for the ith token // Returns NULL if no logits were sampled. - LLAMA_API float * llama_get_sampled_logits_ith(struct llama_context * ctx, int32_t i); - // - // Get the number of backend sampled logits for the ith token. + LLAMA_API float * llama_get_sampled_logits_ith (struct llama_context * ctx, int32_t i); LLAMA_API uint32_t llama_get_sampled_logits_count_ith(struct llama_context * ctx, int32_t i); // Get the backend sampled candidates (token ids) for the ith token // Returns NULL if no candidates were sampled. - LLAMA_API llama_token * llama_get_sampled_candidates_ith(struct llama_context * ctx, int32_t i); - // - // Get the number of backend sampled candidates for the ith token. - LLAMA_API uint32_t llama_get_sampled_candidates_count_ith(struct llama_context * ctx, int32_t i); + LLAMA_API llama_token * llama_get_sampled_candidates_ith (struct llama_context * ctx, int32_t i); + LLAMA_API uint32_t llama_get_sampled_candidates_count_ith(struct llama_context * ctx, int32_t i); // // Vocab @@ -1216,6 +1216,7 @@ extern "C" { struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL + // [EXPERIMENTAL] // backend sampling interface: // return true if the backend supports all ops needed by the sampler @@ -1246,6 +1247,10 @@ extern "C" { llama_sampler_context_t ctx; }; + // [EXPERIMENTAL] + // attach a sampler to the context + // note: prefer initializing the context with llama_context_params.samplers when possible + // note: changing the samplers of a context can cause graph reallocations and degraded performance LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl); // mirror of llama_sampler_i: diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index ff4b7205aa..6245cd967a 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -219,11 +219,10 @@ endif() llama_build_and_test(test-gguf.cpp) llama_build_and_test(test-backend-ops.cpp) -llama_build_and_test(test-model-load-cancel.cpp LABEL "model") -llama_build_and_test(test-autorelease.cpp LABEL "model") +llama_build_and_test(test-model-load-cancel.cpp LABEL "model") +llama_build_and_test(test-autorelease.cpp LABEL "model") +llama_build_and_test(test-backend-sampler.cpp LABEL "model") -llama_build_and_test(test-backend-sampler.cpp LABEL "model") -target_include_directories(test-backend-sampler PRIVATE ${PROJECT_SOURCE_DIR}/src) llama_test(test-backend-sampler NAME test-backend-sampler-greedy ARGS --test greedy) llama_test(test-backend-sampler NAME test-backend-sampler-temp ARGS --test temp) llama_test(test-backend-sampler NAME test-backend-sampler-top_k ARGS --test top_k) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index fab2916713..a753e8a2b8 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7712,9 +7712,6 @@ static std::vector> make_test_cases_eval() { exponent <<= 1; } #endif - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200000, 1, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200000, 4, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {643251, 3, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); for (bool mask : {false, true}) { for (bool sinks : {false, true}) { for (float max_bias : {0.0f, 8.0f}) { @@ -7754,8 +7751,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200000, 1, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200000, 4, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {643251, 3, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); for (float max_bias : {0.0f, 8.0f}) { for (float scale : {1.0f, 0.1f}) { @@ -7768,6 +7768,7 @@ static std::vector> make_test_cases_eval() { } } } + for (bool fw : {true, false}) { // fw == forward bool all = true; @@ -8273,8 +8274,8 @@ static std::vector> make_test_cases_perf() { } } - for (int col: {8192, 16384, 32768, 65536, 131072, 262144, 524288}) { - for (int rows: {1, 4, 16}){ + for (int col : {8192, 16384, 32768, 65536, 131072, 262144, 524288}) { + for (int rows : {1, 4, 16}){ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {col, rows, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); } } @@ -8322,8 +8323,8 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_sum(GGML_TYPE_F32, it)); } - test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1})); - test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 1, 1, 1})); + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1})); + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 1, 1, 1})); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 16, 1, 1})); test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2, 1, 1, 1}, 1));