server : add backend sampling options/configuration

This commit is contained in:
Daniel Bevenius 2025-11-17 15:31:30 +01:00
parent 9fe9a00a8a
commit f1f3e68511
No known key found for this signature in database
1 changed files with 54 additions and 0 deletions

View File

@ -197,6 +197,8 @@ struct slot_params {
{"speculative.p_min", speculative.p_min},
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"backend_sampling", sampling.backend_sampling},
{"backend_dist", sampling.backend_dist},
{"lora", lora},
};
}
@ -255,6 +257,8 @@ struct slot_params {
{"speculative.p_min", speculative.p_min},
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"backend_sampling", sampling.backend_sampling},
{"backend_dist", sampling.backend_dist},
{"lora", lora},
};
}
@ -357,6 +361,11 @@ struct server_task {
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
const bool request_backend_sampling = json_value(data, "backend_sampling", defaults.sampling.backend_sampling);
const bool request_backend_dist = json_value(data, "backend_dist", defaults.sampling.backend_dist);
params.sampling.backend_sampling = defaults.sampling.backend_sampling && request_backend_sampling;
params.sampling.backend_dist = params.sampling.backend_sampling && request_backend_dist;
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
@ -1702,6 +1711,7 @@ struct server_slot {
json json_schema;
struct common_sampler * smpl = nullptr;
llama_sampler * backend_sampler = nullptr;
llama_token sampled;
@ -1747,6 +1757,13 @@ struct server_slot {
n_draft_total = 0;
n_draft_accepted = 0;
if (backend_sampler != nullptr) {
if (ctx != nullptr) {
llama_set_backend_sampler(ctx, id, nullptr);
}
backend_sampler = nullptr;
}
task.reset();
task_prev.reset();
@ -2368,6 +2385,13 @@ struct server_context {
common_sampler_free(slot.smpl);
slot.smpl = nullptr;
if (slot.backend_sampler != nullptr) {
if (ctx != nullptr) {
llama_set_backend_sampler(ctx, slot.id, nullptr);
}
slot.backend_sampler = nullptr;
}
llama_free(slot.ctx_dft);
slot.ctx_dft = nullptr;
@ -2840,6 +2864,11 @@ struct server_context {
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str());
}
if (!configure_slot_backend_sampler(slot, task.params.sampling)) {
send_error(task, "Failed to configure backend samplers", ERROR_TYPE_SERVER);
return false;
}
// initialize draft batch
// TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
if (slot.ctx_dft) {
@ -2857,6 +2886,31 @@ struct server_context {
return true;
}
bool configure_slot_backend_sampler(server_slot & slot, const common_params_sampling & sampling) {
if (!sampling.backend_sampling) {
if (slot.backend_sampler != nullptr) {
llama_set_backend_sampler(ctx, slot.id, nullptr);
slot.backend_sampler = nullptr;
}
return true;
}
llama_sampler * backend_chain = common_sampler_backend_init(model, sampling);
if (backend_chain == nullptr) {
SLT_ERR(slot, "%s", "failed to initialize backend sampler\n");
return false;
}
if (slot.backend_sampler != nullptr) {
llama_set_backend_sampler(ctx, slot.id, nullptr);
slot.backend_sampler = nullptr;
}
slot.backend_sampler = backend_chain;
llama_set_backend_sampler(ctx, slot.id, backend_chain);
return true;
}
bool process_token(completion_token_output & result, server_slot & slot) {
// remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = result.text_to_send;