graph : remove redundant GDN state transposes (#20443)
* ggml : transpose fused GDN state access for coalesced memory reads (#20436) The fused Gated Delta Net kernel accessed the [S_v, S_v] state matrix column-wise on row-major storage, causing strided reads (stride S_v = 128 floats = 512 bytes) that waste GPU cache bandwidth. This produced a 39% regression on Qwen3.5-9B (Metal, M4 Max) compared to the unfused path. Transpose the state indexing so threads read contiguously: - Metal: s_ptr[is*S_v] -> s_ptr[is] (stride 1 vs S_v) - CUDA: curr_state[i*S_v+col] -> curr_state[col*S_v+i] (coalesced) - CPU: restructured loops for row-wise transposed access Also add --fused-gdn [on|off|auto] CLI flag (mirrors --flash-attn) so users can control fused GDN independently of auto-detection. All GATED_DELTA_NET backend-ops tests pass. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * ggml : use SIMD dot products in CPU GDN kernel, couple AR/chunked fused flags - Replace scalar inner loops with ggml_vec_dot_f32 for SIMD-optimized dot products in the CPU fused GDN kernel (delta and attention output) - Couple fused_gdn_ar and fused_gdn_ch flags in auto-detection: if one path lacks device support, disable both to prevent state layout mismatch between transposed (fused) and non-transposed (unfused) formats Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * llama : rever fgdn argument changes * graph : remove GDN state transposes * vulkan : adapt * cuda : remove obsolete smem code --------- Co-authored-by: Paul Flynn <paul@arkavo.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Oliver Simons <osimons@nvidia.com>
This commit is contained in:
parent
1430c35948
commit
e30f1fdf74
|
|
@ -10477,34 +10477,40 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
|
|||
const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
|
||||
const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
|
||||
|
||||
// state is stored transposed: s_out[j*S_v + i] = S[i][j]
|
||||
// so row j of s_out = column j of S (contiguous access)
|
||||
|
||||
if (kda) {
|
||||
// precompute exp(g) into delta scratch (reused below)
|
||||
for (int64_t i = 0; i < S_v; ++i) {
|
||||
ggml_vec_scale_f32(S_v, &s_out[i * S_v], expf(g_d[i]));
|
||||
delta[i] = expf(g_d[i]);
|
||||
}
|
||||
// S[i][:] *= exp(g[i]) => for each row j of M: M[j][i] *= exp(g[i])
|
||||
for (int64_t j = 0; j < S_v; ++j) {
|
||||
ggml_vec_mul_f32(S_v, &s_out[j * S_v], &s_out[j * S_v], delta);
|
||||
}
|
||||
} else {
|
||||
ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0]));
|
||||
}
|
||||
|
||||
// delta[j] = sum_i S[j][i] * k[i]
|
||||
memset(delta, 0, S_v * sizeof(float));
|
||||
for (int64_t i = 0; i < S_v; ++i) {
|
||||
ggml_vec_mad_f32(S_v, delta, &s_out[i * S_v], k_d[i]);
|
||||
}
|
||||
// delta[j] = sum_i S[i][j] * k[i] = dot(row j of M, k)
|
||||
for (int64_t j = 0; j < S_v; ++j) {
|
||||
delta[j] = (v_d[j] - delta[j]) * beta_val;
|
||||
float sum = 0.0f;
|
||||
ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, k_d, 0, 1);
|
||||
delta[j] = (v_d[j] - sum) * beta_val;
|
||||
}
|
||||
|
||||
// outer product: S[j][i] += k[i] * delta[j]
|
||||
for (int64_t i = 0; i < S_v; ++i) {
|
||||
ggml_vec_mad_f32(S_v, &s_out[i * S_v], delta, k_d[i]);
|
||||
// outer product: S[i][j] += k[i] * delta[j] => M[j][i] += delta[j] * k[i]
|
||||
for (int64_t j = 0; j < S_v; ++j) {
|
||||
ggml_vec_mad_f32(S_v, &s_out[j * S_v], k_d, delta[j]);
|
||||
}
|
||||
|
||||
// attn_out[j] = sum_i S[j][i] * q[i]
|
||||
memset(attn_data, 0, S_v * sizeof(float));
|
||||
for (int64_t i = 0; i < S_v; ++i) {
|
||||
ggml_vec_mad_f32(S_v, attn_data, &s_out[i * S_v], q_d[i]);
|
||||
// attn_out[j] = sum_i S[i][j] * q[i] = dot(row j of M, q)
|
||||
for (int64_t j = 0; j < S_v; ++j) {
|
||||
float sum = 0.0f;
|
||||
ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, q_d, 0, 1);
|
||||
attn_data[j] = sum * scale;
|
||||
}
|
||||
ggml_vec_scale_f32(S_v, attn_data, scale);
|
||||
|
||||
attn_data += S_v * H; // advance to next token
|
||||
}
|
||||
|
|
|
|||
|
|
@ -45,10 +45,11 @@ __global__ void gated_delta_net_cuda(const float * q,
|
|||
static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
|
||||
constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size;
|
||||
float s_shard[rows_per_lane];
|
||||
// state is stored transposed: M[col][i] = S[i][col], row col is contiguous
|
||||
#pragma unroll
|
||||
for (int r = 0; r < rows_per_lane; r++) {
|
||||
const int i = r * warp_size + lane;
|
||||
s_shard[r] = curr_state[i * S_v + col];
|
||||
s_shard[r] = curr_state[col * S_v + i];
|
||||
}
|
||||
|
||||
for (int t = 0; t < n_tokens; t++) {
|
||||
|
|
@ -126,23 +127,14 @@ __global__ void gated_delta_net_cuda(const float * q,
|
|||
attn_data += S_v * H;
|
||||
}
|
||||
|
||||
// Write state back to global memory
|
||||
// Write state back to global memory (transposed layout)
|
||||
#pragma unroll
|
||||
for (int r = 0; r < rows_per_lane; r++) {
|
||||
const int i = r * warp_size + lane;
|
||||
state[i * S_v + col] = s_shard[r];
|
||||
state[col * S_v + i] = s_shard[r];
|
||||
}
|
||||
}
|
||||
|
||||
static size_t calculate_smem(const int sv, int cc)
|
||||
{
|
||||
size_t smem = 0;
|
||||
if ((GGML_CUDA_CC_IS_AMD(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_RDNA4(cc)) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
|
||||
smem = sv * sv * sizeof(float);
|
||||
}
|
||||
return smem;
|
||||
}
|
||||
|
||||
template <bool KDA>
|
||||
static void launch_gated_delta_net(
|
||||
const float * q_d, const float * k_d, const float * v_d,
|
||||
|
|
@ -179,18 +171,14 @@ static void launch_gated_delta_net(
|
|||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
|
||||
break;
|
||||
case 64: {
|
||||
constexpr int sv = 64;
|
||||
size_t smem = calculate_smem(sv, cc);
|
||||
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
|
||||
gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
|
||||
break;
|
||||
}
|
||||
case 128: {
|
||||
constexpr int sv = 128;
|
||||
size_t smem = calculate_smem(sv, cc);
|
||||
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
|
||||
gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
|
||||
|
|
|
|||
|
|
@ -2469,13 +2469,14 @@ kernel void kernel_gated_delta_net_impl(
|
|||
|
||||
const float scale = 1.0f / sqrt((float)S_v);
|
||||
|
||||
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20;
|
||||
// state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous
|
||||
device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
|
||||
|
||||
float ls[NSG];
|
||||
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
ls[j] = s_ptr[is*S_v];
|
||||
ls[j] = s_ptr[is];
|
||||
}
|
||||
|
||||
device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20;
|
||||
|
|
@ -2536,11 +2537,11 @@ kernel void kernel_gated_delta_net_impl(
|
|||
g_ptr += args.ne21*G;
|
||||
}
|
||||
|
||||
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;
|
||||
device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v;
|
||||
|
||||
FOR_UNROLL (short j = 0; j < NSG; j++) {
|
||||
const short is = tx*NSG + j;
|
||||
dst_state[is*S_v] = ls[j];
|
||||
dst_state[is] = ls[j];
|
||||
}
|
||||
|
||||
#undef S_v
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ void main() {
|
|||
|
||||
FLOAT_TYPE state[S_V];
|
||||
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
||||
state[i] = FLOAT_TYPE(data_state[state_base + i * S_V + col]);
|
||||
state[i] = FLOAT_TYPE(data_state[state_base + col * S_V + i]);
|
||||
}
|
||||
|
||||
uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
|
||||
|
|
@ -123,6 +123,6 @@ void main() {
|
|||
}
|
||||
|
||||
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
||||
data_dst[s_off + state_base + i * S_V + col] = state[i];
|
||||
data_dst[s_off + state_base + col * S_V + i] = state[i];
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -225,9 +225,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
|||
ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg));
|
||||
cb(kg_t, "key_gdiff_t", il);
|
||||
|
||||
ggml_tensor * s_t = ggml_transpose(ctx0, s);
|
||||
s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs);
|
||||
cb(s_t, "dnet_add_ch_state", il);
|
||||
s = ggml_reshape_4d(ctx0, s, S_v, S_v, 1, H_v * n_seqs);
|
||||
cb(s, "dnet_add_ch_state", il);
|
||||
|
||||
// [CS, S_v, n_chunks, H_v * n_seqs]
|
||||
ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v));
|
||||
|
|
@ -240,7 +239,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
|||
ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs]
|
||||
|
||||
// [CS, S_v, 1, H_v * n_seqs]
|
||||
ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t);
|
||||
ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s);
|
||||
cb(v_t_p, "v_prime", il);
|
||||
|
||||
// [CS, S_v, 1, H_v * n_seqs]
|
||||
|
|
@ -252,7 +251,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
|||
cb(v_attn, "v_attn", il);
|
||||
|
||||
// [S_v, CS, 1, H_v * n_seqs]
|
||||
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp);
|
||||
ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s, ch_q_g_exp);
|
||||
cb(attn_inter, "attn_inter", il);
|
||||
|
||||
// [S_v, CS, 1, H_v * n_seqs]
|
||||
|
|
@ -268,13 +267,11 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
|||
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
|
||||
ggml_tensor * ch_g_last_exp_t = get_slice_2d(ctx0, g_last_exp_t, chunk);
|
||||
|
||||
s_t = ggml_mul(ctx0, s_t, ch_g_last_exp_t);
|
||||
s_t = ggml_add(ctx0, s_t, kgv);
|
||||
cb(s_t, "dnet_add_ch_state", il);
|
||||
s = ggml_mul(ctx0, s, ch_g_last_exp_t);
|
||||
s = ggml_add(ctx0, s, kgv);
|
||||
cb(s, "dnet_add_ch_state", il);
|
||||
}
|
||||
|
||||
s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs);
|
||||
|
||||
// truncate padded tokens
|
||||
ggml_tensor * o = ggml_view_4d(ctx0, v,
|
||||
S_v, n_tokens, H_v, n_seqs,
|
||||
|
|
@ -282,7 +279,7 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
|||
ggml_row_size(v->type, S_v * CS * n_chunks),
|
||||
ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0);
|
||||
o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
|
||||
s = ggml_transpose(ctx0, s_t);
|
||||
s = ggml_reshape_4d(ctx0, s, S_v, S_v, H_v, n_seqs);
|
||||
cb(s, "output_state", il);
|
||||
|
||||
return {o, s};
|
||||
|
|
@ -341,11 +338,9 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
|||
g = ggml_exp(ctx0, g);
|
||||
s = ggml_mul(ctx0, s, g);
|
||||
|
||||
ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s));
|
||||
|
||||
// [1, S_v, H_v, n_seqs]
|
||||
ggml_tensor * sk;
|
||||
sk = ggml_mul (ctx0, s_t, k);
|
||||
sk = ggml_mul (ctx0, s, k);
|
||||
sk = ggml_sum_rows(ctx0, sk);
|
||||
|
||||
// [S_v, 1, H_v, n_seqs]
|
||||
|
|
@ -362,15 +357,14 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne
|
|||
k = ggml_repeat(ctx0, k, s);
|
||||
kd = ggml_mul (ctx0, k, d_t);
|
||||
|
||||
s_t = ggml_add(ctx0, s_t, kd);
|
||||
s = ggml_add(ctx0, s, kd);
|
||||
|
||||
cb(s_t, "dnet_add_ar_state", il);
|
||||
cb(s, "dnet_add_ar_state", il);
|
||||
|
||||
ggml_tensor * s_q = ggml_mul (ctx0, s_t, q);
|
||||
ggml_tensor * s_q = ggml_mul (ctx0, s, q);
|
||||
ggml_tensor * o = ggml_sum_rows(ctx0, s_q);
|
||||
|
||||
o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs]
|
||||
s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs]
|
||||
|
||||
return {o, s};
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue