tests : add --device option support to backend sampler tests
This commit adds support for specifying a device to run the test on.
This commit is contained in:
parent
a519aea35c
commit
981475fedc
|
|
@ -11,12 +11,19 @@
|
|||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <array>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
struct backend_cli_args {
|
||||
const char * model = nullptr;
|
||||
const char * test = nullptr;
|
||||
const char * device = "cpu";
|
||||
};
|
||||
|
||||
struct test_model_context {
|
||||
llama_model_ptr model;
|
||||
llama_context_ptr ctx;
|
||||
|
|
@ -25,25 +32,39 @@ struct test_model_context {
|
|||
std::unordered_map<llama_seq_id, int32_t> seq_positions;
|
||||
std::unordered_map<llama_seq_id, int32_t> last_batch_info;
|
||||
|
||||
bool load_model(const char * model_path) {
|
||||
bool load_model(const backend_cli_args & args) {
|
||||
if (model) {
|
||||
return true;
|
||||
}
|
||||
|
||||
llama_backend_init();
|
||||
|
||||
// force CPU backend since it always supports all ggml operations
|
||||
ggml_backend_dev_t devs[2];
|
||||
devs[0] = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
devs[1] = nullptr;
|
||||
|
||||
auto mparams = llama_model_default_params();
|
||||
mparams.devices = devs;
|
||||
|
||||
model.reset(llama_model_load_from_file(model_path, mparams));
|
||||
ggml_backend_dev_t devs[2];
|
||||
if (std::string_view(args.device) == "gpu") {
|
||||
ggml_backend_dev_t gpu = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
|
||||
if (gpu == nullptr) {
|
||||
fprintf(stderr, "Error: GPU requested but not available\n");
|
||||
return false;
|
||||
}
|
||||
devs[0] = gpu;
|
||||
devs[1] = nullptr; // null terminator
|
||||
mparams.devices = devs;
|
||||
mparams.n_gpu_layers = 999;
|
||||
} else if (std::string_view(args.device) == "cpu") {
|
||||
ggml_backend_dev_t cpu = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
devs[0] = cpu;
|
||||
devs[1] = nullptr; // null terminator
|
||||
mparams.devices = devs;
|
||||
}
|
||||
|
||||
fprintf(stderr, "Using device: %s\n", ggml_backend_dev_name(devs[0]));
|
||||
|
||||
model.reset(llama_model_load_from_file(args.model, mparams));
|
||||
|
||||
if (!model) {
|
||||
fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", model_path);
|
||||
fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", args.model);
|
||||
return false;
|
||||
}
|
||||
n_vocab = llama_vocab_n_tokens(get_vocab());
|
||||
|
|
@ -52,9 +73,9 @@ struct test_model_context {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool setup(const char * model_path, std::vector<llama_sampler_seq_config> & configs, int32_t n_seq_max = -1) {
|
||||
bool setup(const backend_cli_args & args, std::vector<llama_sampler_seq_config> & configs, int32_t n_seq_max = -1) {
|
||||
if (!model) {
|
||||
load_model(model_path);
|
||||
load_model(args);
|
||||
}
|
||||
|
||||
if (ctx) {
|
||||
|
|
@ -257,7 +278,7 @@ struct test_model_context {
|
|||
|
||||
};
|
||||
|
||||
static void test_backend_greedy_sampling(const char * model_path) {
|
||||
static void test_backend_greedy_sampling(const backend_cli_args & args) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
const int seq_id = 0;
|
||||
|
|
@ -268,7 +289,7 @@ static void test_backend_greedy_sampling(const char * model_path) {
|
|||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_greedy());
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
if (!test_ctx.setup(args, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -296,7 +317,7 @@ static void test_backend_greedy_sampling(const char * model_path) {
|
|||
}
|
||||
}
|
||||
|
||||
static void test_backend_top_k_sampling(const char * model_path) {
|
||||
static void test_backend_top_k_sampling(const backend_cli_args & args) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
const int seq_id = 0;
|
||||
|
|
@ -306,7 +327,7 @@ static void test_backend_top_k_sampling(const char * model_path) {
|
|||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_k(k));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
if (!test_ctx.setup(args, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -343,9 +364,10 @@ static void test_backend_top_k_sampling(const char * model_path) {
|
|||
printf("backend top-k hybrid sampling test PASSED\n");
|
||||
}
|
||||
|
||||
static void test_backend_temp_sampling(const char * model_path) {
|
||||
static void test_backend_temp_sampling(const backend_cli_args & args) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
|
||||
{
|
||||
const float temp_0 = 0.8f;
|
||||
struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params();
|
||||
|
|
@ -362,7 +384,7 @@ static void test_backend_temp_sampling(const char * model_path) {
|
|||
{ 1, backend_sampler_chain_1.get() }
|
||||
};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
if (!test_ctx.setup(args, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -419,7 +441,7 @@ static void test_backend_temp_sampling(const char * model_path) {
|
|||
{ seq_id, backend_sampler_chain.get() },
|
||||
};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
if (!test_ctx.setup(args, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -440,7 +462,7 @@ static void test_backend_temp_sampling(const char * model_path) {
|
|||
|
||||
}
|
||||
|
||||
static void test_backend_temp_ext_sampling(const char * model_path) {
|
||||
static void test_backend_temp_ext_sampling(const backend_cli_args & args) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
{
|
||||
|
|
@ -456,7 +478,7 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
|
|||
{ seq_id, backend_sampler_chain.get() },
|
||||
};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
if (!test_ctx.setup(args, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -489,7 +511,7 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
|
|||
{ seq_id, backend_sampler_chain.get() },
|
||||
};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
if (!test_ctx.setup(args, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -516,7 +538,7 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
|
|||
|
||||
}
|
||||
|
||||
static void test_backend_min_p_sampling(const char * model_path) {
|
||||
static void test_backend_min_p_sampling(const backend_cli_args & args) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
const int seq_id = 0;
|
||||
|
|
@ -526,7 +548,7 @@ static void test_backend_min_p_sampling(const char * model_path) {
|
|||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_min_p(p, 0));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
if (!test_ctx.setup(args, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -572,7 +594,7 @@ static void test_backend_min_p_sampling(const char * model_path) {
|
|||
printf("min-p sampling test PASSED\n");
|
||||
}
|
||||
|
||||
static void test_backend_top_p_sampling(const char * model_path) {
|
||||
static void test_backend_top_p_sampling(const backend_cli_args & args) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
const int seq_id = 0;
|
||||
|
|
@ -582,7 +604,7 @@ static void test_backend_top_p_sampling(const char * model_path) {
|
|||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_p(p, 0));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
if (!test_ctx.setup(args, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -626,7 +648,7 @@ static void test_backend_top_p_sampling(const char * model_path) {
|
|||
printf("top-p sampling test PASSED\n");
|
||||
}
|
||||
|
||||
static void test_backend_multi_sequence_sampling(const char * model_path) {
|
||||
static void test_backend_multi_sequence_sampling(const backend_cli_args & args) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
|
||||
|
|
@ -643,7 +665,7 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
|
|||
{ 1, sampler_chain_1.get() }
|
||||
};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
if (!test_ctx.setup(args, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -696,7 +718,7 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
|
|||
printf("backend multi-sequence sampling test PASSED\n");
|
||||
}
|
||||
|
||||
static void test_backend_dist_sampling(const char * model_path) {
|
||||
static void test_backend_dist_sampling(const backend_cli_args & args) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
const int seq_id = 189;
|
||||
|
|
@ -706,7 +728,7 @@ static void test_backend_dist_sampling(const char * model_path) {
|
|||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
if (!test_ctx.setup(args, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -727,7 +749,7 @@ static void test_backend_dist_sampling(const char * model_path) {
|
|||
printf("backend dist sampling test PASSED\n");
|
||||
}
|
||||
|
||||
static void test_backend_dist_sampling_and_cpu(const char * model_path) {
|
||||
static void test_backend_dist_sampling_and_cpu(const backend_cli_args & args) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
const int seq_id = 0;
|
||||
|
|
@ -737,7 +759,7 @@ static void test_backend_dist_sampling_and_cpu(const char * model_path) {
|
|||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
if (!test_ctx.setup(args, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -760,11 +782,11 @@ static void test_backend_dist_sampling_and_cpu(const char * model_path) {
|
|||
printf("backend dist & cpu sampling test PASSED\n");
|
||||
}
|
||||
|
||||
static void test_backend_logit_bias_sampling(const char * model_path) {
|
||||
static void test_backend_logit_bias_sampling(const backend_cli_args & args) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
// Calling load_model to ensure vocab is loaded and can be accessed
|
||||
if (!test_ctx.load_model(model_path)) {
|
||||
if (!test_ctx.load_model(args)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -793,7 +815,7 @@ static void test_backend_logit_bias_sampling(const char * model_path) {
|
|||
{ seq_id, backend_sampler_chain.get() },
|
||||
};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
if (!test_ctx.setup(args, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -811,7 +833,7 @@ static void test_backend_logit_bias_sampling(const char * model_path) {
|
|||
|
||||
// This test verifies that it is possible to have two different backend sampler,
|
||||
// one that uses the backend dist sampler, and another that uses CPU dist sampler.
|
||||
static void test_backend_mixed_sampling(const char * model_path) {
|
||||
static void test_backend_mixed_sampling(const backend_cli_args & args) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
|
||||
|
|
@ -828,7 +850,7 @@ static void test_backend_mixed_sampling(const char * model_path) {
|
|||
{ 1, sampler_chain_1.get() }
|
||||
};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
if (!test_ctx.setup(args, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -865,7 +887,7 @@ static void test_backend_mixed_sampling(const char * model_path) {
|
|||
printf("backend mixed sampling test PASSED\n");
|
||||
}
|
||||
|
||||
static void test_backend_set_sampler(const char * model_path) {
|
||||
static void test_backend_set_sampler(const backend_cli_args & args) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
const int32_t seed = 88;
|
||||
|
|
@ -875,7 +897,7 @@ static void test_backend_set_sampler(const char * model_path) {
|
|||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
if (!test_ctx.setup(args, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -933,7 +955,7 @@ static void test_backend_set_sampler(const char * model_path) {
|
|||
printf("backend set sampler test PASSED\n");
|
||||
}
|
||||
|
||||
static void test_backend_cpu_mixed_batch(const char * model_path) {
|
||||
static void test_backend_cpu_mixed_batch(const backend_cli_args & args) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
// Sequence 0 uses backend sampling
|
||||
|
|
@ -946,7 +968,7 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
|
|||
};
|
||||
|
||||
// We need 2 sequences: seq 0 with backend sampling, seq 1 with CPU sampling
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs, 2)) {
|
||||
if (!test_ctx.setup(args, backend_sampler_configs, 2)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -1025,7 +1047,7 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
|
|||
printf("backend-cpu mixed batch test PASSED\n");
|
||||
}
|
||||
|
||||
static void test_backend_max_outputs(const char * model_path) {
|
||||
static void test_backend_max_outputs(const backend_cli_args & args) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
const int seq_id = 0;
|
||||
|
|
@ -1035,7 +1057,7 @@ static void test_backend_max_outputs(const char * model_path) {
|
|||
llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
if (!test_ctx.setup(args, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -1069,7 +1091,7 @@ static void test_backend_max_outputs(const char * model_path) {
|
|||
|
||||
struct backend_test_case {
|
||||
const char * name;
|
||||
void (*fn)(const char *);
|
||||
void (*fn)(const backend_cli_args &);
|
||||
bool enabled_by_default;
|
||||
};
|
||||
|
||||
|
|
@ -1090,11 +1112,6 @@ static const backend_test_case BACKEND_TESTS[] = {
|
|||
{ "top_p", test_backend_top_p_sampling, true },
|
||||
};
|
||||
|
||||
struct backend_cli_args {
|
||||
const char * model = nullptr;
|
||||
const char * test = nullptr;
|
||||
};
|
||||
|
||||
static backend_cli_args parse_backend_cli(int argc, char ** argv) {
|
||||
backend_cli_args out;
|
||||
|
||||
|
|
@ -1125,6 +1142,18 @@ static backend_cli_args parse_backend_cli(int argc, char ** argv) {
|
|||
out.model = arg + 8;
|
||||
continue;
|
||||
}
|
||||
if (std::strcmp(arg, "--device") == 0) {
|
||||
if (i + 1 >= argc) {
|
||||
fprintf(stderr, "--device expects a value (cpu or gpu)\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
out.device = argv[++i];
|
||||
continue;
|
||||
}
|
||||
if (std::strncmp(arg, "--device=", 9) == 0) {
|
||||
out.device = arg + 9;
|
||||
continue;
|
||||
}
|
||||
if (!out.model) {
|
||||
out.model = arg;
|
||||
continue;
|
||||
|
|
@ -1134,6 +1163,11 @@ static backend_cli_args parse_backend_cli(int argc, char ** argv) {
|
|||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
if (std::strcmp(out.device, "cpu") != 0 && std::strcmp(out.device, "gpu") != 0) {
|
||||
fprintf(stderr, "Invalid device '%s'. Must be 'cpu' or 'gpu'\n", out.device);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
|
@ -1169,35 +1203,34 @@ static std::vector<const backend_test_case *> collect_tests_to_run(const char *
|
|||
return selected;
|
||||
}
|
||||
|
||||
static void run_tests(const std::vector<const backend_test_case *> & tests, const char * model_path) {
|
||||
static void run_tests(const std::vector<const backend_test_case *> & tests, const backend_cli_args & args) {
|
||||
for (const auto * test : tests) {
|
||||
fprintf(stderr, "\n=== %s ===\n", test->name);
|
||||
test->fn(model_path);
|
||||
test->fn(args);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
int main(int argc, char *argv[] ) {
|
||||
const backend_cli_args args = parse_backend_cli(argc, argv);
|
||||
int main(int argc, char ** argv) {
|
||||
backend_cli_args args = parse_backend_cli(argc, argv);
|
||||
|
||||
std::array<char *, 2> model_argv { argv[0], const_cast<char *>(args.model) };
|
||||
const int model_argc = args.model ? 2 : 1;
|
||||
char * model_path = get_model_or_exit(model_argc, model_argv.data());
|
||||
if (args.model == nullptr) {
|
||||
args.model = get_model_or_exit(1, argv);
|
||||
}
|
||||
|
||||
auto * file = fopen(model_path, "r");
|
||||
if (file == nullptr) {
|
||||
fprintf(stderr, "no model at '%s' found\n", model_path);
|
||||
std::ifstream file(args.model);
|
||||
if (!file.is_open()) {
|
||||
fprintf(stderr, "no model '%s' found\n", args.model);
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
fprintf(stderr, "using '%s'\n", model_path);
|
||||
fclose(file);
|
||||
fprintf(stderr, "using '%s'\n", args.model);
|
||||
|
||||
ggml_time_init();
|
||||
|
||||
const std::vector<const backend_test_case *> tests = collect_tests_to_run(args.test);
|
||||
if (!tests.empty()) {
|
||||
run_tests(tests, model_path);
|
||||
run_tests(tests, args);
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
|
|
|||
Loading…
Reference in New Issue