common : add common_prompt_batch_decode function
This commit adds a new function which is responsible for decoding prompt and optionally handle the saving for session data.
This commit is contained in:
parent
44bddc0a89
commit
e1373fd89c
|
|
@ -1863,3 +1863,56 @@ float lr_opt::get_lr(float epoch) const {
|
|||
LOG_INF("epoch %.2g lr=%.2g\n", epoch, r);
|
||||
return r;
|
||||
}
|
||||
|
||||
bool common_prompt_batch_decode(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<llama_token> & tokens,
|
||||
int & n_past,
|
||||
int n_batch,
|
||||
const std::filesystem::path & state_path,
|
||||
bool save_state,
|
||||
bool is_last_batch) {
|
||||
const int n_eval = tokens.size();
|
||||
if (n_eval == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (save_state && is_last_batch && n_eval > 1) {
|
||||
const int n_tokens_before_last = n_eval - 1;
|
||||
|
||||
GGML_ASSERT(n_eval <= n_batch);
|
||||
|
||||
// Decode all but the last token so we can save the memory state before decoding the last token.
|
||||
// This is done so we can restore the session state later and replay the last token.
|
||||
// Memory implementations in recurrent/hybrid models don't support removing tokens from their
|
||||
// memory, so we can't just remove the last token from the memory and replay the last token which
|
||||
// is the reason for this logic.
|
||||
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_tokens_before_last))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
n_past += n_tokens_before_last;
|
||||
|
||||
llama_state_save_file(ctx, state_path.string().c_str(), tokens.data(), n_tokens_before_last);
|
||||
LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.string().c_str(), n_tokens_before_last);
|
||||
|
||||
llama_token last_token = tokens.back();
|
||||
llama_batch batch = llama_batch_get_one(&last_token, 1);
|
||||
int32_t pos = n_past;
|
||||
batch.pos = &pos;
|
||||
|
||||
if (llama_decode(ctx, batch)) {
|
||||
LOG_ERR("%s : failed to eval last token\n", __func__);
|
||||
return false;
|
||||
}
|
||||
n_past++;
|
||||
} else {
|
||||
if (llama_decode(ctx, llama_batch_get_one(const_cast<llama_token*>(tokens.data()), n_eval))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
n_past += n_eval;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include "ggml-opt.h"
|
||||
#include "llama-cpp.h"
|
||||
|
||||
#include <filesystem>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
|
@ -780,6 +781,20 @@ void common_batch_add(
|
|||
const std::vector<llama_seq_id> & seq_ids,
|
||||
bool logits);
|
||||
|
||||
// decodes a single batch of tokens for a prompt and manages session tokens
|
||||
//
|
||||
// Note: We save state before the last token so that we can replay it to ensure
|
||||
// compatibility with all memory types. Recurrent/hybrid models cannot remove
|
||||
// tokens from memory, so this approach works across all model architectures.
|
||||
bool common_prompt_batch_decode(
|
||||
struct llama_context * ctx,
|
||||
const std::vector<llama_token> & embd,
|
||||
int & n_past,
|
||||
int n_batch,
|
||||
const std::filesystem::path & state_path,
|
||||
bool save_state,
|
||||
bool is_last_batch = true);
|
||||
|
||||
//
|
||||
// Token utils
|
||||
//
|
||||
|
|
|
|||
|
|
@ -392,19 +392,7 @@ int main(int argc, char ** argv) {
|
|||
// "replay" the last token to get logits for sampling.
|
||||
if (!session_tokens.empty() && n_match > 0 && n_match == session_tokens.size()) {
|
||||
llama_token last_token = session_tokens.back();
|
||||
int32_t pos;
|
||||
|
||||
if (llama_model_is_recurrent(model) || llama_model_is_hybrid(model)) {
|
||||
LOG_INF("%s: recurrent/hybrid model: decode using next position: %d\n", __func__, (int)n_match);
|
||||
pos = n_match; // use next position for decoding
|
||||
} else {
|
||||
LOG_INF("%s: non-recurrent model: removing and re-decode last position: %d\n", __func__, (int)n_match - 1);
|
||||
if (!llama_memory_seq_rm(mem, 0, n_match - 1, n_match)) {
|
||||
LOG_ERR("%s: failed to remove last position from KV cache\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
pos = n_match - 1;
|
||||
}
|
||||
int32_t pos = n_match;
|
||||
|
||||
llama_batch batch = llama_batch_get_one(&last_token, 1);
|
||||
batch.pos = &pos;
|
||||
|
|
@ -412,6 +400,9 @@ int main(int argc, char ** argv) {
|
|||
LOG_ERR("%s: failed to regenerate logits after loading state\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
session_do_save = false;
|
||||
LOG_INF("%s: replayed last token from session\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -701,40 +692,26 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
if (!embd.empty()) {
|
||||
int n_eval = (int) embd.size();
|
||||
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
|
||||
|
||||
GGML_ASSERT(n_eval <= params.n_batch);
|
||||
if (llama_decode(ctx, llama_batch_get_one(embd.data(), n_eval))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
const bool is_last_batch = (n_consumed >= (int) embd_inp.size());
|
||||
if (!common_prompt_batch_decode(ctx, embd, n_past, params.n_batch, path_session, session_do_save, is_last_batch)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
n_past += n_eval;
|
||||
session_tokens.insert(session_tokens.end(), embd.begin(), embd.begin());
|
||||
n_session_consumed = session_tokens.size();
|
||||
session_do_save = false;
|
||||
|
||||
LOG_DBG("n_past = %d\n", n_past);
|
||||
|
||||
// Display total tokens alongside total time
|
||||
if (params.n_print > 0 && n_past % params.n_print == 0) {
|
||||
LOG_DBG("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx);
|
||||
}
|
||||
}
|
||||
|
||||
if (!embd.empty() && !path_session.empty()) {
|
||||
session_tokens.insert(session_tokens.end(), embd.begin(), embd.end());
|
||||
n_session_consumed = session_tokens.size();
|
||||
}
|
||||
}
|
||||
|
||||
embd.clear();
|
||||
|
||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||
// optionally save the session on first sample (for faster prompt loading next time)
|
||||
if (session_do_save) {
|
||||
session_do_save = false;
|
||||
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
||||
|
||||
LOG_DBG("saved session to %s\n", path_session.c_str());
|
||||
}
|
||||
|
||||
const llama_token id = common_sampler_sample(smpl, ctx, -1);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue