wip
This commit is contained in:
parent
36f8e20d08
commit
7664390bc8
|
|
@ -33,6 +33,9 @@ llama_context::llama_context(
|
|||
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
|
||||
}
|
||||
|
||||
const char * LLAMA_HT = getenv("LLAMA_HT");
|
||||
cparams.n_seq_virt = LLAMA_HT ? cparams.n_seq_max : 1;
|
||||
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch;
|
||||
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
||||
|
|
|
|||
|
|
@ -11,8 +11,9 @@ struct llama_cparams {
|
|||
uint32_t n_batch;
|
||||
uint32_t n_ubatch;
|
||||
uint32_t n_seq_max;
|
||||
int n_threads; // number of threads to use for generation
|
||||
int n_threads_batch; // number of threads to use for batch processing
|
||||
uint32_t n_seq_virt;
|
||||
int32_t n_threads; // number of threads to use for generation
|
||||
int32_t n_threads_batch; // number of threads to use for batch processing
|
||||
|
||||
float rope_freq_base;
|
||||
float rope_freq_scale;
|
||||
|
|
|
|||
|
|
@ -20,14 +20,15 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|||
bool swa_full,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_seq_virt,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad) : hparams(model.hparams) {
|
||||
uint32_t n_pad) : hparams(model.hparams), n_seq_virt(n_seq_virt) {
|
||||
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
||||
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
||||
|
||||
const uint32_t size_base = kv_size;
|
||||
|
||||
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
|
||||
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(n_seq_max/n_seq_virt) + n_ubatch, n_pad));
|
||||
|
||||
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
||||
if (swa_full) {
|
||||
|
|
@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|||
|
||||
kv_base = std::make_unique<llama_kv_cache_unified>(
|
||||
model, std::move(filter_base), type_k, type_v,
|
||||
v_trans, offload, size_base, n_seq_max, n_pad,
|
||||
v_trans, offload, size_base, n_seq_max, n_seq_virt, n_pad,
|
||||
0, LLAMA_SWA_TYPE_NONE);
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
||||
|
||||
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
||||
model, std::move(filter_swa), type_k, type_v,
|
||||
v_trans, offload, size_swa, n_seq_max, n_pad,
|
||||
v_trans, offload, size_swa, n_seq_max, n_seq_virt, n_pad,
|
||||
hparams.n_swa, hparams.swa_type);
|
||||
}
|
||||
|
||||
|
|
@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|||
|
||||
// first try simple split
|
||||
do {
|
||||
if (n_seq_virt > 1) {
|
||||
// requires equal splits
|
||||
break;
|
||||
}
|
||||
|
||||
balloc.split_reset();
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ public:
|
|||
bool swa_full,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_seq_virt,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad);
|
||||
|
||||
|
|
@ -68,6 +69,8 @@ public:
|
|||
private:
|
||||
const llama_hparams & hparams;
|
||||
|
||||
const uint32_t n_seq_virt = 1;
|
||||
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -25,11 +25,12 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||
bool offload,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_seq_virt,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type) :
|
||||
model(model), hparams(model.hparams), v_trans(v_trans),
|
||||
n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
|
||||
n_seq_max(n_seq_max), n_seq_virt(n_seq_virt), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
|
||||
|
||||
GGML_ASSERT(kv_size % n_pad == 0);
|
||||
|
||||
|
|
@ -92,8 +93,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||
ggml_tensor * k;
|
||||
ggml_tensor * v;
|
||||
|
||||
k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
|
||||
v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
|
||||
k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_seq_virt);
|
||||
v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_seq_virt);
|
||||
|
||||
ggml_format_name(k, "cache_k_l%d", il);
|
||||
ggml_format_name(v, "cache_v_l%d", il);
|
||||
|
|
@ -122,8 +123,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||
const size_t memory_size_k = size_k_bytes();
|
||||
const size_t memory_size_v = size_v_bytes();
|
||||
|
||||
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
||||
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
|
||||
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
||||
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_seq_virt,
|
||||
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||
}
|
||||
|
|
@ -325,7 +326,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
|
|||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
while (true) {
|
||||
auto ubatch = balloc.split_simple(n_ubatch);
|
||||
auto ubatch = n_seq_virt == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch);
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
|
|
@ -525,6 +526,10 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
|||
}
|
||||
|
||||
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
|
||||
if (n_seq_virt > 1) {
|
||||
GGML_ASSERT(!cont && "n_seq_virt > 1 does not support continuous slots");
|
||||
}
|
||||
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
uint32_t head_cur = this->head;
|
||||
|
|
@ -617,45 +622,51 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
|
|||
for (uint32_t i = 0; i < n_test; i++) {
|
||||
const auto idx = head_cur;
|
||||
|
||||
//const llama_pos pos = ubatch.pos[i];
|
||||
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||
|
||||
// can we use this cell? either:
|
||||
// - the cell is empty
|
||||
// - the cell is occupied only by one sequence:
|
||||
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
|
||||
// - mask SWA, using current max pos for that sequence in the cache
|
||||
// always insert in the cell with minimum pos
|
||||
bool can_use = cells.is_empty(idx);
|
||||
|
||||
if (!can_use && cells.seq_count(idx) == 1) {
|
||||
const llama_pos pos_cell = cells.pos_get(idx);
|
||||
|
||||
// (disabled) causal mask
|
||||
// note: it's better to purge any "future" tokens beforehand
|
||||
//if (cells.seq_has(idx, seq_id)) {
|
||||
// can_use = pos_cell >= pos;
|
||||
//}
|
||||
|
||||
if (!can_use) {
|
||||
const llama_seq_id seq_id_cell = cells.seq_get(idx);
|
||||
|
||||
// SWA mask
|
||||
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
||||
can_use = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
head_cur++;
|
||||
n_tested++;
|
||||
|
||||
if (can_use) {
|
||||
res.idxs[n_found] = idx;
|
||||
if (n_seq_virt == 1) {
|
||||
//const llama_pos pos = ubatch.pos[i];
|
||||
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||
|
||||
n_found++;
|
||||
// can we use this cell? either:
|
||||
// - the cell is empty
|
||||
// - the cell is occupied only by one sequence:
|
||||
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
|
||||
// - mask SWA, using current max pos for that sequence in the cache
|
||||
// always insert in the cell with minimum pos
|
||||
bool can_use = cells.is_empty(idx);
|
||||
|
||||
if (!can_use && cells.seq_count(idx) == 1) {
|
||||
const llama_pos pos_cell = cells.pos_get(idx);
|
||||
|
||||
// (disabled) causal mask
|
||||
// note: it's better to purge any "future" tokens beforehand
|
||||
//if (cells.seq_has(idx, seq_id)) {
|
||||
// can_use = pos_cell >= pos;
|
||||
//}
|
||||
|
||||
if (!can_use) {
|
||||
const llama_seq_id seq_id_cell = cells.seq_get(idx);
|
||||
|
||||
// SWA mask
|
||||
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
||||
can_use = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (can_use) {
|
||||
res.idxs[n_found] = idx;
|
||||
|
||||
n_found++;
|
||||
} else {
|
||||
if (cont) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
GGML_ABORT("WIP");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -52,9 +52,6 @@ public:
|
|||
void clear() {
|
||||
idxs.clear();
|
||||
}
|
||||
|
||||
// TODO: implement
|
||||
//std::vector<idx_vec_t> seq_idxs;
|
||||
};
|
||||
|
||||
using slot_info_vec_t = std::vector<slot_info>;
|
||||
|
|
@ -68,6 +65,7 @@ public:
|
|||
bool offload,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_seq_virt,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type);
|
||||
|
|
@ -173,7 +171,8 @@ private:
|
|||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
uint32_t head = 0;
|
||||
|
||||
const uint32_t n_seq_max = 1;
|
||||
const uint32_t n_seq_max = 1;
|
||||
const uint32_t n_seq_virt = 1;
|
||||
|
||||
// required padding
|
||||
const uint32_t n_pad = 1;
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ llama_memory_hybrid::llama_memory_hybrid(
|
|||
offload,
|
||||
kv_size,
|
||||
n_seq_max,
|
||||
1,
|
||||
n_pad,
|
||||
n_swa,
|
||||
swa_type
|
||||
|
|
|
|||
|
|
@ -13814,6 +13814,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
params.swa_full,
|
||||
cparams.n_ctx,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_seq_virt,
|
||||
cparams.n_ubatch,
|
||||
padding);
|
||||
} else {
|
||||
|
|
@ -13828,6 +13829,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
cparams.offload_kqv,
|
||||
cparams.n_ctx,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_seq_virt,
|
||||
padding,
|
||||
hparams.n_swa,
|
||||
hparams.swa_type);
|
||||
|
|
|
|||
Loading…
Reference in New Issue