batch : fix sequence id ownage
This commit is contained in:
parent
38882247d3
commit
44d5c4b592
|
|
@ -714,6 +714,8 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
|||
udata->seq_idx .resize(LLAMA_MAX_SEQ, -1);
|
||||
udata->output .resize(n_tokens);
|
||||
|
||||
udata->seq_id_data.resize(n_tokens);
|
||||
|
||||
seq_set_t seq_set_unq;
|
||||
|
||||
for (size_t i = 0; i < idxs.size(); ++i) {
|
||||
|
|
@ -735,7 +737,11 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
|||
}
|
||||
|
||||
udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
|
||||
udata->seq_id[i] = batch.seq_id[idxs[i]];
|
||||
udata->seq_id_data[i].reserve(udata->n_seq_id[i]);
|
||||
for (int s = 0; s < udata->n_seq_id[i]; ++s) {
|
||||
udata->seq_id_data[i].push_back(batch.seq_id[idxs[i]][s]);
|
||||
}
|
||||
udata->seq_id[i] = udata->seq_id_data[i].data();
|
||||
udata->output[i] = batch.logits[idxs[i]];
|
||||
|
||||
for (int s = 0; s < udata->n_seq_id[i]; ++s) {
|
||||
|
|
|
|||
|
|
@ -60,6 +60,8 @@ struct llama_ubatch {
|
|||
std::vector<llama_seq_id> seq_id_unq;
|
||||
std::vector<int32_t> seq_idx;
|
||||
std::vector<int8_t> output;
|
||||
|
||||
std::vector<std::vector<llama_seq_id>> seq_id_data;
|
||||
};
|
||||
|
||||
// the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data
|
||||
|
|
|
|||
Loading…
Reference in New Issue