diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 36d8a3aaab..9f93c70d21 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -799,6 +799,16 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { #endif // CUDART_VERSION >= 12050 } +static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) { +#ifdef FP8_AVAILABLE + const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation. + const __nv_fp8_e4m3 xf = *reinterpret_cast(&bits); + return static_cast(xf) / 2; +#else + NO_DEVICE_CODE; +#endif // FP8_AVAILABLE +} + __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) { const uint8_t sign_bit = (x < 0.0f) << 3; float ax = fabsf(x) * e; @@ -931,6 +941,13 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI_MXFP4; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_NVFP4; + static constexpr int qr = QR_NVFP4; + static constexpr int qi = QI_NVFP4; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK_K; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index b70492c7d6..79ccfe568a 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -617,6 +617,45 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t dequantize_block_mxfp4<<>>(vx, y); } +template +static __global__ void dequantize_block_nvfp4( + const void * __restrict__ vx, + dst_t * __restrict__ yy, + const int64_t ne) { + const int64_t i = blockIdx.x; + const int tid = threadIdx.x; + + const int64_t base = i * QK_NVFP4; + if (base >= ne) { + return; + } + + const block_nvfp4 * x = (const block_nvfp4 *) vx; + const block_nvfp4 & xb = x[i]; + + const int sub = tid / (QK_NVFP4_SUB / 2); + const int j = tid % (QK_NVFP4_SUB / 2); + + const float d = ggml_cuda_ue4m3_to_fp32(xb.d[sub]); + const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + j]; + + const int64_t y0 = base + sub * QK_NVFP4_SUB + j; + const int64_t y1 = y0 + QK_NVFP4_SUB / 2; + + yy[y0] = ggml_cuda_cast(d * kvalues_mxfp4[q & 0x0F]); + yy[y1] = ggml_cuda_cast(d * kvalues_mxfp4[q >> 4]); +} + +template +static void dequantize_row_nvfp4_cuda( + const void * vx, + dst_t * y, + const int64_t k, + cudaStream_t stream) { + GGML_ASSERT(k % QK_NVFP4 == 0); + const int nb = k / QK_NVFP4; + dequantize_block_nvfp4<<>>(vx, y, k); +} template static __global__ void convert_unary( const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, @@ -715,6 +754,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq3_s_cuda; case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_cuda; + case GGML_TYPE_NVFP4: + return dequantize_row_nvfp4_cuda; case GGML_TYPE_F32: return convert_unary_cont_cuda; case GGML_TYPE_BF16: @@ -766,6 +807,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq3_s_cuda; case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_cuda; + case GGML_TYPE_NVFP4: + return dequantize_row_nvfp4_cuda; case GGML_TYPE_F16: return convert_unary_cont_cuda; case GGML_TYPE_BF16: diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index a31e843e15..cc80eb3ffc 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1297,7 +1297,12 @@ static void ggml_cuda_op_mul_mat_cublas( const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2); - const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT; + const bool use_fp16 = + src0->type != GGML_TYPE_NVFP4 && + (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && + ggml_is_contiguous(src0) && + row_diff == src0->ne[1] && + dst->op_params[0] == GGML_PREC_DEFAULT; if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { ggml_cuda_pool_alloc src1_as_bf16(ctx.pool(id)); @@ -4781,6 +4786,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_MXFP4: +#ifdef FP8_AVAILABLE + case GGML_TYPE_NVFP4: +#endif // FP8_AVAILABLE case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 024b3d8cf2..66bd8beeae 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -15,6 +15,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1; case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1; case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1; + case GGML_TYPE_NVFP4: return vec_dot_nvfp4_q8_1; case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1; case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1; case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1; @@ -41,6 +42,7 @@ static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) { case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ; case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ; case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ; + case GGML_TYPE_NVFP4: return VDR_NVFP4_Q8_1_MMVQ; case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ; case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ; case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ; @@ -626,6 +628,12 @@ static void mul_mat_vec_q_switch_type( nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; + case GGML_TYPE_NVFP4: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; case GGML_TYPE_Q2_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index ab803aca21..40b2b41e7e 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -322,6 +322,38 @@ static __device__ __forceinline__ float vec_dot_mxfp4_q8_1( return d * sumi; } +#define VDR_NVFP4_Q8_1_MMVQ 4 +#define VDR_NVFP4_Q8_1_MMQ 8 + +static __device__ __forceinline__ float vec_dot_nvfp4_q8_1( + const void * __restrict__ vbq, + const block_q8_1 * __restrict__ bq8_1, + const int32_t & kbx, + const int32_t & iqs) { + + const block_nvfp4 * bq4 = (const block_nvfp4 *) vbq + kbx; + float sum = 0.0f; +#pragma unroll + for (int i = 0; i < VDR_NVFP4_Q8_1_MMVQ/2; i++) { + const int32_t iqs0 = iqs + 2*i; + const int32_t iqs1 = iqs0 + 1; + const int32_t is = iqs0 >> 1; + const int2 v0 = get_int_from_table_16(get_int_b4(bq4->qs, iqs0), kvalues_mxfp4); + const int2 v1 = get_int_from_table_16(get_int_b4(bq4->qs, iqs1), kvalues_mxfp4); + const block_q8_1 * bq8 = bq8_1 + (is >> 1); + const int32_t i8 = ((is & 1) << 2); + + int sumi = ggml_cuda_dp4a(v0.x, get_int_b4(bq8->qs, i8 + 0), 0); + sumi = ggml_cuda_dp4a(v0.y, get_int_b4(bq8->qs, i8 + 2), sumi); + sumi = ggml_cuda_dp4a(v1.x, get_int_b4(bq8->qs, i8 + 1), sumi); + sumi = ggml_cuda_dp4a(v1.y, get_int_b4(bq8->qs, i8 + 3), sumi); + + const float d = ggml_cuda_ue4m3_to_fp32(bq4->d[is]) * __low2float(bq8->ds); + sum += d * float(sumi); + } + + return sum; +} #define VDR_Q2_K_Q8_1_MMVQ 1 #define VDR_Q2_K_Q8_1_MMQ 4 diff --git a/ggml/src/ggml-cuda/vendors/cuda.h b/ggml/src/ggml-cuda/vendors/cuda.h index ba032cfab4..07bc47df3b 100644 --- a/ggml/src/ggml-cuda/vendors/cuda.h +++ b/ggml/src/ggml-cuda/vendors/cuda.h @@ -6,9 +6,10 @@ #include #include -#if CUDART_VERSION >= 12050 +#if CUDART_VERSION >= 11080 #include -#endif // CUDART_VERSION >= 12050 +#define FP8_AVAILABLE +#endif // CUDART_VERSION >= 11080 #if CUDART_VERSION >= 12080 #include diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 35d1e1a063..9d9ba1ee21 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -235,6 +235,12 @@ typedef __hip_bfloat16 nv_bfloat16; typedef __hip_bfloat162 nv_bfloat162; +#if HIP_VERSION >= 60200000 +#include +typedef __hip_fp8_e4m3 __nv_fp8_e4m3; +#define FP8_AVAILABLE +#endif // HIP_VERSION >= 60200000 + typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); static __device__ __forceinline__ int __vsubss4(const int a, const int b) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ac284db2b8..6a4f9b634b 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7284,7 +7284,7 @@ static const ggml_type all_types[] = { GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0, - GGML_TYPE_MXFP4, + GGML_TYPE_MXFP4, GGML_TYPE_NVFP4, GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K, @@ -7300,7 +7300,7 @@ static const ggml_type base_types[] = { GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, // for I8MM tests GGML_TYPE_Q4_K, - GGML_TYPE_MXFP4, // TODO: or "other" + GGML_TYPE_MXFP4, GGML_TYPE_NVFP4, // TODO: or "other" GGML_TYPE_IQ2_XXS };