models : avoid Q and K repeats when using fused GDA

This commit is contained in:
Georgi Gerganov 2026-03-10 11:23:25 +02:00
parent ec2443a94a
commit 39b6f5a760
7 changed files with 33 additions and 26 deletions

View File

@ -2464,6 +2464,8 @@ extern "C" {
bool lower,
bool uni);
// TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST]
// ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306
GGML_API struct ggml_tensor * ggml_gated_delta_net(
struct ggml_context * ctx,
struct ggml_tensor * q,

View File

@ -10436,8 +10436,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
const float * state_in_base = (const float *)src_state->data;
const int64_t rq1 = nev1 / neq1;
const int64_t rk1 = nev1 / nek1;
//const int64_t rq1 = nev1 / neq1;
//const int64_t rk1 = nev1 / nek1;
const int64_t rq3 = nev3 / neq3;
const int64_t rk3 = nev3 / nek3;
@ -10447,8 +10447,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
const int64_t iv1 = ir % H; // head_index
const int64_t iv3 = ir / H; // sequence
const int64_t iq1 = iv1 / rq1;
const int64_t ik1 = iv1 / rk1;
const int64_t iq1 = iv1 % neq1;
const int64_t ik1 = iv1 % nek1;
const int64_t iq3 = iv3 / rq3;
const int64_t ik3 = iv3 / rk3;

View File

@ -21,14 +21,14 @@ __global__ void gated_delta_net_cuda(const float * q,
int64_t sb1,
int64_t sb2,
int64_t sb3,
int64_t rq1,
int64_t neqk1,
int64_t rq3,
float scale) {
const int64_t h_idx = blockIdx.x;
const int64_t sequence = blockIdx.y;
const int col = threadIdx.x; // each thread owns one column
const int64_t iq1 = h_idx / rq1;
const int64_t iq1 = h_idx % neqk1;
const int64_t iq3 = sequence / rq3;
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
@ -119,11 +119,11 @@ static void launch_gated_delta_net(
const float * q_d, const float * k_d, const float * v_d,
const float * g_d, const float * b_d, const float * s_d,
float * dst_d,
int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
int64_t sq1, int64_t sq2, int64_t sq3,
int64_t sv1, int64_t sv2, int64_t sv3,
int64_t sb1, int64_t sb2, int64_t sb3,
int64_t rq1, int64_t rq3,
int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
int64_t sq1, int64_t sq2, int64_t sq3,
int64_t sv1, int64_t sv2, int64_t sv3,
int64_t sb1, int64_t sb2, int64_t sb3,
int64_t neqk1, int64_t rq3,
float scale, cudaStream_t stream) {
dim3 grid_dims(H, n_seqs, 1);
@ -134,19 +134,19 @@ static void launch_gated_delta_net(
gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
sb1, sb2, sb3, neqk1, rq3, scale);
break;
case 64:
gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
sb1, sb2, sb3, neqk1, rq3, scale);
break;
case 128:
gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale);
sb1, sb2, sb3, neqk1, rq3, scale);
break;
default:
GGML_ABORT("fatal error");
@ -163,10 +163,12 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
ggml_tensor * src_state = dst->src[5];
GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb);
GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb);
GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb);
GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
const int64_t S_v = nev0;
const int64_t H = nev1;
@ -175,7 +177,9 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
const bool kda = (src_g->ne[0] == S_v);
const int64_t rq1 = nev1 / neq1;
GGML_ASSERT(neq1 == nek1);
const int64_t neqk1 = neq1;
const int64_t rq3 = nev3 / neq3;
const float * q_d = (const float *) src_q->data;
@ -214,10 +218,10 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
if (kda) {
launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale, stream);
sb1, sb2, sb3, neqk1, rq3, scale, stream);
} else {
launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale, stream);
sb1, sb2, sb3, neqk1, rq3, scale, stream);
}
}

View File

@ -4999,8 +4999,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
#ifdef GGML_USE_MUSA
return false;
#else
// TODO: add chunked support
return op->src[0]->ne[2] == 1;
// TODO: add non-KDA chunked support. for now enable chunked support for KDA only
return op->src[0]->ne[2] == 1 || op->src[3]->ne[0] == op->src[2]->ne[0];
#endif // GGML_USE_MUSA
case GGML_OP_FLASH_ATTN_EXT:
return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);

View File

@ -321,9 +321,9 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
// if head keys and value keys are different, repeat to force tensors into matching shapes
if (num_k_heads != num_v_heads) {
// note: need explicit repeat only if we are not using the fused GDN
if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) {
GGML_ASSERT(num_v_heads % num_k_heads == 0);
// TODO: try to avoid these explicit repeats by utilizing op broadcast
q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
}

View File

@ -321,9 +321,9 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
// if head keys and value keys are different, repeat to force tensors into matching shapes
if (num_k_heads != num_v_heads) {
// note: need explicit repeat only if we are not using the fused GDN
if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) {
GGML_ASSERT(num_v_heads % num_k_heads == 0);
// TODO: try to avoid these explicit repeats by utilizing op broadcast
q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
}

View File

@ -406,6 +406,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
// if head keys and value keys are different, repeat to force tensors into matching shapes
// TODO: avoid repeats for fused GDN, needs broadcast configuration for GDN op [TAG_GGML_GDN_BCAST]
if (num_k_heads != num_v_heads) {
GGML_ASSERT(num_v_heads % num_k_heads == 0);
int64_t repeat_factor = num_v_heads / num_k_heads;