diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 3323f8e6c3..25f9601e9b 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2466,6 +2466,8 @@ extern "C" { bool lower, bool uni); + // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST] + // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306 GGML_API struct ggml_tensor * ggml_gated_delta_net( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index f9c4ec16e4..fa9d27046b 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10443,8 +10443,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const float * state_in_base = (const float *)src_state->data; - const int64_t rq1 = nev1 / neq1; - const int64_t rk1 = nev1 / nek1; + //const int64_t rq1 = nev1 / neq1; + //const int64_t rk1 = nev1 / nek1; const int64_t rq3 = nev3 / neq3; const int64_t rk3 = nev3 / nek3; @@ -10454,8 +10454,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const int64_t iv1 = ir % H; // head_index const int64_t iv3 = ir / H; // sequence - const int64_t iq1 = iv1 / rq1; - const int64_t ik1 = iv1 / rk1; + const int64_t iq1 = iv1 % neq1; + const int64_t ik1 = iv1 % nek1; const int64_t iq3 = iv3 / rq3; const int64_t ik3 = iv3 / rk3; @@ -10475,7 +10475,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1); const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1); - const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1); + const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1); if (kda) { for (int64_t i = 0; i < S_v; ++i) { @@ -10508,7 +10508,6 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( attn_data += S_v * H; // advance to next token } - } } diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index c249bbc86d..5f0fa8e58d 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -1,36 +1,36 @@ #include "gated_delta_net.cuh" -#include "ggml-cuda/common.cuh" template -__global__ void __launch_bounds__(S_v, 1) -gated_delta_net_cuda(const float * q, - const float * k, - const float * v, - const float * g, - const float * beta, - const float * curr_state, - float * dst, - const int64_t H, - const int64_t n_tokens, - const int64_t n_seqs, - const int64_t sq1, - const int64_t sq2, - const int64_t sq3, - const int64_t sv1, - const int64_t sv2, - const int64_t sv3, - const int64_t sb1, - const int64_t sb2, - const int64_t sb3, - const int64_t rq1, - const int64_t rq3, - const float scale) { - const int64_t h_idx = blockIdx.x; - const int64_t sequence = blockIdx.y; - const int col = threadIdx.x; // each thread owns one column +__global__ void gated_delta_net_cuda(const float * q, + const float * k, + const float * v, + const float * g, + const float * beta, + const float * curr_state, + float * dst, + int64_t H, + int64_t n_tokens, + int64_t n_seqs, + int64_t sq1, + int64_t sq2, + int64_t sq3, + int64_t sv1, + int64_t sv2, + int64_t sv3, + int64_t sb1, + int64_t sb2, + int64_t sb3, + const uint3 neqk1_magic, + const uint3 rq3_magic, + float scale) { + const uint32_t h_idx = blockIdx.x; + const uint32_t sequence = blockIdx.y; + // each warp owns one column, using warp-level primitives to reduce across rows + const int lane = threadIdx.x; + const int col = blockIdx.z * blockDim.y + threadIdx.y; - const int64_t iq1 = h_idx / rq1; - const int64_t iq3 = sequence / rq3; + const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic); + const uint32_t iq3 = fastdiv(sequence, rq3_magic); const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; float * attn_data = dst; @@ -41,17 +41,14 @@ gated_delta_net_cuda(const float * q, curr_state += state_offset; attn_data += (sequence * n_tokens * H + h_idx) * S_v; - // GCN and CDNA devices spill registers, we use shared mem for them. See https://github.com/ggml-org/llama.cpp/pull/20282#issuecomment-4025770229 - // TODO: check optimal path for RDNA1 and RDNA2 devices. -#if (defined(GGML_USE_HIP) && !defined(RDNA3) && !defined(RDNA4)) || defined(GGML_USE_MUSA) - extern __shared__ float s_shared[]; - float * s = s_shared + col * S_v; -#else - float s[S_v]; -#endif + constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v; + static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size"); + constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size; + float s_shard[rows_per_lane]; #pragma unroll - for (int i = 0; i < S_v; i++) { - s[i] = curr_state[i * S_v + col]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + s_shard[r] = curr_state[i * S_v + col]; } for (int t = 0; t < n_tokens; t++) { @@ -69,46 +66,61 @@ gated_delta_net_cuda(const float * q, const float g_val = expf(*g_t); // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i] - float kv_col = 0.0f; + float kv_shard = 0.0f; #pragma unroll - for (int i = 0; i < S_v; i++) { - kv_col += s[i] * k_t[i]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + kv_shard += s_shard[r] * k_t[i]; } + float kv_col = warp_reduce_sum(kv_shard); // delta[col] = (v[col] - g * kv[col]) * beta float delta_col = (v_t[col] - g_val * kv_col) * beta_val; // fused: S[i][col] = g * S[i][col] + k[i] * delta[col] // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i] - float attn_col = 0.0f; + float attn_partial = 0.0f; #pragma unroll - for (int i = 0; i < S_v; i++) { - s[i] = g_val * s[i] + k_t[i] * delta_col; - attn_col += s[i] * q_t[i]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col; + attn_partial += s_shard[r] * q_t[i]; } - attn_data[col] = attn_col * scale; + float attn_col = warp_reduce_sum(attn_partial); + + if (lane == 0) { + attn_data[col] = attn_col * scale; + } } else { // kv[col] = sum_i g[i] * S[i][col] * k[i] - float kv_col = 0.0f; + float kv_shard = 0.0f; #pragma unroll - for (int i = 0; i < S_v; i++) { - kv_col += expf(g_t[i]) * s[i] * k_t[i]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i]; } + float kv_col = warp_reduce_sum(kv_shard); + // delta[col] = (v[col] - kv[col]) * beta float delta_col = (v_t[col] - kv_col) * beta_val; // fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col] // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i] - float attn_col = 0.0f; + float attn_partial = 0.0f; #pragma unroll - for (int i = 0; i < S_v; i++) { - s[i] = expf(g_t[i]) * s[i] + k_t[i] * delta_col; - attn_col += s[i] * q_t[i]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + s_shard[r] = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col; + attn_partial += s_shard[r] * q_t[i]; } - attn_data[col] = attn_col * scale; + float attn_col = warp_reduce_sum(attn_partial); + + if (lane == 0) { + attn_data[col] = attn_col * scale; + } } attn_data += S_v * H; @@ -116,8 +128,9 @@ gated_delta_net_cuda(const float * q, // Write state back to global memory #pragma unroll - for (int i = 0; i < S_v; i++) { - state[i * S_v + col] = s[i]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + state[i * S_v + col] = s_shard[r]; } } @@ -135,35 +148,43 @@ static void launch_gated_delta_net( const float * q_d, const float * k_d, const float * v_d, const float * g_d, const float * b_d, const float * s_d, float * dst_d, - int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs, - int64_t sq1, int64_t sq2, int64_t sq3, - int64_t sv1, int64_t sv2, int64_t sv3, - int64_t sb1, int64_t sb2, int64_t sb3, - int64_t rq1, int64_t rq3, + int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs, + int64_t sq1, int64_t sq2, int64_t sq3, + int64_t sv1, int64_t sv2, int64_t sv3, + int64_t sb1, int64_t sb2, int64_t sb3, + int64_t neqk1, int64_t rq3, float scale, cudaStream_t stream) { + //TODO: Add chunked kernel for even faster pre-fill + const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; + const int num_warps = 4; + dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps); + dim3 block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1); - dim3 grid_dims(H, n_seqs, 1); - dim3 block_dims(S_v, 1, 1); + const uint3 neqk1_magic = init_fastdiv_values(neqk1); + const uint3 rq3_magic = init_fastdiv_values(rq3); int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; switch (S_v) { - case 32: { - constexpr int sv = 32; - size_t smem = calculate_smem(sv, cc); - gated_delta_net_cuda<<>>( + case 16: + gated_delta_net_cuda<16, KDA><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, rq1, rq3, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + break; + case 32: + gated_delta_net_cuda<32, KDA><<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; - } case 64: { constexpr int sv = 64; size_t smem = calculate_smem(sv, cc); gated_delta_net_cuda<<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, rq1, rq3, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; } case 128: { @@ -172,7 +193,7 @@ static void launch_gated_delta_net( gated_delta_net_cuda<<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, rq1, rq3, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; } default: @@ -190,10 +211,12 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_tensor * src_state = dst->src[5]; GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); - GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb); + GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb); + GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne); + GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb); GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne); - GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb); - GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb); + GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb); + GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb); const int64_t S_v = nev0; const int64_t H = nev1; @@ -202,7 +225,9 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * const bool kda = (src_g->ne[0] == S_v); - const int64_t rq1 = nev1 / neq1; + GGML_ASSERT(neq1 == nek1); + const int64_t neqk1 = neq1; + const int64_t rq3 = nev3 / neq3; const float * q_d = (const float *) src_q->data; @@ -241,10 +266,10 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * if (kda) { launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, rq1, rq3, scale, stream); + sb1, sb2, sb3, neqk1, rq3, scale, stream); } else { launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, rq1, rq3, scale, stream); + sb1, sb2, sb3, neqk1, rq3, scale, stream); } } diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 169c63dd7a..15ae2e517d 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -577,6 +577,41 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_ return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(ggml_metal_library_t lib, const ggml_tensor * op) { + char base[256]; + char name[256]; + + // v is src[2], dimensions: S_v = ne[0], H = ne[1] + const int ne20 = op->src[2]->ne[0]; // S_v + const int ne21 = op->src[2]->ne[1]; // H + const int ne30 = op->src[3]->ne[0]; // G + + const int nsg = op->src[2]->ne[0]/32; + + GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32); + GGML_ASSERT(op->ne[0] == ne20 * ne21); + GGML_ASSERT(ne20 % 32 == 0); + + snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg); + snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0); + ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + } + + res.nsg = nsg; + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) { char base[256]; char name[256]; diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 93d7f6a216..fd2b3ddeb5 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -125,6 +125,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index d42b8ab1eb..05b826a61b 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1155,6 +1155,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: return true; + case GGML_OP_GATED_DELTA_NET: + return op->src[2]->ne[0] % 32 == 0; case GGML_OP_SOLVE_TRI: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 99d64efc3b..53437b23cd 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -84,6 +84,7 @@ #define FC_BIN 1300 #define FC_SUM_ROWS 1400 #define FC_UPSCALE 1500 +#define FC_GATED_DELTA_NET 1600 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPSG 8 @@ -793,6 +794,44 @@ typedef struct { uint64_t nb0; } ggml_metal_kargs_ssm_scan; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne20; + int32_t ne21; + int32_t ne22; + int32_t ne23; + uint64_t nb20; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + int32_t ns02; + int32_t ns12; + int32_t ns22; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_gated_delta_net; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 267755d08c..306dbcf366 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -333,6 +333,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_rwkv(ctx, idx); } break; + case GGML_OP_GATED_DELTA_NET: + { + n_fuse = ggml_metal_op_gated_delta_net(ctx, idx); + } break; case GGML_OP_SOLVE_TRI: { n_fuse = ggml_metal_op_solve_tri(ctx, idx); @@ -1562,6 +1566,81 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + auto pipeline = ggml_metal_library_get_pipeline_gated_delta_net(lib, op); + + int ida = 0; + + ggml_metal_kargs_gated_delta_net args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne20 =*/ ne20, + /*.ne21 =*/ ne21, + /*.ne22 =*/ ne22, + /*.ne23 =*/ ne23, + /*.nb20 =*/ nb20, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, + /*.ns02 =*/ (int32_t) (nb02/sizeof(float)), + /*.ns12 =*/ (int32_t) (nb12/sizeof(float)), + /*.ns22 =*/ (int32_t) (nb22/sizeof(float)), + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); // q + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); // k + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); // v + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); // gate + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); // beta + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); // state + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); // dst + + const int nsg = pipeline.nsg; + + ggml_metal_encoder_dispatch_threadgroups(enc, op->src[2]->ne[0]/nsg, op->src[2]->ne[1], op->src[2]->ne[3], 32, nsg, 1); + + return 1; +} + int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index f3e38c7aa9..019f2fec9e 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -58,6 +58,7 @@ int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_gated_delta_net (ggml_metal_op_t ctx, int idx); int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx); int ggml_metal_op_set (ggml_metal_op_t ctx, int idx); int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 29e4a245d5..0b77d5349b 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2434,6 +2434,227 @@ kernel void kernel_rwkv_wkv7_f32( } } +constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]]; +constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]]; + +#if 1 +template +kernel void kernel_gated_delta_net_impl( + constant ggml_metal_kargs_gated_delta_net & args, + device const char * q, + device const char * k, + device const char * v, + device const char * g, + device const char * b, + device const char * s, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { +#define S_v FC_gated_delta_net_ne20 +#define G FC_gated_delta_net_ne30 + + const uint tx = tpitg.x; + const uint ty = tpitg.y; + + const uint i23 = tgpig.z; // B + const uint i21 = tgpig.y; // H + const uint i20 = tgpig.x*NSG + ty; + + const uint i01 = i21 % args.ne01; + const uint i11 = i21 % args.ne11; + + const float scale = 1.0f / sqrt((float)S_v); + + device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20; + + float ls[NSG]; + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + ls[j] = s_ptr[is*S_v]; + } + + device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20; + + device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01); + device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11); + device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21); + + device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); + device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; + + for (short t = 0; t < args.ne22; t++) { + float s_k = 0.0f; + + if (G == 1) { + const float g_exp = exp(g_ptr[0]); + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + ls[j] *= g_exp; + + s_k += ls[j]*k_ptr[is]; + } + } else { + // KDA + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + ls[j] *= exp(g_ptr[is]); + + s_k += ls[j]*k_ptr[is]; + } + } + + s_k = simd_sum(s_k); + + const float d = (v_ptr[i20] - s_k)*b_ptr[0]; + + float y = 0.0f; + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + ls[j] += k_ptr[is]*d; + + y += ls[j]*q_ptr[is]; + } + + y = simd_sum(y); + + if (tx == 0) { + dst_attn[t*args.ne21*S_v] = y*scale; + } + + q_ptr += args.ns02; + k_ptr += args.ns12; + v_ptr += args.ns22; + + b_ptr += args.ne21; + g_ptr += args.ne21*G; + } + + device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20; + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is*S_v] = ls[j]; + } + +#undef S_v +#undef G +} + +typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t; + +template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>; +template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>; +template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>; + +#else +// a simplified version of the above +// no performance improvement, so keep the above version for now + +template +kernel void kernel_gated_delta_net_impl( + constant ggml_metal_kargs_gated_delta_net & args, + device const char * q, + device const char * k, + device const char * v, + device const char * g, + device const char * b, + device const char * s, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { +#define S_v FC_gated_delta_net_ne20 +#define G FC_gated_delta_net_ne30 + + const uint tx = tpitg.x; + const uint ty = tpitg.y; + + const uint i23 = tgpig.z; // B + const uint i21 = tgpig.y; // H + const uint i20 = tgpig.x*NSG + ty; + + const uint i01 = i21 % args.ne01; + const uint i11 = i21 % args.ne11; + + const float scale = 1.0f / sqrt((float)S_v); + + device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20; + + float lsf[NSG]; + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + lsf[j] = s_ptr[is*S_v]; + } + + thread T * ls = (thread T *) (lsf); + + device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20; + + device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01); + device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11); + device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21); + + device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); + device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; + + for (short t = 0; t < args.ne22; t++) { + device const T * qt_ptr = (device const T *) (q_ptr); + device const T * kt_ptr = (device const T *) (k_ptr); + device const T * gt_ptr = (device const T *) (g_ptr); + + if (G == 1) { + *ls *= exp(g_ptr[0]); + } else { + // KDA + *ls *= exp(gt_ptr[tx]); + } + + const float s_k = simd_sum(dot(*ls, kt_ptr[tx])); + + const float d = (v_ptr[i20] - s_k)*b_ptr[0]; + + *ls += kt_ptr[tx]*d; + + const float y = simd_sum(dot(*ls, qt_ptr[tx])); + + if (tx == 0) { + *dst_attn = y*scale; + } + + q_ptr += args.ns02; + k_ptr += args.ns12; + v_ptr += args.ns22; + + b_ptr += args.ne21; + g_ptr += args.ne21*G; + + dst_attn += args.ne21*S_v; + } + + device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20; + device T * dstt_state = (device T *) (dst_state); + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is*S_v] = lsf[j]; + } + +#undef S_v +#undef G +} + +typedef decltype(kernel_gated_delta_net_impl) kernel_gated_delta_net_t; + +template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl; +template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl; +template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl; +#endif + constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]]; constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]]; constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]]; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ee2669c154..0be9493910 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -151,7 +151,8 @@ llama_context::llama_context( cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO; cparams.fused_gdn_ar = true; - cparams.fused_gdn_ch = false; // TODO: implement + cparams.fused_gdn_ch = true; + cparams.auto_fgdn = true; // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; @@ -462,37 +463,81 @@ void llama_context::sched_reserve() { cparams.auto_fa = false; } - if (cparams.fused_gdn_ar) { - auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); - if (!gf) { - throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check"); - } + if (cparams.auto_fgdn) { + LLAMA_LOG_INFO("%s: resolving fused Gated Delta Net support:\n", __func__); - const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDNAR) + 1; - bool gdn_device_mismatch = false; - for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { - ggml_tensor * n = ggml_graph_node(gf, i); - if (n->op != GGML_OP_GATED_DELTA_NET) { - continue; + if (cparams.fused_gdn_ar) { + auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); + if (!gf) { + throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (autoregressive)"); } - ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n)); - GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDNAR "-", prefix_len) == 0); - const int il = std::stoi(n->name + prefix_len); - ggml_backend_dev_t device_kv = model.dev_layer(il); - if (device_gdn != device_kv) { - LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor " - "is assigned to device %s (usually due to missing support)\n", - __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn)); - gdn_device_mismatch = true; - break; + const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_AR) + 1; + bool gdn_device_mismatch = false; + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * n = ggml_graph_node(gf, i); + if (n->op != GGML_OP_GATED_DELTA_NET) { + continue; + } + ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n)); + + GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_AR "-", prefix_len) == 0); + const int il = std::stoi(n->name + prefix_len); + ggml_backend_dev_t device_kv = model.dev_layer(il); + if (device_gdn != device_kv) { + LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor " + "is assigned to device %s (usually due to missing support)\n", + __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn)); + gdn_device_mismatch = true; + break; + } + } + + if (gdn_device_mismatch) { + cparams.fused_gdn_ar = false; + LLAMA_LOG_WARN("%s: fused Gated Delta Net (autoregressive) not supported, set to disabled\n", __func__); + } else { + LLAMA_LOG_INFO("%s: fused Gated Delta Net (autoregressive) enabled\n", __func__); } } - if (gdn_device_mismatch) { - cparams.fused_gdn_ar = false; - LLAMA_LOG_WARN("%s: fused Gated Delta Net not supported, set to disabled\n", __func__); + if (cparams.fused_gdn_ch) { + // more than one token in the batch per sequence in order to take the chunked path + auto * gf = graph_reserve(16*n_seqs, n_seqs, n_outputs, mctx.get(), true); + if (!gf) { + throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (chunked)"); + } + + const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_CH) + 1; + bool gdn_device_mismatch = false; + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * n = ggml_graph_node(gf, i); + if (n->op != GGML_OP_GATED_DELTA_NET) { + continue; + } + ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n)); + + GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_CH "-", prefix_len) == 0); + const int il = std::stoi(n->name + prefix_len); + ggml_backend_dev_t device_kv = model.dev_layer(il); + if (device_gdn != device_kv) { + LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor " + "is assigned to device %s (usually due to missing support)\n", + __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn)); + gdn_device_mismatch = true; + break; + } + } + + if (gdn_device_mismatch) { + cparams.fused_gdn_ch = false; + LLAMA_LOG_WARN("%s: fused Gated Delta Net (chunked) not supported, set to disabled\n", __func__); + } else { + LLAMA_LOG_INFO("%s: fused Gated Delta Net (chunked) enabled\n", __func__); + } } + + cparams.auto_fgdn = false; } // reserve worst-case graph diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 333922468c..9d35947413 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -33,6 +33,7 @@ struct llama_cparams { bool auto_fa; bool fused_gdn_ar; // use fused gated delta net (autoregressive) bool fused_gdn_ch; // use fused gated delta net (chunked) + bool auto_fgdn; bool no_perf; bool warmup; bool op_offload; diff --git a/src/llama-impl.h b/src/llama-impl.h index ee27ac1bea..e4f35c8e53 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -70,6 +70,6 @@ std::string llama_format_tensor_shape(const struct ggml_tensor * t); std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i); -#define LLAMA_TENSOR_NAME_FATTN "__fattn__" -#define LLAMA_TENSOR_NAME_FGDNAR "__fgdnar__" -#define LLAMA_TENSOR_NAME_FGDNCH "__fgdnch__" +#define LLAMA_TENSOR_NAME_FATTN "__fattn__" +#define LLAMA_TENSOR_NAME_FGDN_AR "__fgdn_ar__" +#define LLAMA_TENSOR_NAME_FGDN_CH "__fgdn_ch__" diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index b0be62fc68..a62dbc15dd 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -41,13 +41,6 @@ std::pair llm_build_delta_net_base::build_delta_ne GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - if (cparams.fused_gdn_ch) { - //ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); - //cb(result, LLAMA_TENSOR_NAME_FGDNCH, il); - - GGML_ABORT("not implemented yet"); - } - const float scale = 1.0f / sqrtf(S_k); q = ggml_scale(ctx0, q, scale); @@ -325,26 +318,6 @@ std::pair llm_build_delta_net_base::build_delta_ne GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - if (cparams.fused_gdn_ar) { - ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); - cb(result, LLAMA_TENSOR_NAME_FGDNAR, il); - - ggml_tensor * output = ggml_view_4d(ctx0, result, - S_v, H_v, n_tokens, n_seqs, - ggml_row_size(result->type, S_v), - ggml_row_size(result->type, S_v * H_v), - ggml_row_size(result->type, S_v * H_v * n_tokens), 0); - - ggml_tensor * new_state = ggml_view_4d(ctx0, result, - S_v, S_v, H_v, n_seqs, - ggml_row_size(result->type, S_v), - ggml_row_size(result->type, S_v * S_v), - ggml_row_size(result->type, S_v * S_v * H_v), - ggml_row_size(result->type, S_v * H_v * n_tokens * n_seqs)); - - return {output, new_state}; - } - const float scale = 1.0f / sqrtf(S_k); q = ggml_scale(ctx0, q, scale); @@ -401,3 +374,78 @@ std::pair llm_build_delta_net_base::build_delta_ne return {o, s}; } + +std::pair llm_build_delta_net_base::build_delta_net_fused( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(S_k == S_v); + GGML_ASSERT(H_v % H_k == 0); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); + + GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); + GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs); + GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); + GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + + ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); + if (n_tokens == 1) { + cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il); + } else { + cb(result, LLAMA_TENSOR_NAME_FGDN_CH, il); + } + + ggml_tensor * output = ggml_view_4d(ctx0, result, + S_v, H_v, n_tokens, n_seqs, + ggml_row_size(result->type, S_v), + ggml_row_size(result->type, S_v * H_v), + ggml_row_size(result->type, S_v * H_v * n_tokens), 0); + + ggml_tensor * new_state = ggml_view_4d(ctx0, result, + S_v, S_v, H_v, n_seqs, + ggml_row_size(result->type, S_v), + ggml_row_size(result->type, S_v * S_v), + ggml_row_size(result->type, S_v * S_v * H_v), + ggml_row_size(result->type, S_v * H_v * n_tokens * n_seqs)); + + return {output, new_state}; +} + +std::pair llm_build_delta_net_base::build_delta_net( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il) { + const int64_t n_seq_tokens = q->ne[2]; + + if (n_seq_tokens == 1) { + if (cparams.fused_gdn_ar) { + return build_delta_net_fused(q, k, v, g, b, s, il); + } + return build_delta_net_autoregressive(q, k, v, g, b, s, il); + } + + if (cparams.fused_gdn_ch) { + return build_delta_net_fused(q, k, v, g, b, s, il); + } + + return build_delta_net_chunking(q, k, v, g, b, s, il); +} diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp index 063b17a2f6..4d62f4e715 100644 --- a/src/models/kimi-linear.cpp +++ b/src/models/kimi-linear.cpp @@ -169,9 +169,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll Kcur = ggml_l2_norm(ctx0, Kcur, eps_norm); // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens - std::pair attn_out = n_seq_tokens == 1 ? - build_delta_net_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) : - build_delta_net_chunking(Qcur, Kcur, Vcur, g1, beta, state, il); + auto attn_out = build_delta_net(Qcur, Kcur, Vcur, g1, beta, state, il); ggml_tensor * output = ggml_cont(ctx0, attn_out.first); ggml_tensor * new_state = attn_out.second; diff --git a/src/models/models.h b/src/models/models.h index cf9ba04e7f..a86b2b1ebd 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -44,6 +44,26 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * b, ggml_tensor * s, int il); + + // use the ggml_gated_delta_net fused operator + std::pair build_delta_net_fused( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); + + // choose one of two implementations above based on the number of tokens + std::pair build_delta_net( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); }; struct llm_build_rwkv6_base : public llm_graph_context { diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index ba096a5a7b..e12dad7001 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -321,9 +321,9 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes - if (num_k_heads != num_v_heads) { + // note: need explicit repeat only if we are not using the fused GDN + if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { GGML_ASSERT(num_v_heads % num_k_heads == 0); - // TODO: try to avoid these explicit repeats by utilizing op broadcast q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); } @@ -332,12 +332,8 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - std::pair attn_out; - if (n_seq_tokens == 1) { - attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); - } else { - attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il); - } + auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); + ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; cb(output, "attn_output", il); diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index fe382286e9..8d07c7ed27 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -321,9 +321,9 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes - if (num_k_heads != num_v_heads) { + // note: need explicit repeat only if we are not using the fused GDN + if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { GGML_ASSERT(num_v_heads % num_k_heads == 0); - // TODO: try to avoid these explicit repeats by utilizing op broadcast q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); } @@ -332,12 +332,8 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - std::pair attn_out; - if (n_seq_tokens == 1) { - attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); - } else { - attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il); - } + auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); + ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; cb(output, "attn_output", il); diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index 30912fd5e3..cc479dd075 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -406,6 +406,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes + // TODO: avoid repeats for fused GDN, needs broadcast configuration for GDN op [TAG_GGML_GDN_BCAST] if (num_k_heads != num_v_heads) { GGML_ASSERT(num_v_heads % num_k_heads == 0); int64_t repeat_factor = num_v_heads / num_k_heads; @@ -431,13 +432,8 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens - std::pair attn_out; // pair of (output, new_state) - if (n_seq_tokens == 1) { - attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); - } else { - attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il); - } + auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); + ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; cb(output, "attn_output", il); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 66aaddcfff..58d67d97f8 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8447,6 +8447,9 @@ static std::vector> make_test_cases_eval() { } test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, true, true)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, false, true)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64, 1, 2)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2)); @@ -8456,10 +8459,12 @@ static std::vector> make_test_cases_eval() { // KDA (vector gate) test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 1, 1, 1, false, true)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 1, 2, 1, false, true)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 1, 2, 1, false, true)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 4, 1, 1, false, true)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, false, true)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2, false, true)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, true, true)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 4, 2, 1, true, true)); #if 0 // these tests are disabled to save execution time, sbut they can be handy for debugging