completion : simplify batch (embd) processing (#19286)
* completion : simplify batch (embd) processing This commit simplifies the processing of embd by removing the for loop that currently exists which uses params.n_batch as its increment. This commit also removes the clamping of n_eval as the size of embd is always at most the size of params.n_batch. The motivation is to clarify the code as it is currently a little confusing when looking at this for loop in isolation and thinking that it can process multiple batches. * add an assert to verify n_eval is not greater than n_batch
This commit is contained in:
parent
015deb9048
commit
25f40ca65f
|
|
@ -674,15 +674,12 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
|
||||
int n_eval = (int) embd.size() - i;
|
||||
if (n_eval > params.n_batch) {
|
||||
n_eval = params.n_batch;
|
||||
}
|
||||
|
||||
if (!embd.empty()) {
|
||||
int n_eval = (int) embd.size();
|
||||
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
|
||||
|
||||
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
|
||||
GGML_ASSERT(n_eval <= params.n_batch);
|
||||
if (llama_decode(ctx, llama_batch_get_one(embd.data(), n_eval))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
|
@ -743,7 +740,7 @@ int main(int argc, char ** argv) {
|
|||
common_sampler_accept(smpl, embd_inp[n_consumed], /* accept_grammar= */ false);
|
||||
|
||||
++n_consumed;
|
||||
if ((int) embd.size() >= params.n_batch) {
|
||||
if ((int) embd.size() == params.n_batch) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue