From 865ff06b2ffa2f91f30cbec1f8c73d66cc6642aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 11 Apr 2026 09:23:42 +0200 Subject: [PATCH] TP: fix Qwen 3 Next data split (#21732) --- src/llama-model.cpp | 58 +++++++++++++++++++++++++--------------- src/models/qwen3next.cpp | 4 +-- 2 files changed, 39 insertions(+), 23 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f057fd31c3..d2ffc1f45f 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -202,24 +202,37 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str const int64_t n_v_heads = hparams.ssm_dt_rank; const int64_t key_dim = head_k_dim * n_k_heads; const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t head_ratio = n_v_heads / n_k_heads; - if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) { - GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim); - return std::vector(2 + head_ratio, key_dim); - } - if (std::regex_match(tensor_name, pattern_attn_gate_weight) || std::regex_match(tensor_name, pattern_ssm_out_weight)) { - return std::vector(head_ratio, key_dim); - } - if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) || - std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) { - return std::vector(head_ratio, n_k_heads); - } - if (std::regex_match(tensor_name, pattern_r_cache)) { - return std::vector(2 + head_ratio, key_dim * (hparams.ssm_d_conv - 1)); - } - if (std::regex_match(tensor_name, pattern_s_cache)) { - return std::vector(head_ratio, n_k_heads * head_v_dim * head_v_dim); + + // both Qwen 3 Next and Qwen 3.5 support n_v_heads > n_k_heads but the broadcasting pattern is different: + // - Qwen 3 Next: [k0_v0, k0_v1, k1_v2, k1_v3] (this is the default split pattern) + // - Qwen 3.5: [k0_v0, k1_v1, k0_v2, k1_v3] (needs segmenting of V on the scale of K to get the correct pattern) + if (ud->model->arch == LLM_ARCH_QWEN3NEXT) { + if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) { + GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim); + return {key_dim, key_dim, value_dim}; + } + } else { + const int64_t head_ratio = n_v_heads / n_k_heads; + if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) { + GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim); + return std::vector(2 + head_ratio, key_dim); + } + if (std::regex_match(tensor_name, pattern_attn_gate_weight) || std::regex_match(tensor_name, pattern_ssm_out_weight)) { + return std::vector(head_ratio, key_dim); + } + if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) || + std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) { + return std::vector(head_ratio, n_k_heads); + } + if (std::regex_match(tensor_name, pattern_r_cache)) { + return std::vector(2 + head_ratio, key_dim * (hparams.ssm_d_conv - 1)); + } + if (std::regex_match(tensor_name, pattern_s_cache)) { + return std::vector(head_ratio, n_k_heads * head_v_dim * head_v_dim); + } } + + // the FFN is the same for Qwen 3 Next and Qwen 3.5: if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) { const int64_t n_ff_exp = hparams.n_ff_exp; GGML_ASSERT(tensor->ne[axis] == 2*n_ff_exp); @@ -249,13 +262,16 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str const int64_t head_dim = hparams.ssm_d_state; const int64_t granularity_qkv = std::lcm(blck_size, head_dim); if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_attn_gate_weight) || - std::regex_match(tensor_name, pattern_ssm_conv1d) || std::regex_match(tensor_name, pattern_ssm_out_weight)) { + std::regex_match(tensor_name, pattern_ssm_conv1d) || std::regex_match(tensor_name, pattern_ssm_out_weight)) { return std::vector(segments.size(), granularity_qkv); } - if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) || - std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) { + if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) || + std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) { return std::vector(segments.size(), granularity_qkv / head_dim); } + if (std::regex_match(tensor_name, pattern_ssm_beta_alpha)) { + return std::vector(segments.size(), 2 * (granularity_qkv / head_dim)); + } if (std::regex_match(tensor_name, pattern_r_cache)) { return std::vector(segments.size(), granularity_qkv * (hparams.ssm_d_conv - 1)); } @@ -300,7 +316,7 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str // FFN if (std::regex_match(tensor_name, pattern_ffn_up_gate_weight) || std::regex_match(tensor_name, pattern_ffn_up_gate_bias) || - std::regex_match(tensor_name, pattern_ffn_gate_up_weight) || std::regex_match(tensor_name, pattern_ffn_down_weight)) { + std::regex_match(tensor_name, pattern_ffn_gate_up_weight) || std::regex_match(tensor_name, pattern_ffn_down_weight)) { GGML_ASSERT(segments.size() <= 2); return std::vector(segments.size(), blck_size); } diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index 98b4cb1047..5fb0a1de98 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -354,7 +354,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(last_conv_states, "last_conv_states", il); ggml_tensor * state_update_target = - ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs, + ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); cb(state_update_target, "state_update_target", il); @@ -445,7 +445,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( // Update the recurrent states ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_state, - ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]