tests : run backend sampler tests always on the CPU
This commit is contained in:
parent
74b112e3e7
commit
ab65b47a52
|
|
@ -616,8 +616,9 @@ private:
|
|||
|
||||
// check if all ggml ops used by the sampler are supported by the backend
|
||||
static bool llama_sampler_backend_support(
|
||||
llama_sampler * smpl,
|
||||
ggml_backend_dev_t device) {
|
||||
llama_sampler * smpl,
|
||||
ggml_backend_buffer_type_t buft) {
|
||||
auto * device = ggml_backend_buft_get_device(buft);
|
||||
if (!device) {
|
||||
// CPU backend always supported
|
||||
return true;
|
||||
|
|
@ -669,6 +670,9 @@ static bool llama_sampler_backend_support(
|
|||
struct ggml_tensor * op = ggml_graph_node(gf, i);
|
||||
|
||||
if (!ggml_backend_dev_supports_op(device, op)) {
|
||||
LLAMA_LOG_WARN("%s: device '%s' does not have support for op %s needed for sampler '%s'\n",
|
||||
__func__, ggml_backend_dev_name(device), ggml_op_name(op->op), smpl->iface->name(smpl));
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -958,7 +962,7 @@ static bool llama_sampler_greedy_backend_init(
|
|||
ggml_backend_buffer_type_t buft) {
|
||||
auto * sctx = (llama_sampler_greedy *) smpl->ctx;
|
||||
|
||||
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
|
||||
const bool res = llama_sampler_backend_support(smpl, buft);
|
||||
|
||||
sctx->init(res);
|
||||
|
||||
|
|
@ -1142,7 +1146,7 @@ static bool llama_sampler_dist_backend_init(
|
|||
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
|
||||
}
|
||||
|
||||
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
|
||||
const bool res = llama_sampler_backend_support(smpl, buft);
|
||||
|
||||
sctx->init(res);
|
||||
|
||||
|
|
@ -1274,7 +1278,7 @@ static bool llama_sampler_top_k_backend_init(
|
|||
ggml_backend_buffer_type_t buft) {
|
||||
auto * sctx = (llama_sampler_top_k *) smpl->ctx;
|
||||
|
||||
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
|
||||
const bool res = llama_sampler_backend_support(smpl, buft);
|
||||
|
||||
sctx->init(res);
|
||||
|
||||
|
|
@ -1419,7 +1423,7 @@ static bool llama_sampler_top_p_backend_init(
|
|||
ggml_backend_buffer_type_t buft) {
|
||||
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
|
||||
|
||||
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
|
||||
const bool res = llama_sampler_backend_support(smpl, buft);
|
||||
|
||||
sctx->init(res);
|
||||
|
||||
|
|
@ -1613,7 +1617,7 @@ static bool llama_sampler_min_p_backend_init(
|
|||
ggml_backend_buffer_type_t buft) {
|
||||
auto * sctx = (llama_sampler_min_p *) smpl->ctx;
|
||||
|
||||
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
|
||||
const bool res = llama_sampler_backend_support(smpl, buft);
|
||||
|
||||
sctx->init(res);
|
||||
|
||||
|
|
@ -1861,7 +1865,7 @@ static bool llama_sampler_temp_backend_init(
|
|||
ggml_backend_buffer_type_t buft) {
|
||||
auto * sctx = (llama_sampler_temp *) smpl->ctx;
|
||||
|
||||
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
|
||||
const bool res = llama_sampler_backend_support(smpl, buft);
|
||||
|
||||
sctx->init(res);
|
||||
|
||||
|
|
@ -2004,7 +2008,7 @@ static bool llama_sampler_temp_ext_backend_init(
|
|||
ggml_backend_buffer_type_t buft) {
|
||||
auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
|
||||
|
||||
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
|
||||
const bool res = llama_sampler_backend_support(smpl, buft);
|
||||
|
||||
sctx->init(res);
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#undef NDEBUG
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <array>
|
||||
|
|
@ -31,7 +32,15 @@ struct test_model_context {
|
|||
|
||||
llama_backend_init();
|
||||
|
||||
model = llama_model_load_from_file(model_path, llama_model_default_params());
|
||||
// force CPU backend since it always supports all ggml operations
|
||||
ggml_backend_dev_t devs[2];
|
||||
devs[0] = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
||||
devs[1] = nullptr;
|
||||
|
||||
auto mparams = llama_model_default_params();
|
||||
mparams.devices = devs;
|
||||
|
||||
model = llama_model_load_from_file(model_path, mparams);
|
||||
if (model == nullptr) {
|
||||
fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", model_path);
|
||||
cleanup();
|
||||
|
|
@ -63,9 +72,7 @@ struct test_model_context {
|
|||
if (n_seq_max < 0) {
|
||||
int32_t max_seq_id = 0;
|
||||
for (const auto & config : configs) {
|
||||
if (config.seq_id > max_seq_id) {
|
||||
max_seq_id = config.seq_id;
|
||||
}
|
||||
max_seq_id = std::max(config.seq_id, max_seq_id);
|
||||
}
|
||||
cparams.n_seq_max = max_seq_id + 1;
|
||||
} else {
|
||||
|
|
@ -859,6 +866,8 @@ static void test_backend_logit_bias_sampling(const char * model_path) {
|
|||
printf("logit bias sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str());
|
||||
GGML_ASSERT(backend_token == bias_token);
|
||||
|
||||
printf("backend logit bias sampling test PASSED\n");
|
||||
|
||||
llama_sampler_free(backend_sampler_chain);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue