sampling : generic ggml op support detection
This commit is contained in:
parent
d5d16651a8
commit
8544aba37f
|
|
@ -614,6 +614,68 @@ private:
|
|||
bool support;
|
||||
};
|
||||
|
||||
// check if all ggml ops used by the sampler are supported by the backend
|
||||
static bool llama_sampler_backend_support(
|
||||
llama_sampler * smpl,
|
||||
ggml_backend_dev_t device) {
|
||||
if (!device) {
|
||||
// CPU backend always supported
|
||||
return true;
|
||||
}
|
||||
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ 128*ggml_tensor_overhead() + ggml_graph_overhead(),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
||||
if (!ctx_ptr) {
|
||||
throw std::runtime_error(format("failed to create ggml context"));
|
||||
}
|
||||
|
||||
ggml_context * ctx = ctx_ptr.get();
|
||||
|
||||
const int64_t n = 1024*1024;
|
||||
|
||||
llama_sampler_data data = {
|
||||
/*.logits = */ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n),
|
||||
/*.probs = */ nullptr,
|
||||
/*.sampled = */ nullptr,
|
||||
/*.candidates = */ ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n),
|
||||
};
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx);
|
||||
|
||||
smpl->iface->backend_apply(smpl, ctx, gf, &data);
|
||||
|
||||
if (data.logits) {
|
||||
ggml_build_forward_expand(gf, data.logits);
|
||||
}
|
||||
|
||||
if (data.probs) {
|
||||
ggml_build_forward_expand(gf, data.probs);
|
||||
}
|
||||
|
||||
if (data.sampled) {
|
||||
ggml_build_forward_expand(gf, data.sampled);
|
||||
}
|
||||
|
||||
if (data.candidates) {
|
||||
ggml_build_forward_expand(gf, data.candidates);
|
||||
}
|
||||
|
||||
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
||||
struct ggml_tensor * op = ggml_graph_node(gf, i);
|
||||
|
||||
if (!ggml_backend_dev_supports_op(device, op)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// sampler chain
|
||||
|
||||
static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
|
||||
|
|
@ -850,8 +912,12 @@ int llama_sampler_chain_n(const struct llama_sampler * chain) {
|
|||
|
||||
// greedy
|
||||
|
||||
static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) {
|
||||
return "greedy";
|
||||
struct llama_sampler_greedy : public llama_sampler_backend {
|
||||
};
|
||||
|
||||
static const char * llama_sampler_greedy_name(const struct llama_sampler * smpl) {
|
||||
auto * sctx = (llama_sampler_greedy *) smpl->ctx;
|
||||
return sctx->get_name();
|
||||
}
|
||||
|
||||
static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
|
||||
|
|
@ -866,10 +932,13 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to
|
|||
static bool llama_sampler_greedy_backend_init(
|
||||
struct llama_sampler * smpl,
|
||||
ggml_backend_buffer_type_t buft) {
|
||||
GGML_UNUSED(smpl);
|
||||
GGML_UNUSED(buft);
|
||||
auto * sctx = (llama_sampler_greedy *) smpl->ctx;
|
||||
|
||||
return true;
|
||||
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
|
||||
|
||||
sctx->init(res);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
static void llama_sampler_greedy_backend_apply(
|
||||
|
|
@ -879,9 +948,11 @@ static void llama_sampler_greedy_backend_apply(
|
|||
struct llama_sampler_data * data) {
|
||||
GGML_UNUSED(gf);
|
||||
GGML_UNUSED(smpl);
|
||||
struct ggml_tensor * argmax_result = ggml_argmax(ctx, data->logits);
|
||||
ggml_set_name(argmax_result, "argmax_result");
|
||||
data->sampled = argmax_result;
|
||||
|
||||
struct ggml_tensor * curl = ggml_argmax(ctx, data->logits);
|
||||
ggml_set_name(curl, "greedy_argmax");
|
||||
|
||||
data->sampled = curl;
|
||||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_greedy_i = {
|
||||
|
|
@ -900,7 +971,9 @@ static struct llama_sampler_i llama_sampler_greedy_i = {
|
|||
struct llama_sampler * llama_sampler_init_greedy() {
|
||||
return llama_sampler_init(
|
||||
/* .iface = */ &llama_sampler_greedy_i,
|
||||
/* .ctx = */ nullptr
|
||||
/* .ctx = */ new llama_sampler_greedy {
|
||||
("greedy"),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -1025,36 +1098,8 @@ static bool llama_sampler_dist_backend_init(
|
|||
ggml_backend_buffer_type_t buft) {
|
||||
auto * sctx = (llama_sampler_dist *) smpl->ctx;
|
||||
|
||||
bool res = true;
|
||||
|
||||
// determine backend support
|
||||
// allocate inputs
|
||||
{
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ ggml_tensor_overhead()*8,
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
||||
if (!ctx_ptr) {
|
||||
throw std::runtime_error(format("failed to create ggml context"));
|
||||
}
|
||||
|
||||
ggml_context * ctx = ctx_ptr.get();
|
||||
|
||||
ggml_tensor * probs = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1024*1024);
|
||||
ggml_tensor * op = ggml_cumsum(ctx, probs);
|
||||
|
||||
auto * device = ggml_backend_buft_get_device(buft);
|
||||
|
||||
if (device && !ggml_backend_dev_supports_op(device, op)) {
|
||||
res = false;
|
||||
}
|
||||
|
||||
sctx->init(res);
|
||||
}
|
||||
|
||||
if (res) {
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ ggml_tensor_overhead(),
|
||||
/*.mem_buffer =*/ nullptr,
|
||||
|
|
@ -1066,13 +1111,22 @@ static bool llama_sampler_dist_backend_init(
|
|||
// Create the uniform random scalar input tensor. This will be set by
|
||||
// llama_sampler_dist_backend_set_input after this graph is built.
|
||||
sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1);
|
||||
ggml_set_name(sctx->inp_uniform, "uniform");
|
||||
ggml_set_name (sctx->inp_uniform, "uniform");
|
||||
ggml_set_input(sctx->inp_uniform);
|
||||
|
||||
// Allocate all tensors from our context to the backend
|
||||
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
|
||||
}
|
||||
|
||||
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
|
||||
|
||||
sctx->init(res);
|
||||
|
||||
if (!res) {
|
||||
sctx->inp_ctx.reset(nullptr);
|
||||
sctx->inp_buf.reset(nullptr);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
@ -1088,7 +1142,7 @@ static void llama_sampler_dist_backend_apply(
|
|||
ggml_set_name(probs, "dist_probs");
|
||||
|
||||
struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs);
|
||||
ggml_set_name(cumsum, "cumsum");
|
||||
ggml_set_name(cumsum, "dist_cumsum");
|
||||
|
||||
// The uniform tensor has a random value and we subtract this tensor with
|
||||
// the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub).
|
||||
|
|
@ -1196,34 +1250,9 @@ static bool llama_sampler_top_k_backend_init(
|
|||
ggml_backend_buffer_type_t buft) {
|
||||
auto * sctx = (llama_sampler_top_k *) smpl->ctx;
|
||||
|
||||
bool res = true;
|
||||
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
|
||||
|
||||
// determine backend support
|
||||
{
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ ggml_tensor_overhead()*8,
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
||||
if (!ctx_ptr) {
|
||||
throw std::runtime_error(format("failed to create ggml context"));
|
||||
}
|
||||
|
||||
ggml_context * ctx = ctx_ptr.get();
|
||||
|
||||
ggml_tensor * logits = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1024*1024);
|
||||
ggml_tensor * op = ggml_top_k(ctx, logits, sctx->k);
|
||||
|
||||
auto * device = ggml_backend_buft_get_device(buft);
|
||||
|
||||
if (device && !ggml_backend_dev_supports_op(device, op)) {
|
||||
res = false;
|
||||
}
|
||||
|
||||
sctx->init(res);
|
||||
}
|
||||
sctx->init(res);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
|
@ -1240,6 +1269,7 @@ static void llama_sampler_top_k_backend_apply(
|
|||
|
||||
if (data->candidates) {
|
||||
data->candidates = ggml_get_rows(ctx, data->candidates, top_k);
|
||||
ggml_set_name(data->candidates, "top_k_candidates");
|
||||
} else {
|
||||
data->candidates = top_k;
|
||||
}
|
||||
|
|
@ -1363,12 +1393,13 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
|
|||
static bool llama_sampler_top_p_backend_init(
|
||||
struct llama_sampler * smpl,
|
||||
ggml_backend_buffer_type_t buft) {
|
||||
GGML_UNUSED(buft);
|
||||
|
||||
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
|
||||
sctx->init(true);
|
||||
|
||||
return true;
|
||||
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
|
||||
|
||||
sctx->init(res);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
static void llama_sampler_top_p_backend_apply(
|
||||
|
|
@ -1418,7 +1449,7 @@ static void llama_sampler_top_p_backend_apply(
|
|||
// Taking the sum of the mask gives us the sum of elements after the threshold
|
||||
// we are interested in.
|
||||
struct ggml_tensor * idxf = ggml_sum(ctx, mask);
|
||||
ggml_set_name(idxf, "dist_index_f32");
|
||||
ggml_set_name(idxf, "top_p_index_f32");
|
||||
|
||||
// prevent out-of-bounds access
|
||||
idxf = ggml_clamp(ctx, idxf, 0.0f, mask->ne[0] - 1);
|
||||
|
|
@ -1556,13 +1587,13 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
|
|||
static bool llama_sampler_min_p_backend_init(
|
||||
struct llama_sampler * smpl,
|
||||
ggml_backend_buffer_type_t buft) {
|
||||
GGML_UNUSED(buft);
|
||||
|
||||
auto * sctx = (llama_sampler_min_p *) smpl->ctx;
|
||||
|
||||
sctx->init(true);
|
||||
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
|
||||
|
||||
return true;
|
||||
sctx->init(res);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
static void llama_sampler_min_p_backend_apply(
|
||||
|
|
@ -1804,13 +1835,13 @@ static void llama_sampler_backend_temp_sampling(
|
|||
static bool llama_sampler_temp_backend_init(
|
||||
struct llama_sampler * smpl,
|
||||
ggml_backend_buffer_type_t buft) {
|
||||
GGML_UNUSED(buft);
|
||||
|
||||
auto * sctx = (llama_sampler_temp *) smpl->ctx;
|
||||
|
||||
sctx->init(true);
|
||||
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
|
||||
|
||||
return true;
|
||||
sctx->init(res);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
static void llama_sampler_temp_backend_apply(
|
||||
|
|
@ -1947,13 +1978,13 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
|
|||
static bool llama_sampler_temp_ext_backend_init(
|
||||
struct llama_sampler * smpl,
|
||||
ggml_backend_buffer_type_t buft) {
|
||||
GGML_UNUSED(buft);
|
||||
|
||||
auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
|
||||
|
||||
sctx->init(true);
|
||||
const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft));
|
||||
|
||||
return true;
|
||||
sctx->init(res);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
static void llama_sampler_temp_ext_backend_apply(
|
||||
|
|
@ -3289,7 +3320,6 @@ static void llama_sampler_logit_bias_backend_apply(
|
|||
return;
|
||||
}
|
||||
|
||||
//struct ggml_tensor * logit_biased = ggml_add(ctx, data->logits, sctx->inp_logit_bias);
|
||||
ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f);
|
||||
|
||||
cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur));
|
||||
|
|
|
|||
Loading…
Reference in New Issue