This commit is contained in:
Daniel Bevenius 2026-02-07 07:00:42 +09:00 committed by GitHub
commit a078f09648
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 136 additions and 200 deletions

View File

@ -1863,3 +1863,56 @@ float lr_opt::get_lr(float epoch) const {
LOG_INF("epoch %.2g lr=%.2g\n", epoch, r);
return r;
}
bool common_prompt_batch_decode(
struct llama_context * ctx,
const std::vector<llama_token> & tokens,
int & n_past,
int n_batch,
const std::filesystem::path & state_path,
bool save_state,
bool is_last_batch) {
const int n_eval = tokens.size();
if (n_eval == 0) {
return true;
}
if (save_state && is_last_batch && n_eval > 1) {
const int n_tokens_before_last = n_eval - 1;
GGML_ASSERT(n_eval <= n_batch);
// Decode all but the last token so we can save the memory state before decoding the last token.
// This is done so we can restore the session state later and replay the last token.
// Memory implementations in recurrent/hybrid models don't support removing tokens from their
// memory, so we can't just remove the last token from the memory and replay the last token which
// is the reason for this logic.
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_tokens_before_last))) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}
n_past += n_tokens_before_last;
llama_state_save_file(ctx, state_path.string().c_str(), tokens.data(), n_tokens_before_last);
LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.string().c_str(), n_tokens_before_last);
llama_token last_token = tokens.back();
llama_batch batch = llama_batch_get_one(&last_token, 1);
int32_t pos = n_past;
batch.pos = &pos;
if (llama_decode(ctx, batch)) {
LOG_ERR("%s : failed to eval last token\n", __func__);
return false;
}
n_past++;
} else {
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_eval))) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}
n_past += n_eval;
}
return true;
}

View File

@ -5,6 +5,7 @@
#include "ggml-opt.h"
#include "llama-cpp.h"
#include <filesystem>
#include <set>
#include <sstream>
#include <string>
@ -780,6 +781,20 @@ void common_batch_add(
const std::vector<llama_seq_id> & seq_ids,
bool logits);
// decodes a single batch of tokens for a prompt and manages session tokens
//
// Note: We save state before the last token so that we can replay it to ensure
// compatibility with all memory types. Recurrent/hybrid models cannot remove
// tokens from memory, so this approach works across all model architectures.
bool common_prompt_batch_decode(
struct llama_context * ctx,
const std::vector<llama_token> & embd,
int & n_past,
int n_batch,
const std::filesystem::path & state_path,
bool save_state,
bool is_last_batch = true);
//
// Token utils
//

View File

