Remove fastdiv_s64, as we can treat neqk1 and rq3 as uint32_t
This commit is contained in:
parent
0211798e56
commit
e26d75b083
|
|
@ -821,91 +821,6 @@ __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e)
|
||||||
return static_cast<uint8_t>(best_i | sign_bit);
|
return static_cast<uint8_t>(best_i | sign_bit);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct fastdiv_consts_s64 {
|
|
||||||
int64_t mp; // magic number
|
|
||||||
int64_t d; // divisor
|
|
||||||
uint8_t L; // need at most 6 bits to represent L, use 7th bit to signal sign of d
|
|
||||||
};
|
|
||||||
|
|
||||||
static inline uint8_t floor_log2(uint64_t x) {
|
|
||||||
uint64_t exp;
|
|
||||||
#if defined(__GNUC__) || defined(__clang__)
|
|
||||||
exp = 63 - __builtin_clzll(x);
|
|
||||||
#elif defined(_MSC_VER)
|
|
||||||
// MSVC: _BitScanReverse64 finds the index of the MSB (0 to 63)
|
|
||||||
_BitScanReverse64(&exp, x);
|
|
||||||
#endif
|
|
||||||
return exp;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper: Safely computes floor(2^{64 + l} * (1 + fraction) / d)
|
|
||||||
static uint64_t compute_m(uint64_t d, int l, int prec) {
|
|
||||||
// 1. Calculate q and r such that: 2^64 = (q * d) + r
|
|
||||||
// Since 2^64 overflows uint64_t, we use 0xFF...FF and adjust.
|
|
||||||
uint64_t q = (0xFFFFFFFFFFFFFFFF / d);
|
|
||||||
uint64_t r = (0xFFFFFFFFFFFFFFFF % d) + 1;
|
|
||||||
if (r >= d) {
|
|
||||||
r -= d;
|
|
||||||
q++;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. We need (2^l * 2^64) / d => (2^l * q) + (2^l * r) / d
|
|
||||||
uint64_t base_q = (q << l);
|
|
||||||
uint64_t base_r = (r << l) / d;
|
|
||||||
uint64_t m = base_q + base_r;
|
|
||||||
|
|
||||||
// 3. For m_high, we need to add the precision term: (2^{64 + l - prec}) / d
|
|
||||||
// (2^{64 + l - prec}) / d => (2^{64} / d) >> (prec - l)
|
|
||||||
// This is safe because (prec - l) >= 0 in this algorithm
|
|
||||||
uint64_t extra = (q >> (prec - l)) + ((r >> (prec - l)) / d);
|
|
||||||
m += extra;
|
|
||||||
|
|
||||||
return m;
|
|
||||||
}
|
|
||||||
|
|
||||||
static const fastdiv_consts_s64 init_fastdiv_s64(int64_t d) {
|
|
||||||
GGML_ASSERT(d != 0);
|
|
||||||
uint64_t abs_d = d < 0 ? -d : d;
|
|
||||||
|
|
||||||
uint8_t L = floor_log2(abs_d);
|
|
||||||
|
|
||||||
if (uint64_t{ 1 } << L == abs_d) {
|
|
||||||
// signal negative divisor in L's 7th bit
|
|
||||||
d < 0 ? L |= 0x40 : 0;
|
|
||||||
// multiply with 0 to avoid branching in fastdiv_s64 kernel when d is power of 2
|
|
||||||
return { 0, d, L };
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t mh = compute_m(abs_d, L, 63);
|
|
||||||
// signal negative divisor in L's 7th bit
|
|
||||||
d < 0 ? L |= 0x40 : 0;
|
|
||||||
return { (int64_t) mh, d, L };
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ int64_t fastdiv_s64(int64_t n, fastdiv_consts_s64 c) {
|
|
||||||
int64_t q;
|
|
||||||
q = __mul64hi(n, c.mp);
|
|
||||||
q += n;
|
|
||||||
|
|
||||||
// Extract the sign bit
|
|
||||||
uint64_t q_sign = q >> 63;
|
|
||||||
// L's lower 6 bits are the shift amount
|
|
||||||
uint8_t shift = c.L & 0x3F;
|
|
||||||
q += q_sign & ((1ULL << shift) - (c.mp == 0));
|
|
||||||
q >>= shift;
|
|
||||||
|
|
||||||
// if divisor is negative, negate the quotient
|
|
||||||
int64_t d_sign = c.L >> 6;
|
|
||||||
d_sign ? q = -q : 0;
|
|
||||||
|
|
||||||
return q;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ int64_t fastmodulo_s64(int64_t n, fastdiv_consts_s64 c) {
|
|
||||||
int64_t q = fastdiv_s64(n, c);
|
|
||||||
return n - (q * c.d);
|
|
||||||
}
|
|
||||||
|
|
||||||
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
|
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
|
||||||
// Precompute mp (m' in the paper) and L such that division
|
// Precompute mp (m' in the paper) and L such that division
|
||||||
// can be computed using a multiply (high 32b of 64b result)
|
// can be computed using a multiply (high 32b of 64b result)
|
||||||
|
|
|
||||||
|
|
@ -1,36 +1,36 @@
|
||||||
#include "gated_delta_net.cuh"
|
#include "gated_delta_net.cuh"
|
||||||
|
|
||||||
template <int S_v, bool KDA>
|
template <int S_v, bool KDA>
|
||||||
__global__ void gated_delta_net_cuda(const float * q,
|
__global__ void gated_delta_net_cuda(const float * q,
|
||||||
const float * k,
|
const float * k,
|
||||||
const float * v,
|
const float * v,
|
||||||
const float * g,
|
const float * g,
|
||||||
const float * beta,
|
const float * beta,
|
||||||
const float * curr_state,
|
const float * curr_state,
|
||||||
float * dst,
|
float * dst,
|
||||||
int64_t H,
|
int64_t H,
|
||||||
int64_t n_tokens,
|
int64_t n_tokens,
|
||||||
int64_t n_seqs,
|
int64_t n_seqs,
|
||||||
int64_t sq1,
|
int64_t sq1,
|
||||||
int64_t sq2,
|
int64_t sq2,
|
||||||
int64_t sq3,
|
int64_t sq3,
|
||||||
int64_t sv1,
|
int64_t sv1,
|
||||||
int64_t sv2,
|
int64_t sv2,
|
||||||
int64_t sv3,
|
int64_t sv3,
|
||||||
int64_t sb1,
|
int64_t sb1,
|
||||||
int64_t sb2,
|
int64_t sb2,
|
||||||
int64_t sb3,
|
int64_t sb3,
|
||||||
const fastdiv_consts_s64 neqk1_magic,
|
const uint3 neqk1_magic,
|
||||||
const fastdiv_consts_s64 rq3_magic,
|
const uint3 rq3_magic,
|
||||||
float scale) {
|
float scale) {
|
||||||
const int64_t h_idx = blockIdx.x;
|
const uint32_t h_idx = blockIdx.x;
|
||||||
const int64_t sequence = blockIdx.y;
|
const uint32_t sequence = blockIdx.y;
|
||||||
// each warp owns one column, using warp-level primitives to reduce across rows
|
// each warp owns one column, using warp-level primitives to reduce across rows
|
||||||
const int lane = threadIdx.x;
|
const int lane = threadIdx.x;
|
||||||
const int col = blockIdx.z * blockDim.y + threadIdx.y;
|
const int col = blockIdx.z * blockDim.y + threadIdx.y;
|
||||||
|
|
||||||
const int64_t iq1 = fastmodulo_s64(h_idx, neqk1_magic);
|
const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic);
|
||||||
const int64_t iq3 = fastdiv_s64(sequence, rq3_magic);
|
const uint32_t iq3 = fastdiv(sequence, rq3_magic);
|
||||||
|
|
||||||
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
|
const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs;
|
||||||
float * attn_data = dst;
|
float * attn_data = dst;
|
||||||
|
|
@ -151,8 +151,8 @@ static void launch_gated_delta_net(
|
||||||
dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps);
|
dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps);
|
||||||
dim3 block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1);
|
dim3 block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1);
|
||||||
|
|
||||||
const fastdiv_consts_s64 neqk1_magic = init_fastdiv_s64(neqk1);
|
const uint3 neqk1_magic = init_fastdiv_values(neqk1);
|
||||||
const fastdiv_consts_s64 rq3_magic = init_fastdiv_s64(rq3);
|
const uint3 rq3_magic = init_fastdiv_values(rq3);
|
||||||
|
|
||||||
switch (S_v) {
|
switch (S_v) {
|
||||||
case 16:
|
case 16:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue