use interleaved layout for mma

This commit is contained in:
Aman Gupta 2025-12-11 08:41:20 +01:00
parent a6dcaa5742
commit b7deb96d7c
2 changed files with 31 additions and 70 deletions

View File

@ -804,25 +804,8 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
int aux_q4[4];
memcpy(aux_q4, bxi->qs, 16);
// Compress: extract low nibbles from each byte and pack into 16 bits
// Input byte layout: [hi3|lo3][hi2|lo2][hi1|lo1][hi0|lo0]
// Output: [lo3|lo2|lo1|lo0] as 16 bits
const auto compress = [](const int x) -> int {
const int m = x & 0x0F0F0F0F; // isolate low nibbles: 0x0lo30lo20lo10lo0
// Pack nibbles: shift and combine
const int t1 = (m | (m >> 4)) & 0x00FF00FF; // 0x00_lo3lo2_00_lo1lo0
return (t1 | (t1 >> 8)) & 0x0000FFFF; // 0x0000_lo3lo2lo1lo0
};
const int k0 = kbx * 4;
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 0] = compress(aux_q4[1]) << 16 | compress(aux_q4[0]);
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 1] = compress(aux_q4[3]) << 16 | compress(aux_q4[2]);
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 2] = compress(aux_q4[1] >> 4) << 16 | compress(aux_q4[0] >> 4);
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 3] = compress(aux_q4[3] >> 4) << 16 | compress(aux_q4[2] >> 4);
memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16);
// Load E8M0 scales: pack 2 consecutive scales into one uint32
if (kbx % 2 == 0) {

View File

@ -79,10 +79,8 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
const int nwarps = blockDim.y;
const int64_t warp_start_offset = (blockIdx.y * nwarps + warp_id) * vals_per_warp;
const int64_t i0_block0 = warp_start_offset + lane_id_32;
const int64_t i0_block1 = warp_start_offset + vals_per_scale + lane_id_32;
if (i0_block0 >= ne0) {
if (warp_start_offset >= ne0) {
return;
}
@ -101,66 +99,46 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x;
const int64_t pair_idx_in_block = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp;
// Precompute common values
const int lane_id = lane_id_32 % 4;
const int group_id = lane_id_32 / 4;
const int group_base = group_id * 4;
const int lane_in_group = lane_id_32 % 4;
const int base = group_id * 2;
char2 * yqs2 = (char2 *) y[ib].qs;
const int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01;
const float xi0 = (i0_block0 < ne00) ? x[base_pos + i0_block0] : 0.0f;
const float xi1 = (i0_block1 < ne00) ? x[base_pos + i0_block1] : 0.0f;
// === Process first block (0-31) ===
float amax0 = fabsf(xi0);
uint8_t scales[2];
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
amax0 = fmaxf(amax0, __shfl_xor_sync(0xFFFFFFFF, amax0, mask, WARP_SIZE));
}
for (int b = 0; b < 2; ++b) {
const int64_t i0 = warp_start_offset + b * vals_per_scale + lane_id_32;
const float xi = (i0 < ne00) ? x[base_pos + i0] : 0.0f;
const uint8_t e0 = compute_e8m0_scale(amax0);
const float inv_s0 = (amax0 == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e0));
const uint8_t q_val0 = ggml_cuda_float_to_fp4_e2m1(xi0, inv_s0);
// Gather 4 values from consecutive threads using shuffle
const uint8_t q0_0 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 0, WARP_SIZE);
const uint8_t q0_1 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 1, WARP_SIZE);
const uint8_t q0_2 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 2, WARP_SIZE);
const uint8_t q0_3 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 3, WARP_SIZE);
if (lane_id == 0) {
char2 q;
q.x = (q0_1 << 4) | q0_0;
q.y = (q0_3 << 4) | q0_2;
yqs2[pair_idx_in_block * 16 + group_id] = q;
}
// === Process second block (32-63) ===
float amax1 = fabsf(xi1);
float amax = fabsf(xi);
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
amax1 = fmaxf(amax1, __shfl_xor_sync(0xFFFFFFFF, amax1, mask, WARP_SIZE));
for (int mask = 16; mask > 0; mask >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
}
const uint8_t e = compute_e8m0_scale(amax);
scales[b] = e;
const float inv_s = (amax == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e));
const uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s);
const uint8_t q_lo_0 = __shfl_sync(0xFFFFFFFF, q_val, base, WARP_SIZE);
const uint8_t q_lo_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 1, WARP_SIZE);
const uint8_t q_hi_0 = __shfl_sync(0xFFFFFFFF, q_val, base + 16, WARP_SIZE);
const uint8_t q_hi_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 17, WARP_SIZE);
if (lane_in_group == 0) {
char2 q;
q.x = (q_hi_0 << 4) | q_lo_0;
q.y = (q_hi_1 << 4) | q_lo_1;
yqs2[pair_idx_in_block * 16 + b * 8 + group_id] = q;
}
}
const uint8_t e1 = compute_e8m0_scale(amax1);
const float inv_s1 = (amax1 == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e1));
const uint8_t q_val1 = ggml_cuda_float_to_fp4_e2m1(xi1, inv_s1);
const uint8_t q1_0 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 0, WARP_SIZE);
const uint8_t q1_1 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 1, WARP_SIZE);
const uint8_t q1_2 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 2, WARP_SIZE);
const uint8_t q1_3 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 3, WARP_SIZE);
if (lane_id == 0) {
char2 q;
q.x = (q1_1 << 4) | q1_0;
q.y = (q1_3 << 4) | q1_2;
yqs2[pair_idx_in_block * 16 + 8 + group_id] = q;
}
// Write packed exponents
if (lane_id_32 == 0) {
y[ib].d4[pair_idx_in_block] = (e1 << 8) | e0;
y[ib].d4[pair_idx_in_block] = (scales[1] << 8) | scales[0];
}
}