common : extract replay_last_token to common.h
This commit extracts the replay_last_token function from save-load-state.cpp to common.h. The motivation for this is to allow reuse of the function but also to clarify the intent of code that replays the last token after loading the session state.
This commit is contained in:
parent
a70867c19e
commit
d9a23126bf
|
|
@ -1789,6 +1789,16 @@ float lr_opt::get_lr(float epoch) const {
|
|||
return r;
|
||||
}
|
||||
|
||||
bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos) {
|
||||
llama_batch batch = llama_batch_get_one(&last_token, 1);
|
||||
batch.pos = &pos;
|
||||
if (llama_decode(ctx, batch)) {
|
||||
LOG_ERR("%s: failed to replay last token\n", __func__);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool common_prompt_batch_decode(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<llama_token> & tokens,
|
||||
|
|
|
|||
|
|
@ -794,6 +794,10 @@ bool common_prompt_batch_decode(
|
|||
bool save_state,
|
||||
bool is_last_batch = true);
|
||||
|
||||
// replays the last token after loading state to regenerate logits
|
||||
// used after loading session state to ensure the sampling context has valid logits
|
||||
bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos);
|
||||
|
||||
//
|
||||
// Vocab utils
|
||||
//
|
||||
|
|
|
|||
|
|
@ -6,17 +6,6 @@
|
|||
#include <vector>
|
||||
#include <cstdio>
|
||||
|
||||
static bool replay_last_token(llama_context * ctx, llama_token last_token, int & n_past) {
|
||||
llama_batch batch = llama_batch_get_one(&last_token, 1);
|
||||
int pos = n_past;
|
||||
batch.pos = &pos;
|
||||
if (llama_decode(ctx, batch)) {
|
||||
fprintf(stderr, "%s: failed to replay last token after loading state\n", __func__);
|
||||
return false;
|
||||
}
|
||||
++n_past;
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
|
@ -120,9 +109,10 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// restore state (last tokens)
|
||||
n_past = n_token_count_out;
|
||||
if (!replay_last_token(ctx2, tokens.back(), n_past)) {
|
||||
if (!common_replay_last_token(ctx2, tokens.back(), n_past)) {
|
||||
return 1;
|
||||
}
|
||||
++n_past;
|
||||
|
||||
// second run
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
|
|
@ -173,9 +163,10 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// restore state (last tokens)
|
||||
n_past = n_token_count_out;
|
||||
if (!replay_last_token(ctx3, tokens.back(), n_past)) {
|
||||
if (!common_replay_last_token(ctx3, tokens.back(), n_past)) {
|
||||
return 1;
|
||||
}
|
||||
++n_past;
|
||||
|
||||
// save seq 0 and load into seq 1
|
||||
{
|
||||
|
|
|
|||
|
|
@ -391,13 +391,7 @@ int main(int argc, char ** argv) {
|
|||
// Logits are not stored as part of the session state so we need to
|
||||
// "replay" the last token to get logits for sampling.
|
||||
if (!session_tokens.empty() && n_match > 0 && n_match == session_tokens.size()) {
|
||||
llama_token last_token = session_tokens.back();
|
||||
int32_t pos = n_match;
|
||||
|
||||
llama_batch batch = llama_batch_get_one(&last_token, 1);
|
||||
batch.pos = &pos;
|
||||
if (llama_decode(ctx, batch)) {
|
||||
LOG_ERR("%s: failed to regenerate logits after loading state\n", __func__);
|
||||
if (!common_replay_last_token(ctx, session_tokens.back(), n_match)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue