diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 86f82d1691..9eee48f753 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1226,16 +1226,16 @@ static void llama_sampler_top_k_backend_apply( struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, sctx->k); ggml_set_name(top_k, "top_k"); - struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); - struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k); - ggml_set_name(top_k_rows, "top_k_rows"); - if (data->candidates) { data->candidates = ggml_get_rows(ctx, data->candidates, top_k); } else { data->candidates = top_k; } + struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); + struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k); + ggml_set_name(top_k_rows, "top_k_rows"); + data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k); GGML_UNUSED(gf); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 0e4224e115..8b529781c6 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1056,11 +1056,20 @@ struct server_context_impl { return false; } - // TODO: getting post/pre sampling logits is not yet supported with backend sampling const bool need_logits = task.params.sampling.n_probs > 0; + bool backend_sampling = true; + + backend_sampling &= task.params.sampling.backend_sampling; + + // TODO: speculative decoding requires multiple samples per batch - not supported yet + backend_sampling &= !(slot.ctx_dft && task.params.speculative.n_max > 0); + + // TODO: getting post/pre sampling logits is not yet supported with backend sampling + backend_sampling &= !need_logits; + // TODO: tmp until backend sampling is fully implemented - if (task.params.sampling.backend_sampling && !need_logits) { + if (backend_sampling) { llama_set_sampler(ctx, slot.id, common_sampler_get(slot.smpl.get())); } else { llama_set_sampler(ctx, slot.id, nullptr);