use interleaved layout for mma
This commit is contained in:
parent
a6dcaa5742
commit
b7deb96d7c
|
|
@ -804,25 +804,8 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
|
|||
|
||||
const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
|
||||
|
||||
int aux_q4[4];
|
||||
memcpy(aux_q4, bxi->qs, 16);
|
||||
|
||||
// Compress: extract low nibbles from each byte and pack into 16 bits
|
||||
// Input byte layout: [hi3|lo3][hi2|lo2][hi1|lo1][hi0|lo0]
|
||||
// Output: [lo3|lo2|lo1|lo0] as 16 bits
|
||||
const auto compress = [](const int x) -> int {
|
||||
const int m = x & 0x0F0F0F0F; // isolate low nibbles: 0x0lo30lo20lo10lo0
|
||||
// Pack nibbles: shift and combine
|
||||
const int t1 = (m | (m >> 4)) & 0x00FF00FF; // 0x00_lo3lo2_00_lo1lo0
|
||||
return (t1 | (t1 >> 8)) & 0x0000FFFF; // 0x0000_lo3lo2lo1lo0
|
||||
};
|
||||
|
||||
const int k0 = kbx * 4;
|
||||
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 0] = compress(aux_q4[1]) << 16 | compress(aux_q4[0]);
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 1] = compress(aux_q4[3]) << 16 | compress(aux_q4[2]);
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 2] = compress(aux_q4[1] >> 4) << 16 | compress(aux_q4[0] >> 4);
|
||||
x_qs[i * MMQ_MMA_TILE_X_K_FP4 + k0 + 3] = compress(aux_q4[3] >> 4) << 16 | compress(aux_q4[2] >> 4);
|
||||
memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16);
|
||||
|
||||
// Load E8M0 scales: pack 2 consecutive scales into one uint32
|
||||
if (kbx % 2 == 0) {
|
||||
|
|
|
|||
|
|
@ -79,10 +79,8 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
|
|||
const int nwarps = blockDim.y;
|
||||
|
||||
const int64_t warp_start_offset = (blockIdx.y * nwarps + warp_id) * vals_per_warp;
|
||||
const int64_t i0_block0 = warp_start_offset + lane_id_32;
|
||||
const int64_t i0_block1 = warp_start_offset + vals_per_scale + lane_id_32;
|
||||
|
||||
if (i0_block0 >= ne0) {
|
||||
if (warp_start_offset >= ne0) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -101,66 +99,46 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
|
|||
const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x;
|
||||
const int64_t pair_idx_in_block = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp;
|
||||
|
||||
// Precompute common values
|
||||
const int lane_id = lane_id_32 % 4;
|
||||
const int group_id = lane_id_32 / 4;
|
||||
const int group_base = group_id * 4;
|
||||
const int lane_in_group = lane_id_32 % 4;
|
||||
const int base = group_id * 2;
|
||||
char2 * yqs2 = (char2 *) y[ib].qs;
|
||||
|
||||
const int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01;
|
||||
const float xi0 = (i0_block0 < ne00) ? x[base_pos + i0_block0] : 0.0f;
|
||||
const float xi1 = (i0_block1 < ne00) ? x[base_pos + i0_block1] : 0.0f;
|
||||
|
||||
// === Process first block (0-31) ===
|
||||
float amax0 = fabsf(xi0);
|
||||
uint8_t scales[2];
|
||||
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
amax0 = fmaxf(amax0, __shfl_xor_sync(0xFFFFFFFF, amax0, mask, WARP_SIZE));
|
||||
}
|
||||
for (int b = 0; b < 2; ++b) {
|
||||
const int64_t i0 = warp_start_offset + b * vals_per_scale + lane_id_32;
|
||||
const float xi = (i0 < ne00) ? x[base_pos + i0] : 0.0f;
|
||||
|
||||
const uint8_t e0 = compute_e8m0_scale(amax0);
|
||||
const float inv_s0 = (amax0 == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e0));
|
||||
const uint8_t q_val0 = ggml_cuda_float_to_fp4_e2m1(xi0, inv_s0);
|
||||
|
||||
// Gather 4 values from consecutive threads using shuffle
|
||||
const uint8_t q0_0 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 0, WARP_SIZE);
|
||||
const uint8_t q0_1 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 1, WARP_SIZE);
|
||||
const uint8_t q0_2 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 2, WARP_SIZE);
|
||||
const uint8_t q0_3 = __shfl_sync(0xFFFFFFFF, q_val0, group_base + 3, WARP_SIZE);
|
||||
|
||||
if (lane_id == 0) {
|
||||
char2 q;
|
||||
q.x = (q0_1 << 4) | q0_0;
|
||||
q.y = (q0_3 << 4) | q0_2;
|
||||
yqs2[pair_idx_in_block * 16 + group_id] = q;
|
||||
}
|
||||
|
||||
// === Process second block (32-63) ===
|
||||
float amax1 = fabsf(xi1);
|
||||
float amax = fabsf(xi);
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
amax1 = fmaxf(amax1, __shfl_xor_sync(0xFFFFFFFF, amax1, mask, WARP_SIZE));
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
|
||||
}
|
||||
|
||||
const uint8_t e = compute_e8m0_scale(amax);
|
||||
scales[b] = e;
|
||||
const float inv_s = (amax == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e));
|
||||
const uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s);
|
||||
|
||||
const uint8_t q_lo_0 = __shfl_sync(0xFFFFFFFF, q_val, base, WARP_SIZE);
|
||||
const uint8_t q_lo_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 1, WARP_SIZE);
|
||||
const uint8_t q_hi_0 = __shfl_sync(0xFFFFFFFF, q_val, base + 16, WARP_SIZE);
|
||||
const uint8_t q_hi_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 17, WARP_SIZE);
|
||||
|
||||
if (lane_in_group == 0) {
|
||||
char2 q;
|
||||
q.x = (q_hi_0 << 4) | q_lo_0;
|
||||
q.y = (q_hi_1 << 4) | q_lo_1;
|
||||
yqs2[pair_idx_in_block * 16 + b * 8 + group_id] = q;
|
||||
}
|
||||
}
|
||||
|
||||
const uint8_t e1 = compute_e8m0_scale(amax1);
|
||||
const float inv_s1 = (amax1 == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e1));
|
||||
const uint8_t q_val1 = ggml_cuda_float_to_fp4_e2m1(xi1, inv_s1);
|
||||
|
||||
const uint8_t q1_0 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 0, WARP_SIZE);
|
||||
const uint8_t q1_1 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 1, WARP_SIZE);
|
||||
const uint8_t q1_2 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 2, WARP_SIZE);
|
||||
const uint8_t q1_3 = __shfl_sync(0xFFFFFFFF, q_val1, group_base + 3, WARP_SIZE);
|
||||
|
||||
if (lane_id == 0) {
|
||||
char2 q;
|
||||
q.x = (q1_1 << 4) | q1_0;
|
||||
q.y = (q1_3 << 4) | q1_2;
|
||||
yqs2[pair_idx_in_block * 16 + 8 + group_id] = q;
|
||||
}
|
||||
|
||||
// Write packed exponents
|
||||
if (lane_id_32 == 0) {
|
||||
y[ib].d4[pair_idx_in_block] = (e1 << 8) | e0;
|
||||
y[ib].d4[pair_idx_in_block] = (scales[1] << 8) | scales[0];
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue