ggml-cuda: gdn use shared mem for HIP (#20366)
Suggested-by: Aman Gupta <amangupta052@gmail.com>
This commit is contained in:
parent
9ef7523ee9
commit
5f91b1d5d5
|
|
@ -2,28 +2,29 @@
|
|||
#include "ggml-cuda/common.cuh"
|
||||
|
||||
template <int S_v, bool KDA>
|
||||
__global__ void gated_delta_net_cuda(const float * q,
|
||||
const float * k,
|
||||
const float * v,
|
||||
const float * g,
|
||||
const float * beta,
|
||||
const float * curr_state,
|
||||
float * dst,
|
||||
int64_t H,
|
||||
int64_t n_tokens,
|
||||
int64_t n_seqs,
|
||||
int64_t sq1,
|
||||
int64_t sq2,
|
||||
int64_t sq3,
|
||||
int64_t sv1,
|
||||
int64_t sv2,
|
||||
int64_t sv3,
|
||||
int64_t sb1,
|
||||
int64_t sb2,
|
||||
int64_t sb3,
|
||||
int64_t rq1,
|
||||
int64_t rq3,
|
||||
float scale) {
|
||||
__global__ void __launch_bounds__(S_v, 1)
|
||||
gated_delta_net_cuda(const float * q,
|
||||
const float * k,
|
||||
const float * v,
|
||||
const float * g,
|
||||
const float * beta,
|
||||
const float * curr_state,
|
||||
float * dst,
|
||||
const int64_t H,
|
||||
const int64_t n_tokens,
|
||||
const int64_t n_seqs,
|
||||
const int64_t sq1,
|
||||
const int64_t sq2,
|
||||
const int64_t sq3,
|
||||
const int64_t sv1,
|
||||
const int64_t sv2,
|
||||
const int64_t sv3,
|
||||
const int64_t sb1,
|
||||
const int64_t sb2,
|
||||
const int64_t sb3,
|
||||
const int64_t rq1,
|
||||
const int64_t rq3,
|
||||
const float scale) {
|
||||
const int64_t h_idx = blockIdx.x;
|
||||
const int64_t sequence = blockIdx.y;
|
||||
const int col = threadIdx.x; // each thread owns one column
|
||||
|
|
@ -40,8 +41,14 @@ __global__ void gated_delta_net_cuda(const float * q,
|
|||
curr_state += state_offset;
|
||||
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
|
||||
|
||||
// Load state column into registers
|
||||
// GCN and CDNA devices spill registers, we use shared mem for them. See https://github.com/ggml-org/llama.cpp/pull/20282#issuecomment-4025770229
|
||||
// TODO: check optimal path for RDNA1 and RDNA2 devices.
|
||||
#if (defined(GGML_USE_HIP) && !defined(RDNA3) && !defined(RDNA4)) || defined(GGML_USE_MUSA)
|
||||
extern __shared__ float s_shared[];
|
||||
float * s = s_shared + col * S_v;
|
||||
#else
|
||||
float s[S_v];
|
||||
#endif
|
||||
#pragma unroll
|
||||
for (int i = 0; i < S_v; i++) {
|
||||
s[i] = curr_state[i * S_v + col];
|
||||
|
|
@ -114,6 +121,15 @@ __global__ void gated_delta_net_cuda(const float * q,
|
|||
}
|
||||
}
|
||||
|
||||
static size_t calculate_smem(const int sv, int cc)
|
||||
{
|
||||
size_t smem = 0;
|
||||
if ((GGML_CUDA_CC_IS_AMD(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_RDNA4(cc)) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
|
||||
smem = sv * sv * sizeof(float);
|
||||
}
|
||||
return smem;
|
||||
}
|
||||
|
||||
template <bool KDA>
|
||||
static void launch_gated_delta_net(
|
||||
const float * q_d, const float * k_d, const float * v_d,
|
||||
|
|
@ -129,25 +145,36 @@ static void launch_gated_delta_net(
|
|||
dim3 grid_dims(H, n_seqs, 1);
|
||||
dim3 block_dims(S_v, 1, 1);
|
||||
|
||||
int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
|
||||
switch (S_v) {
|
||||
case 32:
|
||||
gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
|
||||
case 32: {
|
||||
constexpr int sv = 32;
|
||||
size_t smem = calculate_smem(sv, cc);
|
||||
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, rq1, rq3, scale);
|
||||
break;
|
||||
case 64:
|
||||
gated_delta_net_cuda<64, KDA><<<grid_dims, block_dims, 0, stream>>>(
|
||||
}
|
||||
case 64: {
|
||||
constexpr int sv = 64;
|
||||
size_t smem = calculate_smem(sv, cc);
|
||||
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, rq1, rq3, scale);
|
||||
break;
|
||||
case 128:
|
||||
gated_delta_net_cuda<128, KDA><<<grid_dims, block_dims, 0, stream>>>(
|
||||
}
|
||||
case 128: {
|
||||
constexpr int sv = 128;
|
||||
size_t smem = calculate_smem(sv, cc);
|
||||
gated_delta_net_cuda<sv, KDA><<<grid_dims, block_dims, smem, stream>>>(
|
||||
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
|
||||
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
|
||||
sb1, sb2, sb3, rq1, rq3, scale);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
break;
|
||||
|
|
|
|||
Loading…
Reference in New Issue