remove is_last_batch parameter from common_prompt_batch_decode
This commit is contained in:
parent
7902ae7380
commit
d9a6e49844
|
|
@ -1805,14 +1805,13 @@ bool common_prompt_batch_decode(
|
||||||
int & n_past,
|
int & n_past,
|
||||||
int n_batch,
|
int n_batch,
|
||||||
const std::string_view & state_path,
|
const std::string_view & state_path,
|
||||||
bool save_state,
|
bool save_state) {
|
||||||
bool is_last_batch) {
|
|
||||||
const int n_eval = tokens.size();
|
const int n_eval = tokens.size();
|
||||||
if (n_eval == 0) {
|
if (n_eval == 0) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (save_state && is_last_batch && n_eval > 1) {
|
if (save_state && n_eval > 1) {
|
||||||
const int n_tokens_before_last = n_eval - 1;
|
const int n_tokens_before_last = n_eval - 1;
|
||||||
|
|
||||||
GGML_ASSERT(n_eval <= n_batch);
|
GGML_ASSERT(n_eval <= n_batch);
|
||||||
|
|
|
||||||
|
|
@ -790,8 +790,7 @@ bool common_prompt_batch_decode(
|
||||||
int & n_past,
|
int & n_past,
|
||||||
int n_batch,
|
int n_batch,
|
||||||
const std::string_view & state_path,
|
const std::string_view & state_path,
|
||||||
bool save_state,
|
bool save_state);
|
||||||
bool is_last_batch = true);
|
|
||||||
|
|
||||||
// replays the last token after loading state to regenerate logits
|
// replays the last token after loading state to regenerate logits
|
||||||
// used after loading session state to ensure the sampling context has valid logits
|
// used after loading session state to ensure the sampling context has valid logits
|
||||||
|
|
|
||||||
|
|
@ -687,7 +687,8 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
if (!embd.empty()) {
|
if (!embd.empty()) {
|
||||||
const bool is_last_batch = (n_consumed >= (int) embd_inp.size());
|
const bool is_last_batch = (n_consumed >= (int) embd_inp.size());
|
||||||
if (!common_prompt_batch_decode(ctx, embd, n_past, params.n_batch, path_session, session_do_save, is_last_batch)) {
|
const bool save_now = session_do_save && is_last_batch;
|
||||||
|
if (!common_prompt_batch_decode(ctx, embd, n_past, params.n_batch, path_session, save_now)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
session_tokens.insert(session_tokens.end(), embd.begin(), embd.begin());
|
session_tokens.insert(session_tokens.end(), embd.begin(), embd.begin());
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue