diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index d8e8111455..c249bbc86d 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -2,28 +2,29 @@ #include "ggml-cuda/common.cuh" template -__global__ void gated_delta_net_cuda(const float * q, - const float * k, - const float * v, - const float * g, - const float * beta, - const float * curr_state, - float * dst, - int64_t H, - int64_t n_tokens, - int64_t n_seqs, - int64_t sq1, - int64_t sq2, - int64_t sq3, - int64_t sv1, - int64_t sv2, - int64_t sv3, - int64_t sb1, - int64_t sb2, - int64_t sb3, - int64_t rq1, - int64_t rq3, - float scale) { +__global__ void __launch_bounds__(S_v, 1) +gated_delta_net_cuda(const float * q, + const float * k, + const float * v, + const float * g, + const float * beta, + const float * curr_state, + float * dst, + const int64_t H, + const int64_t n_tokens, + const int64_t n_seqs, + const int64_t sq1, + const int64_t sq2, + const int64_t sq3, + const int64_t sv1, + const int64_t sv2, + const int64_t sv3, + const int64_t sb1, + const int64_t sb2, + const int64_t sb3, + const int64_t rq1, + const int64_t rq3, + const float scale) { const int64_t h_idx = blockIdx.x; const int64_t sequence = blockIdx.y; const int col = threadIdx.x; // each thread owns one column @@ -40,8 +41,14 @@ __global__ void gated_delta_net_cuda(const float * q, curr_state += state_offset; attn_data += (sequence * n_tokens * H + h_idx) * S_v; - // Load state column into registers + // GCN and CDNA devices spill registers, we use shared mem for them. See https://github.com/ggml-org/llama.cpp/pull/20282#issuecomment-4025770229 + // TODO: check optimal path for RDNA1 and RDNA2 devices. +#if (defined(GGML_USE_HIP) && !defined(RDNA3) && !defined(RDNA4)) || defined(GGML_USE_MUSA) + extern __shared__ float s_shared[]; + float * s = s_shared + col * S_v; +#else float s[S_v]; +#endif #pragma unroll for (int i = 0; i < S_v; i++) { s[i] = curr_state[i * S_v + col]; @@ -114,6 +121,15 @@ __global__ void gated_delta_net_cuda(const float * q, } } +static size_t calculate_smem(const int sv, int cc) +{ + size_t smem = 0; + if ((GGML_CUDA_CC_IS_AMD(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_RDNA4(cc)) || GGML_CUDA_CC_IS_MTHREADS(cc)) { + smem = sv * sv * sizeof(float); + } + return smem; +} + template static void launch_gated_delta_net( const float * q_d, const float * k_d, const float * v_d, @@ -129,25 +145,36 @@ static void launch_gated_delta_net( dim3 grid_dims(H, n_seqs, 1); dim3 block_dims(S_v, 1, 1); + int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + switch (S_v) { - case 32: - gated_delta_net_cuda<32, KDA><<>>( + case 32: { + constexpr int sv = 32; + size_t smem = calculate_smem(sv, cc); + gated_delta_net_cuda<<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, rq1, rq3, scale); break; - case 64: - gated_delta_net_cuda<64, KDA><<>>( + } + case 64: { + constexpr int sv = 64; + size_t smem = calculate_smem(sv, cc); + gated_delta_net_cuda<<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, rq1, rq3, scale); break; - case 128: - gated_delta_net_cuda<128, KDA><<>>( + } + case 128: { + constexpr int sv = 128; + size_t smem = calculate_smem(sv, cc); + gated_delta_net_cuda<<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, rq1, rq3, scale); break; + } default: GGML_ABORT("fatal error"); break;