common : add common_prompt_batch_decode function

This commit adds a new function which is responsible for decoding prompt
and optionally handle the saving for session data.
This commit is contained in:
Daniel Bevenius 2026-02-03 10:13:04 +01:00
parent 44bddc0a89
commit e1373fd89c
No known key found for this signature in database
3 changed files with 78 additions and 33 deletions

View File

@ -1863,3 +1863,56 @@ float lr_opt::get_lr(float epoch) const {
LOG_INF("epoch %.2g lr=%.2g\n", epoch, r);
return r;
}
bool common_prompt_batch_decode(
struct llama_context * ctx,
const std::vector<llama_token> & tokens,
int & n_past,
int n_batch,
const std::filesystem::path & state_path,
bool save_state,
bool is_last_batch) {
const int n_eval = tokens.size();
if (n_eval == 0) {
return true;
}
if (save_state && is_last_batch && n_eval > 1) {
const int n_tokens_before_last = n_eval - 1;
GGML_ASSERT(n_eval <= n_batch);
// Decode all but the last token so we can save the memory state before decoding the last token.
// This is done so we can restore the session state later and replay the last token.
// Memory implementations in recurrent/hybrid models don't support removing tokens from their
// memory, so we can't just remove the last token from the memory and replay the last token which
// is the reason for this logic.
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_tokens_before_last))) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}
n_past += n_tokens_before_last;
llama_state_save_file(ctx, state_path.string().c_str(), tokens.data(), n_tokens_before_last);
LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.string().c_str(), n_tokens_before_last);
llama_token last_token = tokens.back();
llama_batch batch = llama_batch_get_one(&last_token, 1);
int32_t pos = n_past;
batch.pos = &pos;
if (llama_decode(ctx, batch)) {
LOG_ERR("%s : failed to eval last token\n", __func__);
return false;
}
n_past++;
} else {
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_eval))) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}
n_past += n_eval;
}
return true;
}

View File

@ -5,6 +5,7 @@
#include "ggml-opt.h"
#include "llama-cpp.h"
#include <filesystem>
#include <set>
#include <sstream>
#include <string>
@ -780,6 +781,20 @@ void common_batch_add(
const std::vector<llama_seq_id> & seq_ids,
bool logits);
// decodes a single batch of tokens for a prompt and manages session tokens
//
// Note: We save state before the last token so that we can replay it to ensure
// compatibility with all memory types. Recurrent/hybrid models cannot remove
// tokens from memory, so this approach works across all model architectures.
bool common_prompt_batch_decode(
struct llama_context * ctx,
const std::vector<llama_token> & embd,
int & n_past,
int n_batch,
const std::filesystem::path & state_path,
bool save_state,
bool is_last_batch = true);
//
// Token utils
//

View File

@ -392,19 +392,7 @@ int main(int argc, char ** argv) {
// "replay" the last token to get logits for sampling.
if (!session_tokens.empty() && n_match > 0 && n_match == session_tokens.size()) {
llama_token last_token = session_tokens.back();
int32_t pos;
if (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) {
LOG_INF("%s: recurrent/hybrid model: decode using next position: %d\n", __func__, (int)n_match);
pos = n_match; // use next position for decoding
} else {
LOG_INF("%s: non-recurrent model: removing and re-decode last position: %d\n", __func__, (int)n_match - 1);
if (!llama_memory_seq_rm(mem, 0, n_match - 1, n_match)) {
LOG_ERR("%s: failed to remove last position from KV cache\n", __func__);
return 1;
}
pos = n_match - 1;
}
int32_t pos = n_match;
llama_batch batch = llama_batch_get_one(&last_token, 1);
batch.pos = &pos;
@ -412,6 +400,9 @@ int main(int argc, char ** argv) {
LOG_ERR("%s: failed to regenerate logits after loading state\n", __func__);
return 1;
}
session_do_save = false;
LOG_INF("%s: replayed last token from session\n", __func__);
}
}
@ -701,40 +692,26 @@ int main(int argc, char ** argv) {
}
if (!embd.empty()) {
int n_eval = (int) embd.size();
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
GGML_ASSERT(n_eval <= params.n_batch);
if (llama_decode(ctx, llama_batch_get_one(embd.data(), n_eval))) {
LOG_ERR("%s : failed to eval\n", __func__);
const bool is_last_batch = (n_consumed >= (int) embd_inp.size());
if (!common_prompt_batch_decode(ctx, embd, n_past, params.n_batch, path_session, session_do_save, is_last_batch)) {
return 1;
}
n_past += n_eval;
session_tokens.insert(session_tokens.end(), embd.begin(), embd.begin());
n_session_consumed = session_tokens.size();
session_do_save = false;
LOG_DBG("n_past = %d\n", n_past);
// Display total tokens alongside total time
if (params.n_print > 0 && n_past % params.n_print == 0) {
LOG_DBG("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx);
}
}
if (!embd.empty() && !path_session.empty()) {
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
n_session_consumed = session_tokens.size();
}
}
embd.clear();
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// optionally save the session on first sample (for faster prompt loading next time)
if (session_do_save) {
session_do_save = false;
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
LOG_DBG("saved session to %s\n", path_session.c_str());
}
const llama_token id = common_sampler_sample(smpl, ctx, -1);