diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 566e271479..2f007260b4 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2464,6 +2464,8 @@ extern "C" { bool lower, bool uni); + // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST] + // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306 GGML_API struct ggml_tensor * ggml_gated_delta_net( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 331e071a26..d324128c89 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10436,8 +10436,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const float * state_in_base = (const float *)src_state->data; - const int64_t rq1 = nev1 / neq1; - const int64_t rk1 = nev1 / nek1; + //const int64_t rq1 = nev1 / neq1; + //const int64_t rk1 = nev1 / nek1; const int64_t rq3 = nev3 / neq3; const int64_t rk3 = nev3 / nek3; @@ -10447,8 +10447,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const int64_t iv1 = ir % H; // head_index const int64_t iv3 = ir / H; // sequence - const int64_t iq1 = iv1 / rq1; - const int64_t ik1 = iv1 / rk1; + const int64_t iq1 = iv1 % neq1; + const int64_t ik1 = iv1 % nek1; const int64_t iq3 = iv3 / rq3; const int64_t ik3 = iv3 / rk3; diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index d8e8111455..086d38648e 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -21,14 +21,14 @@ __global__ void gated_delta_net_cuda(const float * q, int64_t sb1, int64_t sb2, int64_t sb3, - int64_t rq1, + int64_t neqk1, int64_t rq3, float scale) { const int64_t h_idx = blockIdx.x; const int64_t sequence = blockIdx.y; const int col = threadIdx.x; // each thread owns one column - const int64_t iq1 = h_idx / rq1; + const int64_t iq1 = h_idx % neqk1; const int64_t iq3 = sequence / rq3; const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; @@ -119,11 +119,11 @@ static void launch_gated_delta_net( const float * q_d, const float * k_d, const float * v_d, const float * g_d, const float * b_d, const float * s_d, float * dst_d, - int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs, - int64_t sq1, int64_t sq2, int64_t sq3, - int64_t sv1, int64_t sv2, int64_t sv3, - int64_t sb1, int64_t sb2, int64_t sb3, - int64_t rq1, int64_t rq3, + int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs, + int64_t sq1, int64_t sq2, int64_t sq3, + int64_t sv1, int64_t sv2, int64_t sv3, + int64_t sb1, int64_t sb2, int64_t sb3, + int64_t neqk1, int64_t rq3, float scale, cudaStream_t stream) { dim3 grid_dims(H, n_seqs, 1); @@ -134,19 +134,19 @@ static void launch_gated_delta_net( gated_delta_net_cuda<32, KDA><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, rq1, rq3, scale); + sb1, sb2, sb3, neqk1, rq3, scale); break; case 64: gated_delta_net_cuda<64, KDA><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, rq1, rq3, scale); + sb1, sb2, sb3, neqk1, rq3, scale); break; case 128: gated_delta_net_cuda<128, KDA><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, rq1, rq3, scale); + sb1, sb2, sb3, neqk1, rq3, scale); break; default: GGML_ABORT("fatal error"); @@ -163,10 +163,12 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_tensor * src_state = dst->src[5]; GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); - GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb); + GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb); + GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne); + GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb); GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne); - GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb); - GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb); + GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb); + GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb); const int64_t S_v = nev0; const int64_t H = nev1; @@ -175,7 +177,9 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * const bool kda = (src_g->ne[0] == S_v); - const int64_t rq1 = nev1 / neq1; + GGML_ASSERT(neq1 == nek1); + const int64_t neqk1 = neq1; + const int64_t rq3 = nev3 / neq3; const float * q_d = (const float *) src_q->data; @@ -214,10 +218,10 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * if (kda) { launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, rq1, rq3, scale, stream); + sb1, sb2, sb3, neqk1, rq3, scale, stream); } else { launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, rq1, rq3, scale, stream); + sb1, sb2, sb3, neqk1, rq3, scale, stream); } } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 3c36398647..7c717fcde8 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4999,8 +4999,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g #ifdef GGML_USE_MUSA return false; #else - // TODO: add chunked support - return op->src[0]->ne[2] == 1; + // TODO: add non-KDA chunked support. for now enable chunked support for KDA only + return op->src[0]->ne[2] == 1 || op->src[3]->ne[0] == op->src[2]->ne[0]; #endif // GGML_USE_MUSA case GGML_OP_FLASH_ATTN_EXT: return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op); diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index ba096a5a7b..dda2cc74d3 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -321,9 +321,9 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes - if (num_k_heads != num_v_heads) { + // note: need explicit repeat only if we are not using the fused GDN + if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { GGML_ASSERT(num_v_heads % num_k_heads == 0); - // TODO: try to avoid these explicit repeats by utilizing op broadcast q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); } diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index fe382286e9..d9000eece6 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -321,9 +321,9 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes - if (num_k_heads != num_v_heads) { + // note: need explicit repeat only if we are not using the fused GDN + if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { GGML_ASSERT(num_v_heads % num_k_heads == 0); - // TODO: try to avoid these explicit repeats by utilizing op broadcast q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); } diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index 30912fd5e3..5d6a4ac9f7 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -406,6 +406,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes + // TODO: avoid repeats for fused GDN, needs broadcast configuration for GDN op [TAG_GGML_GDN_BCAST] if (num_k_heads != num_v_heads) { GGML_ASSERT(num_v_heads % num_k_heads == 0); int64_t repeat_factor = num_v_heads / num_k_heads;