use __nv_fp4x4_e2m1

This commit is contained in:
Aman Gupta 2025-12-11 16:38:27 +01:00
parent 928cc5594f
commit a1672f620b
2 changed files with 25 additions and 0 deletions

View File

@ -122,6 +122,26 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
const uint8_t e = compute_e8m0_scale(amax);
scales[b] = e;
const float inv_s = (amax == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e));
#if CUDART_VERSION >= 12040
// Use hardware FP4 conversion: pre-scale and gather 4 floats, then convert+pack
const float scaled_val = xi * inv_s;
// Gather 4 scaled floats in the order matching __nv_fp4x4_e2m1 packing:
// float4(x,y,z,w) -> 16-bit with bits [3:0]=x, [7:4]=y, [11:8]=z, [15:12]=w
// This produces byte0 = (y<<4)|x, byte1 = (w<<4)|z
const float val0 = __shfl_sync(0xFFFFFFFF, scaled_val, base, WARP_SIZE); // -> low nibble byte 0
const float val1 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 16, WARP_SIZE); // -> high nibble byte 0
const float val2 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 1, WARP_SIZE); // -> low nibble byte 1
const float val3 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 17, WARP_SIZE); // -> high nibble byte 1
if (lane_in_group == 0) {
// Convert 4 floats -> packed 16-bit FP4 in one step
__nv_fp4x4_e2m1 fp4_packed(make_float4(val0, val1, val2, val3));
yqs2[pair_idx_in_block * 16 + b * 8 + group_id] = reinterpret_cast<const char2&>(fp4_packed);
}
#else
// Fallback: manual FP4 conversion using LUT
const uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s);
const uint8_t q_lo_0 = __shfl_sync(0xFFFFFFFF, q_val, base, WARP_SIZE);
@ -135,6 +155,7 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
q.y = (q_hi_1 << 4) | q_lo_1;
yqs2[pair_idx_in_block * 16 + b * 8 + group_id] = q;
}
#endif // CUDART_VERSION >= 12040
}
if (lane_id_32 == 0) {

View File

@ -10,6 +10,10 @@
#include <cuda_fp8.h>
#endif // CUDART_VERSION >= 12050
#if CUDART_VERSION >= 12040
#include <cuda_fp4.h>
#endif // CUDART_VERSION >= 12040
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH