From 981475fedc2eff543703ceb6f6f79d4b3a2a2150 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 17 Dec 2025 15:27:23 +0100 Subject: [PATCH] tests : add --device option support to backend sampler tests This commit adds support for specifying a device to run the test on. --- tests/test-backend-sampler.cpp | 157 ++++++++++++++++++++------------- 1 file changed, 95 insertions(+), 62 deletions(-) diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp index 65a9e718e7..24ece9d4b1 100644 --- a/tests/test-backend-sampler.cpp +++ b/tests/test-backend-sampler.cpp @@ -11,12 +11,19 @@ #include #include #include -#include +#include +#include #include #include #include #include +struct backend_cli_args { + const char * model = nullptr; + const char * test = nullptr; + const char * device = "cpu"; +}; + struct test_model_context { llama_model_ptr model; llama_context_ptr ctx; @@ -25,25 +32,39 @@ struct test_model_context { std::unordered_map seq_positions; std::unordered_map last_batch_info; - bool load_model(const char * model_path) { + bool load_model(const backend_cli_args & args) { if (model) { return true; } llama_backend_init(); - // force CPU backend since it always supports all ggml operations - ggml_backend_dev_t devs[2]; - devs[0] = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - devs[1] = nullptr; - auto mparams = llama_model_default_params(); - mparams.devices = devs; - model.reset(llama_model_load_from_file(model_path, mparams)); + ggml_backend_dev_t devs[2]; + if (std::string_view(args.device) == "gpu") { + ggml_backend_dev_t gpu = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU); + if (gpu == nullptr) { + fprintf(stderr, "Error: GPU requested but not available\n"); + return false; + } + devs[0] = gpu; + devs[1] = nullptr; // null terminator + mparams.devices = devs; + mparams.n_gpu_layers = 999; + } else if (std::string_view(args.device) == "cpu") { + ggml_backend_dev_t cpu = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + devs[0] = cpu; + devs[1] = nullptr; // null terminator + mparams.devices = devs; + } + + fprintf(stderr, "Using device: %s\n", ggml_backend_dev_name(devs[0])); + + model.reset(llama_model_load_from_file(args.model, mparams)); if (!model) { - fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", model_path); + fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", args.model); return false; } n_vocab = llama_vocab_n_tokens(get_vocab()); @@ -52,9 +73,9 @@ struct test_model_context { return true; } - bool setup(const char * model_path, std::vector & configs, int32_t n_seq_max = -1) { + bool setup(const backend_cli_args & args, std::vector & configs, int32_t n_seq_max = -1) { if (!model) { - load_model(model_path); + load_model(args); } if (ctx) { @@ -257,7 +278,7 @@ struct test_model_context { }; -static void test_backend_greedy_sampling(const char * model_path) { +static void test_backend_greedy_sampling(const backend_cli_args & args) { test_model_context test_ctx; const int seq_id = 0; @@ -268,7 +289,7 @@ static void test_backend_greedy_sampling(const char * model_path) { llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_greedy()); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; - if (!test_ctx.setup(model_path, backend_sampler_configs)) { + if (!test_ctx.setup(args, backend_sampler_configs)) { return; } @@ -296,7 +317,7 @@ static void test_backend_greedy_sampling(const char * model_path) { } } -static void test_backend_top_k_sampling(const char * model_path) { +static void test_backend_top_k_sampling(const backend_cli_args & args) { test_model_context test_ctx; const int seq_id = 0; @@ -306,7 +327,7 @@ static void test_backend_top_k_sampling(const char * model_path) { llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_k(k)); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; - if (!test_ctx.setup(model_path, backend_sampler_configs)) { + if (!test_ctx.setup(args, backend_sampler_configs)) { return; } @@ -343,9 +364,10 @@ static void test_backend_top_k_sampling(const char * model_path) { printf("backend top-k hybrid sampling test PASSED\n"); } -static void test_backend_temp_sampling(const char * model_path) { +static void test_backend_temp_sampling(const backend_cli_args & args) { test_model_context test_ctx; + { const float temp_0 = 0.8f; struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params(); @@ -362,7 +384,7 @@ static void test_backend_temp_sampling(const char * model_path) { { 1, backend_sampler_chain_1.get() } }; - if (!test_ctx.setup(model_path, backend_sampler_configs)) { + if (!test_ctx.setup(args, backend_sampler_configs)) { return; } @@ -419,7 +441,7 @@ static void test_backend_temp_sampling(const char * model_path) { { seq_id, backend_sampler_chain.get() }, }; - if (!test_ctx.setup(model_path, backend_sampler_configs)) { + if (!test_ctx.setup(args, backend_sampler_configs)) { return; } @@ -440,7 +462,7 @@ static void test_backend_temp_sampling(const char * model_path) { } -static void test_backend_temp_ext_sampling(const char * model_path) { +static void test_backend_temp_ext_sampling(const backend_cli_args & args) { test_model_context test_ctx; { @@ -456,7 +478,7 @@ static void test_backend_temp_ext_sampling(const char * model_path) { { seq_id, backend_sampler_chain.get() }, }; - if (!test_ctx.setup(model_path, backend_sampler_configs)) { + if (!test_ctx.setup(args, backend_sampler_configs)) { return; } @@ -489,7 +511,7 @@ static void test_backend_temp_ext_sampling(const char * model_path) { { seq_id, backend_sampler_chain.get() }, }; - if (!test_ctx.setup(model_path, backend_sampler_configs)) { + if (!test_ctx.setup(args, backend_sampler_configs)) { return; } @@ -516,7 +538,7 @@ static void test_backend_temp_ext_sampling(const char * model_path) { } -static void test_backend_min_p_sampling(const char * model_path) { +static void test_backend_min_p_sampling(const backend_cli_args & args) { test_model_context test_ctx; const int seq_id = 0; @@ -526,7 +548,7 @@ static void test_backend_min_p_sampling(const char * model_path) { llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_min_p(p, 0)); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; - if (!test_ctx.setup(model_path, backend_sampler_configs)) { + if (!test_ctx.setup(args, backend_sampler_configs)) { return; } @@ -572,7 +594,7 @@ static void test_backend_min_p_sampling(const char * model_path) { printf("min-p sampling test PASSED\n"); } -static void test_backend_top_p_sampling(const char * model_path) { +static void test_backend_top_p_sampling(const backend_cli_args & args) { test_model_context test_ctx; const int seq_id = 0; @@ -582,7 +604,7 @@ static void test_backend_top_p_sampling(const char * model_path) { llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_p(p, 0)); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; - if (!test_ctx.setup(model_path, backend_sampler_configs)) { + if (!test_ctx.setup(args, backend_sampler_configs)) { return; } @@ -626,7 +648,7 @@ static void test_backend_top_p_sampling(const char * model_path) { printf("top-p sampling test PASSED\n"); } -static void test_backend_multi_sequence_sampling(const char * model_path) { +static void test_backend_multi_sequence_sampling(const backend_cli_args & args) { test_model_context test_ctx; struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); @@ -643,7 +665,7 @@ static void test_backend_multi_sequence_sampling(const char * model_path) { { 1, sampler_chain_1.get() } }; - if (!test_ctx.setup(model_path, backend_sampler_configs)) { + if (!test_ctx.setup(args, backend_sampler_configs)) { return; } @@ -696,7 +718,7 @@ static void test_backend_multi_sequence_sampling(const char * model_path) { printf("backend multi-sequence sampling test PASSED\n"); } -static void test_backend_dist_sampling(const char * model_path) { +static void test_backend_dist_sampling(const backend_cli_args & args) { test_model_context test_ctx; const int seq_id = 189; @@ -706,7 +728,7 @@ static void test_backend_dist_sampling(const char * model_path) { llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed)); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; - if (!test_ctx.setup(model_path, backend_sampler_configs)) { + if (!test_ctx.setup(args, backend_sampler_configs)) { return; } @@ -727,7 +749,7 @@ static void test_backend_dist_sampling(const char * model_path) { printf("backend dist sampling test PASSED\n"); } -static void test_backend_dist_sampling_and_cpu(const char * model_path) { +static void test_backend_dist_sampling_and_cpu(const backend_cli_args & args) { test_model_context test_ctx; const int seq_id = 0; @@ -737,7 +759,7 @@ static void test_backend_dist_sampling_and_cpu(const char * model_path) { llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed)); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; - if (!test_ctx.setup(model_path, backend_sampler_configs)) { + if (!test_ctx.setup(args, backend_sampler_configs)) { return; } @@ -760,11 +782,11 @@ static void test_backend_dist_sampling_and_cpu(const char * model_path) { printf("backend dist & cpu sampling test PASSED\n"); } -static void test_backend_logit_bias_sampling(const char * model_path) { +static void test_backend_logit_bias_sampling(const backend_cli_args & args) { test_model_context test_ctx; // Calling load_model to ensure vocab is loaded and can be accessed - if (!test_ctx.load_model(model_path)) { + if (!test_ctx.load_model(args)) { return; } @@ -793,7 +815,7 @@ static void test_backend_logit_bias_sampling(const char * model_path) { { seq_id, backend_sampler_chain.get() }, }; - if (!test_ctx.setup(model_path, backend_sampler_configs)) { + if (!test_ctx.setup(args, backend_sampler_configs)) { return; } @@ -811,7 +833,7 @@ static void test_backend_logit_bias_sampling(const char * model_path) { // This test verifies that it is possible to have two different backend sampler, // one that uses the backend dist sampler, and another that uses CPU dist sampler. -static void test_backend_mixed_sampling(const char * model_path) { +static void test_backend_mixed_sampling(const backend_cli_args & args) { test_model_context test_ctx; struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); @@ -828,7 +850,7 @@ static void test_backend_mixed_sampling(const char * model_path) { { 1, sampler_chain_1.get() } }; - if (!test_ctx.setup(model_path, backend_sampler_configs)) { + if (!test_ctx.setup(args, backend_sampler_configs)) { return; } @@ -865,7 +887,7 @@ static void test_backend_mixed_sampling(const char * model_path) { printf("backend mixed sampling test PASSED\n"); } -static void test_backend_set_sampler(const char * model_path) { +static void test_backend_set_sampler(const backend_cli_args & args) { test_model_context test_ctx; const int32_t seed = 88; @@ -875,7 +897,7 @@ static void test_backend_set_sampler(const char * model_path) { llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed)); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; - if (!test_ctx.setup(model_path, backend_sampler_configs)) { + if (!test_ctx.setup(args, backend_sampler_configs)) { return; } @@ -933,7 +955,7 @@ static void test_backend_set_sampler(const char * model_path) { printf("backend set sampler test PASSED\n"); } -static void test_backend_cpu_mixed_batch(const char * model_path) { +static void test_backend_cpu_mixed_batch(const backend_cli_args & args) { test_model_context test_ctx; // Sequence 0 uses backend sampling @@ -946,7 +968,7 @@ static void test_backend_cpu_mixed_batch(const char * model_path) { }; // We need 2 sequences: seq 0 with backend sampling, seq 1 with CPU sampling - if (!test_ctx.setup(model_path, backend_sampler_configs, 2)) { + if (!test_ctx.setup(args, backend_sampler_configs, 2)) { return; } @@ -1025,7 +1047,7 @@ static void test_backend_cpu_mixed_batch(const char * model_path) { printf("backend-cpu mixed batch test PASSED\n"); } -static void test_backend_max_outputs(const char * model_path) { +static void test_backend_max_outputs(const backend_cli_args & args) { test_model_context test_ctx; const int seq_id = 0; @@ -1035,7 +1057,7 @@ static void test_backend_max_outputs(const char * model_path) { llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed)); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; - if (!test_ctx.setup(model_path, backend_sampler_configs)) { + if (!test_ctx.setup(args, backend_sampler_configs)) { return; } @@ -1069,7 +1091,7 @@ static void test_backend_max_outputs(const char * model_path) { struct backend_test_case { const char * name; - void (*fn)(const char *); + void (*fn)(const backend_cli_args &); bool enabled_by_default; }; @@ -1090,11 +1112,6 @@ static const backend_test_case BACKEND_TESTS[] = { { "top_p", test_backend_top_p_sampling, true }, }; -struct backend_cli_args { - const char * model = nullptr; - const char * test = nullptr; -}; - static backend_cli_args parse_backend_cli(int argc, char ** argv) { backend_cli_args out; @@ -1125,6 +1142,18 @@ static backend_cli_args parse_backend_cli(int argc, char ** argv) { out.model = arg + 8; continue; } + if (std::strcmp(arg, "--device") == 0) { + if (i + 1 >= argc) { + fprintf(stderr, "--device expects a value (cpu or gpu)\n"); + exit(EXIT_FAILURE); + } + out.device = argv[++i]; + continue; + } + if (std::strncmp(arg, "--device=", 9) == 0) { + out.device = arg + 9; + continue; + } if (!out.model) { out.model = arg; continue; @@ -1134,6 +1163,11 @@ static backend_cli_args parse_backend_cli(int argc, char ** argv) { exit(EXIT_FAILURE); } + if (std::strcmp(out.device, "cpu") != 0 && std::strcmp(out.device, "gpu") != 0) { + fprintf(stderr, "Invalid device '%s'. Must be 'cpu' or 'gpu'\n", out.device); + exit(EXIT_FAILURE); + } + return out; } @@ -1169,35 +1203,34 @@ static std::vector collect_tests_to_run(const char * return selected; } -static void run_tests(const std::vector & tests, const char * model_path) { +static void run_tests(const std::vector & tests, const backend_cli_args & args) { for (const auto * test : tests) { fprintf(stderr, "\n=== %s ===\n", test->name); - test->fn(model_path); + test->fn(args); } } -int main(int argc, char *argv[] ) { - const backend_cli_args args = parse_backend_cli(argc, argv); +int main(int argc, char ** argv) { + backend_cli_args args = parse_backend_cli(argc, argv); - std::array model_argv { argv[0], const_cast(args.model) }; - const int model_argc = args.model ? 2 : 1; - char * model_path = get_model_or_exit(model_argc, model_argv.data()); + if (args.model == nullptr) { + args.model = get_model_or_exit(1, argv); + } - auto * file = fopen(model_path, "r"); - if (file == nullptr) { - fprintf(stderr, "no model at '%s' found\n", model_path); + std::ifstream file(args.model); + if (!file.is_open()) { + fprintf(stderr, "no model '%s' found\n", args.model); return EXIT_FAILURE; } - fprintf(stderr, "using '%s'\n", model_path); - fclose(file); + fprintf(stderr, "using '%s'\n", args.model); ggml_time_init(); const std::vector tests = collect_tests_to_run(args.test); if (!tests.empty()) { - run_tests(tests, model_path); + run_tests(tests, args); } return 0;