use __nv_fp4x4_e2m1
This commit is contained in:
parent
928cc5594f
commit
a1672f620b
|
|
@ -122,6 +122,26 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
|
|||
const uint8_t e = compute_e8m0_scale(amax);
|
||||
scales[b] = e;
|
||||
const float inv_s = (amax == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e));
|
||||
|
||||
#if CUDART_VERSION >= 12040
|
||||
// Use hardware FP4 conversion: pre-scale and gather 4 floats, then convert+pack
|
||||
const float scaled_val = xi * inv_s;
|
||||
|
||||
// Gather 4 scaled floats in the order matching __nv_fp4x4_e2m1 packing:
|
||||
// float4(x,y,z,w) -> 16-bit with bits [3:0]=x, [7:4]=y, [11:8]=z, [15:12]=w
|
||||
// This produces byte0 = (y<<4)|x, byte1 = (w<<4)|z
|
||||
const float val0 = __shfl_sync(0xFFFFFFFF, scaled_val, base, WARP_SIZE); // -> low nibble byte 0
|
||||
const float val1 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 16, WARP_SIZE); // -> high nibble byte 0
|
||||
const float val2 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 1, WARP_SIZE); // -> low nibble byte 1
|
||||
const float val3 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 17, WARP_SIZE); // -> high nibble byte 1
|
||||
|
||||
if (lane_in_group == 0) {
|
||||
// Convert 4 floats -> packed 16-bit FP4 in one step
|
||||
__nv_fp4x4_e2m1 fp4_packed(make_float4(val0, val1, val2, val3));
|
||||
yqs2[pair_idx_in_block * 16 + b * 8 + group_id] = reinterpret_cast<const char2&>(fp4_packed);
|
||||
}
|
||||
#else
|
||||
// Fallback: manual FP4 conversion using LUT
|
||||
const uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s);
|
||||
|
||||
const uint8_t q_lo_0 = __shfl_sync(0xFFFFFFFF, q_val, base, WARP_SIZE);
|
||||
|
|
@ -135,6 +155,7 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
|
|||
q.y = (q_hi_1 << 4) | q_lo_1;
|
||||
yqs2[pair_idx_in_block * 16 + b * 8 + group_id] = q;
|
||||
}
|
||||
#endif // CUDART_VERSION >= 12040
|
||||
}
|
||||
|
||||
if (lane_id_32 == 0) {
|
||||
|
|
|
|||
|
|
@ -10,6 +10,10 @@
|
|||
#include <cuda_fp8.h>
|
||||
#endif // CUDART_VERSION >= 12050
|
||||
|
||||
#if CUDART_VERSION >= 12040
|
||||
#include <cuda_fp4.h>
|
||||
#endif // CUDART_VERSION >= 12040
|
||||
|
||||
#if CUDART_VERSION < 11020
|
||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
||||
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
|
||||
|
|
|
|||
Loading…
Reference in New Issue