Merge branch 'master' into compilade/mamba2
This commit is contained in:
commit
830e5542c2
|
|
@ -499,6 +499,7 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_COS,
|
||||
GGML_METAL_KERNEL_TYPE_NEG,
|
||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||
GGML_METAL_KERNEL_TYPE_MEAN,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
||||
|
|
@ -1456,6 +1457,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
||||
|
|
@ -1655,6 +1657,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|||
case GGML_OP_LOG:
|
||||
return false; // TODO: implement
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||
|
|
@ -2402,11 +2405,30 @@ static bool ggml_metal_encode_node(
|
|||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MEAN:
|
||||
{
|
||||
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
switch (dst->op) {
|
||||
case GGML_OP_SUM_ROWS:
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
||||
break;
|
||||
case GGML_OP_MEAN:
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = MIN(nth, ne00);
|
||||
|
||||
ggml_metal_kargs_sum_rows args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
|
|
@ -2436,11 +2458,12 @@ static bool ggml_metal_encode_node(
|
|||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:2];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
{
|
||||
|
|
|
|||
|
|
@ -993,31 +993,61 @@ kernel void kernel_neg(
|
|||
dst[tpig] = -src0[tpig];
|
||||
}
|
||||
|
||||
template <bool norm>
|
||||
kernel void kernel_sum_rows(
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
device const float * src0,
|
||||
device float * dst,
|
||||
constant ggml_metal_kargs_sum_rows & args,
|
||||
uint3 tpig[[thread_position_in_grid]]) {
|
||||
int64_t i3 = tpig.z;
|
||||
int64_t i2 = tpig.y;
|
||||
int64_t i1 = tpig.x;
|
||||
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
int64_t i3 = tgpig.z;
|
||||
int64_t i2 = tgpig.y;
|
||||
int64_t i1 = tgpig.x;
|
||||
|
||||
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (sgitg == 0) {
|
||||
shmem_f32[tiisg] = 0.0f;
|
||||
}
|
||||
|
||||
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
|
||||
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
|
||||
|
||||
float row_sum = 0;
|
||||
float sumf = 0;
|
||||
|
||||
for (int64_t i0 = 0; i0 < args.ne00; i0++) {
|
||||
row_sum += src_row[i0];
|
||||
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
|
||||
sumf += src_row[i0];
|
||||
}
|
||||
|
||||
dst_row[0] = row_sum;
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
if (tiisg == 0) {
|
||||
shmem_f32[sgitg] = sumf;
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
sumf = shmem_f32[tiisg];
|
||||
sumf = simd_sum(sumf);
|
||||
|
||||
if (tpitg.x == 0) {
|
||||
dst_row[0] = norm ? sumf / args.ne00 : sumf;
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
|
||||
|
||||
template [[host_name("kernel_sum_rows")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
|
||||
template [[host_name("kernel_mean")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_soft_max(
|
||||
device const char * src0,
|
||||
|
|
|
|||
|
|
@ -22,8 +22,9 @@ add_library(llama
|
|||
llama-io.cpp
|
||||
llama-kv-cache-unified.cpp
|
||||
llama-kv-cache-unified-iswa.cpp
|
||||
llama-kv-cache-recurrent.cpp
|
||||
llama-memory.cpp
|
||||
llama-memory-hybrid.cpp
|
||||
llama-memory-recurrent.cpp
|
||||
llama-mmap.cpp
|
||||
llama-model-loader.cpp
|
||||
llama-model-saver.cpp
|
||||
|
|
|
|||
|
|
@ -148,6 +148,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
||||
{ LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" },
|
||||
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
||||
|
|
@ -1835,3 +1836,26 @@ llm_arch llm_arch_from_string(const std::string & name) {
|
|||
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
|
||||
return LLM_TENSOR_INFOS.at(tensor);
|
||||
}
|
||||
|
||||
bool llm_arch_is_recurrent(const llm_arch & arch) {
|
||||
switch (arch) {
|
||||
case LLM_ARCH_MAMBA:
|
||||
case LLM_ARCH_MAMBA2:
|
||||
case LLM_ARCH_RWKV6:
|
||||
case LLM_ARCH_RWKV6QWEN2:
|
||||
case LLM_ARCH_RWKV7:
|
||||
case LLM_ARCH_ARWKV7:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_arch_is_hybrid(const llm_arch & arch) {
|
||||
// TODO: There are currently no hybrid models! Once there are, this will be
|
||||
// the place to identify them
|
||||
switch (arch) {
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -152,6 +152,7 @@ enum llm_kv {
|
|||
LLM_KV_ATTENTION_SCALE,
|
||||
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
||||
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
||||
LLM_KV_ATTENTION_LAYER_INDICES,
|
||||
|
||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
||||
|
|
@ -442,3 +443,6 @@ const char * llm_arch_name(llm_arch arch);
|
|||
llm_arch llm_arch_from_string(const std::string & name);
|
||||
|
||||
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
|
||||
|
||||
bool llm_arch_is_recurrent(const llm_arch & arch);
|
||||
bool llm_arch_is_hybrid (const llm_arch & arch);
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@
|
|||
|
||||
#include "llama-kv-cache-unified.h"
|
||||
#include "llama-kv-cache-unified-iswa.h"
|
||||
#include "llama-kv-cache-recurrent.h"
|
||||
#include "llama-memory-hybrid.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
|
@ -238,18 +239,18 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
||||
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
||||
GGML_UNUSED(ubatch);
|
||||
|
||||
const int64_t n_kv = kv_state->get_n_kv();
|
||||
const int64_t n_rs = mem_state->get_n_rs();
|
||||
|
||||
if (s_copy) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
||||
int32_t * data = (int32_t *) s_copy->data;
|
||||
|
||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
data[i] = kv_state->s_copy(i);
|
||||
for (uint32_t i = 0; i < n_rs; ++i) {
|
||||
data[i] = mem_state->s_copy(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -403,6 +404,24 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_kq_mask) {
|
||||
mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
const int64_t n_rs = mem_state->get_state_recr()->get_n_rs();
|
||||
|
||||
if (s_copy) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
||||
int32_t * data = (int32_t *) s_copy->data;
|
||||
|
||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||
for (uint32_t i = 0; i < n_rs; ++i) {
|
||||
data[i] = mem_state->get_state_recr()->s_copy(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// llm_graph_context
|
||||
//
|
||||
|
|
@ -961,23 +980,6 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
|||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
|
||||
|
||||
const auto n_kv = kv_state->get_n_kv();
|
||||
|
||||
auto & cur = inp->s_copy;
|
||||
|
||||
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
|
||||
ggml_set_input(cur);
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
|
||||
auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
|
||||
|
||||
|
|
@ -1047,6 +1049,33 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
|
|||
return pos_bias;
|
||||
}
|
||||
|
||||
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
||||
const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state);
|
||||
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
||||
|
||||
const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv();
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
}
|
||||
|
||||
{
|
||||
const auto n_rs = mem_state->get_state_recr()->get_n_rs();
|
||||
|
||||
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
||||
ggml_set_input(inp->s_copy);
|
||||
}
|
||||
|
||||
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * q,
|
||||
|
|
@ -1291,36 +1320,6 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
return cur;
|
||||
}
|
||||
|
||||
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
|
||||
|
||||
{
|
||||
const auto n_kv = kv_state->get_base()->get_n_kv();
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
}
|
||||
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
||||
|
||||
const auto n_kv = kv_state->get_swa()->get_n_kv();
|
||||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
}
|
||||
|
||||
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn(
|
||||
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||
ggml_cgraph * gf,
|
||||
|
|
@ -1430,33 +1429,109 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_recurrent_state(
|
||||
ggml_tensor * llm_graph_context::build_attn(
|
||||
llm_graph_input_mem_hybrid * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
// these nodes are added to the graph together so that they are not reordered
|
||||
// by doing so, the number of splits in the graph is reduced
|
||||
ggml_build_forward_expand(gf, q_cur);
|
||||
ggml_build_forward_expand(gf, k_cur);
|
||||
ggml_build_forward_expand(gf, v_cur);
|
||||
|
||||
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_attn();
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
||||
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
||||
}
|
||||
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
|
||||
ggml_tensor * q = q_cur;
|
||||
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
||||
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
cur = build_lora_mm(wo, cur);
|
||||
if (arch == LLM_ARCH_GLM4) {
|
||||
// GLM4 seems to have numerical issues with half-precision accumulators
|
||||
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
}
|
||||
|
||||
if (wo_b) {
|
||||
cur = ggml_add(ctx0, cur, wo_b);
|
||||
}
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
|
||||
|
||||
{
|
||||
const auto n_kv = kv_state->get_base()->get_n_kv();
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
}
|
||||
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
||||
|
||||
const auto n_kv = kv_state->get_swa()->get_n_kv();
|
||||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
}
|
||||
|
||||
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_rs(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
ggml_tensor * state_copy,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
const std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)> & get_state_rows) const {
|
||||
uint32_t n_kv,
|
||||
uint32_t kv_head,
|
||||
uint32_t kv_size,
|
||||
int32_t rs_zero,
|
||||
const llm_graph_get_rows_fn & get_state_rows) const {
|
||||
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
const auto n_kv = kv_state->get_n_kv();
|
||||
const auto kv_head = kv_state->get_head();
|
||||
const auto rs_zero = kv_state->get_rs_z();
|
||||
|
||||
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
|
||||
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
|
||||
|
||||
// Clear a single state which will then be copied to the other cleared states.
|
||||
// Note that this is a no-op when the view is zero-sized.
|
||||
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
|
||||
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
|
||||
|
||||
ggml_tensor * output_states;
|
||||
|
||||
// copy states
|
||||
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
||||
// {state_size, kv_size} -> {state_size, n_seqs}
|
||||
output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
|
||||
ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
|
||||
ggml_build_forward_expand(gf, output_states);
|
||||
|
||||
// copy extra states which won't be changed further (between n_seqs and n_kv)
|
||||
|
|
@ -1469,22 +1544,59 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
|
|||
return output_states;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
||||
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
|
||||
|
||||
const auto n_rs = kv_state->get_n_rs();
|
||||
|
||||
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
||||
ggml_set_input(inp->s_copy);
|
||||
|
||||
return (llm_graph_input_rs *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_rs(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
const llm_graph_get_rows_fn & get_state_rows) const {
|
||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
||||
|
||||
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_rs(
|
||||
llm_graph_input_mem_hybrid * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
const llm_graph_get_rows_fn & get_state_rows) const {
|
||||
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recr();
|
||||
|
||||
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * state_copy,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
||||
|
||||
const auto token_shift_count = hparams.token_shift_count;
|
||||
|
||||
const int64_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
ggml_tensor * token_shift_all = kv_state->get_k_l(il);
|
||||
ggml_tensor * token_shift_all = kv_state->get_r_l(il);
|
||||
|
||||
ggml_tensor * token_shift = build_recurrent_state(
|
||||
gf, token_shift_all, state_copy,
|
||||
hparams.n_embd_k_s(), n_seqs);
|
||||
ggml_tensor * token_shift = build_rs(
|
||||
inp, gf, token_shift_all,
|
||||
hparams.n_embd_r(), n_seqs);
|
||||
|
||||
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
|
||||
|
||||
|
|
@ -1495,7 +1607,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|||
ggml_tensor * token_shift,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
||||
|
||||
const auto token_shift_count = hparams.token_shift_count;
|
||||
const auto n_embd = hparams.n_embd;
|
||||
|
|
@ -1507,7 +1619,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|||
return ggml_cpy(
|
||||
ctx0,
|
||||
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
||||
ggml_view_1d(ctx0, kv_state->get_k_l(il), hparams.n_embd_k_s()*n_seqs, hparams.n_embd_k_s()*kv_head*ggml_element_size(kv_state->get_k_l(il)))
|
||||
ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il)))
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,8 @@ struct llama_memory_state_i;
|
|||
|
||||
class llama_kv_cache_unified_state;
|
||||
class llama_kv_cache_unified_iswa_state;
|
||||
class llama_kv_cache_recurrent_state;
|
||||
class llama_memory_recurrent_state;
|
||||
class llama_memory_hybrid_state;
|
||||
|
||||
// certain models (typically multi-modal) can produce different types of graphs
|
||||
enum llm_graph_type {
|
||||
|
|
@ -188,16 +189,16 @@ public:
|
|||
const llama_cparams & cparams;
|
||||
};
|
||||
|
||||
class llm_graph_input_s_copy : public llm_graph_input_i {
|
||||
class llm_graph_input_rs : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
|
||||
virtual ~llm_graph_input_s_copy() = default;
|
||||
llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {}
|
||||
virtual ~llm_graph_input_rs() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * s_copy; // I32 [kv_size]
|
||||
|
||||
const llama_kv_cache_recurrent_state * kv_state;
|
||||
const llama_memory_recurrent_state * mem_state;
|
||||
};
|
||||
|
||||
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
||||
|
|
@ -300,6 +301,33 @@ public:
|
|||
const llama_cross * cross = nullptr;
|
||||
};
|
||||
|
||||
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_mem_hybrid(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_memory_hybrid_state * mem_state) :
|
||||
hparams(hparams),
|
||||
cparams(cparams),
|
||||
mem_state(mem_state) {
|
||||
}
|
||||
virtual ~llm_graph_input_mem_hybrid() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
ggml_tensor * s_copy; // I32 [kv_size]
|
||||
|
||||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
||||
const llama_memory_hybrid_state * mem_state;
|
||||
};
|
||||
|
||||
//
|
||||
// llm_graph_result
|
||||
//
|
||||
|
|
@ -383,6 +411,9 @@ struct llm_graph_params {
|
|||
const llm_graph_cb & cb;
|
||||
};
|
||||
|
||||
// used in build_rs to properly order writes and avoid unnecessary copies
|
||||
using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
|
||||
|
||||
struct llm_graph_context {
|
||||
const llm_arch arch;
|
||||
|
||||
|
|
@ -508,13 +539,14 @@ struct llm_graph_context {
|
|||
ggml_tensor * build_inp_out_ids() const;
|
||||
ggml_tensor * build_inp_mean() const;
|
||||
ggml_tensor * build_inp_cls() const;
|
||||
ggml_tensor * build_inp_s_copy() const;
|
||||
|
||||
ggml_tensor * build_inp_cross_embd() const;
|
||||
ggml_tensor * build_inp_pos_bucket_enc() const;
|
||||
ggml_tensor * build_inp_pos_bucket_dec() const;
|
||||
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
|
||||
|
||||
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
|
||||
|
||||
//
|
||||
// attention
|
||||
//
|
||||
|
|
@ -589,22 +621,61 @@ struct llm_graph_context {
|
|||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
ggml_tensor * build_attn(
|
||||
llm_graph_input_mem_hybrid * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
//
|
||||
// recurrent
|
||||
//
|
||||
|
||||
ggml_tensor * build_recurrent_state(
|
||||
// TODO: avoid notion of "kv"
|
||||
// TODO: move this implementation to llama_memory_recurrent.
|
||||
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
|
||||
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
|
||||
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
|
||||
// `llama_memory_recurrent`
|
||||
ggml_tensor * build_rs(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
ggml_tensor * state_copy,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
const std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>
|
||||
& get_state_rows = ggml_get_rows) const;
|
||||
uint32_t n_kv,
|
||||
uint32_t kv_head,
|
||||
uint32_t kv_size,
|
||||
int32_t rs_zero,
|
||||
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
|
||||
|
||||
llm_graph_input_rs * build_rs_inp() const;
|
||||
|
||||
ggml_tensor * build_rs(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
|
||||
|
||||
ggml_tensor * build_rs(
|
||||
llm_graph_input_mem_hybrid * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
|
||||
|
||||
ggml_tensor * build_rwkv_token_shift_load(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * state_copy,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const;
|
||||
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
|
|||
return n_embd_head_v * n_head_kv;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_k_s() const {
|
||||
uint32_t llama_hparams::n_embd_r() const {
|
||||
if (wkv_head_size != 0) {
|
||||
// for RWKV models
|
||||
return token_shift_count * n_embd;
|
||||
|
|
@ -77,7 +77,7 @@ uint32_t llama_hparams::n_embd_k_s() const {
|
|||
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_v_s() const {
|
||||
uint32_t llama_hparams::n_embd_s() const {
|
||||
if (wkv_head_size != 0) {
|
||||
// corresponds to RWKV's wkv_states size
|
||||
return n_embd * wkv_head_size;
|
||||
|
|
@ -87,6 +87,10 @@ uint32_t llama_hparams::n_embd_v_s() const {
|
|||
return ssm_d_state * ssm_d_inner;
|
||||
}
|
||||
|
||||
bool llama_hparams::is_recurrent(uint32_t il) const {
|
||||
return recurrent_layer_arr[il];
|
||||
}
|
||||
|
||||
bool llama_hparams::is_swa(uint32_t il) const {
|
||||
if (il < n_layer) {
|
||||
return swa_layers[il];
|
||||
|
|
|
|||
|
|
@ -116,6 +116,9 @@ struct llama_hparams {
|
|||
uint32_t ssm_dt_rank = 0;
|
||||
uint32_t ssm_n_group = 0;
|
||||
|
||||
// for hybrid state space models
|
||||
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
|
||||
|
||||
bool ssm_dt_b_c_rms = false;
|
||||
|
||||
float f_clamp_kqv = 0.0f;
|
||||
|
|
@ -182,10 +185,13 @@ struct llama_hparams {
|
|||
|
||||
// dimension of the rolling state embeddings
|
||||
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
||||
uint32_t n_embd_k_s() const;
|
||||
uint32_t n_embd_r() const;
|
||||
|
||||
// dimension of the recurrent state embeddings
|
||||
uint32_t n_embd_v_s() const;
|
||||
uint32_t n_embd_s() const;
|
||||
|
||||
// whether or not the given layer is recurrent (for hybrid models)
|
||||
bool is_recurrent(uint32_t il) const;
|
||||
|
||||
bool is_swa(uint32_t il) const;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -197,21 +197,19 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
|||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||
llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
state_base = kv->get_base()->init_full();
|
||||
state_swa = kv->get_swa ()->init_full();
|
||||
|
||||
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
||||
llama_kv_cache_unified_iswa * kv) :
|
||||
state_base(kv->get_base()->init_full()),
|
||||
state_swa (kv->get_swa ()->init_full()),
|
||||
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||
llama_kv_cache_unified_iswa * kv,
|
||||
llama_context * lctx,
|
||||
bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
state_base = kv->get_base()->init_update(lctx, optimize);
|
||||
state_swa = kv->get_swa ()->init_update(lctx, optimize);
|
||||
|
||||
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
||||
bool optimize) :
|
||||
state_base(kv->get_base()->init_update(lctx, optimize)),
|
||||
state_swa (kv->get_swa ()->init_update(lctx, optimize)),
|
||||
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
||||
|
|
@ -219,15 +217,13 @@ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
|||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads_base,
|
||||
std::vector<uint32_t> heads_swa,
|
||||
std::vector<llama_ubatch> ubatches)
|
||||
: status(LLAMA_MEMORY_STATUS_SUCCESS),
|
||||
std::vector<llama_ubatch> ubatches) :
|
||||
sbatch(std::move(sbatch)),
|
||||
ubatches(std::move(ubatches)) {
|
||||
ubatches(std::move(ubatches)),
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
state_base.reset(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches));
|
||||
state_swa .reset(new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches));
|
||||
|
||||
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
||||
state_base(new llama_kv_cache_unified_state(kv->get_base(), {}, std::move(heads_base), this->ubatches)),
|
||||
state_swa (new llama_kv_cache_unified_state(kv->get_swa (), {}, std::move(heads_swa), this->ubatches)),
|
||||
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
||||
}
|
||||
|
||||
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
|
||||
|
|
|
|||
|
|
@ -117,8 +117,6 @@ public:
|
|||
const llama_kv_cache_unified_state * get_swa() const;
|
||||
|
||||
private:
|
||||
llama_memory_status status;
|
||||
|
||||
//llama_kv_cache_unified_iswa * kv;
|
||||
|
||||
llama_sbatch sbatch;
|
||||
|
|
@ -128,6 +126,8 @@ private:
|
|||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
llama_memory_state_ptr state_base;
|
||||
llama_memory_state_ptr state_swa;
|
||||
const llama_memory_state_ptr state_base;
|
||||
const llama_memory_state_ptr state_swa;
|
||||
|
||||
const llama_memory_status status;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -68,8 +68,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||
continue;
|
||||
}
|
||||
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
const char * dev_name = "CPU";
|
||||
|
||||
|
|
@ -1430,7 +1430,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
|||
for (const auto & layer : layers) {
|
||||
const uint32_t il = layer.il;
|
||||
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
|
||||
// Write key type
|
||||
const int32_t k_type_i = (int32_t)layer.k->type;
|
||||
|
|
@ -1452,7 +1452,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
|||
for (const auto & layer : layers) {
|
||||
const uint32_t il = layer.il;
|
||||
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
// Write value type
|
||||
const int32_t v_type_i = (int32_t)layer.v->type;
|
||||
|
|
@ -1476,7 +1476,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
|||
for (const auto & layer : layers) {
|
||||
const uint32_t il = layer.il;
|
||||
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
// Write value type
|
||||
const int32_t v_type_i = (int32_t)layer.v->type;
|
||||
|
|
@ -1621,7 +1621,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|||
for (const auto & layer : layers) {
|
||||
const uint32_t il = layer.il;
|
||||
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
|
||||
// Read type of key
|
||||
int32_t k_type_i_ref;
|
||||
|
|
@ -1651,7 +1651,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|||
for (const auto & layer : layers) {
|
||||
const uint32_t il = layer.il;
|
||||
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
|
|
@ -1681,7 +1681,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|||
for (const auto & layer : layers) {
|
||||
const uint32_t il = layer.il;
|
||||
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,247 @@
|
|||
#include "llama-memory-hybrid.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-context.h"
|
||||
|
||||
//
|
||||
// llama_memory_hybrid
|
||||
//
|
||||
|
||||
llama_memory_hybrid::llama_memory_hybrid(
|
||||
const llama_model & model,
|
||||
/* attn */
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
/* recurrent */
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
uint32_t rs_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
bool offload,
|
||||
/* layer filters */
|
||||
layer_filter_cb && filter_attn,
|
||||
layer_filter_cb && filter_recr) :
|
||||
hparams(model.hparams),
|
||||
mem_attn(new llama_kv_cache_unified(
|
||||
model,
|
||||
filter_attn == nullptr ?
|
||||
[&](int32_t il) { return !model.hparams.is_recurrent(il); }
|
||||
: filter_attn,
|
||||
type_k,
|
||||
type_v,
|
||||
v_trans,
|
||||
offload,
|
||||
kv_size,
|
||||
n_seq_max,
|
||||
n_pad,
|
||||
n_swa,
|
||||
swa_type
|
||||
)),
|
||||
mem_recr(new llama_memory_recurrent(
|
||||
model,
|
||||
filter_recr == nullptr ?
|
||||
[&](int32_t il) { return model.hparams.is_recurrent(il); }
|
||||
: filter_recr,
|
||||
type_r,
|
||||
type_s,
|
||||
offload,
|
||||
rs_size,
|
||||
n_seq_max
|
||||
)) {}
|
||||
|
||||
llama_memory_state_ptr llama_memory_hybrid::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) {
|
||||
|
||||
// since this includes a recurrent cache, we cannot use split_simple
|
||||
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
|
||||
|
||||
// follow the recurrent pattern for creating the ubatch splits
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
while (sbatch.n_tokens > 0) {
|
||||
llama_ubatch ubatch;
|
||||
|
||||
if (embd_pooled) {
|
||||
// Pooled embeddings cannot be split across ubatches (yet)
|
||||
ubatch = sbatch.split_seq(n_ubatch);
|
||||
} else {
|
||||
ubatch = sbatch.split_equal(n_ubatch);
|
||||
}
|
||||
|
||||
ubatches.push_back(ubatch);
|
||||
}
|
||||
|
||||
// prepare the recurrent batches first
|
||||
if (!mem_recr->prepare(ubatches)) {
|
||||
// TODO: will the recurrent cache be in an undefined state at this point?
|
||||
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
|
||||
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
// prepare the attention cache
|
||||
auto heads_attn = mem_attn->prepare(ubatches);
|
||||
if (heads_attn.empty()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
|
||||
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
return std::make_unique<llama_memory_hybrid_state>(
|
||||
this, std::move(sbatch), std::move(heads_attn), std::move(ubatches));
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_memory_hybrid::init_full() {
|
||||
return std::make_unique<llama_memory_hybrid_state>(this);
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
|
||||
return std::make_unique<llama_memory_hybrid_state>(this, lctx, optimize);
|
||||
}
|
||||
|
||||
bool llama_memory_hybrid::get_can_shift() const {
|
||||
// Shifting is trivially supported for recurrent
|
||||
return mem_attn->get_can_shift();
|
||||
}
|
||||
|
||||
void llama_memory_hybrid::clear(bool data) {
|
||||
mem_attn->clear(data);
|
||||
mem_recr->clear(data);
|
||||
}
|
||||
|
||||
bool llama_memory_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
// Try removing from the recurrent cache first since it may fail. If it does
|
||||
// fail, the cache will not have been mutated.
|
||||
if (!mem_recr->seq_rm(seq_id, p0, p1)) {
|
||||
return false;
|
||||
}
|
||||
return mem_attn->seq_rm(seq_id, p0, p1);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid::seq_keep(llama_seq_id seq_id) {
|
||||
mem_attn->seq_keep(seq_id);
|
||||
mem_recr->seq_keep(seq_id);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
mem_attn->seq_add(seq_id, p0, p1, shift);
|
||||
mem_recr->seq_add(seq_id, p0, p1, shift);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
mem_attn->seq_div(seq_id, p0, p1, d);
|
||||
mem_recr->seq_div(seq_id, p0, p1, d);
|
||||
}
|
||||
|
||||
llama_pos llama_memory_hybrid::seq_pos_min(llama_seq_id seq_id) const {
|
||||
// the min of the total cache is the max of the two caches' min values
|
||||
return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
|
||||
}
|
||||
|
||||
llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
|
||||
// the max of the total cache is the min of the two caches' max values
|
||||
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
|
||||
}
|
||||
|
||||
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
||||
mem_attn->state_write(io, seq_id);
|
||||
mem_recr->state_write(io, seq_id);
|
||||
}
|
||||
|
||||
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||
mem_attn->state_read(io, seq_id);
|
||||
mem_recr->state_read(io, seq_id);
|
||||
}
|
||||
|
||||
llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
|
||||
return mem_attn.get();
|
||||
}
|
||||
|
||||
llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
|
||||
return mem_recr.get();
|
||||
}
|
||||
|
||||
llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_hybrid * mem) :
|
||||
state_attn(mem->get_mem_attn()->init_full()),
|
||||
state_recr(mem->get_mem_recr()->init_full()),
|
||||
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
|
||||
}
|
||||
|
||||
llama_memory_hybrid_state::llama_memory_hybrid_state(
|
||||
llama_memory_hybrid * mem,
|
||||
llama_context * lctx,
|
||||
bool optimize) :
|
||||
state_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
|
||||
state_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
|
||||
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
|
||||
}
|
||||
|
||||
llama_memory_hybrid_state::llama_memory_hybrid_state(
|
||||
llama_memory_hybrid * mem,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads_attn,
|
||||
std::vector<llama_ubatch> ubatches) :
|
||||
sbatch(std::move(sbatch)),
|
||||
ubatches(std::move(ubatches)),
|
||||
// note: here we copy the ubatches. not sure if this is ideal
|
||||
state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), {}, std::move(heads_attn), this->ubatches)),
|
||||
state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(), {}, this->ubatches)),
|
||||
status(LLAMA_MEMORY_STATUS_SUCCESS) {
|
||||
}
|
||||
|
||||
bool llama_memory_hybrid_state::next() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
state_attn->next();
|
||||
state_recr->next();
|
||||
|
||||
if (++i_next >= ubatches.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool llama_memory_hybrid_state::apply() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
bool res = true;
|
||||
|
||||
res = res & state_attn->apply();
|
||||
res = res & state_recr->apply();
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<int64_t> & llama_memory_hybrid_state::out_ids() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return sbatch.out_ids;
|
||||
}
|
||||
|
||||
llama_memory_status llama_memory_hybrid_state::get_status() const {
|
||||
return status;
|
||||
}
|
||||
|
||||
const llama_ubatch & llama_memory_hybrid_state::get_ubatch() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
return ubatches[i_next];
|
||||
}
|
||||
|
||||
const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn() const {
|
||||
return static_cast<const llama_kv_cache_unified_state *>(state_attn.get());
|
||||
}
|
||||
|
||||
const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr() const {
|
||||
return static_cast<const llama_memory_recurrent_state *>(state_recr.get());
|
||||
}
|
||||
|
|
@ -0,0 +1,143 @@
|
|||
#pragma once
|
||||
|
||||
#include "llama-batch.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-kv-cache-unified.h"
|
||||
#include "llama-memory.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
//
|
||||
// llama_memory_hybrid
|
||||
//
|
||||
|
||||
// utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to
|
||||
// support models where each layer may be either attention-based or recurrent
|
||||
|
||||
class llama_memory_hybrid : public llama_memory_i {
|
||||
public:
|
||||
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
llama_memory_hybrid(
|
||||
const llama_model & model,
|
||||
/* attn */
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_pad,
|
||||
uint32_t n_swa,
|
||||
llama_swa_type swa_type,
|
||||
/* recurrent */
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
uint32_t rs_size,
|
||||
/* common */
|
||||
uint32_t n_seq_max,
|
||||
bool offload,
|
||||
/* layer filters */
|
||||
layer_filter_cb && filter_attn = nullptr,
|
||||
layer_filter_cb && filter_recr = nullptr);
|
||||
|
||||
~llama_memory_hybrid() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
//
|
||||
|
||||
llama_memory_state_ptr init_batch(
|
||||
const llama_batch & batch,
|
||||
uint32_t n_ubatch,
|
||||
bool embd_pooled) override;
|
||||
|
||||
llama_memory_state_ptr init_full() override;
|
||||
|
||||
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
||||
void clear(bool data) override;
|
||||
|
||||
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
||||
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
||||
void seq_keep(llama_seq_id seq_id) override;
|
||||
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
||||
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
||||
|
||||
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
||||
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
||||
|
||||
// state write/load
|
||||
|
||||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
||||
|
||||
//
|
||||
// llama_memory_hybrid specific API
|
||||
//
|
||||
|
||||
llama_kv_cache_unified * get_mem_attn() const;
|
||||
llama_memory_recurrent * get_mem_recr() const;
|
||||
|
||||
private:
|
||||
const llama_hparams & hparams;
|
||||
|
||||
const std::unique_ptr<llama_kv_cache_unified> mem_attn;
|
||||
const std::unique_ptr<llama_memory_recurrent> mem_recr;
|
||||
};
|
||||
|
||||
class llama_memory_hybrid_state : public llama_memory_state_i {
|
||||
public:
|
||||
// init failure
|
||||
explicit llama_memory_hybrid_state(llama_memory_status status);
|
||||
|
||||
// init full
|
||||
explicit llama_memory_hybrid_state(llama_memory_hybrid * mem);
|
||||
|
||||
// init update
|
||||
explicit llama_memory_hybrid_state(
|
||||
llama_memory_hybrid * mem,
|
||||
llama_context * lctx,
|
||||
bool optimize);
|
||||
|
||||
// init success
|
||||
llama_memory_hybrid_state(
|
||||
llama_memory_hybrid * mem,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<uint32_t> heads_attn,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
~llama_memory_hybrid_state() = default;
|
||||
|
||||
bool next() override;
|
||||
bool apply() override;
|
||||
|
||||
std::vector<int64_t> & out_ids() override;
|
||||
|
||||
llama_memory_status get_status() const override;
|
||||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_memory_hybrid_state
|
||||
//
|
||||
|
||||
const llama_kv_cache_unified_state * get_state_attn() const;
|
||||
const llama_memory_recurrent_state * get_state_recr() const;
|
||||
|
||||
private:
|
||||
llama_sbatch sbatch;
|
||||
|
||||
// the index of the next ubatch to process
|
||||
size_t i_next = 0;
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
||||
const llama_memory_state_ptr state_attn;
|
||||
const llama_memory_state_ptr state_recr;
|
||||
|
||||
const llama_memory_status status;
|
||||
};
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
#include "llama-kv-cache-recurrent.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
#include "llama-impl.h"
|
||||
#include "llama-io.h"
|
||||
|
|
@ -12,27 +12,28 @@
|
|||
#include <stdexcept>
|
||||
|
||||
//
|
||||
// llama_kv_cache_recurrent
|
||||
// llama_memory_recurrent
|
||||
//
|
||||
|
||||
llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
||||
llama_memory_recurrent::llama_memory_recurrent(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
bool offload,
|
||||
uint32_t kv_size,
|
||||
uint32_t mem_size,
|
||||
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
|
||||
const int32_t n_layer = hparams.n_layer;
|
||||
|
||||
LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
|
||||
__func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
|
||||
LLAMA_LOG_INFO("%s: mem_size = %u, n_seq_max = %u, type_r = '%s', type_s = '%s', n_layer = %d\n",
|
||||
__func__, mem_size, n_seq_max, ggml_type_name(type_r), ggml_type_name(type_s), n_layer);
|
||||
|
||||
head = 0;
|
||||
size = kv_size;
|
||||
size = mem_size;
|
||||
used = 0;
|
||||
|
||||
cells.clear();
|
||||
cells.resize(kv_size);
|
||||
cells.resize(mem_size);
|
||||
|
||||
// create a context for each buffer type
|
||||
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
||||
|
|
@ -59,12 +60,14 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
|||
return it->second;
|
||||
};
|
||||
|
||||
k_l.reserve(n_layer);
|
||||
v_l.reserve(n_layer);
|
||||
r_l.resize(n_layer);
|
||||
s_l.resize(n_layer);
|
||||
|
||||
for (int i = 0; i < n_layer; i++) {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
|
||||
if (filter && !filter(i)) {
|
||||
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
|
||||
continue;
|
||||
}
|
||||
|
||||
const char * dev_name = "CPU";
|
||||
|
||||
|
|
@ -84,12 +87,12 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
|||
throw std::runtime_error("failed to create ggml context for kv cache");
|
||||
}
|
||||
|
||||
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
|
||||
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
|
||||
ggml_format_name(k, "cache_k_l%d", i);
|
||||
ggml_format_name(v, "cache_v_l%d", i);
|
||||
k_l.push_back(k);
|
||||
v_l.push_back(v);
|
||||
ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
|
||||
ggml_tensor * s = ggml_new_tensor_1d(ctx, type_s, hparams.n_embd_s()*mem_size);
|
||||
ggml_format_name(r, "cache_r_l%d", i);
|
||||
ggml_format_name(s, "cache_s_l%d", i);
|
||||
r_l[i] = r;
|
||||
s_l[i] = s;
|
||||
}
|
||||
|
||||
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
||||
|
|
@ -107,17 +110,17 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
|||
}
|
||||
|
||||
{
|
||||
const size_t memory_size_k = size_k_bytes();
|
||||
const size_t memory_size_v = size_v_bytes();
|
||||
const size_t memory_size_r = size_r_bytes();
|
||||
const size_t memory_size_s = size_s_bytes();
|
||||
|
||||
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
|
||||
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
|
||||
(float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
|
||||
ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
|
||||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::clear(bool data) {
|
||||
void llama_memory_recurrent::clear(bool data) {
|
||||
for (int32_t i = 0; i < (int32_t) size; ++i) {
|
||||
cells[i].pos = -1;
|
||||
cells[i].seq_id.clear();
|
||||
|
|
@ -135,7 +138,7 @@ void llama_kv_cache_recurrent::clear(bool data) {
|
|||
}
|
||||
}
|
||||
|
||||
bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||
uint32_t new_head = size;
|
||||
|
||||
if (p0 < 0) {
|
||||
|
|
@ -154,7 +157,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
|
|||
if (0 <= seq_id) {
|
||||
int32_t & tail_id = cells[seq_id].tail;
|
||||
if (tail_id >= 0) {
|
||||
const kv_cell & cell = cells[tail_id];
|
||||
const auto & cell = cells[tail_id];
|
||||
// partial intersection is invalid
|
||||
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
|
||||
return false;
|
||||
|
|
@ -202,7 +205,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
|
|||
return true;
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
void llama_memory_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
||||
if (seq_id_src == seq_id_dst) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -216,11 +219,11 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
|
|||
}
|
||||
|
||||
if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
|
||||
kv_cell & tail_src = cells[seq_id_src];
|
||||
kv_cell & tail_dst = cells[seq_id_dst];
|
||||
auto & tail_src = cells[seq_id_src];
|
||||
auto & tail_dst = cells[seq_id_dst];
|
||||
if (tail_dst.tail >= 0) {
|
||||
// clear destination seq_id if it wasn't empty
|
||||
kv_cell & cell_dst = cells[tail_dst.tail];
|
||||
auto & cell_dst = cells[tail_dst.tail];
|
||||
|
||||
cell_dst.seq_id.erase(seq_id_dst);
|
||||
tail_dst.tail = -1;
|
||||
|
|
@ -231,7 +234,7 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
|
|||
}
|
||||
}
|
||||
if (tail_src.tail >= 0) {
|
||||
kv_cell & cell_src = cells[tail_src.tail];
|
||||
auto & cell_src = cells[tail_src.tail];
|
||||
|
||||
cell_src.seq_id.insert(seq_id_dst);
|
||||
tail_dst.tail = tail_src.tail;
|
||||
|
|
@ -239,7 +242,7 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
|
|||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
|
||||
void llama_memory_recurrent::seq_keep(llama_seq_id seq_id) {
|
||||
uint32_t new_head = size;
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
|
|
@ -271,7 +274,7 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
|
|||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
void llama_memory_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
||||
if (shift == 0) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -293,7 +296,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
|
|||
if (0 <= seq_id && seq_id < (int64_t) size) {
|
||||
const int32_t tail_id = cells[seq_id].tail;
|
||||
if (tail_id >= 0) {
|
||||
kv_cell & cell = cells[tail_id];
|
||||
auto & cell = cells[tail_id];
|
||||
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
||||
cell.pos += shift;
|
||||
}
|
||||
|
|
@ -301,7 +304,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
|
|||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
void llama_memory_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
||||
if (d == 1) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -323,7 +326,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
|
|||
if (0 <= seq_id && seq_id < (int64_t) size) {
|
||||
const int32_t tail_id = cells[seq_id].tail;
|
||||
if (tail_id >= 0) {
|
||||
kv_cell & cell = cells[tail_id];
|
||||
auto & cell = cells[tail_id];
|
||||
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
||||
cell.pos /= d;
|
||||
}
|
||||
|
|
@ -331,7 +334,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
|
|||
}
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
|
||||
llama_pos llama_memory_recurrent::seq_pos_min(llama_seq_id seq_id) const {
|
||||
llama_pos result = std::numeric_limits<llama_pos>::max();
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
|
|
@ -347,7 +350,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
|
|||
return result;
|
||||
}
|
||||
|
||||
llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
||||
llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
||||
llama_pos result = -1;
|
||||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
|
|
@ -359,7 +362,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
|||
return result;
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
|
||||
llama_memory_state_ptr llama_memory_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_all) {
|
||||
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
|
@ -378,24 +381,24 @@ llama_memory_state_ptr llama_kv_cache_recurrent::init_batch(const llama_batch &
|
|||
}
|
||||
|
||||
if (!prepare(ubatches)) {
|
||||
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
||||
}
|
||||
|
||||
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this, std::move(sbatch), std::move(ubatches));
|
||||
return std::make_unique<llama_memory_recurrent_state>(this, std::move(sbatch), std::move(ubatches));
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_recurrent::init_full() {
|
||||
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_SUCCESS, this);
|
||||
llama_memory_state_ptr llama_memory_recurrent::init_full() {
|
||||
return std::make_unique<llama_memory_recurrent_state>(this);
|
||||
}
|
||||
|
||||
llama_memory_state_ptr llama_kv_cache_recurrent::init_update(llama_context * lctx, bool optimize) {
|
||||
llama_memory_state_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
|
||||
GGML_UNUSED(lctx);
|
||||
GGML_UNUSED(optimize);
|
||||
|
||||
return std::make_unique<llama_kv_cache_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
|
||||
return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
|
||||
}
|
||||
|
||||
bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||
bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
||||
// simply remember the full state because it is very small for this type of cache
|
||||
// TODO: optimize
|
||||
auto org_cells = cells;
|
||||
|
|
@ -419,7 +422,7 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
|
|||
return success;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
||||
const uint32_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
|
@ -453,9 +456,9 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|||
return false;
|
||||
}
|
||||
if (j > 0) {
|
||||
kv_cell & seq = cells[seq_id];
|
||||
auto & seq = cells[seq_id];
|
||||
if (seq.tail >= 0) {
|
||||
kv_cell & cell = cells[seq.tail];
|
||||
auto & cell = cells[seq.tail];
|
||||
// clear cells from seq_ids that become shared
|
||||
// (should not normally happen, but let's handle it anyway)
|
||||
cell.seq_id.erase(seq_id);
|
||||
|
|
@ -475,7 +478,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|||
std::vector<int32_t> tails_verif;
|
||||
tails_verif.assign(size, -1);
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
kv_cell & cell = cells[i];
|
||||
auto & cell = cells[i];
|
||||
for (llama_seq_id seq_id : cell.seq_id) {
|
||||
if (tails_verif[seq_id] != -1) {
|
||||
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
|
||||
|
|
@ -496,7 +499,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|||
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
||||
kv_cell & cell = cells[next_empty_cell];
|
||||
auto & cell = cells[next_empty_cell];
|
||||
if (cell.is_empty()) { break; }
|
||||
next_empty_cell += 1;
|
||||
}
|
||||
|
|
@ -504,20 +507,20 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|||
// find usable cell range
|
||||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
||||
kv_cell & seq_meta = cells[seq_id];
|
||||
auto & seq_meta = cells[seq_id];
|
||||
bool has_cell = false;
|
||||
if (seq_meta.tail >= 0) {
|
||||
kv_cell & cell = cells[seq_meta.tail];
|
||||
auto & cell = cells[seq_meta.tail];
|
||||
GGML_ASSERT(cell.has_seq_id(seq_id));
|
||||
// does this seq_id "own" the cell?
|
||||
if (cell.seq_id.size() == 1) { has_cell = true; }
|
||||
}
|
||||
if (!has_cell) {
|
||||
kv_cell & empty_cell = cells[next_empty_cell];
|
||||
auto & empty_cell = cells[next_empty_cell];
|
||||
GGML_ASSERT(empty_cell.is_empty());
|
||||
// copy old tail into the empty cell
|
||||
if (seq_meta.tail >= 0) {
|
||||
kv_cell & orig_cell = cells[seq_meta.tail];
|
||||
auto & orig_cell = cells[seq_meta.tail];
|
||||
empty_cell.pos = orig_cell.pos;
|
||||
empty_cell.src = orig_cell.src;
|
||||
orig_cell.seq_id.erase(seq_id);
|
||||
|
|
@ -530,7 +533,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|||
for (uint32_t i = 0; i < size; ++i) {
|
||||
next_empty_cell += 1;
|
||||
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
||||
kv_cell & cell = cells[next_empty_cell];
|
||||
auto & cell = cells[next_empty_cell];
|
||||
if (cell.is_empty()) { break; }
|
||||
}
|
||||
}
|
||||
|
|
@ -544,8 +547,8 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|||
const int32_t dst_id = s + min;
|
||||
const int32_t src_id = cells[ubatch.seq_id[s][0]].tail;
|
||||
if (dst_id != src_id) {
|
||||
kv_cell & dst_cell = cells[dst_id];
|
||||
kv_cell & src_cell = cells[src_id];
|
||||
auto & dst_cell = cells[dst_id];
|
||||
auto & src_cell = cells[src_id];
|
||||
|
||||
std::swap(dst_cell.pos, src_cell.pos);
|
||||
std::swap(dst_cell.src, src_cell.src);
|
||||
|
|
@ -567,7 +570,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|||
for (uint32_t s = 0; s < n_seqs; ++s) {
|
||||
const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1];
|
||||
const int32_t cell_id = s + min;
|
||||
kv_cell & cell = cells[cell_id];
|
||||
auto & cell = cells[cell_id];
|
||||
|
||||
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
|
||||
// What should happen when the pos backtracks or skips a value?
|
||||
|
|
@ -620,18 +623,18 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|||
head = min;
|
||||
n = max - min + 1;
|
||||
used = std::count_if(cells.begin(), cells.end(),
|
||||
[](const kv_cell & cell){ return !cell.is_empty(); });
|
||||
[](const mem_cell & cell){ return !cell.is_empty(); });
|
||||
|
||||
// sanity check
|
||||
return n >= n_seqs;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_recurrent::get_can_shift() const {
|
||||
bool llama_memory_recurrent::get_can_shift() const {
|
||||
// shifting the pos is trivial for recurrent models
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t llama_kv_cache_recurrent::total_size() const {
|
||||
size_t llama_memory_recurrent::total_size() const {
|
||||
size_t size = 0;
|
||||
for (const auto & buf : bufs) {
|
||||
size += ggml_backend_buffer_get_size(buf.get());
|
||||
|
|
@ -640,27 +643,31 @@ size_t llama_kv_cache_recurrent::total_size() const {
|
|||
return size;
|
||||
}
|
||||
|
||||
size_t llama_kv_cache_recurrent::size_k_bytes() const {
|
||||
size_t size_k_bytes = 0;
|
||||
size_t llama_memory_recurrent::size_r_bytes() const {
|
||||
size_t size_r_bytes = 0;
|
||||
|
||||
for (const auto & k : k_l) {
|
||||
size_k_bytes += ggml_nbytes(k);
|
||||
for (const auto & r : r_l) {
|
||||
if (r != nullptr) {
|
||||
size_r_bytes += ggml_nbytes(r);
|
||||
}
|
||||
}
|
||||
|
||||
return size_k_bytes;
|
||||
return size_r_bytes;
|
||||
}
|
||||
|
||||
size_t llama_kv_cache_recurrent::size_v_bytes() const {
|
||||
size_t size_v_bytes = 0;
|
||||
size_t llama_memory_recurrent::size_s_bytes() const {
|
||||
size_t size_s_bytes = 0;
|
||||
|
||||
for (const auto & v : v_l) {
|
||||
size_v_bytes += ggml_nbytes(v);
|
||||
for (const auto & s : s_l) {
|
||||
if (s != nullptr) {
|
||||
size_s_bytes += ggml_nbytes(s);
|
||||
}
|
||||
}
|
||||
|
||||
return size_v_bytes;
|
||||
return size_s_bytes;
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
||||
void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
||||
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
||||
uint32_t cell_count = 0;
|
||||
|
||||
|
|
@ -698,7 +705,7 @@ void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id s
|
|||
state_write_data(io, cell_ranges);
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||
void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
||||
uint32_t cell_count;
|
||||
io.read_to(&cell_count, sizeof(cell_count));
|
||||
|
||||
|
|
@ -717,7 +724,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
|
|||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
||||
void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
||||
for (const auto & range : cell_ranges) {
|
||||
for (uint32_t i = range.first; i < range.second; ++i) {
|
||||
const auto & cell = cells[i];
|
||||
|
|
@ -736,11 +743,11 @@ void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std
|
|||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
||||
const uint32_t v_trans = 0;
|
||||
void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
||||
const uint32_t s_trans = 0;
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
|
||||
io.write(&v_trans, sizeof(v_trans));
|
||||
io.write(&s_trans, sizeof(s_trans));
|
||||
io.write(&n_layer, sizeof(n_layer));
|
||||
|
||||
std::vector<uint8_t> tmp_buf;
|
||||
|
|
@ -748,75 +755,73 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
|
|||
// Iterate and write all the keys first, each row is a cell
|
||||
// Get whole range at a time
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
|
||||
// Write key type
|
||||
const int32_t k_type_i = (int32_t)k_l[il]->type;
|
||||
io.write(&k_type_i, sizeof(k_type_i));
|
||||
const int32_t r_type_i = (int32_t)r_l[il]->type;
|
||||
io.write(&r_type_i, sizeof(r_type_i));
|
||||
|
||||
// Write row size of key
|
||||
const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
|
||||
io.write(&k_size_row, sizeof(k_size_row));
|
||||
const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
|
||||
io.write(&r_size_row, sizeof(r_size_row));
|
||||
|
||||
// Read each range of cells of k_size length each into tmp_buf and write out
|
||||
for (const auto & range : cell_ranges) {
|
||||
const size_t range_size = range.second - range.first;
|
||||
const size_t buf_size = range_size * k_size_row;
|
||||
io.write_tensor(k_l[il], range.first * k_size_row, buf_size);
|
||||
const size_t buf_size = range_size * r_size_row;
|
||||
io.write_tensor(r_l[il], range.first * r_size_row, buf_size);
|
||||
}
|
||||
}
|
||||
|
||||
if (!v_trans) {
|
||||
if (!s_trans) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
// Write value type
|
||||
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
||||
io.write(&v_type_i, sizeof(v_type_i));
|
||||
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
||||
io.write(&s_type_i, sizeof(s_type_i));
|
||||
|
||||
// Write row size of value
|
||||
const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
|
||||
io.write(&v_size_row, sizeof(v_size_row));
|
||||
const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
|
||||
io.write(&s_size_row, sizeof(s_size_row));
|
||||
|
||||
// Read each range of cells of v_size length each into tmp_buf and write out
|
||||
// Read each range of cells of s_size length each into tmp_buf and write out
|
||||
for (const auto & range : cell_ranges) {
|
||||
const size_t range_size = range.second - range.first;
|
||||
const size_t buf_size = range_size * v_size_row;
|
||||
io.write_tensor(v_l[il], range.first * v_size_row, buf_size);
|
||||
const size_t buf_size = range_size * s_size_row;
|
||||
io.write_tensor(s_l[il], range.first * s_size_row, buf_size);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// When v is transposed, we also need the element size and get the element ranges from each row
|
||||
const uint32_t kv_size = size;
|
||||
const uint32_t mem_size = size;
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
const uint32_t n_embd_s = hparams.n_embd_s();
|
||||
|
||||
// Write value type
|
||||
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
||||
io.write(&v_type_i, sizeof(v_type_i));
|
||||
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
||||
io.write(&s_type_i, sizeof(s_type_i));
|
||||
|
||||
// Write element size
|
||||
const uint32_t v_size_el = ggml_type_size(v_l[il]->type);
|
||||
io.write(&v_size_el, sizeof(v_size_el));
|
||||
const uint32_t s_size_el = ggml_type_size(s_l[il]->type);
|
||||
io.write(&s_size_el, sizeof(s_size_el));
|
||||
|
||||
// Write GQA embedding size
|
||||
io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
|
||||
io.write(&n_embd_s, sizeof(n_embd_s));
|
||||
|
||||
// For each row, we get the element values of each cell
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
for (uint32_t j = 0; j < n_embd_s; ++j) {
|
||||
// Read each range of cells of v_size_el length each into tmp_buf and write out
|
||||
for (const auto & range : cell_ranges) {
|
||||
const size_t range_size = range.second - range.first;
|
||||
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
|
||||
const size_t buf_size = range_size * v_size_el;
|
||||
io.write_tensor(v_l[il], src_offset, buf_size);
|
||||
const size_t src_offset = (range.first + j * mem_size) * s_size_el;
|
||||
const size_t buf_size = range_size * s_size_el;
|
||||
io.write_tensor(s_l[il], src_offset, buf_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
||||
bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
||||
if (dest_seq_id != -1) {
|
||||
// single sequence
|
||||
|
||||
|
|
@ -869,7 +874,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
|||
clear(true);
|
||||
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
kv_cell & cell = cells[i];
|
||||
auto & cell = cells[i];
|
||||
|
||||
llama_pos pos;
|
||||
uint32_t n_seq_id;
|
||||
|
|
@ -883,7 +888,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
|||
llama_seq_id seq_id;
|
||||
io.read_to(&seq_id, sizeof(seq_id));
|
||||
|
||||
// TODO: llama_kv_cache_recurrent should have a notion of max sequences
|
||||
// TODO: llama_memory_recurrent should have a notion of max sequences
|
||||
//if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
|
||||
if (seq_id < 0) {
|
||||
//LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
|
||||
|
|
@ -915,10 +920,10 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
|||
return true;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
||||
uint32_t v_trans;
|
||||
bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
||||
uint32_t s_trans;
|
||||
uint32_t n_layer;
|
||||
io.read_to(&v_trans, sizeof(v_trans));
|
||||
io.read_to(&s_trans, sizeof(s_trans));
|
||||
io.read_to(&n_layer, sizeof(n_layer));
|
||||
|
||||
if (n_layer != hparams.n_layer) {
|
||||
|
|
@ -929,102 +934,100 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
|
|||
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
|
||||
return false;
|
||||
}
|
||||
if (false != (bool) v_trans) {
|
||||
LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
|
||||
if (false != (bool) s_trans) {
|
||||
LLAMA_LOG_ERROR("%s: incompatible s transposition\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
|
||||
// Read type of key
|
||||
int32_t k_type_i_ref;
|
||||
io.read_to(&k_type_i_ref, sizeof(k_type_i_ref));
|
||||
const int32_t k_type_i = (int32_t) k_l[il]->type;
|
||||
if (k_type_i != k_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
|
||||
int32_t r_type_i_ref;
|
||||
io.read_to(&r_type_i_ref, sizeof(r_type_i_ref));
|
||||
const int32_t r_type_i = (int32_t) r_l[il]->type;
|
||||
if (r_type_i != r_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Read row size of key
|
||||
uint64_t k_size_row_ref;
|
||||
io.read_to(&k_size_row_ref, sizeof(k_size_row_ref));
|
||||
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
|
||||
if (k_size_row != k_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
|
||||
uint64_t r_size_row_ref;
|
||||
io.read_to(&r_size_row_ref, sizeof(r_size_row_ref));
|
||||
const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
|
||||
if (r_size_row != r_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cell_count) {
|
||||
// Read and set the keys for the whole cell range
|
||||
ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
|
||||
ggml_backend_tensor_set(r_l[il], io.read(cell_count * r_size_row), head * r_size_row, cell_count * r_size_row);
|
||||
}
|
||||
}
|
||||
|
||||
if (!v_trans) {
|
||||
if (!s_trans) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
||||
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
||||
if (v_type_i != v_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
||||
int32_t s_type_i_ref;
|
||||
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
|
||||
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
||||
if (s_type_i != s_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Read row size of value
|
||||
uint64_t v_size_row_ref;
|
||||
io.read_to(&v_size_row_ref, sizeof(v_size_row_ref));
|
||||
const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa);
|
||||
if (v_size_row != v_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il);
|
||||
uint64_t s_size_row_ref;
|
||||
io.read_to(&s_size_row_ref, sizeof(s_size_row_ref));
|
||||
const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
|
||||
if (s_size_row != s_size_row_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cell_count) {
|
||||
// Read and set the values for the whole cell range
|
||||
ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
|
||||
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_row), head * s_size_row, cell_count * s_size_row);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For each layer, read the values for each cell (transposed)
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
const uint32_t n_embd_s = hparams.n_embd_s();
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
io.read_to(&v_type_i_ref, sizeof(v_type_i_ref));
|
||||
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
||||
if (v_type_i != v_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
||||
int32_t s_type_i_ref;
|
||||
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
|
||||
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
||||
if (s_type_i != s_type_i_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Read element size of value
|
||||
uint32_t v_size_el_ref;
|
||||
io.read_to(&v_size_el_ref, sizeof(v_size_el_ref));
|
||||
const size_t v_size_el = ggml_type_size(v_l[il]->type);
|
||||
if (v_size_el != v_size_el_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il);
|
||||
uint32_t s_size_el_ref;
|
||||
io.read_to(&s_size_el_ref, sizeof(s_size_el_ref));
|
||||
const size_t s_size_el = ggml_type_size(s_l[il]->type);
|
||||
if (s_size_el != s_size_el_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Read GQA embedding size
|
||||
uint32_t n_embd_v_gqa_ref;
|
||||
io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
|
||||
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il);
|
||||
// Read state embedding size
|
||||
uint32_t n_embd_s_ref;
|
||||
io.read_to(&n_embd_s_ref, sizeof(n_embd_s_ref));
|
||||
if (n_embd_s != n_embd_s_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cell_count) {
|
||||
// For each row in the transposed matrix, read the values for the whole cell range
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
const size_t dst_offset = (head + j * size) * v_size_el;
|
||||
ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
||||
for (uint32_t j = 0; j < n_embd_s; ++j) {
|
||||
const size_t dst_offset = (head + j * size) * s_size_el;
|
||||
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_el), dst_offset, cell_count * s_size_el);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1034,25 +1037,23 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
|
|||
}
|
||||
|
||||
//
|
||||
// llama_kv_cache_recurrent_state
|
||||
// llama_memory_recurrent_state
|
||||
//
|
||||
|
||||
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(llama_memory_status status) : status(status) {}
|
||||
llama_memory_recurrent_state::llama_memory_recurrent_state(llama_memory_status status) : status(status) {}
|
||||
|
||||
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) {
|
||||
llama_memory_recurrent_state::llama_memory_recurrent_state(
|
||||
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
|
||||
}
|
||||
|
||||
llama_kv_cache_recurrent_state::llama_kv_cache_recurrent_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_recurrent * kv,
|
||||
llama_memory_recurrent_state::llama_memory_recurrent_state(
|
||||
llama_memory_recurrent * mem,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<llama_ubatch> ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
|
||||
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
|
||||
|
||||
llama_kv_cache_recurrent_state::~llama_kv_cache_recurrent_state() = default;
|
||||
llama_memory_recurrent_state::~llama_memory_recurrent_state() = default;
|
||||
|
||||
bool llama_kv_cache_recurrent_state::next() {
|
||||
bool llama_memory_recurrent_state::next() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
if (++i_next >= ubatches.size()) {
|
||||
|
|
@ -1062,54 +1063,54 @@ bool llama_kv_cache_recurrent_state::next() {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool llama_kv_cache_recurrent_state::apply() {
|
||||
bool llama_memory_recurrent_state::apply() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
kv->find_slot(ubatches[i_next]);
|
||||
mem->find_slot(ubatches[i_next]);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<int64_t> & llama_kv_cache_recurrent_state::out_ids() {
|
||||
std::vector<int64_t> & llama_memory_recurrent_state::out_ids() {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return sbatch.out_ids;
|
||||
}
|
||||
|
||||
llama_memory_status llama_kv_cache_recurrent_state::get_status() const {
|
||||
llama_memory_status llama_memory_recurrent_state::get_status() const {
|
||||
return status;
|
||||
}
|
||||
|
||||
const llama_ubatch & llama_kv_cache_recurrent_state::get_ubatch() const {
|
||||
const llama_ubatch & llama_memory_recurrent_state::get_ubatch() const {
|
||||
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
||||
|
||||
return ubatches[i_next];
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_recurrent_state::get_n_kv() const {
|
||||
return is_full ? kv->size : kv->n;
|
||||
uint32_t llama_memory_recurrent_state::get_n_rs() const {
|
||||
return is_full ? mem->size : mem->n;
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_recurrent_state::get_head() const {
|
||||
return is_full ? 0 : kv->head;
|
||||
uint32_t llama_memory_recurrent_state::get_head() const {
|
||||
return is_full ? 0 : mem->head;
|
||||
}
|
||||
|
||||
int32_t llama_kv_cache_recurrent_state::get_rs_z() const {
|
||||
return is_full ? 0 : kv->rs_z;
|
||||
int32_t llama_memory_recurrent_state::get_rs_z() const {
|
||||
return is_full ? 0 : mem->rs_z;
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_recurrent_state::get_size() const {
|
||||
return kv->size;
|
||||
uint32_t llama_memory_recurrent_state::get_size() const {
|
||||
return mem->size;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_recurrent_state::get_k_l(int32_t il) const {
|
||||
return kv->k_l[il];
|
||||
ggml_tensor * llama_memory_recurrent_state::get_r_l(int32_t il) const {
|
||||
return mem->r_l[il];
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_recurrent_state::get_v_l(int32_t il) const {
|
||||
return kv->v_l[il];
|
||||
ggml_tensor * llama_memory_recurrent_state::get_s_l(int32_t il) const {
|
||||
return mem->s_l[il];
|
||||
}
|
||||
|
||||
int32_t llama_kv_cache_recurrent_state::s_copy(int i) const {
|
||||
return kv->cells[i + kv->head].src0;
|
||||
int32_t llama_memory_recurrent_state::s_copy(int i) const {
|
||||
return mem->cells[i + mem->head].src0;
|
||||
}
|
||||
|
|
@ -8,22 +8,27 @@
|
|||
#include <vector>
|
||||
|
||||
//
|
||||
// llama_kv_cache_recurrent
|
||||
// llama_memory_recurrent
|
||||
//
|
||||
|
||||
// TODO: extract the KV cache state used for graph computation into llama_kv_cache_recurrent_state_i
|
||||
// TODO: extract the cache state used for graph computation into llama_memory_recurrent_state_i
|
||||
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
|
||||
class llama_kv_cache_recurrent : public llama_memory_i {
|
||||
class llama_memory_recurrent : public llama_memory_i {
|
||||
public:
|
||||
llama_kv_cache_recurrent(
|
||||
|
||||
// this callback is used to filter out layers that should not be included in the cache
|
||||
using layer_filter_cb = std::function<bool(int32_t il)>;
|
||||
|
||||
llama_memory_recurrent(
|
||||
const llama_model & model,
|
||||
ggml_type type_k,
|
||||
ggml_type type_v,
|
||||
layer_filter_cb && filter,
|
||||
ggml_type type_r,
|
||||
ggml_type type_s,
|
||||
bool offload,
|
||||
uint32_t kv_size,
|
||||
uint32_t mem_size,
|
||||
uint32_t n_seq_max);
|
||||
|
||||
~llama_kv_cache_recurrent() = default;
|
||||
~llama_memory_recurrent() = default;
|
||||
|
||||
//
|
||||
// llama_memory_i
|
||||
|
|
@ -51,7 +56,7 @@ public:
|
|||
|
||||
bool prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
|
||||
// find a contiguous slot of kv cells and emplace the ubatch there
|
||||
// find a contiguous slot of memory cells and emplace the ubatch there
|
||||
bool find_slot(const llama_ubatch & ubatch);
|
||||
|
||||
bool get_can_shift() const override;
|
||||
|
|
@ -72,7 +77,7 @@ public:
|
|||
int32_t rs_z = -1;
|
||||
|
||||
// TODO: optimize for recurrent state needs
|
||||
struct kv_cell {
|
||||
struct mem_cell {
|
||||
llama_pos pos = -1;
|
||||
int32_t src = -1; // used to know where states should be copied from
|
||||
int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
|
||||
|
|
@ -88,15 +93,16 @@ public:
|
|||
return seq_id.empty();
|
||||
}
|
||||
|
||||
bool is_same_seq(const kv_cell & other) const {
|
||||
bool is_same_seq(const mem_cell & other) const {
|
||||
return seq_id == other.seq_id;
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<kv_cell> cells;
|
||||
std::vector<mem_cell> cells;
|
||||
|
||||
std::vector<ggml_tensor *> k_l; // per layer
|
||||
std::vector<ggml_tensor *> v_l;
|
||||
// per layer
|
||||
std::vector<ggml_tensor *> r_l;
|
||||
std::vector<ggml_tensor *> s_l;
|
||||
|
||||
private:
|
||||
//const llama_model & model;
|
||||
|
|
@ -109,8 +115,8 @@ private:
|
|||
|
||||
size_t total_size() const;
|
||||
|
||||
size_t size_k_bytes() const;
|
||||
size_t size_v_bytes() const;
|
||||
size_t size_r_bytes() const;
|
||||
size_t size_s_bytes() const;
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||
|
|
@ -119,24 +125,22 @@ private:
|
|||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
};
|
||||
|
||||
class llama_kv_cache_recurrent_state : public llama_memory_state_i {
|
||||
class llama_memory_recurrent_state : public llama_memory_state_i {
|
||||
public:
|
||||
// used for errors
|
||||
llama_kv_cache_recurrent_state(llama_memory_status status);
|
||||
llama_memory_recurrent_state(llama_memory_status status);
|
||||
|
||||
// used to create a full-cache state
|
||||
llama_kv_cache_recurrent_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_recurrent * kv);
|
||||
llama_memory_recurrent_state(
|
||||
llama_memory_recurrent * mem);
|
||||
|
||||
// used to create a state from a batch
|
||||
llama_kv_cache_recurrent_state(
|
||||
llama_memory_status status,
|
||||
llama_kv_cache_recurrent * kv,
|
||||
llama_memory_recurrent_state(
|
||||
llama_memory_recurrent * mem,
|
||||
llama_sbatch sbatch,
|
||||
std::vector<llama_ubatch> ubatches);
|
||||
|
||||
virtual ~llama_kv_cache_recurrent_state();
|
||||
virtual ~llama_memory_recurrent_state();
|
||||
|
||||
//
|
||||
// llama_memory_state_i
|
||||
|
|
@ -151,23 +155,23 @@ public:
|
|||
const llama_ubatch & get_ubatch() const override;
|
||||
|
||||
//
|
||||
// llama_kv_cache_recurrent_state specific API
|
||||
// llama_memory_recurrent_state specific API
|
||||
//
|
||||
|
||||
uint32_t get_n_kv() const;
|
||||
uint32_t get_n_rs() const;
|
||||
uint32_t get_head() const;
|
||||
int32_t get_rs_z() const;
|
||||
uint32_t get_size() const;
|
||||
|
||||
ggml_tensor * get_k_l(int32_t il) const;
|
||||
ggml_tensor * get_v_l(int32_t il) const;
|
||||
ggml_tensor * get_r_l(int32_t il) const;
|
||||
ggml_tensor * get_s_l(int32_t il) const;
|
||||
|
||||
int32_t s_copy(int i) const;
|
||||
|
||||
private:
|
||||
const llama_memory_status status;
|
||||
|
||||
llama_kv_cache_recurrent * kv;
|
||||
llama_memory_recurrent * mem;
|
||||
|
||||
llama_sbatch sbatch;
|
||||
|
||||
|
|
@ -8,7 +8,8 @@
|
|||
|
||||
#include "llama-kv-cache-unified.h"
|
||||
#include "llama-kv-cache-unified-iswa.h"
|
||||
#include "llama-kv-cache-recurrent.h"
|
||||
#include "llama-memory-hybrid.h"
|
||||
#include "llama-memory-recurrent.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
|
||||
|
|
@ -474,6 +475,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
|
||||
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
|
||||
std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
|
||||
std::fill(
|
||||
hparams.recurrent_layer_arr.begin(),
|
||||
hparams.recurrent_layer_arr.end(),
|
||||
llm_arch_is_recurrent(ml.get_arch()));
|
||||
|
||||
std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
|
||||
|
||||
|
|
@ -9199,7 +9204,7 @@ struct llm_build_mamba : public llm_graph_context {
|
|||
// {n_embd, n_tokens}
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
ggml_tensor * state_copy = build_inp_s_copy();
|
||||
auto * rs_inp = build_rs_inp();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
// norm
|
||||
|
|
@ -9209,9 +9214,9 @@ struct llm_build_mamba : public llm_graph_context {
|
|||
cb(cur, "attn_norm", il);
|
||||
|
||||
if (model.arch == LLM_ARCH_MAMBA2) {
|
||||
cur = build_mamba2_layer(gf, cur, state_copy, ubatch, il);
|
||||
cur = build_mamba2_layer(rs_inp, gf, cur, ubatch, il);
|
||||
} else {
|
||||
cur = build_mamba_layer(gf, cur, state_copy, ubatch, il);
|
||||
cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
|
|
@ -9249,12 +9254,12 @@ struct llm_build_mamba : public llm_graph_context {
|
|||
}
|
||||
|
||||
ggml_tensor * build_mamba_layer(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * state_copy,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
||||
|
||||
const auto kv_head = kv_state->get_head();
|
||||
|
||||
|
|
@ -9276,10 +9281,10 @@ struct llm_build_mamba : public llm_graph_context {
|
|||
GGML_ASSERT(ubatch.equal_seqs);
|
||||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||
|
||||
ggml_tensor * conv_states_all = kv_state->get_k_l(il);
|
||||
ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
|
||||
ggml_tensor * conv_states_all = kv_state->get_r_l(il);
|
||||
ggml_tensor * ssm_states_all = kv_state->get_s_l(il);
|
||||
|
||||
ggml_tensor * conv = build_recurrent_state(gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs);
|
||||
ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
|
||||
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
|
||||
|
||||
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
|
||||
|
|
@ -9360,7 +9365,7 @@ struct llm_build_mamba : public llm_graph_context {
|
|||
return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
|
||||
};
|
||||
|
||||
ggml_tensor * y_ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), ubatch.n_seqs, get_ssm_rows);
|
||||
ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
|
||||
|
||||
// store last states
|
||||
ggml_build_forward_expand(gf,
|
||||
|
|
@ -9387,12 +9392,12 @@ struct llm_build_mamba : public llm_graph_context {
|
|||
}
|
||||
|
||||
ggml_tensor * build_mamba2_layer(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * state_copy,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
||||
|
||||
const auto kv_head = kv_state->get_head();
|
||||
|
||||
|
|
@ -9410,10 +9415,10 @@ struct llm_build_mamba : public llm_graph_context {
|
|||
GGML_ASSERT(ubatch.equal_seqs);
|
||||
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
||||
|
||||
ggml_tensor * conv_states_all = kv_state->get_k_l(il);
|
||||
ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
|
||||
ggml_tensor * conv_states_all = kv_state->get_r_l(il);
|
||||
ggml_tensor * ssm_states_all = kv_state->get_s_l(il);
|
||||
|
||||
ggml_tensor * conv = build_recurrent_state(gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs);
|
||||
ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
|
||||
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
|
||||
|
||||
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
|
||||
|
|
@ -9483,7 +9488,7 @@ struct llm_build_mamba : public llm_graph_context {
|
|||
return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
|
||||
};
|
||||
|
||||
ggml_tensor * y_ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), ubatch.n_seqs, get_ssm_rows);
|
||||
ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
|
||||
|
||||
// store last states
|
||||
ggml_build_forward_expand(gf,
|
||||
|
|
@ -12131,13 +12136,13 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|||
}
|
||||
|
||||
ggml_tensor * build_rwkv6_time_mix(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * x_prev,
|
||||
ggml_tensor * state_copy,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
||||
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_seqs = ubatch.n_seqs;
|
||||
|
|
@ -12258,9 +12263,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|||
k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
|
||||
}
|
||||
|
||||
ggml_tensor * wkv_state = build_recurrent_state(
|
||||
gf, kv_state->get_v_l(il), state_copy,
|
||||
hparams.n_embd_v_s(), n_seqs);
|
||||
ggml_tensor * wkv_state = build_rs(
|
||||
inp, gf, kv_state->get_s_l(il),
|
||||
hparams.n_embd_s(), n_seqs);
|
||||
|
||||
ggml_tensor * wkv_output;
|
||||
if (is_qrwkv) {
|
||||
|
|
@ -12278,9 +12283,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|||
wkv_state,
|
||||
ggml_view_1d(
|
||||
ctx0,
|
||||
kv_state->get_v_l(il),
|
||||
hparams.n_embd_v_s() * n_seqs,
|
||||
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il))
|
||||
kv_state->get_s_l(il),
|
||||
hparams.n_embd_s() * n_seqs,
|
||||
hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
|
||||
)
|
||||
)
|
||||
);
|
||||
|
|
@ -12314,7 +12319,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
|||
inpL = build_inp_embd(model.tok_embd);
|
||||
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
||||
|
||||
ggml_tensor * state_copy = build_inp_s_copy();
|
||||
auto * rs_inp = build_rs_inp();
|
||||
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
|
@ -12324,9 +12329,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
|||
const llama_layer * layer = &model.layers[il];
|
||||
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
||||
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
||||
gf, state_copy, ubatch, il
|
||||
);
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
|
||||
|
||||
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
|
||||
ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
|
||||
|
|
@ -12341,7 +12344,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
|||
1
|
||||
);
|
||||
|
||||
cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
|
||||
cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
|
@ -12404,14 +12407,14 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
|||
// ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
|
||||
struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
||||
llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) {
|
||||
GGML_ASSERT(n_embd == hparams.n_embd_k_s());
|
||||
GGML_ASSERT(n_embd == hparams.n_embd_r());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
ggml_tensor * state_copy = build_inp_s_copy();
|
||||
auto * rs_inp = build_rs_inp();
|
||||
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
|
@ -12421,9 +12424,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
|||
const llama_layer * layer = &model.layers[il];
|
||||
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
||||
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
||||
gf, state_copy, ubatch, il
|
||||
);
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
|
||||
|
||||
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
|
||||
cb(att_norm, "attn_norm", il);
|
||||
|
|
@ -12435,7 +12436,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
|||
1
|
||||
);
|
||||
|
||||
cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
|
||||
cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
|
||||
|
||||
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
|
||||
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
||||
|
|
@ -12523,14 +12524,14 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|||
}
|
||||
|
||||
ggml_tensor * build_rwkv7_time_mix(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * x_prev,
|
||||
ggml_tensor * state_copy,
|
||||
ggml_tensor *& first_layer_value,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
||||
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_seqs = ubatch.n_seqs;
|
||||
|
|
@ -12609,9 +12610,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|||
v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
|
||||
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
|
||||
|
||||
ggml_tensor * wkv_state = build_recurrent_state(
|
||||
gf, kv_state->get_v_l(il), state_copy,
|
||||
hparams.n_embd_v_s(), n_seqs);
|
||||
ggml_tensor * wkv_state = build_rs(
|
||||
inp, gf, kv_state->get_s_l(il),
|
||||
hparams.n_embd_s(), n_seqs);
|
||||
|
||||
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
|
||||
cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
|
||||
|
|
@ -12624,9 +12625,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|||
wkv_state,
|
||||
ggml_view_1d(
|
||||
ctx0,
|
||||
kv_state->get_v_l(il),
|
||||
hparams.n_embd_v_s() * n_seqs,
|
||||
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_state->get_v_l(il))
|
||||
kv_state->get_s_l(il),
|
||||
hparams.n_embd_s() * n_seqs,
|
||||
hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
|
||||
)
|
||||
)
|
||||
);
|
||||
|
|
@ -12667,7 +12668,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
|
|||
inpL = build_inp_embd(model.tok_embd);
|
||||
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
||||
|
||||
ggml_tensor * state_copy = build_inp_s_copy();
|
||||
auto * rs_inp = build_rs_inp();
|
||||
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
|
@ -12677,9 +12678,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
|
|||
const llama_layer * layer = &model.layers[il];
|
||||
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
||||
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
||||
gf, state_copy, ubatch, il
|
||||
);
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
|
||||
|
||||
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
|
||||
ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
|
||||
|
|
@ -12694,7 +12693,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
|
|||
1
|
||||
);
|
||||
|
||||
cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
|
||||
cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
|
@ -12752,7 +12751,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
|
|||
|
||||
struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
||||
llm_build_arwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
|
||||
GGML_ASSERT(n_embd == hparams.n_embd_k_s());
|
||||
GGML_ASSERT(n_embd == hparams.n_embd_r());
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
|
@ -12760,7 +12759,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
|||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
ggml_tensor * state_copy = build_inp_s_copy();
|
||||
auto * rs_inp = build_rs_inp();
|
||||
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
|
@ -12770,9 +12769,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
|||
const llama_layer * layer = &model.layers[il];
|
||||
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
||||
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
||||
gf, state_copy, ubatch, il
|
||||
);
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
|
||||
|
||||
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
|
||||
cb(att_norm, "attn_norm", il);
|
||||
|
|
@ -12784,7 +12781,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
|||
1
|
||||
);
|
||||
|
||||
cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
|
||||
cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
|
||||
|
||||
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
|
||||
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
||||
|
|
@ -13965,6 +13962,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
llama_memory_i * res;
|
||||
|
||||
switch (arch) {
|
||||
// Models that need specific instantiation should be handled in the
|
||||
// switch statement
|
||||
case LLM_ARCH_BERT:
|
||||
case LLM_ARCH_JINA_BERT_V2:
|
||||
case LLM_ARCH_NOMIC_BERT:
|
||||
|
|
@ -13974,23 +13973,39 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
{
|
||||
res = nullptr;
|
||||
} break;
|
||||
case LLM_ARCH_MAMBA:
|
||||
case LLM_ARCH_MAMBA2:
|
||||
case LLM_ARCH_RWKV6:
|
||||
case LLM_ARCH_RWKV6QWEN2:
|
||||
case LLM_ARCH_RWKV7:
|
||||
case LLM_ARCH_ARWKV7:
|
||||
// Models that need standard caching should rely on recurrent/hybrid
|
||||
// checks
|
||||
default:
|
||||
{
|
||||
res = new llama_kv_cache_recurrent(
|
||||
if (llm_arch_is_recurrent(arch)) {
|
||||
res = new llama_memory_recurrent(
|
||||
*this,
|
||||
nullptr,
|
||||
GGML_TYPE_F32,
|
||||
GGML_TYPE_F32,
|
||||
cparams.offload_kqv,
|
||||
std::max((uint32_t) 1, cparams.n_seq_max),
|
||||
cparams.n_seq_max);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
} else if (llm_arch_is_hybrid(arch)) {
|
||||
const auto padding = llama_kv_cache_unified::get_padding(cparams);
|
||||
|
||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
|
||||
|
||||
res = new llama_memory_hybrid(
|
||||
/* model */ *this,
|
||||
/* attn_type_k */ params.type_k,
|
||||
/* attn_type_v */ params.type_v,
|
||||
/* attn_v_trans */ !cparams.flash_attn,
|
||||
/* attn_kv_size */ cparams.n_ctx,
|
||||
/* attn_n_pad */ padding,
|
||||
/* attn_n_swa */ hparams.n_swa,
|
||||
/* attn_swa_type */ hparams.swa_type,
|
||||
/* recurrent_type_k */ GGML_TYPE_F32,
|
||||
/* recurrent_type_v */ GGML_TYPE_F32,
|
||||
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
|
||||
/* n_seq_max */ cparams.n_seq_max,
|
||||
/* offload */ cparams.offload_kqv);
|
||||
} else {
|
||||
const auto padding = llama_kv_cache_unified::get_padding(cparams);
|
||||
|
||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
|
||||
|
|
@ -14029,6 +14044,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
|
@ -14607,17 +14623,7 @@ llama_token llama_model_decoder_start_token(const llama_model * model) {
|
|||
}
|
||||
|
||||
bool llama_model_is_recurrent(const llama_model * model) {
|
||||
switch (model->arch) {
|
||||
case LLM_ARCH_MAMBA:
|
||||
case LLM_ARCH_MAMBA2:
|
||||
case LLM_ARCH_RWKV6:
|
||||
case LLM_ARCH_RWKV6QWEN2:
|
||||
case LLM_ARCH_RWKV7:
|
||||
case LLM_ARCH_ARWKV7:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
return llm_arch_is_recurrent(model->arch);
|
||||
}
|
||||
|
||||
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue