tests : add --device option support to backend sampler tests

This commit adds support for specifying a device to run the test on.
This commit is contained in:
Daniel Bevenius 2025-12-17 15:27:23 +01:00
parent a519aea35c
commit 981475fedc
No known key found for this signature in database
1 changed files with 95 additions and 62 deletions

View File

@ -11,12 +11,19 @@
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <array>
#include <iostream>
#include <fstream>
#include <map>
#include <string>
#include <unordered_map>
#include <vector>
struct backend_cli_args {
const char * model = nullptr;
const char * test = nullptr;
const char * device = "cpu";
};
struct test_model_context {
llama_model_ptr model;
llama_context_ptr ctx;
@ -25,25 +32,39 @@ struct test_model_context {
std::unordered_map<llama_seq_id, int32_t> seq_positions;
std::unordered_map<llama_seq_id, int32_t> last_batch_info;
bool load_model(const char * model_path) {
bool load_model(const backend_cli_args & args) {
if (model) {
return true;
}
llama_backend_init();
// force CPU backend since it always supports all ggml operations
ggml_backend_dev_t devs[2];
devs[0] = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
devs[1] = nullptr;
auto mparams = llama_model_default_params();
mparams.devices = devs;
model.reset(llama_model_load_from_file(model_path, mparams));
ggml_backend_dev_t devs[2];
if (std::string_view(args.device) == "gpu") {
ggml_backend_dev_t gpu = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
if (gpu == nullptr) {
fprintf(stderr, "Error: GPU requested but not available\n");
return false;
}
devs[0] = gpu;
devs[1] = nullptr; // null terminator
mparams.devices = devs;
mparams.n_gpu_layers = 999;
} else if (std::string_view(args.device) == "cpu") {
ggml_backend_dev_t cpu = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
devs[0] = cpu;
devs[1] = nullptr; // null terminator
mparams.devices = devs;
}
fprintf(stderr, "Using device: %s\n", ggml_backend_dev_name(devs[0]));
model.reset(llama_model_load_from_file(args.model, mparams));
if (!model) {
fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", model_path);
fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", args.model);
return false;
}
n_vocab = llama_vocab_n_tokens(get_vocab());
@ -52,9 +73,9 @@ struct test_model_context {
return true;
}
bool setup(const char * model_path, std::vector<llama_sampler_seq_config> & configs, int32_t n_seq_max = -1) {
bool setup(const backend_cli_args & args, std::vector<llama_sampler_seq_config> & configs, int32_t n_seq_max = -1) {
if (!model) {
load_model(model_path);
load_model(args);
}
if (ctx) {
@ -257,7 +278,7 @@ struct test_model_context {
};
static void test_backend_greedy_sampling(const char * model_path) {
static void test_backend_greedy_sampling(const backend_cli_args & args) {
test_model_context test_ctx;
const int seq_id = 0;
@ -268,7 +289,7 @@ static void test_backend_greedy_sampling(const char * model_path) {
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_greedy());
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
if (!test_ctx.setup(args, backend_sampler_configs)) {
return;
}
@ -296,7 +317,7 @@ static void test_backend_greedy_sampling(const char * model_path) {
}
}
static void test_backend_top_k_sampling(const char * model_path) {
static void test_backend_top_k_sampling(const backend_cli_args & args) {
test_model_context test_ctx;
const int seq_id = 0;
@ -306,7 +327,7 @@ static void test_backend_top_k_sampling(const char * model_path) {
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_k(k));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
if (!test_ctx.setup(args, backend_sampler_configs)) {
return;
}
@ -343,9 +364,10 @@ static void test_backend_top_k_sampling(const char * model_path) {
printf("backend top-k hybrid sampling test PASSED\n");
}
static void test_backend_temp_sampling(const char * model_path) {
static void test_backend_temp_sampling(const backend_cli_args & args) {
test_model_context test_ctx;
{
const float temp_0 = 0.8f;
struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params();
@ -362,7 +384,7 @@ static void test_backend_temp_sampling(const char * model_path) {
{ 1, backend_sampler_chain_1.get() }
};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
if (!test_ctx.setup(args, backend_sampler_configs)) {
return;
}
@ -419,7 +441,7 @@ static void test_backend_temp_sampling(const char * model_path) {
{ seq_id, backend_sampler_chain.get() },
};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
if (!test_ctx.setup(args, backend_sampler_configs)) {
return;
}
@ -440,7 +462,7 @@ static void test_backend_temp_sampling(const char * model_path) {
}
static void test_backend_temp_ext_sampling(const char * model_path) {
static void test_backend_temp_ext_sampling(const backend_cli_args & args) {
test_model_context test_ctx;
{
@ -456,7 +478,7 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
{ seq_id, backend_sampler_chain.get() },
};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
if (!test_ctx.setup(args, backend_sampler_configs)) {
return;
}
@ -489,7 +511,7 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
{ seq_id, backend_sampler_chain.get() },
};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
if (!test_ctx.setup(args, backend_sampler_configs)) {
return;
}
@ -516,7 +538,7 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
}
static void test_backend_min_p_sampling(const char * model_path) {
static void test_backend_min_p_sampling(const backend_cli_args & args) {
test_model_context test_ctx;
const int seq_id = 0;
@ -526,7 +548,7 @@ static void test_backend_min_p_sampling(const char * model_path) {
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_min_p(p, 0));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
if (!test_ctx.setup(args, backend_sampler_configs)) {
return;
}
@ -572,7 +594,7 @@ static void test_backend_min_p_sampling(const char * model_path) {
printf("min-p sampling test PASSED\n");
}
static void test_backend_top_p_sampling(const char * model_path) {
static void test_backend_top_p_sampling(const backend_cli_args & args) {
test_model_context test_ctx;
const int seq_id = 0;
@ -582,7 +604,7 @@ static void test_backend_top_p_sampling(const char * model_path) {
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_p(p, 0));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
if (!test_ctx.setup(args, backend_sampler_configs)) {
return;
}
@ -626,7 +648,7 @@ static void test_backend_top_p_sampling(const char * model_path) {
printf("top-p sampling test PASSED\n");
}
static void test_backend_multi_sequence_sampling(const char * model_path) {
static void test_backend_multi_sequence_sampling(const backend_cli_args & args) {
test_model_context test_ctx;
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
@ -643,7 +665,7 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
{ 1, sampler_chain_1.get() }
};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
if (!test_ctx.setup(args, backend_sampler_configs)) {
return;
}
@ -696,7 +718,7 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
printf("backend multi-sequence sampling test PASSED\n");
}
static void test_backend_dist_sampling(const char * model_path) {
static void test_backend_dist_sampling(const backend_cli_args & args) {
test_model_context test_ctx;
const int seq_id = 189;
@ -706,7 +728,7 @@ static void test_backend_dist_sampling(const char * model_path) {
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
if (!test_ctx.setup(args, backend_sampler_configs)) {
return;
}
@ -727,7 +749,7 @@ static void test_backend_dist_sampling(const char * model_path) {
printf("backend dist sampling test PASSED\n");
}
static void test_backend_dist_sampling_and_cpu(const char * model_path) {
static void test_backend_dist_sampling_and_cpu(const backend_cli_args & args) {
test_model_context test_ctx;
const int seq_id = 0;
@ -737,7 +759,7 @@ static void test_backend_dist_sampling_and_cpu(const char * model_path) {
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
if (!test_ctx.setup(args, backend_sampler_configs)) {
return;
}
@ -760,11 +782,11 @@ static void test_backend_dist_sampling_and_cpu(const char * model_path) {
printf("backend dist & cpu sampling test PASSED\n");
}
static void test_backend_logit_bias_sampling(const char * model_path) {
static void test_backend_logit_bias_sampling(const backend_cli_args & args) {
test_model_context test_ctx;
// Calling load_model to ensure vocab is loaded and can be accessed
if (!test_ctx.load_model(model_path)) {
if (!test_ctx.load_model(args)) {
return;
}
@ -793,7 +815,7 @@ static void test_backend_logit_bias_sampling(const char * model_path) {
{ seq_id, backend_sampler_chain.get() },
};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
if (!test_ctx.setup(args, backend_sampler_configs)) {
return;
}
@ -811,7 +833,7 @@ static void test_backend_logit_bias_sampling(const char * model_path) {
// This test verifies that it is possible to have two different backend sampler,
// one that uses the backend dist sampler, and another that uses CPU dist sampler.
static void test_backend_mixed_sampling(const char * model_path) {
static void test_backend_mixed_sampling(const backend_cli_args & args) {
test_model_context test_ctx;
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
@ -828,7 +850,7 @@ static void test_backend_mixed_sampling(const char * model_path) {
{ 1, sampler_chain_1.get() }
};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
if (!test_ctx.setup(args, backend_sampler_configs)) {
return;
}
@ -865,7 +887,7 @@ static void test_backend_mixed_sampling(const char * model_path) {
printf("backend mixed sampling test PASSED\n");
}
static void test_backend_set_sampler(const char * model_path) {
static void test_backend_set_sampler(const backend_cli_args & args) {
test_model_context test_ctx;
const int32_t seed = 88;
@ -875,7 +897,7 @@ static void test_backend_set_sampler(const char * model_path) {
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
if (!test_ctx.setup(args, backend_sampler_configs)) {
return;
}
@ -933,7 +955,7 @@ static void test_backend_set_sampler(const char * model_path) {
printf("backend set sampler test PASSED\n");
}
static void test_backend_cpu_mixed_batch(const char * model_path) {
static void test_backend_cpu_mixed_batch(const backend_cli_args & args) {
test_model_context test_ctx;
// Sequence 0 uses backend sampling
@ -946,7 +968,7 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
};
// We need 2 sequences: seq 0 with backend sampling, seq 1 with CPU sampling
if (!test_ctx.setup(model_path, backend_sampler_configs, 2)) {
if (!test_ctx.setup(args, backend_sampler_configs, 2)) {
return;
}
@ -1025,7 +1047,7 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
printf("backend-cpu mixed batch test PASSED\n");
}
static void test_backend_max_outputs(const char * model_path) {
static void test_backend_max_outputs(const backend_cli_args & args) {
test_model_context test_ctx;
const int seq_id = 0;
@ -1035,7 +1057,7 @@ static void test_backend_max_outputs(const char * model_path) {
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
if (!test_ctx.setup(args, backend_sampler_configs)) {
return;
}
@ -1069,7 +1091,7 @@ static void test_backend_max_outputs(const char * model_path) {
struct backend_test_case {
const char * name;
void (*fn)(const char *);
void (*fn)(const backend_cli_args &);
bool enabled_by_default;
};
@ -1090,11 +1112,6 @@ static const backend_test_case BACKEND_TESTS[] = {
{ "top_p", test_backend_top_p_sampling, true },
};
struct backend_cli_args {
const char * model = nullptr;
const char * test = nullptr;
};
static backend_cli_args parse_backend_cli(int argc, char ** argv) {
backend_cli_args out;
@ -1125,6 +1142,18 @@ static backend_cli_args parse_backend_cli(int argc, char ** argv) {
out.model = arg + 8;
continue;
}
if (std::strcmp(arg, "--device") == 0) {
if (i + 1 >= argc) {
fprintf(stderr, "--device expects a value (cpu or gpu)\n");
exit(EXIT_FAILURE);
}
out.device = argv[++i];
continue;
}
if (std::strncmp(arg, "--device=", 9) == 0) {
out.device = arg + 9;
continue;
}
if (!out.model) {
out.model = arg;
continue;
@ -1134,6 +1163,11 @@ static backend_cli_args parse_backend_cli(int argc, char ** argv) {
exit(EXIT_FAILURE);
}
if (std::strcmp(out.device, "cpu") != 0 && std::strcmp(out.device, "gpu") != 0) {
fprintf(stderr, "Invalid device '%s'. Must be 'cpu' or 'gpu'\n", out.device);
exit(EXIT_FAILURE);
}
return out;
}
@ -1169,35 +1203,34 @@ static std::vector<const backend_test_case *> collect_tests_to_run(const char *
return selected;
}
static void run_tests(const std::vector<const backend_test_case *> & tests, const char * model_path) {
static void run_tests(const std::vector<const backend_test_case *> & tests, const backend_cli_args & args) {
for (const auto * test : tests) {
fprintf(stderr, "\n=== %s ===\n", test->name);
test->fn(model_path);
test->fn(args);
}
}
int main(int argc, char *argv[] ) {
const backend_cli_args args = parse_backend_cli(argc, argv);
int main(int argc, char ** argv) {
backend_cli_args args = parse_backend_cli(argc, argv);
std::array<char *, 2> model_argv { argv[0], const_cast<char *>(args.model) };
const int model_argc = args.model ? 2 : 1;
char * model_path = get_model_or_exit(model_argc, model_argv.data());
if (args.model == nullptr) {
args.model = get_model_or_exit(1, argv);
}
auto * file = fopen(model_path, "r");
if (file == nullptr) {
fprintf(stderr, "no model at '%s' found\n", model_path);
std::ifstream file(args.model);
if (!file.is_open()) {
fprintf(stderr, "no model '%s' found\n", args.model);
return EXIT_FAILURE;
}
fprintf(stderr, "using '%s'\n", model_path);
fclose(file);
fprintf(stderr, "using '%s'\n", args.model);
ggml_time_init();
const std::vector<const backend_test_case *> tests = collect_tests_to_run(args.test);
if (!tests.empty()) {
run_tests(tests, model_path);
run_tests(tests, args);
}
return 0;