diff --git a/common/arg.cpp b/common/arg.cpp index c7d4b22c9b..e29c3619b6 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1520,14 +1520,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sampling.backend_sampling = true; } ).set_sparam()); - add_opt(common_arg( - {"--backend-dist"}, - "perform final (distribution) sampling on backend (default: disabled)", - [](common_params & params) { - params.sampling.backend_dist = true; - params.sampling.backend_sampling = true; - } - ).set_sparam()); add_opt(common_arg( {"--pooling"}, "{none,mean,cls,last,rank}", "pooling type for embeddings, use model default if unspecified", diff --git a/common/common.cpp b/common/common.cpp index ff07ba4f23..bf81370730 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1021,12 +1021,17 @@ struct common_init_result common_init_from_params(common_params & params) { // backend sampling initialization if (params.sampling.backend_sampling) { - iparams.samplers_seq_config.resize(cparams.n_seq_max); - for (int i = 0; i < (int) cparams.n_seq_max; ++i) { - iparams.samplers_seq_config[i] = { i, common_sampler_backend_init(model, params.sampling) }; + llama_sampler * backend_chain = common_sampler_backend_init(model, params.sampling); + if (backend_chain != nullptr) { + iparams.samplers_seq_config.resize(cparams.n_seq_max); + for (int i = 0; i < (int) cparams.n_seq_max; ++i) { + iparams.samplers_seq_config[i] = { i, llama_sampler_clone(backend_chain) }; + } + cparams.samplers = iparams.samplers_seq_config.data(); + cparams.n_samplers = cparams.n_seq_max; + + llama_sampler_free(backend_chain); } - cparams.samplers = iparams.samplers_seq_config.data(); - cparams.n_samplers = cparams.n_seq_max; } llama_context * lctx = llama_init_from_model(model, cparams); diff --git a/common/common.h b/common/common.h index e48fe336ea..8e78ab32ea 100644 --- a/common/common.h +++ b/common/common.h @@ -213,9 +213,7 @@ struct common_params_sampling { std::vector logit_bias; // logit biases to apply std::vector logit_bias_eog; // pre-calculated logit biases for EOG tokens - // Backend sampling flags bool backend_sampling = false; // enable backend sampling - bool backend_dist = false; // backend performs final sampling (dist) // print the parameters into a string std::string print() const; diff --git a/common/sampling.cpp b/common/sampling.cpp index 0a2be6bf7d..c116771981 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -105,7 +105,8 @@ struct common_sampler { common_params_sampling params; struct llama_sampler * grmr; - struct llama_sampler * chain; + struct llama_sampler * chain; // CPU sampling chain + struct llama_sampler * backend_chain; // Backend sampling chain ring_buffer prev; @@ -118,6 +119,9 @@ struct common_sampler { llama_sampler_reset(grmr); llama_sampler_reset(chain); + if (backend_chain) { + llama_sampler_reset(backend_chain); + } } void set_logits(struct llama_context * ctx, int idx) { @@ -165,6 +169,20 @@ static bool sampler_enabled(const struct common_params_sampling & params, enum c return std::find(params.samplers.begin(), params.samplers.end(), type) != params.samplers.end(); } +static bool sampler_backend_supported(enum common_sampler_type type) { + switch (type) { + case COMMON_SAMPLER_TYPE_TOP_K: + case COMMON_SAMPLER_TYPE_TEMPERATURE: + return true; + default: + return false; + } +} + +static bool has_logit_bias(const struct common_params_sampling & params) { + return !params.logit_bias.empty(); +} + std::string common_params_sampling::print() const { char result[1024]; @@ -249,22 +267,86 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } auto * result = new common_sampler { - /* .params = */ params, - /* .grmr = */ grmr, - /* .chain = */ llama_sampler_chain_init(lparams), - /* .prev = */ ring_buffer(std::max(32, params.n_prev)), - /* .cur = */ {}, - /* .cur_p = */ {}, + /* .params = */ params, + /* .grmr = */ grmr, + /* .chain = */ llama_sampler_chain_init(lparams), + /* .backend_chain = */ nullptr, + /* .prev = */ ring_buffer(std::max(32, params.n_prev)), + /* .cur = */ {}, + /* .cur_p = */ {}, }; - llama_sampler_chain_add(result->chain, - llama_sampler_init_logit_bias( - llama_vocab_n_tokens(vocab), - params.logit_bias.size(), - params.logit_bias.data())); + size_t backend_sampler_count = 0; + if (params.backend_sampling && params.mirostat == 0) { + if (has_logit_bias(params)) { + backend_sampler_count++; + } + + // Find the longest contiguous chain of backend-supported samplers from the start + for (const auto & sampler_type : params.samplers) { + if (sampler_backend_supported(sampler_type)) { + backend_sampler_count++; + } else { + break; + } + } + } + + // If the samplers combination is supported then we can build the backend chain. + if (backend_sampler_count > 0 || (params.backend_sampling && has_logit_bias(params))) { + llama_sampler_chain_params backend_params = llama_sampler_chain_default_params(); + backend_params.no_perf = params.no_perf; + result->backend_chain = llama_sampler_chain_init(backend_params); + + if (has_logit_bias(params)) { + llama_sampler_chain_add(result->backend_chain, + llama_sampler_backend_init_logit_bias( + llama_vocab_n_tokens(vocab), + params.logit_bias.size(), + params.logit_bias.data())); + } + + size_t backend_idx = 0; + for (const auto & sampler_type : params.samplers) { + if (backend_idx >= backend_sampler_count - has_logit_bias(params)) { + break; + } + + switch (sampler_type) { + case COMMON_SAMPLER_TYPE_TOP_K: + if (params.top_k > 0) { + llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_top_k(params.top_k)); + } + backend_idx++; + break; + case COMMON_SAMPLER_TYPE_TEMPERATURE: + if (params.temp > 0.0f) { + llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_temp(params.temp)); + } + backend_idx++; + break; + default: + GGML_ASSERT(false && "unsupported backend sampler"); + } + } + } + + size_t cpu_start_idx = backend_sampler_count - has_logit_bias(params); + bool cpu_has_samplers = cpu_start_idx < params.samplers.size(); + + // Build CPU chain + if (!params.backend_sampling || !has_logit_bias(params)) { + llama_sampler_chain_add(result->chain, + llama_sampler_init_logit_bias( + llama_vocab_n_tokens(vocab), + params.logit_bias.size(), + params.logit_bias.data())); + } if (params.mirostat == 0) { - for (const auto & cnstr : params.samplers) { + // Add remaining CPU samplers + for (size_t i = cpu_start_idx; i < params.samplers.size(); i++) { + const auto & cnstr = params.samplers[i]; switch (cnstr) { case COMMON_SAMPLER_TYPE_DRY: { @@ -308,7 +390,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co GGML_ASSERT(false && "unknown sampler type"); } } - llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); + + // If all samplers are on backend, add dist to backend; otherwise add to CPU + if (result->backend_chain && !cpu_has_samplers) { + llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_dist(params.seed)); + } else { + llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); + } } else if (params.mirostat == 1) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); @@ -323,36 +411,74 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params) { - if (!params.backend_sampling) { + if (!params.backend_sampling || params.mirostat != 0) { return nullptr; } + const llama_vocab * vocab = llama_model_get_vocab(model); + // Determine the split point for backend sampling using the same logic as common_sampler_init + size_t backend_sampler_count = 0; + if (has_logit_bias(params)) { + backend_sampler_count++; + } + + // Find the longest contiguous chain of backend-supported samplers from the start + for (const auto & sampler_type : params.samplers) { + if (sampler_backend_supported(sampler_type)) { + backend_sampler_count++; + } else { + break; + } + } + + if (backend_sampler_count == 0 && !has_logit_bias(params)) { + return nullptr; + } + llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); chain_params.no_perf = params.no_perf; struct llama_sampler * chain = llama_sampler_chain_init(chain_params); - const bool enable_temp = params.temp > 0.0f && sampler_enabled(params, COMMON_SAMPLER_TYPE_TEMPERATURE); - const bool enable_top_k = params.top_k > 0 && sampler_enabled(params, COMMON_SAMPLER_TYPE_TOP_K); - const bool enable_dist = params.backend_dist; - - if (!params.logit_bias.empty()) { + // Add logit_bias to backend chain if present + if (has_logit_bias(params)) { llama_sampler_chain_add(chain, llama_sampler_backend_init_logit_bias( llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data())); } - if (enable_temp) { - llama_sampler_chain_add(chain, llama_sampler_backend_init_temp(params.temp)); + size_t backend_idx = 0; + for (const auto & sampler_type : params.samplers) { + if (backend_idx >= backend_sampler_count - has_logit_bias(params)) { + break; + } + + switch (sampler_type) { + case COMMON_SAMPLER_TYPE_TOP_K: + if (params.top_k > 0) { + llama_sampler_chain_add(chain, llama_sampler_backend_init_top_k(params.top_k)); + } + backend_idx++; + break; + case COMMON_SAMPLER_TYPE_TEMPERATURE: + if (params.temp > 0.0f) { + llama_sampler_chain_add(chain, llama_sampler_backend_init_temp(params.temp)); + } + backend_idx++; + break; + default: + GGML_ASSERT(false && "unsupported backend sampler"); + } } - if (enable_top_k) { - llama_sampler_chain_add(chain, llama_sampler_backend_init_top_k(params.top_k)); - } + // Determine if we should add dist sampler to backend chain + // Only add it if all samplers from params.samplers are on the backend + size_t cpu_start_idx = backend_sampler_count - has_logit_bias(params); + bool cpu_has_samplers = cpu_start_idx < params.samplers.size(); - if (enable_dist) { + if (!cpu_has_samplers) { llama_sampler_chain_add(chain, llama_sampler_backend_init_dist(params.seed)); } @@ -362,9 +488,12 @@ struct llama_sampler * common_sampler_backend_init(const struct llama_model * mo void common_sampler_free(struct common_sampler * gsmpl) { if (gsmpl) { llama_sampler_free(gsmpl->grmr); - llama_sampler_free(gsmpl->chain); + if (gsmpl->backend_chain) { + llama_sampler_free(gsmpl->backend_chain); + } + delete gsmpl; } } @@ -387,12 +516,13 @@ void common_sampler_reset(struct common_sampler * gsmpl) { struct common_sampler * common_sampler_clone(common_sampler * gsmpl) { return new common_sampler { - /* .params = */ gsmpl->params, - /* .grmr = */ llama_sampler_clone(gsmpl->grmr), - /* .chain = */ llama_sampler_clone(gsmpl->chain), - /* .prev = */ gsmpl->prev, - /* .cur = */ gsmpl->cur, - /* .cur_p = */ gsmpl->cur_p, + /* .params = */ gsmpl->params, + /* .grmr = */ llama_sampler_clone(gsmpl->grmr), + /* .chain = */ llama_sampler_clone(gsmpl->chain), + /* .backend_chain = */ gsmpl->backend_chain ? llama_sampler_clone(gsmpl->backend_chain) : nullptr, + /* .prev = */ gsmpl->prev, + /* .cur = */ gsmpl->cur, + /* .cur_p = */ gsmpl->cur_p, }; } diff --git a/examples/batched/README.md b/examples/batched/README.md index f10639220e..a68b45b290 100644 --- a/examples/batched/README.md +++ b/examples/batched/README.md @@ -45,25 +45,38 @@ llama_print_timings: total time = 4156.04 ms ### Using backend samplers It is possible to run this example using backend samplers so that sampling is -performed on the backend device, like a GPU. +performed on a backend device, like a GPU. ```bash ./llama-batched \ -m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf -p "Hello my name is" \ - -np 4 -kvu \ - --backend_sampling --top-k 80 --backend_dist + -np 4 \ + -kvu \ + --backend_sampling \ + --samplers 'top_k;temperature' \ + --top-k 80 ``` -The `--verbose` flag can be added to see more detailed output and also show -that the backend samplers are being used. The above example will perform distribution -sampling on the backend device and only transfer the sampled token ids back to the host. +The samplers specified with `--samplers` must be supported by the backend and +this is why we are explicitly specifying only `top_k` and `temperature` here as +at the time of writing these are supported. -It is also possible to perform partial sampling on the backend, and then allow CPU samplers -to process those results further. This is sometimes referred to as hybrid sampling. -For an example of this we can remove `--backend_dist` from the above command: -```bash -./llama-batched \ - -m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf -p "Hello my name is" \ - -np 4 -kvu \ - --backend_sampling --top-k 80 -v -``` -This will perform the top-k filtering on the backend device, and then transfer the filtered logits -back to the host for sampling. +The `--verbose` flag can be added to see more detailed output and also show +that the backend samplers are being used. + +With `--backend_sampling` enabled, the sampler chain is automatically analyzed +to determine which samplers can run on the backend. The system finds the longest +contiguous chain of backend-supported samplers from the start of the sampler +sequence. For example: +* If the chain is `top-k -> temperature -> top-p`, and both `top-k` and + `temperature` are backend-supported but `top-p` is not, then `top-k` and + `temperature` will run on the backend, while `top-p` and subsequent samplers + run on the CPU. +* If all configured samplers are supported, the final distribution sampling will + also happen on the backend, transferring only the sampled token IDs back to the + host. +* If the sampler chain starts with an unsupported sampler (e.g., `penalties`), + all sampling runs on the CPU. + +**Note:** The default sampler chain includes `penalties` as the first sampler, +which is not backend-supported yet. To use backend sampling, you must explicitly +configure a sampler chain that starts with backend-supported samplers using +`--samplers` like shown above. diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 14009a7a39..1a40fe16d1 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 9503262fe4..d083777fa7 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -79,7 +79,6 @@ json task_params::to_json(bool only_metrics) const { {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, {"backend_sampling", sampling.backend_sampling}, - {"backend_dist", sampling.backend_dist}, {"lora", lora}, }; } @@ -139,7 +138,6 @@ json task_params::to_json(bool only_metrics) const { {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, {"backend_sampling", sampling.backend_sampling}, - {"backend_dist", sampling.backend_dist}, {"lora", lora}, }; } @@ -210,9 +208,7 @@ task_params server_task::params_from_json_cmpl( params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); const bool request_backend_sampling = json_value(data, "backend_sampling", defaults.sampling.backend_sampling); - const bool request_backend_dist = json_value(data, "backend_dist", defaults.sampling.backend_dist); params.sampling.backend_sampling = defaults.sampling.backend_sampling && request_backend_sampling; - params.sampling.backend_dist = params.sampling.backend_sampling && request_backend_dist; params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 37882f5ff1..48a7cf2c82 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1047,9 +1047,16 @@ struct server_context { } llama_sampler * backend_chain = common_sampler_backend_init(model, sampling); + // The sampler types configured with --samplers might not be supported + // by backend samplers in which case we disable backend sampling and + // fallback to CPU only sampling. if (backend_chain == nullptr) { - SLT_ERR(slot, "%s", "failed to initialize backend sampler\n"); - return false; + if (slot.backend_sampler != nullptr) { + llama_set_backend_sampler(ctx, slot.id, nullptr); + slot.backend_sampler = nullptr; + } + SLT_INF(slot, "%s", "no backend samplers configured (sampler chain doesn't start with backend-supported samplers)\n"); + return true; } if (slot.backend_sampler != nullptr) { @@ -1059,6 +1066,7 @@ struct server_context { slot.backend_sampler = backend_chain; llama_set_backend_sampler(ctx, slot.id, backend_chain); + SLT_INF(slot, "%s", "configured backend samplers\n"); return true; } diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte index 78ad97a41c..b6f345c241 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte @@ -179,11 +179,6 @@ key: 'backend_sampling', label: 'Backend sampling', type: 'checkbox' - }, - { - key: 'backend_dist', - label: 'Backend dist sampling', - type: 'checkbox' } ] }, @@ -297,10 +292,6 @@ function handleConfigChange(key: string, value: string | boolean) { localConfig[key] = value; - - if (key === 'backend_sampling' && value === false) { - localConfig.backend_dist = false; - } } function handleReset() { diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte index 1bafaf137a..57862bc05e 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettingsFields.svelte @@ -211,8 +211,7 @@ {/if} {:else if field.type === 'checkbox'} {@const pdfDisabled = field.key === 'pdfAsImage' && !supportsVision()} - {@const backendDistDisabled = field.key === 'backend_dist' && !localConfig.backend_sampling} - {@const isDisabled = pdfDisabled || backendDistDisabled} + {@const isDisabled = pdfDisabled}
- {:else if backendDistDisabled} -

- Enable GPU sampling to allow GPU dist sampling. -

{/if}
diff --git a/tools/server/webui/src/lib/constants/settings-config.ts b/tools/server/webui/src/lib/constants/settings-config.ts index 67323bb720..f8187bbf49 100644 --- a/tools/server/webui/src/lib/constants/settings-config.ts +++ b/tools/server/webui/src/lib/constants/settings-config.ts @@ -20,7 +20,6 @@ export const SETTING_CONFIG_DEFAULT: Record = // make sure these default values are in sync with `common.h` samplers: 'top_k;typ_p;top_p;min_p;temperature', backend_sampling: false, - backend_dist: false, temperature: 0.8, dynatemp_range: 0.0, dynatemp_exponent: 1.0, @@ -56,8 +55,6 @@ export const SETTING_CONFIG_INFO: Record = { 'The order at which samplers are applied, in simplified way. Default is "top_k;typ_p;top_p;min_p;temperature": top_k->typ_p->top_p->min_p->temperature', backend_sampling: 'Enable backend-based samplers. When enabled, supported samplers run on the accelerator backend for faster sampling.', - backend_dist: - 'Perform the final distribution sampling step on the backend. Requires backend sampling to be enabled.', temperature: 'Controls the randomness of the generated text by affecting the probability distribution of the output tokens. Higher = more random, lower = more focused.', dynatemp_range: diff --git a/tools/server/webui/src/lib/services/chat.ts b/tools/server/webui/src/lib/services/chat.ts index 5cda8c8868..66134a524e 100644 --- a/tools/server/webui/src/lib/services/chat.ts +++ b/tools/server/webui/src/lib/services/chat.ts @@ -99,7 +99,6 @@ export class ChatService { // Other parameters samplers, backend_sampling, - backend_dist, custom, timings_per_token } = options; @@ -185,7 +184,6 @@ export class ChatService { } if (backend_sampling !== undefined) requestBody.backend_sampling = backend_sampling; - if (backend_dist !== undefined) requestBody.backend_dist = backend_dist; if (timings_per_token !== undefined) requestBody.timings_per_token = timings_per_token; diff --git a/tools/server/webui/src/lib/stores/chat.svelte.ts b/tools/server/webui/src/lib/stores/chat.svelte.ts index e00994bc46..d8978558e1 100644 --- a/tools/server/webui/src/lib/stores/chat.svelte.ts +++ b/tools/server/webui/src/lib/stores/chat.svelte.ts @@ -301,9 +301,6 @@ class ChatStore { if (currentConfig.backend_sampling !== undefined) { apiOptions.backend_sampling = Boolean(currentConfig.backend_sampling); } - if (currentConfig.backend_dist !== undefined) { - apiOptions.backend_dist = Boolean(currentConfig.backend_dist); - } if (currentConfig.custom) { apiOptions.custom = currentConfig.custom; } diff --git a/tools/server/webui/src/lib/types/api.d.ts b/tools/server/webui/src/lib/types/api.d.ts index 149d4fb118..592923b653 100644 --- a/tools/server/webui/src/lib/types/api.d.ts +++ b/tools/server/webui/src/lib/types/api.d.ts @@ -182,7 +182,6 @@ export interface ApiChatCompletionRequest { // Sampler configuration samplers?: string[]; backend_sampling?: boolean; - backend_dist?: boolean; // Custom parameters (JSON string) custom?: Record; timings_per_token?: boolean; diff --git a/tools/server/webui/src/lib/types/settings.d.ts b/tools/server/webui/src/lib/types/settings.d.ts index e68d107faa..25b232582b 100644 --- a/tools/server/webui/src/lib/types/settings.d.ts +++ b/tools/server/webui/src/lib/types/settings.d.ts @@ -38,7 +38,6 @@ export interface SettingsChatServiceOptions { // Sampler configuration samplers?: string | string[]; backend_sampling?: boolean; - backend_dist?: boolean; // Custom parameters custom?: string; timings_per_token?: boolean;