graph : add back hybrid memory graph input
But this time it contains the sub-cache graph inputs. This *should* make it easier to handle updating the inputs when caching the graph (eventually).
This commit is contained in:
parent
4682e21c46
commit
20f8e43e63
|
|
@ -335,6 +335,11 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
||||||
|
inp_attn->set_input(ubatch);
|
||||||
|
inp_rs->set_input(ubatch);
|
||||||
|
}
|
||||||
|
|
||||||
void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
|
void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
|
||||||
GGML_UNUSED(ubatch);
|
GGML_UNUSED(ubatch);
|
||||||
GGML_ASSERT(one && ggml_nelements(one) == 1);
|
GGML_ASSERT(one && ggml_nelements(one) == 1);
|
||||||
|
|
@ -1147,10 +1152,12 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
return cur;
|
return cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(const llama_kv_cache_unified_context * mctx_cur) const {
|
static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
|
||||||
if (!mctx_cur) {
|
ggml_context * ctx0,
|
||||||
mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
const llama_ubatch & ubatch,
|
||||||
}
|
const llama_hparams & hparams,
|
||||||
|
const llama_cparams & cparams,
|
||||||
|
const llama_kv_cache_unified_context * mctx_cur) {
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
|
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
|
||||||
|
|
||||||
|
|
@ -1158,6 +1165,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(c
|
||||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
||||||
|
|
||||||
const auto n_kv = mctx_cur->get_n_kv();
|
const auto n_kv = mctx_cur->get_n_kv();
|
||||||
|
const auto n_tokens = ubatch.n_tokens;
|
||||||
|
|
||||||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||||
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||||
|
|
@ -1168,6 +1176,14 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(c
|
||||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return inp;
|
||||||
|
}
|
||||||
|
|
||||||
|
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
||||||
|
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
||||||
|
|
||||||
|
auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
|
||||||
|
|
||||||
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
|
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1346,10 +1362,11 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
return cur;
|
return cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa(const llama_kv_cache_unified_iswa_context * mctx_cur) const {
|
// TODO: maybe separate the inner implementation into a separate function
|
||||||
if (!mctx_cur) {
|
// like with the non-sliding window equivalent
|
||||||
mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
|
// once sliding-window hybrid caches are a thing.
|
||||||
}
|
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
||||||
|
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
||||||
|
|
||||||
|
|
@ -1417,10 +1434,9 @@ ggml_tensor * llm_graph_context::build_rs(
|
||||||
return output_states;
|
return output_states;
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_graph_input_rs * llm_graph_context::build_rs_inp(const llama_memory_recurrent_context * mctx_cur) const {
|
static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
|
||||||
if (!mctx_cur) {
|
ggml_context * ctx0,
|
||||||
mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
const llama_memory_recurrent_context * mctx_cur) {
|
||||||
}
|
|
||||||
|
|
||||||
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
|
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
|
||||||
|
|
||||||
|
|
@ -1429,6 +1445,14 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp(const llama_memory_recurren
|
||||||
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
||||||
ggml_set_input(inp->s_copy);
|
ggml_set_input(inp->s_copy);
|
||||||
|
|
||||||
|
return inp;
|
||||||
|
}
|
||||||
|
|
||||||
|
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
||||||
|
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
||||||
|
|
||||||
|
auto inp = build_rs_inp_impl(ctx0, mctx_cur);
|
||||||
|
|
||||||
return (llm_graph_input_rs *) res->add_input(std::move(inp));
|
return (llm_graph_input_rs *) res->add_input(std::move(inp));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1486,6 +1510,17 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
||||||
|
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
||||||
|
|
||||||
|
auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr());
|
||||||
|
auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
|
||||||
|
|
||||||
|
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
||||||
|
|
||||||
|
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
||||||
|
}
|
||||||
|
|
||||||
void llm_graph_context::build_pooling(
|
void llm_graph_context::build_pooling(
|
||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
ggml_tensor * cls,
|
ggml_tensor * cls,
|
||||||
|
|
|
||||||
|
|
@ -319,6 +319,28 @@ public:
|
||||||
const llama_cross * cross = nullptr;
|
const llama_cross * cross = nullptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
|
||||||
|
public:
|
||||||
|
llm_graph_input_mem_hybrid(
|
||||||
|
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn,
|
||||||
|
std::unique_ptr<llm_graph_input_rs> inp_rs,
|
||||||
|
const llama_memory_hybrid_context * mctx) :
|
||||||
|
inp_attn(std::move(inp_attn)),
|
||||||
|
inp_rs(std::move(inp_rs)),
|
||||||
|
mctx(mctx) { }
|
||||||
|
virtual ~llm_graph_input_mem_hybrid() = default;
|
||||||
|
|
||||||
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
||||||
|
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn;
|
||||||
|
std::unique_ptr<llm_graph_input_rs> inp_rs;
|
||||||
|
|
||||||
|
llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); }
|
||||||
|
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
|
||||||
|
|
||||||
|
const llama_memory_hybrid_context * mctx;
|
||||||
|
};
|
||||||
|
|
||||||
// TODO: remove this when ggml_scale_add is implemented
|
// TODO: remove this when ggml_scale_add is implemented
|
||||||
class llm_graph_input_one : public llm_graph_input_i {
|
class llm_graph_input_one : public llm_graph_input_i {
|
||||||
public:
|
public:
|
||||||
|
|
@ -575,7 +597,7 @@ struct llm_graph_context {
|
||||||
float kq_scale,
|
float kq_scale,
|
||||||
int il) const;
|
int il) const;
|
||||||
|
|
||||||
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified(const llama_kv_cache_unified_context * mctx_cur = nullptr) const;
|
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
|
||||||
|
|
||||||
ggml_tensor * build_attn(
|
ggml_tensor * build_attn(
|
||||||
llm_graph_input_attn_kv_unified * inp,
|
llm_graph_input_attn_kv_unified * inp,
|
||||||
|
|
@ -590,7 +612,7 @@ struct llm_graph_context {
|
||||||
float kq_scale,
|
float kq_scale,
|
||||||
int il) const;
|
int il) const;
|
||||||
|
|
||||||
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa(const llama_kv_cache_unified_iswa_context * mctx_cur = nullptr) const;
|
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
|
||||||
|
|
||||||
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
|
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
|
||||||
ggml_tensor * build_attn(
|
ggml_tensor * build_attn(
|
||||||
|
|
@ -643,7 +665,7 @@ struct llm_graph_context {
|
||||||
int32_t rs_zero,
|
int32_t rs_zero,
|
||||||
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
|
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
|
||||||
|
|
||||||
llm_graph_input_rs * build_rs_inp(const llama_memory_recurrent_context * mctx_cur = nullptr) const;
|
llm_graph_input_rs * build_rs_inp() const;
|
||||||
|
|
||||||
ggml_tensor * build_rs(
|
ggml_tensor * build_rs(
|
||||||
llm_graph_input_rs * inp,
|
llm_graph_input_rs * inp,
|
||||||
|
|
@ -663,6 +685,11 @@ struct llm_graph_context {
|
||||||
ggml_tensor * token_shift,
|
ggml_tensor * token_shift,
|
||||||
const llama_ubatch & ubatch,
|
const llama_ubatch & ubatch,
|
||||||
int il) const;
|
int il) const;
|
||||||
|
//
|
||||||
|
// hybrid
|
||||||
|
//
|
||||||
|
|
||||||
|
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
|
||||||
|
|
||||||
//
|
//
|
||||||
// pooling
|
// pooling
|
||||||
|
|
|
||||||
|
|
@ -10220,11 +10220,7 @@ struct llm_build_jamba : public llm_graph_context_mamba {
|
||||||
// {n_embd, n_tokens}
|
// {n_embd, n_tokens}
|
||||||
inpL = build_inp_embd(model.tok_embd);
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
|
|
||||||
const auto * mctx_hyb = static_cast<const llama_memory_hybrid_context *>(mctx);
|
auto * inp_hybrid = build_inp_mem_hybrid();
|
||||||
|
|
||||||
auto * inp_rs = build_rs_inp(mctx_hyb->get_recr());
|
|
||||||
|
|
||||||
auto * inp_attn = build_attn_inp_kv_unified(mctx_hyb->get_attn());
|
|
||||||
|
|
||||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||||
|
|
||||||
|
|
@ -10235,7 +10231,7 @@ struct llm_build_jamba : public llm_graph_context_mamba {
|
||||||
cb(cur, "attn_norm", il);
|
cb(cur, "attn_norm", il);
|
||||||
|
|
||||||
if (n_head_kv == 0) {
|
if (n_head_kv == 0) {
|
||||||
cur = build_mamba_layer(inp_rs, gf, cur, model, ubatch, il);
|
cur = build_mamba_layer(inp_hybrid->get_recr(), gf, cur, model, ubatch, il);
|
||||||
} else {
|
} else {
|
||||||
// Attention
|
// Attention
|
||||||
|
|
||||||
|
|
@ -10256,7 +10252,7 @@ struct llm_build_jamba : public llm_graph_context_mamba {
|
||||||
cb(Vcur, "Vcur", il);
|
cb(Vcur, "Vcur", il);
|
||||||
|
|
||||||
// No RoPE :)
|
// No RoPE :)
|
||||||
cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il);
|
cur = build_attn(inp_hybrid->get_attn(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (il == n_layer - 1 && inp_out_ids) {
|
if (il == n_layer - 1 && inp_out_ids) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue