From 55d2e3a361f877527f80c74bdb4af2bab5e9292a Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Tue, 10 Mar 2026 13:25:00 +0100 Subject: [PATCH] Add FastDiv to gated_delta_net_cuda --- ggml/src/ggml-cuda/common.cuh | 85 +++++++++++++++++++++++++++ ggml/src/ggml-cuda/gated_delta_net.cu | 57 +++++++++--------- 2 files changed, 115 insertions(+), 27 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 36d8a3aaab..0734e5a1bf 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -821,6 +821,91 @@ __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) return static_cast(best_i | sign_bit); } +struct fastdiv_consts_s64 { + int64_t mp; // magic number + int64_t d; // divisor + uint8_t L; // need at most 6 bits to represent L, use 7th bit to signal sign of d +}; + +static inline uint8_t floor_log2(uint64_t x) { + uint64_t exp; +#if defined(__GNUC__) || defined(__clang__) + exp = 63 - __builtin_clzll(x); +#elif defined(_MSC_VER) + // MSVC: _BitScanReverse64 finds the index of the MSB (0 to 63) + _BitScanReverse64(&exp, x); +#endif + return exp; +} + +// Helper: Safely computes floor(2^{64 + l} * (1 + fraction) / d) +static uint64_t compute_m(uint64_t d, int l, int prec) { + // 1. Calculate q and r such that: 2^64 = (q * d) + r + // Since 2^64 overflows uint64_t, we use 0xFF...FF and adjust. + uint64_t q = (0xFFFFFFFFFFFFFFFF / d); + uint64_t r = (0xFFFFFFFFFFFFFFFF % d) + 1; + if (r >= d) { + r -= d; + q++; + } + + // 2. We need (2^l * 2^64) / d => (2^l * q) + (2^l * r) / d + uint64_t base_q = (q << l); + uint64_t base_r = (r << l) / d; + uint64_t m = base_q + base_r; + + // 3. For m_high, we need to add the precision term: (2^{64 + l - prec}) / d + // (2^{64 + l - prec}) / d => (2^{64} / d) >> (prec - l) + // This is safe because (prec - l) >= 0 in this algorithm + uint64_t extra = (q >> (prec - l)) + ((r >> (prec - l)) / d); + m += extra; + + return m; +} + +static const fastdiv_consts_s64 init_fastdiv_s64(int64_t d) { + GGML_ASSERT(d != 0); + uint64_t abs_d = d < 0 ? -d : d; + + uint8_t L = floor_log2(abs_d); + + if (uint64_t{ 1 } << L == abs_d) { + // signal negative divisor in L's 7th bit + d < 0 ? L |= 0x40 : 0; + // multiply with 0 to avoid branching in fastdiv_s64 kernel when d is power of 2 + return { 0, d, L }; + } + + uint64_t mh = compute_m(abs_d, L, 63); + // signal negative divisor in L's 7th bit + d < 0 ? L |= 0x40 : 0; + return { (int64_t) mh, d, L }; +} + +static __device__ int64_t fastdiv_s64(int64_t n, fastdiv_consts_s64 c) { + int64_t q; + q = __mul64hi(n, c.mp); + q += n; + + // Extract the sign bit + uint64_t q_sign = q >> 63; + // L's lower 6 bits are the shift amount + uint8_t shift = c.L & 0x3F; + q += q_sign & ((1ULL << shift) - (c.mp == 0)); + q >>= shift; + + // if divisor is negative, negate the quotient + int64_t d_sign = c.L >> 6; + d_sign ? q = -q : 0; + + return q; +} + +__device__ __forceinline__ int64_t fastmodulo_s64(int64_t n, fastdiv_consts_s64 c) { + int64_t q = fastdiv_s64(n, c); + return n - (q * c.d); +} + // See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. // Precompute mp (m' in the paper) and L such that division // can be computed using a multiply (high 32b of 64b result) diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 086d38648e..801d828cea 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -2,34 +2,34 @@ #include "ggml-cuda/common.cuh" template -__global__ void gated_delta_net_cuda(const float * q, - const float * k, - const float * v, - const float * g, - const float * beta, - const float * curr_state, - float * dst, - int64_t H, - int64_t n_tokens, - int64_t n_seqs, - int64_t sq1, - int64_t sq2, - int64_t sq3, - int64_t sv1, - int64_t sv2, - int64_t sv3, - int64_t sb1, - int64_t sb2, - int64_t sb3, - int64_t neqk1, - int64_t rq3, - float scale) { +__global__ void gated_delta_net_cuda(const float * q, + const float * k, + const float * v, + const float * g, + const float * beta, + const float * curr_state, + float * dst, + int64_t H, + int64_t n_tokens, + int64_t n_seqs, + int64_t sq1, + int64_t sq2, + int64_t sq3, + int64_t sv1, + int64_t sv2, + int64_t sv3, + int64_t sb1, + int64_t sb2, + int64_t sb3, + const fastdiv_consts_s64 neqk1_magic, + const fastdiv_consts_s64 rq3_magic, + float scale) { const int64_t h_idx = blockIdx.x; const int64_t sequence = blockIdx.y; const int col = threadIdx.x; // each thread owns one column - const int64_t iq1 = h_idx % neqk1; - const int64_t iq3 = sequence / rq3; + const int64_t iq1 = fastmodulo_s64(h_idx, neqk1_magic); + const int64_t iq3 = fastdiv_s64(sequence, rq3_magic); const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; float * attn_data = dst; @@ -129,24 +129,27 @@ static void launch_gated_delta_net( dim3 grid_dims(H, n_seqs, 1); dim3 block_dims(S_v, 1, 1); + const fastdiv_consts_s64 neqk1_magic = init_fastdiv_s64(neqk1); + const fastdiv_consts_s64 rq3_magic = init_fastdiv_s64(rq3); + switch (S_v) { case 32: gated_delta_net_cuda<32, KDA><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; case 64: gated_delta_net_cuda<64, KDA><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; case 128: gated_delta_net_cuda<128, KDA><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; default: GGML_ABORT("fatal error");