llama : remove implicit recurrent state rollbacks
This commit is contained in:
parent
124c222f76
commit
8006f3b3c8
|
|
@ -966,7 +966,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||
if (llama_model_has_decoder(model)) {
|
||||
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
|
||||
}
|
||||
llama_past_clear(lctx);
|
||||
llama_kv_cache_clear(lctx);
|
||||
llama_synchronize(lctx);
|
||||
llama_perf_context_reset(lctx);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const auto t_pp_start = ggml_time_us();
|
||||
|
||||
llama_past_clear(ctx);
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
|
||||
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
||||
|
|
@ -142,7 +142,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
if (is_pp_shared) {
|
||||
for (int32_t i = 1; i < pl; ++i) {
|
||||
llama_past_seq_cp(ctx, 0, i, -1, -1);
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ if llama_decode(context, batch) != 0 {
|
|||
}
|
||||
|
||||
for i in 1 ..< n_parallel {
|
||||
llama_past_seq_cp(context, 0, Int32(i), -1, -1)
|
||||
llama_kv_cache_seq_cp(context, 0, Int32(i), -1, -1)
|
||||
}
|
||||
|
||||
if n_parallel > 1 {
|
||||
|
|
|
|||
|
|
@ -138,7 +138,7 @@ int main(int argc, char ** argv) {
|
|||
//// assign the system KV cache to all parallel sequences
|
||||
//// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
|
||||
//for (int32_t i = 1; i < n_parallel; ++i) {
|
||||
// llama_past_seq_cp(ctx, 0, i, -1, -1);
|
||||
// llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
||||
//}
|
||||
|
||||
if (n_parallel > 1) {
|
||||
|
|
|
|||
|
|
@ -338,7 +338,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
|
|||
}
|
||||
|
||||
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
|
||||
llama_past_clear(ctx);
|
||||
llama_kv_cache_clear(ctx);
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
|
|||
const struct llama_model * model = llama_get_model(ctx);
|
||||
|
||||
// clear previous kv_cache values (irrelevant for embeddings)
|
||||
llama_past_clear(ctx);
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
// run model
|
||||
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
|||
}
|
||||
|
||||
// clear previous kv_cache values (irrelevant for embeddings)
|
||||
llama_past_clear(ctx);
|
||||
llama_kv_cache_clear(ctx);
|
||||
llama_set_embeddings(ctx, true);
|
||||
llama_set_causal_attn(ctx, false);
|
||||
|
||||
|
|
@ -99,7 +99,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
|
|||
const llama_model * model = llama_get_model(ctx);
|
||||
llama_token eos_token = llama_token_eos(model);
|
||||
|
||||
llama_past_clear(ctx);
|
||||
llama_kv_cache_clear(ctx);
|
||||
llama_set_embeddings(ctx, false);
|
||||
llama_set_causal_attn(ctx, true);
|
||||
|
||||
|
|
|
|||
|
|
@ -494,7 +494,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
|
|||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// clear the KV cache
|
||||
llama_past_clear(ctx);
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
for (int j = 0; j < num_batches; ++j) {
|
||||
const int batch_start = start + j * n_batch;
|
||||
|
|
|
|||
|
|
@ -375,8 +375,8 @@ int main(int argc, char ** argv) {
|
|||
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
||||
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
||||
|
||||
llama_past_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
|
||||
llama_past_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
|
||||
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
|
||||
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
|
||||
|
||||
n_past -= n_discard;
|
||||
|
||||
|
|
|
|||
|
|
@ -1566,7 +1566,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
test t(inst, lmodel, ctx);
|
||||
|
||||
llama_past_clear(ctx);
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
// cool off before the test
|
||||
if (params.delay) {
|
||||
|
|
@ -1606,7 +1606,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
for (int i = 0; i < params.reps; i++) {
|
||||
llama_past_clear(ctx);
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
uint64_t t_start = get_time_ns();
|
||||
|
||||
|
|
|
|||
|
|
@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
|||
}
|
||||
|
||||
batch->logits[batch->n_tokens - 1] = true;
|
||||
llama_past_clear(context);
|
||||
llama_kv_cache_clear(context);
|
||||
|
||||
const auto t_pp_start = ggml_time_us();
|
||||
if (llama_decode(context, *batch) != 0) {
|
||||
|
|
@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
|||
|
||||
LOGi("Benchmark text generation (tg)");
|
||||
|
||||
llama_past_clear(context);
|
||||
llama_kv_cache_clear(context);
|
||||
const auto t_tg_start = ggml_time_us();
|
||||
for (i = 0; i < tg; i++) {
|
||||
|
||||
|
|
@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
|
|||
|
||||
const auto t_tg_end = ggml_time_us();
|
||||
|
||||
llama_past_clear(context);
|
||||
llama_kv_cache_clear(context);
|
||||
|
||||
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
|
||||
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
|
||||
|
|
@ -446,5 +446,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
|||
extern "C"
|
||||
JNIEXPORT void JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
|
||||
llama_past_clear(reinterpret_cast<llama_context *>(context));
|
||||
llama_kv_cache_clear(reinterpret_cast<llama_context *>(context));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -209,7 +209,7 @@ actor LlamaContext {
|
|||
}
|
||||
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
|
||||
|
||||
llama_past_clear(context)
|
||||
llama_kv_cache_clear(context)
|
||||
|
||||
let t_pp_start = ggml_time_us()
|
||||
|
||||
|
|
@ -222,7 +222,7 @@ actor LlamaContext {
|
|||
|
||||
// bench text generation
|
||||
|
||||
llama_past_clear(context)
|
||||
llama_kv_cache_clear(context)
|
||||
|
||||
let t_tg_start = ggml_time_us()
|
||||
|
||||
|
|
@ -241,7 +241,7 @@ actor LlamaContext {
|
|||
|
||||
let t_tg_end = ggml_time_us()
|
||||
|
||||
llama_past_clear(context)
|
||||
llama_kv_cache_clear(context)
|
||||
|
||||
let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0
|
||||
let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0
|
||||
|
|
@ -291,7 +291,7 @@ actor LlamaContext {
|
|||
func clear() {
|
||||
tokens_list.removeAll()
|
||||
temporary_invalid_cchars.removeAll()
|
||||
llama_past_clear(context)
|
||||
llama_kv_cache_clear(context)
|
||||
}
|
||||
|
||||
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ int main(int argc, char ** argv) {
|
|||
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
|
||||
|
||||
for (int s = 1; s < W + G + 1; ++s) {
|
||||
llama_past_seq_cp(ctx, 0, s, -1, -1);
|
||||
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
|
||||
}
|
||||
|
||||
const auto t_enc_end = ggml_time_us();
|
||||
|
|
@ -436,17 +436,17 @@ int main(int argc, char ** argv) {
|
|||
// KV cache management
|
||||
// if no verification token matched, we simply remove all cells from this batch -> no fragmentation
|
||||
// FIXME: recurrent and hybrid models
|
||||
llama_past_seq_rm(ctx, -1, n_past, -1);
|
||||
llama_kv_cache_seq_rm(ctx, -1, n_past, -1);
|
||||
|
||||
if (seq_id_best != 0) {
|
||||
// if a verification token matched, we keep the best sequence and remove the rest
|
||||
// this leads to some KV cache fragmentation
|
||||
llama_past_seq_keep(ctx, seq_id_best);
|
||||
llama_past_seq_cp (ctx, seq_id_best, 0, -1, -1);
|
||||
llama_past_seq_rm (ctx, seq_id_best, -1, -1);
|
||||
llama_kv_cache_seq_keep(ctx, seq_id_best);
|
||||
llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1);
|
||||
llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1);
|
||||
|
||||
for (int s = 1; s < W + G + 1; ++s) {
|
||||
llama_past_seq_cp(ctx, 0, s, -1, -1);
|
||||
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -191,7 +191,7 @@ int main(int argc, char ** argv){
|
|||
// KV cache management
|
||||
// clean the cache of draft tokens that weren't accepted
|
||||
// FIXME: recurrent and hybrid models
|
||||
llama_past_seq_rm(ctx, 0, n_past, -1);
|
||||
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
|
||||
|
||||
common_batch_clear(batch_tgt);
|
||||
common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);
|
||||
|
|
|
|||
|
|
@ -332,10 +332,6 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
n_matching_session_tokens++;
|
||||
}
|
||||
|
||||
// remove any "future" tokens that we might have inherited from the previous session
|
||||
n_matching_session_tokens = llama_past_seq_rm(ctx, -1, n_matching_session_tokens, -1);
|
||||
|
||||
if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) {
|
||||
LOG_INF("%s: using full prompt from session file\n", __func__);
|
||||
} else if (n_matching_session_tokens >= embd_inp.size()) {
|
||||
|
|
@ -347,6 +343,9 @@ int main(int argc, char ** argv) {
|
|||
LOG_INF("%s: session file matches %zu / %zu tokens of prompt\n",
|
||||
__func__, n_matching_session_tokens, embd_inp.size());
|
||||
}
|
||||
|
||||
// remove any "future" tokens that we might have inherited from the previous session
|
||||
llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
|
||||
}
|
||||
|
||||
LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
|
||||
|
|
@ -358,8 +357,6 @@ int main(int argc, char ** argv) {
|
|||
LOG_DBG("recalculate the cached logits (do): session_tokens.resize( %zu )\n", embd_inp.size() - 1);
|
||||
|
||||
session_tokens.resize(embd_inp.size() - 1);
|
||||
} else {
|
||||
session_tokens.resize(n_matching_session_tokens);
|
||||
}
|
||||
|
||||
// number of tokens to keep when resetting context
|
||||
|
|
@ -609,9 +606,9 @@ int main(int argc, char ** argv) {
|
|||
LOG_DBG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
|
||||
LOG_DBG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
|
||||
|
||||
llama_past_seq_add(ctx, 0, ga_i, n_past, ib*bd);
|
||||
llama_past_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
|
||||
llama_past_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
|
||||
llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd);
|
||||
llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
|
||||
llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
|
||||
|
||||
n_past -= bd;
|
||||
|
||||
|
|
@ -625,8 +622,6 @@ int main(int argc, char ** argv) {
|
|||
if (n_session_consumed < (int) session_tokens.size()) {
|
||||
size_t i = 0;
|
||||
for ( ; i < embd.size(); i++) {
|
||||
// TODO: are the session tokens guaranteed to all be matching here?
|
||||
// Should n_matching_session_tokens be re-used instead?
|
||||
if (embd[i] != session_tokens[n_session_consumed]) {
|
||||
session_tokens.resize(n_session_consumed);
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// assign the system KV cache to all parallel sequences
|
||||
for (int32_t i = 1; i <= n_clients; ++i) {
|
||||
llama_past_seq_cp(ctx, 0, i, -1, -1);
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
||||
}
|
||||
|
||||
LOG_INF("\n");
|
||||
|
|
@ -231,9 +231,9 @@ int main(int argc, char ** argv) {
|
|||
if (batch.n_tokens == 0) {
|
||||
// all sequences have ended - clear the entire KV cache
|
||||
for (int i = 1; i <= n_clients; ++i) {
|
||||
llama_past_seq_rm(ctx, i, -1, -1);
|
||||
llama_kv_cache_seq_rm(ctx, i, -1, -1);
|
||||
// but keep the system prompt
|
||||
llama_past_seq_cp(ctx, 0, i, -1, -1);
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
||||
}
|
||||
|
||||
LOG_INF("%s: clearing the KV cache\n", __func__);
|
||||
|
|
@ -370,8 +370,8 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
|
||||
llama_past_seq_rm(ctx, client.id + 1, -1, -1);
|
||||
llama_past_seq_cp(ctx, 0, client.id + 1, -1, -1);
|
||||
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
|
||||
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
|
||||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
||||
|
|
|
|||
|
|
@ -130,11 +130,11 @@ int main(int argc, char ** argv) {
|
|||
const int ib = i/n_batch - 1;
|
||||
const int bd = n_batch_grp*(n_grp - 1);
|
||||
|
||||
llama_past_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
|
||||
llama_past_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
|
||||
llama_kv_cache_update(ctx);
|
||||
llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
|
||||
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
|
||||
llama_kv_cache_update (ctx);
|
||||
|
||||
n_past = llama_past_seq_pos_max(ctx, 0) + 1;
|
||||
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
|
||||
}
|
||||
|
||||
common_batch_clear(batch);
|
||||
|
|
@ -164,12 +164,12 @@ int main(int argc, char ** argv) {
|
|||
|
||||
LOG_INF("%s: shifting KV cache with %d\n", __func__, n_discard);
|
||||
|
||||
llama_past_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
||||
llama_past_seq_add (ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||
//llama_kv_cache_defrag(ctx);
|
||||
llama_kv_cache_update(ctx);
|
||||
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||
//llama_kv_cache_defrag (ctx);
|
||||
llama_kv_cache_update (ctx);
|
||||
|
||||
n_past = llama_past_seq_pos_max(ctx, 0) + 1;
|
||||
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
|
||||
|
||||
common_batch_clear(batch);
|
||||
|
||||
|
|
@ -195,12 +195,12 @@ int main(int argc, char ** argv) {
|
|||
if (n_discard > 0) {
|
||||
LOG_INF("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
|
||||
|
||||
llama_past_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
||||
llama_past_seq_add (ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||
//llama_kv_cache_defrag(ctx);
|
||||
llama_kv_cache_update(ctx);
|
||||
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||
//llama_kv_cache_defrag (ctx);
|
||||
llama_kv_cache_update (ctx);
|
||||
|
||||
n_past = llama_past_seq_pos_max(ctx, 0) + 1;
|
||||
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -406,7 +406,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
|
|||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// clear the KV cache
|
||||
llama_past_clear(ctx);
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
for (int j = 0; j < num_batches; ++j) {
|
||||
const int batch_start = start + j * n_batch;
|
||||
|
|
@ -580,7 +580,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
|
|||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// clear the KV cache
|
||||
llama_past_clear(ctx);
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
for (int j = 0; j < num_batches; ++j) {
|
||||
const int batch_start = start + j * n_batch;
|
||||
|
|
@ -955,7 +955,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
|
|||
return;
|
||||
}
|
||||
|
||||
llama_past_clear(ctx);
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
// decode all tasks [i0, i1)
|
||||
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
|
||||
|
|
@ -1232,7 +1232,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
|
|||
return;
|
||||
}
|
||||
|
||||
llama_past_clear(ctx);
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
// decode all tasks [i0, i1)
|
||||
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
|
||||
|
|
@ -1602,7 +1602,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
|
|||
return;
|
||||
}
|
||||
|
||||
llama_past_clear(ctx);
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
// decode all tasks [i0, i1)
|
||||
if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) {
|
||||
|
|
@ -1789,7 +1789,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
|
|||
}
|
||||
|
||||
// clear the KV cache
|
||||
llama_past_clear(ctx);
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
for (int j = 0; j < num_batches; ++j) {
|
||||
const int batch_start = start + j * n_batch;
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
|
|||
|
||||
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
|
||||
// clear previous kv_cache values (irrelevant for embeddings)
|
||||
llama_past_clear(ctx);
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
// run model
|
||||
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ int main(int argc, char ** argv) {
|
|||
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
|
||||
|
||||
// erase whole kv
|
||||
llama_past_clear(ctx3);
|
||||
llama_kv_cache_clear(ctx3);
|
||||
fprintf(stderr, "%s : kv cache cleared\n", __func__);
|
||||
|
||||
// restore kv into seq 1
|
||||
|
|
|
|||
|
|
@ -1029,7 +1029,7 @@ struct server_context {
|
|||
SRV_DBG("%s", "clearing KV cache\n");
|
||||
|
||||
// clear the entire KV cache
|
||||
llama_past_clear(ctx);
|
||||
llama_kv_cache_clear(ctx);
|
||||
clean_kv_cache = false;
|
||||
}
|
||||
|
||||
|
|
@ -1760,7 +1760,7 @@ struct server_context {
|
|||
|
||||
// Erase token cache
|
||||
const size_t n_erased = slot->cache_tokens.size();
|
||||
llama_past_seq_rm(ctx, slot->id + 1, -1, -1);
|
||||
llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
|
||||
slot->cache_tokens.clear();
|
||||
|
||||
server_task_result result;
|
||||
|
|
|
|||
|
|
@ -410,15 +410,15 @@ int main(int argc, char ** argv) {
|
|||
{
|
||||
LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
|
||||
|
||||
llama_past_seq_keep(ctx_dft, s_keep);
|
||||
llama_past_seq_cp (ctx_dft, s_keep, 0, -1, -1);
|
||||
llama_past_seq_keep(ctx_dft, 0);
|
||||
llama_kv_cache_seq_keep(ctx_dft, s_keep);
|
||||
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
|
||||
llama_kv_cache_seq_keep(ctx_dft, 0);
|
||||
|
||||
// FIXME: recurrent and hybrid models
|
||||
llama_past_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
|
||||
llama_past_seq_keep(ctx_tgt, s_keep);
|
||||
llama_past_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
|
||||
llama_past_seq_keep(ctx_tgt, 0);
|
||||
llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
|
||||
llama_kv_cache_seq_keep(ctx_tgt, s_keep);
|
||||
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
|
||||
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
||||
}
|
||||
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
|
|
@ -495,8 +495,8 @@ int main(int argc, char ** argv) {
|
|||
if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_split) {
|
||||
LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
||||
|
||||
llama_past_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
||||
llama_past_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
|
||||
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
||||
llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
|
||||
|
||||
// all previous tokens from this branch are now also part of the new branch
|
||||
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
|
||||
|
|
@ -577,9 +577,9 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// evaluate the target model on the drafted tokens
|
||||
{
|
||||
llama_past_seq_keep(ctx_tgt, 0);
|
||||
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
||||
for (int s = 1; s < n_seq_dft; ++s) {
|
||||
llama_past_seq_cp(ctx_tgt, 0, s, -1, -1);
|
||||
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
|
||||
}
|
||||
|
||||
// LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
|
||||
|
|
|
|||
|
|
@ -19825,6 +19825,7 @@ struct ggml_cplan ggml_graph_plan(
|
|||
cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
|
||||
}
|
||||
} break;
|
||||
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
{
|
||||
cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@
|
|||
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
||||
|
||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||
#define LLAMA_SESSION_VERSION 9
|
||||
#define LLAMA_SESSION_VERSION 10
|
||||
|
||||
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
||||
#define LLAMA_STATE_SEQ_VERSION 3
|
||||
|
|
@ -613,58 +613,35 @@ extern "C" {
|
|||
LLAMA_API int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx);
|
||||
|
||||
// Clear the KV cache and recurrent states - both cell info is erased and KV data is zeroed
|
||||
LLAMA_API void llama_past_clear(
|
||||
LLAMA_API void llama_kv_cache_clear(
|
||||
struct llama_context * ctx);
|
||||
LLAMA_API DEPRECATED(void llama_kv_cache_clear(
|
||||
struct llama_context * ctx),
|
||||
"use llama_past_clear instead");
|
||||
|
||||
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails
|
||||
// seq_id < 0 : match any sequence
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
// Returns n_past (one more than the largest remaining pos in the seq_id)
|
||||
// which is only meaningful to handle for partial removals.
|
||||
LLAMA_API llama_pos llama_past_seq_rm(
|
||||
LLAMA_API bool llama_kv_cache_seq_rm(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1);
|
||||
LLAMA_API DEPRECATED(bool llama_kv_cache_seq_rm(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1),
|
||||
"use llama_past_seq_rm instead, and handle its return value for partial removals");
|
||||
|
||||
// Copy all tokens that belong to the specified sequence to another sequence
|
||||
// Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
// Returns n_past (one more than the largest remaining pos in the destination seq_id)
|
||||
// which is only meaningful to handle when partially copying.
|
||||
LLAMA_API llama_pos llama_past_seq_cp(
|
||||
LLAMA_API void llama_kv_cache_seq_cp(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1);
|
||||
LLAMA_API DEPRECATED(void llama_kv_cache_seq_cp(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id_src,
|
||||
llama_seq_id seq_id_dst,
|
||||
llama_pos p0,
|
||||
llama_pos p1),
|
||||
"use llama_past_seq_cp instead, and handle its return value for partial copies");
|
||||
|
||||
// Removes all tokens that do not belong to the specified sequence
|
||||
LLAMA_API void llama_past_seq_keep(
|
||||
LLAMA_API void llama_kv_cache_seq_keep(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
LLAMA_API DEPRECATED(void llama_kv_cache_seq_keep(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id),
|
||||
"use llama_past_seq_keep instead");
|
||||
|
||||
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||
|
|
@ -672,19 +649,12 @@ extern "C" {
|
|||
// - explicitly with llama_kv_cache_update()
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_past_seq_add(
|
||||
LLAMA_API void llama_kv_cache_seq_add(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta);
|
||||
LLAMA_API DEPRECATED(void llama_kv_cache_seq_add(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
llama_pos delta),
|
||||
"use llama_past_seq_add instead");
|
||||
|
||||
// Integer division of the positions by factor of `d > 1`
|
||||
// If the KV cache is RoPEd, the KV data is updated accordingly:
|
||||
|
|
@ -692,28 +662,17 @@ extern "C" {
|
|||
// - explicitly with llama_kv_cache_update()
|
||||
// p0 < 0 : [0, p1]
|
||||
// p1 < 0 : [p0, inf)
|
||||
LLAMA_API void llama_past_seq_div(
|
||||
LLAMA_API void llama_kv_cache_seq_div(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d);
|
||||
LLAMA_API DEPRECATED(void llama_kv_cache_seq_div(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id,
|
||||
llama_pos p0,
|
||||
llama_pos p1,
|
||||
int d),
|
||||
"use llama_past_seq_div instead");
|
||||
|
||||
// Returns the largest position present in the KV and/or RS cache for the specified sequence
|
||||
LLAMA_API llama_pos llama_past_seq_pos_max(
|
||||
LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max(
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id),
|
||||
"use llama_past_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells");
|
||||
|
||||
// Defragment the KV cache
|
||||
// This will be applied:
|
||||
|
|
|
|||
1315
src/llama.cpp
1315
src/llama.cpp
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue