diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 8a16557796..d764725b33 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -122,6 +122,26 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x, const uint8_t e = compute_e8m0_scale(amax); scales[b] = e; const float inv_s = (amax == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e)); + +#if CUDART_VERSION >= 12040 + // Use hardware FP4 conversion: pre-scale and gather 4 floats, then convert+pack + const float scaled_val = xi * inv_s; + + // Gather 4 scaled floats in the order matching __nv_fp4x4_e2m1 packing: + // float4(x,y,z,w) -> 16-bit with bits [3:0]=x, [7:4]=y, [11:8]=z, [15:12]=w + // This produces byte0 = (y<<4)|x, byte1 = (w<<4)|z + const float val0 = __shfl_sync(0xFFFFFFFF, scaled_val, base, WARP_SIZE); // -> low nibble byte 0 + const float val1 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 16, WARP_SIZE); // -> high nibble byte 0 + const float val2 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 1, WARP_SIZE); // -> low nibble byte 1 + const float val3 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 17, WARP_SIZE); // -> high nibble byte 1 + + if (lane_in_group == 0) { + // Convert 4 floats -> packed 16-bit FP4 in one step + __nv_fp4x4_e2m1 fp4_packed(make_float4(val0, val1, val2, val3)); + yqs2[pair_idx_in_block * 16 + b * 8 + group_id] = reinterpret_cast(fp4_packed); + } +#else + // Fallback: manual FP4 conversion using LUT const uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s); const uint8_t q_lo_0 = __shfl_sync(0xFFFFFFFF, q_val, base, WARP_SIZE); @@ -135,6 +155,7 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x, q.y = (q_hi_1 << 4) | q_lo_1; yqs2[pair_idx_in_block * 16 + b * 8 + group_id] = q; } +#endif // CUDART_VERSION >= 12040 } if (lane_id_32 == 0) { diff --git a/ggml/src/ggml-cuda/vendors/cuda.h b/ggml/src/ggml-cuda/vendors/cuda.h index 3b3086778e..5ba6227177 100644 --- a/ggml/src/ggml-cuda/vendors/cuda.h +++ b/ggml/src/ggml-cuda/vendors/cuda.h @@ -10,6 +10,10 @@ #include #endif // CUDART_VERSION >= 12050 +#if CUDART_VERSION >= 12040 +#include +#endif // CUDART_VERSION >= 12040 + #if CUDART_VERSION < 11020 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH