server : handle unsupported cases

This commit is contained in:
Georgi Gerganov 2025-12-09 10:55:11 +02:00
parent f3beb22b17
commit 560ac16f7d
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 15 additions and 6 deletions

View File

@ -1226,16 +1226,16 @@ static void llama_sampler_top_k_backend_apply(
struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, sctx->k); struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, sctx->k);
ggml_set_name(top_k, "top_k"); ggml_set_name(top_k, "top_k");
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
ggml_set_name(top_k_rows, "top_k_rows");
if (data->candidates) { if (data->candidates) {
data->candidates = ggml_get_rows(ctx, data->candidates, top_k); data->candidates = ggml_get_rows(ctx, data->candidates, top_k);
} else { } else {
data->candidates = top_k; data->candidates = top_k;
} }
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
ggml_set_name(top_k_rows, "top_k_rows");
data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k); data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k);
GGML_UNUSED(gf); GGML_UNUSED(gf);

View File

@ -1056,11 +1056,20 @@ struct server_context_impl {
return false; return false;
} }
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
const bool need_logits = task.params.sampling.n_probs > 0; const bool need_logits = task.params.sampling.n_probs > 0;
bool backend_sampling = true;
backend_sampling &= task.params.sampling.backend_sampling;
// TODO: speculative decoding requires multiple samples per batch - not supported yet
backend_sampling &= !(slot.ctx_dft && task.params.speculative.n_max > 0);
// TODO: getting post/pre sampling logits is not yet supported with backend sampling
backend_sampling &= !need_logits;
// TODO: tmp until backend sampling is fully implemented // TODO: tmp until backend sampling is fully implemented
if (task.params.sampling.backend_sampling && !need_logits) { if (backend_sampling) {
llama_set_sampler(ctx, slot.id, common_sampler_get(slot.smpl.get())); llama_set_sampler(ctx, slot.id, common_sampler_get(slot.smpl.get()));
} else { } else {
llama_set_sampler(ctx, slot.id, nullptr); llama_set_sampler(ctx, slot.id, nullptr);