kv-cache : added llama_kv_cache_dsa KV cache specific to DSA composed of llama_kv_cache and new llama_ik_cache (lightning indexer key cache).

model : used new llama_kv_cache_dsa instead of modified llama_kv_cache with indexer keys in DeepseekV32ForCausalLM
model : removed non-MLA path in DeepseekV32ForCausalLM
This commit is contained in:
Stanisław Szymczyk 2026-03-24 13:51:33 +01:00
parent 1874ac9b86
commit 4309c8486a
11 changed files with 2819 additions and 153 deletions

View File

@ -22,6 +22,8 @@ add_library(llama
llama-io.cpp
llama-kv-cache.cpp
llama-kv-cache-iswa.cpp
llama-ik-cache.cpp
llama-kv-cache-dsa.cpp
llama-memory.cpp
llama-memory-hybrid.cpp
llama-memory-hybrid-iswa.cpp

View File

@ -6,6 +6,7 @@
#include "llama-kv-cache.h"
#include "llama-kv-cache-iswa.h"
#include "llama-kv-cache-dsa.h"
#include "llama-memory-hybrid.h"
#include "llama-memory-hybrid-iswa.h"
#include "llama-memory-recurrent.h"
@ -31,6 +32,18 @@ static ggml_tensor * build_kq_mask(
return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
}
static ggml_tensor * build_kq_mask(
ggml_context * ctx,
const llama_ik_cache_context * mctx,
const llama_ubatch & ubatch,
const llama_cparams & cparams) {
const auto n_kv = mctx->get_n_kv();
const auto n_tokens = ubatch.n_tokens;
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
}
static bool can_reuse_kq_mask(
ggml_tensor * kq_mask,
const llama_kv_cache_context * mctx,
@ -50,6 +63,25 @@ static bool can_reuse_kq_mask(
return res;
}
static bool can_reuse_kq_mask(
ggml_tensor * kq_mask,
const llama_ik_cache_context * mctx,
const llama_ubatch & ubatch,
const llama_cparams & cparams) {
const auto n_kv = mctx->get_n_kv();
const auto n_tokens = ubatch.n_tokens;
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
bool res = true;
res &= (kq_mask->ne[0] == n_kv);
res &= (kq_mask->ne[1] == n_tokens/n_stream);
res &= (kq_mask->ne[2] == 1);
res &= (kq_mask->ne[3] == n_stream);
return res;
}
// impl
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
@ -2108,6 +2140,112 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_k * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * sinks,
ggml_tensor * v_mla,
ggml_tensor * top_k,
float kq_scale,
int il) const {
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
// expand k later to enable rope fusion which directly writes into k-v cache
ggml_build_forward_expand(gf, q_cur);
ggml_build_forward_expand(gf, v_cur);
ggml_build_forward_expand(gf, k_cur);
const auto * mctx_cur = inp->mctx;
// store to KV cache
{
const auto & k_idxs = inp->get_k_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
}
const auto & kq_mask = inp->get_kq_mask();
ggml_tensor * kq_mask_f32 = ggml_cast(ctx0, kq_mask, GGML_TYPE_F32);
// prepare new kq mask - starts filled with -INFINITY
ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask_f32, -INFINITY);
// modify it by unmasking tokens that are in top_k indices
ggml_tensor * kq_mask_top_k = ggml_where_id(ctx0, kq_mask_f32, kq_mask_all, top_k);
kq_mask_top_k = ggml_cast(ctx0, kq_mask_top_k, kq_mask->type);
ggml_tensor * q = q_cur;
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask_top_k, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}
if (wo_b) {
cur = ggml_add(ctx0, cur, wo_b);
}
return cur;
}
static std::unique_ptr<llm_graph_input_attn_ik> build_attn_inp_ik_impl(
ggml_context * ctx0,
const llama_ubatch & ubatch,
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_ik_cache_context * mctx_cur) {
auto inp = std::make_unique<llm_graph_input_attn_ik>(hparams, cparams, mctx_cur);
{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = inp->self_kq_mask;
}
return inp;
}
void llm_graph_input_attn_ik::set_input(const llama_ubatch * ubatch) {
mctx->set_input_k_idxs(self_k_idxs, ubatch);
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}
bool llm_graph_input_attn_ik::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_ik_cache_context *>(params.mctx);
this->mctx = mctx;
bool res = true;
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
return res;
}
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_kv_iswa * inp,
ggml_tensor * wo,
@ -2230,6 +2368,17 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}
std::pair<llm_graph_input_attn_k *, llm_graph_input_attn_ik *> llm_graph_context::build_attn_inp_k_dsa() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_dsa_context *>(mctx);
auto inp_k = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_base());
auto inp_ik = build_attn_inp_ik_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_ik());
return std::make_pair(
(llm_graph_input_attn_k *) res->add_input(std::move(inp_k)),
(llm_graph_input_attn_ik *) res->add_input(std::move(inp_ik)));
}
// TODO: maybe separate the inner implementation into a separate function
// like with the non-sliding window equivalent
// once sliding-window hybrid caches are a thing.

View File

@ -21,6 +21,7 @@ struct llama_cparams;
struct llama_memory_context_i;
class llama_kv_cache_context;
class llama_ik_cache_context;
class llama_kv_cache_iswa_context;
class llama_memory_recurrent_context;
class llama_memory_hybrid_context;
@ -350,6 +351,39 @@ public:
const llama_kv_cache_context * mctx;
};
// V-less input for the indexer KV cache
class llm_graph_input_attn_ik : public llm_graph_input_i {
public:
llm_graph_input_attn_ik(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_ik_cache_context * mctx) :
hparams(hparams),
cparams(cparams),
mctx(mctx) {
}
~llm_graph_input_attn_ik() = default;
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
const llama_hparams hparams;
const llama_cparams cparams;
const llama_ik_cache_context * mctx;
};
class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
public:
llm_graph_input_attn_kv_iswa(
@ -914,6 +948,20 @@ struct llm_graph_context {
float kq_scale,
int il) const;
ggml_tensor * build_attn(
llm_graph_input_attn_k * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
ggml_tensor * top_k, // [n_indexer_top_k, n_tokens]
float kq_scale,
int il) const;
llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
@ -945,6 +993,8 @@ struct llm_graph_context {
float kq_scale,
int il) const;
std::pair<llm_graph_input_attn_k *, llm_graph_input_attn_ik *> build_attn_inp_k_dsa() const;
//
// recurrent
//

1885
src/llama-ik-cache.cpp Normal file

File diff suppressed because it is too large Load Diff

306
src/llama-ik-cache.h Normal file
View File

@ -0,0 +1,306 @@
#pragma once
#include "llama-kv-cache.h"
#include "llama-batch.h"
#include "llama-graph.h"
#include "llama-kv-cells.h"
#include "llama-memory.h"
#include <unordered_map>
#include <vector>
struct llama_cparams;
struct llama_hparams;
struct llama_model;
struct llama_context;
//
// llama_ik_cache
//
class llama_ik_cache : public llama_memory_i {
public:
using stream_copy_info = llama_kv_cache::stream_copy_info;
using slot_info = llama_kv_cache::slot_info;
using slot_info_vec_t = std::vector<slot_info>;
llama_ik_cache(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
~llama_ik_cache() = default;
//
// llama_memory_i
//
llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) override;
llama_memory_context_ptr init_full() override;
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
void clear(bool data) override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
//
// llama_ik_cache specific API
//
uint32_t get_size() const;
uint32_t get_n_stream() const;
bool get_has_shift() const;
//
// graph_build API
//
uint32_t get_n_kv(const slot_info & sinfo) const;
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
// store k_cur and v_cur in the cache based on the provided head location
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
//
// preparation API
//
// find places for the provided ubatches in the cache, returns the slot infos
// return empty vector on failure
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
bool update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info);
// find a slot of kv cells that can hold the ubatch
// if cont == true, then the slot must be continuous
// return empty slot_info on failure
slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
// emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
//
// input API
//
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_k_shift(ggml_tensor * dst) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
private:
const llama_model & model;
const llama_hparams & hparams;
struct kv_layer {
// layer index in the model
// note: can be different from the layer index in the KV cache
uint32_t il;
ggml_tensor * k;
std::vector<ggml_tensor *> k_stream;
};
bool v_trans = true; // the value tensor is transposed
const uint32_t n_seq_max = 1;
const uint32_t n_stream = 1;
// required padding
const uint32_t n_pad = 1;
// SWA
const uint32_t n_swa = 0;
// env: LLAMA_KV_CACHE_DEBUG
int debug = 0;
// this is the SWA type of the cache - not to be confused with the model SWA type
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
// ggml contexts for the KV cache along with the allocated backend buffers:
std::vector<std::pair<ggml_context_ptr, ggml_backend_buffer_ptr>> ctxs_bufs;
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
std::vector<uint32_t> v_heads;
std::vector<llama_kv_cells> v_cells;
// maps from a sequence id to a stream id
std::vector<uint32_t> seq_to_stream;
// pending stream copies that will be applied during the next update
stream_copy_info sc_info;
std::vector<kv_layer> layers;
// model layer id -> KV cache layer id
std::unordered_map<int32_t, int32_t> map_layer_ids;
size_t total_size() const;
size_t size_k_bytes() const;
ggml_tensor * build_rope_shift(
const llama_cparams & cparams,
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * factors,
float freq_base,
float freq_scale,
uint32_t il) const;
ggml_cgraph * build_graph_shift(
llm_graph_result * res,
llama_context * lctx) const;
struct cell_ranges_t {
uint32_t strm;
std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive
};
void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id = -1);
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo);
};
class llama_ik_cache_context : public llama_memory_context_i {
public:
// some shorthands
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
using stream_copy_info = llama_kv_cache::stream_copy_info;
// used for errors
llama_ik_cache_context(llama_memory_status status);
// used to create a full-cache context
llama_ik_cache_context(
llama_ik_cache * kv);
// used to create an update context
llama_ik_cache_context(
llama_ik_cache * kv,
llama_context * lctx,
bool do_shift,
stream_copy_info sc_info);
// used to create a batch processing context from a batch
llama_ik_cache_context(
llama_ik_cache * kv,
slot_info_vec_t sinfos,
std::vector<llama_ubatch> ubatches);
virtual ~llama_ik_cache_context();
//
// llama_memory_context_i
//
bool next() override;
bool apply() override;
llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override;
//
// llama_ik_cache_context specific API
//
uint32_t get_n_kv() const;
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
// store k_cur and v_cur in the cache based on the provided head location
// note: the heads in k_cur and v_cur should be layed out contiguously in memory
// - k_cur [n_embd_head_k, n_head_k, n_tokens]
// - k_idxs [n_tokens]
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
// create destination indices for each head of the current batch for where it would be written in the KV cache
// the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but
// helps understand the implementation logic of cpy_k
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_k_shift (ggml_tensor * dst) const;
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
private:
llama_memory_status status;
llama_ik_cache * kv;
llama_context * lctx;
//
// update context
//
bool do_shift = false;
stream_copy_info sc_info;
//
// batch processing context
//
// the index of the cur ubatch to process
size_t i_cur = 0;
slot_info_vec_t sinfos;
std::vector<llama_ubatch> ubatches;
//
// data needed for building the compute graph for the current ubatch:
//
// a heuristic, to avoid attending the full cache if it is not yet utilized
// as the cache gets filled, the benefit from this heuristic disappears
int32_t n_kv;
};

251
src/llama-kv-cache-dsa.cpp Normal file
View File

@ -0,0 +1,251 @@
#include "llama-kv-cache-dsa.h"
#include "llama-impl.h"
#include "llama-batch.h"
#include "llama-model.h"
#include <algorithm>
#include <cassert>
//
// llama_kv_cache_dsa
//
llama_kv_cache_dsa::llama_kv_cache_dsa(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse) :
n_stream(unified ? 1 : n_seq_max) {
LLAMA_LOG_INFO("%s: creating main KV cache, size = %u cells\n", __func__, kv_size);
kv_base = std::make_unique<llama_kv_cache>(
model, type_k, type_v,
v_trans, offload, unified, kv_size, n_seq_max, n_pad,
n_swa, swa_type, filter, reuse);
LLAMA_LOG_INFO("%s: creating indexer KV cache, size = %u cells\n", __func__, kv_size);
kv_ik = std::make_unique<llama_ik_cache>(
model, type_k, type_v,
v_trans, offload, unified, kv_size, n_seq_max, n_pad,
n_swa, swa_type, filter, reuse);
}
void llama_kv_cache_dsa::clear(bool data) {
kv_base->clear(data);
kv_ik ->clear(data);
}
bool llama_kv_cache_dsa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
bool res = true;
res = res & kv_base->seq_rm(seq_id, p0, p1);
res = res & kv_ik ->seq_rm(seq_id, p0, p1);
return res;
}
void llama_kv_cache_dsa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
kv_ik ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
}
void llama_kv_cache_dsa::seq_keep(llama_seq_id seq_id) {
kv_base->seq_keep(seq_id);
kv_ik ->seq_keep(seq_id);
}
void llama_kv_cache_dsa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
kv_base->seq_add(seq_id, p0, p1, shift);
kv_ik ->seq_add(seq_id, p0, p1, shift);
}
void llama_kv_cache_dsa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
kv_base->seq_div(seq_id, p0, p1, d);
kv_ik ->seq_div(seq_id, p0, p1, d);
}
llama_pos llama_kv_cache_dsa::seq_pos_min(llama_seq_id seq_id) const {
return kv_base->seq_pos_min(seq_id);
}
llama_pos llama_kv_cache_dsa::seq_pos_max(llama_seq_id seq_id) const {
return kv_base->seq_pos_max(seq_id);
}
std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache_dsa::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> mb = kv_base->memory_breakdown();
for (const auto & buft_size : kv_ik->memory_breakdown()) {
mb[buft_size.first] += buft_size.second;
}
return mb;
}
llama_memory_context_ptr llama_kv_cache_dsa::init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) {
GGML_UNUSED(embd_all);
do {
balloc.split_reset();
std::vector<llama_ubatch> ubatches;
while (true) {
auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true);
if (ubatch.n_tokens == 0) {
break;
}
ubatches.push_back(std::move(ubatch)); // NOLINT
}
if (balloc.get_n_used() < balloc.get_n_tokens()) {
// failed to find a suitable split
break;
}
auto sinfos_base = kv_base->prepare(ubatches);
if (sinfos_base.empty()) {
break;
}
auto sinfos_ik = kv_ik->prepare(ubatches);
if (sinfos_ik.empty()) {
break;
}
assert(sinfos_base.size() == sinfos_ik.size());
return std::make_unique<llama_kv_cache_dsa_context>(
this, std::move(sinfos_base), std::move(sinfos_ik), std::move(ubatches));
} while (false);
return std::make_unique<llama_kv_cache_dsa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
}
llama_memory_context_ptr llama_kv_cache_dsa::init_full() {
return std::make_unique<llama_kv_cache_dsa_context>(this);
}
llama_memory_context_ptr llama_kv_cache_dsa::init_update(llama_context * lctx, bool optimize) {
return std::make_unique<llama_kv_cache_dsa_context>(this, lctx, optimize);
}
bool llama_kv_cache_dsa::get_can_shift() const {
return kv_base->get_can_shift() &&
kv_ik->get_can_shift() &&
kv_base->get_size() == kv_ik->get_size();
}
void llama_kv_cache_dsa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
kv_base->state_write(io, seq_id, flags);
kv_ik->state_write(io, seq_id, flags);
}
void llama_kv_cache_dsa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
kv_base->state_read(io, seq_id, flags);
kv_ik->state_read(io, seq_id, flags);
}
llama_kv_cache * llama_kv_cache_dsa::get_base() const {
return kv_base.get();
}
llama_ik_cache * llama_kv_cache_dsa::get_ik() const {
return kv_ik.get();
}
//
// llama_kv_cache_dsa_context
//
llama_kv_cache_dsa_context::llama_kv_cache_dsa_context(llama_memory_status status) : status(status) {}
llama_kv_cache_dsa_context::llama_kv_cache_dsa_context(
llama_kv_cache_dsa * kv) :
ctx_base(kv->get_base()->init_full()),
ctx_ik(kv->get_ik()->init_full()),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_ik->get_status())) {
}
llama_kv_cache_dsa_context::llama_kv_cache_dsa_context(
llama_kv_cache_dsa * kv,
llama_context * lctx,
bool optimize) :
ctx_base(kv->get_base()->init_update(lctx, optimize)),
ctx_ik(kv->get_ik()->init_update(lctx, optimize)),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_ik->get_status())) {
}
llama_kv_cache_dsa_context::llama_kv_cache_dsa_context(
llama_kv_cache_dsa * kv,
slot_info_vec_t sinfos_base,
slot_info_vec_t sinfos_ik,
std::vector<llama_ubatch> ubatches) :
ubatches(std::move(ubatches)),
// note: here we copy the ubatches. not sure if this is ideal
ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
ctx_ik(new llama_ik_cache_context(kv->get_ik(), std::move(sinfos_ik), this->ubatches)),
status(llama_memory_status_combine(ctx_base->get_status(), ctx_ik->get_status())) {
}
llama_kv_cache_dsa_context:: ~llama_kv_cache_dsa_context() = default;
bool llama_kv_cache_dsa_context::next() {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
ctx_base->next();
ctx_ik ->next();
if (++i_next >= ubatches.size()) {
return false;
}
return true;
}
bool llama_kv_cache_dsa_context::apply() {
assert(!llama_memory_status_is_fail(status));
bool res = true;
res = res & ctx_base->apply();
res = res & ctx_ik ->apply();
return res;
}
llama_memory_status llama_kv_cache_dsa_context::get_status() const {
return status;
}
const llama_ubatch & llama_kv_cache_dsa_context::get_ubatch() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return ubatches[i_next];
}
const llama_kv_cache_context * llama_kv_cache_dsa_context::get_base() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return static_cast<const llama_kv_cache_context *>(ctx_base.get());
}
const llama_ik_cache_context * llama_kv_cache_dsa_context::get_ik() const {
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
return static_cast<const llama_ik_cache_context *>(ctx_ik.get());
}

