tests : run backend sampler tests always on the CPU
This commit is contained in:
parent
74b112e3e7
commit
ab65b47a52
|
|
@ -617,7 +617,8 @@ private:
|
||||||
// check if all ggml ops used by the sampler are supported by the backend
|
// check if all ggml ops used by the sampler are supported by the backend
|
||||||
static bool llama_sampler_backend_support(
|
static bool llama_sampler_backend_support(
|
||||||
llama_sampler * smpl,
|
llama_sampler * smpl,
|
||||||
ggml_backend_dev_t device) {
|
ggml_backend_buffer_type_t buft) {
|
||||||
|
auto * device = ggml_backend_buft_get_device(buft);
|
||||||
if (!device) {
|
if (!device) {
|
||||||
// CPU backend always supported
|
// CPU backend always supported
|
||||||
return true;
|
return true;
|
||||||
|
|
@ -669,6 +670,9 @@ static bool llama_sampler_backend_support(
|
||||||
struct ggml_tensor * op = ggml_graph_node(gf, i);
|
struct ggml_tensor * op = ggml_graph_node(gf, i);
|
||||||
|
|
||||||
if (!ggml_backend_dev_supports_op(device, op)) {
|
if (!ggml_backend_dev_supports_op(device, op)) {
|
||||||
|
LLAMA_LOG_WARN("%s: device '%s' does not have support for op %s needed for sampler '%s'\n",
|
||||||
|
__func__, ggml_backend_dev_name(device), ggml_op_name(op->op), smpl->iface->name(smpl));
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -958,7 +962,7 @@ static bool llama_sampler_greedy_backend_init(
|
||||||
ggml_backend_buffer_type_t buft) {
|
ggml_backend_buffer_type_t buft) {
|
||||||
auto * sctx = (llama_sampler_greedy *) smpl->ctx;
|
auto * sctx = (llama_sampler_greedy *) smpl->ctx;
|
||||||
|
|
||||||
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
|
const bool res = llama_sampler_backend_support(smpl, buft);
|
||||||
|
|
||||||
sctx->init(res);
|
sctx->init(res);
|
||||||
|
|
||||||
|
|
@ -1142,7 +1146,7 @@ static bool llama_sampler_dist_backend_init(
|
||||||
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
|
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
|
||||||
}
|
}
|
||||||
|
|
||||||
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
|
const bool res = llama_sampler_backend_support(smpl, buft);
|
||||||
|
|
||||||
sctx->init(res);
|
sctx->init(res);
|
||||||
|
|
||||||
|
|
@ -1274,7 +1278,7 @@ static bool llama_sampler_top_k_backend_init(
|
||||||
ggml_backend_buffer_type_t buft) {
|
ggml_backend_buffer_type_t buft) {
|
||||||
auto * sctx = (llama_sampler_top_k *) smpl->ctx;
|
auto * sctx = (llama_sampler_top_k *) smpl->ctx;
|
||||||
|
|
||||||
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
|
const bool res = llama_sampler_backend_support(smpl, buft);
|
||||||
|
|
||||||
sctx->init(res);
|
sctx->init(res);
|
||||||
|
|
||||||
|
|
@ -1419,7 +1423,7 @@ static bool llama_sampler_top_p_backend_init(
|
||||||
ggml_backend_buffer_type_t buft) {
|
ggml_backend_buffer_type_t buft) {
|
||||||
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
|
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
|
||||||
|
|
||||||
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
|
const bool res = llama_sampler_backend_support(smpl, buft);
|
||||||
|
|
||||||
sctx->init(res);
|
sctx->init(res);
|
||||||
|
|
||||||
|
|
@ -1613,7 +1617,7 @@ static bool llama_sampler_min_p_backend_init(
|
||||||
ggml_backend_buffer_type_t buft) {
|
ggml_backend_buffer_type_t buft) {
|
||||||
auto * sctx = (llama_sampler_min_p *) smpl->ctx;
|
auto * sctx = (llama_sampler_min_p *) smpl->ctx;
|
||||||
|
|
||||||
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
|
const bool res = llama_sampler_backend_support(smpl, buft);
|
||||||
|
|
||||||
sctx->init(res);
|
sctx->init(res);
|
||||||
|
|
||||||
|
|
@ -1861,7 +1865,7 @@ static bool llama_sampler_temp_backend_init(
|
||||||
ggml_backend_buffer_type_t buft) {
|
ggml_backend_buffer_type_t buft) {
|
||||||
auto * sctx = (llama_sampler_temp *) smpl->ctx;
|
auto * sctx = (llama_sampler_temp *) smpl->ctx;
|
||||||
|
|
||||||
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
|
const bool res = llama_sampler_backend_support(smpl, buft);
|
||||||
|
|
||||||
sctx->init(res);
|
sctx->init(res);
|
||||||
|
|
||||||
|
|
@ -2004,7 +2008,7 @@ static bool llama_sampler_temp_ext_backend_init(
|
||||||
ggml_backend_buffer_type_t buft) {
|
ggml_backend_buffer_type_t buft) {
|
||||||
auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
|
auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
|
||||||
|
|
||||||
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
|
const bool res = llama_sampler_backend_support(smpl, buft);
|
||||||
|
|
||||||
sctx->init(res);
|
sctx->init(res);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@
|
||||||
#undef NDEBUG
|
#undef NDEBUG
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <array>
|
#include <array>
|
||||||
|
|
@ -31,7 +32,15 @@ struct test_model_context {
|
||||||
|
|
||||||
llama_backend_init();
|
llama_backend_init();
|
||||||
|
|
||||||
model = llama_model_load_from_file(model_path, llama_model_default_params());
|
// force CPU backend since it always supports all ggml operations
|
||||||
|
ggml_backend_dev_t devs[2];
|
||||||
|
devs[0] = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||||
|
devs[1] = nullptr;
|
||||||
|
|
||||||
|
auto mparams = llama_model_default_params();
|
||||||
|
mparams.devices = devs;
|
||||||
|
|
||||||
|
model = llama_model_load_from_file(model_path, mparams);
|
||||||
if (model == nullptr) {
|
if (model == nullptr) {
|
||||||
fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", model_path);
|
fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", model_path);
|
||||||
cleanup();
|
cleanup();
|
||||||
|
|
@ -63,9 +72,7 @@ struct test_model_context {
|
||||||
if (n_seq_max < 0) {
|
if (n_seq_max < 0) {
|
||||||
int32_t max_seq_id = 0;
|
int32_t max_seq_id = 0;
|
||||||
for (const auto & config : configs) {
|
for (const auto & config : configs) {
|
||||||
if (config.seq_id > max_seq_id) {
|
max_seq_id = std::max(config.seq_id, max_seq_id);
|
||||||
max_seq_id = config.seq_id;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
cparams.n_seq_max = max_seq_id + 1;
|
cparams.n_seq_max = max_seq_id + 1;
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -859,6 +866,8 @@ static void test_backend_logit_bias_sampling(const char * model_path) {
|
||||||
printf("logit bias sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str());
|
printf("logit bias sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str());
|
||||||
GGML_ASSERT(backend_token == bias_token);
|
GGML_ASSERT(backend_token == bias_token);
|
||||||
|
|
||||||
|
printf("backend logit bias sampling test PASSED\n");
|
||||||
|
|
||||||
llama_sampler_free(backend_sampler_chain);
|
llama_sampler_free(backend_sampler_chain);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue