qwen3next: trim comments
This commit is contained in:
parent
4114537fb9
commit
0a192937a1
|
|
@ -1418,17 +1418,15 @@ static void delta_net_f32_cuda(
|
|||
// - Vectors (Q,K,V,KBeta,VBeta,KCumdecay,VPrime,VNew,Out): 9 × HEAD_DIM × sizeof(float) = 4608 bytes
|
||||
// - Warp scratch: 16 × sizeof(float) = 64 bytes
|
||||
// Total: 65536 + 4608 + 64 = 70208 bytes (~68.6KB)
|
||||
// Note: __shared__ scalars (decay, beta, etc.) are static, not dynamic
|
||||
// __shared__ scalars (decay, beta, etc.) are static, not dynamic
|
||||
constexpr size_t state_bytes = 128 * 128 * sizeof(float); // 64KB
|
||||
constexpr size_t vector_bytes = 9 * 128 * sizeof(float); // 4.5KB
|
||||
constexpr size_t warp_scratch_bytes = 16 * sizeof(float); // 64B
|
||||
constexpr size_t blackwell_smem_size = state_bytes + vector_bytes + warp_scratch_bytes;
|
||||
|
||||
// Sanity check: ensure we allocated enough
|
||||
static_assert(blackwell_smem_size == 70208, "Shared memory size mismatch");
|
||||
|
||||
// Check for A/B comparison mode
|
||||
// Use a function-local static for thread-safe lazy initialization
|
||||
// A/B comparison mode (set GGML_CUDA_DELTA_NET_AB=1)
|
||||
static const bool ab_mode = []() {
|
||||
const char* env = std::getenv("GGML_CUDA_DELTA_NET_AB");
|
||||
if (env != nullptr) {
|
||||
|
|
|
|||
|
|
@ -33,12 +33,9 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
|
|||
cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// Determine layer type and build appropriate attention mechanism
|
||||
if (hparams.is_recurrent(il)) {
|
||||
// Linear attention layer (gated delta net)
|
||||
cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il);
|
||||
} else {
|
||||
// Full attention layer
|
||||
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il);
|
||||
}
|
||||
|
||||
|
|
@ -47,37 +44,28 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr
|
|||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
// Residual connection
|
||||
cur = ggml_add(ctx0, cur, inpSA);
|
||||
cb(cur, "attn_residual", il);
|
||||
|
||||
// Save the tensor before post-attention norm for residual connection
|
||||
ggml_tensor * ffn_residual = cur;
|
||||
|
||||
// Post-attention norm
|
||||
ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(attn_post_norm, "attn_post_norm", il);
|
||||
|
||||
// FFN layer (MoE or dense) - without residual connection
|
||||
cur = build_layer_ffn(attn_post_norm, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
// Residual connection for FFN - add to the tensor from before post_attention_layernorm
|
||||
cur = ggml_add(ctx0, cur, ffn_residual);
|
||||
cb(cur, "post_moe", il);
|
||||
|
||||
// Input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
cur = inpL;
|
||||
|
||||
// Final norm
|
||||
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// LM head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
|
|
@ -517,16 +505,11 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
|
|||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
// Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention
|
||||
|
||||
// Qwen3Next uses a single Q projection that outputs query + gate
|
||||
ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur_full, "Qcur_full", il);
|
||||
|
||||
Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1);
|
||||
|
||||
// Split Q projection into query and gate
|
||||
// The split should be along dimension 0 (the feature dimension)
|
||||
ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1,
|
||||
Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0);
|
||||
ggml_tensor * gate =
|
||||
|
|
@ -535,11 +518,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
|
|||
cb(Qcur, "Qcur", il);
|
||||
cb(gate, "gate", il);
|
||||
|
||||
// Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention
|
||||
Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
cb(Qcur, "Qcur_reshaped", il);
|
||||
|
||||
// Apply Q normalization
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
|
|
@ -549,18 +530,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
|
|||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// Apply K normalization
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
|
||||
// Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads)
|
||||
gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens);
|
||||
cb(gate, "gate_reshaped", il);
|
||||
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// Apply RoPE
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
|
|
@ -575,7 +553,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn(
|
|||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// Attention computation
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
|
||||
cur = build_attn(inp,
|
||||
|
|
@ -861,9 +838,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
|
|||
}
|
||||
|
||||
ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int il) {
|
||||
// Check if this is an MoE layer
|
||||
if (model.layers[il].ffn_gate_inp != nullptr) {
|
||||
// MoE branch
|
||||
ggml_tensor * moe_out =
|
||||
build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
|
||||
|
|
@ -873,7 +848,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
|
|||
true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
// Add shared experts if present - following Qwen3Next reference implementation
|
||||
if (model.layers[il].ffn_up_shexp != nullptr) {
|
||||
ggml_tensor * ffn_shexp =
|
||||
build_ffn(cur,
|
||||
|
|
@ -884,23 +858,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
|
|||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(ffn_shexp, "ffn_shexp", il);
|
||||
|
||||
// Apply shared expert gating as in the reference implementation
|
||||
// The shared expert has its own gate that is sigmoided
|
||||
// Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token)
|
||||
ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur);
|
||||
cb(shared_gate, "shared_expert_gate", il);
|
||||
|
||||
// Apply sigmoid to the gate
|
||||
shared_gate = ggml_sigmoid(ctx0, shared_gate);
|
||||
cb(shared_gate, "shared_expert_gate_sigmoid", il);
|
||||
|
||||
// The gate needs to be broadcast to match the dimensions of ffn_shexp
|
||||
// ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1]
|
||||
// We need to repeat the gate along the feature dimension
|
||||
shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp);
|
||||
cb(shared_gate, "shared_expert_gate_broadcast", il);
|
||||
|
||||
// Apply the gate to the shared expert output
|
||||
ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate);
|
||||
cb(ffn_shexp, "ffn_shexp_gated", il);
|
||||
|
||||
|
|
@ -910,7 +876,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
|
|||
cur = moe_out;
|
||||
}
|
||||
} else {
|
||||
// Dense FFN branch
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
|
|
|
|||
Loading…
Reference in New Issue