@ -2,15 +2,30 @@
#include "common.h"
#include "llama.h"
#include <filesystem>
#include <vector>
#include <cstdio>
static bool replay_last_token(llama_context * ctx, llama_token last_token, int & n_past) {
llama_batch batch = llama_batch_get_one(&last_token, 1);
int pos = n_past;
batch.pos = &pos;
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s: failed to replay last token after loading state\n", __func__);
return false;
}
++n_past;
return true;
}
int main(int argc, char ** argv) {
common_params params;
params.prompt = "The quick brown fox";
params.sampling.seed = 1234;
std::filesystem::path state_file = "dump_state.bin";
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
return 1;
}
@ -53,35 +68,16 @@ int main(int argc, char ** argv) {
// tokenize prompt
auto tokens = common_tokenize(ctx, params.prompt, true);
// prepare the batch
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
for (size_t i = 0; i < tokens.size(); i++) {
common_batch_add(batch, tokens[i], i, {0}, false);
const bool save_state = true;
if (!common_prompt_batch_decode(ctx, tokens, n_past, params.n_batch, state_file, save_state)) {
return 1;
}
batch.logits[batch.n_tokens - 1] = true; // generate next token
// evaluate prompt
llama_decode(ctx, batch);
n_past += batch.n_tokens;
// save state (rng, logits, embedding and kv_cache) to file
{
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size());
FILE *fp_write = fopen("dump_state.bin", "wb");
fwrite(state_mem.data(), 1, written, fp_write);
fclose(fp_write);
fprintf(stderr, "%s : serialized state into %zd out of a maximum of %zd bytes\n", __func__, written, state_mem.size());
}
// save state (last tokens)
const auto n_past_saved = n_past;
// first run
printf("\nfirst run: %s", params.prompt.c_str());
llama_batch batch = llama_batch_init(1, 0, 1);
for (auto i = 0; i < params.n_predict; i++) {
auto next_token = llama_sampler_sample(smpl, ctx, -1);
auto next_token_str = common_token_to_piece(ctx, next_token);
@ -111,27 +107,22 @@ int main(int argc, char ** argv) {
printf("\nsecond run: %s", params.prompt.c_str());
// load state (rng, logits, embedding and kv_cache) from file
{
std::vector<uint8_t> state_mem;
// load state from file
std::vector<llama_token> unused_sts(tokens.size()); // unused session tokens.
size_t n_token_count_out = 0;
FILE * fp_read = fopen("dump_state.bin", "rb");
fseek(fp_read, 0, SEEK_END);
state_mem.resize(ftell(fp_read));
fseek(fp_read, 0, SEEK_SET);
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);
if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
return 1;
}
fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
if (!llama_state_load_file(ctx2, state_file.string().c_str(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
fprintf(stderr, "\n%s : failed to load state\n", __func__);
return 1;
}
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
// restore state (last tokens)
n_past = n_past_saved;
n_past = n_token_count_out;
if (!replay_last_token(ctx2, tokens.back(), n_past)) {
return 1;
}
// second run
for (auto i = 0; i < params.n_predict; i++) {
@ -160,7 +151,9 @@ int main(int argc, char ** argv) {
}
// make new context
llama_context * ctx3 = llama_init_from_model(model, common_context_params_to_llama(params));
auto params_ctx3 = common_context_params_to_llama(params);
params_ctx3.n_seq_max = 2;
llama_context * ctx3 = llama_init_from_model(model, params_ctx3);
llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
@ -169,26 +162,20 @@ int main(int argc, char ** argv) {
printf("\nsingle seq run: %s", params.prompt.c_str());
// load state (rng, logits, embedding and kv_cache) from file
{
std::vector<uint8_t> state_mem;
n_token_count_out = 0;
FILE * fp_read = fopen("dump_state.bin", "rb");
fseek(fp_read, 0, SEEK_END);
state_mem.resize(ftell(fp_read));
fseek(fp_read, 0, SEEK_SET);
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);
if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
return 1;
}
fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
if (!llama_state_load_file(ctx3, state_file.string().c_str(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) {
fprintf(stderr, "\n%s : failed to load state\n", __func__);
return 1;
}
fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out);
// restore state (last tokens)
n_past = n_past_saved;
n_past = n_token_count_out;
if (!replay_last_token(ctx3, tokens.back(), n_past)) {
return 1;
}
// save seq 0 and load into seq 1
{

View File

@ -2500,64 +2500,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
// TODO: add more model-specific info which should prevent loading the session file if not identical
}
// write output ids
{
LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
const auto n_outputs = this->n_outputs;
const auto & output_ids = this->output_ids;
std::vector<int32_t> w_output_pos;
w_output_pos.resize(n_outputs);
// build a more compact representation of the output ids
for (size_t i = 0; i < n_batch(); ++i) {
// map an output id to a position in the batch
int64_t pos = output_ids[i];
if (pos >= 0) {
GGML_ASSERT(pos < n_outputs);
w_output_pos[pos] = i;
}
}
io.write(&n_outputs, sizeof(n_outputs));
if (n_outputs) {
io.write(w_output_pos.data(), n_outputs * sizeof(int32_t));
}
}
// [TAG_CONTEXT_STATE_LOGITS]
// write logits
{
LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);
const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens());
io.write(&logits_size, sizeof(logits_size));
if (logits_size) {
io.write(logits, logits_size * sizeof(float));
}
}
// write embeddings
{
LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__);
const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd);
io.write(&embd_size, sizeof(embd_size));
if (embd_size) {
io.write(embd, embd_size * sizeof(float));
}
}
// TODO: handle sampling buffers and samplers state ?
// https://github.com/ggml-org/llama.cpp/pull/17004
if (memory != nullptr) {
LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__);
memory->state_write(io);
@ -2583,70 +2525,6 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
// TODO: add more info which needs to be identical but which is not verified otherwise
}
// read output ids
{
LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__);
auto n_outputs = this->n_outputs;
io.read_to(&n_outputs, sizeof(n_outputs));
if (n_outputs > output_reserve(n_outputs)) {
throw std::runtime_error("could not reserve outputs");
}
std::vector<int32_t> output_pos;
if (n_outputs) {
output_pos.resize(n_outputs);
io.read_to(output_pos.data(), n_outputs * sizeof(int32_t));
for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
int32_t id = output_pos[i];
if ((uint32_t) id >= n_batch()) {
throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch()));
}
this->output_ids[id] = i;
}
this->n_outputs = n_outputs;
}
}
// read logits
{
LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__);
uint64_t logits_size;
io.read_to(&logits_size, sizeof(logits_size));
if (this->logits_size < logits_size) {
throw std::runtime_error("logits buffer too small");
}
if (logits_size) {
io.read_to(this->logits, logits_size * sizeof(float));
}
}
// read embeddings
{
LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__);
uint64_t embd_size;
io.read_to(&embd_size, sizeof(embd_size));
if (this->embd_size < embd_size) {
throw std::runtime_error("embeddings buffer too small");
}
if (embd_size) {
io.read_to(this->embd, embd_size * sizeof(float));
}
}
// TODO: handle sampling buffers and samplers state ?
// https://github.com/ggml-org/llama.cpp/pull/17004
if (memory) {
LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__);

