feat: Construct hybrid recurrent cache for hybrid recurrent models
This includes a refactor of the create_memory logic to avoid needing to use the arch enum explicitly unless a model needs explicit cache instantiation logic beyond the standard logic for recurrent, hybrid, unified, and iswa. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
c71eaa37a0
commit
423c89401d
|
|
@ -9,6 +9,7 @@
|
||||||
#include "llama-kv-cache-unified.h"
|
#include "llama-kv-cache-unified.h"
|
||||||
#include "llama-kv-cache-unified-iswa.h"
|
#include "llama-kv-cache-unified-iswa.h"
|
||||||
#include "llama-kv-cache-recurrent.h"
|
#include "llama-kv-cache-recurrent.h"
|
||||||
|
#include "llama-kv-cache-hybrid-recurrent.h"
|
||||||
|
|
||||||
#include "ggml-cpp.h"
|
#include "ggml-cpp.h"
|
||||||
|
|
||||||
|
|
@ -13742,6 +13743,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||||
llama_memory_i * res;
|
llama_memory_i * res;
|
||||||
|
|
||||||
switch (arch) {
|
switch (arch) {
|
||||||
|
// Models that need specific instantiation should be handled in the
|
||||||
|
// switch statement
|
||||||
case LLM_ARCH_BERT:
|
case LLM_ARCH_BERT:
|
||||||
case LLM_ARCH_JINA_BERT_V2:
|
case LLM_ARCH_JINA_BERT_V2:
|
||||||
case LLM_ARCH_NOMIC_BERT:
|
case LLM_ARCH_NOMIC_BERT:
|
||||||
|
|
@ -13751,58 +13754,71 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
||||||
{
|
{
|
||||||
res = nullptr;
|
res = nullptr;
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_MAMBA:
|
// Models that need standard caching should rely on recurrent/hybrid
|
||||||
case LLM_ARCH_RWKV6:
|
// checks
|
||||||
case LLM_ARCH_RWKV6QWEN2:
|
|
||||||
case LLM_ARCH_RWKV7:
|
|
||||||
case LLM_ARCH_ARWKV7:
|
|
||||||
{
|
|
||||||
res = new llama_kv_cache_recurrent(
|
|
||||||
*this,
|
|
||||||
nullptr,
|
|
||||||
GGML_TYPE_F32,
|
|
||||||
GGML_TYPE_F32,
|
|
||||||
cparams.offload_kqv,
|
|
||||||
std::max((uint32_t) 1, cparams.n_seq_max),
|
|
||||||
cparams.n_seq_max);
|
|
||||||
} break;
|
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
const auto padding = llama_kv_cache_unified::get_padding(cparams);
|
if (llm_arch_is_recurrent(arch)) {
|
||||||
|
res = new llama_kv_cache_recurrent(
|
||||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
|
|
||||||
|
|
||||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
|
||||||
|
|
||||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
|
||||||
GGML_ASSERT(hparams.is_swa_any());
|
|
||||||
|
|
||||||
res = new llama_kv_cache_unified_iswa(
|
|
||||||
*this,
|
|
||||||
params.type_k,
|
|
||||||
params.type_v,
|
|
||||||
!cparams.flash_attn,
|
|
||||||
cparams.offload_kqv,
|
|
||||||
params.swa_full,
|
|
||||||
cparams.n_ctx,
|
|
||||||
cparams.n_seq_max,
|
|
||||||
cparams.n_ubatch,
|
|
||||||
padding);
|
|
||||||
} else {
|
|
||||||
GGML_ASSERT(!hparams.is_swa_any());
|
|
||||||
|
|
||||||
res = new llama_kv_cache_unified(
|
|
||||||
*this,
|
*this,
|
||||||
nullptr,
|
nullptr,
|
||||||
params.type_k,
|
GGML_TYPE_F32,
|
||||||
params.type_v,
|
GGML_TYPE_F32,
|
||||||
!cparams.flash_attn,
|
|
||||||
cparams.offload_kqv,
|
cparams.offload_kqv,
|
||||||
cparams.n_ctx,
|
std::max((uint32_t) 1, cparams.n_seq_max),
|
||||||
cparams.n_seq_max,
|
cparams.n_seq_max);
|
||||||
padding,
|
} else if (llm_arch_is_hybrid_recurrent(arch)) {
|
||||||
hparams.n_swa,
|
res = new llama_kv_cache_hybrid_recurrent(
|
||||||
hparams.swa_type);
|
/* model */ *this,
|
||||||
|
/* attn_type_k */ params.type_k,
|
||||||
|
/* attn_type_v */ params.type_v,
|
||||||
|
/* attn_v_trans */ !cparams.flash_attn,
|
||||||
|
/* attn_kv_size */ cparams.n_ctx,
|
||||||
|
/* attn_n_pad */ llama_kv_cache_unified::get_padding(cparams),
|
||||||
|
/* attn_n_swa */ hparams.n_swa,
|
||||||
|
/* attn_swa_type */ hparams.swa_type,
|
||||||
|
/* recurrent_type_k */ GGML_TYPE_F32,
|
||||||
|
/* recurrent_type_v */ GGML_TYPE_F32,
|
||||||
|
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
|
||||||
|
/* n_seq_max */ cparams.n_seq_max,
|
||||||
|
/* offload */ cparams.offload_kqv);
|
||||||
|
} else {
|
||||||
|
const auto padding = llama_kv_cache_unified::get_padding(cparams);
|
||||||
|
|
||||||
|
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
|
||||||
|
|
||||||
|
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
||||||
|
|
||||||
|
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||||
|
GGML_ASSERT(hparams.is_swa_any());
|
||||||
|
|
||||||
|
res = new llama_kv_cache_unified_iswa(
|
||||||
|
*this,
|
||||||
|
params.type_k,
|
||||||
|
params.type_v,
|
||||||
|
!cparams.flash_attn,
|
||||||
|
cparams.offload_kqv,
|
||||||
|
params.swa_full,
|
||||||
|
cparams.n_ctx,
|
||||||
|
cparams.n_seq_max,
|
||||||
|
cparams.n_ubatch,
|
||||||
|
padding);
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(!hparams.is_swa_any());
|
||||||
|
|
||||||
|
res = new llama_kv_cache_unified(
|
||||||
|
*this,
|
||||||
|
nullptr,
|
||||||
|
params.type_k,
|
||||||
|
params.type_v,
|
||||||
|
!cparams.flash_attn,
|
||||||
|
cparams.offload_kqv,
|
||||||
|
cparams.n_ctx,
|
||||||
|
cparams.n_seq_max,
|
||||||
|
padding,
|
||||||
|
hparams.n_swa,
|
||||||
|
hparams.swa_type);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue