sampling : fix outputs and device checks

This commit is contained in:
Georgi Gerganov 2025-12-04 19:33:01 +02:00
parent abc19635a3
commit 7864074fdb
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 6 additions and 8 deletions

View File

@ -2098,21 +2098,25 @@ void llm_graph_context::build_sampling() const {
sampler->iface->backend_apply(sampler, ctx0, gf, &data);
if (data.sampled != nullptr) {
ggml_set_output(data.sampled);
res->t_sampled[seq_id] = data.sampled;
ggml_build_forward_expand(gf, data.sampled);
}
if (data.probs != nullptr) {
ggml_set_output(data.probs);
res->t_sampled_probs[seq_id] = data.probs;
ggml_build_forward_expand(gf, data.probs);
}
if (data.logits != logits_seq) {
ggml_set_output(data.logits);
res->t_sampled_logits[seq_id] = data.logits;
ggml_build_forward_expand(gf, res->t_sampled_logits[seq_id]);
}
if (data.candidates != nullptr) {
ggml_set_output(data.candidates);
res->t_candidates[seq_id] = data.candidates;
ggml_build_forward_expand(gf, data.candidates);
}

View File

@ -1018,9 +1018,8 @@ static bool llama_sampler_dist_backend_init(
ggml_tensor * op = ggml_cumsum(ctx, probs);
auto * device = ggml_backend_buft_get_device(buft);
GGML_ASSERT(device);
if (!ggml_backend_dev_supports_op(device, op)) {
if (device && !ggml_backend_dev_supports_op(device, op)) {
res = false;
}
@ -1099,7 +1098,6 @@ static void llama_sampler_dist_backend_apply(
ggml_set_name(sampled_token, "dist_sampled_token");
}
ggml_set_output(sampled_token);
data->sampled = sampled_token;
}
@ -1192,9 +1190,8 @@ static bool llama_sampler_top_k_backend_init(
ggml_tensor * op = ggml_top_k(ctx, logits, sctx->k);
auto * device = ggml_backend_buft_get_device(buft);
GGML_ASSERT(device);
if (!ggml_backend_dev_supports_op(device, op)) {
if (device && !ggml_backend_dev_supports_op(device, op)) {
res = false;
}
@ -1411,9 +1408,6 @@ static void llama_sampler_top_p_backend_apply(
data->logits = ggml_add(ctx, sorted_logits, top_p_bias);
ggml_set_name(data->logits, "top_p_logits");
ggml_set_output(data->candidates);
ggml_set_output(data->logits);
GGML_UNUSED(gf);
}