sampling : add stride variable for clarity

This commit is contained in:
Daniel Bevenius 2025-11-23 11:27:54 +01:00
parent 79b8cf2a75
commit 65500d05ab
No known key found for this signature in database
1 changed files with 4 additions and 3 deletions

View File

@ -1468,6 +1468,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
if (has_backend_samplers && backend_has_sampled) {
const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
const auto stride = n_vocab;
// If a backend sampler has sampled a token we only want to copy the
// sampled tokens and avoid copying logits and probabilites.
@ -1476,15 +1477,15 @@ int llama_context::decode(const llama_batch & batch_inp) {
copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get());
} else {
// async copy the sampled logits/probs from the backend to the host.
copy_tensor_async_floats(res->t_sampled_logits, sampling.logits, n_vocab, sampling.logits_count, seq_to_output_row, sched.get());
copy_tensor_async_floats(res->t_sampled_probs, sampling.probs, n_vocab, sampling.probs_count, seq_to_output_row, sched.get());
copy_tensor_async_floats(res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get());
copy_tensor_async_floats(res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get());
}
// async copy the candidate token ids from the backend to the host.
// These are needed for:
// 1) Backend dist sampler to map indices to vocab token ids.
// 2) CPU samplers to associate candidate logits with their token ids.
copy_tensor_async_candidates(res->t_candidates, sampling.candidates, n_vocab, sampling.candidates_count, seq_to_output_row, sched.get());
copy_tensor_async_candidates(res->t_candidates, sampling.candidates, stride, sampling.candidates_count, seq_to_output_row, sched.get());
}