models : fix graph splits (#19866)

This commit is contained in:
Georgi Gerganov 2026-02-25 00:01:13 +02:00 committed by GitHub
parent 47eb12b953
commit 244641955f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 8 additions and 3 deletions

View File

@ -116,6 +116,8 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll
cur = build_norm(inpL, layer.attn_norm, NULL, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
ggml_build_forward_expand(gf, cur);
// Check layer type by checking which tensors exist
// KDA layers have ssm_a_log tensor, MLA layers have wkv_a_mqa tensor
bool is_kda = (layer.ssm_a != nullptr);

View File

@ -29,6 +29,8 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
ggml_build_forward_expand(gf, cur);
// Determine layer type and build appropriate attention mechanism
if (hparams.is_recurrent(il)) {
// Linear attention layer (gated delta net)
@ -269,7 +271,6 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
cb(state_update_target, "state_update_target", il);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
cb(conv_states_all, "conv_states_updated", il);
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);

View File

@ -29,6 +29,8 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
ggml_build_forward_expand(gf, cur);
// Determine layer type and build appropriate attention mechanism
if (hparams.is_recurrent(il)) {
// Linear attention layer (gated delta net)
@ -269,7 +271,6 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
cb(state_update_target, "state_update_target", il);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
cb(conv_states_all, "conv_states_updated", il);
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);

View File

@ -21,6 +21,8 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
ggml_build_forward_expand(gf, cur);
// Determine layer type and build appropriate attention mechanism
if (hparams.is_recurrent(il)) {
// Linear attention layer (gated delta net)
@ -354,7 +356,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
cb(state_update_target, "state_update_target", il);
ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target));
cb(conv_states_all, "conv_states_updated", il);
ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs);
state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs);