This commit is contained in:
Georgi Gerganov 2025-06-23 16:29:02 +03:00
parent 36f8e20d08
commit 7664390bc8
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
8 changed files with 76 additions and 50 deletions

View File

@ -33,6 +33,9 @@ llama_context::llama_context(
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ)); throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
} }
const char * LLAMA_HT = getenv("LLAMA_HT");
cparams.n_seq_virt = LLAMA_HT ? cparams.n_seq_max : 1;
cparams.n_threads = params.n_threads; cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch; cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor; cparams.yarn_ext_factor = params.yarn_ext_factor;

View File

@ -11,8 +11,9 @@ struct llama_cparams {
uint32_t n_batch; uint32_t n_batch;
uint32_t n_ubatch; uint32_t n_ubatch;
uint32_t n_seq_max; uint32_t n_seq_max;
int n_threads; // number of threads to use for generation uint32_t n_seq_virt;
int n_threads_batch; // number of threads to use for batch processing int32_t n_threads; // number of threads to use for generation
int32_t n_threads_batch; // number of threads to use for batch processing
float rope_freq_base; float rope_freq_base;
float rope_freq_scale; float rope_freq_scale;

View File

@ -20,14 +20,15 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
bool swa_full, bool swa_full,
uint32_t kv_size, uint32_t kv_size,
uint32_t n_seq_max, uint32_t n_seq_max,
uint32_t n_seq_virt,
uint32_t n_ubatch, uint32_t n_ubatch,
uint32_t n_pad) : hparams(model.hparams) { uint32_t n_pad) : hparams(model.hparams), n_seq_virt(n_seq_virt) {
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
const uint32_t size_base = kv_size; const uint32_t size_base = kv_size;
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad)); uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(n_seq_max/n_seq_virt) + n_ubatch, n_pad));
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
if (swa_full) { if (swa_full) {
@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
kv_base = std::make_unique<llama_kv_cache_unified>( kv_base = std::make_unique<llama_kv_cache_unified>(
model, std::move(filter_base), type_k, type_v, model, std::move(filter_base), type_k, type_v,
v_trans, offload, size_base, n_seq_max, n_pad, v_trans, offload, size_base, n_seq_max, n_seq_virt, n_pad,
0, LLAMA_SWA_TYPE_NONE); 0, LLAMA_SWA_TYPE_NONE);
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
kv_swa = std::make_unique<llama_kv_cache_unified>( kv_swa = std::make_unique<llama_kv_cache_unified>(
model, std::move(filter_swa), type_k, type_v, model, std::move(filter_swa), type_k, type_v,
v_trans, offload, size_swa, n_seq_max, n_pad, v_trans, offload, size_swa, n_seq_max, n_seq_virt, n_pad,
hparams.n_swa, hparams.swa_type); hparams.n_swa, hparams.swa_type);
} }
@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
// first try simple split // first try simple split
do { do {
if (n_seq_virt > 1) {
// requires equal splits
break;
}
balloc.split_reset(); balloc.split_reset();
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;

View File

@ -22,6 +22,7 @@ public:
bool swa_full, bool swa_full,
uint32_t kv_size, uint32_t kv_size,
uint32_t n_seq_max, uint32_t n_seq_max,
uint32_t n_seq_virt,
uint32_t n_ubatch, uint32_t n_ubatch,
uint32_t n_pad); uint32_t n_pad);
@ -68,6 +69,8 @@ public:
private: private:
const llama_hparams & hparams; const llama_hparams & hparams;
const uint32_t n_seq_virt = 1;
std::unique_ptr<llama_kv_cache_unified> kv_base; std::unique_ptr<llama_kv_cache_unified> kv_base;
std::unique_ptr<llama_kv_cache_unified> kv_swa; std::unique_ptr<llama_kv_cache_unified> kv_swa;
}; };

View File

