diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 0fc3cf9195..21d2f0f417 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -197,6 +197,8 @@ struct slot_params { {"speculative.p_min", speculative.p_min}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, + {"backend_sampling", sampling.backend_sampling}, + {"backend_dist", sampling.backend_dist}, {"lora", lora}, }; } @@ -255,6 +257,8 @@ struct slot_params { {"speculative.p_min", speculative.p_min}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, + {"backend_sampling", sampling.backend_sampling}, + {"backend_dist", sampling.backend_dist}, {"lora", lora}, }; } @@ -357,6 +361,11 @@ struct server_task { params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); + const bool request_backend_sampling = json_value(data, "backend_sampling", defaults.sampling.backend_sampling); + const bool request_backend_dist = json_value(data, "backend_dist", defaults.sampling.backend_dist); + params.sampling.backend_sampling = defaults.sampling.backend_sampling && request_backend_sampling; + params.sampling.backend_dist = params.sampling.backend_sampling && request_backend_dist; + params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); @@ -1702,6 +1711,7 @@ struct server_slot { json json_schema; struct common_sampler * smpl = nullptr; + llama_sampler * backend_sampler = nullptr; llama_token sampled; @@ -1747,6 +1757,13 @@ struct server_slot { n_draft_total = 0; n_draft_accepted = 0; + if (backend_sampler != nullptr) { + if (ctx != nullptr) { + llama_set_backend_sampler(ctx, id, nullptr); + } + backend_sampler = nullptr; + } + task.reset(); task_prev.reset(); @@ -2368,6 +2385,13 @@ struct server_context { common_sampler_free(slot.smpl); slot.smpl = nullptr; + if (slot.backend_sampler != nullptr) { + if (ctx != nullptr) { + llama_set_backend_sampler(ctx, slot.id, nullptr); + } + slot.backend_sampler = nullptr; + } + llama_free(slot.ctx_dft); slot.ctx_dft = nullptr; @@ -2840,6 +2864,11 @@ struct server_context { SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str()); } + if (!configure_slot_backend_sampler(slot, task.params.sampling)) { + send_error(task, "Failed to configure backend samplers", ERROR_TYPE_SERVER); + return false; + } + // initialize draft batch // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] if (slot.ctx_dft) { @@ -2857,6 +2886,31 @@ struct server_context { return true; } + bool configure_slot_backend_sampler(server_slot & slot, const common_params_sampling & sampling) { + if (!sampling.backend_sampling) { + if (slot.backend_sampler != nullptr) { + llama_set_backend_sampler(ctx, slot.id, nullptr); + slot.backend_sampler = nullptr; + } + return true; + } + + llama_sampler * backend_chain = common_sampler_backend_init(model, sampling); + if (backend_chain == nullptr) { + SLT_ERR(slot, "%s", "failed to initialize backend sampler\n"); + return false; + } + + if (slot.backend_sampler != nullptr) { + llama_set_backend_sampler(ctx, slot.id, nullptr); + slot.backend_sampler = nullptr; + } + + slot.backend_sampler = backend_chain; + llama_set_backend_sampler(ctx, slot.id, backend_chain); + return true; + } + bool process_token(completion_token_output & result, server_slot & slot) { // remember which tokens were sampled - used for repetition penalties during sampling const std::string token_str = result.text_to_send;