TP: fix Qwen 3 Next data split (#21732)

This commit is contained in:
Johannes Gäßler 2026-04-11 09:23:42 +02:00 committed by GitHub
parent 2b2cd57de6
commit 865ff06b2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 39 additions and 23 deletions

View File

@ -202,24 +202,37 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
const int64_t n_v_heads = hparams.ssm_dt_rank;
const int64_t key_dim = head_k_dim * n_k_heads;
const int64_t value_dim = head_v_dim * n_v_heads;
const int64_t head_ratio = n_v_heads / n_k_heads;
if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) {
GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim);
return std::vector<int64_t>(2 + head_ratio, key_dim);
}
if (std::regex_match(tensor_name, pattern_attn_gate_weight) || std::regex_match(tensor_name, pattern_ssm_out_weight)) {
return std::vector<int64_t>(head_ratio, key_dim);
}
if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) ||
std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) {
return std::vector<int64_t>(head_ratio, n_k_heads);
}
if (std::regex_match(tensor_name, pattern_r_cache)) {
return std::vector<int64_t>(2 + head_ratio, key_dim * (hparams.ssm_d_conv - 1));
}
if (std::regex_match(tensor_name, pattern_s_cache)) {
return std::vector<int64_t>(head_ratio, n_k_heads * head_v_dim * head_v_dim);
// both Qwen 3 Next and Qwen 3.5 support n_v_heads > n_k_heads but the broadcasting pattern is different:
// - Qwen 3 Next: [k0_v0, k0_v1, k1_v2, k1_v3] (this is the default split pattern)
// - Qwen 3.5: [k0_v0, k1_v1, k0_v2, k1_v3] (needs segmenting of V on the scale of K to get the correct pattern)
if (ud->model->arch == LLM_ARCH_QWEN3NEXT) {
if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) {
GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim);
return {key_dim, key_dim, value_dim};
}
} else {
const int64_t head_ratio = n_v_heads / n_k_heads;
if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) {
GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim);
return std::vector<int64_t>(2 + head_ratio, key_dim);
}
if (std::regex_match(tensor_name, pattern_attn_gate_weight) || std::regex_match(tensor_name, pattern_ssm_out_weight)) {
return std::vector<int64_t>(head_ratio, key_dim);
}
if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) ||
std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) {
return std::vector<int64_t>(head_ratio, n_k_heads);
}
if (std::regex_match(tensor_name, pattern_r_cache)) {
return std::vector<int64_t>(2 + head_ratio, key_dim * (hparams.ssm_d_conv - 1));
}
if (std::regex_match(tensor_name, pattern_s_cache)) {
return std::vector<int64_t>(head_ratio, n_k_heads * head_v_dim * head_v_dim);
}
}
// the FFN is the same for Qwen 3 Next and Qwen 3.5:
if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) {
const int64_t n_ff_exp = hparams.n_ff_exp;
GGML_ASSERT(tensor->ne[axis] == 2*n_ff_exp);
@ -249,13 +262,16 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
const int64_t head_dim = hparams.ssm_d_state;
const int64_t granularity_qkv = std::lcm(blck_size, head_dim);
if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_attn_gate_weight) ||
std::regex_match(tensor_name, pattern_ssm_conv1d) || std::regex_match(tensor_name, pattern_ssm_out_weight)) {
std::regex_match(tensor_name, pattern_ssm_conv1d) || std::regex_match(tensor_name, pattern_ssm_out_weight)) {
return std::vector<int64_t>(segments.size(), granularity_qkv);
}
if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) ||
std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) {
if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) ||
std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) {
return std::vector<int64_t>(segments.size(), granularity_qkv / head_dim);
}
if (std::regex_match(tensor_name, pattern_ssm_beta_alpha)) {
return std::vector<int64_t>(segments.size(), 2 * (granularity_qkv / head_dim));
}
if (std::regex_match(tensor_name, pattern_r_cache)) {
return std::vector<int64_t>(segments.size(), granularity_qkv * (hparams.ssm_d_conv - 1));
}
@ -300,7 +316,7 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str
// FFN
if (std::regex_match(tensor_name, pattern_ffn_up_gate_weight) || std::regex_match(tensor_name, pattern_ffn_up_gate_bias) ||
std::regex_match(tensor_name, pattern_ffn_gate_up_weight) || std::regex_match(tensor_name, pattern_ffn_down_weight)) {
std::regex_match(tensor_name, pattern_ffn_gate_up_weight) || std::regex_match(tensor_name, pattern_ffn_down_weight)) {
GGML_ASSERT(segments.size() <= 2);
return std::vector<int64_t>(segments.size(), blck_size);
}

View File

@ -354,7 +354,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
cb(last_conv_states, "last_conv_states", il);
ggml_tensor * state_update_target =
ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs,
ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1],
kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all));
cb(state_update_target, "state_update_target", il);
@ -445,7 +445,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
// Update the recurrent states
ggml_build_forward_expand(gf,
ggml_cpy(ctx0, new_state,
ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1],
kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
// z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim]