Make sharding HIP-compatible

1. Use ggml_cuda_get_physical_warp_size() to determine warp size flexibly
2. Add test with partial warp to test sum reduction on CUDA
This commit is contained in:
Oliver Simons 2026-03-11 15:28:47 +01:00
parent e1cfcf8774
commit 0211798e56
3 changed files with 27 additions and 14 deletions

View File

@ -334,7 +334,7 @@ static bool blackwell_mma_available(const int cc) {
ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_RUBIN;
}
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
static constexpr __device__ __host__ int ggml_cuda_get_physical_warp_size() {
#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
return 64;
#else

View File

@ -41,12 +41,13 @@ __global__ void gated_delta_net_cuda(const float * q,
curr_state += state_offset;
attn_data += (sequence * n_tokens * H + h_idx) * S_v;
static_assert(S_v % WARP_SIZE == 0, "S_v must be a multiple of WARP_SIZE");
constexpr int ROWS_PER_LANE = S_v / WARP_SIZE;
constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v;
static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size");
constexpr int ROWS_PER_LANE = (S_v + warp_size - 1) / warp_size;
float s_shard[ROWS_PER_LANE];
#pragma unroll
for (int r = 0; r < ROWS_PER_LANE; r++) {
const int i = r * WARP_SIZE + lane;
const int i = r * warp_size + lane;
s_shard[r] = curr_state[i * S_v + col];
}
@ -68,10 +69,10 @@ __global__ void gated_delta_net_cuda(const float * q,
float kv_shard = 0.0f;
#pragma unroll
for (int r = 0; r < ROWS_PER_LANE; r++) {
const int i = r * WARP_SIZE + lane;
const int i = r * warp_size + lane;
kv_shard += s_shard[r] * k_t[i];
}
float kv_col = warp_reduce_sum(kv_shard);
float kv_col = warp_reduce_sum<warp_size>(kv_shard);
// delta[col] = (v[col] - g * kv[col]) * beta
float delta_col = (v_t[col] - g_val * kv_col) * beta_val;
@ -81,12 +82,12 @@ __global__ void gated_delta_net_cuda(const float * q,
float attn_partial = 0.0f;
#pragma unroll
for (int r = 0; r < ROWS_PER_LANE; r++) {
const int i = r * WARP_SIZE + lane;
const int i = r * warp_size + lane;
s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col;
attn_partial += s_shard[r] * q_t[i];
}
float attn_col = warp_reduce_sum(attn_partial);
float attn_col = warp_reduce_sum<warp_size>(attn_partial);
if (lane == 0) {
attn_data[col] = attn_col * scale;
@ -96,11 +97,11 @@ __global__ void gated_delta_net_cuda(const float * q,
float kv_shard = 0.0f;
#pragma unroll
for (int r = 0; r < ROWS_PER_LANE; r++) {
const int i = r * WARP_SIZE + lane;
const int i = r * warp_size + lane;
kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i];
}
float kv_col = warp_reduce_sum(kv_shard);
float kv_col = warp_reduce_sum<warp_size>(kv_shard);
// delta[col] = (v[col] - kv[col]) * beta
float delta_col = (v_t[col] - kv_col) * beta_val;
@ -110,12 +111,12 @@ __global__ void gated_delta_net_cuda(const float * q,
float attn_partial = 0.0f;
#pragma unroll
for (int r = 0; r < ROWS_PER_LANE; r++) {
const int i = r * WARP_SIZE + lane;
const int i = r * warp_size + lane;
s_shard[r] = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col;
attn_partial += s_shard[r] * q_t[i];
}
float attn_col = warp_reduce_sum(attn_partial);
float attn_col = warp_reduce_sum<warp_size>(attn_partial);
if (lane == 0) {
attn_data[col] = attn_col * scale;
@ -128,7 +129,7 @@ __global__ void gated_delta_net_cuda(const float * q,
// Write state back to global memory
#pragma unroll
for (int r = 0; r < ROWS_PER_LANE; r++) {
const int i = r * WARP_SIZE + lane;
const int i = r * warp_size + lane;
state[i * S_v + col] = s_shard[r];
}
}
@ -145,14 +146,21 @@ static void launch_gated_delta_net(
int64_t neqk1, int64_t rq3,
float scale, cudaStream_t stream) {
constexpr uint32_t warp_size = ggml_cuda_get_physical_warp_size();
const int num_warps = 4;
dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps);
dim3 block_dims(WARP_SIZE, num_warps, 1);
dim3 block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1);
const fastdiv_consts_s64 neqk1_magic = init_fastdiv_s64(neqk1);
const fastdiv_consts_s64 rq3_magic = init_fastdiv_s64(rq3);
switch (S_v) {
case 16:
gated_delta_net_cuda<16, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,
n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3,
sb1, sb2, sb3, neqk1_magic, rq3_magic, scale);
break;
case 32:
gated_delta_net_cuda<32, KDA><<<grid_dims, block_dims, 0, stream>>>(
q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H,

View File

@ -8451,6 +8451,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 1, 1));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, true, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 16, 1, 1, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 16, 64, 1, 2));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 1));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2));
@ -8460,10 +8463,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
// KDA (vector gate)
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 1, 1, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 1, 2, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 1, 2, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 32, 4, 1, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 8, 32, 4, 2, 2, false, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 64, 4, 2, 1, true, true));
test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 16, 4, 2, 1, true, true));
#if 0
// these tests are disabled to save execution time, sbut they can be handy for debugging