memory : add llama_memory_hybrid_iswa (#18601)
* memory : add llama_memory_hybrid_iswa * Update src/llama-memory-hybrid-iswa.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
12a4a47e6a
commit
ad8d85bd94
|
|
@ -24,6 +24,7 @@ add_library(llama
|
||||||
llama-kv-cache-iswa.cpp
|
llama-kv-cache-iswa.cpp
|
||||||
llama-memory.cpp
|
llama-memory.cpp
|
||||||
llama-memory-hybrid.cpp
|
llama-memory-hybrid.cpp
|
||||||
|
llama-memory-hybrid-iswa.cpp
|
||||||
llama-memory-recurrent.cpp
|
llama-memory-recurrent.cpp
|
||||||
llama-mmap.cpp
|
llama-mmap.cpp
|
||||||
llama-model-loader.cpp
|
llama-model-loader.cpp
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@
|
||||||
#include "llama-kv-cache.h"
|
#include "llama-kv-cache.h"
|
||||||
#include "llama-kv-cache-iswa.h"
|
#include "llama-kv-cache-iswa.h"
|
||||||
#include "llama-memory-hybrid.h"
|
#include "llama-memory-hybrid.h"
|
||||||
|
#include "llama-memory-hybrid-iswa.h"
|
||||||
#include "llama-memory-recurrent.h"
|
#include "llama-memory-recurrent.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
|
|
@ -510,6 +511,76 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
|
||||||
|
const auto * attn_ctx = mctx->get_attn();
|
||||||
|
|
||||||
|
// base tensors may not be allocated if there are no non-SWA attention layers
|
||||||
|
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
|
||||||
|
attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
|
||||||
|
attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
|
||||||
|
|
||||||
|
attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
|
||||||
|
}
|
||||||
|
|
||||||
|
// swa tensors may not be allocated if there are no SWA attention layers
|
||||||
|
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
|
||||||
|
attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
|
||||||
|
attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
|
||||||
|
|
||||||
|
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
||||||
|
|
||||||
|
if (inp_rs->s_copy) {
|
||||||
|
GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
|
||||||
|
int32_t * data = (int32_t *) inp_rs->s_copy->data;
|
||||||
|
|
||||||
|
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||||
|
for (uint32_t i = 0; i < n_rs; ++i) {
|
||||||
|
data[i] = mctx->get_recr()->s_copy(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) {
|
||||||
|
const auto * mctx = static_cast<const llama_memory_hybrid_iswa_context *>(params.mctx);
|
||||||
|
|
||||||
|
this->mctx = mctx;
|
||||||
|
|
||||||
|
bool res = true;
|
||||||
|
|
||||||
|
const auto * attn_ctx = mctx->get_attn();
|
||||||
|
|
||||||
|
// base tensors may not be allocated if there are no non-SWA attention layers
|
||||||
|
if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
|
||||||
|
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||||
|
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||||
|
|
||||||
|
res &= inp_attn->self_kq_mask->ne[0] == attn_ctx->get_base()->get_n_kv();
|
||||||
|
res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
// swa tensors may not be allocated if there are no SWA attention layers
|
||||||
|
if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
|
||||||
|
res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
|
||||||
|
//res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||||
|
|
||||||
|
res &= inp_attn->self_kq_mask_swa->ne[0] == attn_ctx->get_swa()->get_n_kv();
|
||||||
|
res &= inp_attn->self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
||||||
|
|
||||||
|
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
|
||||||
|
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
|
||||||
|
|
||||||
|
res &= inp_rs->head == mctx->get_recr()->get_head();
|
||||||
|
res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
|
||||||
// set the inputs only for the active samplers in the current ubatch
|
// set the inputs only for the active samplers in the current ubatch
|
||||||
std::unordered_set<llama_seq_id> active_samplers;
|
std::unordered_set<llama_seq_id> active_samplers;
|
||||||
|
|
@ -2056,6 +2127,47 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
||||||
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const {
|
||||||
|
const auto * mctx_cur = static_cast<const llama_memory_hybrid_iswa_context *>(mctx);
|
||||||
|
|
||||||
|
auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
|
||||||
|
|
||||||
|
// build iswa attention input
|
||||||
|
const auto * attn_ctx = mctx_cur->get_attn();
|
||||||
|
|
||||||
|
auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx);
|
||||||
|
|
||||||
|
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||||
|
|
||||||
|
{
|
||||||
|
const auto n_kv = attn_ctx->get_base()->get_n_kv();
|
||||||
|
|
||||||
|
inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||||
|
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||||
|
|
||||||
|
inp_attn->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||||
|
ggml_set_input(inp_attn->self_kq_mask);
|
||||||
|
|
||||||
|
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const auto n_kv = attn_ctx->get_swa()->get_n_kv();
|
||||||
|
|
||||||
|
inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||||
|
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||||
|
|
||||||
|
inp_attn->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||||
|
ggml_set_input(inp_attn->self_kq_mask_swa);
|
||||||
|
|
||||||
|
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
||||||
|
|
||||||
|
return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp));
|
||||||
|
}
|
||||||
|
|
||||||
void llm_graph_context::build_dense_out(
|
void llm_graph_context::build_dense_out(
|
||||||
ggml_tensor * dense_2,
|
ggml_tensor * dense_2,
|
||||||
ggml_tensor * dense_3) const {
|
ggml_tensor * dense_3) const {
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ class llama_kv_cache_context;
|
||||||
class llama_kv_cache_iswa_context;
|
class llama_kv_cache_iswa_context;
|
||||||
class llama_memory_recurrent_context;
|
class llama_memory_recurrent_context;
|
||||||
class llama_memory_hybrid_context;
|
class llama_memory_hybrid_context;
|
||||||
|
class llama_memory_hybrid_iswa_context;
|
||||||
|
|
||||||
// certain models (typically multi-modal) can produce different types of graphs
|
// certain models (typically multi-modal) can produce different types of graphs
|
||||||
enum llm_graph_type {
|
enum llm_graph_type {
|
||||||
|
|
@ -397,6 +398,34 @@ public:
|
||||||
const llama_memory_hybrid_context * mctx;
|
const llama_memory_hybrid_context * mctx;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i {
|
||||||
|
public:
|
||||||
|
llm_graph_input_mem_hybrid_iswa(
|
||||||
|
const llama_cparams & cparams,
|
||||||
|
std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn,
|
||||||
|
std::unique_ptr<llm_graph_input_rs> inp_rs,
|
||||||
|
const llama_memory_hybrid_iswa_context * mctx) :
|
||||||
|
inp_attn(std::move(inp_attn)),
|
||||||
|
inp_rs(std::move(inp_rs)),
|
||||||
|
cparams(cparams),
|
||||||
|
mctx(mctx) { }
|
||||||
|
virtual ~llm_graph_input_mem_hybrid_iswa() = default;
|
||||||
|
|
||||||
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
bool can_reuse(const llm_graph_params & params) override;
|
||||||
|
|
||||||
|
std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn;
|
||||||
|
std::unique_ptr<llm_graph_input_rs> inp_rs;
|
||||||
|
|
||||||
|
llm_graph_input_attn_kv_iswa * get_attn() const { return inp_attn.get(); }
|
||||||
|
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
|
||||||
|
|
||||||
|
const llama_cparams cparams;
|
||||||
|
|
||||||
|
const llama_memory_hybrid_iswa_context * mctx;
|
||||||
|
};
|
||||||
|
|
||||||
class llm_graph_input_sampling : public llm_graph_input_i {
|
class llm_graph_input_sampling : public llm_graph_input_i {
|
||||||
public:
|
public:
|
||||||
llm_graph_input_sampling(std::map<llama_seq_id, llama_sampler *> samplers) :
|
llm_graph_input_sampling(std::map<llama_seq_id, llama_sampler *> samplers) :
|
||||||
|
|
@ -881,6 +910,8 @@ struct llm_graph_context {
|
||||||
|
|
||||||
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
|
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
|
||||||
|
|
||||||
|
llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const;
|
||||||
|
|
||||||
//
|
//
|
||||||
// pooling
|
// pooling
|
||||||
//
|
//
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,275 @@
|
||||||
|
#include "llama-memory-hybrid-iswa.h"
|
||||||
|
|
||||||
|
#include "llama-impl.h"
|
||||||
|
#include "llama-model.h"
|
||||||
|
#include "llama-context.h"
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_memory_hybrid_iswa
|
||||||
|
//
|
||||||
|
|
||||||
|
llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
|
||||||
|
const llama_model & model,
|
||||||
|
/* attn */
|
||||||
|
ggml_type type_k,
|
||||||
|
ggml_type type_v,
|
||||||
|
bool v_trans,
|
||||||
|
bool swa_full,
|
||||||
|
uint32_t kv_size,
|
||||||
|
uint32_t n_ubatch,
|
||||||
|
uint32_t n_pad,
|
||||||
|
/* recurrent */
|
||||||
|
ggml_type type_r,
|
||||||
|
ggml_type type_s,
|
||||||
|
uint32_t rs_size,
|
||||||
|
/* common */
|
||||||
|
uint32_t n_seq_max,
|
||||||
|
bool offload,
|
||||||
|
bool unified,
|
||||||
|
/* layer filters */
|
||||||
|
const layer_filter_cb & filter_attn,
|
||||||
|
const layer_filter_cb & filter_recr) :
|
||||||
|
hparams(model.hparams),
|
||||||
|
mem_attn(new llama_kv_cache_iswa(
|
||||||
|
model,
|
||||||
|
type_k,
|
||||||
|
type_v,
|
||||||
|
v_trans,
|
||||||
|
offload,
|
||||||
|
swa_full,
|
||||||
|
unified,
|
||||||
|
kv_size,
|
||||||
|
n_seq_max,
|
||||||
|
n_ubatch,
|
||||||
|
n_pad,
|
||||||
|
filter_attn == nullptr ?
|
||||||
|
[&](int32_t il) { return !hparams.is_recurrent(il); }
|
||||||
|
: filter_attn,
|
||||||
|
nullptr
|
||||||
|
)),
|
||||||
|
mem_recr(new llama_memory_recurrent(
|
||||||
|
model,
|
||||||
|
type_r,
|
||||||
|
type_s,
|
||||||
|
offload,
|
||||||
|
rs_size,
|
||||||
|
n_seq_max,
|
||||||
|
filter_recr == nullptr ?
|
||||||
|
[&](int32_t il) { return hparams.is_recurrent(il); }
|
||||||
|
: filter_recr
|
||||||
|
)) {}
|
||||||
|
|
||||||
|
llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
||||||
|
do {
|
||||||
|
balloc.split_reset();
|
||||||
|
|
||||||
|
// follow the recurrent pattern for creating the ubatch splits
|
||||||
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
llama_ubatch ubatch;
|
||||||
|
|
||||||
|
if (embd_all) {
|
||||||
|
// if all tokens are output, split by sequence
|
||||||
|
ubatch = balloc.split_seq(n_ubatch);
|
||||||
|
} else {
|
||||||
|
// TODO: non-sequential equal split can be done if using unified KV cache
|
||||||
|
// for simplicity, we always use sequential equal split for now
|
||||||
|
ubatch = balloc.split_equal(n_ubatch, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ubatch.n_tokens == 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
ubatches.push_back(std::move(ubatch)); // NOLINT
|
||||||
|
}
|
||||||
|
|
||||||
|
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
||||||
|
// failed to find a suitable split
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepare the recurrent batches first
|
||||||
|
if (!mem_recr->prepare(ubatches)) {
|
||||||
|
// TODO: will the recurrent cache be in an undefined context at this point?
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
|
||||||
|
return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepare the attention cache (iswa version returns both base and swa slot infos)
|
||||||
|
auto sinfos_base = mem_attn->get_base()->prepare(ubatches);
|
||||||
|
if (sinfos_base.empty()) {
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to prepare attention base ubatches\n", __func__);
|
||||||
|
return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto sinfos_swa = mem_attn->get_swa()->prepare(ubatches);
|
||||||
|
if (sinfos_swa.empty()) {
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to prepare attention swa ubatches\n", __func__);
|
||||||
|
return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_unique<llama_memory_hybrid_iswa_context>(
|
||||||
|
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
||||||
|
} while(false);
|
||||||
|
|
||||||
|
return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_memory_context_ptr llama_memory_hybrid_iswa::init_full() {
|
||||||
|
return std::make_unique<llama_memory_hybrid_iswa_context>(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_memory_context_ptr llama_memory_hybrid_iswa::init_update(llama_context * lctx, bool optimize) {
|
||||||
|
return std::make_unique<llama_memory_hybrid_iswa_context>(this, lctx, optimize);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_memory_hybrid_iswa::get_can_shift() const {
|
||||||
|
// Shifting is trivially supported for recurrent
|
||||||
|
return mem_attn->get_can_shift();
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_memory_hybrid_iswa::clear(bool data) {
|
||||||
|
mem_attn->clear(data);
|
||||||
|
mem_recr->clear(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_memory_hybrid_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||||
|
// Try removing from the recurrent cache first since it may fail. If it does
|
||||||
|
// fail, the cache will not have been mutated.
|
||||||
|
if (!mem_recr->seq_rm(seq_id, p0, p1)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return mem_attn->seq_rm(seq_id, p0, p1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_memory_hybrid_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||||
|
mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||||
|
mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_memory_hybrid_iswa::seq_keep(llama_seq_id seq_id) {
|
||||||
|
mem_attn->seq_keep(seq_id);
|
||||||
|
mem_recr->seq_keep(seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_memory_hybrid_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||||
|
mem_attn->seq_add(seq_id, p0, p1, shift);
|
||||||
|
mem_recr->seq_add(seq_id, p0, p1, shift);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_memory_hybrid_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||||
|
mem_attn->seq_div(seq_id, p0, p1, d);
|
||||||
|
mem_recr->seq_div(seq_id, p0, p1, d);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_pos llama_memory_hybrid_iswa::seq_pos_min(llama_seq_id seq_id) const {
|
||||||
|
// the min of the total cache is the max of the two caches' min values
|
||||||
|
return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_pos llama_memory_hybrid_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
||||||
|
// the max of the total cache is the min of the two caches' max values
|
||||||
|
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_hybrid_iswa::memory_breakdown() const {
|
||||||
|
std::map<ggml_backend_buffer_type_t, size_t> mb = mem_attn->memory_breakdown();
|
||||||
|
for (const auto & buft_size : mem_recr->memory_breakdown()) {
|
||||||
|
mb[buft_size.first] += buft_size.second;
|
||||||
|
}
|
||||||
|
return mb;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_memory_hybrid_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
|
||||||
|
mem_attn->state_write(io, seq_id, flags);
|
||||||
|
mem_recr->state_write(io, seq_id, flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_memory_hybrid_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
|
||||||
|
mem_attn->state_read(io, seq_id, flags);
|
||||||
|
mem_recr->state_read(io, seq_id, flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_kv_cache_iswa * llama_memory_hybrid_iswa::get_mem_attn() const {
|
||||||
|
return mem_attn.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_memory_recurrent * llama_memory_hybrid_iswa::get_mem_recr() const {
|
||||||
|
return mem_recr.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_memory_hybrid_iswa_context
|
||||||
|
//
|
||||||
|
|
||||||
|
llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_status status) : status(status) {}
|
||||||
|
|
||||||
|
llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem) :
|
||||||
|
ctx_attn(mem->get_mem_attn()->init_full()),
|
||||||
|
ctx_recr(mem->get_mem_recr()->init_full()),
|
||||||
|
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(
|
||||||
|
llama_memory_hybrid_iswa * mem,
|
||||||
|
llama_context * lctx,
|
||||||
|
bool optimize) :
|
||||||
|
ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
|
||||||
|
ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
|
||||||
|
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(
|
||||||
|
llama_memory_hybrid_iswa * mem,
|
||||||
|
slot_info_vec_t sinfos_base,
|
||||||
|
slot_info_vec_t sinfos_swa,
|
||||||
|
std::vector<llama_ubatch> ubatches) :
|
||||||
|
ubatches(std::move(ubatches)),
|
||||||
|
// note: here we copy the ubatches. not sure if this is ideal
|
||||||
|
ctx_attn(new llama_kv_cache_iswa_context(mem->get_mem_attn(), std::move(sinfos_base), std::move(sinfos_swa), this->ubatches)),
|
||||||
|
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
||||||
|
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_memory_hybrid_iswa_context::next() {
|
||||||
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
|
||||||
|
ctx_attn->next();
|
||||||
|
ctx_recr->next();
|
||||||
|
|
||||||
|
if (++i_next >= ubatches.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool llama_memory_hybrid_iswa_context::apply() {
|
||||||
|
assert(!llama_memory_status_is_fail(status));
|
||||||
|
|
||||||
|
bool res = true;
|
||||||
|
|
||||||
|
res = res & ctx_attn->apply();
|
||||||
|
res = res & ctx_recr->apply();
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_memory_status llama_memory_hybrid_iswa_context::get_status() const {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_ubatch & llama_memory_hybrid_iswa_context::get_ubatch() const {
|
||||||
|
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||||
|
return ubatches[i_next];
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_kv_cache_iswa_context * llama_memory_hybrid_iswa_context::get_attn() const {
|
||||||
|
return static_cast<const llama_kv_cache_iswa_context *>(ctx_attn.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
const llama_memory_recurrent_context * llama_memory_hybrid_iswa_context::get_recr() const {
|
||||||
|
return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,140 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "llama-batch.h"
|
||||||
|
#include "llama-graph.h"
|
||||||
|
#include "llama-kv-cache-iswa.h"
|
||||||
|
#include "llama-memory.h"
|
||||||
|
#include "llama-memory-recurrent.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_memory_hybrid_iswa
|
||||||
|
//
|
||||||
|
|
||||||
|
// utilizes instances of llama_memory_recurrent and llama_kv_cache_iswa to
|
||||||
|
// support models where each layer may be either attention-based (with SWA support) or recurrent
|
||||||
|
|
||||||
|
class llama_memory_hybrid_iswa : public llama_memory_i {
|
||||||
|
public:
|
||||||
|
llama_memory_hybrid_iswa(
|
||||||
|
const llama_model & model,
|
||||||
|
/* attn */
|
||||||
|
ggml_type type_k,
|
||||||
|
ggml_type type_v,
|
||||||
|
bool v_trans,
|
||||||
|
bool swa_full,
|
||||||
|
uint32_t kv_size,
|
||||||
|
uint32_t n_ubatch,
|
||||||
|
uint32_t n_pad,
|
||||||
|
/* recurrent */
|
||||||
|
ggml_type type_r,
|
||||||
|
ggml_type type_s,
|
||||||
|
uint32_t rs_size,
|
||||||
|
/* common */
|
||||||
|
uint32_t n_seq_max,
|
||||||
|
bool offload,
|
||||||
|
bool unified,
|
||||||
|
/* layer filters */
|
||||||
|
const layer_filter_cb & filter_attn = nullptr,
|
||||||
|
const layer_filter_cb & filter_recr = nullptr);
|
||||||
|
|
||||||
|
~llama_memory_hybrid_iswa() = default;
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_memory_i
|
||||||
|
//
|
||||||
|
|
||||||
|
llama_memory_context_ptr init_batch(
|
||||||
|
llama_batch_allocr & balloc,
|
||||||
|
uint32_t n_ubatch,
|
||||||
|
bool embd_all) override;
|
||||||
|
|
||||||
|
llama_memory_context_ptr init_full() override;
|
||||||
|
|
||||||
|
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||||
|
|
||||||
|
bool get_can_shift() const override;
|
||||||
|
|
||||||
|
void clear(bool data) override;
|
||||||
|
|
||||||
|
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||||
|
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||||
|
void seq_keep(llama_seq_id seq_id) override;
|
||||||
|
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||||
|
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||||
|
|
||||||
|
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||||
|
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||||
|
|
||||||
|
std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
|
||||||
|
|
||||||
|
// state write/load
|
||||||
|
|
||||||
|
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
|
||||||
|
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_memory_hybrid_iswa specific API
|
||||||
|
//
|
||||||
|
|
||||||
|
llama_kv_cache_iswa * get_mem_attn() const;
|
||||||
|
llama_memory_recurrent * get_mem_recr() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
const llama_hparams & hparams;
|
||||||
|
|
||||||
|
const std::unique_ptr<llama_kv_cache_iswa> mem_attn;
|
||||||
|
const std::unique_ptr<llama_memory_recurrent> mem_recr;
|
||||||
|
};
|
||||||
|
|
||||||
|
class llama_memory_hybrid_iswa_context : public llama_memory_context_i {
|
||||||
|
public:
|
||||||
|
using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
|
||||||
|
|
||||||
|
// init failure
|
||||||
|
explicit llama_memory_hybrid_iswa_context(llama_memory_status status);
|
||||||
|
|
||||||
|
// init full
|
||||||
|
explicit llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem);
|
||||||
|
|
||||||
|
// init update
|
||||||
|
explicit llama_memory_hybrid_iswa_context(
|
||||||
|
llama_memory_hybrid_iswa * mem,
|
||||||
|
llama_context * lctx,
|
||||||
|
bool optimize);
|
||||||
|
|
||||||
|
// init success
|
||||||
|
llama_memory_hybrid_iswa_context(
|
||||||
|
llama_memory_hybrid_iswa * mem,
|
||||||
|
slot_info_vec_t sinfos_base,
|
||||||
|
slot_info_vec_t sinfos_swa,
|
||||||
|
std::vector<llama_ubatch> ubatches);
|
||||||
|
|
||||||
|
~llama_memory_hybrid_iswa_context() = default;
|
||||||
|
|
||||||
|
bool next() override;
|
||||||
|
bool apply() override;
|
||||||
|
|
||||||
|
llama_memory_status get_status() const override;
|
||||||
|
const llama_ubatch & get_ubatch() const override;
|
||||||
|
|
||||||
|
//
|
||||||
|
// llama_memory_hybrid_iswa_context
|
||||||
|
//
|
||||||
|
|
||||||
|
const llama_kv_cache_iswa_context * get_attn() const;
|
||||||
|
const llama_memory_recurrent_context * get_recr() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// the index of the next ubatch to process
|
||||||
|
size_t i_next = 0;
|
||||||
|
|
||||||
|
std::vector<llama_ubatch> ubatches;
|
||||||
|
|
||||||
|
const llama_memory_context_ptr ctx_attn;
|
||||||
|
const llama_memory_context_ptr ctx_recr;
|
||||||
|
|
||||||
|
const llama_memory_status status;
|
||||||
|
};
|
||||||
|
|
@ -8,6 +8,7 @@
|
||||||
#include "llama-kv-cache.h"
|
#include "llama-kv-cache.h"
|
||||||
#include "llama-kv-cache-iswa.h"
|
#include "llama-kv-cache-iswa.h"
|
||||||
#include "llama-memory-hybrid.h"
|
#include "llama-memory-hybrid.h"
|
||||||
|
#include "llama-memory-hybrid-iswa.h"
|
||||||
#include "llama-memory-recurrent.h"
|
#include "llama-memory-recurrent.h"
|
||||||
|
|
||||||
#include "ggml-cpp.h"
|
#include "ggml-cpp.h"
|
||||||
|
|
@ -7528,23 +7529,44 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
res = new llama_memory_hybrid(
|
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||||
/* model */ *this,
|
// Use hybrid-iswa for hybrid models with SWA
|
||||||
/* attn_type_k */ params.type_k,
|
res = new llama_memory_hybrid_iswa(
|
||||||
/* attn_type_v */ params.type_v,
|
/* model */ *this,
|
||||||
/* attn_v_trans */ !cparams.flash_attn,
|
/* attn_type_k */ params.type_k,
|
||||||
/* attn_kv_size */ cparams.n_ctx,
|
/* attn_type_v */ params.type_v,
|
||||||
/* attn_n_pad */ 1,
|
/* attn_v_trans */ !cparams.flash_attn,
|
||||||
/* attn_n_swa */ hparams.n_swa,
|
/* attn_swa_full */ params.swa_full,
|
||||||
/* attn_swa_type */ hparams.swa_type,
|
/* attn_kv_size */ cparams.n_ctx,
|
||||||
/* recurrent_type_k */ GGML_TYPE_F32,
|
/* attn_n_ubatch */ cparams.n_ubatch,
|
||||||
/* recurrent_type_v */ GGML_TYPE_F32,
|
/* attn_n_pad */ 1,
|
||||||
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
|
/* recurrent_type_r */ GGML_TYPE_F32,
|
||||||
/* n_seq_max */ cparams.n_seq_max,
|
/* recurrent_type_s */ GGML_TYPE_F32,
|
||||||
/* offload */ cparams.offload_kqv,
|
/* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max),
|
||||||
/* unified */ cparams.kv_unified,
|
/* n_seq_max */ cparams.n_seq_max,
|
||||||
/* filter_attn */ std::move(filter_attn),
|
/* offload */ cparams.offload_kqv,
|
||||||
/* filter_recr */ std::move(filter_recr));
|
/* unified */ cparams.kv_unified,
|
||||||
|
/* filter_attn */ std::move(filter_attn),
|
||||||
|
/* filter_recr */ std::move(filter_recr));
|
||||||
|
} else {
|
||||||
|
res = new llama_memory_hybrid(
|
||||||
|
/* model */ *this,
|
||||||
|
/* attn_type_k */ params.type_k,
|
||||||
|
/* attn_type_v */ params.type_v,
|
||||||
|
/* attn_v_trans */ !cparams.flash_attn,
|
||||||
|
/* attn_kv_size */ cparams.n_ctx,
|
||||||
|
/* attn_n_pad */ 1,
|
||||||
|
/* attn_n_swa */ hparams.n_swa,
|
||||||
|
/* attn_swa_type */ hparams.swa_type,
|
||||||
|
/* recurrent_type_k */ GGML_TYPE_F32,
|
||||||
|
/* recurrent_type_v */ GGML_TYPE_F32,
|
||||||
|
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
|
||||||
|
/* n_seq_max */ cparams.n_seq_max,
|
||||||
|
/* offload */ cparams.offload_kqv,
|
||||||
|
/* unified */ cparams.kv_unified,
|
||||||
|
/* filter_attn */ std::move(filter_attn),
|
||||||
|
/* filter_recr */ std::move(filter_recr));
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
llama_memory_i::layer_reuse_cb reuse = nullptr;
|
llama_memory_i::layer_reuse_cb reuse = nullptr;
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue