diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 2a8e129d75..a5a48a4d7e 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -614,6 +614,68 @@ private: bool support; }; +// check if all ggml ops used by the sampler are supported by the backend +static bool llama_sampler_backend_support( + llama_sampler * smpl, + ggml_backend_dev_t device) { + if (!device) { + // CPU backend always supported + return true; + } + + ggml_init_params params = { + /*.mem_size =*/ 128*ggml_tensor_overhead() + ggml_graph_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error(format("failed to create ggml context")); + } + + ggml_context * ctx = ctx_ptr.get(); + + const int64_t n = 1024*1024; + + llama_sampler_data data = { + /*.logits = */ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n), + /*.probs = */ nullptr, + /*.sampled = */ nullptr, + /*.candidates = */ ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n), + }; + + ggml_cgraph * gf = ggml_new_graph(ctx); + + smpl->iface->backend_apply(smpl, ctx, gf, &data); + + if (data.logits) { + ggml_build_forward_expand(gf, data.logits); + } + + if (data.probs) { + ggml_build_forward_expand(gf, data.probs); + } + + if (data.sampled) { + ggml_build_forward_expand(gf, data.sampled); + } + + if (data.candidates) { + ggml_build_forward_expand(gf, data.candidates); + } + + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + struct ggml_tensor * op = ggml_graph_node(gf, i); + + if (!ggml_backend_dev_supports_op(device, op)) { + return false; + } + } + + return true; +} + // sampler chain static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) { @@ -850,8 +912,12 @@ int llama_sampler_chain_n(const struct llama_sampler * chain) { // greedy -static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) { - return "greedy"; +struct llama_sampler_greedy : public llama_sampler_backend { +}; + +static const char * llama_sampler_greedy_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_greedy *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { @@ -866,10 +932,13 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to static bool llama_sampler_greedy_backend_init( struct llama_sampler * smpl, ggml_backend_buffer_type_t buft) { - GGML_UNUSED(smpl); - GGML_UNUSED(buft); + auto * sctx = (llama_sampler_greedy *) smpl->ctx; - return true; + const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft)); + + sctx->init(res); + + return res; } static void llama_sampler_greedy_backend_apply( @@ -879,9 +948,11 @@ static void llama_sampler_greedy_backend_apply( struct llama_sampler_data * data) { GGML_UNUSED(gf); GGML_UNUSED(smpl); - struct ggml_tensor * argmax_result = ggml_argmax(ctx, data->logits); - ggml_set_name(argmax_result, "argmax_result"); - data->sampled = argmax_result; + + struct ggml_tensor * curl = ggml_argmax(ctx, data->logits); + ggml_set_name(curl, "greedy_argmax"); + + data->sampled = curl; } static struct llama_sampler_i llama_sampler_greedy_i = { @@ -900,7 +971,9 @@ static struct llama_sampler_i llama_sampler_greedy_i = { struct llama_sampler * llama_sampler_init_greedy() { return llama_sampler_init( /* .iface = */ &llama_sampler_greedy_i, - /* .ctx = */ nullptr + /* .ctx = */ new llama_sampler_greedy { + ("greedy"), + } ); } @@ -1025,36 +1098,8 @@ static bool llama_sampler_dist_backend_init( ggml_backend_buffer_type_t buft) { auto * sctx = (llama_sampler_dist *) smpl->ctx; - bool res = true; - - // determine backend support + // allocate inputs { - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead()*8, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - - ggml_context_ptr ctx_ptr { ggml_init(params) }; - if (!ctx_ptr) { - throw std::runtime_error(format("failed to create ggml context")); - } - - ggml_context * ctx = ctx_ptr.get(); - - ggml_tensor * probs = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1024*1024); - ggml_tensor * op = ggml_cumsum(ctx, probs); - - auto * device = ggml_backend_buft_get_device(buft); - - if (device && !ggml_backend_dev_supports_op(device, op)) { - res = false; - } - - sctx->init(res); - } - - if (res) { ggml_init_params params = { /*.mem_size =*/ ggml_tensor_overhead(), /*.mem_buffer =*/ nullptr, @@ -1066,13 +1111,22 @@ static bool llama_sampler_dist_backend_init( // Create the uniform random scalar input tensor. This will be set by // llama_sampler_dist_backend_set_input after this graph is built. sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1); - ggml_set_name(sctx->inp_uniform, "uniform"); + ggml_set_name (sctx->inp_uniform, "uniform"); ggml_set_input(sctx->inp_uniform); // Allocate all tensors from our context to the backend sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); } + const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft)); + + sctx->init(res); + + if (!res) { + sctx->inp_ctx.reset(nullptr); + sctx->inp_buf.reset(nullptr); + } + return res; } @@ -1088,7 +1142,7 @@ static void llama_sampler_dist_backend_apply( ggml_set_name(probs, "dist_probs"); struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs); - ggml_set_name(cumsum, "cumsum"); + ggml_set_name(cumsum, "dist_cumsum"); // The uniform tensor has a random value and we subtract this tensor with // the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub). @@ -1196,34 +1250,9 @@ static bool llama_sampler_top_k_backend_init( ggml_backend_buffer_type_t buft) { auto * sctx = (llama_sampler_top_k *) smpl->ctx; - bool res = true; + const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft)); - // determine backend support - { - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead()*8, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - - ggml_context_ptr ctx_ptr { ggml_init(params) }; - if (!ctx_ptr) { - throw std::runtime_error(format("failed to create ggml context")); - } - - ggml_context * ctx = ctx_ptr.get(); - - ggml_tensor * logits = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1024*1024); - ggml_tensor * op = ggml_top_k(ctx, logits, sctx->k); - - auto * device = ggml_backend_buft_get_device(buft); - - if (device && !ggml_backend_dev_supports_op(device, op)) { - res = false; - } - - sctx->init(res); - } + sctx->init(res); return res; } @@ -1240,6 +1269,7 @@ static void llama_sampler_top_k_backend_apply( if (data->candidates) { data->candidates = ggml_get_rows(ctx, data->candidates, top_k); + ggml_set_name(data->candidates, "top_k_candidates"); } else { data->candidates = top_k; } @@ -1363,12 +1393,13 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) { static bool llama_sampler_top_p_backend_init( struct llama_sampler * smpl, ggml_backend_buffer_type_t buft) { - GGML_UNUSED(buft); - auto * sctx = (llama_sampler_top_p *) smpl->ctx; - sctx->init(true); - return true; + const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft)); + + sctx->init(res); + + return res; } static void llama_sampler_top_p_backend_apply( @@ -1418,7 +1449,7 @@ static void llama_sampler_top_p_backend_apply( // Taking the sum of the mask gives us the sum of elements after the threshold // we are interested in. struct ggml_tensor * idxf = ggml_sum(ctx, mask); - ggml_set_name(idxf, "dist_index_f32"); + ggml_set_name(idxf, "top_p_index_f32"); // prevent out-of-bounds access idxf = ggml_clamp(ctx, idxf, 0.0f, mask->ne[0] - 1); @@ -1556,13 +1587,13 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) { static bool llama_sampler_min_p_backend_init( struct llama_sampler * smpl, ggml_backend_buffer_type_t buft) { - GGML_UNUSED(buft); - auto * sctx = (llama_sampler_min_p *) smpl->ctx; - sctx->init(true); + const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft)); - return true; + sctx->init(res); + + return res; } static void llama_sampler_min_p_backend_apply( @@ -1804,13 +1835,13 @@ static void llama_sampler_backend_temp_sampling( static bool llama_sampler_temp_backend_init( struct llama_sampler * smpl, ggml_backend_buffer_type_t buft) { - GGML_UNUSED(buft); - auto * sctx = (llama_sampler_temp *) smpl->ctx; - sctx->init(true); + const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft)); - return true; + sctx->init(res); + + return res; } static void llama_sampler_temp_backend_apply( @@ -1947,13 +1978,13 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) { static bool llama_sampler_temp_ext_backend_init( struct llama_sampler * smpl, ggml_backend_buffer_type_t buft) { - GGML_UNUSED(buft); - auto * sctx = (llama_sampler_temp_ext *) smpl->ctx; - sctx->init(true); + const bool res = llama_sampler_backend_support(smpl, ggml_backend_buft_get_device(buft)); - return true; + sctx->init(res); + + return res; } static void llama_sampler_temp_ext_backend_apply( @@ -3289,7 +3320,6 @@ static void llama_sampler_logit_bias_backend_apply( return; } - //struct ggml_tensor * logit_biased = ggml_add(ctx, data->logits, sctx->inp_logit_bias); ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f); cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur));