squash! sampling : simplify backend sampling logic decode
Fix condition to check if backend actually sampled tokens, not just that backend samplers are available.
This commit is contained in:
parent
7e98ebcc6b
commit
d74eb61aa7
|
|
@ -1361,7 +1361,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||||
//}
|
//}
|
||||||
|
|
||||||
if (has_backend_samplers) {
|
const bool backend_has_sampled = !res->t_sampled_tokens.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
|
||||||
|
|
||||||
|
if (has_backend_samplers && backend_has_sampled) {
|
||||||
const auto seq_to_batch_idx = build_seq_to_batch_idx(ubatch);
|
const auto seq_to_batch_idx = build_seq_to_batch_idx(ubatch);
|
||||||
|
|
||||||
// If a backend sampler has sampled a token we only want to copy the
|
// If a backend sampler has sampled a token we only want to copy the
|
||||||
|
|
@ -1381,7 +1383,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||||
// 2) CPU samplers to associate filtered logits with their token ids.
|
// 2) CPU samplers to associate filtered logits with their token ids.
|
||||||
copy_tensor_async_token_ids(res->t_sampled_token_ids, sampled_token_ids_map, seq_to_batch_idx, sched.get());
|
copy_tensor_async_token_ids(res->t_sampled_token_ids, sampled_token_ids_map, seq_to_batch_idx, sched.get());
|
||||||
|
|
||||||
} else {
|
}
|
||||||
|
|
||||||
|
if (!backend_has_sampled) {
|
||||||
auto * t_logits = res->get_logits();
|
auto * t_logits = res->get_logits();
|
||||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue