Getting to decode stage...

This commit is contained in:
Piotr Wilkin 2025-09-18 21:47:40 +02:00
parent c78f9fce68
commit 178230ee21
2 changed files with 228 additions and 113 deletions

View File

@ -3435,7 +3435,7 @@ struct ggml_tensor * ggml_reshape_4d(
int64_t ne2, int64_t ne2,
int64_t ne3) { int64_t ne3) {
GGML_ASSERT(ggml_is_contiguous(a)); GGML_ASSERT(ggml_is_contiguous(a));
GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3); GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3);
const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0); struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0);
@ -5441,17 +5441,25 @@ struct ggml_tensor * ggml_delta_net(
GGML_ASSERT(ggml_is_contiguous(beta)); GGML_ASSERT(ggml_is_contiguous(beta));
GGML_ASSERT(ggml_is_contiguous(state)); GGML_ASSERT(ggml_is_contiguous(state));
const int64_t S = k->ne[0]; const int64_t S_k = k->ne[0];
const int64_t H = k->ne[1]; const int64_t H_k = k->ne[1];
const int64_t n_tokens = k->ne[2]; const int64_t n_tokens = k->ne[2];
const int64_t n_seqs = state->ne[1]; const int64_t n_seqs = state->ne[1];
// Validate dimensions const int64_t S_v = v->ne[0];
GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens); const int64_t H_v = v->ne[1];
GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens);
GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens); // Validate dimensions - allow different head dimensions for q/k vs v
GGML_ASSERT(beta->ne[0] == H && beta->ne[1] == n_tokens && beta->ne[2] == n_seqs); GGML_ASSERT(v->ne[2] == n_tokens);
GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs); GGML_ASSERT(q->ne[2] == n_tokens);
GGML_ASSERT(g->ne[2] == n_tokens);
GGML_ASSERT(beta->ne[0] == H_v && beta->ne[1] == n_tokens && (beta->ne[2] == n_seqs || beta->ne[2] == 1));
GGML_ASSERT(ggml_nelements(state) == S_v * H_v * n_seqs);
// Check that q and k have the same dimensions
GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens);
GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens);
GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[2] == n_tokens);
// Apply L2 normalization to query and key if requested // Apply L2 normalization to query and key if requested
struct ggml_tensor * q_norm = q; struct ggml_tensor * q_norm = q;
@ -5466,53 +5474,101 @@ struct ggml_tensor * ggml_delta_net(
// Apply sigmoid to beta for gating // Apply sigmoid to beta for gating
struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx, beta); struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx, beta);
struct ggml_tensor * mixed_qkv = ggml_concat(ctx, q_norm, k_norm, 1);
// Apply causal 1D convolution preprocessing to mixed QKV mixed_qkv = ggml_concat(ctx, mixed_qkv, v, 1);
// Concatenate q, k, v along the feature dimension
int64_t concat_ne[4] = { q->ne[0], q->ne[1], q->ne[2], q->ne[3] * 3 }; u_int32_t dim = (S_v * H_v) + 2 * (H_k * S_k);
struct ggml_tensor * mixed_qkv = ggml_concat(ctx, q_norm, k_norm, 3);
mixed_qkv = ggml_concat(ctx, mixed_qkv, v, 3); mixed_qkv = ggml_reshape_3d(ctx, mixed_qkv, 1, dim, n_tokens);
struct ggml_tensor * mixed_qkv_padded = ggml_pad(ctx, mixed_qkv, 3, 0, 0, 0);
// Transpose for convolution: [S, H, n_tokens, n_seqs*3] -> [S, n_tokens, H, n_seqs*3]
mixed_qkv = ggml_permute(ctx, mixed_qkv, 0, 2, 1, 3); // Apply SSM convolution
struct ggml_tensor * conv_out = ggml_ssm_conv(ctx, mixed_qkv_padded, conv_weight);
// Apply causal 1D convolution
struct ggml_tensor * conv_out = ggml_conv_1d(
ctx,
conv_weight,
mixed_qkv,
1, // stride
conv_weight->ne[2] - 1, // padding (kernel_size - 1)
1 // dilation
);
// Apply bias if provided // Apply bias if provided
if (conv_bias) { if (conv_bias) {
conv_out = ggml_add(ctx, conv_out, conv_bias); conv_out = ggml_add(ctx, conv_out, conv_bias);
} }
// Apply SiLU activation // Apply SiLU activation
conv_out = ggml_silu(ctx, conv_out); conv_out = ggml_silu(ctx, conv_out);
// Transpose back: [S, n_tokens, H, n_seqs*3] -> [S, H, n_tokens, n_seqs*3] // Reshape back to 4D: [dim, n_tokens, 1] -> [dim, n_tokens, 1, 1]
conv_out = ggml_reshape_4d(ctx, conv_out, dim, n_tokens, 1, 1);
// Transpose to get the right layout: [dim, n_tokens, 1] -> [dim, 1, n_tokens, 1]
conv_out = ggml_permute(ctx, conv_out, 0, 2, 1, 3); conv_out = ggml_permute(ctx, conv_out, 0, 2, 1, 3);
// q projection view
struct ggml_tensor * q_conv = ggml_view_4d(ctx, conv_out,
S_k, // ne0
H_k, // ne1
conv_out->ne[1], // ne2 = sequence length (1)
conv_out->ne[2], // ne3 = batch (1)
H_k * sizeof(float), // nb1 = stride along H_k
conv_out->nb[1], // nb2 = stride along sequence dim
conv_out->nb[2], // nb3 = stride along batch dim
0 // offset in bytes
);
// k projection view
struct ggml_tensor * k_conv = ggml_view_4d(ctx, conv_out,
S_k, // ne0
H_k, // ne1
conv_out->ne[1], // ne2
conv_out->ne[2], // ne3
H_k * sizeof(float), // nb1
conv_out->nb[1], // nb2
conv_out->nb[2], // nb3
S_k * H_k * sizeof(q->type) // offset = skip q_out
);
// v projection view
struct ggml_tensor * v_conv = ggml_view_4d(ctx, conv_out,
S_v, // ne0
H_v, // ne1
conv_out->ne[1], // ne2
conv_out->ne[2], // ne3
H_v * sizeof(float), // nb1
conv_out->nb[1], // nb2
conv_out->nb[2], // nb3
(2 * S_k * H_k) * sizeof(q->type)// offset = skip q_out + k_out
);
// Transpose each component back to original layout: [S_v, 1, token_split_size, 1] -> [S_v, token_split_size, 1, 1]
q_conv = ggml_permute(ctx, q_conv, 0, 2, 1, 3);
k_conv = ggml_permute(ctx, k_conv, 0, 2, 1, 3);
v_conv = ggml_permute(ctx, v_conv, 0, 2, 1, 3);
q_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, q_conv), S_k * H_k, 1, n_tokens);
k_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, k_conv), S_k * H_k, 1, n_tokens);
v_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, v_conv), S_v * H_v, 1, n_tokens);
// Split the convolved output back into q, k, v components // NOW we repeat query and key to match value head dimensions if needed (after convolution)
// Split along the last dimension (3 * original size) struct ggml_tensor * q_broadcast = q_conv;
int64_t split_size = q->ne[3]; struct ggml_tensor * k_broadcast = k_conv;
struct ggml_tensor * q_conv = ggml_view_4d(ctx, conv_out, q->ne[0], q->ne[1], q->ne[2], split_size,
conv_out->nb[0], conv_out->nb[1], conv_out->nb[2], 0);
struct ggml_tensor * k_conv = ggml_view_4d(ctx, conv_out, k->ne[0], k->ne[1], k->ne[2], split_size, if (H_k != H_v) {
conv_out->nb[0], conv_out->nb[1], conv_out->nb[2], // Calculate the repeat factor: H_v / H_k
split_size * ggml_type_size(q->type)); GGML_ASSERT(H_v % H_k == 0);
int64_t repeat_factor = H_v / H_k;
struct ggml_tensor * v_conv = ggml_view_4d(ctx, conv_out, v->ne[0], v->ne[1], v->ne[2], split_size,
conv_out->nb[0], conv_out->nb[1], conv_out->nb[2], // Repeat query and key along the head dimension
2 * split_size * ggml_type_size(q->type)); // First reshape to separate the repeat dimension: [S_k, H_k, n_tokens, 1] -> [S_k, 1, H_k, n_tokens]
q_broadcast = ggml_reshape_4d(ctx, q_conv, S_k, 1, H_k, n_tokens);
k_broadcast = ggml_reshape_4d(ctx, k_conv, S_k, 1, H_k, n_tokens);
// Repeat along the new dimension: [S_k, repeat_factor, H_k, n_tokens]
q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, repeat_factor, H_k, n_tokens);
k_broadcast = ggml_repeat_4d(ctx, k_broadcast, S_k, repeat_factor, H_k, n_tokens);
// Reshape back to original dimensions but with H_v heads: [S_k, H_v, n_tokens, 1]
q_broadcast = ggml_reshape_4d(ctx, q_broadcast, S_k, H_v, n_tokens, 1);
k_broadcast = ggml_reshape_4d(ctx, k_broadcast, S_k, H_v, n_tokens, 1);
}
// concat output and new_state // concat output and new_state
const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 }; const int64_t ne[4] = { S_v * H_v, n_tokens + H_v * n_seqs, 1, 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
// Set operation parameters for the delta rule computation // Set operation parameters for the delta rule computation
@ -5520,15 +5576,15 @@ struct ggml_tensor * ggml_delta_net(
chunk_size, chunk_size,
use_qk_l2norm ? 1 : 0, use_qk_l2norm ? 1 : 0,
0, 0, // reserved 0, 0, // reserved
0, 0, 0, 0 // scale and other params 0, 0, 0 // scale and other params
}; };
memcpy(params + 4, &scale, sizeof(float)); memcpy(params + 4, &scale, sizeof(float));
ggml_set_op_params(result, params, sizeof(params)); ggml_set_op_params(result, params, sizeof(params));
// Use custom operation for the gated delta rule computation // Use custom operation for the gated delta rule computation
result->op = GGML_OP_DELTA_NET; result->op = GGML_OP_DELTA_NET;
result->src[0] = q_conv; result->src[0] = q_broadcast;
result->src[1] = k_conv; result->src[1] = k_broadcast;
result->src[2] = v_conv; result->src[2] = v_conv;
result->src[3] = g; result->src[3] = g;
result->src[4] = beta_sigmoid; result->src[4] = beta_sigmoid;

View File

@ -19049,9 +19049,9 @@ private:
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il); cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens); Qcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Qcur), n_embd_head, hparams.n_head(il), n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); Kcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Kcur), n_embd_head, hparams.n_head_kv(il), n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); Vcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Vcur), n_embd_head, hparams.n_head_kv(il), n_tokens);
// Apply Q/K normalization // Apply Q/K normalization
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
@ -19079,8 +19079,8 @@ private:
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
// Apply gating // Apply gating
gate = ggml_reshape_2d(ctx0, gate, n_embd_q, n_tokens); gate = ggml_reshape_2d(ctx0, ggml_cont(ctx0, gate), n_embd_q, n_tokens);
cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); cur = ggml_cont(ctx0, ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)));
cb(cur, "attn_gated", il); cb(cur, "attn_gated", il);
return cur; return cur;
@ -19096,59 +19096,102 @@ private:
const auto kv_head = mctx_cur->get_head(); const auto kv_head = mctx_cur->get_head();
const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_inner = hparams.ssm_d_inner;
const int64_t d_state = hparams.ssm_d_state;
const int64_t n_heads = hparams.ssm_dt_rank; const int64_t n_heads = hparams.ssm_dt_rank;
const int64_t head_dim = d_inner / n_heads; const int64_t head_dim = d_inner / n_heads;
const int64_t n_seqs = ubatch.n_seqs; const int64_t n_seqs = ubatch.n_seqs;
const int64_t head_k_dim = hparams.ssm_d_state;
const int64_t head_v_dim = hparams.ssm_d_state;
const int64_t num_k_heads = hparams.ssm_n_group;
const int64_t num_v_heads = hparams.ssm_dt_rank;
const int64_t n_seq_tokens = ubatch.n_seq_tokens; const int64_t n_seq_tokens = ubatch.n_seq_tokens;
const int64_t n_tokens = ubatch.n_tokens;
GGML_ASSERT(n_seqs != 0); GGML_ASSERT(n_seqs != 0);
GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.equal_seqs());
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
// Input projection for QKV and beta/alpha // Input projections
ggml_tensor * qkvz_ba = build_lora_mm(model.layers[il].ssm_in, cur); ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, cur);
cb(qkvz_ba, "linear_attn_in_proj", il); cb(mixed_qkvz, "linear_attn_mixed_qkvz", il);
// Split into QKV and beta/alpha components ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur);
const int64_t qkv_size = d_inner * 2 + d_state * 2; cb(mixed_ba, "linear_attn_mixed_ba", il);
ggml_tensor * qkv = // Reshape mixed_qkvz: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*head_k_dim + 2*head_v_dim*num_v_heads/num_k_heads]
ggml_view_3d(ctx0, qkvz_ba, qkv_size, n_tokens, 1, qkv_size * sizeof(float), qkvz_ba->nb[1], 0); int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * num_v_heads / num_k_heads;
ggml_tensor * ba = ggml_view_2d(ctx0, qkvz_ba, n_embd, n_tokens, ggml_tensor * mixed_qkvz_reshaped =
qkvz_ba->nb[1], qkv_size * sizeof(float)); ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_tokens, n_seqs);
// Reshape QKV for processing // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads]
qkv = ggml_reshape_3d(ctx0, qkv, head_dim, n_heads * 2 + d_state * 2 / head_dim, n_tokens); int64_t ba_new_dim = 2 * num_v_heads / num_k_heads;
ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_tokens, n_seqs);
// Split into individual components // Split mixed_qkvz into query, key, value, z
ggml_tensor * query = int64_t split_sizes_qkvz[4] = {
ggml_view_3d(ctx0, qkv, head_dim, n_heads, n_tokens, head_dim * sizeof(float), qkv->nb[1], 0); head_k_dim, // query size
ggml_tensor * key = ggml_view_3d(ctx0, qkv, head_dim, n_heads, n_tokens, head_dim * sizeof(float), qkv->nb[1], head_k_dim, // key size
n_heads * head_dim * sizeof(float)); head_v_dim * num_v_heads / num_k_heads, // value size
ggml_tensor * value = ggml_view_3d(ctx0, qkv, head_dim, n_heads, n_tokens, head_dim * sizeof(float), qkv->nb[1], head_v_dim * num_v_heads / num_k_heads // z size
n_heads * head_dim * 2 * sizeof(float)); };
// Process beta and alpha parameters (corrected dimensions) ggml_tensor * query = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_tokens,
ggml_tensor * beta_alpha = build_lora_mm(model.layers[il].ssm_beta_alpha, ba); n_seqs, split_sizes_qkvz[0] * sizeof(float), mixed_qkvz_reshaped->nb[1],
ggml_tensor * beta = mixed_qkvz_reshaped->nb[2], 0));
ggml_view_3d(ctx0, beta_alpha, n_heads, n_tokens, n_seqs, n_heads * sizeof(float), beta_alpha->nb[1], 0);
ggml_tensor * alpha = ggml_view_3d(ctx0, beta_alpha, n_heads, n_tokens, n_seqs, n_heads * sizeof(float),
beta_alpha->nb[1], n_heads * sizeof(float));
// Apply sigmoid to beta (exactly like reference: beta = b.sigmoid()) ggml_tensor * key = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tokens, n_seqs,
beta = ggml_sigmoid(ctx0, beta); split_sizes_qkvz[1] * sizeof(float), mixed_qkvz_reshaped->nb[1],
mixed_qkvz_reshaped->nb[2], split_sizes_qkvz[0] * sizeof(float)));
ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); // a + dt_bias ggml_tensor * value =
ggml_tensor * alpha_exp = ggml_exp(ctx0, alpha_biased); // exp(a + dt_bias) ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_tokens, n_seqs,
ggml_tensor * one_tensor = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); // Create scalar tensor split_sizes_qkvz[2] * sizeof(float), mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2],
one_tensor = ggml_exp(ctx0, one_tensor); // e^0 = 1 (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float));
ggml_tensor * one_plus_exp = ggml_add1(ctx0, alpha_exp, one_tensor); // 1 + exp(a + dt_bias)
ggml_tensor * alpha_softplus = ggml_log(ctx0, one_plus_exp); // log(1 + exp(...)) ggml_tensor * z =
ggml_tensor * A_log_exp = ggml_exp(ctx0, model.layers[il].ssm_a); // A_log.exp() ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_tokens, n_seqs,
ggml_tensor * gate_scaled = ggml_mul(ctx0, alpha_softplus, A_log_exp); // A_log.exp() * softplus split_sizes_qkvz[3] * sizeof(float), mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2],
ggml_tensor * gate = ggml_neg(ctx0, gate_scaled); // - (A_log.exp() * softplus) (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float));
// Reshape value and z to merge head dimensions: [batch, seq_len, num_k_heads, head_v_dim*num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads, head_v_dim]
ggml_tensor * value_reshaped = ggml_reshape_4d(ctx0, ggml_cont(ctx0, value), head_v_dim, num_v_heads, n_tokens, n_seqs);
ggml_tensor * z_reshaped = ggml_reshape_4d(ctx0, ggml_cont(ctx0, z), head_v_dim, num_v_heads, n_tokens, n_seqs);
GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value_reshaped) +
ggml_nelements(z_reshaped) ==
ggml_nelements(mixed_qkvz));
// Split mixed_ba into b and a (beta and alpha parameters)
int64_t split_sizes_ba[2] = {
num_v_heads / num_k_heads, // beta size
num_v_heads / num_k_heads // alpha size
};
ggml_tensor * b =
ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_tokens, n_seqs,
split_sizes_ba[0] * sizeof(float), mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], 0);
ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_tokens, n_seqs,
split_sizes_ba[1] * sizeof(float), mixed_ba_reshaped->nb[1],
mixed_ba_reshaped->nb[2], split_sizes_ba[0] * sizeof(float));
// Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads]
ggml_tensor * beta = ggml_reshape_3d(ctx0, ggml_cont(ctx0, b), num_v_heads, n_tokens, n_seqs);
ggml_tensor * alpha = ggml_reshape_3d(ctx0, ggml_cont(ctx0, a), num_v_heads, n_tokens, n_seqs);
GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba));
// Softplus would be nice...
ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); // a + dt_bias
ggml_tensor * alpha_exp = ggml_exp(ctx0, alpha_biased); // exp(a + dt_bias)
ggml_tensor * one_tensor = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); // Create scalar tensor
ggml_exp(ctx0, one_tensor); // make it a 1
ggml_tensor * one_plus_exp = ggml_add1(ctx0, alpha_exp, one_tensor); // 1 + exp(a + dt_bias)
ggml_tensor * alpha_softplus = ggml_log(ctx0, one_plus_exp); // log(1 + exp(...))
ggml_tensor * A_log_exp = ggml_exp(ctx0, model.layers[il].ssm_a); // A_log.exp()
ggml_tensor * gate_scaled = ggml_mul(ctx0, alpha_softplus, A_log_exp); // A_log.exp() * softplus
ggml_tensor * gate = ggml_neg(ctx0, gate_scaled); // - (A_log.exp() * softplus)
// Get convolution weights and bias // Get convolution weights and bias
ggml_tensor * conv_weight = model.layers[il].ssm_conv1d; ggml_tensor * conv_weight = model.layers[il].ssm_conv1d;
@ -19157,12 +19200,6 @@ private:
// Get recurrent states (conv_states not needed as it's handled internally by ggml_delta_net) // Get recurrent states (conv_states not needed as it's handled internally by ggml_delta_net)
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
// Reshape tensors to match ggml_delta_net expectations
// [S, H, n_tokens, n_seqs] format
query = ggml_reshape_4d(ctx0, query, head_dim, n_heads, n_tokens, n_seqs);
key = ggml_reshape_4d(ctx0, key, head_dim, n_heads, n_tokens, n_seqs);
value = ggml_reshape_4d(ctx0, value, head_dim, n_heads, n_tokens, n_seqs);
// Beta tensor // Beta tensor
beta = ggml_reshape_3d(ctx0, beta, n_heads, n_tokens, n_seqs); beta = ggml_reshape_3d(ctx0, beta, n_heads, n_tokens, n_seqs);
@ -19170,22 +19207,25 @@ private:
ggml_tensor * state = ggml_view_4d(ctx0, ssm_states_all, head_dim, head_dim, n_heads, n_seqs, ggml_tensor * state = ggml_view_4d(ctx0, ssm_states_all, head_dim, head_dim, n_heads, n_seqs,
ssm_states_all->nb[0], ssm_states_all->nb[1], ssm_states_all->nb[2], ssm_states_all->nb[0], ssm_states_all->nb[1], ssm_states_all->nb[2],
kv_head * head_dim * head_dim * n_heads * ggml_element_size(ssm_states_all)); kv_head * head_dim * head_dim * n_heads * ggml_element_size(ssm_states_all));
state = ggml_cont(ctx0, state); state = ggml_cont(ctx0, state);
gate = ggml_repeat(ctx0, gate, ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 1, n_heads, n_tokens, n_seqs));
// Call the new ggml_delta_net function ggml_tensor * target_gate = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_dim, n_heads, n_tokens, n_seqs);
ggml_tensor * gate_broadcast = ggml_reshape_4d(ctx0, gate, 1, n_heads, n_tokens, n_seqs);
gate = ggml_repeat(ctx0, gate_broadcast, target_gate);
// Call the new ggml_delta_net function with the corrected flow
ggml_tensor * output = ggml_delta_net(ctx0, ggml_tensor * output = ggml_delta_net(ctx0,
key, // k tensor key, // k tensor
value, // v tensor value_reshaped, // v tensor
query, // q tensor query, // q tensor
gate, // g tensor gate, // g tensor
conv_weight, // conv_weight tensor conv_weight, // conv_weight tensor
conv_bias, // conv_bias tensor (can be nullptr) conv_bias, // conv_bias tensor (can be nullptr)
beta, // beta tensor beta, // beta tensor
state, // state tensor state, // state tensor
64, // chunk_size (adjust as needed) 64, // chunk_size (adjust as needed)
true, // use_qk_l2norm true, // use_qk_l2norm
1.0f // scale (adjust based on your model) 1.0f // scale (adjust based on your model)
); );
cb(output, "delta_net_output", il); cb(output, "delta_net_output", il);
@ -19205,18 +19245,37 @@ private:
ctx0, ssm_states_all, head_dim * head_dim * n_heads * n_seqs, ctx0, ssm_states_all, head_dim * head_dim * n_heads * n_seqs,
kv_head * n_seqs * head_dim * head_dim * n_heads * ggml_element_size(ssm_states_all)))); kv_head * n_seqs * head_dim * head_dim * n_heads * ggml_element_size(ssm_states_all))));
// Apply normalization and gating // Reshape both attn_out and z to 2D tensors for normalization
attn_out = build_norm(attn_out, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); // attn_out: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
ggml_tensor * attn_out_2d = ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out), head_dim, n_heads * n_tokens * n_seqs);
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]
ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z_reshaped, head_dim, n_heads * n_tokens * n_seqs);
// Apply gated normalization: self.norm(core_attn_out, z)
// This is Qwen3NextRMSNormGated which applies: RMSNorm(x) * silu(gate)
ggml_tensor * attn_out_norm = build_norm(attn_out_2d, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
// Apply silu gate: attn_out_norm * silu(z_2d)
ggml_tensor * z_silu = ggml_silu(ctx0, z_2d);
ggml_tensor * gated_output = ggml_mul(ctx0, attn_out_norm, z_silu);
// Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs]
ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, gated_output, head_dim, n_heads, n_tokens, n_seqs);
// Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim]
ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, n_heads * head_dim, n_tokens, n_seqs);
// Output projection // Output projection
cur = build_lora_mm(model.layers[il].wo, attn_out); cur = build_lora_mm(model.layers[il].ssm_out, final_output);
cb(cur, "linear_attn_out", il); cb(cur, "linear_attn_out", il);
// Reshape back to original dimensions // Reshape back to original dimensions
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); cur = ggml_cont(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens));
return cur; return cur;
} }
ggml_tensor * build_layer_ffn(ggml_tensor * cur, const llama_model & model, const int il) { ggml_tensor * build_layer_ffn(ggml_tensor * cur, const llama_model & model, const int il) {
// Check if this is an MoE layer // Check if this is an MoE layer
if (model.layers[il].ffn_gate_inp != nullptr) { if (model.layers[il].ffn_gate_inp != nullptr) {