From b7deb96d7c19a4884219a22367511a1635f91e1d Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 11 Dec 2025 08:41:20 +0100 Subject: [PATCH] use interleaved layout for mma --- ggml/src/ggml-cuda/mmq.cuh | 19 +------- ggml/src/ggml-cuda/quantize.cu | 82 +++++++++++++--------------------- 2 files changed, 31 insertions(+), 70 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 319c7d08ee..15af470843 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -804,25 +804,8 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx; - int aux_q4[4]; - memcpy(aux_q4, bxi->qs, 16); - - // Compress: extract low nibbles from each byte and pack into 16 bits - // Input byte layout: [hi3|lo3][hi2|lo2][hi1|lo1][hi0|lo0] - // Output: [lo3|lo2|lo1|lo0] as 16 bits - const auto compress = [](const int x) -> int { - const int m = x & 0x0F0F0F0F; // isolate low nibbles: 0x0lo30lo20lo10lo0 - // Pack nibbles: shift and combine - const int t1 = (m | (m >> 4)) & 0x00FF00FF; // 0x00_lo3lo2_00_lo1lo0 - return (t1 | (t1 >> 8)) & 0x0000FFFF; // 0x0000_lo3lo2lo1lo0 - }; - const int k0 = kbx * 4; - - x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 0] = compress(aux_q4[1]) << 16 | compress(aux_q4[0]); - x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 1] = compress(aux_q4[3]) << 16 | compress(aux_q4[2]); - x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 2] = compress(aux_q4[1] >> 4) << 16 | compress(aux_q4[0] >> 4); - x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 3] = compress(aux_q4[3] >> 4) << 16 | compress(aux_q4[2] >> 4); + memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16); // Load E8M0 scales: pack 2 consecutive scales into one uint32 if (kbx % 2 == 0) { diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 39d8b6592d..8a16557796 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -79,10 +79,8 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x, const int nwarps = blockDim.y; const int64_t warp_start_offset = (blockIdx.y * nwarps + warp_id) * vals_per_warp; - const int64_t i0_block0 = warp_start_offset + lane_id_32; - const int64_t i0_block1 = warp_start_offset + vals_per_scale + lane_id_32; - if (i0_block0 >= ne0) { + if (warp_start_offset >= ne0) { return; } @@ -101,66 +99,46 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x, const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x; const int64_t pair_idx_in_block = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp; - // Precompute common values - const int lane_id = lane_id_32 % 4; const int group_id = lane_id_32 / 4; - const int group_base = group_id * 4; + const int lane_in_group = lane_id_32 % 4; + const int base = group_id * 2; char2 * yqs2 = (char2 *) y[ib].qs; const int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01; - const float xi0 = (i0_block0 < ne00) ? x[base_pos + i0_block0] : 0.0f; - const float xi1 = (i0_block1 < ne00) ? x[base_pos + i0_block1] : 0.0f; - // === Process first block (0-31) === - float amax0 = fabsf(xi0); + uint8_t scales[2]; + #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - amax0 = fmaxf(amax0, __shfl_xor_sync(0xFFFFFFFF, amax0, mask, WARP_SIZE)); - } + for (int b = 0; b < 2; ++b) { + const int64_t i0 = warp_start_offset + b * vals_per_scale + lane_id_32; + const float xi = (i0 < ne00) ? x[base_pos + i0] : 0.0f; - const uint8_t e0 = compute_e8m0_scale(amax0); - const float inv_s0 = (amax0 == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e0)); - const uint8_t q_val0 = ggml_cuda_float_to_fp4_e2m1(xi0, inv_s0); - - // Gather 4 values from consecutive threads using shuffle - const uint8_t q0_0 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 0, WARP_SIZE); - const uint8_t q0_1 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 1, WARP_SIZE); - const uint8_t q0_2 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 2, WARP_SIZE); - const uint8_t q0_3 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 3, WARP_SIZE); - - if (lane_id == 0) { - char2 q; - q.x = (q0_1 << 4) | q0_0; - q.y = (q0_3 << 4) | q0_2; - yqs2[pair_idx_in_block * 16 + group_id] = q; - } - - // === Process second block (32-63) === - float amax1 = fabsf(xi1); + float amax = fabsf(xi); #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - amax1 = fmaxf(amax1, __shfl_xor_sync(0xFFFFFFFF, amax1, mask, WARP_SIZE)); + for (int mask = 16; mask > 0; mask >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE)); + } + + const uint8_t e = compute_e8m0_scale(amax); + scales[b] = e; + const float inv_s = (amax == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e)); + const uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s); + + const uint8_t q_lo_0 = __shfl_sync(0xFFFFFFFF, q_val, base, WARP_SIZE); + const uint8_t q_lo_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 1, WARP_SIZE); + const uint8_t q_hi_0 = __shfl_sync(0xFFFFFFFF, q_val, base + 16, WARP_SIZE); + const uint8_t q_hi_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 17, WARP_SIZE); + + if (lane_in_group == 0) { + char2 q; + q.x = (q_hi_0 << 4) | q_lo_0; + q.y = (q_hi_1 << 4) | q_lo_1; + yqs2[pair_idx_in_block * 16 + b * 8 + group_id] = q; + } } - const uint8_t e1 = compute_e8m0_scale(amax1); - const float inv_s1 = (amax1 == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e1)); - const uint8_t q_val1 = ggml_cuda_float_to_fp4_e2m1(xi1, inv_s1); - - const uint8_t q1_0 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 0, WARP_SIZE); - const uint8_t q1_1 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 1, WARP_SIZE); - const uint8_t q1_2 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 2, WARP_SIZE); - const uint8_t q1_3 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 3, WARP_SIZE); - - if (lane_id == 0) { - char2 q; - q.x = (q1_1 << 4) | q1_0; - q.y = (q1_3 << 4) | q1_2; - yqs2[pair_idx_in_block * 16 + 8 + group_id] = q; - } - - // Write packed exponents if (lane_id_32 == 0) { - y[ib].d4[pair_idx_in_block] = (e1 << 8) | e0; + y[ib].d4[pair_idx_in_block] = (scales[1] << 8) | scales[0]; } }