diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index b3198b7e3a..0698f2210c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -40,6 +40,33 @@ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_ngram_ids::set_input(const llama_ubatch * ubatch) { + GGML_ASSERT(!ubatch->embd); + GGML_ASSERT(ubatch->token); + const int64_t n_tokens = ubatch->n_tokens; + + // each token have a context of ngram_k ids + std::vector> ngrams; + ngrams.reserve(ubatch->n_tokens); + for (size_t i = 0; i < (size_t) n_tokens; ++i) { + auto ngram = mctx->get_last_n_tokens(ngram_n, + ubatch->pos[i], + ubatch->seq_id[i][0] /* FIXME: support multiple seq ids */); + + printf("token[%zu] = %d : ngram =", i, ubatch->token[i]); + for (size_t j = 0; j < ngram.size(); ++j) { + printf(" %d", ngram[j]); + } + printf("\n"); + ngrams.push_back(std::move(ngram)); + } + + if (ubatch->pos) { exit(1); } // TEST ONLY + + if (ubatch->pos && pos_ngram) { + } +} + void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) { if (ubatch->pos && pos) { const int64_t n_tokens = ubatch->n_tokens; @@ -1471,6 +1498,15 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const { return cur; } +ggml_tensor * llm_graph_context::build_inp_ngram_ids() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp = std::make_unique(4, 4, mctx_cur); + res->add_input(std::move(inp)); + + return nullptr; // TODO +} + ggml_tensor * llm_graph_context::build_inp_out_ids() const { // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls, // but this would make the graph topology depend on the number of output tokens, which can interere with diff --git a/src/llama-graph.h b/src/llama-graph.h index 4090d8116c..1a2b5fa254 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -119,6 +119,21 @@ public: const int64_t n_embd = 0; }; +class llm_graph_input_ngram_ids : public llm_graph_input_i { +public: + llm_graph_input_ngram_ids(uint32_t ngram_n, uint32_t ngram_k, const llama_kv_cache_context * mctx) + : ngram_n(ngram_n), ngram_k(ngram_k), mctx(mctx) {} + virtual ~llm_graph_input_ngram_ids() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * pos_ngram = nullptr; // I32 [n_batch, ngram_k] + + uint32_t ngram_n = 0; + uint32_t ngram_k = 0; + const llama_kv_cache_context * mctx; +}; + class llm_graph_input_pos : public llm_graph_input_i { public: llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {} @@ -816,6 +831,7 @@ struct llm_graph_context { ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const; ggml_tensor * build_inp_pos() const; ggml_tensor * build_inp_attn_scale() const; + ggml_tensor * build_inp_ngram_ids() const; ggml_tensor * build_inp_out_ids() const; ggml_tensor * build_inp_mean() const; ggml_tensor * build_inp_cls() const; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index f3c9b49f30..1bbe2ed99d 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -933,8 +933,20 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & if (ubatch.is_pos_2d()) { llama_kv_cell_ext ext { - /*.x =*/ ubatch.pos[i + ubatch.n_tokens*2], - /*.y =*/ ubatch.pos[i + ubatch.n_tokens], + /*.x =*/ ubatch.pos[i + ubatch.n_tokens*2], + /*.y =*/ ubatch.pos[i + ubatch.n_tokens], + /*.id =*/ 0, // unused + }; + cells.ext_set(idx, ext); + } + + if (ubatch.token) { + // save token id for ngram embeddings + GGML_ASSERT(!ubatch.embd); + llama_kv_cell_ext ext { + /*.x =*/ 0, // unused + /*.y =*/ 0, // unused + /*.id =*/ ubatch.token[i], }; cells.ext_set(idx, ext); } @@ -1500,6 +1512,40 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch } } +std::vector llama_kv_cache::get_last_n_tokens(size_t n, llama_pos pos, llama_seq_id seq_id) const { + std::vector result; + result.resize(n, 0); + + for (uint32_t s = 0; s < n_stream; ++s) { + const auto & cell = v_cells[s]; + + // TODO: linear scan is inefficient, optimize this later + for (uint32_t i = 0; i < cell.size(); ++i) { + if (!cell.seq_has(i, seq_id)) { + continue; + } + + const llama_pos p = cell.pos_get(i); + const llama_token tok = cell.ext_get(i).id; + + // check distance: (pos - n) <= p < pos + if (pos - (llama_pos) n <= p && p < pos) { + // make sure last token goes last + size_t insert_pos = n - (size_t)(pos - p); + // this assert should mathematically hold, but added for clarity + GGML_ASSERT(insert_pos < n); + result[insert_pos] = tok; + } + } + + if (result.size() >= n) { + break; + } + } + + return result; +} + size_t llama_kv_cache::total_size() const { size_t size = 0; @@ -2264,3 +2310,7 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { kv->set_input_pos_bucket(dst, ubatch); } + +std::vector llama_kv_cache_context::get_last_n_tokens(size_t n, llama_pos pos, llama_seq_id seq_id) const { + return kv->get_last_n_tokens(n, pos, seq_id); +} diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index e194bf3e26..925f6eada0 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -199,6 +199,10 @@ public: void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + // used by ngram embeddings + // output order: token with higher pos first + std::vector get_last_n_tokens(size_t n, llama_pos pos, llama_seq_id seq_id) const; + private: const llama_model & model; const llama_hparams & hparams; @@ -353,6 +357,9 @@ public: void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + // used by ngram embeddings + std::vector get_last_n_tokens(size_t n, llama_pos pos, llama_seq_id seq_id) const; + private: llama_memory_status status; diff --git a/src/llama-kv-cells.h b/src/llama-kv-cells.h index 10063bf427..dac68909d8 100644 --- a/src/llama-kv-cells.h +++ b/src/llama-kv-cells.h @@ -15,6 +15,10 @@ struct llama_kv_cell_ext { llama_pos x = 0; llama_pos y = 0; + // token ID, used by ngram embeddings + // currently default to 0, according to longcat-ngram implementation + llama_token id = 0; + // return true if the current 2D spatial position is greater than other bool is_2d_gt(llama_pos ox, llama_pos oy) const { return (y > oy) || (y == oy && x > ox); diff --git a/src/models/llama.cpp b/src/models/llama.cpp index 42b5fcdf42..986b83bee5 100644 --- a/src/models/llama.cpp +++ b/src/models/llama.cpp @@ -24,6 +24,9 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_gra inp_attn = build_attn_inp_kv(); } + // TEST ONLY + build_inp_ngram_ids(); + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; ggml_tensor * inp_out_ids = build_inp_out_ids();