sampling : fix outputs and device checks
This commit is contained in:
parent
abc19635a3
commit
7864074fdb
|
|
@ -2098,21 +2098,25 @@ void llm_graph_context::build_sampling() const {
|
|||
sampler->iface->backend_apply(sampler, ctx0, gf, &data);
|
||||
|
||||
if (data.sampled != nullptr) {
|
||||
ggml_set_output(data.sampled);
|
||||
res->t_sampled[seq_id] = data.sampled;
|
||||
ggml_build_forward_expand(gf, data.sampled);
|
||||
}
|
||||
|
||||
if (data.probs != nullptr) {
|
||||
ggml_set_output(data.probs);
|
||||
res->t_sampled_probs[seq_id] = data.probs;
|
||||
ggml_build_forward_expand(gf, data.probs);
|
||||
}
|
||||
|
||||
if (data.logits != logits_seq) {
|
||||
ggml_set_output(data.logits);
|
||||
res->t_sampled_logits[seq_id] = data.logits;
|
||||
ggml_build_forward_expand(gf, res->t_sampled_logits[seq_id]);
|
||||
}
|
||||
|
||||
if (data.candidates != nullptr) {
|
||||
ggml_set_output(data.candidates);
|
||||
res->t_candidates[seq_id] = data.candidates;
|
||||
ggml_build_forward_expand(gf, data.candidates);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1018,9 +1018,8 @@ static bool llama_sampler_dist_backend_init(
|
|||
ggml_tensor * op = ggml_cumsum(ctx, probs);
|
||||
|
||||
auto * device = ggml_backend_buft_get_device(buft);
|
||||
GGML_ASSERT(device);
|
||||
|
||||
if (!ggml_backend_dev_supports_op(device, op)) {
|
||||
if (device && !ggml_backend_dev_supports_op(device, op)) {
|
||||
res = false;
|
||||
}
|
||||
|
||||
|
|
@ -1099,7 +1098,6 @@ static void llama_sampler_dist_backend_apply(
|
|||
ggml_set_name(sampled_token, "dist_sampled_token");
|
||||
}
|
||||
|
||||
ggml_set_output(sampled_token);
|
||||
data->sampled = sampled_token;
|
||||
}
|
||||
|
||||
|
|
@ -1192,9 +1190,8 @@ static bool llama_sampler_top_k_backend_init(
|
|||
ggml_tensor * op = ggml_top_k(ctx, logits, sctx->k);
|
||||
|
||||
auto * device = ggml_backend_buft_get_device(buft);
|
||||
GGML_ASSERT(device);
|
||||
|
||||
if (!ggml_backend_dev_supports_op(device, op)) {
|
||||
if (device && !ggml_backend_dev_supports_op(device, op)) {
|
||||
res = false;
|
||||
}
|
||||
|
||||
|
|
@ -1411,9 +1408,6 @@ static void llama_sampler_top_p_backend_apply(
|
|||
data->logits = ggml_add(ctx, sorted_logits, top_p_bias);
|
||||
ggml_set_name(data->logits, "top_p_logits");
|
||||
|
||||
ggml_set_output(data->candidates);
|
||||
ggml_set_output(data->logits);
|
||||
|
||||
GGML_UNUSED(gf);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue