sampling : generic ggml op support detection

This commit is contained in:
Georgi Gerganov 2025-12-11 13:19:43 +02:00
parent d5d16651a8
commit 8544aba37f
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
1 changed files with 115 additions and 85 deletions

View File

@ -614,6 +614,68 @@ private:
bool support;
};
// check if all ggml ops used by the sampler are supported by the backend
static bool llama_sampler_backend_support(
llama_sampler * smpl,
ggml_backend_dev_t device) {
if (!device) {
// CPU backend always supported
return true;
}
ggml_init_params params = {
/*.mem_size =*/ 128*ggml_tensor_overhead() + ggml_graph_overhead(),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
ggml_context_ptr ctx_ptr { ggml_init(params) };
if (!ctx_ptr) {
throw std::runtime_error(format("failed to create ggml context"));
}
ggml_context * ctx = ctx_ptr.get();
const int64_t n = 1024*1024;
llama_sampler_data data = {
/*.logits = */ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n),
/*.probs = */ nullptr,
/*.sampled = */ nullptr,
/*.candidates = */ ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n),
};
ggml_cgraph * gf = ggml_new_graph(ctx);
smpl->iface->backend_apply(smpl, ctx, gf, &data);
if (data.logits) {
ggml_build_forward_expand(gf, data.logits);
}
if (data.probs) {
ggml_build_forward_expand(gf, data.probs);
}
if (data.sampled) {
ggml_build_forward_expand(gf, data.sampled);
}
if (data.candidates) {
ggml_build_forward_expand(gf, data.candidates);
}
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
struct ggml_tensor * op = ggml_graph_node(gf, i);
if (!ggml_backend_dev_supports_op(device, op)) {
return false;
}
}
return true;
}
// sampler chain
static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
@ -850,8 +912,12 @@ int llama_sampler_chain_n(const struct llama_sampler * chain) {
// greedy
static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) {
return "greedy";
struct llama_sampler_greedy : public llama_sampler_backend {
};
static const char * llama_sampler_greedy_name(const struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_greedy *) smpl->ctx;
return sctx->get_name();
}
static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
@ -866,10 +932,13 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to
static bool llama_sampler_greedy_backend_init(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
GGML_UNUSED(smpl);
GGML_UNUSED(buft);
auto * sctx = (llama_sampler_greedy *) smpl->ctx;
return true;
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
sctx->init(res);
return res;
}
static void llama_sampler_greedy_backend_apply(
@ -879,9 +948,11 @@ static void llama_sampler_greedy_backend_apply(
struct llama_sampler_data * data) {
GGML_UNUSED(gf);
GGML_UNUSED(smpl);
struct ggml_tensor * argmax_result = ggml_argmax(ctx, data->logits);
ggml_set_name(argmax_result, "argmax_result");
data->sampled = argmax_result;
struct ggml_tensor * curl = ggml_argmax(ctx, data->logits);
ggml_set_name(curl, "greedy_argmax");
data->sampled = curl;
}
static struct llama_sampler_i llama_sampler_greedy_i = {
@ -900,7 +971,9 @@ static struct llama_sampler_i llama_sampler_greedy_i = {
struct llama_sampler * llama_sampler_init_greedy() {
return llama_sampler_init(
/* .iface = */ &llama_sampler_greedy_i,
/* .ctx = */ nullptr
/* .ctx = */ new llama_sampler_greedy {
("greedy"),
}
);
}
@ -1025,36 +1098,8 @@ static bool llama_sampler_dist_backend_init(
ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_dist *) smpl->ctx;
bool res = true;
// determine backend support
// allocate inputs
{
ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead()*8,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
ggml_context_ptr ctx_ptr { ggml_init(params) };
if (!ctx_ptr) {
throw std::runtime_error(format("failed to create ggml context"));
}
ggml_context * ctx = ctx_ptr.get();
ggml_tensor * probs = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1024*1024);
ggml_tensor * op = ggml_cumsum(ctx, probs);
auto * device = ggml_backend_buft_get_device(buft);
if (device && !ggml_backend_dev_supports_op(device, op)) {
res = false;
}
sctx->init(res);
}
if (res) {
ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
@ -1066,13 +1111,22 @@ static bool llama_sampler_dist_backend_init(
// Create the uniform random scalar input tensor. This will be set by
// llama_sampler_dist_backend_set_input after this graph is built.
sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1);
ggml_set_name(sctx->inp_uniform, "uniform");
ggml_set_name (sctx->inp_uniform, "uniform");
ggml_set_input(sctx->inp_uniform);
// Allocate all tensors from our context to the backend
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
}
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
sctx->init(res);
if (!res) {
sctx->inp_ctx.reset(nullptr);
sctx->inp_buf.reset(nullptr);
}
return res;
}
@ -1088,7 +1142,7 @@ static void llama_sampler_dist_backend_apply(
ggml_set_name(probs, "dist_probs");
struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs);
ggml_set_name(cumsum, "cumsum");
ggml_set_name(cumsum, "dist_cumsum");
// The uniform tensor has a random value and we subtract this tensor with
// the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub).
@ -1196,34 +1250,9 @@ static bool llama_sampler_top_k_backend_init(
ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_top_k *) smpl->ctx;
bool res = true;
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
// determine backend support
{
ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead()*8,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
ggml_context_ptr ctx_ptr { ggml_init(params) };
if (!ctx_ptr) {
throw std::runtime_error(format("failed to create ggml context"));
}
ggml_context * ctx = ctx_ptr.get();
ggml_tensor * logits = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1024*1024);
ggml_tensor * op = ggml_top_k(ctx, logits, sctx->k);
auto * device = ggml_backend_buft_get_device(buft);
if (device && !ggml_backend_dev_supports_op(device, op)) {
res = false;
}
sctx->init(res);
}
sctx->init(res);
return res;
}
@ -1240,6 +1269,7 @@ static void llama_sampler_top_k_backend_apply(
if (data->candidates) {
data->candidates = ggml_get_rows(ctx, data->candidates, top_k);
ggml_set_name(data->candidates, "top_k_candidates");
} else {
data->candidates = top_k;
}
@ -1363,12 +1393,13 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
static bool llama_sampler_top_p_backend_init(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
GGML_UNUSED(buft);
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
sctx->init(true);
return true;
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
sctx->init(res);
return res;
}
static void llama_sampler_top_p_backend_apply(
@ -1418,7 +1449,7 @@ static void llama_sampler_top_p_backend_apply(
// Taking the sum of the mask gives us the sum of elements after the threshold
// we are interested in.
struct ggml_tensor * idxf = ggml_sum(ctx, mask);
ggml_set_name(idxf, "dist_index_f32");
ggml_set_name(idxf, "top_p_index_f32");
// prevent out-of-bounds access
idxf = ggml_clamp(ctx, idxf, 0.0f, mask->ne[0] - 1);
@ -1556,13 +1587,13 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
static bool llama_sampler_min_p_backend_init(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
GGML_UNUSED(buft);
auto * sctx = (llama_sampler_min_p *) smpl->ctx;
sctx->init(true);
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
return true;
sctx->init(res);
return res;
}
static void llama_sampler_min_p_backend_apply(
@ -1804,13 +1835,13 @@ static void llama_sampler_backend_temp_sampling(
static bool llama_sampler_temp_backend_init(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
GGML_UNUSED(buft);
auto * sctx = (llama_sampler_temp *) smpl->ctx;
sctx->init(true);
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
return true;
sctx->init(res);
return res;
}
static void llama_sampler_temp_backend_apply(
@ -1947,13 +1978,13 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
static bool llama_sampler_temp_ext_backend_init(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
GGML_UNUSED(buft);
auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
sctx->init(true);
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
return true;
sctx->init(res);
return res;
}
static void llama_sampler_temp_ext_backend_apply(
@ -3289,7 +3320,6 @@ static void llama_sampler_logit_bias_backend_apply(
return;
}
//struct ggml_tensor * logit_biased = ggml_add(ctx, data->logits, sctx->inp_logit_bias);
ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f);
cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur));