diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index a621c4ebf5..164195d802 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2098,21 +2098,25 @@ void llm_graph_context::build_sampling() const { sampler->iface->backend_apply(sampler, ctx0, gf, &data); if (data.sampled != nullptr) { + ggml_set_output(data.sampled); res->t_sampled[seq_id] = data.sampled; ggml_build_forward_expand(gf, data.sampled); } if (data.probs != nullptr) { + ggml_set_output(data.probs); res->t_sampled_probs[seq_id] = data.probs; ggml_build_forward_expand(gf, data.probs); } if (data.logits != logits_seq) { + ggml_set_output(data.logits); res->t_sampled_logits[seq_id] = data.logits; ggml_build_forward_expand(gf, res->t_sampled_logits[seq_id]); } if (data.candidates != nullptr) { + ggml_set_output(data.candidates); res->t_candidates[seq_id] = data.candidates; ggml_build_forward_expand(gf, data.candidates); } diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index e910b6e14e..004284c6be 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1018,9 +1018,8 @@ static bool llama_sampler_dist_backend_init( ggml_tensor * op = ggml_cumsum(ctx, probs); auto * device = ggml_backend_buft_get_device(buft); - GGML_ASSERT(device); - if (!ggml_backend_dev_supports_op(device, op)) { + if (device && !ggml_backend_dev_supports_op(device, op)) { res = false; } @@ -1099,7 +1098,6 @@ static void llama_sampler_dist_backend_apply( ggml_set_name(sampled_token, "dist_sampled_token"); } - ggml_set_output(sampled_token); data->sampled = sampled_token; } @@ -1192,9 +1190,8 @@ static bool llama_sampler_top_k_backend_init( ggml_tensor * op = ggml_top_k(ctx, logits, sctx->k); auto * device = ggml_backend_buft_get_device(buft); - GGML_ASSERT(device); - if (!ggml_backend_dev_supports_op(device, op)) { + if (device && !ggml_backend_dev_supports_op(device, op)) { res = false; } @@ -1411,9 +1408,6 @@ static void llama_sampler_top_p_backend_apply( data->logits = ggml_add(ctx, sorted_logits, top_p_bias); ggml_set_name(data->logits, "top_p_logits"); - ggml_set_output(data->candidates); - ggml_set_output(data->logits); - GGML_UNUSED(gf); }