sampling : remove backend-dist option (wip)

This commit removes the `--backend-dist` option and instead uses the
configured --samplers chain to determine which samplers run on the
backend.

Backend sampling is still enabled using With `--backend_sampling`, and
the sampler chain, either explictly specified using `--samplers` or the
default, is automatically analyzed to determine which samplers can run
on the backend. The system finds the longest contiguous chain of
backend supported samplers from the start of the sampler sequence.
For example:

* If the chain is `top-k -> temperature -> top-p`, and both `top-k` and
  `temperature` are backend-supported but `top-p` is not, then `top-k`
  and `temperature` will run on the backend, while `top-p` and
  subsequent samplers run on the CPU.

* If all configured samplers are supported, the final distribution
  sampling will also happen on the backend, transferring only the
  sampled token IDs back to the host.

* If the sampler chain starts with an unsupported sampler (e.g.,
  `penalties`), all sampling runs on the CPU. Note that this is
  currently the case with the default sampler so to use backend sampling
  it is required to specify a sampler chain. See below for an example.

The following shows how llama-cli can be run with backend sampling:
```console
$ llama-cli -m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf \
    --prompt 'What is the capital of Sweden?' \
    -n 20 \
    -no-cnv \
    --verbose-prompt \
    -ngl 40 \
    --backend-sampling \
    --samplers 'top_k;temperature'
```
In this case the all sampling will happen on the backend since both
`top_k` and `temperature` are supported backend samplers.

To enable a partial backend sampling (hybrid sampling), for example
running `top_k` and `temperature` on the backend and `typ_p` on the CPU
the following sampler chain could be specified:
```console
$ llama-cli -m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf \
    --prompt 'What is the capital of Sweden?' \
    -n 20 \
    -no-cnv \
    --verbose-prompt \
    -ngl 40 \
    --backend-sampling \
    --samplers 'top_k;temperature;top_p'
```

If this looks good then I'll follow up with updates the llama-cli and
llama-server documentation to reflect these changes.
This commit is contained in:
Daniel Bevenius 2025-11-25 13:45:02 +01:00
parent 53dca56d9b
commit 9e5e09d087
No known key found for this signature in database
15 changed files with 214 additions and 96 deletions

View File

