kv-cache : added llama_kv_cache_dsa KV cache specific to DSA composed of llama_kv_cache and new llama_ik_cache (lightning indexer key cache).
model : used new llama_kv_cache_dsa instead of modified llama_kv_cache with indexer keys in DeepseekV32ForCausalLM model : removed non-MLA path in DeepseekV32ForCausalLM
This commit is contained in:
parent
1874ac9b86
commit
4309c8486a
|
|
@ -22,6 +22,8 @@ add_library(llama
|
|||
llama-io.cpp
|
||||
llama-kv-cache.cpp
|
||||
llama-kv-cache-iswa.cpp
|
||||
llama-ik-cache.cpp
|
||||
llama-kv-cache-dsa.cpp
|
||||
llama-memory.cpp
|
||||
llama-memory-hybrid.cpp
|
||||
llama-memory-hybrid-iswa.cpp
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-kv-cache-iswa.h"
|
||||
#include "llama-kv-cache-dsa.h"
|
||||
#include "llama-memory-hybrid.h"
|
||||
#include "llama-memory-hybrid-iswa.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
|
@ -31,6 +32,18 @@ static ggml_tensor * build_kq_mask(
|
|||
return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
}
|
||||
|
||||
static ggml_tensor * build_kq_mask(
|
||||
ggml_context * ctx,
|
||||
const llama_ik_cache_context * mctx,
|
||||
const llama_ubatch & ubatch,
|
||||
const llama_cparams & cparams) {
|
||||
const auto n_kv = mctx->get_n_kv();
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
}
|
||||
|
||||
static bool can_reuse_kq_mask(
|
||||
ggml_tensor * kq_mask,
|
||||
const llama_kv_cache_context * mctx,
|
||||
|
|
@ -50,6 +63,25 @@ static bool can_reuse_kq_mask(
|
|||
return res;
|
||||
}
|
||||
|
||||
static bool can_reuse_kq_mask(
|
||||
ggml_tensor * kq_mask,
|
||||
const llama_ik_cache_context * mctx,
|
||||
const llama_ubatch & ubatch,
|
||||
const llama_cparams & cparams) {
|
||||
const auto n_kv = mctx->get_n_kv();
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
bool res = true;
|
||||
|
||||
res &= (kq_mask->ne[0] == n_kv);
|
||||
res &= (kq_mask->ne[1] == n_tokens/n_stream);
|
||||
res &= (kq_mask->ne[2] == 1);
|
||||
res &= (kq_mask->ne[3] == n_stream);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// impl
|
||||
|
||||
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
||||
|
|
@ -2108,6 +2140,112 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn(
|
||||
llm_graph_input_attn_k * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * v_mla,
|
||||
ggml_tensor * top_k,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
// these nodes are added to the graph together so that they are not reordered
|
||||
// by doing so, the number of splits in the graph is reduced
|
||||
// expand k later to enable rope fusion which directly writes into k-v cache
|
||||
ggml_build_forward_expand(gf, q_cur);
|
||||
ggml_build_forward_expand(gf, v_cur);
|
||||
ggml_build_forward_expand(gf, k_cur);
|
||||
|
||||
const auto * mctx_cur = inp->mctx;
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
const auto & k_idxs = inp->get_k_idxs();
|
||||
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
||||
}
|
||||
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
|
||||
ggml_tensor * kq_mask_f32 = ggml_cast(ctx0, kq_mask, GGML_TYPE_F32);
|
||||
|
||||
// prepare new kq mask - starts filled with -INFINITY
|
||||
ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask_f32, -INFINITY);
|
||||
|
||||
// modify it by unmasking tokens that are in top_k indices
|
||||
ggml_tensor * kq_mask_top_k = ggml_where_id(ctx0, kq_mask_f32, kq_mask_all, top_k);
|
||||
kq_mask_top_k = ggml_cast(ctx0, kq_mask_top_k, kq_mask->type);
|
||||
|
||||
ggml_tensor * q = q_cur;
|
||||
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||
ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask_top_k, sinks, v_mla, kq_scale, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
cur = build_lora_mm(wo, cur);
|
||||
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
|
||||
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
|
||||
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
}
|
||||
|
||||
if (wo_b) {
|
||||
cur = ggml_add(ctx0, cur, wo_b);
|
||||
}
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
|
||||
static std::unique_ptr<llm_graph_input_attn_ik> build_attn_inp_ik_impl(
|
||||
ggml_context * ctx0,
|
||||
const llama_ubatch & ubatch,
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_ik_cache_context * mctx_cur) {
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_ik>(hparams, cparams, mctx_cur);
|
||||
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
|
||||
|
||||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = inp->self_kq_mask;
|
||||
}
|
||||
|
||||
return inp;
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_ik::set_input(const llama_ubatch * ubatch) {
|
||||
mctx->set_input_k_idxs(self_k_idxs, ubatch);
|
||||
|
||||
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
bool llm_graph_input_attn_ik::can_reuse(const llm_graph_params & params) {
|
||||
const auto * mctx = static_cast<const llama_ik_cache_context *>(params.mctx);
|
||||
|
||||
this->mctx = mctx;
|
||||
|
||||
bool res = true;
|
||||
|
||||
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||
|
||||
res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn(
|
||||
llm_graph_input_attn_kv_iswa * inp,
|
||||
ggml_tensor * wo,
|
||||
|
|
@ -2230,6 +2368,17 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
return cur;
|
||||
}
|
||||
|
||||
std::pair<llm_graph_input_attn_k *, llm_graph_input_attn_ik *> llm_graph_context::build_attn_inp_k_dsa() const {
|
||||
const auto * mctx_cur = static_cast<const llama_kv_cache_dsa_context *>(mctx);
|
||||
|
||||
auto inp_k = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_base());
|
||||
auto inp_ik = build_attn_inp_ik_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_ik());
|
||||
|
||||
return std::make_pair(
|
||||
(llm_graph_input_attn_k *) res->add_input(std::move(inp_k)),
|
||||
(llm_graph_input_attn_ik *) res->add_input(std::move(inp_ik)));
|
||||
}
|
||||
|
||||
// TODO: maybe separate the inner implementation into a separate function
|
||||
// like with the non-sliding window equivalent
|
||||
// once sliding-window hybrid caches are a thing.
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ struct llama_cparams;
|
|||
struct llama_memory_context_i;
|
||||
|
||||
class llama_kv_cache_context;
|
||||
class llama_ik_cache_context;
|
||||
class llama_kv_cache_iswa_context;
|
||||
class llama_memory_recurrent_context;
|
||||
class llama_memory_hybrid_context;
|
||||
|
|
@ -350,6 +351,39 @@ public:
|
|||
const llama_kv_cache_context * mctx;
|
||||
};
|
||||
|
||||
// V-less input for the indexer KV cache
|
||||
class llm_graph_input_attn_ik : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_attn_ik(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_ik_cache_context * mctx) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
mctx(mctx) {
|
||||
}
|
||||
~llm_graph_input_attn_ik() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
bool can_reuse(const llm_graph_params & params) override;
|
||||
|
||||
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
||||
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
|
||||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
const llama_hparams hparams;
|
||||
const llama_cparams cparams;
|
||||
|
||||
const llama_ik_cache_context * mctx;
|
||||
};
|
||||
|
||||
|
||||
class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_attn_kv_iswa(
|
||||
|
|
@ -914,6 +948,20 @@ struct llm_graph_context {
|
|||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
ggml_tensor * build_attn(
|
||||
llm_graph_input_attn_k * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * sinks, // [n_head_q]
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
ggml_tensor * top_k, // [n_indexer_top_k, n_tokens]
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
|
||||
|
||||
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
|
||||
|
|
@ -945,6 +993,8 @@ struct llm_graph_context {
|
|||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
std::pair<llm_graph_input_attn_k *, llm_graph_input_attn_ik *> build_attn_inp_k_dsa() const;
|
||||
|
||||
//
|
||||
// recurrent
|
||||
//
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,306 @@
|
|||
#pragma once
|
||||
|
||||
#include "llama-kv-cache.h"
|
||||
|
||||
#include "llama-batch.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-kv-cells.h"
|
||||
#include "llama-memory.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
struct llama_cparams;
|
||||
struct llama_hparams;
|
||||
struct llama_model;
|
||||
struct llama_context;
|
||||
|
||||
//
|
||||
// llama_ik_cache
|
||||
//
|
||||
|
||||
class llama_ik_cache : public llama_memory_i {
|
||||
public:
|
||||
using stream_copy_info = llama_kv_cache::stream_copy_info;
|
||||
using slot_info = llama_kv_cache::slot_info;
|
||||
using slot_info_vec_t = std::vector<slot_info>;
|
||||
|
||||
llama_ik_cache(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse);
|
||||
|
||||
~llama_ik_cache() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
llama_memory_context_ptr init_batch(
|
||||
llama_batch_allocr & balloc,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_all) override;
|
||||
|
||||
llama_memory_context_ptr init_full() override;
|
||||
|
||||
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
void clear(bool data) override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
||||
|
||||
//
|
||||
// llama_ik_cache specific API
|
||||
//
|
||||
|
||||
uint32_t get_size() const;
|
||||
uint32_t get_n_stream() const;
|
||||
|
||||
bool get_has_shift() const;
|
||||
|
||||
//
|
||||
// graph_build API
|
||||
//
|
||||
|
||||
uint32_t get_n_kv(const slot_info & sinfo) const;
|
||||
|
||||
// get views of the current state of the cache
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
|
||||
|
||||
//
|
||||
// preparation API
|
||||
//
|
||||
|
||||
// find places for the provided ubatches in the cache, returns the slot infos
|
||||
// return empty vector on failure
|
||||
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
|
||||
bool update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info);
|
||||
|
||||
// find a slot of kv cells that can hold the ubatch
|
||||
// if cont == true, then the slot must be continuous
|
||||
// return empty slot_info on failure
|
||||
slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
|
||||
|
||||
// emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
|
||||
void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
|
||||
|
||||
//
|
||||
// input API
|
||||
//
|
||||
|
||||
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
|
||||
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||
|
||||
void set_input_k_shift(ggml_tensor * dst) const;
|
||||
|
||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
|
||||
private:
|
||||
const llama_model & model;
|
||||
const llama_hparams & hparams;
|
||||
|
||||
struct kv_layer {
|
||||
// layer index in the model
|
||||
// note: can be different from the layer index in the KV cache
|
||||
uint32_t il;
|
||||
|
||||
ggml_tensor * k;
|
||||
|
||||
std::vector<ggml_tensor *> k_stream;
|
||||
};
|
||||
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
|
||||
const uint32_t n_seq_max = 1;
|
||||
const uint32_t n_stream = 1;
|
||||
|
||||
// required padding
|
||||
const uint32_t n_pad = 1;
|
||||
|
||||
// SWA
|
||||
const uint32_t n_swa = 0;
|
||||
|
||||
// env: LLAMA_KV_CACHE_DEBUG
|
||||
int debug = 0;
|
||||
|
||||
// this is the SWA type of the cache - not to be confused with the model SWA type
|
||||
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
||||
|
||||
// ggml contexts for the KV cache along with the allocated backend buffers:
|
||||
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
|
||||
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
std::vector<uint32_t> v_heads;
|
||||
|
||||
std::vector<llama_kv_cells> v_cells;
|
||||
|
||||
// maps from a sequence id to a stream id
|
||||
std::vector<uint32_t> seq_to_stream;
|
||||
|
||||
// pending stream copies that will be applied during the next update
|
||||
stream_copy_info sc_info;
|
||||
|
||||
std::vector<kv_layer> layers;
|
||||
|
||||
// model layer id -> KV cache layer id
|
||||
std::unordered_map<int32_t, int32_t> map_layer_ids;
|
||||
|
||||
size_t total_size() const;
|
||||
|
||||
size_t size_k_bytes() const;
|
||||
|
||||
ggml_tensor * build_rope_shift(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * shift,
|
||||
ggml_tensor * factors,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
uint32_t il) const;
|
||||
|
||||
ggml_cgraph * build_graph_shift(
|
||||
llm_graph_result * res,
|
||||
llama_context * lctx) const;
|
||||
|
||||
struct cell_ranges_t {
|
||||
uint32_t strm;
|
||||
|
||||
std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive
|
||||
};
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
|
||||
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo);
|
||||
};
|
||||
|
||||
class llama_ik_cache_context : public llama_memory_context_i {
|
||||
public:
|
||||
// some shorthands
|
||||
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
|
||||
using stream_copy_info = llama_kv_cache::stream_copy_info;
|
||||
|
||||
// used for errors
|
||||
llama_ik_cache_context(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache context
|
||||
llama_ik_cache_context(
|
||||
llama_ik_cache * kv);
|
||||
|
||||
// used to create an update context
|
||||
llama_ik_cache_context(
|
||||
llama_ik_cache * kv,
|
||||
llama_context * lctx,
|
||||
bool do_shift,
|
||||
stream_copy_info sc_info);
|
||||
|
||||
// used to create a batch processing context from a batch
|
||||
llama_ik_cache_context(
|
||||
llama_ik_cache * kv,
|
||||
slot_info_vec_t sinfos,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_ik_cache_context();
|
||||
|
||||
//
|
||||
// llama_memory_context_i
|
||||
//
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_ik_cache_context specific API
|
||||
//
|
||||
|
||||
uint32_t get_n_kv() const;
|
||||
|
||||
// get views of the current state of the cache
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
// note: the heads in k_cur and v_cur should be layed out contiguously in memory
|
||||
// - k_cur [n_embd_head_k, n_head_k, n_tokens]
|
||||
// - k_idxs [n_tokens]
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
|
||||
|
||||
// create destination indices for each head of the current batch for where it would be written in the KV cache
|
||||
// the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but
|
||||
// helps understand the implementation logic of cpy_k
|
||||
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
|
||||
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
void set_input_k_shift (ggml_tensor * dst) const;
|
||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
|
||||
private:
|
||||
llama_memory_status status;
|
||||
|
||||
llama_ik_cache * kv;
|
||||
llama_context * lctx;
|
||||
|
||||
//
|
||||
// update context
|
||||
//
|
||||
|
||||
bool do_shift = false;
|
||||
|
||||
stream_copy_info sc_info;
|
||||
|
||||
//
|
||||
// batch processing context
|
||||
//
|
||||
|
||||
// the index of the cur ubatch to process
|
||||
size_t i_cur = 0;
|
||||
|
||||
slot_info_vec_t sinfos;
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
//
|
||||
// data needed for building the compute graph for the current ubatch:
|
||||
//
|
||||
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// as the cache gets filled, the benefit from this heuristic disappears
|
||||
int32_t n_kv;
|
||||
};
|
||||
|
|
@ -0,0 +1,251 @@
|
|||
#include "llama-kv-cache-dsa.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-batch.h"
|
||||
#include "llama-model.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
|
||||
//
|
||||
// llama_kv_cache_dsa
|
||||
//
|
||||
|
||||
llama_kv_cache_dsa::llama_kv_cache_dsa(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse) :
|
||||
n_stream(unified ? 1 : n_seq_max) {
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating main KV cache, size = %u cells\n", __func__, kv_size);
|
||||
|
||||
kv_base = std::make_unique<llama_kv_cache>(
|
||||
model, type_k, type_v,
|
||||
v_trans, offload, unified, kv_size, n_seq_max, n_pad,
|
||||
n_swa, swa_type, filter, reuse);
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating indexer KV cache, size = %u cells\n", __func__, kv_size);
|
||||
|
||||
kv_ik = std::make_unique<llama_ik_cache>(
|
||||
model, type_k, type_v,
|
||||
v_trans, offload, unified, kv_size, n_seq_max, n_pad,
|
||||
n_swa, swa_type, filter, reuse);
|
||||
}
|
||||
|
||||
void llama_kv_cache_dsa::clear(bool data) {
|
||||
kv_base->clear(data);
|
||||
kv_ik ->clear(data);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_dsa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
bool res = true;
|
||||
|
||||
res = res & kv_base->seq_rm(seq_id, p0, p1);
|
||||
res = res & kv_ik ->seq_rm(seq_id, p0, p1);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void llama_kv_cache_dsa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
kv_ik ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
void llama_kv_cache_dsa::seq_keep(llama_seq_id seq_id) {
|
||||
kv_base->seq_keep(seq_id);
|
||||
kv_ik ->seq_keep(seq_id);
|
||||
}
|
||||
|
||||
void llama_kv_cache_dsa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
kv_base->seq_add(seq_id, p0, p1, shift);
|
||||
kv_ik ->seq_add(seq_id, p0, p1, shift);
|
||||
}
|
||||
|
||||
void llama_kv_cache_dsa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
kv_base->seq_div(seq_id, p0, p1, d);
|
||||
kv_ik ->seq_div(seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_dsa::seq_pos_min(llama_seq_id seq_id) const {
|
||||
return kv_base->seq_pos_min(seq_id);
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_dsa::seq_pos_max(llama_seq_id seq_id) const {
|
||||
return kv_base->seq_pos_max(seq_id);
|
||||
}
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache_dsa::memory_breakdown() const {
|
||||
std::map<ggml_backend_buffer_type_t, size_t> mb = kv_base->memory_breakdown();
|
||||
for (const auto & buft_size : kv_ik->memory_breakdown()) {
|
||||
mb[buft_size.first] += buft_size.second;
|
||||
}
|
||||
return mb;
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_kv_cache_dsa::init_batch(
|
||||
llama_batch_allocr & balloc,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_all) {
|
||||
GGML_UNUSED(embd_all);
|
||||
|
||||
do {
|
||||
balloc.split_reset();
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
while (true) {
|
||||
auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||
}
|
||||
|
||||
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||
// failed to find a suitable split
|
||||
break;
|
||||
}
|
||||
|
||||
auto sinfos_base = kv_base->prepare(ubatches);
|
||||
if (sinfos_base.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
auto sinfos_ik = kv_ik->prepare(ubatches);
|
||||
if (sinfos_ik.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
assert(sinfos_base.size() == sinfos_ik.size());
|
||||
|
||||
return std::make_unique<llama_kv_cache_dsa_context>(
|
||||
this, std::move(sinfos_base), std::move(sinfos_ik), std::move(ubatches));
|
||||
} while (false);
|
||||
|
||||
return std::make_unique<llama_kv_cache_dsa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_kv_cache_dsa::init_full() {
|
||||
return std::make_unique<llama_kv_cache_dsa_context>(this);
|
||||
}
|
||||
|
||||
llama_memory_context_ptr llama_kv_cache_dsa::init_update(llama_context * lctx, bool optimize) {
|
||||
return std::make_unique<llama_kv_cache_dsa_context>(this, lctx, optimize);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_dsa::get_can_shift() const {
|
||||
return kv_base->get_can_shift() &&
|
||||
kv_ik->get_can_shift() &&
|
||||
kv_base->get_size() == kv_ik->get_size();
|
||||
}
|
||||
|
||||
void llama_kv_cache_dsa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||
kv_base->state_write(io, seq_id, flags);
|
||||
kv_ik->state_write(io, seq_id, flags);
|
||||
}
|
||||
|
||||
void llama_kv_cache_dsa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||
kv_base->state_read(io, seq_id, flags);
|
||||
kv_ik->state_read(io, seq_id, flags);
|
||||
}
|
||||
|
||||
llama_kv_cache * llama_kv_cache_dsa::get_base() const {
|
||||
return kv_base.get();
|
||||
}
|
||||
|
||||
llama_ik_cache * llama_kv_cache_dsa::get_ik() const {
|
||||
return kv_ik.get();
|
||||
}
|
||||
|
||||
//
|
||||
// llama_kv_cache_dsa_context
|
||||
//
|
||||
|
||||
llama_kv_cache_dsa_context::llama_kv_cache_dsa_context(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_kv_cache_dsa_context::llama_kv_cache_dsa_context(
|
||||
llama_kv_cache_dsa * kv) :
|
||||
ctx_base(kv->get_base()->init_full()),
|
||||
ctx_ik(kv->get_ik()->init_full()),
|
||||
status(llama_memory_status_combine(ctx_base->get_status(), ctx_ik->get_status())) {
|
||||
}
|
||||
|
||||
llama_kv_cache_dsa_context::llama_kv_cache_dsa_context(
|
||||
llama_kv_cache_dsa * kv,
|
||||
llama_context * lctx,
|
||||
bool optimize) :
|
||||
ctx_base(kv->get_base()->init_update(lctx, optimize)),
|
||||
ctx_ik(kv->get_ik()->init_update(lctx, optimize)),
|
||||
status(llama_memory_status_combine(ctx_base->get_status(), ctx_ik->get_status())) {
|
||||
}
|
||||
|
||||
llama_kv_cache_dsa_context::llama_kv_cache_dsa_context(
|
||||
llama_kv_cache_dsa * kv,
|
||||
slot_info_vec_t sinfos_base,
|
||||
slot_info_vec_t sinfos_ik,
|
||||
std::vector<llama_ubatch> ubatches) :
|
||||
ubatches(std::move(ubatches)),
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
|
||||
ctx_ik(new llama_ik_cache_context(kv->get_ik(), std::move(sinfos_ik), this->ubatches)),
|
||||
status(llama_memory_status_combine(ctx_base->get_status(), ctx_ik->get_status())) {
|
||||
}
|
||||
|
||||
llama_kv_cache_dsa_context:: ~llama_kv_cache_dsa_context() = default;
|
||||
|
||||
bool llama_kv_cache_dsa_context::next() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
ctx_base->next();
|
||||
ctx_ik ->next();
|
||||
|
||||
if (++i_next >= ubatches.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_dsa_context::apply() {
|
||||
assert(!llama_memory_status_is_fail(status));
|
||||
|
||||
bool res = true;
|
||||
|
||||
res = res & ctx_base->apply();
|
||||
res = res & ctx_ik ->apply();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
llama_memory_status llama_kv_cache_dsa_context::get_status() const {
|
||||
return status;
|
||||
}
|
||||
|
||||
const llama_ubatch & llama_kv_cache_dsa_context::get_ubatch() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return ubatches[i_next];
|
||||
}
|
||||
|
||||
const llama_kv_cache_context * llama_kv_cache_dsa_context::get_base() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return static_cast<const llama_kv_cache_context *>(ctx_base.get());
|
||||
}
|
||||
|
||||
const llama_ik_cache_context * llama_kv_cache_dsa_context::get_ik() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return static_cast<const llama_ik_cache_context *>(ctx_ik.get());
|
||||
}
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
#pragma once
|
||||
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-ik-cache.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
//
|
||||
// llama_kv_cache_dsa
|
||||
//
|
||||
|
||||
// utilizes two KV cache instances: llama_kv_cache and llama_ik_cache
|
||||
// the first instance is for caching key tensors of the model,
|
||||
// the second instance is for caching lightning indexer key tensors
|
||||
|
||||
class llama_kv_cache_dsa : public llama_memory_i {
|
||||
public:
|
||||
llama_kv_cache_dsa(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
const layer_filter_cb & filter,
|
||||
const layer_reuse_cb & reuse);
|
||||
|
||||
~llama_kv_cache_dsa() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
llama_memory_context_ptr init_batch(
|
||||
llama_batch_allocr & balloc,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_all) override;
|
||||
|
||||
llama_memory_context_ptr init_full() override;
|
||||
|
||||
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
void clear(bool data) override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_dsa specific API
|
||||
//
|
||||
|
||||
llama_kv_cache * get_base() const;
|
||||
llama_ik_cache * get_ik () const;
|
||||
|
||||
private:
|
||||
const uint32_t n_stream = 1;
|
||||
|
||||
std::unique_ptr<llama_kv_cache> kv_base;
|
||||
std::unique_ptr<llama_ik_cache> kv_ik;
|
||||
};
|
||||
|
||||
class llama_kv_cache_dsa_context : public llama_memory_context_i {
|
||||
public:
|
||||
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
|
||||
|
||||
// used for errors
|
||||
llama_kv_cache_dsa_context(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache context
|
||||
llama_kv_cache_dsa_context(
|
||||
llama_kv_cache_dsa * kv);
|
||||
|
||||
// used to create an update context
|
||||
llama_kv_cache_dsa_context(
|
||||
llama_kv_cache_dsa * kv,
|
||||
llama_context * lctx,
|
||||
bool optimize);
|
||||
|
||||
// used to create a batch processing context from a batch
|
||||
llama_kv_cache_dsa_context(
|
||||
llama_kv_cache_dsa * kv,
|
||||
slot_info_vec_t sinfos_base,
|
||||
slot_info_vec_t sinfos_ik,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_dsa_context();
|
||||
|
||||
//
|
||||
// llama_memory_context_i
|
||||
//
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_dsa_context specific API
|
||||
//
|
||||
|
||||
const llama_kv_cache_context * get_base() const;
|
||||
const llama_ik_cache_context * get_ik() const;
|
||||
|
||||
private:
|
||||
//llama_kv_cache_dsa * kv;
|
||||
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
const llama_memory_context_ptr ctx_base;
|
||||
const llama_memory_context_ptr ctx_ik;
|
||||
|
||||
const llama_memory_status status;
|
||||
};
|
||||
|
|
@ -51,7 +51,7 @@ llama_kv_cache::llama_kv_cache(
|
|||
auto it = ctx_map.find(buft);
|
||||
if (it == ctx_map.end()) {
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ size_t(3u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
|
||||
/*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
|
@ -113,7 +113,6 @@ llama_kv_cache::llama_kv_cache(
|
|||
// [TAG_V_CACHE_VARIABLE]
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
|
||||
const uint32_t n_embd_indexer_head = hparams.indexer_head_size;
|
||||
|
||||
const char * dev_name = "CPU";
|
||||
|
||||
|
|
@ -135,29 +134,24 @@ llama_kv_cache::llama_kv_cache(
|
|||
|
||||
const bool has_k = true;
|
||||
const bool has_v = !is_mla;
|
||||
const bool has_ik = hparams.indexer_top_k > 0;
|
||||
|
||||
ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream) : nullptr;
|
||||
ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream) : nullptr;
|
||||
ggml_tensor * ik = has_ik ? ggml_new_tensor_3d(ctx, type_k, n_embd_indexer_head, kv_size, n_stream) : nullptr;
|
||||
|
||||
has_k && ggml_format_name(k, "cache_k_l%d", il);
|
||||
has_v && ggml_format_name(v, "cache_v_l%d", il);
|
||||
has_ik && ggml_format_name(ik, "cache_ik_l%d", il);
|
||||
|
||||
std::vector<ggml_tensor *> k_stream;
|
||||
std::vector<ggml_tensor *> v_stream;
|
||||
std::vector<ggml_tensor *> ik_stream;
|
||||
|
||||
for (uint32_t s = 0; s < n_stream; ++s) {
|
||||
k_stream.push_back(has_k ? ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]) : nullptr);
|
||||
v_stream.push_back(has_v ? ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]) : nullptr);
|
||||
ik_stream.push_back(has_ik ? ggml_view_2d(ctx, ik, n_embd_indexer_head, kv_size, ik->nb[1], s*ik->nb[2]) : nullptr);
|
||||
}
|
||||
|
||||
map_layer_ids[il] = layers.size();
|
||||
|
||||
layers.push_back({ il, k, v, ik, k_stream, v_stream, ik_stream });
|
||||
layers.push_back({ il, k, v, k_stream, v_stream, });
|
||||
}
|
||||
|
||||
if (reuse) {
|
||||
|
|
@ -208,13 +202,11 @@ llama_kv_cache::llama_kv_cache(
|
|||
{
|
||||
const size_t memory_size_k = size_k_bytes();
|
||||
const size_t memory_size_v = size_v_bytes();
|
||||
const size_t memory_size_ik = size_ik_bytes();
|
||||
|
||||
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB, IK (%s): %7.2f MiB\n", __func__,
|
||||
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
||||
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream,
|
||||
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_k), (float)memory_size_ik / (1024.0f * 1024.0f));
|
||||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||
}
|
||||
|
||||
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
|
||||
|
|
@ -664,10 +656,6 @@ bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_co
|
|||
if (layer.v_stream[ssrc]) {
|
||||
ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
|
||||
}
|
||||
|
||||
if (layer.ik_stream[ssrc]) {
|
||||
ggml_backend_tensor_copy(layer.ik_stream[ssrc], layer.ik_stream[sdst]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1084,26 +1072,6 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
|
|||
ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache::get_ik(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
|
||||
const int32_t ikv = map_layer_ids.at(il);
|
||||
|
||||
auto * ik = layers[ikv].ik;
|
||||
|
||||
const uint64_t kv_size = get_size();
|
||||
const uint64_t n_embd_indexer_head = ik->ne[0];
|
||||
|
||||
assert(n_embd_indexer_head == hparams.indexer_head_size);
|
||||
|
||||
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
|
||||
|
||||
return ggml_view_4d(ctx, ik,
|
||||
n_embd_indexer_head, 1, n_kv, ns,
|
||||
ggml_row_size(ik->type, n_embd_indexer_head),
|
||||
ggml_row_size(ik->type, n_embd_indexer_head),
|
||||
ggml_row_size(ik->type, n_embd_indexer_head*kv_size),
|
||||
ggml_row_size(ik->type, n_embd_indexer_head*kv_size)*sinfo.s0);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
|
||||
GGML_UNUSED(sinfo);
|
||||
|
||||
|
|
@ -1195,41 +1163,6 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm
|
|||
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache::cpy_ik(ggml_context * ctx, ggml_tensor * ik_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
|
||||
GGML_UNUSED(sinfo);
|
||||
|
||||
const int32_t ikv = map_layer_ids.at(il);
|
||||
|
||||
ggml_tensor * ik = layers[ikv].ik;
|
||||
|
||||
const int64_t n_embd_indexer_head = ik_cur->ne[0];
|
||||
const int64_t n_head = ik_cur->ne[1];
|
||||
const int64_t n_tokens = ik_cur->ne[2];
|
||||
|
||||
const int64_t n_embd_gqa = n_embd_indexer_head*n_head;
|
||||
|
||||
// we can merge dims 0 and 1
|
||||
// TODO: add ggml helper function for this?
|
||||
GGML_ASSERT(ggml_row_size(ik_cur->type, n_embd_indexer_head) == ik_cur->nb[1]);
|
||||
|
||||
ik_cur = ggml_view_2d(ctx, ik_cur, n_embd_gqa, n_tokens, ik_cur->nb[2], 0);
|
||||
|
||||
const int64_t n_stream = ik->ne[2];
|
||||
|
||||
if (n_stream > 1) {
|
||||
const int64_t kv_size = get_size();
|
||||
|
||||
assert(n_embd_gqa == ik->ne[0]);
|
||||
assert(kv_size == ik->ne[1]);
|
||||
|
||||
// merge the buffer across all streams because the idxs are global
|
||||
ik = ggml_reshape_2d(ctx, ik, n_embd_gqa, kv_size*n_stream);
|
||||
}
|
||||
|
||||
// store the current K values into the cache
|
||||
return ggml_set_rows(ctx, ik, ik_cur, k_idxs);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
const uint32_t n_tokens = ubatch.n_tokens;
|
||||
|
||||
|
|
@ -1604,16 +1537,6 @@ size_t llama_kv_cache::size_v_bytes() const {
|
|||
return size_v_bytes;
|
||||
}
|
||||
|
||||
size_t llama_kv_cache::size_ik_bytes() const {
|
||||
size_t size_ik_bytes = 0;
|
||||
|
||||
for (const auto & layer : layers) {
|
||||
size_ik_bytes += layer.ik ? ggml_nbytes(layer.ik) : 0;
|
||||
}
|
||||
|
||||
return size_ik_bytes;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache::build_rope_shift(
|
||||
const llama_cparams & cparams,
|
||||
ggml_context * ctx,
|
||||
|
|
@ -2319,10 +2242,6 @@ ggml_tensor * llama_kv_cache_context::get_v(ggml_context * ctx, int32_t il) cons
|
|||
return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_context::get_ik(ggml_context * ctx, int32_t il) const {
|
||||
return kv->get_ik(ctx, il, n_kv, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
|
||||
return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
|
||||
}
|
||||
|
|
@ -2331,10 +2250,6 @@ ggml_tensor * llama_kv_cache_context::cpy_v(ggml_context * ctx, ggml_tensor * v_
|
|||
return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_context::cpy_ik(ggml_context * ctx, ggml_tensor * ik_cur, ggml_tensor * k_idxs, int32_t il) const {
|
||||
return kv->cpy_ik(ctx, ik_cur, k_idxs, il, sinfos[i_cur]);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
||||
return kv->build_input_k_idxs(ctx, ubatch);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -161,12 +161,10 @@ public:
|
|||
// get views of the current state of the cache
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
|
||||
ggml_tensor * get_ik(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
|
||||
ggml_tensor * cpy_ik(ggml_context * ctx, ggml_tensor * ik_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
|
||||
|
||||
//
|
||||
// preparation API
|
||||
|
|
@ -212,11 +210,9 @@ private:
|
|||
|
||||
ggml_tensor * k;
|
||||
ggml_tensor * v;
|
||||
ggml_tensor * ik;
|
||||
|
||||
std::vector<ggml_tensor *> k_stream;
|
||||
std::vector<ggml_tensor *> v_stream;
|
||||
std::vector<ggml_tensor *> ik_stream;
|
||||
};
|
||||
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
|
|
@ -260,7 +256,6 @@ private:
|
|||
|
||||
size_t size_k_bytes() const;
|
||||
size_t size_v_bytes() const;
|
||||
size_t size_ik_bytes() const;
|
||||
|
||||
ggml_tensor * build_rope_shift(
|
||||
const llama_cparams & cparams,
|
||||
|
|
@ -336,7 +331,6 @@ public:
|
|||
// get views of the current state of the cache
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
||||
ggml_tensor * get_ik(ggml_context * ctx, int32_t il) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
// note: the heads in k_cur and v_cur should be layed out contiguously in memory
|
||||
|
|
@ -346,7 +340,6 @@ public:
|
|||
// - v_idxs [n_tokens] or [n_tokens*n_embd_v_gqa] depending if V cache is transposed
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
|
||||
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
|
||||
ggml_tensor * cpy_ik(ggml_context * ctx, ggml_tensor * ik_cur, ggml_tensor * k_idxs, int32_t il) const;
|
||||
|
||||
// create destination indices for each head of the current batch for where it would be written in the KV cache
|
||||
// the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-kv-cache-iswa.h"
|
||||
#include "llama-kv-cache-dsa.h"
|
||||
#include "llama-memory-hybrid.h"
|
||||
#include "llama-memory-hybrid-iswa.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
|
@ -8111,6 +8112,23 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
{
|
||||
res = nullptr;
|
||||
} break;
|
||||
case LLM_ARCH_DEEPSEEK32:
|
||||
{
|
||||
res = new llama_kv_cache_dsa(
|
||||
*this,
|
||||
params.type_k,
|
||||
params.type_v,
|
||||
!cparams.flash_attn,
|
||||
cparams.offload_kqv,
|
||||
cparams.kv_unified,
|
||||
cparams.n_ctx_seq,
|
||||
cparams.n_seq_max,
|
||||
1,
|
||||
hparams.n_swa,
|
||||
hparams.swa_type,
|
||||
nullptr,
|
||||
nullptr);
|
||||
} break;
|
||||
// Models that need standard caching should rely on recurrent/hybrid
|
||||
// checks
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -1,14 +1,17 @@
|
|||
#include "models.h"
|
||||
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-ik-cache.h"
|
||||
|
||||
llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const bool is_mla = hparams.is_mla();
|
||||
GGML_ASSERT(is_mla);
|
||||
|
||||
// note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k_mla();
|
||||
const int64_t n_embd_head_v = hparams.n_embd_head_v_mla();
|
||||
GGML_UNUSED(n_embd_head_v);
|
||||
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot();
|
||||
const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope;
|
||||
|
|
@ -42,8 +45,9 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
|
|||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn_kv = !is_mla ? build_attn_inp_kv() : nullptr;
|
||||
auto * inp_attn_k = is_mla ? build_attn_inp_k() : nullptr;
|
||||
std::pair<llm_graph_input_attn_k*, llm_graph_input_attn_ik*> inp_attn_dsa = build_attn_inp_k_dsa();
|
||||
auto * inp_attn_k = inp_attn_dsa.first;
|
||||
auto * inp_attn_ik = inp_attn_dsa.second;
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
|
|
@ -63,9 +67,7 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
|
|||
qr = build_norm(qr, model.layers[il].attn_q_a_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(qr, "qr", il);
|
||||
|
||||
ggml_tensor * kq_mask = is_mla ? inp_attn_k->get_kq_mask() : inp_attn_kv->get_kq_mask();
|
||||
ggml_tensor * kq_mask_bak = ggml_dup(ctx0, kq_mask);
|
||||
ggml_build_forward_expand(gf, kq_mask_bak);
|
||||
ggml_tensor * top_k = nullptr;
|
||||
|
||||
// lightning indexer
|
||||
{
|
||||
|
|
@ -133,9 +135,9 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
|
|||
cb(indexer_k, "indexer_k", il);
|
||||
|
||||
// store indexer keys to KV cache
|
||||
const auto * mctx_cur = is_mla ? inp_attn_k->mctx : inp_attn_kv->mctx;
|
||||
const auto & k_idxs = is_mla ? inp_attn_k->get_k_idxs() : inp_attn_kv->get_k_idxs();
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_ik(ctx0, indexer_k, k_idxs, il));
|
||||
const auto * mctx_cur = inp_attn_ik->mctx;
|
||||
const auto & k_idxs = inp_attn_ik->get_k_idxs();
|
||||
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, indexer_k, k_idxs, il));
|
||||
|
||||
// prepare indexer weights
|
||||
ggml_tensor * indexer_weights = ggml_mul_mat(ctx0, model.layers[il].indexer_proj, cur);
|
||||
|
|
@ -145,7 +147,7 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
|
|||
cb(indexer_weights, "indexer_weights", il);
|
||||
|
||||
// get cached indexer keys
|
||||
indexer_k = mctx_cur->get_ik(ctx0, il);
|
||||
indexer_k = mctx_cur->get_k(ctx0, il);
|
||||
|
||||
// split the batch into streams if needed
|
||||
const auto n_stream = indexer_k->ne[3];
|
||||
|
|
@ -188,24 +190,14 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
|
|||
cb(indexer_score, "indexer_score", il);
|
||||
|
||||
// mask indexer scores
|
||||
ggml_tensor * kq_mask_f32 = ggml_cast(ctx0, kq_mask, GGML_TYPE_F32);
|
||||
indexer_score = ggml_add(ctx0, indexer_score, kq_mask_f32);
|
||||
ggml_tensor * indexer_kq_mask = inp_attn_ik->get_kq_mask();
|
||||
indexer_score = ggml_add(ctx0, indexer_score, indexer_kq_mask);
|
||||
cb(indexer_score, "indexer_score", il);
|
||||
|
||||
// get indices of top k indexer scores
|
||||
uint32_t n_top_k = indexer_score->ne[0] < n_indexer_top_k ? indexer_score->ne[0] : n_indexer_top_k;
|
||||
ggml_tensor * top_k = ggml_cont(ctx0, ggml_top_k(ctx0, indexer_score, n_top_k));
|
||||
top_k = ggml_cont(ctx0, ggml_top_k(ctx0, indexer_score, n_top_k));
|
||||
cb(top_k, "top_k", il);
|
||||
|
||||
// prepare new kq mask - starts filled with -INFINITY
|
||||
ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask_f32, -INFINITY);
|
||||
cb(kq_mask_all, "kq_mask_all", il);
|
||||
|
||||
// modify it by unmasking tokens that are in top_k indices
|
||||
ggml_tensor * kq_mask_top_k = ggml_where_id(ctx0, kq_mask_f32, kq_mask_all, top_k);
|
||||
cb(kq_mask_top_k, "kq_mask_top_k", il);
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_cast(ctx0, kq_mask_top_k, kq_mask->type), kq_mask));
|
||||
}
|
||||
|
||||
ggml_tensor * q = ggml_mul_mat(ctx0, model.layers[il].wq_b, qr);
|
||||
|
|
@ -250,7 +242,8 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
|
|||
kv_cmpr = build_norm(kv_cmpr, model.layers[il].attn_kv_a_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(kv_cmpr, "kv_cmpr", il);
|
||||
|
||||
if (is_mla) {
|
||||
// MLA attention
|
||||
{
|
||||
// {n_embd_head_qk_nope, n_tokens, n_head}
|
||||
q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
|
||||
cb(q_nope, "q_nope_perm", il);
|
||||
|
|
@ -282,41 +275,8 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
|
|||
// note: MLA with the absorption optimization converts into MQA (ie: GQA with 1 group)
|
||||
cur = build_attn(inp_attn_k,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il);
|
||||
} else {
|
||||
ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr);
|
||||
cb(kv, "kv", il);
|
||||
|
||||
// split into {n_embd_head_qk_nope, n_head, n_tokens}
|
||||
ggml_tensor * k_nope =
|
||||
ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
|
||||
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v),
|
||||
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head, 0);
|
||||
cb(k_nope, "k_nope_view", il);
|
||||
|
||||
// and {n_embd_head_v, n_head, n_tokens}
|
||||
ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, n_embd_head_v, n_head, n_tokens,
|
||||
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v),
|
||||
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head,
|
||||
ggml_row_size(kv->type, n_embd_head_qk_nope));
|
||||
cb(Vcur, "Vcur_view", il);
|
||||
|
||||
Vcur = ggml_cont(ctx0, Vcur);
|
||||
cb(Vcur, "Vcur_cont", il);
|
||||
|
||||
ggml_tensor * Qcur = ggml_concat(ctx0, q_nope, q_pe, 0);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
// note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups)
|
||||
cur = build_attn(inp_attn_kv,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, top_k, kq_scale, il);
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_cpy(ctx0, kq_mask_bak, kq_mask));
|
||||
}
|
||||
if (il == effective_n_layers - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
|
|
|
|||
Loading…
Reference in New Issue