diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 0734e5a1bf..a2584465e3 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -334,7 +334,7 @@ static bool blackwell_mma_available(const int cc) { ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_RUBIN; } -static constexpr __device__ int ggml_cuda_get_physical_warp_size() { +static constexpr __device__ __host__ int ggml_cuda_get_physical_warp_size() { #if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) return 64; #else diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 3d3f268a65..7e5ae51420 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -41,12 +41,13 @@ __global__ void gated_delta_net_cuda(const float * q, curr_state += state_offset; attn_data += (sequence * n_tokens * H + h_idx) * S_v; - static_assert(S_v % WARP_SIZE == 0, "S_v must be a multiple of WARP_SIZE"); - constexpr int ROWS_PER_LANE = S_v / WARP_SIZE; + constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v; + static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size"); + constexpr int ROWS_PER_LANE = (S_v + warp_size - 1) / warp_size; float s_shard[ROWS_PER_LANE]; #pragma unroll for (int r = 0; r < ROWS_PER_LANE; r++) { - const int i = r * WARP_SIZE + lane; + const int i = r * warp_size + lane; s_shard[r] = curr_state[i * S_v + col]; } @@ -68,10 +69,10 @@ __global__ void gated_delta_net_cuda(const float * q, float kv_shard = 0.0f; #pragma unroll for (int r = 0; r < ROWS_PER_LANE; r++) { - const int i = r * WARP_SIZE + lane; + const int i = r * warp_size + lane; kv_shard += s_shard[r] * k_t[i]; } - float kv_col = warp_reduce_sum(kv_shard); + float kv_col = warp_reduce_sum(kv_shard); // delta[col] = (v[col] - g * kv[col]) * beta float delta_col = (v_t[col] - g_val * kv_col) * beta_val; @@ -81,12 +82,12 @@ __global__ void gated_delta_net_cuda(const float * q, float attn_partial = 0.0f; #pragma unroll for (int r = 0; r < ROWS_PER_LANE; r++) { - const int i = r * WARP_SIZE + lane; + const int i = r * warp_size + lane; s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col; attn_partial += s_shard[r] * q_t[i]; } - float attn_col = warp_reduce_sum(attn_partial); + float attn_col = warp_reduce_sum(attn_partial); if (lane == 0) { attn_data[col] = attn_col * scale; @@ -96,11 +97,11 @@ __global__ void gated_delta_net_cuda(const float * q, float kv_shard = 0.0f; #pragma unroll for (int r = 0; r < ROWS_PER_LANE; r++) { - const int i = r * WARP_SIZE + lane; + const int i = r * warp_size + lane; kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i]; } - float kv_col = warp_reduce_sum(kv_shard); + float kv_col = warp_reduce_sum(kv_shard); // delta[col] = (v[col] - kv[col]) * beta float delta_col = (v_t[col] - kv_col) * beta_val; @@ -110,12 +111,12 @@ __global__ void gated_delta_net_cuda(const float * q, float attn_partial = 0.0f; #pragma unroll for (int r = 0; r < ROWS_PER_LANE; r++) { - const int i = r * WARP_SIZE + lane; + const int i = r * warp_size + lane; s_shard[r] = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col; attn_partial += s_shard[r] * q_t[i]; } - float attn_col = warp_reduce_sum(attn_partial); + float attn_col = warp_reduce_sum(attn_partial); if (lane == 0) { attn_data[col] = attn_col * scale; @@ -128,7 +129,7 @@ __global__ void gated_delta_net_cuda(const float * q, // Write state back to global memory #pragma unroll for (int r = 0; r < ROWS_PER_LANE; r++) { - const int i = r * WARP_SIZE + lane; + const int i = r * warp_size + lane; state[i * S_v + col] = s_shard[r]; } } @@ -145,14 +146,21 @@ static void launch_gated_delta_net( int64_t neqk1, int64_t rq3, float scale, cudaStream_t stream) { + constexpr uint32_t warp_size = ggml_cuda_get_physical_warp_size(); const int num_warps = 4; dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps); - dim3 block_dims(WARP_SIZE, num_warps, 1); + dim3 block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1); const fastdiv_consts_s64 neqk1_magic = init_fastdiv_s64(neqk1); const fastdiv_consts_s64 rq3_magic = init_fastdiv_s64(rq3); switch (S_v) { + case 16: + gated_delta_net_cuda<16, KDA><<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + break; case 32: gated_delta_net_cuda<32, KDA><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 32a83b001d..40117dccf2 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8451,6 +8451,9 @@ static std::vector> make_test_cases_eval() { } test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, true, true)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, false, true)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64, 1, 2)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2)); @@ -8460,10 +8463,12 @@ static std::vector> make_test_cases_eval() { // KDA (vector gate) test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 1, 1, 1, false, true)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 1, 2, 1, false, true)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 1, 2, 1, false, true)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 4, 1, 1, false, true)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, false, true)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2, false, true)); test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, true, true)); + test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 4, 2, 1, true, true)); #if 0 // these tests are disabled to save execution time, sbut they can be handy for debugging