diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 0b171ffd31..ff3550a2d3 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3435,7 +3435,7 @@ struct ggml_tensor * ggml_reshape_4d( int64_t ne2, int64_t ne3) { GGML_ASSERT(ggml_is_contiguous(a)); - GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3); + GGML_ASSERT(ggml_nelements(a) == ne0*ne1*ne2*ne3); const int64_t ne[4] = { ne0, ne1, ne2, ne3 }; struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0); @@ -5441,17 +5441,25 @@ struct ggml_tensor * ggml_delta_net( GGML_ASSERT(ggml_is_contiguous(beta)); GGML_ASSERT(ggml_is_contiguous(state)); - const int64_t S = k->ne[0]; - const int64_t H = k->ne[1]; + const int64_t S_k = k->ne[0]; + const int64_t H_k = k->ne[1]; const int64_t n_tokens = k->ne[2]; const int64_t n_seqs = state->ne[1]; - // Validate dimensions - GGML_ASSERT(v->ne[0] == S && v->ne[1] == H && v->ne[2] == n_tokens); - GGML_ASSERT(q->ne[0] == S && q->ne[1] == H && q->ne[2] == n_tokens); - GGML_ASSERT(g->ne[0] == S && g->ne[1] == H && g->ne[2] == n_tokens); - GGML_ASSERT(beta->ne[0] == H && beta->ne[1] == n_tokens && beta->ne[2] == n_seqs); - GGML_ASSERT(ggml_nelements(state) == S * S * H * n_seqs); + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + // Validate dimensions - allow different head dimensions for q/k vs v + GGML_ASSERT(v->ne[2] == n_tokens); + GGML_ASSERT(q->ne[2] == n_tokens); + GGML_ASSERT(g->ne[2] == n_tokens); + GGML_ASSERT(beta->ne[0] == H_v && beta->ne[1] == n_tokens && (beta->ne[2] == n_seqs || beta->ne[2] == 1)); + GGML_ASSERT(ggml_nelements(state) == S_v * H_v * n_seqs); + + // Check that q and k have the same dimensions + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == S_v && g->ne[1] == H_v && g->ne[2] == n_tokens); // Apply L2 normalization to query and key if requested struct ggml_tensor * q_norm = q; @@ -5466,53 +5474,101 @@ struct ggml_tensor * ggml_delta_net( // Apply sigmoid to beta for gating struct ggml_tensor * beta_sigmoid = ggml_sigmoid(ctx, beta); - - // Apply causal 1D convolution preprocessing to mixed QKV - // Concatenate q, k, v along the feature dimension - int64_t concat_ne[4] = { q->ne[0], q->ne[1], q->ne[2], q->ne[3] * 3 }; - struct ggml_tensor * mixed_qkv = ggml_concat(ctx, q_norm, k_norm, 3); - mixed_qkv = ggml_concat(ctx, mixed_qkv, v, 3); - - // Transpose for convolution: [S, H, n_tokens, n_seqs*3] -> [S, n_tokens, H, n_seqs*3] - mixed_qkv = ggml_permute(ctx, mixed_qkv, 0, 2, 1, 3); - - // Apply causal 1D convolution - struct ggml_tensor * conv_out = ggml_conv_1d( - ctx, - conv_weight, - mixed_qkv, - 1, // stride - conv_weight->ne[2] - 1, // padding (kernel_size - 1) - 1 // dilation - ); - + struct ggml_tensor * mixed_qkv = ggml_concat(ctx, q_norm, k_norm, 1); + mixed_qkv = ggml_concat(ctx, mixed_qkv, v, 1); + + u_int32_t dim = (S_v * H_v) + 2 * (H_k * S_k); + + mixed_qkv = ggml_reshape_3d(ctx, mixed_qkv, 1, dim, n_tokens); + struct ggml_tensor * mixed_qkv_padded = ggml_pad(ctx, mixed_qkv, 3, 0, 0, 0); + + // Apply SSM convolution + struct ggml_tensor * conv_out = ggml_ssm_conv(ctx, mixed_qkv_padded, conv_weight); + // Apply bias if provided if (conv_bias) { conv_out = ggml_add(ctx, conv_out, conv_bias); } - + // Apply SiLU activation conv_out = ggml_silu(ctx, conv_out); - - // Transpose back: [S, n_tokens, H, n_seqs*3] -> [S, H, n_tokens, n_seqs*3] + + // Reshape back to 4D: [dim, n_tokens, 1] -> [dim, n_tokens, 1, 1] + conv_out = ggml_reshape_4d(ctx, conv_out, dim, n_tokens, 1, 1); + + // Transpose to get the right layout: [dim, n_tokens, 1] -> [dim, 1, n_tokens, 1] conv_out = ggml_permute(ctx, conv_out, 0, 2, 1, 3); + + // q projection view + struct ggml_tensor * q_conv = ggml_view_4d(ctx, conv_out, + S_k, // ne0 + H_k, // ne1 + conv_out->ne[1], // ne2 = sequence length (1) + conv_out->ne[2], // ne3 = batch (1) + H_k * sizeof(float), // nb1 = stride along H_k + conv_out->nb[1], // nb2 = stride along sequence dim + conv_out->nb[2], // nb3 = stride along batch dim + 0 // offset in bytes + ); + + // k projection view + struct ggml_tensor * k_conv = ggml_view_4d(ctx, conv_out, + S_k, // ne0 + H_k, // ne1 + conv_out->ne[1], // ne2 + conv_out->ne[2], // ne3 + H_k * sizeof(float), // nb1 + conv_out->nb[1], // nb2 + conv_out->nb[2], // nb3 + S_k * H_k * sizeof(q->type) // offset = skip q_out + ); + + // v projection view + struct ggml_tensor * v_conv = ggml_view_4d(ctx, conv_out, + S_v, // ne0 + H_v, // ne1 + conv_out->ne[1], // ne2 + conv_out->ne[2], // ne3 + H_v * sizeof(float), // nb1 + conv_out->nb[1], // nb2 + conv_out->nb[2], // nb3 + (2 * S_k * H_k) * sizeof(q->type)// offset = skip q_out + k_out + ); + + // Transpose each component back to original layout: [S_v, 1, token_split_size, 1] -> [S_v, token_split_size, 1, 1] + q_conv = ggml_permute(ctx, q_conv, 0, 2, 1, 3); + k_conv = ggml_permute(ctx, k_conv, 0, 2, 1, 3); + v_conv = ggml_permute(ctx, v_conv, 0, 2, 1, 3); + + q_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, q_conv), S_k * H_k, 1, n_tokens); + k_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, k_conv), S_k * H_k, 1, n_tokens); + v_conv = ggml_reshape_3d(ctx, ggml_cont(ctx, v_conv), S_v * H_v, 1, n_tokens); - // Split the convolved output back into q, k, v components - // Split along the last dimension (3 * original size) - int64_t split_size = q->ne[3]; - struct ggml_tensor * q_conv = ggml_view_4d(ctx, conv_out, q->ne[0], q->ne[1], q->ne[2], split_size, - conv_out->nb[0], conv_out->nb[1], conv_out->nb[2], 0); + // NOW we repeat query and key to match value head dimensions if needed (after convolution) + struct ggml_tensor * q_broadcast = q_conv; + struct ggml_tensor * k_broadcast = k_conv; - struct ggml_tensor * k_conv = ggml_view_4d(ctx, conv_out, k->ne[0], k->ne[1], k->ne[2], split_size, - conv_out->nb[0], conv_out->nb[1], conv_out->nb[2], - split_size * ggml_type_size(q->type)); - - struct ggml_tensor * v_conv = ggml_view_4d(ctx, conv_out, v->ne[0], v->ne[1], v->ne[2], split_size, - conv_out->nb[0], conv_out->nb[1], conv_out->nb[2], - 2 * split_size * ggml_type_size(q->type)); + if (H_k != H_v) { + // Calculate the repeat factor: H_v / H_k + GGML_ASSERT(H_v % H_k == 0); + int64_t repeat_factor = H_v / H_k; + + // Repeat query and key along the head dimension + // First reshape to separate the repeat dimension: [S_k, H_k, n_tokens, 1] -> [S_k, 1, H_k, n_tokens] + q_broadcast = ggml_reshape_4d(ctx, q_conv, S_k, 1, H_k, n_tokens); + k_broadcast = ggml_reshape_4d(ctx, k_conv, S_k, 1, H_k, n_tokens); + + // Repeat along the new dimension: [S_k, repeat_factor, H_k, n_tokens] + q_broadcast = ggml_repeat_4d(ctx, q_broadcast, S_k, repeat_factor, H_k, n_tokens); + k_broadcast = ggml_repeat_4d(ctx, k_broadcast, S_k, repeat_factor, H_k, n_tokens); + + // Reshape back to original dimensions but with H_v heads: [S_k, H_v, n_tokens, 1] + q_broadcast = ggml_reshape_4d(ctx, q_broadcast, S_k, H_v, n_tokens, 1); + k_broadcast = ggml_reshape_4d(ctx, k_broadcast, S_k, H_v, n_tokens, 1); + } // concat output and new_state - const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 }; + const int64_t ne[4] = { S_v * H_v, n_tokens + H_v * n_seqs, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); // Set operation parameters for the delta rule computation @@ -5520,15 +5576,15 @@ struct ggml_tensor * ggml_delta_net( chunk_size, use_qk_l2norm ? 1 : 0, 0, 0, // reserved - 0, 0, 0, 0 // scale and other params + 0, 0, 0 // scale and other params }; memcpy(params + 4, &scale, sizeof(float)); ggml_set_op_params(result, params, sizeof(params)); // Use custom operation for the gated delta rule computation result->op = GGML_OP_DELTA_NET; - result->src[0] = q_conv; - result->src[1] = k_conv; + result->src[0] = q_broadcast; + result->src[1] = k_broadcast; result->src[2] = v_conv; result->src[3] = g; result->src[4] = beta_sigmoid; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b4fb644918..acd7ed8e31 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -19049,9 +19049,9 @@ private: cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + Qcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Qcur), n_embd_head, hparams.n_head(il), n_tokens); + Kcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Kcur), n_embd_head, hparams.n_head_kv(il), n_tokens); + Vcur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, Vcur), n_embd_head, hparams.n_head_kv(il), n_tokens); // Apply Q/K normalization Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); @@ -19079,8 +19079,8 @@ private: Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); // Apply gating - gate = ggml_reshape_2d(ctx0, gate, n_embd_q, n_tokens); - cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + gate = ggml_reshape_2d(ctx0, ggml_cont(ctx0, gate), n_embd_q, n_tokens); + cur = ggml_cont(ctx0, ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate))); cb(cur, "attn_gated", il); return cur; @@ -19096,59 +19096,102 @@ private: const auto kv_head = mctx_cur->get_head(); const int64_t d_inner = hparams.ssm_d_inner; - const int64_t d_state = hparams.ssm_d_state; const int64_t n_heads = hparams.ssm_dt_rank; const int64_t head_dim = d_inner / n_heads; const int64_t n_seqs = ubatch.n_seqs; + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t num_k_heads = hparams.ssm_n_group; + const int64_t num_v_heads = hparams.ssm_dt_rank; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_tokens = ubatch.n_tokens; GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); - // Input projection for QKV and beta/alpha - ggml_tensor * qkvz_ba = build_lora_mm(model.layers[il].ssm_in, cur); - cb(qkvz_ba, "linear_attn_in_proj", il); + // Input projections + ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, cur); + cb(mixed_qkvz, "linear_attn_mixed_qkvz", il); - // Split into QKV and beta/alpha components - const int64_t qkv_size = d_inner * 2 + d_state * 2; + ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur); + cb(mixed_ba, "linear_attn_mixed_ba", il); - ggml_tensor * qkv = - ggml_view_3d(ctx0, qkvz_ba, qkv_size, n_tokens, 1, qkv_size * sizeof(float), qkvz_ba->nb[1], 0); - ggml_tensor * ba = ggml_view_2d(ctx0, qkvz_ba, n_embd, n_tokens, - qkvz_ba->nb[1], qkv_size * sizeof(float)); + // Reshape mixed_qkvz: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*head_k_dim + 2*head_v_dim*num_v_heads/num_k_heads] + int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * num_v_heads / num_k_heads; + ggml_tensor * mixed_qkvz_reshaped = + ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_tokens, n_seqs); - // Reshape QKV for processing - qkv = ggml_reshape_3d(ctx0, qkv, head_dim, n_heads * 2 + d_state * 2 / head_dim, n_tokens); + // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads] + int64_t ba_new_dim = 2 * num_v_heads / num_k_heads; + ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_tokens, n_seqs); - // Split into individual components - ggml_tensor * query = - ggml_view_3d(ctx0, qkv, head_dim, n_heads, n_tokens, head_dim * sizeof(float), qkv->nb[1], 0); - ggml_tensor * key = ggml_view_3d(ctx0, qkv, head_dim, n_heads, n_tokens, head_dim * sizeof(float), qkv->nb[1], - n_heads * head_dim * sizeof(float)); - ggml_tensor * value = ggml_view_3d(ctx0, qkv, head_dim, n_heads, n_tokens, head_dim * sizeof(float), qkv->nb[1], - n_heads * head_dim * 2 * sizeof(float)); + // Split mixed_qkvz into query, key, value, z + int64_t split_sizes_qkvz[4] = { + head_k_dim, // query size + head_k_dim, // key size + head_v_dim * num_v_heads / num_k_heads, // value size + head_v_dim * num_v_heads / num_k_heads // z size + }; - // Process beta and alpha parameters (corrected dimensions) - ggml_tensor * beta_alpha = build_lora_mm(model.layers[il].ssm_beta_alpha, ba); - ggml_tensor * beta = - ggml_view_3d(ctx0, beta_alpha, n_heads, n_tokens, n_seqs, n_heads * sizeof(float), beta_alpha->nb[1], 0); - ggml_tensor * alpha = ggml_view_3d(ctx0, beta_alpha, n_heads, n_tokens, n_seqs, n_heads * sizeof(float), - beta_alpha->nb[1], n_heads * sizeof(float)); + ggml_tensor * query = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_tokens, + n_seqs, split_sizes_qkvz[0] * sizeof(float), mixed_qkvz_reshaped->nb[1], + mixed_qkvz_reshaped->nb[2], 0)); - // Apply sigmoid to beta (exactly like reference: beta = b.sigmoid()) - beta = ggml_sigmoid(ctx0, beta); + ggml_tensor * key = ggml_cont(ctx0, ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_tokens, n_seqs, + split_sizes_qkvz[1] * sizeof(float), mixed_qkvz_reshaped->nb[1], + mixed_qkvz_reshaped->nb[2], split_sizes_qkvz[0] * sizeof(float))); - ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); // a + dt_bias - ggml_tensor * alpha_exp = ggml_exp(ctx0, alpha_biased); // exp(a + dt_bias) - ggml_tensor * one_tensor = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); // Create scalar tensor - one_tensor = ggml_exp(ctx0, one_tensor); // e^0 = 1 - ggml_tensor * one_plus_exp = ggml_add1(ctx0, alpha_exp, one_tensor); // 1 + exp(a + dt_bias) - ggml_tensor * alpha_softplus = ggml_log(ctx0, one_plus_exp); // log(1 + exp(...)) - ggml_tensor * A_log_exp = ggml_exp(ctx0, model.layers[il].ssm_a); // A_log.exp() - ggml_tensor * gate_scaled = ggml_mul(ctx0, alpha_softplus, A_log_exp); // A_log.exp() * softplus - ggml_tensor * gate = ggml_neg(ctx0, gate_scaled); // - (A_log.exp() * softplus) + ggml_tensor * value = + ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_tokens, n_seqs, + split_sizes_qkvz[2] * sizeof(float), mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], + (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float)); + + ggml_tensor * z = + ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_tokens, n_seqs, + split_sizes_qkvz[3] * sizeof(float), mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], + (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float)); + + // Reshape value and z to merge head dimensions: [batch, seq_len, num_k_heads, head_v_dim*num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads, head_v_dim] + ggml_tensor * value_reshaped = ggml_reshape_4d(ctx0, ggml_cont(ctx0, value), head_v_dim, num_v_heads, n_tokens, n_seqs); + ggml_tensor * z_reshaped = ggml_reshape_4d(ctx0, ggml_cont(ctx0, z), head_v_dim, num_v_heads, n_tokens, n_seqs); + + GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value_reshaped) + + ggml_nelements(z_reshaped) == + ggml_nelements(mixed_qkvz)); + + // Split mixed_ba into b and a (beta and alpha parameters) + int64_t split_sizes_ba[2] = { + num_v_heads / num_k_heads, // beta size + num_v_heads / num_k_heads // alpha size + }; + + ggml_tensor * b = + ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_tokens, n_seqs, + split_sizes_ba[0] * sizeof(float), mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], 0); + + ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_tokens, n_seqs, + split_sizes_ba[1] * sizeof(float), mixed_ba_reshaped->nb[1], + mixed_ba_reshaped->nb[2], split_sizes_ba[0] * sizeof(float)); + + // Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads] + ggml_tensor * beta = ggml_reshape_3d(ctx0, ggml_cont(ctx0, b), num_v_heads, n_tokens, n_seqs); + ggml_tensor * alpha = ggml_reshape_3d(ctx0, ggml_cont(ctx0, a), num_v_heads, n_tokens, n_seqs); + + GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba)); + + // Softplus would be nice... + ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); // a + dt_bias + ggml_tensor * alpha_exp = ggml_exp(ctx0, alpha_biased); // exp(a + dt_bias) + ggml_tensor * one_tensor = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); // Create scalar tensor + ggml_exp(ctx0, one_tensor); // make it a 1 + ggml_tensor * one_plus_exp = ggml_add1(ctx0, alpha_exp, one_tensor); // 1 + exp(a + dt_bias) + ggml_tensor * alpha_softplus = ggml_log(ctx0, one_plus_exp); // log(1 + exp(...)) + ggml_tensor * A_log_exp = ggml_exp(ctx0, model.layers[il].ssm_a); // A_log.exp() + ggml_tensor * gate_scaled = ggml_mul(ctx0, alpha_softplus, A_log_exp); // A_log.exp() * softplus + ggml_tensor * gate = ggml_neg(ctx0, gate_scaled); // - (A_log.exp() * softplus) // Get convolution weights and bias ggml_tensor * conv_weight = model.layers[il].ssm_conv1d; @@ -19157,12 +19200,6 @@ private: // Get recurrent states (conv_states not needed as it's handled internally by ggml_delta_net) ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Reshape tensors to match ggml_delta_net expectations - // [S, H, n_tokens, n_seqs] format - query = ggml_reshape_4d(ctx0, query, head_dim, n_heads, n_tokens, n_seqs); - key = ggml_reshape_4d(ctx0, key, head_dim, n_heads, n_tokens, n_seqs); - value = ggml_reshape_4d(ctx0, value, head_dim, n_heads, n_tokens, n_seqs); - // Beta tensor beta = ggml_reshape_3d(ctx0, beta, n_heads, n_tokens, n_seqs); @@ -19170,22 +19207,25 @@ private: ggml_tensor * state = ggml_view_4d(ctx0, ssm_states_all, head_dim, head_dim, n_heads, n_seqs, ssm_states_all->nb[0], ssm_states_all->nb[1], ssm_states_all->nb[2], kv_head * head_dim * head_dim * n_heads * ggml_element_size(ssm_states_all)); - state = ggml_cont(ctx0, state); - gate = ggml_repeat(ctx0, gate, ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 1, n_heads, n_tokens, n_seqs)); + state = ggml_cont(ctx0, state); - // Call the new ggml_delta_net function + ggml_tensor * target_gate = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_dim, n_heads, n_tokens, n_seqs); + ggml_tensor * gate_broadcast = ggml_reshape_4d(ctx0, gate, 1, n_heads, n_tokens, n_seqs); + gate = ggml_repeat(ctx0, gate_broadcast, target_gate); + + // Call the new ggml_delta_net function with the corrected flow ggml_tensor * output = ggml_delta_net(ctx0, - key, // k tensor - value, // v tensor - query, // q tensor - gate, // g tensor - conv_weight, // conv_weight tensor - conv_bias, // conv_bias tensor (can be nullptr) - beta, // beta tensor - state, // state tensor - 64, // chunk_size (adjust as needed) - true, // use_qk_l2norm - 1.0f // scale (adjust based on your model) + key, // k tensor + value_reshaped, // v tensor + query, // q tensor + gate, // g tensor + conv_weight, // conv_weight tensor + conv_bias, // conv_bias tensor (can be nullptr) + beta, // beta tensor + state, // state tensor + 64, // chunk_size (adjust as needed) + true, // use_qk_l2norm + 1.0f // scale (adjust based on your model) ); cb(output, "delta_net_output", il); @@ -19205,18 +19245,37 @@ private: ctx0, ssm_states_all, head_dim * head_dim * n_heads * n_seqs, kv_head * n_seqs * head_dim * head_dim * n_heads * ggml_element_size(ssm_states_all)))); - // Apply normalization and gating - attn_out = build_norm(attn_out, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); + // Reshape both attn_out and z to 2D tensors for normalization + // attn_out: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] + ggml_tensor * attn_out_2d = ggml_reshape_2d(ctx0, ggml_cont(ctx0, attn_out), head_dim, n_heads * n_tokens * n_seqs); + + // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] + ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z_reshaped, head_dim, n_heads * n_tokens * n_seqs); + + // Apply gated normalization: self.norm(core_attn_out, z) + // This is Qwen3NextRMSNormGated which applies: RMSNorm(x) * silu(gate) + ggml_tensor * attn_out_norm = build_norm(attn_out_2d, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); + + // Apply silu gate: attn_out_norm * silu(z_2d) + ggml_tensor * z_silu = ggml_silu(ctx0, z_2d); + ggml_tensor * gated_output = ggml_mul(ctx0, attn_out_norm, z_silu); + + // Reshape back to original dimensions: [n_heads * n_tokens * n_seqs, head_dim] -> [head_dim, n_heads, n_tokens, n_seqs] + ggml_tensor * gated_output_4d = ggml_reshape_4d(ctx0, gated_output, head_dim, n_heads, n_tokens, n_seqs); + + // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim] + ggml_tensor * final_output = ggml_reshape_3d(ctx0, gated_output_4d, n_heads * head_dim, n_tokens, n_seqs); // Output projection - cur = build_lora_mm(model.layers[il].wo, attn_out); + cur = build_lora_mm(model.layers[il].ssm_out, final_output); cb(cur, "linear_attn_out", il); // Reshape back to original dimensions - cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens); + cur = ggml_cont(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens)); return cur; } + ggml_tensor * build_layer_ffn(ggml_tensor * cur, const llama_model & model, const int il) { // Check if this is an MoE layer if (model.layers[il].ffn_gate_inp != nullptr) {