diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 2c2143ad10..8f92ff9057 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -222,8 +222,8 @@ int main(int argc, char ** argv) { float * emb = embeddings.data(); // break into batches - int p = 0; // number of prompts processed already - int s = 0; // number of prompts in current batch + unsigned int p = 0; // number of prompts processed already + unsigned int s = 0; // number of prompts in current batch for (int k = 0; k < n_chunks; k++) { // clamp to n_batch tokens auto & inp = chunks[k].tokens; @@ -231,7 +231,7 @@ int main(int argc, char ** argv) { const uint64_t n_toks = inp.size(); // encode if at capacity - if (batch.n_tokens + n_toks > n_batch) { + if (batch.n_tokens + n_toks > n_batch || s >= llama_n_seq_max(ctx)) { float * out = emb + p * n_embd; batch_process(ctx, batch, out, s, n_embd); common_batch_clear(batch);