diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index 99b1a76a48..b4f509ee90 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -117,7 +117,7 @@ std::pair llm_build_qwen3next::build_delta_net_chu GGML_ASSERT(k->ne[2] == n_tokens); GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs); GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); @@ -141,25 +141,23 @@ std::pair llm_build_qwen3next::build_delta_net_chu cb(beta, "beta_in", il); cb(g, "g_in", il); - q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + k = ggml_permute(ctx0, k, 0, 2, 1, 3); + v = ggml_permute(ctx0, v, 0, 2, 1, 3); + g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_v, n_seqs); - beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); + beta = ggml_permute(ctx0, beta, 2, 0, 1, 3); cb(q, "q_perm", il); cb(k, "k_perm", il); cb(v, "v_perm", il); cb(beta, "beta_perm", il); cb(g, "g_perm", il); - cb(state, "state_in", il); GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); + GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_v && v->ne[3] == n_seqs); + GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_v && beta->ne[0] == 1 && beta->ne[3] == n_seqs); // Do padding const int64_t chunk_size = CHUNK_SIZE; @@ -180,19 +178,19 @@ std::pair llm_build_qwen3next::build_delta_net_chu cb(g, "g_pad", il); ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); - ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); + ggml_tensor * k_beta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, beta, k->ne[0], beta->ne[1], beta->ne[2], beta->ne[3]), k); cb(v_beta, "v_beta", il); cb(k_beta, "k_beta", il); q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); - k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); + k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_v * n_seqs); v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); - g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); - beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); + g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_v * n_seqs); + beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_v * n_seqs); ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) @@ -237,8 +235,8 @@ std::pair llm_build_qwen3next::build_delta_net_chu cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q); - attn_kq = ggml_mul(ctx0, attn_kq, decay_mask); - attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); + attn_kq = ggml_mul(ctx0, decay_mask, attn_kq); + attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) @@ -268,17 +266,12 @@ std::pair llm_build_qwen3next::build_delta_net_chu ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp, 1, chunk_size, n_chunks, g_diff_exp->ne[3]); - ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t); + ggml_tensor * key_gdiff = ggml_mul(ctx0, ggml_repeat_4d(ctx0, g_diff_exp_t, k->ne[0], g_diff_exp_t->ne[1], g_diff_exp_t->ne[2], g_diff_exp_t->ne[3]), k); cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)); cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs) - - // state to be updated per chunk - ggml_tensor * new_state = state; // ggml_dup(ctx0, state); - cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs) - // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs) ggml_tensor * core_attn_out = nullptr; @@ -300,7 +293,7 @@ std::pair llm_build_qwen3next::build_delta_net_chu ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk); cb(attn_chunk, "attn_chunk", il); - ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); + ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); @@ -312,7 +305,7 @@ std::pair llm_build_qwen3next::build_delta_net_chu cb(v_new, "v_new_chunk", il); // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); + ggml_tensor * q_g_exp = ggml_mul(ctx0, ggml_repeat_4d(ctx0, gexp_chunk, q_chunk->ne[0], gexp_chunk->ne[1], gexp_chunk->ne[2], gexp_chunk->ne[3]), q_chunk); ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); cb(attn_inter, "attn_inter_chunk", il); @@ -334,8 +327,8 @@ std::pair llm_build_qwen3next::build_delta_net_chu // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk)); - new_state = ggml_add(ctx0, - ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)), + state = ggml_add(ctx0, + ggml_mul(ctx0, state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)), ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); } @@ -345,14 +338,14 @@ std::pair llm_build_qwen3next::build_delta_net_chu ggml_row_size(core_attn_out->type, S_v), ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); - output_tokens = ggml_cont(ctx0, output_tokens); + cb(output_tokens, "output_tokens", il); // permute back to (S_v, H_v, n_tokens, n_seqs) output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); output_tokens = ggml_cont(ctx0, output_tokens); - return {output_tokens, new_state}; + return {output_tokens, state}; } std::pair llm_build_qwen3next::build_delta_net_autoregressive( @@ -376,33 +369,35 @@ std::pair llm_build_qwen3next::build_delta_net_aut GGML_ASSERT(k->ne[2] == n_tokens); GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs); GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case + //GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case const float eps_norm = hparams.f_norm_rms_eps; q = ggml_l2_norm(ctx0, q, eps_norm); k = ggml_l2_norm(ctx0, k, eps_norm); - const float scale = 1.0f / sqrtf(S_v); + const float scale = 1.0f / sqrtf(S_k); q = ggml_scale(ctx0, q, scale); beta = ggml_sigmoid(ctx0, beta); + q = ggml_permute(ctx0, q, 0, 2, 1, 3); + k = ggml_permute(ctx0, k, 0, 2, 1, 3); + v = ggml_permute(ctx0, v, 1, 2, 0, 3); + cb(q, "q_in", il); cb(k, "k_in", il); cb(v, "v_in", il); cb(beta, "beta_in", il); cb(g, "g_in", il); - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); - ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); + ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_v, n_seqs); + ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_v, n_seqs); // Apply exponential to g_t g_t = ggml_exp(ctx0, g_t); @@ -412,28 +407,26 @@ std::pair llm_build_qwen3next::build_delta_net_aut state = ggml_mul(ctx0, state, g_t); // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs); - ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); - // we need to sum over dim=-2, so we transpose, sum, then transpose again - kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)))); + state = ggml_cont(ctx0, ggml_transpose(ctx0, state)); + ggml_tensor * kv_mem = ggml_mul(ctx0, state, k); + kv_mem = ggml_sum_rows(ctx0, kv_mem); // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v) - ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); // delta = (v_t - kv_mem) * beta_t - ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs] + ggml_tensor * v_diff = ggml_sub(ctx0, v, kv_mem); // both should be [1, S_v, H_v, n_seqs] ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta - ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta); - state = ggml_add(ctx0, state, k_t_delta); + ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k, S_v, S_v, H_v, n_seqs), delta); + state = ggml_add(ctx0, state, k_t_delta); // Compute the attention output // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t - ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); - // again, since it's over dim = -2, transpose, sum, transpose back - ggml_tensor * core_attn_out = - ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q)))); + ggml_tensor * state_q = ggml_mul(ctx0, state, q); + ggml_tensor * core_attn_out = ggml_sum_rows(ctx0, state_q); + + core_attn_out = ggml_transpose(ctx0, core_attn_out); + state = ggml_transpose(ctx0, state); // core_attn_out should be [S_v, 1, H_v, n_seqs] after this cb(core_attn_out, "output_tokens", il); @@ -734,25 +727,27 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim); // Extract the convolved Q, K, V from conv_output - ggml_tensor * q_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0); - cb(q_conv, "q_conv", il); - ggml_tensor * k_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, + ggml_tensor * q_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0); + ggml_tensor * k_conv = ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + + ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_v_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads)); + + cb(q_conv, "q_conv", il); cb(k_conv, "k_conv", il); - ggml_tensor * v_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv, - 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(v_conv, "v_conv", il); // Unsqueeze them q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); - v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); - state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs); + state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); cb(state, "state_predelta", il); // if head keys and value keys are different, repeat to force tensors into matching shapes @@ -818,7 +813,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(cur, "linear_attn_out", il); // Reshape back to original dimensions - cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); return cur; }