retrieval : use at most n_seq_max chunks (#18400)
This commit is contained in:
parent
daa242dfc8
commit
0c8986403b
|
|
@ -222,8 +222,8 @@ int main(int argc, char ** argv) {
|
||||||
float * emb = embeddings.data();
|
float * emb = embeddings.data();
|
||||||
|
|
||||||
// break into batches
|
// break into batches
|
||||||
int p = 0; // number of prompts processed already
|
unsigned int p = 0; // number of prompts processed already
|
||||||
int s = 0; // number of prompts in current batch
|
unsigned int s = 0; // number of prompts in current batch
|
||||||
for (int k = 0; k < n_chunks; k++) {
|
for (int k = 0; k < n_chunks; k++) {
|
||||||
// clamp to n_batch tokens
|
// clamp to n_batch tokens
|
||||||
auto & inp = chunks[k].tokens;
|
auto & inp = chunks[k].tokens;
|
||||||
|
|
@ -231,7 +231,7 @@ int main(int argc, char ** argv) {
|
||||||
const uint64_t n_toks = inp.size();
|
const uint64_t n_toks = inp.size();
|
||||||
|
|
||||||
// encode if at capacity
|
// encode if at capacity
|
||||||
if (batch.n_tokens + n_toks > n_batch) {
|
if (batch.n_tokens + n_toks > n_batch || s >= llama_n_seq_max(ctx)) {
|
||||||
float * out = emb + p * n_embd;
|
float * out = emb + p * n_embd;
|
||||||
batch_process(ctx, batch, out, s, n_embd);
|
batch_process(ctx, batch, out, s, n_embd);
|
||||||
common_batch_clear(batch);
|
common_batch_clear(batch);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue