parent
30b4d4e1b3
commit
dfceb012ee
|
|
@ -235,7 +235,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
|
||||
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
|
||||
llama_batch batch = llama_batch_init(n_ctx, 0, 1);
|
||||
llama_batch batch = llama_batch_init(n_ctx*n_clients, 0, 1);
|
||||
|
||||
int32_t n_total_prompt = 0;
|
||||
int32_t n_total_gen = 0;
|
||||
|
|
@ -289,6 +289,7 @@ int main(int argc, char ** argv) {
|
|||
// all sequences have ended - clear the entire KV cache
|
||||
for (int i = 1; i <= n_clients; ++i) {
|
||||
llama_memory_seq_rm(mem, i, -1, -1);
|
||||
|
||||
// but keep the system prompt
|
||||
llama_memory_seq_cp(mem, 0, i, -1, -1);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4696,7 +4696,6 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
|||
|
||||
if (mask) {
|
||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||
GGML_ASSERT(mask->ne[2] == q->ne[3]);
|
||||
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
|
||||
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
|
||||
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
|
||||
|
|
|
|||
|
|
@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
|
|||
|
||||
// note: tracking the other way around is not necessary for now
|
||||
//seq_cpl[s0][s1] = true;
|
||||
|
||||
has_cpl = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -405,6 +407,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
|
|||
return n_outputs;
|
||||
}
|
||||
|
||||
uint32_t llama_batch_allocr::get_n_used() const {
|
||||
return n_used;
|
||||
}
|
||||
|
||||
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
|
||||
return out_ids;
|
||||
}
|
||||
|
|
@ -420,6 +426,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
|
|||
void llama_batch_allocr::split_reset() {
|
||||
out_ids.clear();
|
||||
|
||||
n_used = 0;
|
||||
|
||||
used.clear();
|
||||
used.resize(get_n_tokens(), false);
|
||||
|
||||
|
|
@ -444,6 +452,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
|||
idxs.push_back(cur_idx);
|
||||
|
||||
used[cur_idx] = true;
|
||||
++n_used;
|
||||
|
||||
++cur_idx;
|
||||
|
||||
|
|
@ -459,9 +468,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
|||
return ubatch_add(idxs, idxs.size(), false);
|
||||
}
|
||||
|
||||
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
||||
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
|
||||
if (sequential && has_cpl) {
|
||||
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<seq_set_t> cur_seq_set;
|
||||
|
||||
llama_seq_id last_seq_id = -1;
|
||||
|
||||
// determine the non-overlapping sequence sets participating in this ubatch
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
if (used[i]) {
|
||||
|
|
@ -478,9 +495,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
|||
}
|
||||
}
|
||||
|
||||
// accept only increasing sequence ids
|
||||
if (sequential) {
|
||||
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
|
||||
}
|
||||
|
||||
if (add) {
|
||||
cur_seq_set.push_back(seq_set[i]);
|
||||
|
||||
last_seq_id = batch.seq_id[i][0];
|
||||
|
||||
if (cur_seq_set.size() > n_ubatch) {
|
||||
break;
|
||||
}
|
||||
|
|
@ -529,6 +553,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
|||
idxs_per_seq[s].push_back(idx);
|
||||
|
||||
used[idx] = true;
|
||||
++n_used;
|
||||
|
||||
++cur_idx[s];
|
||||
}
|
||||
|
|
@ -570,6 +595,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
|
|||
idxs.push_back(cur_idx);
|
||||
|
||||
used[cur_idx] = true;
|
||||
++n_used;
|
||||
|
||||
if (idxs.size() >= n_ubatch) {
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ public:
|
|||
|
||||
uint32_t get_n_tokens() const;
|
||||
uint32_t get_n_outputs() const;
|
||||
uint32_t get_n_used() const;
|
||||
|
||||
// the array of output indices in the order they were encountered during the ubatch splitting
|
||||
std::vector<int32_t> & get_out_ids();
|
||||
|
|
@ -69,7 +70,8 @@ public:
|
|||
llama_ubatch split_simple(uint32_t n_ubatch);
|
||||
|
||||
// make ubatches of equal-length sequences sets
|
||||
llama_ubatch split_equal(uint32_t n_ubatch);
|
||||
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
|
||||
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
|
||||
|
||||
// sequence-set-wise split - each ubatch contains a single sequence-set
|
||||
llama_ubatch split_seq(uint32_t n_ubatch);
|
||||
|
|
@ -112,6 +114,9 @@ private:
|
|||
using pos_set_t = std::set<llama_pos>;
|
||||
using seq_cpl_t = std::vector<bool>;
|
||||
|
||||
// helper flag to quickly determine if there are any coupled sequences in the batch
|
||||
bool has_cpl;
|
||||
|
||||
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
|
||||
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
|
||||
|
||||
|
|
@ -125,6 +130,8 @@ private:
|
|||
// batch indices of the output
|
||||
std::vector<int32_t> out_ids;
|
||||
|
||||
uint32_t n_used;
|
||||
|
||||
// used[i] indicates if token i has already been used in a previous ubatch
|
||||
std::vector<bool> used;
|
||||
|
||||
|
|
|
|||
|
|
@ -33,6 +33,9 @@ llama_context::llama_context(
|
|||
throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ));
|
||||
}
|
||||
|
||||
const char * LLAMA_HT = getenv("LLAMA_HT");
|
||||
cparams.n_seq_virt = LLAMA_HT ? cparams.n_seq_max : 1;
|
||||
|
||||
cparams.n_threads = params.n_threads;
|
||||
cparams.n_threads_batch = params.n_threads_batch;
|
||||
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
||||
|
|
@ -1308,7 +1311,8 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
|||
this->n_outputs = n_outputs;
|
||||
|
||||
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
|
||||
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
||||
//llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
||||
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens, 1);
|
||||
|
||||
auto * gf = graph_init();
|
||||
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
|
||||
|
|
|
|||
|
|
@ -11,8 +11,9 @@ struct llama_cparams {
|
|||
uint32_t n_batch;
|
||||
uint32_t n_ubatch;
|
||||
uint32_t n_seq_max;
|
||||
int n_threads; // number of threads to use for generation
|
||||
int n_threads_batch; // number of threads to use for batch processing
|
||||
uint32_t n_seq_virt;
|
||||
int32_t n_threads; // number of threads to use for generation
|
||||
int32_t n_threads_batch; // number of threads to use for batch processing
|
||||
|
||||
float rope_freq_base;
|
||||
float rope_freq_scale;
|
||||
|
|
|
|||
|
|
@ -1001,12 +1001,12 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
|||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
||||
|
||||
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
|
||||
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
|
||||
|
||||
inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
|
|
@ -1033,6 +1033,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||
float kq_scale) const {
|
||||
const bool v_trans = v->nb[1] > v->nb[2];
|
||||
|
||||
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
|
||||
|
||||
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_seqs, n_seqs);
|
||||
|
||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
||||
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
||||
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
||||
|
|
@ -1081,7 +1085,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||
#endif
|
||||
}
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_seqs);
|
||||
} else {
|
||||
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
|
||||
|
|
@ -1126,7 +1130,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||
|
||||
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens*n_seqs);
|
||||
|
||||
if (!cparams.offload_kqv) {
|
||||
// all nodes between the KV store and the attention output are run on the CPU
|
||||
|
|
@ -1205,11 +1209,12 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
|
|||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
||||
|
||||
const auto n_kv = mctx_cur->get_n_kv();
|
||||
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
|
||||
|
||||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
|
|
@ -1451,13 +1456,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
||||
|
||||
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
|
||||
|
||||
{
|
||||
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
||||
|
||||
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
|
|
@ -1471,7 +1478,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|||
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), 1, n_seqs);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
|
|
|
|||
|
|
@ -255,10 +255,10 @@ public:
|
|||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
|
||||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
|
@ -289,14 +289,14 @@ public:
|
|||
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
||||
|
||||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
|
||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_seqs, n_seqs]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_seqs, n_seqs]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
|
|
|||
|
|
@ -20,14 +20,15 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|||
bool swa_full,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_seq_virt,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad) : hparams(model.hparams) {
|
||||
uint32_t n_pad) : hparams(model.hparams), n_seq_virt(n_seq_virt) {
|
||||
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
||||
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
||||
|
||||
const uint32_t size_base = kv_size;
|
||||
|
||||
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
|
||||
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(n_seq_max/n_seq_virt) + n_ubatch, n_pad));
|
||||
|
||||
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
||||
if (swa_full) {
|
||||
|
|
@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|||
|
||||
kv_base = std::make_unique<llama_kv_cache_unified>(
|
||||
model, std::move(filter_base), type_k, type_v,
|
||||
v_trans, offload, size_base, n_seq_max, n_pad,
|
||||
v_trans, offload, size_base, n_seq_max, n_seq_virt, n_pad,
|
||||
0, LLAMA_SWA_TYPE_NONE);
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
||||
|
||||
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
||||
model, std::move(filter_swa), type_k, type_v,
|
||||
v_trans, offload, size_swa, n_seq_max, n_pad,
|
||||
v_trans, offload, size_swa, n_seq_max, n_seq_virt, n_pad,
|
||||
hparams.n_swa, hparams.swa_type);
|
||||
}
|
||||
|
||||
|
|
@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|||
|
||||
// first try simple split
|
||||
do {
|
||||
if (n_seq_virt > 1) {
|
||||
// requires equal splits, so we skip the simple split
|
||||
break;
|
||||
}
|
||||
|
||||
balloc.split_reset();
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
|
@ -113,6 +119,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
break;
|
||||
}
|
||||
|
||||
auto sinfos_base = kv_base->prepare(ubatches);
|
||||
if (sinfos_base.empty()) {
|
||||
break;
|
||||
|
|
@ -135,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
while (true) {
|
||||
auto ubatch = balloc.split_equal(n_ubatch);
|
||||
auto ubatch = balloc.split_equal(n_ubatch, n_seq_virt > 1);
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
|
|
@ -144,6 +155,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
break;
|
||||
}
|
||||
|
||||
auto sinfos_base = kv_base->prepare(ubatches);
|
||||
if (sinfos_base.empty()) {
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ public:
|
|||
bool swa_full,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_seq_virt,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad);
|
||||
|
||||
|
|
@ -68,6 +69,8 @@ public:
|
|||
private:
|
||||
const llama_hparams & hparams;
|
||||
|
||||
const uint32_t n_seq_virt = 1;
|
||||
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -25,11 +25,12 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||
bool offload,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_seq_virt,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type) :
|
||||
model(model), hparams(model.hparams), v_trans(v_trans),
|
||||
n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
|
||||
n_seq_max(n_seq_max), n_seq_virt(n_seq_virt), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
|
||||
|
||||
GGML_ASSERT(kv_size % n_pad == 0);
|
||||
|
||||
|
|
@ -45,7 +46,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||
auto it = ctx_map.find(buft);
|
||||
if (it == ctx_map.end()) {
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
|
||||
/*.mem_size =*/ size_t(2u*(1 + n_seq_virt)*n_layer_cache*ggml_tensor_overhead()),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
|
@ -64,9 +65,27 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||
return it->second;
|
||||
};
|
||||
|
||||
head = 0;
|
||||
GGML_ASSERT(n_seq_virt == 1 || n_seq_virt == n_seq_max);
|
||||
|
||||
cells.resize(kv_size);
|
||||
v_heads.resize(n_seq_virt);
|
||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||
v_heads[s] = 0;
|
||||
}
|
||||
|
||||
v_cells.resize(n_seq_virt);
|
||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||
v_cells[s].resize(kv_size);
|
||||
}
|
||||
|
||||
// by default, all sequence ids are mapped to the 0th virtual sequence
|
||||
seq_virt_idx.resize(LLAMA_MAX_SEQ, 0);
|
||||
|
||||
if (n_seq_virt > 1) {
|
||||
seq_virt_idx.resize(n_seq_virt, 0);
|
||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||
seq_virt_idx[s] = s;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t il = 0; il < n_layer_cache; il++) {
|
||||
if (filter && !filter(il)) {
|
||||
|
|
@ -98,14 +117,23 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||
ggml_tensor * k;
|
||||
ggml_tensor * v;
|
||||
|
||||
k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size);
|
||||
v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size);
|
||||
k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_seq_virt);
|
||||
v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_seq_virt);
|
||||
|
||||
ggml_format_name(k, "cache_k_l%d", il);
|
||||
ggml_format_name(v, "cache_v_l%d", il);
|
||||
|
||||
std::vector<ggml_tensor *> k_seq;
|
||||
std::vector<ggml_tensor *> v_seq;
|
||||
|
||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||
k_seq.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
|
||||
v_seq.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
|
||||
}
|
||||
|
||||
map_layer_ids[il] = layers.size();
|
||||
layers.push_back({ il, k, v });
|
||||
|
||||
layers.push_back({ il, k, v, k_seq, v_seq, });
|
||||
}
|
||||
|
||||
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
|
||||
|
|
@ -148,8 +176,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||
const size_t memory_size_k = size_k_bytes();
|
||||
const size_t memory_size_v = size_v_bytes();
|
||||
|
||||
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
||||
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
|
||||
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
||||
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_seq_virt,
|
||||
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||
}
|
||||
|
|
@ -166,9 +194,10 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||
}
|
||||
|
||||
void llama_kv_cache_unified::clear(bool data) {
|
||||
cells.reset();
|
||||
|
||||
head = 0;
|
||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||
v_cells[s].reset();
|
||||
v_heads[s] = 0;
|
||||
}
|
||||
|
||||
if (data) {
|
||||
for (auto & buf : bufs) {
|
||||
|
|
@ -178,6 +207,9 @@ void llama_kv_cache_unified::clear(bool data) {
|
|||
}
|
||||
|
||||
bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
auto & cells = v_cells[seq_virt_idx[seq_id]];
|
||||
auto & head = v_heads[seq_virt_idx[seq_id]];
|
||||
|
||||
uint32_t new_head = cells.size();
|
||||
|
||||
if (p0 < 0) {
|
||||
|
|
@ -224,6 +256,12 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|||
}
|
||||
|
||||
void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
const auto s0 = seq_virt_idx[seq_id_src];
|
||||
const auto s1 = seq_virt_idx[seq_id_dst];
|
||||
|
||||
if (s0 == s1) {
|
||||
auto & cells = v_cells[s0];
|
||||
|
||||
if (seq_id_src == seq_id_dst) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -245,9 +283,55 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
|
|||
cells.seq_add(i, seq_id_dst);
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
bool is_full = true;
|
||||
|
||||
if (p0 > 0 && p0 + 1 < (int) get_size()) {
|
||||
is_full = false;
|
||||
}
|
||||
|
||||
if (p1 > 0 && p1 + 1 < (int) get_size()) {
|
||||
is_full = false;
|
||||
}
|
||||
|
||||
GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers");
|
||||
|
||||
//LLAMA_LOG_WARN("%s: copying KV buffer from %d (virt = %d) to %d (virt = %d)\n", __func__, seq_id_src, s0, seq_id_dst, s1);
|
||||
|
||||
for (uint32_t il = 0; il < layers.size(); ++il) {
|
||||
const auto & layer = layers[il];
|
||||
|
||||
ggml_backend_tensor_copy(layer.k_seq[s0], layer.k_seq[s1]);
|
||||
ggml_backend_tensor_copy(layer.v_seq[s0], layer.v_seq[s1]);
|
||||
|
||||
// TODO: do we need synchronization here?
|
||||
}
|
||||
|
||||
// TODO: support this:
|
||||
GGML_ASSERT(v_cells[s0].get_has_shift() == false && "cannot copy a KV buffer that has a pending shift");
|
||||
|
||||
v_cells[s1].reset();
|
||||
for (uint32_t i = 0; i < v_cells[s0].size(); ++i) {
|
||||
if (v_cells[s0].seq_has(i, seq_id_src)) {
|
||||
v_cells[s1].pos_set(i, v_cells[s0].pos_get(i));
|
||||
v_cells[s1].seq_add(i, seq_id_dst);
|
||||
}
|
||||
}
|
||||
|
||||
v_heads[s1] = v_heads[s0];
|
||||
|
||||
//for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||
// LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s));
|
||||
//}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
||||
auto & cells = v_cells[seq_virt_idx[seq_id]];
|
||||
auto & head = v_heads[seq_virt_idx[seq_id]];
|
||||
|
||||
uint32_t new_head = cells.size();
|
||||
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
|
|
@ -265,6 +349,9 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
|
|||
}
|
||||
|
||||
void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
auto & cells = v_cells[seq_virt_idx[seq_id]];
|
||||
auto & head = v_heads[seq_virt_idx[seq_id]];
|
||||
|
||||
if (shift == 0) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -304,6 +391,8 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
|
|||
}
|
||||
|
||||
void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
auto & cells = v_cells[seq_virt_idx[seq_id]];
|
||||
|
||||
if (d == 1) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -333,10 +422,14 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
|
|||
}
|
||||
|
||||
llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
|
||||
const auto & cells = v_cells[seq_virt_idx[seq_id]];
|
||||
|
||||
return cells.seq_pos_min(seq_id);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
||||
const auto & cells = v_cells[seq_virt_idx[seq_id]];
|
||||
|
||||
return cells.seq_pos_max(seq_id);
|
||||
}
|
||||
|
||||
|
|
@ -351,7 +444,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
|
|||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
while (true) {
|
||||
auto ubatch = balloc.split_simple(n_ubatch);
|
||||
auto ubatch = n_seq_virt == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
|
|
@ -360,6 +453,11 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
|
|||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
break;
|
||||
}
|
||||
|
||||
auto sinfos = prepare(ubatches);
|
||||
if (sinfos.empty()) {
|
||||
break;
|
||||
|
|
@ -382,7 +480,10 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
|
|||
defrag_info dinfo;
|
||||
|
||||
// see if we need to defrag
|
||||
{
|
||||
if (n_seq_virt == 1) {
|
||||
// note : for now do not consider defrag for n_seq_virt > 1
|
||||
const auto & cells = v_cells[seq_virt_idx[0]];
|
||||
|
||||
bool do_defrag = optimize;
|
||||
|
||||
const auto thold = lctx->get_cparams().defrag_thold;
|
||||
|
|
@ -412,16 +513,16 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
|
|||
llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||
llama_kv_cache_unified::slot_info_vec_t res;
|
||||
|
||||
struct state {
|
||||
uint32_t head_old; // old position of the head, before placing the ubatch
|
||||
|
||||
struct state_t {
|
||||
slot_info sinfo; // slot info for the ubatch
|
||||
|
||||
llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
|
||||
std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
|
||||
|
||||
std::vector<llama_kv_cells_unified> v_cells; // copy of the old cells, before placing the ubatch
|
||||
};
|
||||
|
||||
// remember the old state of the cells so we can restore it in the end
|
||||
std::vector<state> states;
|
||||
std::vector<state_t> states;
|
||||
|
||||
bool success = true;
|
||||
|
||||
|
|
@ -440,16 +541,35 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
|
|||
res.push_back(sinfo_new);
|
||||
|
||||
// store the old state of the cells in the recovery stack
|
||||
states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)});
|
||||
{
|
||||
state_t state = { sinfo_new, v_heads, {} };
|
||||
|
||||
for (uint32_t s = 0; s < sinfo_new.n_seq_virt(); ++s) {
|
||||
auto & cells = v_cells[sinfo_new.seq_id_virt[s]];
|
||||
|
||||
state.v_cells.push_back(cells.cp(sinfo_new.idxs[s]));
|
||||
}
|
||||
|
||||
states.push_back(std::move(state));
|
||||
}
|
||||
|
||||
// now emplace the ubatch
|
||||
apply_ubatch(sinfo_new, ubatch);
|
||||
}
|
||||
|
||||
GGML_ASSERT(!states.empty());
|
||||
|
||||
// iterate backwards and restore the cells to their original state
|
||||
for (auto it = states.rbegin(); it != states.rend(); ++it) {
|
||||
cells.set(it->sinfo.idxs, it->cells);
|
||||
head = it->head_old;
|
||||
const auto & sinfo = it->sinfo;
|
||||
|
||||
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
|
||||
auto & cells = v_cells[sinfo.seq_id_virt[s]];
|
||||
auto & head = v_heads[sinfo.seq_id_virt[s]];
|
||||
|
||||
cells.set(sinfo.idxs[s], it->v_cells[s]);
|
||||
head = it->v_heads_old[s];
|
||||
}
|
||||
}
|
||||
|
||||
if (!success) {
|
||||
|
|
@ -498,12 +618,20 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
|||
updated = true;
|
||||
}
|
||||
|
||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||
auto & cells = v_cells[s];
|
||||
|
||||
cells.reset_shift();
|
||||
}
|
||||
}
|
||||
|
||||
if (!dinfo.empty()) {
|
||||
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
||||
|
||||
// note: for now do not consider defrag for n_seq_virt > 1
|
||||
auto & cells = v_cells[seq_virt_idx[0]];
|
||||
auto & head = v_heads[seq_virt_idx[0]];
|
||||
|
||||
// apply moves:
|
||||
{
|
||||
const auto n_kv = dinfo.ids.size();
|
||||
|
|
@ -551,23 +679,13 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
|||
}
|
||||
|
||||
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
uint32_t head_cur = this->head;
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (head_cur > cells.get_used() + 2*ubatch.n_tokens) {
|
||||
head_cur = 0;
|
||||
}
|
||||
|
||||
if (n_tokens > cells.size()) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
||||
return { };
|
||||
}
|
||||
|
||||
if (debug > 0) {
|
||||
LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa);
|
||||
const auto & cells = v_cells[seq_virt_idx[1]];
|
||||
|
||||
const uint32_t head_cur = v_heads[1];
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
|
||||
__func__, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
|
||||
|
||||
if ((debug == 2 && n_swa > 0) || debug > 2) {
|
||||
std::string ss;
|
||||
|
|
@ -624,18 +742,61 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
|
|||
}
|
||||
}
|
||||
|
||||
uint32_t n_tokens = ubatch.n_tokens;
|
||||
uint32_t n_seqs = 1;
|
||||
|
||||
if (n_seq_virt > 1) {
|
||||
GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0);
|
||||
|
||||
n_seqs = ubatch.n_seqs_unq;
|
||||
n_tokens = n_tokens / n_seqs;
|
||||
}
|
||||
|
||||
slot_info res = {
|
||||
/*.s0 =*/ LLAMA_MAX_SEQ,
|
||||
/*.s1 =*/ 0,
|
||||
/*.seq_id_virt =*/ { },
|
||||
/*.idxs =*/ { },
|
||||
};
|
||||
|
||||
res.resize(n_seqs);
|
||||
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
const auto seq_id = ubatch.seq_id_unq[s];
|
||||
|
||||
if (n_seq_virt > 1) {
|
||||
GGML_ASSERT(ubatch.n_seq_id[s*n_tokens] == 1);
|
||||
GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id);
|
||||
}
|
||||
|
||||
res.s0 = std::min<llama_seq_id>(res.s0, seq_virt_idx[seq_id]);
|
||||
res.s1 = std::max<llama_seq_id>(res.s1, seq_virt_idx[seq_id]);
|
||||
|
||||
res.seq_id_virt[s] = seq_virt_idx[seq_id];
|
||||
res.idxs[s].resize(n_tokens);
|
||||
|
||||
const auto & cells = v_cells[seq_virt_idx[seq_id]];
|
||||
|
||||
uint32_t head_cur = v_heads[seq_virt_idx[seq_id]];
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (head_cur > cells.get_used() + 2*n_tokens) {
|
||||
head_cur = 0;
|
||||
}
|
||||
|
||||
if (n_tokens > cells.size()) {
|
||||
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
||||
return { };
|
||||
}
|
||||
|
||||
uint32_t n_found = 0;
|
||||
uint32_t n_tested = 0;
|
||||
|
||||
// for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
|
||||
// for non-continuous slots, we test the tokens one by one
|
||||
const uint32_t n_test = cont ? n_tokens : 1;
|
||||
|
||||
slot_info res;
|
||||
|
||||
auto & idxs = res.idxs;
|
||||
|
||||
idxs.reserve(n_tokens);
|
||||
|
||||
while (true) {
|
||||
if (head_cur + n_test > cells.size()) {
|
||||
n_tested += cells.size() - head_cur;
|
||||
|
|
@ -646,6 +807,9 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
|
|||
for (uint32_t i = 0; i < n_test; i++) {
|
||||
const auto idx = head_cur;
|
||||
|
||||
head_cur++;
|
||||
n_tested++;
|
||||
|
||||
//const llama_pos pos = ubatch.pos[i];
|
||||
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||
|
||||
|
|
@ -676,22 +840,23 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
|
|||
}
|
||||
}
|
||||
|
||||
head_cur++;
|
||||
n_tested++;
|
||||
|
||||
if (can_use) {
|
||||
idxs.push_back(idx);
|
||||
res.idxs[s][n_found] = idx;
|
||||
|
||||
n_found++;
|
||||
} else {
|
||||
if (cont) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (idxs.size() == n_tokens) {
|
||||
if (n_found == n_tokens) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (cont) {
|
||||
idxs.clear();
|
||||
n_found = 0;
|
||||
}
|
||||
|
||||
if (n_tested >= cells.size()) {
|
||||
|
|
@ -701,9 +866,12 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
|
|||
}
|
||||
|
||||
// we didn't find a suitable slot - return empty result
|
||||
if (idxs.size() < n_tokens) {
|
||||
res.clear();
|
||||
if (n_found < n_tokens) {
|
||||
return { };
|
||||
}
|
||||
}
|
||||
|
||||
assert(res.s1 >= res.s0);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
|
@ -712,14 +880,19 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
|
|||
// keep track of the max sequence position that we would overwrite with this ubatch
|
||||
// for non-SWA cache, this would be always empty
|
||||
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
||||
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
seq_pos_max_rm[s] = -1;
|
||||
}
|
||||
|
||||
assert(ubatch.n_tokens == sinfo.idxs.size());
|
||||
assert(ubatch.n_tokens == sinfo.n_seq_virt()*sinfo.size());
|
||||
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||
const auto idx = sinfo.idxs.at(i);
|
||||
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
|
||||
for (uint32_t ii = 0; ii < sinfo.size(); ++ii) {
|
||||
const uint32_t i = s*sinfo.size() + ii;
|
||||
|
||||
auto & cells = v_cells[sinfo.seq_id_virt[s]];
|
||||
|
||||
const auto idx = sinfo.idxs.at(s).at(ii);
|
||||
|
||||
if (!cells.is_empty(idx)) {
|
||||
assert(cells.seq_count(idx) == 1);
|
||||
|
|
@ -738,15 +911,20 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
|
|||
cells.seq_add(idx, ubatch.seq_id[i][s]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence
|
||||
// will be present in the cache. so we have to purge any position which is less than those we would overwrite
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092
|
||||
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
if (seq_pos_max_rm[s] == -1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
GGML_ASSERT(s < seq_virt_idx.size());
|
||||
|
||||
auto & cells = v_cells[seq_virt_idx[s]];
|
||||
|
||||
if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) {
|
||||
LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n",
|
||||
__func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s);
|
||||
|
|
@ -756,7 +934,11 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
|
|||
}
|
||||
|
||||
// move the head at the end of the slot
|
||||
head = sinfo.idxs.back() + 1;
|
||||
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
|
||||
auto & head = v_heads[sinfo.seq_id_virt[s]];
|
||||
|
||||
head = sinfo.idxs[s].back() + 1;
|
||||
}
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::get_can_shift() const {
|
||||
|
|
@ -764,49 +946,76 @@ bool llama_kv_cache_unified::get_can_shift() const {
|
|||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::get_size() const {
|
||||
const auto & cells = v_cells[seq_virt_idx[0]];
|
||||
|
||||
return cells.size();
|
||||
}
|
||||
|
||||
bool llama_kv_cache_unified::get_has_shift() const {
|
||||
return cells.get_has_shift();
|
||||
bool result = false;
|
||||
|
||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||
result |= v_cells[s].get_has_shift();
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::get_n_kv() const {
|
||||
return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad)));
|
||||
uint32_t result = 0;
|
||||
|
||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||
const auto & cells = v_cells[s];
|
||||
|
||||
result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
|
||||
return result;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
|
||||
const int32_t ikv = map_layer_ids.at(il);
|
||||
|
||||
auto * k = layers[ikv].k;
|
||||
|
||||
return ggml_view_3d(ctx, k,
|
||||
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv,
|
||||
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
|
||||
|
||||
const uint64_t size_virt = ggml_row_size(k->type, hparams.n_embd_k_gqa(il)*get_size());
|
||||
|
||||
return ggml_view_4d(ctx, k,
|
||||
hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns,
|
||||
ggml_row_size(k->type, hparams.n_embd_head_k),
|
||||
ggml_row_size(k->type, hparams.n_embd_k_gqa(il)),
|
||||
0);
|
||||
size_virt,
|
||||
size_virt*sinfo.s0);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const {
|
||||
ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
|
||||
const int32_t ikv = map_layer_ids.at(il);
|
||||
|
||||
auto * v = layers[ikv].v;
|
||||
|
||||
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
|
||||
|
||||
const uint64_t kv_size = get_size();
|
||||
|
||||
if (!v_trans) {
|
||||
// note: v->nb[1] <= v->nb[2]
|
||||
return ggml_view_3d(ctx, v,
|
||||
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv,
|
||||
return ggml_view_4d(ctx, v,
|
||||
hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns,
|
||||
ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1]
|
||||
ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2]
|
||||
0);
|
||||
ggml_row_size(v->type, kv_size*hparams.n_embd_v_gqa(il)), // v->nb[3]
|
||||
ggml_row_size(v->type, kv_size*hparams.n_embd_v_gqa(il)*sinfo.s0));
|
||||
}
|
||||
|
||||
// note: v->nb[1] > v->nb[2]
|
||||
return ggml_view_3d(ctx, v,
|
||||
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v,
|
||||
ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1]
|
||||
ggml_row_size(v->type, v->ne[1]), // v->nb[2]
|
||||
0);
|
||||
return ggml_view_4d(ctx, v,
|
||||
n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns,
|
||||
ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1]
|
||||
ggml_row_size(v->type, kv_size), // v->nb[2]
|
||||
ggml_row_size(v->type, kv_size*hparams.n_embd_v_gqa(il)), // v->nb[3]
|
||||
ggml_row_size(v->type, kv_size*hparams.n_embd_v_gqa(il)*sinfo.s0));
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
|
||||
|
|
@ -820,12 +1029,16 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
|
|||
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
|
||||
|
||||
if (k_idxs && supports_set_rows) {
|
||||
k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]);
|
||||
|
||||
return ggml_set_rows(ctx, k, k_cur, k_idxs);
|
||||
}
|
||||
|
||||
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
||||
// will be removed when ggml_set_rows() is adopted by all backends
|
||||
|
||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not supported");
|
||||
|
||||
ggml_tensor * k_view = ggml_view_1d(ctx, k,
|
||||
n_tokens*n_embd_k_gqa,
|
||||
ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
|
||||
|
|
@ -845,30 +1058,24 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
|
|||
|
||||
if (v_idxs && supports_set_rows) {
|
||||
if (!v_trans) {
|
||||
v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]);
|
||||
|
||||
return ggml_set_rows(ctx, v, v_cur, v_idxs);
|
||||
}
|
||||
|
||||
// the row becomes a single element
|
||||
ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
|
||||
ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]);
|
||||
|
||||
// note: the V cache is transposed when not using flash attention
|
||||
v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
|
||||
v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]);
|
||||
|
||||
// note: we can be more explicit here at the cost of extra cont
|
||||
// however, above we take advantage that a row of single element is always continuous regardless of the row stride
|
||||
//v_cur = ggml_transpose(ctx, v_cur);
|
||||
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
|
||||
|
||||
// we broadcast the KV indices n_embd_v_gqa times
|
||||
// v [1, n_kv, n_embd_v_gqa]
|
||||
// v_cur [1, n_tokens, n_embd_v_gqa]
|
||||
// v_idxs [n_tokens, 1, 1]
|
||||
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
|
||||
}
|
||||
|
||||
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
||||
// will be removed when ggml_set_rows() is adopted by all backends
|
||||
|
||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not supported");
|
||||
|
||||
ggml_tensor * v_view = nullptr;
|
||||
|
||||
if (!v_trans) {
|
||||
|
|
@ -899,7 +1106,14 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con
|
|||
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
|
||||
ggml_tensor * v_idxs;
|
||||
|
||||
if (!v_trans) {
|
||||
v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
|
||||
} else {
|
||||
// TODO: assert that n_embd_v_gqa is the same for all layers, or take the max
|
||||
v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa());
|
||||
}
|
||||
|
||||
ggml_set_input(v_idxs);
|
||||
|
||||
|
|
@ -912,12 +1126,17 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba
|
|||
}
|
||||
|
||||
const uint32_t n_tokens = ubatch->n_tokens;
|
||||
GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_seq_virt());
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||
int64_t * data = (int64_t *) dst->data;
|
||||
|
||||
for (int64_t i = 0; i < n_tokens; ++i) {
|
||||
data[i] = sinfo.idxs.at(i);
|
||||
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
|
||||
const int64_t offs = sinfo.seq_id_virt[s]*get_size();
|
||||
|
||||
for (uint32_t i = 0; i < sinfo.size(); ++i) {
|
||||
data[s*sinfo.size() + i] = offs + sinfo.idxs.at(s).at(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -927,12 +1146,49 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
|
|||
}
|
||||
|
||||
const uint32_t n_tokens = ubatch->n_tokens;
|
||||
GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_seq_virt());
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||
int64_t * data = (int64_t *) dst->data;
|
||||
|
||||
for (int64_t i = 0; i < n_tokens; ++i) {
|
||||
data[i] = sinfo.idxs.at(i);
|
||||
if (!v_trans) {
|
||||
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
|
||||
const int64_t offs = sinfo.seq_id_virt[s]*get_size();
|
||||
|
||||
for (uint32_t i = 0; i < sinfo.size(); ++i) {
|
||||
data[s*sinfo.size() + i] = offs + sinfo.idxs.at(s).at(i);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// note: the V cache is transposed when not using flash attention
|
||||
const int64_t kv_size = get_size();
|
||||
|
||||
// TODO: assert that n_embd_v_gqa is the same for all layers, or take the max
|
||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
|
||||
for (uint32_t s = 0; s < sinfo.n_seq_virt(); ++s) {
|
||||
const int64_t offs = sinfo.seq_id_virt[s]*kv_size*n_embd_v_gqa;
|
||||
|
||||
for (uint32_t i = 0; i < sinfo.size(); ++i) {
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs.at(s).at(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||
|
||||
int32_t * data = (int32_t *) dst->data;
|
||||
|
||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||
const auto & cells = v_cells[s];
|
||||
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -943,6 +1199,12 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|||
float * data = (float *) dst->data;
|
||||
|
||||
const int64_t n_kv = dst->ne[0];
|
||||
const int64_t n_seq_virt = dst->ne[3]; // num virtual sequences in the current ubatch
|
||||
|
||||
GGML_ASSERT(n_tokens%n_seq_virt == 0);
|
||||
|
||||
const int64_t n_tokens_per_seq = n_tokens/n_seq_virt;
|
||||
const int64_t n_tokens_per_seq_pad = GGML_PAD(n_tokens_per_seq, GGML_KQ_MASK_PAD);
|
||||
|
||||
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
||||
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
||||
|
|
@ -957,9 +1219,14 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|||
// xxxxx-----
|
||||
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
||||
for (uint32_t h = 0; h < 1; ++h) {
|
||||
for (uint32_t i = 0; i < n_tokens; ++i) {
|
||||
for (uint32_t s = 0; s < n_seq_virt; ++s) {
|
||||
for (uint32_t ii = 0; ii < n_tokens_per_seq; ++ii) {
|
||||
const uint32_t i = s*n_tokens_per_seq + ii;
|
||||
|
||||
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
|
||||
const auto & cells = v_cells[seq_virt_idx[seq_id]];
|
||||
|
||||
const llama_pos p1 = ubatch->pos[i];
|
||||
|
||||
for (uint32_t j = 0; j < n_kv; ++j) {
|
||||
|
|
@ -990,34 +1257,28 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|||
f = -INFINITY;
|
||||
}
|
||||
|
||||
data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
|
||||
}
|
||||
data[h*n_seq_virt*n_tokens_per_seq_pad*n_kv + s*n_tokens_per_seq_pad*n_kv + ii*n_kv + j] = f;
|
||||
}
|
||||
|
||||
// mask padded tokens
|
||||
if (data) {
|
||||
for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
||||
for (uint32_t ii = n_tokens_per_seq; ii < n_tokens_per_seq_pad; ++ii) {
|
||||
for (uint32_t j = 0; j < n_kv; ++j) {
|
||||
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
||||
data[h*n_seq_virt*n_tokens_per_seq_pad*n_kv + s*n_tokens_per_seq_pad*n_kv + ii*n_kv + j] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||
|
||||
int32_t * data = (int32_t *) dst->data;
|
||||
|
||||
for (uint32_t i = 0; i < cells.size(); ++i) {
|
||||
data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
GGML_ASSERT(n_seq_virt == 1 && "TODO: support multiple virtual sequences");
|
||||
const auto & cells = v_cells[0];
|
||||
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
||||
|
||||
|
|
@ -1124,7 +1385,7 @@ public:
|
|||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * k_shift; // I32 [kv_size]
|
||||
ggml_tensor * k_shift; // I32 [kv_size*n_seq_virt]
|
||||
|
||||
const llama_kv_cache_unified * kv_self;
|
||||
};
|
||||
|
|
@ -1148,7 +1409,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
|||
|
||||
auto inp = std::make_unique<llm_graph_input_k_shift>(this);
|
||||
|
||||
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size());
|
||||
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_seq_virt);
|
||||
ggml_set_input(inp->k_shift);
|
||||
|
||||
for (const auto & layer : layers) {
|
||||
|
|
@ -1164,7 +1425,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift(
|
|||
|
||||
ggml_tensor * k =
|
||||
ggml_view_3d(ctx, layer.k,
|
||||
n_embd_head_k, n_head_kv, cells.size(),
|
||||
n_embd_head_k, n_head_kv, get_size()*n_seq_virt,
|
||||
ggml_row_size(layer.k->type, n_embd_head_k),
|
||||
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
||||
0);
|
||||
|
|
@ -1186,6 +1447,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
|||
const defrag_info & dinfo) const {
|
||||
auto res = std::make_unique<llm_graph_result>();
|
||||
|
||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 does not support defrag");
|
||||
|
||||
const auto & cells = v_cells[0];
|
||||
|
||||
const auto & ids = dinfo.ids;
|
||||
|
||||
#if 0
|
||||
|
|
@ -1328,6 +1593,10 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag(
|
|||
}
|
||||
|
||||
llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
|
||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 does not support defrag");
|
||||
|
||||
const auto & cells = v_cells[0];
|
||||
|
||||
const uint32_t n_layer = layers.size();
|
||||
|
||||
const uint32_t n_kv = cells.used_max_p1();
|
||||
|
|
@ -1476,6 +1745,10 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
|
|||
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
||||
uint32_t cell_count = 0;
|
||||
|
||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet");
|
||||
|
||||
const auto & cells = v_cells[0];
|
||||
|
||||
// Count the number of cells with the specified seq_id
|
||||
// Find all the ranges of cells with this seq id (or all, when -1)
|
||||
uint32_t cell_range_begin = cells.size();
|
||||
|
|
@ -1530,6 +1803,10 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
|
|||
}
|
||||
|
||||
void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet");
|
||||
|
||||
const auto & cells = v_cells[0];
|
||||
|
||||
for (const auto & range : cell_ranges) {
|
||||
for (uint32_t i = range.first; i < range.second; ++i) {
|
||||
std::vector<llama_seq_id> seq_ids;
|
||||
|
|
@ -1556,6 +1833,10 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::
|
|||
}
|
||||
|
||||
void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet");
|
||||
|
||||
const auto & cells = v_cells[0];
|
||||
|
||||
const uint32_t v_trans = this->v_trans ? 1 : 0;
|
||||
const uint32_t n_layer = layers.size();
|
||||
|
||||
|
|
@ -1643,6 +1924,11 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
|||
}
|
||||
|
||||
bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet");
|
||||
|
||||
auto & cells = v_cells[0];
|
||||
auto & head = v_heads[0];
|
||||
|
||||
if (dest_seq_id != -1) {
|
||||
// single sequence
|
||||
|
||||
|
|
@ -1734,6 +2020,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|||
}
|
||||
|
||||
bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
||||
GGML_ASSERT(n_seq_virt == 1 && "n_seq_virt > 1 not implemented yet");
|
||||
|
||||
auto & cells = v_cells[0];
|
||||
auto & head = v_heads[0];
|
||||
|
||||
uint32_t v_trans;
|
||||
uint32_t n_layer;
|
||||
|
||||
|
|
@ -1872,8 +2163,9 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
|||
|
||||
// create a dummy slot info - the actual data is irrelevant. we just need to build the graph
|
||||
sinfos.resize(1);
|
||||
sinfos[0].seq_id_virt.resize(1, 0);
|
||||
sinfos[0].idxs.resize(1);
|
||||
sinfos[0].idxs[0] = 0;
|
||||
sinfos[0].idxs[0].resize(1, 0);
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
||||
|
|
@ -1936,11 +2228,11 @@ uint32_t llama_kv_cache_unified_context::get_n_kv() const {
|
|||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
|
||||
return kv->get_k(ctx, il, n_kv);
|
||||
return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
|
||||
return kv->get_v(ctx, il, n_kv);
|
||||
return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
|
||||
|
|
|
|||
|
|
@ -41,10 +41,31 @@ public:
|
|||
// data for ggml_set_rows
|
||||
using idx_vec_t = std::vector<uint32_t>;
|
||||
|
||||
idx_vec_t idxs;
|
||||
llama_seq_id s0;
|
||||
llama_seq_id s1;
|
||||
|
||||
std::vector<llama_seq_id> seq_id_virt;
|
||||
std::vector<idx_vec_t> idxs;
|
||||
|
||||
uint32_t head() const {
|
||||
return idxs.at(0);
|
||||
GGML_ASSERT(idxs.size() == 1);
|
||||
|
||||
return idxs.at(0).at(0);
|
||||
}
|
||||
|
||||
void resize(size_t n) {
|
||||
seq_id_virt.resize(n);
|
||||
idxs.resize(n);
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
GGML_ASSERT(idxs.size() == seq_id_virt.size());
|
||||
|
||||
return idxs.at(0).size();
|
||||
}
|
||||
|
||||
size_t n_seq_virt() const {
|
||||
return seq_id_virt.size();
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
|
|
@ -54,9 +75,6 @@ public:
|
|||
void clear() {
|
||||
idxs.clear();
|
||||
}
|
||||
|
||||
// TODO: implement
|
||||
//std::vector<idx_vec_t> seq_idxs;
|
||||
};
|
||||
|
||||
using slot_info_vec_t = std::vector<slot_info>;
|
||||
|
|
@ -70,6 +88,7 @@ public:
|
|||
bool offload,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_seq_virt,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type);
|
||||
|
|
@ -122,8 +141,8 @@ public:
|
|||
uint32_t get_n_kv() const;
|
||||
|
||||
// get views of the current state of the cache
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
|
||||
|
|
@ -157,8 +176,9 @@ public:
|
|||
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||
|
||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_k_shift(ggml_tensor * dst) const;
|
||||
|
||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
private:
|
||||
|
|
@ -172,15 +192,15 @@ private:
|
|||
|
||||
ggml_tensor * k;
|
||||
ggml_tensor * v;
|
||||
|
||||
std::vector<ggml_tensor *> k_seq;
|
||||
std::vector<ggml_tensor *> v_seq;
|
||||
};
|
||||
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
uint32_t head = 0;
|
||||
|
||||
const uint32_t n_seq_max = 1;
|
||||
const uint32_t n_seq_virt = 1;
|
||||
|
||||
// required padding
|
||||
const uint32_t n_pad = 1;
|
||||
|
|
@ -200,7 +220,14 @@ private:
|
|||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
llama_kv_cells_unified cells;
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
std::vector<uint32_t> v_heads;
|
||||
|
||||
std::vector<llama_kv_cells_unified> v_cells;
|
||||
|
||||
// maps from a sequence id to a virtual sequence id
|
||||
std::vector<uint32_t> seq_virt_idx;
|
||||
|
||||
std::vector<kv_layer> layers;
|
||||
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ llama_memory_hybrid::llama_memory_hybrid(
|
|||
offload,
|
||||
kv_size,
|
||||
n_seq_max,
|
||||
1,
|
||||
n_pad,
|
||||
n_swa,
|
||||
swa_type
|
||||
|
|
@ -70,7 +71,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
|
|||
// if all tokens are output, split by sequence
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
ubatch = balloc.split_equal(n_ubatch);
|
||||
ubatch = balloc.split_equal(n_ubatch, false);
|
||||
}
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
|
|
@ -80,6 +81,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
|
|||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
break;
|
||||
}
|
||||
|
||||
// prepare the recurrent batches first
|
||||
if (!mem_recr->prepare(ubatches)) {
|
||||
// TODO: will the recurrent cache be in an undefined context at this point?
|
||||
|
|
|
|||
|
|
@ -374,10 +374,11 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
|
|||
// if all tokens are output, split by sequence
|
||||
ubatch = balloc.split_seq(n_ubatch);
|
||||
} else {
|
||||
ubatch = balloc.split_equal(n_ubatch);
|
||||
ubatch = balloc.split_equal(n_ubatch, false);
|
||||
}
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -14499,6 +14499,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
params.swa_full,
|
||||
cparams.n_ctx,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_seq_virt,
|
||||
cparams.n_ubatch,
|
||||
padding);
|
||||
} else {
|
||||
|
|
@ -14513,6 +14514,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
cparams.offload_kqv,
|
||||
cparams.n_ctx,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_seq_virt,
|
||||
padding,
|
||||
hparams.n_swa,
|
||||
hparams.swa_type);
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const int32_t n_kv_max = llama_n_ctx(ctx);
|
||||
|
||||
llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
|
||||
llama_batch batch = llama_batch_init(n_kv_max*8, 0, 1); // TODO: tmp!!!
|
||||
|
||||
// decode in batches of ctx_params.n_batch tokens
|
||||
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
|
||||
|
|
@ -119,9 +119,9 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);
|
||||
|
||||
if (n_ctx_req > n_kv_max) {
|
||||
continue;
|
||||
}
|
||||
//if (n_ctx_req > n_kv_max) {
|
||||
// continue;
|
||||
//}
|
||||
|
||||
common_batch_clear(batch);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue