models : avoid Q and K repeats when using fused GDA

This commit is contained in:
Georgi Gerganov 2026-03-10 11:23:25 +02:00
parent ec2443a94a
commit 39b6f5a760
7 changed files with 33 additions and 26 deletions

View File

@ -2464,6 +2464,8 @@ extern "C" {
bool lower, bool lower,
bool uni); bool uni);
// TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST]
// ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306
GGML_API struct ggml_tensor * ggml_gated_delta_net( GGML_API struct ggml_tensor * ggml_gated_delta_net(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * q, struct ggml_tensor * q,

View File

@ -10436,8 +10436,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
const float * state_in_base = (const float *)src_state->data; const float * state_in_base = (const float *)src_state->data;
const int64_t rq1 = nev1 / neq1; //const int64_t rq1 = nev1 / neq1;
const int64_t rk1 = nev1 / nek1; //const int64_t rk1 = nev1 / nek1;
const int64_t rq3 = nev3 / neq3; const int64_t rq3 = nev3 / neq3;
const int64_t rk3 = nev3 / nek3; const int64_t rk3 = nev3 / nek3;
@ -10447,8 +10447,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
const int64_t iv1 = ir % H; // head_index const int64_t iv1 = ir % H; // head_index
const int64_t iv3 = ir / H; // sequence const int64_t iv3 = ir / H; // sequence
const int64_t iq1 = iv1 / rq1; const int64_t iq1 = iv1 % neq1;
const int64_t ik1 = iv1 / rk1; const int64_t ik1 = iv1 % nek1;
const int64_t iq3 = iv3 / rq3; const int64_t iq3 = iv3 / rq3;
const int64_t ik3 = iv3 / rk3; const int64_t ik3 = iv3 / rk3;

View File

@ -21,14 +21,14 @@ __global__ void gated_delta_net_cuda(const float * q,
int64_t sb1, int64_t sb1,
int64_t sb2, int64_t sb2,
int64_t sb3, int64_t sb3,
int64_t rq1, int64_t neqk1,
int64_t rq3, int64_t rq3,
float scale) { float scale) {
const int64_t h_idx = blockIdx.x; const int64_t h_idx = blockIdx.x;
const int64_t sequence = blockIdx.y; const int64_t sequence = blockIdx.y;
const int col = threadIdx.x; // each thread owns one column const int col = threadIdx.x; // each thread owns one column
const int64_t iq1 = h_idx / rq1; const int64_t iq1 = h_idx % neqk1;
const int64_t iq3 = sequence / rq3; const int64_t iq3 = sequence / rq3;
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
@ -119,11 +119,11 @@ static void launch_gated_delta_net(
const float * q_d, const float * k_d, const float * v_d, const float * q_d, const float * k_d, const float * v_d,
const float * g_d, const float * b_d, const float * s_d, const float * g_d, const float * b_d, const float * s_d,
float * dst_d, float * dst_d,
int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs, int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs,
int64_t sq1, int64_t sq2, int64_t sq3, int64_t sq1, int64_t sq2, int64_t sq3,
int64_t sv1, int64_t sv2, int64_t sv3, int64_t sv1, int64_t sv2, int64_t sv3,
int64_t sb1, int64_t sb2, int64_t sb3, int64_t sb1, int64_t sb2, int64_t sb3,
int64_t rq1, int64_t rq3, int64_t neqk1, int64_t rq3,
float scale, cudaStream_t stream) { float scale, cudaStream_t stream) {
dim3 grid_dims(H, n_seqs, 1); dim3 grid_dims(H, n_seqs, 1);
@ -134,19 +134,19 @@ static void launch_gated_delta_net(
gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>( gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale); sb1, sb2, sb3, neqk1, rq3, scale);
break; break;
case 64: case 64:
gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>( gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale); sb1, sb2, sb3, neqk1, rq3, scale);
break; break;
case 128: case 128:
gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>( gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale); sb1, sb2, sb3, neqk1, rq3, scale);
break; break;
default: default:
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
@ -163,10 +163,12 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
ggml_tensor * src_state = dst->src[5]; ggml_tensor * src_state = dst->src[5];
GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne);
GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb); GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb);
GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne);
GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb);
GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne); GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne);
GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb); GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb);
GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb); GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb);
const int64_t S_v = nev0; const int64_t S_v = nev0;
const int64_t H = nev1; const int64_t H = nev1;
@ -175,7 +177,9 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
const bool kda = (src_g->ne[0] == S_v); const bool kda = (src_g->ne[0] == S_v);
const int64_t rq1 = nev1 / neq1; GGML_ASSERT(neq1 == nek1);
const int64_t neqk1 = neq1;
const int64_t rq3 = nev3 / neq3; const int64_t rq3 = nev3 / neq3;
const float * q_d = (const float *) src_q->data; const float * q_d = (const float *) src_q->data;
@ -214,10 +218,10 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor *
if (kda) { if (kda) {
launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, launch_gated_delta_net<true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale, stream); sb1, sb2, sb3, neqk1, rq3, scale, stream);
} else { } else {
launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, launch_gated_delta_net<false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d,
S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, rq1, rq3, scale, stream); sb1, sb2, sb3, neqk1, rq3, scale, stream);
} }
} }

View File

@ -4999,8 +4999,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
#ifdef GGML_USE_MUSA #ifdef GGML_USE_MUSA
return false; return false;
#else #else
// TODO: add chunked support // TODO: add non-KDA chunked support. for now enable chunked support for KDA only
return op->src[0]->ne[2] == 1; return op->src[0]->ne[2] == 1 || op->src[3]->ne[0] == op->src[2]->ne[0];
#endif // GGML_USE_MUSA #endif // GGML_USE_MUSA
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op); return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);

View File

@ -321,9 +321,9 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear(
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
// if head keys and value keys are different, repeat to force tensors into matching shapes // if head keys and value keys are different, repeat to force tensors into matching shapes
if (num_k_heads != num_v_heads) { // note: need explicit repeat only if we are not using the fused GDN
if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) {
GGML_ASSERT(num_v_heads % num_k_heads == 0); GGML_ASSERT(num_v_heads % num_k_heads == 0);
// TODO: try to avoid these explicit repeats by utilizing op broadcast
q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
} }

View File

@ -321,9 +321,9 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear(
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
// if head keys and value keys are different, repeat to force tensors into matching shapes // if head keys and value keys are different, repeat to force tensors into matching shapes
if (num_k_heads != num_v_heads) { // note: need explicit repeat only if we are not using the fused GDN
if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) {
GGML_ASSERT(num_v_heads % num_k_heads == 0); GGML_ASSERT(num_v_heads % num_k_heads == 0);
// TODO: try to avoid these explicit repeats by utilizing op broadcast
q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs);
} }

View File

@ -406,6 +406,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear(
//v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs);
// if head keys and value keys are different, repeat to force tensors into matching shapes // if head keys and value keys are different, repeat to force tensors into matching shapes
// TODO: avoid repeats for fused GDN, needs broadcast configuration for GDN op [TAG_GGML_GDN_BCAST]
if (num_k_heads != num_v_heads) { if (num_k_heads != num_v_heads) {
GGML_ASSERT(num_v_heads % num_k_heads == 0); GGML_ASSERT(num_v_heads % num_k_heads == 0);
int64_t repeat_factor = num_v_heads / num_k_heads; int64_t repeat_factor = num_v_heads / num_k_heads;