sampling : fix outputs and device checks
This commit is contained in:
parent
abc19635a3
commit
7864074fdb
|
|
@ -2098,21 +2098,25 @@ void llm_graph_context::build_sampling() const {
|
||||||
sampler->iface->backend_apply(sampler, ctx0, gf, &data);
|
sampler->iface->backend_apply(sampler, ctx0, gf, &data);
|
||||||
|
|
||||||
if (data.sampled != nullptr) {
|
if (data.sampled != nullptr) {
|
||||||
|
ggml_set_output(data.sampled);
|
||||||
res->t_sampled[seq_id] = data.sampled;
|
res->t_sampled[seq_id] = data.sampled;
|
||||||
ggml_build_forward_expand(gf, data.sampled);
|
ggml_build_forward_expand(gf, data.sampled);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (data.probs != nullptr) {
|
if (data.probs != nullptr) {
|
||||||
|
ggml_set_output(data.probs);
|
||||||
res->t_sampled_probs[seq_id] = data.probs;
|
res->t_sampled_probs[seq_id] = data.probs;
|
||||||
ggml_build_forward_expand(gf, data.probs);
|
ggml_build_forward_expand(gf, data.probs);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (data.logits != logits_seq) {
|
if (data.logits != logits_seq) {
|
||||||
|
ggml_set_output(data.logits);
|
||||||
res->t_sampled_logits[seq_id] = data.logits;
|
res->t_sampled_logits[seq_id] = data.logits;
|
||||||
ggml_build_forward_expand(gf, res->t_sampled_logits[seq_id]);
|
ggml_build_forward_expand(gf, res->t_sampled_logits[seq_id]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (data.candidates != nullptr) {
|
if (data.candidates != nullptr) {
|
||||||
|
ggml_set_output(data.candidates);
|
||||||
res->t_candidates[seq_id] = data.candidates;
|
res->t_candidates[seq_id] = data.candidates;
|
||||||
ggml_build_forward_expand(gf, data.candidates);
|
ggml_build_forward_expand(gf, data.candidates);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1018,9 +1018,8 @@ static bool llama_sampler_dist_backend_init(
|
||||||
ggml_tensor * op = ggml_cumsum(ctx, probs);
|
ggml_tensor * op = ggml_cumsum(ctx, probs);
|
||||||
|
|
||||||
auto * device = ggml_backend_buft_get_device(buft);
|
auto * device = ggml_backend_buft_get_device(buft);
|
||||||
GGML_ASSERT(device);
|
|
||||||
|
|
||||||
if (!ggml_backend_dev_supports_op(device, op)) {
|
if (device && !ggml_backend_dev_supports_op(device, op)) {
|
||||||
res = false;
|
res = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1099,7 +1098,6 @@ static void llama_sampler_dist_backend_apply(
|
||||||
ggml_set_name(sampled_token, "dist_sampled_token");
|
ggml_set_name(sampled_token, "dist_sampled_token");
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_set_output(sampled_token);
|
|
||||||
data->sampled = sampled_token;
|
data->sampled = sampled_token;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1192,9 +1190,8 @@ static bool llama_sampler_top_k_backend_init(
|
||||||
ggml_tensor * op = ggml_top_k(ctx, logits, sctx->k);
|
ggml_tensor * op = ggml_top_k(ctx, logits, sctx->k);
|
||||||
|
|
||||||
auto * device = ggml_backend_buft_get_device(buft);
|
auto * device = ggml_backend_buft_get_device(buft);
|
||||||
GGML_ASSERT(device);
|
|
||||||
|
|
||||||
if (!ggml_backend_dev_supports_op(device, op)) {
|
if (device && !ggml_backend_dev_supports_op(device, op)) {
|
||||||
res = false;
|
res = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1411,9 +1408,6 @@ static void llama_sampler_top_p_backend_apply(
|
||||||
data->logits = ggml_add(ctx, sorted_logits, top_p_bias);
|
data->logits = ggml_add(ctx, sorted_logits, top_p_bias);
|
||||||
ggml_set_name(data->logits, "top_p_logits");
|
ggml_set_name(data->logits, "top_p_logits");
|
||||||
|
|
||||||
ggml_set_output(data->candidates);
|
|
||||||
ggml_set_output(data->logits);
|
|
||||||
|
|
||||||
GGML_UNUSED(gf);
|
GGML_UNUSED(gf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue