From 5b27975479c4d9a0cffc3528af54b8d01ae46bfb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 17 Dec 2023 16:47:26 +0200 Subject: [PATCH] lookup : fix token positions in the draft batch --- common/common.h | 3 ++- examples/lookup/lookup.cpp | 41 ++++++++++++++++++++++++-------------- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/common/common.h b/common/common.h index ef2a61de6c..875e012a21 100644 --- a/common/common.h +++ b/common/common.h @@ -239,4 +239,5 @@ void dump_non_result_info_yaml( void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80); // Dump the KV cache view showing individual sequences in each cell (long output). -void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40); \ No newline at end of file +void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40); + diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index db97d241c7..6b4eb957af 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -19,6 +19,8 @@ int main(int argc, char ** argv){ // length of the candidate / draft sequence, if match is found const int n_draft = 10; + const bool dump_kv_cache = params.dump_kv_cache; + #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("lookup", "log")); LOG_TEE("Log start\n"); @@ -37,7 +39,7 @@ int main(int argc, char ** argv){ // tokenize the prompt const bool add_bos = llama_should_add_bos_token(model); LOG("add_bos tgt: %d\n", add_bos); - + std::vector inp; inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); @@ -69,24 +71,33 @@ int main(int argc, char ** argv){ int n_predict = 0; int n_drafted = 0; int n_accept = 0; - + int n_past = inp.size(); bool has_eos = false; struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); - std::vector draft(n_draft); + std::vector draft; llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1); + // debug + struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1); + const auto t_dec_start = ggml_time_us(); - while(true){ + while (true) { + // debug + if (dump_kv_cache) { + llama_kv_cache_view_update(ctx, &kvc_view); + dump_kv_cache_view_seqs(kvc_view, 40); + } + // print current draft sequence LOG("drafted %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, draft).c_str()); - int i_dft = 0; + int i_dft = 0; while (true) { // sample from the target model llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft); @@ -120,13 +131,13 @@ int main(int argc, char ** argv){ } continue; } - + if (params.use_color) { printf("%s", token_str.c_str()); - } + } fflush(stdout); - + LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str()); draft.clear(); @@ -135,7 +146,7 @@ int main(int argc, char ** argv){ break; } - if (n_predict > params.n_predict || has_eos) { + if ((params.n_predict > 0 && n_predict > params.n_predict) || has_eos) { break; } @@ -149,9 +160,9 @@ int main(int argc, char ** argv){ // generate n_pred tokens through prompt lookup auto prompt_lookup = [&]() -> void { int inp_size = inp.size(); - for (int ngram_size = max_ngram_size ; ngram_size > 0; --ngram_size){ + for (int ngram_size = max_ngram_size ; ngram_size > 0; --ngram_size){ const llama_token * ngram = &inp[inp_size - ngram_size]; - + for (int i = 0; i <= (int) inp_size - (ngram_size * 2); ++i) { bool match = true; for (int j = 0; j < ngram_size; ++j) { @@ -164,11 +175,11 @@ int main(int argc, char ** argv){ if (match) { const int startIdx = i + ngram_size; const int endIdx = startIdx + n_draft; - if (endIdx < inp_size){ + if (endIdx < inp_size) { for (int j = startIdx; j < endIdx; ++j) { LOG(" - draft candidate %d: %d\n", j, inp[j]); draft.push_back(inp[j]); - llama_batch_add(batch_tgt, inp[j], n_past + j + 1, { 0 }, true); + llama_batch_add(batch_tgt, inp[j], n_past + (j - startIdx) + 1, { 0 }, true); ++n_drafted; } return; @@ -180,7 +191,7 @@ int main(int argc, char ** argv){ }; prompt_lookup(); - + llama_decode(ctx, batch_tgt); ++n_past; @@ -215,4 +226,4 @@ int main(int argc, char ** argv){ fprintf(stderr, "\n\n"); return 0; -} \ No newline at end of file +}