From 7664390bc8682498770e8807ee2818d1d58eff99 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 23 Jun 2025 16:29:02 +0300 Subject: [PATCH] wip --- src/llama-context.cpp | 3 + src/llama-cparams.h | 5 +- src/llama-kv-cache-unified-iswa.cpp | 14 +++-- src/llama-kv-cache-unified-iswa.h | 3 + src/llama-kv-cache-unified.cpp | 91 ++++++++++++++++------------- src/llama-kv-cache-unified.h | 7 +-- src/llama-memory-hybrid.cpp | 1 + src/llama-model.cpp | 2 + 8 files changed, 76 insertions(+), 50 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 06e93b19cb..96c8d817cc 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -33,6 +33,9 @@ llama_context::llama_context( throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ)); } + const char * LLAMA_HT = getenv("LLAMA_HT"); + cparams.n_seq_virt = LLAMA_HT ? cparams.n_seq_max : 1; + cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 118615d5bd..c746337067 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -11,8 +11,9 @@ struct llama_cparams { uint32_t n_batch; uint32_t n_ubatch; uint32_t n_seq_max; - int n_threads; // number of threads to use for generation - int n_threads_batch; // number of threads to use for batch processing + uint32_t n_seq_virt; + int32_t n_threads; // number of threads to use for generation + int32_t n_threads_batch; // number of threads to use for batch processing float rope_freq_base; float rope_freq_scale; diff --git a/src/llama-kv-cache-unified-iswa.cpp b/src/llama-kv-cache-unified-iswa.cpp index 2c31cc4622..f0aac929c1 100644 --- a/src/llama-kv-cache-unified-iswa.cpp +++ b/src/llama-kv-cache-unified-iswa.cpp @@ -20,14 +20,15 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( bool swa_full, uint32_t kv_size, uint32_t n_seq_max, + uint32_t n_seq_virt, uint32_t n_ubatch, - uint32_t n_pad) : hparams(model.hparams) { + uint32_t n_pad) : hparams(model.hparams), n_seq_virt(n_seq_virt) { llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; const uint32_t size_base = kv_size; - uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad)); + uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(n_seq_max/n_seq_virt) + n_ubatch, n_pad)); // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size if (swa_full) { @@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( kv_base = std::make_unique( model, std::move(filter_base), type_k, type_v, - v_trans, offload, size_base, n_seq_max, n_pad, + v_trans, offload, size_base, n_seq_max, n_seq_virt, n_pad, 0, LLAMA_SWA_TYPE_NONE); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); kv_swa = std::make_unique( model, std::move(filter_swa), type_k, type_v, - v_trans, offload, size_swa, n_seq_max, n_pad, + v_trans, offload, size_swa, n_seq_max, n_seq_virt, n_pad, hparams.n_swa, hparams.swa_type); } @@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all // first try simple split do { + if (n_seq_virt > 1) { + // requires equal splits + break; + } + balloc.split_reset(); std::vector ubatches; diff --git a/src/llama-kv-cache-unified-iswa.h b/src/llama-kv-cache-unified-iswa.h index 23205d826b..8fbc5bab29 100644 --- a/src/llama-kv-cache-unified-iswa.h +++ b/src/llama-kv-cache-unified-iswa.h @@ -22,6 +22,7 @@ public: bool swa_full, uint32_t kv_size, uint32_t n_seq_max, + uint32_t n_seq_virt, uint32_t n_ubatch, uint32_t n_pad); @@ -68,6 +69,8 @@ public: private: const llama_hparams & hparams; + const uint32_t n_seq_virt = 1; + std::unique_ptr kv_base; std::unique_ptr kv_swa; }; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index f2bc035977..26bdd390b6 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -25,11 +25,12 @@ llama_kv_cache_unified::llama_kv_cache_unified( bool offload, uint32_t kv_size, uint32_t n_seq_max, + uint32_t n_seq_virt, uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), - n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { + n_seq_max(n_seq_max), n_seq_virt(n_seq_virt), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { GGML_ASSERT(kv_size % n_pad == 0); @@ -92,8 +93,8 @@ llama_kv_cache_unified::llama_kv_cache_unified( ggml_tensor * k; ggml_tensor * v; - k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size); - v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size); + k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_seq_virt); + v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_seq_virt); ggml_format_name(k, "cache_k_l%d", il); ggml_format_name(v, "cache_v_l%d", il); @@ -122,8 +123,8 @@ llama_kv_cache_unified::llama_kv_cache_unified( const size_t memory_size_k = size_k_bytes(); const size_t memory_size_v = size_v_bytes(); - LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_seq_virt, ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } @@ -325,7 +326,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch( std::vector ubatches; while (true) { - auto ubatch = balloc.split_simple(n_ubatch); + auto ubatch = n_seq_virt == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch); if (ubatch.n_tokens == 0) { break; @@ -525,6 +526,10 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d } llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const { + if (n_seq_virt > 1) { + GGML_ASSERT(!cont && "n_seq_virt > 1 does not support continuous slots"); + } + const uint32_t n_tokens = ubatch.n_tokens; uint32_t head_cur = this->head; @@ -617,45 +622,51 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ for (uint32_t i = 0; i < n_test; i++) { const auto idx = head_cur; - //const llama_pos pos = ubatch.pos[i]; - //const llama_seq_id seq_id = ubatch.seq_id[i][0]; - - // can we use this cell? either: - // - the cell is empty - // - the cell is occupied only by one sequence: - // - (disabled) mask causally, if the sequence is the same as the one we are inserting - // - mask SWA, using current max pos for that sequence in the cache - // always insert in the cell with minimum pos - bool can_use = cells.is_empty(idx); - - if (!can_use && cells.seq_count(idx) == 1) { - const llama_pos pos_cell = cells.pos_get(idx); - - // (disabled) causal mask - // note: it's better to purge any "future" tokens beforehand - //if (cells.seq_has(idx, seq_id)) { - // can_use = pos_cell >= pos; - //} - - if (!can_use) { - const llama_seq_id seq_id_cell = cells.seq_get(idx); - - // SWA mask - if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { - can_use = true; - } - } - } - head_cur++; n_tested++; - if (can_use) { - res.idxs[n_found] = idx; + if (n_seq_virt == 1) { + //const llama_pos pos = ubatch.pos[i]; + //const llama_seq_id seq_id = ubatch.seq_id[i][0]; - n_found++; + // can we use this cell? either: + // - the cell is empty + // - the cell is occupied only by one sequence: + // - (disabled) mask causally, if the sequence is the same as the one we are inserting + // - mask SWA, using current max pos for that sequence in the cache + // always insert in the cell with minimum pos + bool can_use = cells.is_empty(idx); + + if (!can_use && cells.seq_count(idx) == 1) { + const llama_pos pos_cell = cells.pos_get(idx); + + // (disabled) causal mask + // note: it's better to purge any "future" tokens beforehand + //if (cells.seq_has(idx, seq_id)) { + // can_use = pos_cell >= pos; + //} + + if (!can_use) { + const llama_seq_id seq_id_cell = cells.seq_get(idx); + + // SWA mask + if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { + can_use = true; + } + } + } + + if (can_use) { + res.idxs[n_found] = idx; + + n_found++; + } else { + if (cont) { + break; + } + } } else { - break; + GGML_ABORT("WIP"); } } diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 1f72afd21c..2d361549fe 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -52,9 +52,6 @@ public: void clear() { idxs.clear(); } - - // TODO: implement - //std::vector seq_idxs; }; using slot_info_vec_t = std::vector; @@ -68,6 +65,7 @@ public: bool offload, uint32_t kv_size, uint32_t n_seq_max, + uint32_t n_seq_virt, uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type); @@ -173,7 +171,8 @@ private: // note: this is not part of the KV state and it's only used to speed-up the find_slot() method uint32_t head = 0; - const uint32_t n_seq_max = 1; + const uint32_t n_seq_max = 1; + const uint32_t n_seq_virt = 1; // required padding const uint32_t n_pad = 1; diff --git a/src/llama-memory-hybrid.cpp b/src/llama-memory-hybrid.cpp index 0333b8b232..e8d3b581ae 100644 --- a/src/llama-memory-hybrid.cpp +++ b/src/llama-memory-hybrid.cpp @@ -40,6 +40,7 @@ llama_memory_hybrid::llama_memory_hybrid( offload, kv_size, n_seq_max, + 1, n_pad, n_swa, swa_type diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 9b19da9840..01b4266c7f 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13814,6 +13814,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, params.swa_full, cparams.n_ctx, cparams.n_seq_max, + cparams.n_seq_virt, cparams.n_ubatch, padding); } else { @@ -13828,6 +13829,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.offload_kqv, cparams.n_ctx, cparams.n_seq_max, + cparams.n_seq_virt, padding, hparams.n_swa, hparams.swa_type);