View File

@ -387,6 +387,23 @@ int main(int argc, char ** argv) {
}
session_do_save = !path_session.empty() && n_match < embd_inp.size() && !params.prompt_cache_ro;
// Logits are not stored as part of the session state so we need to
// "replay" the last token to get logits for sampling.
if (!session_tokens.empty() && n_match > 0 && n_match == session_tokens.size()) {
llama_token last_token = session_tokens.back();
int32_t pos = n_match;
llama_batch batch = llama_batch_get_one(&last_token, 1);
batch.pos = &pos;
if (llama_decode(ctx, batch)) {
LOG_ERR("%s: failed to regenerate logits after loading state\n", __func__);
return 1;
}
session_do_save = false;
LOG_INF("%s: replayed last token from session\n", __func__);
}
}
// number of tokens to keep when resetting context
@ -675,40 +692,26 @@ int main(int argc, char ** argv) {
}
if (!embd.empty()) {
int n_eval = (int) embd.size();
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
GGML_ASSERT(n_eval <= params.n_batch);
if (llama_decode(ctx, llama_batch_get_one(embd.data(), n_eval))) {
LOG_ERR("%s : failed to eval\n", __func__);
const bool is_last_batch = (n_consumed >= (int) embd_inp.size());
if (!common_prompt_batch_decode(ctx, embd, n_past, params.n_batch, path_session, session_do_save, is_last_batch)) {
return 1;
}
n_past += n_eval;
session_tokens.insert(session_tokens.end(), embd.begin(), embd.begin());
n_session_consumed = session_tokens.size();
session_do_save = false;
LOG_DBG("n_past = %d\n", n_past);
// Display total tokens alongside total time
if (params.n_print > 0 && n_past % params.n_print == 0) {
LOG_DBG("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx);
}
}
if (!embd.empty() && !path_session.empty()) {
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
n_session_consumed = session_tokens.size();
}
}
embd.clear();
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
// optionally save the session on first sample (for faster prompt loading next time)
if (session_do_save) {
session_do_save = false;
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
LOG_DBG("saved session to %s\n", path_session.c_str());
}
const llama_token id = common_sampler_sample(smpl, ctx, -1);