optimize quantize_mxfp4

This commit is contained in:
Aman Gupta 2025-12-10 16:44:07 +01:00
parent 41e876a24f
commit 40eb6c7ccd
2 changed files with 83 additions and 108 deletions

View File

@ -726,6 +726,8 @@ __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e)
int best_i = 0;
float best_err = fabsf(ax - pos_lut[0]);
#pragma unroll
for (int i = 1; i < 8; ++i) {
float err = fabsf(ax - pos_lut[i]);
if (err < best_err) {

View File

@ -47,6 +47,18 @@ static __global__ void quantize_q8_1(
y[ib].ds = make_half2(d, sum);
}
// Helper to compute E8M0 scale from amax using fast math
__device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) {
if (amax == 0.0f) {
return 127; // Special case: use scale of 1.0 for zero input
}
// log2(amax / 6.0) = log2(amax) - log2(6) ≈ log2(amax) - 2.585
// Use __log2f for fast approximate log2
const float log2_amax = __log2f(amax) - 2.5849625007211563f; // log2(6)
const int e_int = __float2int_rd(log2_amax) + 127; // floor + bias
return static_cast<uint8_t>(max(1, min(254, e_int)));
}
static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
const int32_t * __restrict__ ids,
void * __restrict__ vy,
@ -60,10 +72,15 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
constexpr int vals_per_scale = 32;
constexpr int vals_per_warp = 2 * vals_per_scale; // Each warp processes 2 blocks of 32
// Each warp processes 2 adjacent blocks of 32 values (64 values total)
const int64_t warp_start_offset = blockIdx.y * vals_per_warp;
const int64_t i0_block0 = warp_start_offset + threadIdx.x; // First block: 0-31
const int64_t i0_block1 = warp_start_offset + vals_per_scale + threadIdx.x; // Second block: 32-63
// Multiple warps per block - each warp handles different data
const int warp_id = threadIdx.y;
const int lane_id_32 = threadIdx.x;
const int nwarps = blockDim.y;
const int64_t warp_start_offset = (blockIdx.y * nwarps + warp_id) * vals_per_warp;
const int64_t i0_block0 = warp_start_offset + lane_id_32;
const int64_t i0_block1 = warp_start_offset + vals_per_scale + lane_id_32;
if (i0_block0 >= ne0) {
return;
@ -80,117 +97,70 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
block_fp4_mmq * y = (block_fp4_mmq *) vy;
const int64_t block_fp4_mmq_size = 4 * QK_MXFP4; // 128 values
const int64_t ib0 = blockIdx.z * ((int64_t) gridDim.x * gridDim.y * nwarps * vals_per_warp / block_fp4_mmq_size);
const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x;
const int64_t pair_idx_in_block = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp;
const int64_t ib0 =
blockIdx.z * ((int64_t) gridDim.x * gridDim.y * vals_per_warp / block_fp4_mmq_size); // first block of channel
const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x; // block index in channel
const int64_t pair_idx_in_block =
(warp_start_offset % block_fp4_mmq_size) / vals_per_warp; // 0-1: which pair of blocks within block_fp4_mmq
// Precompute common values
const int lane_id = lane_id_32 % 4;
const int group_id = lane_id_32 / 4;
const int group_base = group_id * 4;
char2 * yqs2 = (char2 *) y[ib].qs;
uint8_t e_packed[2];
const int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01;
const float xi0 = (i0_block0 < ne00) ? x[base_pos + i0_block0] : 0.0f;
const float xi1 = (i0_block1 < ne00) ? x[base_pos + i0_block1] : 0.0f;
// Process first block (0-31)
{
const int64_t global_src_pos = i03 * s03 + i02 * s02 + i01 * s01 + i0_block0;
const float xi = i0_block0 < ne00 ? x[global_src_pos] : 0.0f;
float amax = fabsf(xi);
// Reduce max across all 32 threads in the warp
// === Process first block (0-31) ===
float amax0 = fabsf(xi0);
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
}
uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax / 6.0f)) + 127) : 0;
float val = ggml_cuda_e8m0_to_fp32(e);
float inv_s = (amax == 0.0f) ? 0.0f : 1.0f / val;
// Quantize: each thread processes 1 value
uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s);
if (e == 0) {
e = 127;
}
// Pack 4 values into char2: threads 0,1,2,3 -> first char2, etc.
const int lane_id = threadIdx.x % 4;
const int group_id = threadIdx.x / 4;
// Use shuffle to gather values from 4 consecutive threads
uint8_t q0 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 0, WARP_SIZE);
uint8_t q1 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 1, WARP_SIZE);
uint8_t q2 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 2, WARP_SIZE);
uint8_t q3 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 3, WARP_SIZE);
char2 q;
if (lane_id == 0) {
q.x = (q1 << 4) | q0;
q.y = (q3 << 4) | q2;
// Write to output: first block in pair uses positions based on pair_idx_in_block
// Each pair has 2 blocks of 32 = 64 values = 16 char2 elements
char2 * yqs2 = (char2 *) y[ib].qs;
yqs2[pair_idx_in_block * 16 + group_id] = q;
}
if (threadIdx.x == 0) {
e_packed[0] = e;
}
for (int mask = 16; mask > 0; mask >>= 1) {
amax0 = fmaxf(amax0, __shfl_xor_sync(0xFFFFFFFF, amax0, mask, WARP_SIZE));
}
// Process second block (32-63)
{
const int64_t global_src_pos = i03 * s03 + i02 * s02 + i01 * s01 + i0_block1;
const float xi = i0_block1 < ne00 ? x[global_src_pos] : 0.0f;
const uint8_t e0 = compute_e8m0_scale(amax0);
const float inv_s0 = (amax0 == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e0));
const uint8_t q_val0 = ggml_cuda_float_to_fp4_e2m1(xi0, inv_s0);
float amax = fabsf(xi);
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
}
uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax / 6.0f)) + 127) : 0;
float val = ggml_cuda_e8m0_to_fp32(e);
float inv_s = (amax == 0.0f) ? 0.0f : 1.0f / val;
if (e == 0) {
e = 127;
}
uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s);
const int lane_id = threadIdx.x % 4;
const int group_id = threadIdx.x / 4;
// Use shuffle to gather values from 4 consecutive threads
uint8_t q0 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 0, WARP_SIZE);
uint8_t q1 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 1, WARP_SIZE);
uint8_t q2 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 2, WARP_SIZE);
uint8_t q3 = __shfl_sync(0xFFFFFFFF, q_val, (group_id * 4) + 3, WARP_SIZE);
// Gather 4 values from consecutive threads using shuffle
const uint8_t q0_0 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 0, WARP_SIZE);
const uint8_t q0_1 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 1, WARP_SIZE);
const uint8_t q0_2 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 2, WARP_SIZE);
const uint8_t q0_3 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 3, WARP_SIZE);
if (lane_id == 0) {
char2 q;
if (lane_id == 0) {
q.x = (q1 << 4) | q0;
q.y = (q3 << 4) | q2;
// Write to output: second block in pair uses positions 8-15 within the pair
char2 * yqs2 = (char2 *) y[ib].qs;
yqs2[pair_idx_in_block * 16 + 8 + group_id] = q;
}
if (threadIdx.x == 0) {
e_packed[1] = e;
}
q.x = (q0_1 << 4) | q0_0;
q.y = (q0_3 << 4) | q0_2;
yqs2[pair_idx_in_block * 16 + group_id] = q;
}
// Write packed exponents: d4[0-1] each stores 2 scales (for 2 blocks of 32)
// pair_idx_in_block tells us which d4 entry to use (0-1)
if (threadIdx.x == 0) {
y[ib].d4[pair_idx_in_block] = (e_packed[1] << 8) | e_packed[0];
// === Process second block (32-63) ===
float amax1 = fabsf(xi1);
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
amax1 = fmaxf(amax1, __shfl_xor_sync(0xFFFFFFFF, amax1, mask, WARP_SIZE));
}
const uint8_t e1 = compute_e8m0_scale(amax1);
const float inv_s1 = (amax1 == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e1));
const uint8_t q_val1 = ggml_cuda_float_to_fp4_e2m1(xi1, inv_s1);
const uint8_t q1_0 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 0, WARP_SIZE);
const uint8_t q1_1 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 1, WARP_SIZE);
const uint8_t q1_2 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 2, WARP_SIZE);
const uint8_t q1_3 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 3, WARP_SIZE);
if (lane_id == 0) {
char2 q;
q.x = (q1_1 << 4) | q1_0;
q.y = (q1_3 << 4) | q1_2;
yqs2[pair_idx_in_block * 16 + 8 + group_id] = q;
}
// Write packed exponents
if (lane_id_32 == 0) {
y[ib].d4[pair_idx_in_block] = (e1 << 8) | e0;
}
}
@ -353,10 +323,13 @@ void quantize_mmq_mxfp4_cuda(const float * x,
cudaStream_t stream) {
GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0); // Each warp processes 64 values
// ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid:
constexpr int vals_per_warp = 2 * QK_MXFP4; // 64
const int64_t block_num_y = (ne0 + vals_per_warp - 1) / vals_per_warp;
constexpr int nwarps = 8;
constexpr int vals_per_warp = 2 * QK_MXFP4; // 64 values per warp
constexpr int vals_per_block = nwarps * vals_per_warp; // 512 values per block
const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
const dim3 block_size(32, 1, 1); // Warp size
const dim3 block_size(WARP_SIZE, nwarps, 1); // 32 threads x 8 warps = 256 threads per block
quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
}