squash! sampling : simplify backend sampling logic decode
Fix condition to check if backend actually sampled tokens, not just that backend samplers are available.
This commit is contained in:
parent
7e98ebcc6b
commit
d74eb61aa7
|
|
@ -1361,7 +1361,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
//}
|
||||
|
||||
if (has_backend_samplers) {
|
||||
const bool backend_has_sampled = !res->t_sampled_tokens.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
|
||||
|
||||
if (has_backend_samplers && backend_has_sampled) {
|
||||
const auto seq_to_batch_idx = build_seq_to_batch_idx(ubatch);
|
||||
|
||||
// If a backend sampler has sampled a token we only want to copy the
|
||||
|
|
@ -1381,7 +1383,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
// 2) CPU samplers to associate filtered logits with their token ids.
|
||||
copy_tensor_async_token_ids(res->t_sampled_token_ids, sampled_token_ids_map, seq_to_batch_idx, sched.get());
|
||||
|
||||
} else {
|
||||
}
|
||||
|
||||
if (!backend_has_sampled) {
|
||||
auto * t_logits = res->get_logits();
|
||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue