diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 2700b970c9..cd9bc14fd4 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -714,6 +714,8 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u udata->seq_idx .resize(LLAMA_MAX_SEQ, -1); udata->output .resize(n_tokens); + udata->seq_id_data.resize(n_tokens); + seq_set_t seq_set_unq; for (size_t i = 0; i < idxs.size(); ++i) { @@ -735,7 +737,11 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u } udata->n_seq_id[i] = batch.n_seq_id[idxs[i]]; - udata->seq_id[i] = batch.seq_id[idxs[i]]; + udata->seq_id_data[i].reserve(udata->n_seq_id[i]); + for (int s = 0; s < udata->n_seq_id[i]; ++s) { + udata->seq_id_data[i].push_back(batch.seq_id[idxs[i]][s]); + } + udata->seq_id[i] = udata->seq_id_data[i].data(); udata->output[i] = batch.logits[idxs[i]]; for (int s = 0; s < udata->n_seq_id[i]; ++s) { diff --git a/src/llama-batch.h b/src/llama-batch.h index db7a75b804..97213e2398 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -60,6 +60,8 @@ struct llama_ubatch { std::vector seq_id_unq; std::vector seq_idx; std::vector output; + + std::vector> seq_id_data; }; // the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data