diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index fdda05d3ea..daf249422a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -57,13 +57,14 @@ add_library(llama models/deci.cpp models/deepseek.cpp models/deepseek2.cpp + models/delta-net-base.cpp models/dots1.cpp models/dream.cpp models/ernie4-5-moe.cpp models/ernie4-5.cpp + models/exaone-moe.cpp models/exaone.cpp models/exaone4.cpp - models/exaone-moe.cpp models/falcon-h1.cpp models/falcon.cpp models/gemma-embedding.cpp @@ -91,10 +92,12 @@ add_library(llama models/llama-iswa.cpp models/llama.cpp models/maincoder.cpp + models/mamba-base.cpp models/mamba.cpp models/mimo2-iswa.cpp models/minicpm3.cpp models/minimax-m2.cpp + models/mistral3.cpp models/modern-bert.cpp models/mpt.cpp models/nemotron-h.cpp @@ -118,12 +121,12 @@ add_library(llama models/qwen2moe.cpp models/qwen2vl.cpp models/qwen3.cpp - models/qwen3vl.cpp - models/qwen3vl-moe.cpp - models/qwen3moe.cpp - models/qwen3next.cpp models/qwen35.cpp models/qwen35moe.cpp + models/qwen3moe.cpp + models/qwen3next.cpp + models/qwen3vl-moe.cpp + models/qwen3vl.cpp models/refact.cpp models/rnd1.cpp models/rwkv6-base.cpp @@ -142,8 +145,6 @@ add_library(llama models/t5-enc.cpp models/wavtokenizer-dec.cpp models/xverse.cpp - models/mistral3.cpp - models/graph-context-mamba.cpp ) set_target_properties(llama PROPERTIES diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp new file mode 100644 index 0000000000..0cdf9c324b --- /dev/null +++ b/src/models/delta-net-base.cpp @@ -0,0 +1,333 @@ +#include "models.h" + +#define CHUNK_SIZE 64 + +// utility to get one slice from the third dimension +// input dim: [x, y, c, b] +// output dim: [x, y, 1, b] +static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) { + return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); +} + +llm_build_delta_net_base::llm_build_delta_net_base(const llm_graph_params & params) : llm_graph_context(params) {} + +std::pair llm_build_delta_net_base::build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(S_k == S_v); + GGML_ASSERT(H_v % H_k == 0); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); + + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); + GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + + const float scale = 1.0f / sqrtf(S_k); + + q = ggml_scale(ctx0, q, scale); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(b, "b_in", il); + cb(g, "g_in", il); + + q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] + g = ggml_permute(ctx0, g, 2, 1, 3, 0); // [ 1, n_tokens, H_v, n_seqs] + b = ggml_permute(ctx0, b, 2, 0, 1, 3); // [ 1, n_tokens, H_v, n_seqs] + + const int CS = CHUNK_SIZE; + + const int pad = (CS - n_tokens % CS) % CS; + const int n_chunks = (n_tokens + pad) / CS; + + q = ggml_pad(ctx0, q, 0, pad, 0, 0); + k = ggml_pad(ctx0, k, 0, pad, 0, 0); + v = ggml_pad(ctx0, v, 0, pad, 0, 0); + g = ggml_pad(ctx0, g, 0, pad, 0, 0); + b = ggml_pad(ctx0, b, 0, pad, 0, 0); + + ggml_tensor * v_b = ggml_mul(ctx0, v, b); + ggml_tensor * k_b = ggml_mul(ctx0, k, b); + + cb(v_b, "v_b", il); + cb(k_b, "k_b", il); + + q = ggml_reshape_4d(ctx0, q, S_k, CS, n_chunks, H_k * n_seqs); + k = ggml_reshape_4d(ctx0, k, S_k, CS, n_chunks, H_k * n_seqs); + k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs); + v = ggml_reshape_4d(ctx0, v, S_v, CS, n_chunks, H_v * n_seqs); + v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs); + + g = ggml_reshape_4d(ctx0, g, CS, 1, n_chunks, H_v * n_seqs); + b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs); + + // [CS, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_cs = ggml_cumsum(ctx0, g); + cb(g_cs, "g_cs", il); + + ggml_tensor * g_cs_i = g_cs; + ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs); + + g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs); + + // [CS, CS, n_chunks, H_v * n_seqs] + ggml_tensor * decay_mask; + decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i); + decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG); + decay_mask = ggml_exp(ctx0, decay_mask); + cb(decay_mask, "decay_mask", il); + + // [CS, CS, n_chunks, H_k * n_seqs] + ggml_tensor * kb; + kb = ggml_mul_mat(ctx0, k, k_b); + kb = ggml_mul (ctx0, kb, decay_mask); + + // [CS, CS, n_chunks, H_k * n_seqs] + ggml_tensor * attn; + attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER); + + ggml_tensor * identity; + identity = ggml_view_1d(ctx0, attn, CS, 0); + identity = ggml_fill (ctx0, identity, 1.0f); + identity = ggml_diag (ctx0, identity); + + ggml_tensor * lhs = ggml_add(ctx0, attn, identity); + cb(lhs, "dnet_add_ch_lhs", il); + + attn = ggml_neg(ctx0, attn); + + ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); + attn = ggml_add(ctx0, lin_solve, identity); + cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs] + + // [S_v, CS, n_chunks, H_v * n_seqs] + v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn); + + // [CS, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_exp = ggml_exp(ctx0, g_cs); + + k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b)); + + // [CS, S_k, n_chunks, H_k * n_seqs] + ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp); + cb(kbg, "k_beta_g_exp", il); + + // [S_k, CS, n_chunks, H_k * n_seqs] + ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn); + cb(k_cd, "k_cumdecay", il); + + // [S_k, CS, n_chunks, H_k * n_seqs] + ggml_tensor * g_exp_t = ggml_transpose(ctx0, g_exp); + ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t); + + // [CS, CS, n_chunks, H_k * n_seqs] + ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + kq = ggml_mul(ctx0, kq, decay_mask); + kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG); + cb(kq, "kq", il); + + // vectorized calculation of key_gdiff + // improved from the chunked version: + // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) + // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() + // key_gdiff = key * g_diff.unsqueeze(-1) + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + + // get last element in g_cumsum along CS dimension (ne0) + // example: [[x, y, z, ..., last], ...] -> [[last], ...] + // [1, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, 1, g_cs->ne[2], g_cs->ne[3], + g_cs->nb[1], + g_cs->nb[2], + g_cs->nb[3], + ggml_row_size(g_cs->type, g_cs->ne[0] - 1)); + cb(g_last, "g_last", il); + + // TODO: remove this cont when CUDA supports non-cont unary ops + g_last = ggml_cont(ctx0, g_last); + + // [1, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); + cb(g_last_exp, "g_last_exp", il); + + // [CS, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last)); + cb(g_diff, "g_diff", il); + + ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); + ggml_tensor * g_diff_exp_t = ggml_transpose(ctx0, g_diff_exp); + + // [S_k, CS, n_chunks, H_v * n_seqs] + ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t); + cb(kg, "key_gdiff", il); + + // [CS, S_k, n_chunks, H_v * n_seqs] + ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg)); + cb(kg_t, "key_gdiff_t", il); + + ggml_tensor * s_t = ggml_transpose(ctx0, s); + s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs); + cb(s_t, "dnet_add_ch_state", il); + + // [CS, S_v, n_chunks, H_v * n_seqs] + ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v)); + + for (int64_t chunk = 0; chunk < n_chunks; chunk++) { + ggml_tensor * ch_k_cd = get_slice_2d(ctx0, k_cd, chunk); // [S_k, CS, 1, H_k * n_seqs] + ggml_tensor * ch_v_t = get_slice_2d(ctx0, v_t, chunk); // [ CS, S_v, 1, H_v * n_seqs] + ggml_tensor * ch_kq = get_slice_2d(ctx0, kq, chunk); // [ CS, CS, 1, H_k * n_seqs] + ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k, CS, 1, H_k * n_seqs] + ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs] + + // [CS, S_v, 1, H_v * n_seqs] + ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t); + cb(v_t_p, "v_prime", il); + + // [CS, S_v, 1, H_v * n_seqs] + ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p); + cb(v_t_new, "v_t_new", il); + + // [S_v, CS, 1, H_v * n_seqs] + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq); + cb(v_attn, "v_attn", il); + + // [S_v, CS, 1, H_v * n_seqs] + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp); + cb(attn_inter, "attn_inter", il); + + // [S_v, CS, 1, H_v * n_seqs] + ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn); + cb(o_ch, "dnet_add_ch_attn_out", il); + + v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]); + + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // TODO: head broadcast might not work here - probably will need a transpose + ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs] + + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + ggml_tensor * ch_g_last_exp = get_slice_2d(ctx0, g_last_exp, chunk); + s_t = ggml_mul(ctx0, s_t, ch_g_last_exp); + s_t = ggml_add(ctx0, s_t, kgv); + cb(s_t, "dnet_add_ch_state", il); + } + + s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs); + + // truncate padded tokens + ggml_tensor * o = ggml_view_4d(ctx0, v, + S_v, n_tokens, H_v, n_seqs, + ggml_row_size(v->type, S_v), + ggml_row_size(v->type, S_v * CS * n_chunks), + ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0); + + o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs] + s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs] + + return {o, s}; +} + +std::pair llm_build_delta_net_base::build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, // beta + ggml_tensor * s, // state + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(n_tokens == 1); + + GGML_ASSERT(S_k == S_v); + GGML_ASSERT(H_v % H_k == 0); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); + + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); + GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + + const float scale = 1.0f / sqrtf(S_k); + + q = ggml_scale(ctx0, q, scale); + + q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(b, "b_in", il); + cb(g, "g_in", il); + + g = ggml_reshape_4d(ctx0, g, 1, 1, H_v, n_seqs); + b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs); + + // [S_v, S_v, H_v, n_seqs] + g = ggml_exp(ctx0, g); + s = ggml_mul(ctx0, s, g); + + ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s)); + + // [1, S_v, H_v, n_seqs] + ggml_tensor * sk; + sk = ggml_mul (ctx0, s_t, k); + sk = ggml_sum_rows(ctx0, sk); + + // [S_v, 1, H_v, n_seqs] + ggml_tensor * d; + d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk)); + d = ggml_mul(ctx0, d, b); + + // [1, S_v, H_v, n_seqs] + ggml_tensor * d_t; + d_t = ggml_transpose(ctx0, d); + + // [S_v, S_v, H_v, n_seqs] + ggml_tensor * kd; + k = ggml_repeat(ctx0, k, s); + kd = ggml_mul (ctx0, k, d_t); + + s_t = ggml_add(ctx0, s_t, kd); + + cb(s_t, "dnet_add_ar_state", il); + + ggml_tensor * s_q = ggml_mul (ctx0, s_t, q); + ggml_tensor * o = ggml_sum_rows(ctx0, s_q); + + o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs] + s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs] + + return {o, s}; +} diff --git a/src/models/falcon-h1.cpp b/src/models/falcon-h1.cpp index b641a09407..785a7e5e66 100644 --- a/src/models/falcon-h1.cpp +++ b/src/models/falcon-h1.cpp @@ -1,9 +1,7 @@ #include "models.h" - - llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { + llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v; ggml_tensor * cur; diff --git a/src/models/granite-hybrid.cpp b/src/models/granite-hybrid.cpp index f6ca4c17a2..726ecdcca7 100644 --- a/src/models/granite-hybrid.cpp +++ b/src/models/granite-hybrid.cpp @@ -2,7 +2,7 @@ llm_build_granite_hybrid::llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { + llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); diff --git a/src/models/jamba.cpp b/src/models/jamba.cpp index a0187772cc..ceab581740 100644 --- a/src/models/jamba.cpp +++ b/src/models/jamba.cpp @@ -1,6 +1,6 @@ #include "models.h" -llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) { +llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v; ggml_tensor * cur; diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp index 942844d071..133834021d 100644 --- a/src/models/kimi-linear.cpp +++ b/src/models/kimi-linear.cpp @@ -1,6 +1,8 @@ #include "models.h" #include "ggml.h" +#include "llama-memory-recurrent.h" + #define CHUNK_SIZE 64 // Causal Conv1d function for Q,K,V @@ -65,7 +67,7 @@ static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_t } llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params), model(model) { + llm_build_mamba_base(params), model(model) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/src/models/graph-context-mamba.cpp b/src/models/mamba-base.cpp similarity index 97% rename from src/models/graph-context-mamba.cpp rename to src/models/mamba-base.cpp index b9a363b32b..aaac9487df 100644 --- a/src/models/graph-context-mamba.cpp +++ b/src/models/mamba-base.cpp @@ -1,8 +1,10 @@ #include "models.h" -llm_graph_context_mamba::llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {} +#include "llama-memory-recurrent.h" -ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * inp, +llm_build_mamba_base::llm_build_mamba_base(const llm_graph_params & params) : llm_graph_context(params) {} + +ggml_tensor * llm_build_mamba_base::build_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, @@ -143,7 +145,7 @@ ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * in return cur; } -ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * inp, +ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, diff --git a/src/models/mamba.cpp b/src/models/mamba.cpp index 46819613c2..55fd2e055c 100644 --- a/src/models/mamba.cpp +++ b/src/models/mamba.cpp @@ -1,7 +1,6 @@ #include "models.h" - -llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) { +llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/src/models/models.h b/src/models/models.h index ec6f80e526..b98e43d8e8 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -1,23 +1,48 @@ #pragma once -#include "../llama-model.h" -#include "../llama-graph.h" +#include "llama-model.h" +#include "llama-graph.h" -// TODO: remove in follow-up PR - move to .cpp files -#include "../llama-memory-recurrent.h" -#include +// +// base classes +// -struct llm_graph_context_mamba : public llm_graph_context { - llm_graph_context_mamba(const llm_graph_params & params); +struct llm_build_mamba_base : public llm_graph_context { + llm_build_mamba_base(const llm_graph_params & params); - virtual ~llm_graph_context_mamba() = default; + virtual ~llm_build_mamba_base() = default; ggml_tensor * build_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il); ggml_tensor * build_mamba2_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il) const; }; -// Base class for RWKV-related models +struct llm_build_delta_net_base : public llm_graph_context { + llm_build_delta_net_base(const llm_graph_params & params); + + virtual ~llm_build_delta_net_base() = default; + + // returns pair of output and new state + std::pair build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); + + // returns pair of output and new state + std::pair build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); +}; + struct llm_build_rwkv6_base : public llm_graph_context { const llama_model & model; @@ -58,6 +83,10 @@ struct llm_build_rwkv7_base : public llm_graph_context { int il) const; }; +// +// models +// + struct llm_build_afmoe : public llm_graph_context { llm_build_afmoe(const llama_model & model, const llm_graph_params & params); }; @@ -175,7 +204,7 @@ struct llm_build_falcon : public llm_graph_context { llm_build_falcon(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_falcon_h1 : public llm_graph_context_mamba { +struct llm_build_falcon_h1 : public llm_build_mamba_base { llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params); }; @@ -253,7 +282,7 @@ private: const int il); }; -struct llm_build_granite_hybrid : public llm_graph_context_mamba { +struct llm_build_granite_hybrid : public llm_build_mamba_base { llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params); ggml_tensor * build_layer_ffn(ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, const int il); ggml_tensor * build_attention_layer(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, @@ -284,11 +313,12 @@ struct llm_build_jais : public llm_graph_context { llm_build_jais(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_jamba : public llm_graph_context_mamba { +struct llm_build_jamba : public llm_build_mamba_base { llm_build_jamba(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_kimi_linear : public llm_graph_context_mamba { +// TODO: derive llm_build_delta_net_base instead +struct llm_build_kimi_linear : public llm_build_mamba_base { llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params); std::pair build_kda_autoregressive( @@ -347,7 +377,7 @@ struct llm_build_maincoder : public llm_graph_context { llm_build_maincoder(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_mamba : public llm_graph_context_mamba { +struct llm_build_mamba : public llm_build_mamba_base { llm_build_mamba(const llama_model & model, const llm_graph_params & params); }; @@ -379,11 +409,11 @@ struct llm_build_nemotron : public llm_graph_context { llm_build_nemotron(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_nemotron_h : public llm_graph_context_mamba { +struct llm_build_nemotron_h : public llm_build_mamba_base { llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params); - ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il); + ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il); ggml_tensor * build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn, - const llama_model & model, const int64_t n_embd_head, const int il); + const llama_model & model, int64_t n_embd_head, int il); }; struct llm_build_neo_bert : public llm_graph_context { @@ -428,7 +458,7 @@ struct llm_build_phi3 : public llm_graph_context { llm_build_phi3(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_plamo2 : public llm_graph_context_mamba { +struct llm_build_plamo2 : public llm_build_mamba_base { llm_build_plamo2(const llama_model & model, const llm_graph_params & params); private: ggml_tensor * build_plamo2_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il); @@ -477,7 +507,7 @@ struct llm_build_qwen3vlmoe : public llm_graph_context { llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_qwen3next : public llm_graph_context_mamba { +struct llm_build_qwen3next : public llm_build_delta_net_base { llm_build_qwen3next(const llama_model & model, const llm_graph_params & params); private: ggml_tensor * build_layer_attn( @@ -495,26 +525,6 @@ private: ggml_tensor * cur, int il); - // returns pair of output and new state - std::pair build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il); - - // returns pair of output and new state - std::pair build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il); - ggml_tensor * build_norm_gated( ggml_tensor * input, ggml_tensor * weights, @@ -529,7 +539,7 @@ private: const llama_model & model; }; -struct llm_build_qwen35 : public llm_graph_context_mamba { +struct llm_build_qwen35 : public llm_build_delta_net_base { llm_build_qwen35(const llama_model & model, const llm_graph_params & params); private: ggml_tensor * build_layer_attn( @@ -542,38 +552,12 @@ private: ggml_tensor * build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, int il); ggml_tensor * build_layer_ffn( ggml_tensor * cur, int il); - // returns pair of output and new state - std::pair build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il); - - // returns pair of output and new state - std::pair build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il); - ggml_tensor * build_norm_gated( ggml_tensor * input, ggml_tensor * weights, @@ -588,7 +572,7 @@ private: const llama_model & model; }; -struct llm_build_qwen35moe : public llm_graph_context_mamba { +struct llm_build_qwen35moe : public llm_build_delta_net_base { llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params); private: ggml_tensor * build_layer_attn( @@ -601,38 +585,12 @@ private: ggml_tensor * build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, int il); ggml_tensor * build_layer_ffn( ggml_tensor * cur, int il); - // returns pair of output and new state - std::pair build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il); - - // returns pair of output and new state - std::pair build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il); - ggml_tensor * build_norm_gated( ggml_tensor * input, ggml_tensor * weights, diff --git a/src/models/nemotron-h.cpp b/src/models/nemotron-h.cpp index 079c730ac2..d61d62a8c9 100644 --- a/src/models/nemotron-h.cpp +++ b/src/models/nemotron-h.cpp @@ -1,9 +1,7 @@ #include "models.h" - - llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { + llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -65,8 +63,8 @@ llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn, const llama_model & model, - const int64_t n_embd_head, - const int il) { + int64_t n_embd_head, + int il) { // compute Q and K ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); cb(Qcur, "Qcur", il); @@ -106,7 +104,7 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * return cur; } -ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il) { +ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il) { if (model.layers[il].ffn_gate_inp == nullptr) { cur = build_ffn(cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, diff --git a/src/models/plamo2.cpp b/src/models/plamo2.cpp index 31115a08f9..3af236843b 100644 --- a/src/models/plamo2.cpp +++ b/src/models/plamo2.cpp @@ -1,7 +1,9 @@ #include "models.h" +#include "llama-memory-recurrent.h" + llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { + llm_build_mamba_base(params) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index 592c170457..cb652a2cce 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -1,10 +1,9 @@ -#include "ggml.h" #include "models.h" -#define CHUNK_SIZE 64 +#include "llama-memory-recurrent.h" llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params), model(model) { + llm_build_delta_net_base(params), model(model) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -24,17 +23,6 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - ggml_tensor * causal_mask = - ggml_tri(ctx0, ggml_fill(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f), - GGML_TRI_TYPE_LOWER); - - ggml_tensor * identity = ggml_diag(ctx0, ggml_fill(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f)); - ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity); - - ggml_build_forward_expand(gf, causal_mask); - ggml_build_forward_expand(gf, identity); - ggml_build_forward_expand(gf, diag_mask); - for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -44,7 +32,7 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa // Determine layer type and build appropriate attention mechanism if (hparams.is_recurrent(il)) { // Linear attention layer (gated delta net) - cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il); + cur = build_layer_attn_linear(inp->get_recr(), cur, il); } else { // Full attention layer cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); @@ -94,361 +82,6 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa ggml_build_forward_expand(gf, cur); } -// utility to get one slice from the third dimension -// input dim: [x, y, c, b] -// output dim: [x, y, 1, b] -static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) { - return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], - t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); -} - -std::pair llm_build_qwen35::build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - - const float scale = 1.0f / sqrtf(S_v); - - q = ggml_scale(ctx0, q, scale); - - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(g, "g_in", il); - - q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); - - beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - cb(q, "q_perm", il); - cb(k, "k_perm", il); - cb(v, "v_perm", il); - cb(beta, "beta_perm", il); - cb(g, "g_perm", il); - cb(state, "state_in", il); - - GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); - - // Do padding - const int64_t chunk_size = CHUNK_SIZE; - - const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; - const int64_t n_chunks = (n_tokens + pad) / chunk_size; - - q = ggml_pad(ctx0, q, 0, pad, 0, 0); - k = ggml_pad(ctx0, k, 0, pad, 0, 0); - v = ggml_pad(ctx0, v, 0, pad, 0, 0); - g = ggml_pad(ctx0, g, pad, 0, 0, 0); - beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); - - cb(q, "q_pad", il); - cb(k, "k_pad", il); - cb(v, "v_pad", il); - cb(beta, "beta_pad", il); - cb(g, "g_pad", il); - - ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); - ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); - - cb(v_beta, "v_beta", il); - cb(k_beta, "k_beta", il); - - q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); - k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); - k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); - v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); - v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); - - g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); - beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); - - ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); - cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); - ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); - - ggml_tensor * gcs_j_broadcast = - ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); - - ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); - cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - decay_mask = ggml_exp(ctx0, decay_mask); - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - - ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); - - ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); - ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); - cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); - ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); - - ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); - attn = ggml_mul(ctx0, lin_solve, causal_mask); - attn = ggml_add(ctx0, attn, identity); - cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); - - ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); - ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); - - ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); - cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * k_cumdecay = - ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); - cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q); - attn_kq = ggml_mul(ctx0, attn_kq, decay_mask); - attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); - cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - - // vectorized calculation of key_gdiff - // improved from the chunked version: - // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) - // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() - // key_gdiff = key * g_diff.unsqueeze(-1) - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - - // get last element in g_cumsum along chunk_size dimension (ne0) - // example: [[x, y, z, ..., last], ...] -> [[last], ...] - ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3], - g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3], - (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum)); - g_last = ggml_cont(ctx0, g_last); - cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); - cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last)); - cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); - ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp, - 1, chunk_size, n_chunks, g_diff_exp->ne[3]); - - ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t); - cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)); - cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs) - - // state to be updated per chunk - ggml_tensor * new_state = state; // ggml_dup(ctx0, state); - cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs) - - // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs) - ggml_tensor * core_attn_out = nullptr; - - for (int64_t chunk = 0; chunk < n_chunks; chunk++) { - // shape: (S_k, chunk_size, 1, H_k * n_seqs) - ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul - - // shape: (S_v, chunk_size, 1, H_v * n_seqs) - ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat - - // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul - - // shape: (chunk_size, 1, H_v * n_seqs) - ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat - - // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) - // replaced by precomputed attn_kq - ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk); - cb(attn_chunk, "attn_chunk", il); - - ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); - - // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state - ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); - cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs) - - // v_new = v_i - v_prime - ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); - ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); - cb(v_new, "v_new_chunk", il); - - // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); - ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); - cb(attn_inter, "attn_inter_chunk", il); - - // core_attn_out[:, :, i] = attn_inter + attn @ v_new - ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk); - cb(v_attn, "v_attn_chunk", il); - - ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); - cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs) - - core_attn_out = core_attn_out == nullptr - ? core_attn_out_chunk - : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2); - - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk); - //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why? - ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t); - - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk)); - new_state = ggml_add(ctx0, - ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)), - ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); - } - - // truncate padded tokens - ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, - S_v, n_tokens, H_v, n_seqs, - ggml_row_size(core_attn_out->type, S_v), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); - output_tokens = ggml_cont(ctx0, output_tokens); - cb(output_tokens, "output_tokens", il); - - // permute back to (S_v, H_v, n_tokens, n_seqs) - output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); - output_tokens = ggml_cont(ctx0, output_tokens); - - return {output_tokens, new_state}; -} - -std::pair llm_build_qwen35::build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - - const float scale = 1.0f / sqrtf(S_v); - - q = ggml_scale(ctx0, q, scale); - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(g, "g_in", il); - - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); - ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); - - // Apply exponential to g_t - g_t = ggml_exp(ctx0, g_t); - - // Apply the gated delta rule for the single timestep - // last_recurrent_state = last_recurrent_state * g_t - state = ggml_mul(ctx0, state, g_t); - - // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs); - ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); - // we need to sum over dim=-2, so we transpose, sum, then transpose again - kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)))); - - // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v) - ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); - // delta = (v_t - kv_mem) * beta_t - ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs] - ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); - - // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta - ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta); - state = ggml_add(ctx0, state, k_t_delta); - - // Compute the attention output - // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t - ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); - // again, since it's over dim = -2, transpose, sum, transpose back - ggml_tensor * core_attn_out = - ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q)))); - - // core_attn_out should be [S_v, 1, H_v, n_seqs] after this - cb(core_attn_out, "output_tokens", il); - cb(state, "new_state", il); - - return {core_attn_out, state}; -} - std::pair llm_build_qwen35::build_qkvz( ggml_tensor * input, int il) { @@ -560,9 +193,6 @@ ggml_tensor * llm_build_qwen35::build_layer_attn( ggml_tensor * llm_build_qwen35::build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, int il) { const auto * mctx_cur = inp->mctx; @@ -688,7 +318,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( if (n_seq_tokens == 1) { attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); } else { - attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); + attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il); } ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index 0db8f825c6..7e601d018e 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -1,10 +1,9 @@ -#include "ggml.h" #include "models.h" -#define CHUNK_SIZE 64 +#include "llama-memory-recurrent.h" llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params), model(model) { + llm_build_delta_net_base(params), model(model) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -24,17 +23,6 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - ggml_tensor * causal_mask = - ggml_tri(ctx0, ggml_fill(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f), - GGML_TRI_TYPE_LOWER); - - ggml_tensor * identity = ggml_diag(ctx0, ggml_fill(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f)); - ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity); - - ggml_build_forward_expand(gf, causal_mask); - ggml_build_forward_expand(gf, identity); - ggml_build_forward_expand(gf, diag_mask); - for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -44,7 +32,7 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr // Determine layer type and build appropriate attention mechanism if (hparams.is_recurrent(il)) { // Linear attention layer (gated delta net) - cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il); + cur = build_layer_attn_linear(inp->get_recr(), cur, il); } else { // Full attention layer cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); @@ -94,362 +82,6 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr ggml_build_forward_expand(gf, cur); } -// utility to get one slice from the third dimension -// input dim: [x, y, c, b] -// output dim: [x, y, 1, b] -static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) { - return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], - t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); -} - -std::pair llm_build_qwen35moe::build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - - const float scale = 1.0f / sqrtf(S_v); - - q = ggml_scale(ctx0, q, scale); - - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(g, "g_in", il); - - q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); - - beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - cb(q, "q_perm", il); - cb(k, "k_perm", il); - cb(v, "v_perm", il); - cb(beta, "beta_perm", il); - cb(g, "g_perm", il); - cb(state, "state_in", il); - - GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); - - // Do padding - const int64_t chunk_size = CHUNK_SIZE; - - const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; - const int64_t n_chunks = (n_tokens + pad) / chunk_size; - - q = ggml_pad(ctx0, q, 0, pad, 0, 0); - k = ggml_pad(ctx0, k, 0, pad, 0, 0); - v = ggml_pad(ctx0, v, 0, pad, 0, 0); - g = ggml_pad(ctx0, g, pad, 0, 0, 0); - beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); - - cb(q, "q_pad", il); - cb(k, "k_pad", il); - cb(v, "v_pad", il); - cb(beta, "beta_pad", il); - cb(g, "g_pad", il); - - ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); - ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); - - cb(v_beta, "v_beta", il); - cb(k_beta, "k_beta", il); - - q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); - k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); - k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); - v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); - v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); - - g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); - beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); - - ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); - cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); - ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); - - ggml_tensor * gcs_j_broadcast = - ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); - - ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); - cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - decay_mask = ggml_exp(ctx0, decay_mask); - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - - ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); - - ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); - ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); - cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); - ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); - - ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); - attn = ggml_mul(ctx0, lin_solve, causal_mask); - attn = ggml_add(ctx0, attn, identity); - cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); - - ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); - ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); - - ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); - cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * k_cumdecay = - ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); - cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q); - attn_kq = ggml_mul(ctx0, attn_kq, decay_mask); - attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); - cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - - // vectorized calculation of key_gdiff - // improved from the chunked version: - // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) - // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() - // key_gdiff = key * g_diff.unsqueeze(-1) - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - - // get last element in g_cumsum along chunk_size dimension (ne0) - // example: [[x, y, z, ..., last], ...] -> [[last], ...] - ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3], - g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3], - (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum)); - g_last = ggml_cont(ctx0, g_last); - cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); - cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last)); - cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); - ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp, - 1, chunk_size, n_chunks, g_diff_exp->ne[3]); - - ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t); - cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)); - cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs) - - - // state to be updated per chunk - ggml_tensor * new_state = state; // ggml_dup(ctx0, state); - cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs) - - // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs) - ggml_tensor * core_attn_out = nullptr; - - for (int64_t chunk = 0; chunk < n_chunks; chunk++) { - // shape: (S_k, chunk_size, 1, H_k * n_seqs) - ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul - - // shape: (S_v, chunk_size, 1, H_v * n_seqs) - ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat - - // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul - - // shape: (chunk_size, 1, H_v * n_seqs) - ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat - - // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) - // replaced by precomputed attn_kq - ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk); - cb(attn_chunk, "attn_chunk", il); - - ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); - - // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state - ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); - cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs) - - // v_new = v_i - v_prime - ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); - ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); - cb(v_new, "v_new_chunk", il); - - // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); - ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); - cb(attn_inter, "attn_inter_chunk", il); - - // core_attn_out[:, :, i] = attn_inter + attn @ v_new - ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk); - cb(v_attn, "v_attn_chunk", il); - - ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); - cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs) - - core_attn_out = core_attn_out == nullptr - ? core_attn_out_chunk - : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2); - - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk); - //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why? - ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t); - - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk)); - new_state = ggml_add(ctx0, - ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)), - ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); - } - - // truncate padded tokens - ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, - S_v, n_tokens, H_v, n_seqs, - ggml_row_size(core_attn_out->type, S_v), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); - output_tokens = ggml_cont(ctx0, output_tokens); - cb(output_tokens, "output_tokens", il); - - // permute back to (S_v, H_v, n_tokens, n_seqs) - output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); - output_tokens = ggml_cont(ctx0, output_tokens); - - return {output_tokens, new_state}; -} - -std::pair llm_build_qwen35moe::build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - - const float scale = 1.0f / sqrtf(S_v); - - q = ggml_scale(ctx0, q, scale); - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(g, "g_in", il); - - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); - ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); - - // Apply exponential to g_t - g_t = ggml_exp(ctx0, g_t); - - // Apply the gated delta rule for the single timestep - // last_recurrent_state = last_recurrent_state * g_t - state = ggml_mul(ctx0, state, g_t); - - // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs); - ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); - // we need to sum over dim=-2, so we transpose, sum, then transpose again - kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)))); - - // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v) - ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); - // delta = (v_t - kv_mem) * beta_t - ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs] - ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); - - // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta - ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta); - state = ggml_add(ctx0, state, k_t_delta); - - // Compute the attention output - // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t - ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); - // again, since it's over dim = -2, transpose, sum, transpose back - ggml_tensor * core_attn_out = - ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q)))); - - // core_attn_out should be [S_v, 1, H_v, n_seqs] after this - cb(core_attn_out, "output_tokens", il); - cb(state, "new_state", il); - - return {core_attn_out, state}; -} - std::pair llm_build_qwen35moe::build_qkvz( ggml_tensor * input, int il) { @@ -561,9 +193,6 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn( ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, int il) { const auto * mctx_cur = inp->mctx; @@ -689,7 +318,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( if (n_seq_tokens == 1) { attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); } else { - attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); + attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il); } ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index aea8b29513..0fdf2d42c2 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -1,10 +1,9 @@ -#include "ggml.h" #include "models.h" -#define CHUNK_SIZE 64 +#include "llama-memory-recurrent.h" llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params), model(model) { + llm_build_delta_net_base(params), model(model) { ggml_tensor * cur; ggml_tensor * inpL; @@ -83,326 +82,6 @@ static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); } -std::pair llm_build_qwen3next::build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * b, - ggml_tensor * s, - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(S_k == S_v); - GGML_ASSERT(H_v % H_k == 0); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); - - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); - GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - - const float scale = 1.0f / sqrtf(S_k); - - q = ggml_scale(ctx0, q, scale); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(b, "b_in", il); - cb(g, "g_in", il); - - q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] - k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] - v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] - g = ggml_permute(ctx0, g, 2, 1, 3, 0); // [ 1, n_tokens, H_v, n_seqs] - b = ggml_permute(ctx0, b, 2, 0, 1, 3); // [ 1, n_tokens, H_v, n_seqs] - - const int CS = CHUNK_SIZE; - - const int pad = (CS - n_tokens % CS) % CS; - const int n_chunks = (n_tokens + pad) / CS; - - q = ggml_pad(ctx0, q, 0, pad, 0, 0); - k = ggml_pad(ctx0, k, 0, pad, 0, 0); - v = ggml_pad(ctx0, v, 0, pad, 0, 0); - g = ggml_pad(ctx0, g, 0, pad, 0, 0); - b = ggml_pad(ctx0, b, 0, pad, 0, 0); - - ggml_tensor * v_b = ggml_mul(ctx0, v, b); - ggml_tensor * k_b = ggml_mul(ctx0, k, b); - - cb(v_b, "v_b", il); - cb(k_b, "k_b", il); - - q = ggml_reshape_4d(ctx0, q, S_k, CS, n_chunks, H_k * n_seqs); - k = ggml_reshape_4d(ctx0, k, S_k, CS, n_chunks, H_k * n_seqs); - k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs); - v = ggml_reshape_4d(ctx0, v, S_v, CS, n_chunks, H_v * n_seqs); - v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs); - - g = ggml_reshape_4d(ctx0, g, CS, 1, n_chunks, H_v * n_seqs); - b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs); - - // [CS, 1, n_chunks, H_v * n_seqs] - ggml_tensor * g_cs = ggml_cumsum(ctx0, g); - cb(g_cs, "g_cs", il); - - ggml_tensor * g_cs_i = g_cs; - ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs); - - g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs); - - // [CS, CS, n_chunks, H_v * n_seqs] - ggml_tensor * decay_mask; - decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i); - decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG); - decay_mask = ggml_exp(ctx0, decay_mask); - cb(decay_mask, "decay_mask", il); - - // [CS, CS, n_chunks, H_k * n_seqs] - ggml_tensor * kb; - kb = ggml_mul_mat(ctx0, k, k_b); - kb = ggml_mul (ctx0, kb, decay_mask); - - // [CS, CS, n_chunks, H_k * n_seqs] - ggml_tensor * attn; - attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER); - - ggml_tensor * identity; - identity = ggml_view_1d(ctx0, attn, CS, 0); - identity = ggml_fill (ctx0, identity, 1.0f); - identity = ggml_diag (ctx0, identity); - - ggml_tensor * lhs = ggml_add(ctx0, attn, identity); - cb(lhs, "dnet_add_ch_lhs", il); - - attn = ggml_neg(ctx0, attn); - - ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); - attn = ggml_add(ctx0, lin_solve, identity); - cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs] - - // [S_v, CS, n_chunks, H_v * n_seqs] - v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn); - - // [CS, 1, n_chunks, H_v * n_seqs] - ggml_tensor * g_exp = ggml_exp(ctx0, g_cs); - - k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b)); - - // [CS, S_k, n_chunks, H_k * n_seqs] - ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp); - cb(kbg, "k_beta_g_exp", il); - - // [S_k, CS, n_chunks, H_k * n_seqs] - ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn); - cb(k_cd, "k_cumdecay", il); - - // [S_k, CS, n_chunks, H_k * n_seqs] - ggml_tensor * g_exp_t = ggml_transpose(ctx0, g_exp); - ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t); - - // [CS, CS, n_chunks, H_k * n_seqs] - ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - kq = ggml_mul(ctx0, kq, decay_mask); - kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG); - cb(kq, "kq", il); - - // vectorized calculation of key_gdiff - // improved from the chunked version: - // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) - // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() - // key_gdiff = key * g_diff.unsqueeze(-1) - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - - // get last element in g_cumsum along CS dimension (ne0) - // example: [[x, y, z, ..., last], ...] -> [[last], ...] - // [1, 1, n_chunks, H_v * n_seqs] - ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, 1, g_cs->ne[2], g_cs->ne[3], - g_cs->nb[1], - g_cs->nb[2], - g_cs->nb[3], - ggml_row_size(g_cs->type, g_cs->ne[0] - 1)); - cb(g_last, "g_last", il); - - // TODO: remove this cont when CUDA supports non-cont unary ops - g_last = ggml_cont(ctx0, g_last); - - // [1, 1, n_chunks, H_v * n_seqs] - ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); - cb(g_last_exp, "g_last_exp", il); - - // [CS, 1, n_chunks, H_v * n_seqs] - ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last)); - cb(g_diff, "g_diff", il); - - ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); - ggml_tensor * g_diff_exp_t = ggml_transpose(ctx0, g_diff_exp); - - // [S_k, CS, n_chunks, H_v * n_seqs] - ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t); - cb(kg, "key_gdiff", il); - - // [CS, S_k, n_chunks, H_v * n_seqs] - ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg)); - cb(kg_t, "key_gdiff_t", il); - - ggml_tensor * s_t = ggml_transpose(ctx0, s); - s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs); - cb(s_t, "dnet_add_ch_state", il); - - // [CS, S_v, n_chunks, H_v * n_seqs] - ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v)); - - for (int64_t chunk = 0; chunk < n_chunks; chunk++) { - ggml_tensor * ch_k_cd = get_slice_2d(ctx0, k_cd, chunk); // [S_k, CS, 1, H_k * n_seqs] - ggml_tensor * ch_v_t = get_slice_2d(ctx0, v_t, chunk); // [ CS, S_v, 1, H_v * n_seqs] - ggml_tensor * ch_kq = get_slice_2d(ctx0, kq, chunk); // [ CS, CS, 1, H_k * n_seqs] - ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k, CS, 1, H_k * n_seqs] - ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs] - - // [CS, S_v, 1, H_v * n_seqs] - ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t); - cb(v_t_p, "v_prime", il); - - // [CS, S_v, 1, H_v * n_seqs] - ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p); - cb(v_t_new, "v_t_new", il); - - // [S_v, CS, 1, H_v * n_seqs] - ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq); - cb(v_attn, "v_attn", il); - - // [S_v, CS, 1, H_v * n_seqs] - ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp); - cb(attn_inter, "attn_inter", il); - - // [S_v, CS, 1, H_v * n_seqs] - ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn); - cb(o_ch, "dnet_add_ch_attn_out", il); - - v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]); - - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - // TODO: head broadcast might not work here - probably will need a transpose - ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs] - - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - ggml_tensor * ch_g_last_exp = get_slice_2d(ctx0, g_last_exp, chunk); - s_t = ggml_mul(ctx0, s_t, ch_g_last_exp); - s_t = ggml_add(ctx0, s_t, kgv); - cb(s_t, "dnet_add_ch_state", il); - } - - s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs); - - // truncate padded tokens - ggml_tensor * o = ggml_view_4d(ctx0, v, - S_v, n_tokens, H_v, n_seqs, - ggml_row_size(v->type, S_v), - ggml_row_size(v->type, S_v * CS * n_chunks), - ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0); - - o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs] - s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs] - - return {o, s}; -} - -std::pair llm_build_qwen3next::build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * b, // beta - ggml_tensor * s, // state - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(n_tokens == 1); - - GGML_ASSERT(S_k == S_v); - GGML_ASSERT(H_v % H_k == 0); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); - - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); - GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - - const float scale = 1.0f / sqrtf(S_k); - - q = ggml_scale(ctx0, q, scale); - - q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] - k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] - v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(b, "b_in", il); - cb(g, "g_in", il); - - g = ggml_reshape_4d(ctx0, g, 1, 1, H_v, n_seqs); - b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs); - - // [S_v, S_v, H_v, n_seqs] - g = ggml_exp(ctx0, g); - s = ggml_mul(ctx0, s, g); - - ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s)); - - // [1, S_v, H_v, n_seqs] - ggml_tensor * sk; - sk = ggml_mul (ctx0, s_t, k); - sk = ggml_sum_rows(ctx0, sk); - - // [S_v, 1, H_v, n_seqs] - ggml_tensor * d; - d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk)); - d = ggml_mul(ctx0, d, b); - - // [1, S_v, H_v, n_seqs] - ggml_tensor * d_t; - d_t = ggml_transpose(ctx0, d); - - // [S_v, S_v, H_v, n_seqs] - ggml_tensor * kd; - k = ggml_repeat(ctx0, k, s); - kd = ggml_mul (ctx0, k, d_t); - - s_t = ggml_add(ctx0, s_t, kd); - - cb(s_t, "dnet_add_ar_state", il); - - ggml_tensor * s_q = ggml_mul (ctx0, s_t, q); - ggml_tensor * o = ggml_sum_rows(ctx0, s_q); - - o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs] - s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs] - - return {o, s}; -} - ggml_tensor * llm_build_qwen3next::build_norm_gated( ggml_tensor * input, ggml_tensor * weights, diff --git a/src/models/rwkv6-base.cpp b/src/models/rwkv6-base.cpp index 7beed2daff..83aeab7280 100644 --- a/src/models/rwkv6-base.cpp +++ b/src/models/rwkv6-base.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include "llama-memory-recurrent.h" + llm_build_rwkv6_base::llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {} diff --git a/src/models/rwkv7-base.cpp b/src/models/rwkv7-base.cpp index cda4465384..7fcab77745 100644 --- a/src/models/rwkv7-base.cpp +++ b/src/models/rwkv7-base.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include "llama-memory-recurrent.h" + llm_build_rwkv7_base::llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {}