From ab65b47a52ffabd247c3905e766d0969dfaf31fc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 11 Dec 2025 14:12:35 +0200 Subject: [PATCH] tests : run backend sampler tests always on the CPU --- src/llama-sampling.cpp | 22 +++++++++++++--------- tests/test-backend-sampler.cpp | 17 +++++++++++++---- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 811029052a..15dafcf102 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -616,8 +616,9 @@ private: // check if all ggml ops used by the sampler are supported by the backend static bool llama_sampler_backend_support( - llama_sampler * smpl, - ggml_backend_dev_t device) { + llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * device = ggml_backend_buft_get_device(buft); if (!device) { // CPU backend always supported return true; @@ -669,6 +670,9 @@ static bool llama_sampler_backend_support( struct ggml_tensor * op = ggml_graph_node(gf, i); if (!ggml_backend_dev_supports_op(device, op)) { + LLAMA_LOG_WARN("%s: device '%s' does not have support for op %s needed for sampler '%s'\n", + __func__, ggml_backend_dev_name(device), ggml_op_name(op->op), smpl->iface->name(smpl)); + return false; } } @@ -958,7 +962,7 @@ static bool llama_sampler_greedy_backend_init( ggml_backend_buffer_type_t buft) { auto * sctx = (llama_sampler_greedy *) smpl->ctx; - const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft)); + const bool res = llama_sampler_backend_support(smpl, buft); sctx->init(res); @@ -1142,7 +1146,7 @@ static bool llama_sampler_dist_backend_init( sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); } - const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft)); + const bool res = llama_sampler_backend_support(smpl, buft); sctx->init(res); @@ -1274,7 +1278,7 @@ static bool llama_sampler_top_k_backend_init( ggml_backend_buffer_type_t buft) { auto * sctx = (llama_sampler_top_k *) smpl->ctx; - const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft)); + const bool res = llama_sampler_backend_support(smpl, buft); sctx->init(res); @@ -1419,7 +1423,7 @@ static bool llama_sampler_top_p_backend_init( ggml_backend_buffer_type_t buft) { auto * sctx = (llama_sampler_top_p *) smpl->ctx; - const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft)); + const bool res = llama_sampler_backend_support(smpl, buft); sctx->init(res); @@ -1613,7 +1617,7 @@ static bool llama_sampler_min_p_backend_init( ggml_backend_buffer_type_t buft) { auto * sctx = (llama_sampler_min_p *) smpl->ctx; - const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft)); + const bool res = llama_sampler_backend_support(smpl, buft); sctx->init(res); @@ -1861,7 +1865,7 @@ static bool llama_sampler_temp_backend_init( ggml_backend_buffer_type_t buft) { auto * sctx = (llama_sampler_temp *) smpl->ctx; - const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft)); + const bool res = llama_sampler_backend_support(smpl, buft); sctx->init(res); @@ -2004,7 +2008,7 @@ static bool llama_sampler_temp_ext_backend_init( ggml_backend_buffer_type_t buft) { auto * sctx = (llama_sampler_temp_ext *) smpl->ctx; - const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft)); + const bool res = llama_sampler_backend_support(smpl, buft); sctx->init(res); diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp index db1a2631f0..b3f202771a 100644 --- a/tests/test-backend-sampler.cpp +++ b/tests/test-backend-sampler.cpp @@ -7,6 +7,7 @@ #undef NDEBUG #endif +#include #include #include #include @@ -31,7 +32,15 @@ struct test_model_context { llama_backend_init(); - model = llama_model_load_from_file(model_path, llama_model_default_params()); + // force CPU backend since it always supports all ggml operations + ggml_backend_dev_t devs[2]; + devs[0] = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + devs[1] = nullptr; + + auto mparams = llama_model_default_params(); + mparams.devices = devs; + + model = llama_model_load_from_file(model_path, mparams); if (model == nullptr) { fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", model_path); cleanup(); @@ -63,9 +72,7 @@ struct test_model_context { if (n_seq_max < 0) { int32_t max_seq_id = 0; for (const auto & config : configs) { - if (config.seq_id > max_seq_id) { - max_seq_id = config.seq_id; - } + max_seq_id = std::max(config.seq_id, max_seq_id); } cparams.n_seq_max = max_seq_id + 1; } else { @@ -859,6 +866,8 @@ static void test_backend_logit_bias_sampling(const char * model_path) { printf("logit bias sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str()); GGML_ASSERT(backend_token == bias_token); + printf("backend logit bias sampling test PASSED\n"); + llama_sampler_free(backend_sampler_chain); }