@ -25,11 +25,12 @@ llama_kv_cache_unified::llama_kv_cache_unified(
bool offload, bool offload,
uint32_t kv_size, uint32_t kv_size,
uint32_t n_seq_max, uint32_t n_seq_max,
uint32_t n_seq_virt,
uint32_t n_pad, uint32_t n_pad,
uint32_t n_swa, uint32_t n_swa,
llama_swa_type swa_type) : llama_swa_type swa_type) :
model(model), hparams(model.hparams), v_trans(v_trans), model(model), hparams(model.hparams), v_trans(v_trans),
n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { n_seq_max(n_seq_max), n_seq_virt(n_seq_virt), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
GGML_ASSERT(kv_size % n_pad == 0); GGML_ASSERT(kv_size % n_pad == 0);
@ -92,8 +93,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
ggml_tensor * k; ggml_tensor * k;
ggml_tensor * v; ggml_tensor * v;
k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size); k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_seq_virt);
v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size); v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_seq_virt);
ggml_format_name(k, "cache_k_l%d", il); ggml_format_name(k, "cache_k_l%d", il);
ggml_format_name(v, "cache_v_l%d", il); ggml_format_name(v, "cache_v_l%d", il);
@ -122,8 +123,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
const size_t memory_size_k = size_k_bytes(); const size_t memory_size_k = size_k_bytes();
const size_t memory_size_v = size_v_bytes(); const size_t memory_size_v = size_v_bytes();
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_seq_virt,
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
} }
@ -325,7 +326,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
std::vector<llama_ubatch> ubatches; std::vector<llama_ubatch> ubatches;
while (true) { while (true) {
auto ubatch = balloc.split_simple(n_ubatch); auto ubatch = n_seq_virt == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch);
if (ubatch.n_tokens == 0) { if (ubatch.n_tokens == 0) {
break; break;
@ -525,6 +526,10 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
} }
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const { llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
if (n_seq_virt > 1) {
GGML_ASSERT(!cont && "n_seq_virt > 1 does not support continuous slots");
}
const uint32_t n_tokens = ubatch.n_tokens; const uint32_t n_tokens = ubatch.n_tokens;
uint32_t head_cur = this->head; uint32_t head_cur = this->head;
@ -617,45 +622,51 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
for (uint32_t i = 0; i < n_test; i++) { for (uint32_t i = 0; i < n_test; i++) {
const auto idx = head_cur; const auto idx = head_cur;
//const llama_pos pos = ubatch.pos[i];
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
// can we use this cell? either:
// - the cell is empty
// - the cell is occupied only by one sequence:
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
// - mask SWA, using current max pos for that sequence in the cache
// always insert in the cell with minimum pos
bool can_use = cells.is_empty(idx);
if (!can_use && cells.seq_count(idx) == 1) {
const llama_pos pos_cell = cells.pos_get(idx);
// (disabled) causal mask
// note: it's better to purge any "future" tokens beforehand
//if (cells.seq_has(idx, seq_id)) {
// can_use = pos_cell >= pos;
//}
if (!can_use) {
const llama_seq_id seq_id_cell = cells.seq_get(idx);
// SWA mask
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
can_use = true;
}
}
}
head_cur++; head_cur++;
n_tested++; n_tested++;
if (can_use) { if (n_seq_virt == 1) {
res.idxs[n_found] = idx; //const llama_pos pos = ubatch.pos[i];
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
n_found++; // can we use this cell? either:
// - the cell is empty
// - the cell is occupied only by one sequence:
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
// - mask SWA, using current max pos for that sequence in the cache
// always insert in the cell with minimum pos
bool can_use = cells.is_empty(idx);
if (!can_use && cells.seq_count(idx) == 1) {
const llama_pos pos_cell = cells.pos_get(idx);
// (disabled) causal mask
// note: it's better to purge any "future" tokens beforehand
//if (cells.seq_has(idx, seq_id)) {
// can_use = pos_cell >= pos;
//}
if (!can_use) {
const llama_seq_id seq_id_cell = cells.seq_get(idx);
// SWA mask
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
can_use = true;
}
}
}
if (can_use) {
res.idxs[n_found] = idx;
n_found++;
} else {
if (cont) {
break;
}
}
} else { } else {
break; GGML_ABORT("WIP");
} }
} }

View File

@ -52,9 +52,6 @@ public:
void clear() { void clear() {
idxs.clear(); idxs.clear();
} }
// TODO: implement
//std::vector<idx_vec_t> seq_idxs;
}; };
using slot_info_vec_t = std::vector<slot_info>; using slot_info_vec_t = std::vector<slot_info>;
@ -68,6 +65,7 @@ public:
bool offload, bool offload,
uint32_t kv_size, uint32_t kv_size,
uint32_t n_seq_max, uint32_t n_seq_max,
uint32_t n_seq_virt,
uint32_t n_pad, uint32_t n_pad,
uint32_t n_swa, uint32_t n_swa,
llama_swa_type swa_type); llama_swa_type swa_type);
@ -173,7 +171,8 @@ private:
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
uint32_t head = 0; uint32_t head = 0;
const uint32_t n_seq_max = 1; const uint32_t n_seq_max = 1;
const uint32_t n_seq_virt = 1;
// required padding // required padding
const uint32_t n_pad = 1; const uint32_t n_pad = 1;

View File

@ -40,6 +40,7 @@ llama_memory_hybrid::llama_memory_hybrid(
offload, offload,
kv_size, kv_size,
n_seq_max, n_seq_max,
1,
n_pad, n_pad,
n_swa, n_swa,
swa_type swa_type

View File

@ -13814,6 +13814,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
params.swa_full, params.swa_full,
cparams.n_ctx, cparams.n_ctx,
cparams.n_seq_max, cparams.n_seq_max,
cparams.n_seq_virt,
cparams.n_ubatch, cparams.n_ubatch,
padding); padding);
} else { } else {
@ -13828,6 +13829,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.offload_kqv, cparams.offload_kqv,
cparams.n_ctx, cparams.n_ctx,
cparams.n_seq_max, cparams.n_seq_max,
cparams.n_seq_virt,
padding, padding,
hparams.n_swa, hparams.n_swa,
hparams.swa_type); hparams.swa_type);