@ -1520,14 +1520,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sampling.backend_sampling = true;
}
).set_sparam());
add_opt(common_arg(
{"--backend-dist"},
"perform final (distribution) sampling on backend (default: disabled)",
[](common_params & params) {
params.sampling.backend_dist = true;
params.sampling.backend_sampling = true;
}
).set_sparam());
add_opt(common_arg(
{"--pooling"}, "{none,mean,cls,last,rank}",
"pooling type for embeddings, use model default if unspecified",

View File

@ -1021,12 +1021,17 @@ struct common_init_result common_init_from_params(common_params & params) {
// backend sampling initialization
if (params.sampling.backend_sampling) {
iparams.samplers_seq_config.resize(cparams.n_seq_max);
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
iparams.samplers_seq_config[i] = { i, common_sampler_backend_init(model, params.sampling) };
llama_sampler * backend_chain = common_sampler_backend_init(model, params.sampling);
if (backend_chain != nullptr) {
iparams.samplers_seq_config.resize(cparams.n_seq_max);
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
iparams.samplers_seq_config[i] = { i, llama_sampler_clone(backend_chain) };
}
cparams.samplers = iparams.samplers_seq_config.data();
cparams.n_samplers = cparams.n_seq_max;
llama_sampler_free(backend_chain);
}
cparams.samplers = iparams.samplers_seq_config.data();
cparams.n_samplers = cparams.n_seq_max;
}
llama_context * lctx = llama_init_from_model(model, cparams);

View File

@ -213,9 +213,7 @@ struct common_params_sampling {
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
// Backend sampling flags
bool backend_sampling = false; // enable backend sampling
bool backend_dist = false; // backend performs final sampling (dist)
// print the parameters into a string
std::string print() const;

View File

@ -105,7 +105,8 @@ struct common_sampler {
common_params_sampling params;
struct llama_sampler * grmr;
struct llama_sampler * chain;
struct llama_sampler * chain; // CPU sampling chain
struct llama_sampler * backend_chain; // Backend sampling chain
ring_buffer<llama_token> prev;
@ -118,6 +119,9 @@ struct common_sampler {
llama_sampler_reset(grmr);
llama_sampler_reset(chain);
if (backend_chain) {
llama_sampler_reset(backend_chain);
}
}
void set_logits(struct llama_context * ctx, int idx) {
@ -165,6 +169,20 @@ static bool sampler_enabled(const struct common_params_sampling & params, enum c
return std::find(params.samplers.begin(), params.samplers.end(), type) != params.samplers.end();
}
static bool sampler_backend_supported(enum common_sampler_type type) {
switch (type) {
case COMMON_SAMPLER_TYPE_TOP_K:
case COMMON_SAMPLER_TYPE_TEMPERATURE:
return true;
default:
return false;
}
}
static bool has_logit_bias(const struct common_params_sampling & params) {
return !params.logit_bias.empty();
}
std::string common_params_sampling::print() const {
char result[1024];
@ -249,22 +267,86 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
}
auto * result = new common_sampler {
/* .params = */ params,
/* .grmr = */ grmr,
/* .chain = */ llama_sampler_chain_init(lparams),
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {},
/* .cur_p = */ {},
/* .params = */ params,
/* .grmr = */ grmr,
/* .chain = */ llama_sampler_chain_init(lparams),
/* .backend_chain = */ nullptr,
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {},
/* .cur_p = */ {},
};
llama_sampler_chain_add(result->chain,
llama_sampler_init_logit_bias(
llama_vocab_n_tokens(vocab),
params.logit_bias.size(),
params.logit_bias.data()));
size_t backend_sampler_count = 0;
if (params.backend_sampling && params.mirostat == 0) {
if (has_logit_bias(params)) {
backend_sampler_count++;
}
// Find the longest contiguous chain of backend-supported samplers from the start
for (const auto & sampler_type : params.samplers) {
if (sampler_backend_supported(sampler_type)) {
backend_sampler_count++;
} else {
break;
}
}
}
// If the samplers combination is supported then we can build the backend chain.
if (backend_sampler_count > 0 || (params.backend_sampling && has_logit_bias(params))) {
llama_sampler_chain_params backend_params = llama_sampler_chain_default_params();
backend_params.no_perf = params.no_perf;
result->backend_chain = llama_sampler_chain_init(backend_params);
if (has_logit_bias(params)) {
llama_sampler_chain_add(result->backend_chain,
llama_sampler_backend_init_logit_bias(
llama_vocab_n_tokens(vocab),
params.logit_bias.size(),
params.logit_bias.data()));
}
size_t backend_idx = 0;
for (const auto & sampler_type : params.samplers) {
if (backend_idx >= backend_sampler_count - has_logit_bias(params)) {
break;
}
switch (sampler_type) {
case COMMON_SAMPLER_TYPE_TOP_K:
if (params.top_k > 0) {
llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_top_k(params.top_k));
}
backend_idx++;
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
if (params.temp > 0.0f) {
llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_temp(params.temp));
}
backend_idx++;
break;
default:
GGML_ASSERT(false && "unsupported backend sampler");
}
}
}
size_t cpu_start_idx = backend_sampler_count - has_logit_bias(params);
bool cpu_has_samplers = cpu_start_idx < params.samplers.size();
// Build CPU chain
if (!params.backend_sampling || !has_logit_bias(params)) {
llama_sampler_chain_add(result->chain,
llama_sampler_init_logit_bias(
llama_vocab_n_tokens(vocab),
params.logit_bias.size(),
params.logit_bias.data()));
}
if (params.mirostat == 0) {
for (const auto & cnstr : params.samplers) {
// Add remaining CPU samplers
for (size_t i = cpu_start_idx; i < params.samplers.size(); i++) {
const auto & cnstr = params.samplers[i];
switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY:
{
@ -308,7 +390,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
GGML_ASSERT(false && "unknown sampler type");
}
}
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
// If all samplers are on backend, add dist to backend; otherwise add to CPU
if (result->backend_chain && !cpu_has_samplers) {
llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_dist(params.seed));
} else {
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
}
} else if (params.mirostat == 1) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
@ -323,36 +411,74 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
}
struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params) {
if (!params.backend_sampling) {
if (!params.backend_sampling || params.mirostat != 0) {
return nullptr;
}
const llama_vocab * vocab = llama_model_get_vocab(model);
// Determine the split point for backend sampling using the same logic as common_sampler_init
size_t backend_sampler_count = 0;
if (has_logit_bias(params)) {
backend_sampler_count++;
}
// Find the longest contiguous chain of backend-supported samplers from the start
for (const auto & sampler_type : params.samplers) {
if (sampler_backend_supported(sampler_type)) {
backend_sampler_count++;
} else {
break;
}
}
if (backend_sampler_count == 0 && !has_logit_bias(params)) {
return nullptr;
}
llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
chain_params.no_perf = params.no_perf;
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
const bool enable_temp = params.temp > 0.0f && sampler_enabled(params, COMMON_SAMPLER_TYPE_TEMPERATURE);
const bool enable_top_k = params.top_k > 0 && sampler_enabled(params, COMMON_SAMPLER_TYPE_TOP_K);
const bool enable_dist = params.backend_dist;
if (!params.logit_bias.empty()) {
// Add logit_bias to backend chain if present
if (has_logit_bias(params)) {
llama_sampler_chain_add(chain, llama_sampler_backend_init_logit_bias(
llama_vocab_n_tokens(vocab),
params.logit_bias.size(),
params.logit_bias.data()));
}
if (enable_temp) {
llama_sampler_chain_add(chain, llama_sampler_backend_init_temp(params.temp));
size_t backend_idx = 0;
for (const auto & sampler_type : params.samplers) {
if (backend_idx >= backend_sampler_count - has_logit_bias(params)) {
break;
}
switch (sampler_type) {
case COMMON_SAMPLER_TYPE_TOP_K:
if (params.top_k > 0) {
llama_sampler_chain_add(chain, llama_sampler_backend_init_top_k(params.top_k));
}
backend_idx++;
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
if (params.temp > 0.0f) {
llama_sampler_chain_add(chain, llama_sampler_backend_init_temp(params.temp));
}
backend_idx++;
break;
default:
GGML_ASSERT(false && "unsupported backend sampler");
}
}
if (enable_top_k) {
llama_sampler_chain_add(chain, llama_sampler_backend_init_top_k(params.top_k));
}
// Determine if we should add dist sampler to backend chain
// Only add it if all samplers from params.samplers are on the backend
size_t cpu_start_idx = backend_sampler_count - has_logit_bias(params);
bool cpu_has_samplers = cpu_start_idx < params.samplers.size();
if (enable_dist) {
if (!cpu_has_samplers) {
llama_sampler_chain_add(chain, llama_sampler_backend_init_dist(params.seed));
}
@ -362,9 +488,12 @@ struct llama_sampler * common_sampler_backend_init(const struct llama_model * mo
void common_sampler_free(struct common_sampler * gsmpl) {
if (gsmpl) {
llama_sampler_free(gsmpl->grmr);
llama_sampler_free(gsmpl->chain);
if (gsmpl->backend_chain) {
llama_sampler_free(gsmpl->backend_chain);
}
delete gsmpl;
}
}
@ -387,12 +516,13 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
return new common_sampler {
/* .params = */ gsmpl->params,
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .chain = */ llama_sampler_clone(gsmpl->chain),
/* .prev = */ gsmpl->prev,
/* .cur = */ gsmpl->cur,
/* .cur_p = */ gsmpl->cur_p,
/* .params = */ gsmpl->params,
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .chain = */ llama_sampler_clone(gsmpl->chain),
/* .backend_chain = */ gsmpl->backend_chain ? llama_sampler_clone(gsmpl->backend_chain) : nullptr,
/* .prev = */ gsmpl->prev,
/* .cur = */ gsmpl->cur,
/* .cur_p = */ gsmpl->cur_p,
};
}

View File

@ -45,25 +45,38 @@ llama_print_timings: total time = 4156.04 ms
### Using backend samplers
It is possible to run this example using backend samplers so that sampling is
performed on the backend device, like a GPU.
performed on a backend device, like a GPU.
```bash
./llama-batched \
-m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf -p "Hello my name is" \
-np 4 -kvu \
--backend_sampling --top-k 80 --backend_dist
-np 4 \
-kvu \
--backend_sampling \
--samplers 'top_k;temperature' \
--top-k 80
```
The `--verbose` flag can be added to see more detailed output and also show
that the backend samplers are being used. The above example will perform distribution
sampling on the backend device and only transfer the sampled token ids back to the host.
The samplers specified with `--samplers` must be supported by the backend and
this is why we are explicitly specifying only `top_k` and `temperature` here as
at the time of writing these are supported.
It is also possible to perform partial sampling on the backend, and then allow CPU samplers
to process those results further. This is sometimes referred to as hybrid sampling.
For an example of this we can remove `--backend_dist` from the above command:
```bash
./llama-batched \
-m models/Qwen2.5-VL-3B-Instruct-Q8_0.gguf -p "Hello my name is" \
-np 4 -kvu \
--backend_sampling --top-k 80 -v
```
This will perform the top-k filtering on the backend device, and then transfer the filtered logits
back to the host for sampling.
The `--verbose` flag can be added to see more detailed output and also show
that the backend samplers are being used.
With `--backend_sampling` enabled, the sampler chain is automatically analyzed
to determine which samplers can run on the backend. The system finds the longest
contiguous chain of backend-supported samplers from the start of the sampler
sequence. For example:
* If the chain is `top-k -> temperature -> top-p`, and both `top-k` and
`temperature` are backend-supported but `top-p` is not, then `top-k` and
`temperature` will run on the backend, while `top-p` and subsequent samplers
run on the CPU.
* If all configured samplers are supported, the final distribution sampling will
also happen on the backend, transferring only the sampled token IDs back to the
host.
* If the sampler chain starts with an unsupported sampler (e.g., `penalties`),
all sampling runs on the CPU.
**Note:** The default sampler chain includes `penalties` as the first sampler,
which is not backend-supported yet. To use backend sampling, you must explicitly
configure a sampler chain that starts with backend-supported samplers using
`--samplers` like shown above.

Binary file not shown.

View File

@ -79,7 +79,6 @@ json task_params::to_json(bool only_metrics) const {
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"backend_sampling", sampling.backend_sampling},
{"backend_dist", sampling.backend_dist},
{"lora", lora},
};
}
@ -139,7 +138,6 @@ json task_params::to_json(bool only_metrics) const {
{"timings_per_token", timings_per_token},
{"post_sampling_probs", post_sampling_probs},
{"backend_sampling", sampling.backend_sampling},
{"backend_dist", sampling.backend_dist},
{"lora", lora},
};
}
@ -210,9 +208,7 @@ task_params server_task::params_from_json_cmpl(
params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
const bool request_backend_sampling = json_value(data, "backend_sampling", defaults.sampling.backend_sampling);
const bool request_backend_dist = json_value(data, "backend_dist", defaults.sampling.backend_dist);
params.sampling.backend_sampling = defaults.sampling.backend_sampling && request_backend_sampling;
params.sampling.backend_dist = params.sampling.backend_sampling && request_backend_dist;
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);

View File

@ -1047,9 +1047,16 @@ struct server_context {
}
llama_sampler * backend_chain = common_sampler_backend_init(model, sampling);
// The sampler types configured with --samplers might not be supported
// by backend samplers in which case we disable backend sampling and
// fallback to CPU only sampling.
if (backend_chain == nullptr) {
SLT_ERR(slot, "%s", "failed to initialize backend sampler\n");
return false;
if (slot.backend_sampler != nullptr) {
llama_set_backend_sampler(ctx, slot.id, nullptr);
slot.backend_sampler = nullptr;
}
SLT_INF(slot, "%s", "no backend samplers configured (sampler chain doesn't start with backend-supported samplers)\n");
return true;
}
if (slot.backend_sampler != nullptr) {
@ -1059,6 +1066,7 @@ struct server_context {
slot.backend_sampler = backend_chain;
llama_set_backend_sampler(ctx, slot.id, backend_chain);
SLT_INF(slot, "%s", "configured backend samplers\n");
return true;
}

View File

@ -179,11 +179,6 @@
key: 'backend_sampling',
label: 'Backend sampling',
type: 'checkbox'
},
{
key: 'backend_dist',
label: 'Backend dist sampling',
type: 'checkbox'
}
]
},
@ -297,10 +292,6 @@
function handleConfigChange(key: string, value: string | boolean) {
localConfig[key] = value;
if (key === 'backend_sampling' && value === false) {
localConfig.backend_dist = false;
}
}
function handleReset() {

View File

@ -211,8 +211,7 @@
{/if}
{:else if field.type === 'checkbox'}
{@const pdfDisabled = field.key === 'pdfAsImage' && !supportsVision()}
{@const backendDistDisabled = field.key === 'backend_dist' && !localConfig.backend_sampling}
{@const isDisabled = pdfDisabled || backendDistDisabled}
{@const isDisabled = pdfDisabled}
<div class="flex items-start space-x-3">
<Checkbox
@ -246,10 +245,6 @@
PDF-to-image processing requires a vision-capable model. PDFs will be processed as
text.
</p>
{:else if backendDistDisabled}
<p class="text-xs text-muted-foreground">
Enable GPU sampling to allow GPU dist sampling.
</p>
{/if}
</div>
</div>

View File

@ -20,7 +20,6 @@ export const SETTING_CONFIG_DEFAULT: Record<string, string | number | boolean> =
// make sure these default values are in sync with `common.h`
samplers: 'top_k;typ_p;top_p;min_p;temperature',
backend_sampling: false,
backend_dist: false,
temperature: 0.8,
dynatemp_range: 0.0,
dynatemp_exponent: 1.0,
@ -56,8 +55,6 @@ export const SETTING_CONFIG_INFO: Record<string, string> = {
'The order at which samplers are applied, in simplified way. Default is "top_k;typ_p;top_p;min_p;temperature": top_k->typ_p->top_p->min_p->temperature',
backend_sampling:
'Enable backend-based samplers. When enabled, supported samplers run on the accelerator backend for faster sampling.',
backend_dist:
'Perform the final distribution sampling step on the backend. Requires backend sampling to be enabled.',
temperature:
'Controls the randomness of the generated text by affecting the probability distribution of the output tokens. Higher = more random, lower = more focused.',
dynatemp_range:

View File

@ -99,7 +99,6 @@ export class ChatService {
// Other parameters
samplers,
backend_sampling,
backend_dist,
custom,
timings_per_token
} = options;
@ -185,7 +184,6 @@ export class ChatService {
}
if (backend_sampling !== undefined) requestBody.backend_sampling = backend_sampling;
if (backend_dist !== undefined) requestBody.backend_dist = backend_dist;
if (timings_per_token !== undefined) requestBody.timings_per_token = timings_per_token;

View File

@ -301,9 +301,6 @@ class ChatStore {
if (currentConfig.backend_sampling !== undefined) {
apiOptions.backend_sampling = Boolean(currentConfig.backend_sampling);
}
if (currentConfig.backend_dist !== undefined) {
apiOptions.backend_dist = Boolean(currentConfig.backend_dist);
}
if (currentConfig.custom) {
apiOptions.custom = currentConfig.custom;
}

View File

@ -182,7 +182,6 @@ export interface ApiChatCompletionRequest {
// Sampler configuration
samplers?: string[];
backend_sampling?: boolean;
backend_dist?: boolean;
// Custom parameters (JSON string)
custom?: Record<string, unknown>;
timings_per_token?: boolean;

View File

@ -38,7 +38,6 @@ export interface SettingsChatServiceOptions {
// Sampler configuration
samplers?: string | string[];
backend_sampling?: boolean;
backend_dist?: boolean;
// Custom parameters
custom?: string;
timings_per_token?: boolean;