This commit is contained in:
Michael Wand 2026-03-24 02:05:02 -07:00 committed by GitHub
commit 6fc268c673
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 160 additions and 2 deletions

View File

@ -931,6 +931,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_MXFP4> {
static constexpr int qi = QI_MXFP4;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_NVFP4> {
static constexpr int qk = QK_NVFP4;
static constexpr int qr = QR_NVFP4;
static constexpr int qi = QI_NVFP4;
};
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
static constexpr int qk = QK_K;

View File

@ -617,6 +617,45 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t
dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y);
}
template <typename dst_t>
static __global__ void dequantize_block_nvfp4(
const void * __restrict__ vx,
dst_t * __restrict__ yy,
const int64_t ne) {
const int64_t i = blockIdx.x;
const int tid = threadIdx.x;
const int64_t base = i * QK_NVFP4;
if (base >= ne) {
return;
}
const block_nvfp4 * x = (const block_nvfp4 *) vx;
const block_nvfp4 & xb = x[i];
const int sub = tid / (QK_NVFP4_SUB / 2);
const int j = tid % (QK_NVFP4_SUB / 2);
const float d = ggml_cuda_cast<float>(ue4m3{xb.d[sub]});
const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + j];
const int64_t y0 = base + sub * QK_NVFP4_SUB + j;
const int64_t y1 = y0 + QK_NVFP4_SUB / 2;
yy[y0] = ggml_cuda_cast<dst_t>(d * kvalues_mxfp4[q & 0x0F]);
yy[y1] = ggml_cuda_cast<dst_t>(d * kvalues_mxfp4[q >> 4]);
}
template <typename dst_t>
static void dequantize_row_nvfp4_cuda(
const void * vx,
dst_t * y,
const int64_t k,
cudaStream_t stream) {
GGML_ASSERT(k % QK_NVFP4 == 0);
const int nb = k / QK_NVFP4;
dequantize_block_nvfp4<<<nb, 32, 0, stream>>>(vx, y, k);
}
template <typename src_t, typename dst_t>
static __global__ void convert_unary(
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
@ -715,6 +754,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_cuda;
case GGML_TYPE_NVFP4:
return dequantize_row_nvfp4_cuda;
case GGML_TYPE_F32:
return convert_unary_cont_cuda<float>;
case GGML_TYPE_BF16:
@ -766,6 +807,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_cuda;
case GGML_TYPE_NVFP4:
return dequantize_row_nvfp4_cuda;
case GGML_TYPE_F16:
return convert_unary_cont_cuda<half>;
case GGML_TYPE_BF16:

View File

@ -31,6 +31,10 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type);
to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);
struct ue4m3 {
uint8_t x;
};
template<typename dst_t, typename src_t>
__host__ __device__ inline dst_t ggml_cuda_cast(src_t x) {
if constexpr (std::is_same_v<dst_t, src_t>) {
@ -58,6 +62,54 @@ template<typename dst_t, typename src_t>
#else
return {x.x, x.y};
#endif // GGML_USE_HIP
} else if constexpr (std::is_same_v<src_t, ue4m3> && std::is_same_v<dst_t, float>) {
#if defined(__CUDA_ARCH__)
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && defined(CUDART_VERSION) && CUDART_VERSION >= 12050 // This matches cuda_fp8.h's version gate.
// This uses the same fp8 conversion that __nv_fp8_e4m3 uses internally.
__half h = __half(__nv_cvt_fp8_to_halfraw((__nv_fp8_storage_t) x.x, __NV_E4M3));
unsigned short hb = __half_as_ushort(h) - 0x0400u; // Built in 0.5f op.
float f = __half2float(__ushort_as_half(hb));
// __nv_fp8_e4m3 is signed but UE4M3
if (x.x == 0u || x.x == 0x7Fu) { // force 0x7F and 0x00 to 0.0f
f = 0.0f;
}
return f;
#else
if (x.x == 0u || x.x == 0x7Fu) {
return 0.0f;
}
const uint32_t exp = x.x >> 3;
const uint32_t mant = x.x & 0x7u;
uint32_t bits;
if (exp != 0u) {
bits = ((exp + 119u) << 23) | (mant << 20);
} else {
uint32_t p;
if (mant < 2u) {
p = 0u;
} else if (mant < 4u) {
p = 1u;
} else {
p = 2u;
}
const uint32_t r = mant - (1u << p);
const uint32_t exp32 = p + 117u;
const uint32_t mant32 = r << (23u - p);
bits = (exp32 << 23) | mant32;
}
float f;
memcpy(&f, &bits, sizeof(f)); // 0.5f is baked in.
return f;
#endif
#else
return ggml_ue4m3_to_fp32(x.x);
#endif
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
return int32_t(x);
} else {

View File

@ -1297,7 +1297,12 @@ static void ggml_cuda_op_mul_mat_cublas(
const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
const bool use_fp16 =
src0->type != GGML_TYPE_NVFP4 &&
(src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
ggml_is_contiguous(src0) &&
row_diff == src0->ne[1] &&
dst->op_params[0] == GGML_PREC_DEFAULT;
if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
@ -2281,6 +2286,11 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
}
// temporarily block MMQ for NVFP4
if (src0->type == GGML_TYPE_NVFP4) {
use_mul_mat_q = false;
}
// debug helpers
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
@ -2350,10 +2360,15 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
}
}
// this is temporary to block MMQ for now
if (src0->type != GGML_TYPE_NVFP4) {
if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) {
ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
return;
}
}
if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src0->nb, src1->ne[2], /*mul_mat_id=*/true)) {
ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
@ -4781,6 +4796,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_NVFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:

View File

@ -15,6 +15,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1;
case GGML_TYPE_NVFP4: return vec_dot_nvfp4_q8_1;
case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
@ -41,6 +42,7 @@ static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) {
case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ;
case GGML_TYPE_NVFP4: return VDR_NVFP4_Q8_1_MMVQ;
case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
@ -626,6 +628,12 @@ static void mul_mat_vec_q_switch_type(
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_NVFP4:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_NVFP4>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q2_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,

View File

@ -1,6 +1,6 @@
#pragma once
#include "common.cuh"
#include "convert.cuh"
#include <cstdint>
@ -322,6 +322,38 @@ static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
return d * sumi;
}
#define VDR_NVFP4_Q8_1_MMVQ 4
#define VDR_NVFP4_Q8_1_MMQ 8
static __device__ __forceinline__ float vec_dot_nvfp4_q8_1(
const void * __restrict__ vbq,
const block_q8_1 * __restrict__ bq8_1,
const int & kbx,
const int & iqs) {
const block_nvfp4 * bq4 = (const block_nvfp4 *) vbq + kbx;
float sum = 0.0f;
#pragma unroll
for (int i = 0; i < VDR_NVFP4_Q8_1_MMVQ/2; i++) {
const int iqs0 = iqs + 2*i;
const int iqs1 = iqs0 + 1;
const int is = iqs0 >> 1;
const int2 v0 = get_int_from_table_16(get_int_b4(bq4->qs, iqs0), kvalues_mxfp4);
const int2 v1 = get_int_from_table_16(get_int_b4(bq4->qs, iqs1), kvalues_mxfp4);
const block_q8_1 * bq8 = bq8_1 + (is >> 1);
const int i8 = ((is & 1) << 2);
int sumi = ggml_cuda_dp4a(v0.x, get_int_b4(bq8->qs, i8 + 0), 0);
sumi = ggml_cuda_dp4a(v0.y, get_int_b4(bq8->qs, i8 + 2), sumi);
sumi = ggml_cuda_dp4a(v1.x, get_int_b4(bq8->qs, i8 + 1), sumi);
sumi = ggml_cuda_dp4a(v1.y, get_int_b4(bq8->qs, i8 + 3), sumi);
const float d = ggml_cuda_cast<float>(ue4m3{bq4->d[is]}) * __low2float(bq8->ds);
sum += d * float(sumi);
}
return sum;
}
#define VDR_Q2_K_Q8_1_MMVQ 1
#define VDR_Q2_K_Q8_1_MMQ 4