diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 85c587928e..3cfd5a4860 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -726,6 +726,8 @@ __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) int best_i = 0; float best_err = fabsf(ax - pos_lut[0]); + +#pragma unroll for (int i = 1; i < 8; ++i) { float err = fabsf(ax - pos_lut[i]); if (err < best_err) { diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 0fe034fa5e..39d8b6592d 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -47,6 +47,18 @@ static __global__ void quantize_q8_1( y[ib].ds = make_half2(d, sum); } +// Helper to compute E8M0 scale from amax using fast math +__device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) { + if (amax == 0.0f) { + return 127; // Special case: use scale of 1.0 for zero input + } + // log2(amax / 6.0) = log2(amax) - log2(6) ≈ log2(amax) - 2.585 + // Use __log2f for fast approximate log2 + const float log2_amax = __log2f(amax) - 2.5849625007211563f; // log2(6) + const int e_int = __float2int_rd(log2_amax) + 127; // floor + bias + return static_cast(max(1, min(254, e_int))); +} + static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy, @@ -60,10 +72,15 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x, constexpr int vals_per_scale = 32; constexpr int vals_per_warp = 2 * vals_per_scale; // Each warp processes 2 blocks of 32 - // Each warp processes 2 adjacent blocks of 32 values (64 values total) - const int64_t warp_start_offset = blockIdx.y * vals_per_warp; - const int64_t i0_block0 = warp_start_offset + threadIdx.x; // First block: 0-31 - const int64_t i0_block1 = warp_start_offset + vals_per_scale + threadIdx.x; // Second block: 32-63 + // Multiple warps per block - each warp handles different data + const int warp_id = threadIdx.y; + const int lane_id_32 = threadIdx.x; + + const int nwarps = blockDim.y; + + const int64_t warp_start_offset = (blockIdx.y * nwarps + warp_id) * vals_per_warp; + const int64_t i0_block0 = warp_start_offset + lane_id_32; + const int64_t i0_block1 = warp_start_offset + vals_per_scale + lane_id_32; if (i0_block0 >= ne0) { return; @@ -80,117 +97,70 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x, block_fp4_mmq * y = (block_fp4_mmq *) vy; const int64_t block_fp4_mmq_size = 4 * QK_MXFP4; // 128 values + const int64_t ib0 = blockIdx.z * ((int64_t) gridDim.x * gridDim.y * nwarps * vals_per_warp / block_fp4_mmq_size); + const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x; + const int64_t pair_idx_in_block = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp; - const int64_t ib0 = - blockIdx.z * ((int64_t) gridDim.x * gridDim.y * vals_per_warp / block_fp4_mmq_size); // first block of channel - const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x; // block index in channel - const int64_t pair_idx_in_block = - (warp_start_offset % block_fp4_mmq_size) / vals_per_warp; // 0-1: which pair of blocks within block_fp4_mmq + // Precompute common values + const int lane_id = lane_id_32 % 4; + const int group_id = lane_id_32 / 4; + const int group_base = group_id * 4; + char2 * yqs2 = (char2 *) y[ib].qs; - uint8_t e_packed[2]; + const int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01; + const float xi0 = (i0_block0 < ne00) ? x[base_pos + i0_block0] : 0.0f; + const float xi1 = (i0_block1 < ne00) ? x[base_pos + i0_block1] : 0.0f; - // Process first block (0-31) - { - const int64_t global_src_pos = i03 * s03 + i02 * s02 + i01 * s01 + i0_block0; - const float xi = i0_block0 < ne00 ? x[global_src_pos] : 0.0f; - - float amax = fabsf(xi); - - // Reduce max across all 32 threads in the warp + // === Process first block (0-31) === + float amax0 = fabsf(xi0); #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE)); - } - - uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax / 6.0f)) + 127) : 0; - - float val = ggml_cuda_e8m0_to_fp32(e); - float inv_s = (amax == 0.0f) ? 0.0f : 1.0f / val; - - // Quantize: each thread processes 1 value - uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s); - - if (e == 0) { - e = 127; - } - - // Pack 4 values into char2: threads 0,1,2,3 -> first char2, etc. - const int lane_id = threadIdx.x % 4; - const int group_id = threadIdx.x / 4; - - // Use shuffle to gather values from 4 consecutive threads - uint8_t q0 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 0, WARP_SIZE); - uint8_t q1 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 1, WARP_SIZE); - uint8_t q2 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 2, WARP_SIZE); - uint8_t q3 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 3, WARP_SIZE); - - char2 q; - if (lane_id == 0) { - q.x = (q1 << 4) | q0; - q.y = (q3 << 4) | q2; - - // Write to output: first block in pair uses positions based on pair_idx_in_block - // Each pair has 2 blocks of 32 = 64 values = 16 char2 elements - char2 * yqs2 = (char2 *) y[ib].qs; - yqs2[pair_idx_in_block * 16 + group_id] = q; - } - - if (threadIdx.x == 0) { - e_packed[0] = e; - } + for (int mask = 16; mask > 0; mask >>= 1) { + amax0 = fmaxf(amax0, __shfl_xor_sync(0xFFFFFFFF, amax0, mask, WARP_SIZE)); } - // Process second block (32-63) - { - const int64_t global_src_pos = i03 * s03 + i02 * s02 + i01 * s01 + i0_block1; - const float xi = i0_block1 < ne00 ? x[global_src_pos] : 0.0f; + const uint8_t e0 = compute_e8m0_scale(amax0); + const float inv_s0 = (amax0 == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e0)); + const uint8_t q_val0 = ggml_cuda_float_to_fp4_e2m1(xi0, inv_s0); - float amax = fabsf(xi); - -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE)); - } - - uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax / 6.0f)) + 127) : 0; - - float val = ggml_cuda_e8m0_to_fp32(e); - float inv_s = (amax == 0.0f) ? 0.0f : 1.0f / val; - - if (e == 0) { - e = 127; - } - - uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s); - - const int lane_id = threadIdx.x % 4; - const int group_id = threadIdx.x / 4; - - // Use shuffle to gather values from 4 consecutive threads - uint8_t q0 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 0, WARP_SIZE); - uint8_t q1 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 1, WARP_SIZE); - uint8_t q2 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 2, WARP_SIZE); - uint8_t q3 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 3, WARP_SIZE); + // Gather 4 values from consecutive threads using shuffle + const uint8_t q0_0 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 0, WARP_SIZE); + const uint8_t q0_1 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 1, WARP_SIZE); + const uint8_t q0_2 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 2, WARP_SIZE); + const uint8_t q0_3 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 3, WARP_SIZE); + if (lane_id == 0) { char2 q; - if (lane_id == 0) { - q.x = (q1 << 4) | q0; - q.y = (q3 << 4) | q2; - - // Write to output: second block in pair uses positions 8-15 within the pair - char2 * yqs2 = (char2 *) y[ib].qs; - yqs2[pair_idx_in_block * 16 + 8 + group_id] = q; - } - - if (threadIdx.x == 0) { - e_packed[1] = e; - } + q.x = (q0_1 << 4) | q0_0; + q.y = (q0_3 << 4) | q0_2; + yqs2[pair_idx_in_block * 16 + group_id] = q; } - // Write packed exponents: d4[0-1] each stores 2 scales (for 2 blocks of 32) - // pair_idx_in_block tells us which d4 entry to use (0-1) - if (threadIdx.x == 0) { - y[ib].d4[pair_idx_in_block] = (e_packed[1] << 8) | e_packed[0]; + // === Process second block (32-63) === + float amax1 = fabsf(xi1); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + amax1 = fmaxf(amax1, __shfl_xor_sync(0xFFFFFFFF, amax1, mask, WARP_SIZE)); + } + + const uint8_t e1 = compute_e8m0_scale(amax1); + const float inv_s1 = (amax1 == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e1)); + const uint8_t q_val1 = ggml_cuda_float_to_fp4_e2m1(xi1, inv_s1); + + const uint8_t q1_0 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 0, WARP_SIZE); + const uint8_t q1_1 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 1, WARP_SIZE); + const uint8_t q1_2 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 2, WARP_SIZE); + const uint8_t q1_3 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 3, WARP_SIZE); + + if (lane_id == 0) { + char2 q; + q.x = (q1_1 << 4) | q1_0; + q.y = (q1_3 << 4) | q1_2; + yqs2[pair_idx_in_block * 16 + 8 + group_id] = q; + } + + // Write packed exponents + if (lane_id_32 == 0) { + y[ib].d4[pair_idx_in_block] = (e1 << 8) | e0; } } @@ -353,10 +323,13 @@ void quantize_mmq_mxfp4_cuda(const float * x, cudaStream_t stream) { GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0); // Each warp processes 64 values - // ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid: - constexpr int vals_per_warp = 2 * QK_MXFP4; // 64 - const int64_t block_num_y = (ne0 + vals_per_warp - 1) / vals_per_warp; + constexpr int nwarps = 8; + constexpr int vals_per_warp = 2 * QK_MXFP4; // 64 values per warp + constexpr int vals_per_block = nwarps * vals_per_warp; // 512 values per block + + const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block; const dim3 num_blocks(ne1, block_num_y, ne2 * ne3); - const dim3 block_size(32, 1, 1); // Warp size + const dim3 block_size(WARP_SIZE, nwarps, 1); // 32 threads x 8 warps = 256 threads per block + quantize_mmq_mxfp4<<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); }