137
src/llama-kv-cache-dsa.h Normal file
View File

@ -0,0 +1,137 @@
#pragma once
#include "llama-kv-cache.h"
#include "llama-ik-cache.h"
#include <vector>
//
// llama_kv_cache_dsa
//
// utilizes two KV cache instances: llama_kv_cache and llama_ik_cache
// the first instance is for caching key tensors of the model,
// the second instance is for caching lightning indexer key tensors
class llama_kv_cache_dsa : public llama_memory_i {
public:
llama_kv_cache_dsa(
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
~llama_kv_cache_dsa() = default;
//
// llama_memory_i
//
llama_memory_context_ptr init_batch(
llama_batch_allocr & balloc,
uint32_t n_ubatch,
bool embd_all) override;
llama_memory_context_ptr init_full() override;
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
bool get_can_shift() const override;
void clear(bool data) override;
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
void seq_keep(llama_seq_id seq_id) override;
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
// state write/load
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
//
// llama_kv_cache_dsa specific API
//
llama_kv_cache * get_base() const;
llama_ik_cache * get_ik () const;
private:
const uint32_t n_stream = 1;
std::unique_ptr<llama_kv_cache> kv_base;
std::unique_ptr<llama_ik_cache> kv_ik;
};
class llama_kv_cache_dsa_context : public llama_memory_context_i {
public:
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
// used for errors
llama_kv_cache_dsa_context(llama_memory_status status);
// used to create a full-cache context
llama_kv_cache_dsa_context(
llama_kv_cache_dsa * kv);
// used to create an update context
llama_kv_cache_dsa_context(
llama_kv_cache_dsa * kv,
llama_context * lctx,
bool optimize);
// used to create a batch processing context from a batch
llama_kv_cache_dsa_context(
llama_kv_cache_dsa * kv,
slot_info_vec_t sinfos_base,
slot_info_vec_t sinfos_ik,
std::vector<llama_ubatch> ubatches);
virtual ~llama_kv_cache_dsa_context();
//
// llama_memory_context_i
//
bool next() override;
bool apply() override;
llama_memory_status get_status() const override;
const llama_ubatch & get_ubatch() const override;
//
// llama_kv_cache_dsa_context specific API
//
const llama_kv_cache_context * get_base() const;
const llama_ik_cache_context * get_ik() const;
private:
//llama_kv_cache_dsa * kv;
// the index of the next ubatch to process
size_t i_next = 0;
std::vector<llama_ubatch> ubatches;
const llama_memory_context_ptr ctx_base;
const llama_memory_context_ptr ctx_ik;
const llama_memory_status status;
};

View File

@ -51,7 +51,7 @@ llama_kv_cache::llama_kv_cache(
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
ggml_init_params params = {
/*.mem_size =*/ size_t(3u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
/*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
@ -113,7 +113,6 @@ llama_kv_cache::llama_kv_cache(
// [TAG_V_CACHE_VARIABLE]
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
const uint32_t n_embd_indexer_head = hparams.indexer_head_size;
const char * dev_name = "CPU";
@ -135,29 +134,24 @@ llama_kv_cache::llama_kv_cache(
const bool has_k = true;
const bool has_v = !is_mla;
const bool has_ik = hparams.indexer_top_k > 0;
ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream) : nullptr;
ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream) : nullptr;
ggml_tensor * ik = has_ik ? ggml_new_tensor_3d(ctx, type_k, n_embd_indexer_head, kv_size, n_stream) : nullptr;
has_k && ggml_format_name(k, "cache_k_l%d", il);
has_v && ggml_format_name(v, "cache_v_l%d", il);
has_ik && ggml_format_name(ik, "cache_ik_l%d", il);
std::vector<ggml_tensor *> k_stream;
std::vector<ggml_tensor *> v_stream;
std::vector<ggml_tensor *> ik_stream;
for (uint32_t s = 0; s < n_stream; ++s) {
k_stream.push_back(has_k ? ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]) : nullptr);
v_stream.push_back(has_v ? ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]) : nullptr);
ik_stream.push_back(has_ik ? ggml_view_2d(ctx, ik, n_embd_indexer_head, kv_size, ik->nb[1], s*ik->nb[2]) : nullptr);
}
map_layer_ids[il] = layers.size();
layers.push_back({ il, k, v, ik, k_stream, v_stream, ik_stream });
layers.push_back({ il, k, v, k_stream, v_stream, });
}
if (reuse) {
@ -208,13 +202,11 @@ llama_kv_cache::llama_kv_cache(
{
const size_t memory_size_k = size_k_bytes();
const size_t memory_size_v = size_v_bytes();
const size_t memory_size_ik = size_ik_bytes();
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB, IK (%s): %7.2f MiB\n", __func__,
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream,
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_ik / (1024.0f * 1024.0f));
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
@ -664,10 +656,6 @@ bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_co
if (layer.v_stream[ssrc]) {
ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
}
if (layer.ik_stream[ssrc]) {
ggml_backend_tensor_copy(layer.ik_stream[ssrc], layer.ik_stream[sdst]);
}
}
}
}
@ -1084,26 +1072,6 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
}
ggml_tensor * llama_kv_cache::get_ik(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
const int32_t ikv = map_layer_ids.at(il);
auto * ik = layers[ikv].ik;
const uint64_t kv_size = get_size();
const uint64_t n_embd_indexer_head = ik->ne[0];
assert(n_embd_indexer_head == hparams.indexer_head_size);
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
return ggml_view_4d(ctx, ik,
n_embd_indexer_head, 1, n_kv, ns,
ggml_row_size(ik->type, n_embd_indexer_head),
ggml_row_size(ik->type, n_embd_indexer_head),
ggml_row_size(ik->type, n_embd_indexer_head*kv_size),
ggml_row_size(ik->type, n_embd_indexer_head*kv_size)*sinfo.s0);
}
ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
GGML_UNUSED(sinfo);
@ -1195,41 +1163,6 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
}
ggml_tensor * llama_kv_cache::cpy_ik(ggml_context * ctx, ggml_tensor * ik_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
GGML_UNUSED(sinfo);
const int32_t ikv = map_layer_ids.at(il);
ggml_tensor * ik = layers[ikv].ik;
const int64_t n_embd_indexer_head = ik_cur->ne[0];
const int64_t n_head = ik_cur->ne[1];
const int64_t n_tokens = ik_cur->ne[2];
const int64_t n_embd_gqa = n_embd_indexer_head*n_head;
// we can merge dims 0 and 1
// TODO: add ggml helper function for this?
GGML_ASSERT(ggml_row_size(ik_cur->type, n_embd_indexer_head) == ik_cur->nb[1]);
ik_cur = ggml_view_2d(ctx, ik_cur, n_embd_gqa, n_tokens, ik_cur->nb[2], 0);
const int64_t n_stream = ik->ne[2];
if (n_stream > 1) {
const int64_t kv_size = get_size();
assert(n_embd_gqa == ik->ne[0]);
assert(kv_size == ik->ne[1]);
// merge the buffer across all streams because the idxs are global
ik = ggml_reshape_2d(ctx, ik, n_embd_gqa, kv_size*n_stream);
}
// store the current K values into the cache
return ggml_set_rows(ctx, ik, ik_cur, k_idxs);
}
ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
const uint32_t n_tokens = ubatch.n_tokens;
@ -1604,16 +1537,6 @@ size_t llama_kv_cache::size_v_bytes() const {
return size_v_bytes;
}
size_t llama_kv_cache::size_ik_bytes() const {
size_t size_ik_bytes = 0;
for (const auto & layer : layers) {
size_ik_bytes += layer.ik ? ggml_nbytes(layer.ik) : 0;
}
return size_ik_bytes;
}
ggml_tensor * llama_kv_cache::build_rope_shift(
const llama_cparams & cparams,
ggml_context * ctx,
@ -2319,10 +2242,6 @@ ggml_tensor * llama_kv_cache_context::get_v(ggml_context * ctx, int32_t il) cons
return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
}
ggml_tensor * llama_kv_cache_context::get_ik(ggml_context * ctx, int32_t il) const {
return kv->get_ik(ctx, il, n_kv, sinfos[i_cur]);
}
ggml_tensor * llama_kv_cache_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
}
@ -2331,10 +2250,6 @@ ggml_tensor * llama_kv_cache_context::cpy_v(ggml_context * ctx, ggml_tensor * v_
return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
}
ggml_tensor * llama_kv_cache_context::cpy_ik(ggml_context * ctx, ggml_tensor * ik_cur, ggml_tensor * k_idxs, int32_t il) const {
return kv->cpy_ik(ctx, ik_cur, k_idxs, il, sinfos[i_cur]);
}
ggml_tensor * llama_kv_cache_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
return kv->build_input_k_idxs(ctx, ubatch);
}

View File

@ -161,12 +161,10 @@ public:
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
ggml_tensor * get_ik(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
// store k_cur and v_cur in the cache based on the provided head location
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
ggml_tensor * cpy_ik(ggml_context * ctx, ggml_tensor * ik_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
//
// preparation API
@ -212,11 +210,9 @@ private:
ggml_tensor * k;
ggml_tensor * v;
ggml_tensor * ik;
std::vector<ggml_tensor *> k_stream;
std::vector<ggml_tensor *> v_stream;
std::vector<ggml_tensor *> ik_stream;
};
bool v_trans = true; // the value tensor is transposed
@ -260,7 +256,6 @@ private:
size_t size_k_bytes() const;
size_t size_v_bytes() const;
size_t size_ik_bytes() const;
ggml_tensor * build_rope_shift(
const llama_cparams & cparams,
@ -336,7 +331,6 @@ public:
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
ggml_tensor * get_ik(ggml_context * ctx, int32_t il) const;
// store k_cur and v_cur in the cache based on the provided head location
// note: the heads in k_cur and v_cur should be layed out contiguously in memory
@ -346,7 +340,6 @@ public:
// - v_idxs [n_tokens] or [n_tokens*n_embd_v_gqa] depending if V cache is transposed
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
ggml_tensor * cpy_ik(ggml_context * ctx, ggml_tensor * ik_cur, ggml_tensor * k_idxs, int32_t il) const;
// create destination indices for each head of the current batch for where it would be written in the KV cache
// the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but

View File

@ -8,6 +8,7 @@
#include "llama-kv-cache.h"
#include "llama-kv-cache-iswa.h"
#include "llama-kv-cache-dsa.h"
#include "llama-memory-hybrid.h"
#include "llama-memory-hybrid-iswa.h"
#include "llama-memory-recurrent.h"
@ -8111,6 +8112,23 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
{
res = nullptr;
} break;
case LLM_ARCH_DEEPSEEK32:
{
res = new llama_kv_cache_dsa(
*this,
params.type_k,
params.type_v,
!cparams.flash_attn,
cparams.offload_kqv,
cparams.kv_unified,
cparams.n_ctx_seq,
cparams.n_seq_max,
1,
hparams.n_swa,
hparams.swa_type,
nullptr,
nullptr);
} break;
// Models that need standard caching should rely on recurrent/hybrid
// checks
default:

View File

@ -1,14 +1,17 @@
#include "models.h"
#include "llama-kv-cache.h"
#include "llama-ik-cache.h"
llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_graph_params & params) :
llm_graph_context(params) {
const bool is_mla = hparams.is_mla();
GGML_ASSERT(is_mla);
// note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA
const int64_t n_embd_head_k = hparams.n_embd_head_k_mla();
const int64_t n_embd_head_v = hparams.n_embd_head_v_mla();
GGML_UNUSED(n_embd_head_v);
const int64_t n_embd_head_qk_rope = hparams.n_rot();
const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope;
@ -42,8 +45,9 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn_kv = !is_mla ? build_attn_inp_kv() : nullptr;
auto * inp_attn_k = is_mla ? build_attn_inp_k() : nullptr;
std::pair<llm_graph_input_attn_k*, llm_graph_input_attn_ik*> inp_attn_dsa = build_attn_inp_k_dsa();
auto * inp_attn_k = inp_attn_dsa.first;
auto * inp_attn_ik = inp_attn_dsa.second;
ggml_tensor * inp_out_ids = build_inp_out_ids();
@ -63,9 +67,7 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
qr = build_norm(qr, model.layers[il].attn_q_a_norm, nullptr, LLM_NORM_RMS, il);
cb(qr, "qr", il);
ggml_tensor * kq_mask = is_mla ? inp_attn_k->get_kq_mask() : inp_attn_kv->get_kq_mask();
ggml_tensor * kq_mask_bak = ggml_dup(ctx0, kq_mask);
ggml_build_forward_expand(gf, kq_mask_bak);
ggml_tensor * top_k = nullptr;
// lightning indexer
{
@ -133,9 +135,9 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
cb(indexer_k, "indexer_k", il);
// store indexer keys to KV cache
const auto * mctx_cur = is_mla ? inp_attn_k->mctx : inp_attn_kv->mctx;
const auto & k_idxs = is_mla ? inp_attn_k->get_k_idxs() : inp_attn_kv->get_k_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_ik(ctx0, indexer_k, k_idxs, il));
const auto * mctx_cur = inp_attn_ik->mctx;
const auto & k_idxs = inp_attn_ik->get_k_idxs();
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, indexer_k, k_idxs, il));
// prepare indexer weights
ggml_tensor * indexer_weights = ggml_mul_mat(ctx0, model.layers[il].indexer_proj, cur);
@ -145,7 +147,7 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
cb(indexer_weights, "indexer_weights", il);
// get cached indexer keys
indexer_k = mctx_cur->get_ik(ctx0, il);
indexer_k = mctx_cur->get_k(ctx0, il);
// split the batch into streams if needed
const auto n_stream = indexer_k->ne[3];
@ -188,24 +190,14 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
cb(indexer_score, "indexer_score", il);
// mask indexer scores
ggml_tensor * kq_mask_f32 = ggml_cast(ctx0, kq_mask, GGML_TYPE_F32);
indexer_score = ggml_add(ctx0, indexer_score, kq_mask_f32);
ggml_tensor * indexer_kq_mask = inp_attn_ik->get_kq_mask();
indexer_score = ggml_add(ctx0, indexer_score, indexer_kq_mask);
cb(indexer_score, "indexer_score", il);
// get indices of top k indexer scores
uint32_t n_top_k = indexer_score->ne[0] < n_indexer_top_k ? indexer_score->ne[0] : n_indexer_top_k;
ggml_tensor * top_k = ggml_cont(ctx0, ggml_top_k(ctx0, indexer_score, n_top_k));
top_k = ggml_cont(ctx0, ggml_top_k(ctx0, indexer_score, n_top_k));
cb(top_k, "top_k", il);
// prepare new kq mask - starts filled with -INFINITY
ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask_f32, -INFINITY);
cb(kq_mask_all, "kq_mask_all", il);
// modify it by unmasking tokens that are in top_k indices
ggml_tensor * kq_mask_top_k = ggml_where_id(ctx0, kq_mask_f32, kq_mask_all, top_k);
cb(kq_mask_top_k, "kq_mask_top_k", il);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ggml_cast(ctx0, kq_mask_top_k, kq_mask->type), kq_mask));
}
ggml_tensor * q = ggml_mul_mat(ctx0, model.layers[il].wq_b, qr);
@ -250,7 +242,8 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
kv_cmpr = build_norm(kv_cmpr, model.layers[il].attn_kv_a_norm, nullptr, LLM_NORM_RMS, il);
cb(kv_cmpr, "kv_cmpr", il);
if (is_mla) {
// MLA attention
{
// {n_embd_head_qk_nope, n_tokens, n_head}
q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
cb(q_nope, "q_nope_perm", il);
@ -282,41 +275,8 @@ llm_build_deepseek32::llm_build_deepseek32(const llama_model & model, const llm_
// note: MLA with the absorption optimization converts into MQA (ie: GQA with 1 group)
cur = build_attn(inp_attn_k,
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il);
} else {
ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr);
cb(kv, "kv", il);
// split into {n_embd_head_qk_nope, n_head, n_tokens}
ggml_tensor * k_nope =
ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v),
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head, 0);
cb(k_nope, "k_nope_view", il);
// and {n_embd_head_v, n_head, n_tokens}
ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, n_embd_head_v, n_head, n_tokens,
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v),
ggml_row_size(kv->type, n_embd_head_qk_nope + n_embd_head_v) * n_head,
ggml_row_size(kv->type, n_embd_head_qk_nope));
cb(Vcur, "Vcur_view", il);
Vcur = ggml_cont(ctx0, Vcur);
cb(Vcur, "Vcur_cont", il);
ggml_tensor * Qcur = ggml_concat(ctx0, q_nope, q_pe, 0);
cb(Qcur, "Qcur", il);
ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
cb(Kcur, "Kcur", il);
// note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups)
cur = build_attn(inp_attn_kv,
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, top_k, kq_scale, il);
}
ggml_build_forward_expand(gf, ggml_cpy(ctx0, kq_mask_bak, kq_mask));
}
if (il == effective_n_layers - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);