squash! sampling : support intermixed backend/cpu samplers
Add check that logits is not null which is can happen for embeddings.
This commit is contained in:
parent
74be332e24
commit
9ad6522be6
|
|
@ -1662,16 +1662,18 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
|
|||
bool batch_has_backend_sampling = false;
|
||||
bool batch_needs_cpu_logits = false;
|
||||
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
if (!batch.logits[i]) {
|
||||
continue;
|
||||
}
|
||||
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
|
||||
llama_seq_id seq_id = batch.seq_id[i][j];
|
||||
if (sampling.samplers.find(seq_id) != sampling.samplers.end()) {
|
||||
batch_has_backend_sampling = true;
|
||||
} else {
|
||||
batch_needs_cpu_logits = true;
|
||||
if (batch.logits) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
if (!batch.logits[i]) {
|
||||
continue;
|
||||
}
|
||||
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
|
||||
llama_seq_id seq_id = batch.seq_id[i][j];
|
||||
if (sampling.samplers.find(seq_id) != sampling.samplers.end()) {
|
||||
batch_has_backend_sampling = true;
|
||||
} else {
|
||||
batch_needs_cpu_logits = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue