server : add backend sampling options/configuration
This commit is contained in:
parent
9fe9a00a8a
commit
f1f3e68511
|
|
@ -197,6 +197,8 @@ struct slot_params {
|
|||
{"speculative.p_min", speculative.p_min},
|
||||
{"timings_per_token", timings_per_token},
|
||||
{"post_sampling_probs", post_sampling_probs},
|
||||
{"backend_sampling", sampling.backend_sampling},
|
||||
{"backend_dist", sampling.backend_dist},
|
||||
{"lora", lora},
|
||||
};
|
||||
}
|
||||
|
|
@ -255,6 +257,8 @@ struct slot_params {
|
|||
{"speculative.p_min", speculative.p_min},
|
||||
{"timings_per_token", timings_per_token},
|
||||
{"post_sampling_probs", post_sampling_probs},
|
||||
{"backend_sampling", sampling.backend_sampling},
|
||||
{"backend_dist", sampling.backend_dist},
|
||||
{"lora", lora},
|
||||
};
|
||||
}
|
||||
|
|
@ -357,6 +361,11 @@ struct server_task {
|
|||
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
|
||||
params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
|
||||
|
||||
const bool request_backend_sampling = json_value(data, "backend_sampling", defaults.sampling.backend_sampling);
|
||||
const bool request_backend_dist = json_value(data, "backend_dist", defaults.sampling.backend_dist);
|
||||
params.sampling.backend_sampling = defaults.sampling.backend_sampling && request_backend_sampling;
|
||||
params.sampling.backend_dist = params.sampling.backend_sampling && request_backend_dist;
|
||||
|
||||
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
|
||||
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
|
||||
params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
|
||||
|
|
@ -1702,6 +1711,7 @@ struct server_slot {
|
|||
json json_schema;
|
||||
|
||||
struct common_sampler * smpl = nullptr;
|
||||
llama_sampler * backend_sampler = nullptr;
|
||||
|
||||
llama_token sampled;
|
||||
|
||||
|
|
@ -1747,6 +1757,13 @@ struct server_slot {
|
|||
n_draft_total = 0;
|
||||
n_draft_accepted = 0;
|
||||
|
||||
if (backend_sampler != nullptr) {
|
||||
if (ctx != nullptr) {
|
||||
llama_set_backend_sampler(ctx, id, nullptr);
|
||||
}
|
||||
backend_sampler = nullptr;
|
||||
}
|
||||
|
||||
task.reset();
|
||||
task_prev.reset();
|
||||
|
||||
|
|
@ -2368,6 +2385,13 @@ struct server_context {
|
|||
common_sampler_free(slot.smpl);
|
||||
slot.smpl = nullptr;
|
||||
|
||||
if (slot.backend_sampler != nullptr) {
|
||||
if (ctx != nullptr) {
|
||||
llama_set_backend_sampler(ctx, slot.id, nullptr);
|
||||
}
|
||||
slot.backend_sampler = nullptr;
|
||||
}
|
||||
|
||||
llama_free(slot.ctx_dft);
|
||||
slot.ctx_dft = nullptr;
|
||||
|
||||
|
|
@ -2840,6 +2864,11 @@ struct server_context {
|
|||
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str());
|
||||
}
|
||||
|
||||
if (!configure_slot_backend_sampler(slot, task.params.sampling)) {
|
||||
send_error(task, "Failed to configure backend samplers", ERROR_TYPE_SERVER);
|
||||
return false;
|
||||
}
|
||||
|
||||
// initialize draft batch
|
||||
// TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK]
|
||||
if (slot.ctx_dft) {
|
||||
|
|
@ -2857,6 +2886,31 @@ struct server_context {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool configure_slot_backend_sampler(server_slot & slot, const common_params_sampling & sampling) {
|
||||
if (!sampling.backend_sampling) {
|
||||
if (slot.backend_sampler != nullptr) {
|
||||
llama_set_backend_sampler(ctx, slot.id, nullptr);
|
||||
slot.backend_sampler = nullptr;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
llama_sampler * backend_chain = common_sampler_backend_init(model, sampling);
|
||||
if (backend_chain == nullptr) {
|
||||
SLT_ERR(slot, "%s", "failed to initialize backend sampler\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (slot.backend_sampler != nullptr) {
|
||||
llama_set_backend_sampler(ctx, slot.id, nullptr);
|
||||
slot.backend_sampler = nullptr;
|
||||
}
|
||||
|
||||
slot.backend_sampler = backend_chain;
|
||||
llama_set_backend_sampler(ctx, slot.id, backend_chain);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool process_token(completion_token_output & result, server_slot & slot) {
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
const std::string token_str = result.text_to_send;
|
||||
|
|
|
|||
Loading…
Reference in New Issue