diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 72ad876d5e..6e819f199b 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -764,6 +764,12 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta nr0 = N_R0_MXFP4; smem = 32*sizeof(float); } break; + case GGML_TYPE_NVFP4: + { + nsg = N_SG_NVFP4; + nr0 = N_R0_NVFP4; + smem = (32 + 128)*sizeof(float); + } break; case GGML_TYPE_Q2_K: { nsg = N_SG_Q2_K; @@ -976,6 +982,12 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m nr0 = N_R0_MXFP4; smem = 32*sizeof(float); } break; + case GGML_TYPE_NVFP4: + { + nsg = N_SG_NVFP4; + nr0 = N_R0_NVFP4; + smem = (32 + 128)*sizeof(float); + } break; case GGML_TYPE_Q2_K: { nsg = N_SG_Q2_K; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 82101f4714..9f44229d9d 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1161,7 +1161,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_SOLVE_TRI: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: - return has_simdgroup_reduction && op->src[0]->type != GGML_TYPE_NVFP4; + return has_simdgroup_reduction; case GGML_OP_SET: case GGML_OP_CPY: case GGML_OP_DUP: @@ -1219,7 +1219,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te }; } case GGML_OP_GET_ROWS: - return op->src[0]->type != GGML_TYPE_NVFP4; + return true; case GGML_OP_SET_ROWS: { if (op->src[0]->type != GGML_TYPE_F32) { diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 53437b23cd..07bad47a57 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -26,6 +26,9 @@ #define N_R0_MXFP4 2 #define N_SG_MXFP4 2 +#define N_R0_NVFP4 2 +#define N_SG_NVFP4 2 + #define N_R0_Q2_K 4 #define N_SG_Q2_K 2 diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index c0bcad392b..dedf47a76e 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2049,6 +2049,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { op->src[0]->type == GGML_TYPE_Q5_1 || op->src[0]->type == GGML_TYPE_Q8_0 || op->src[0]->type == GGML_TYPE_MXFP4 || + op->src[0]->type == GGML_TYPE_NVFP4 || op->src[0]->type == GGML_TYPE_IQ4_NL || false) && (ne11 >= 2 && ne11 <= 8) ) || diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index b2328605dd..1d288d01bf 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -50,6 +50,36 @@ constexpr constant static float kvalues_mxfp4_f[16] = { 0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f }; +// UE4M3 scale lookup table (128 entries, unsigned E4M3 with bias=7) +// Computed at compile time from the UE4M3 format definition +struct ue4m3_table { + float data[128]; + constexpr ue4m3_table() : data{} { + for (int x = 0; x < 128; ++x) { + if (x == 0 || x == 0x7F) { + data[x] = 0.0f; + } else { + int exp = (x >> 3) & 0xF; + int man = x & 0x7; + if (exp == 0) { + // subnormal: man * 2^(-9) + data[x] = (float)man / 512.0f; + } else { + // normal: (1 + man/8) * 2^(exp-7) + float mantissa = 1.0f + (float)man / 8.0f; + float scale = 1.0f; + int e = exp - 7; + if (e > 0) { for (int i = 0; i < e; ++i) scale *= 2.0f; } + if (e < 0) { for (int i = 0; i < -e; ++i) scale *= 0.5f; } + data[x] = mantissa * scale; + } + } + } + } +}; +constexpr constant static ue4m3_table kvalues_ue4m3_table = ue4m3_table(); +#define kvalues_ue4m3_f kvalues_ue4m3_table.data + static inline int best_index_int8(int n, constant float * val, float x) { if (x <= val[0]) return 0; if (x >= val[n-1]) return n-1; @@ -73,6 +103,22 @@ static inline float e8m0_to_fp32(uint8_t x) { return as_type(bits); } +// UE4M3: unsigned, 4 exp bits (bias=7), 3 mantissa bits +// Only used to populate the threadgroup LUT in kernel_mul_mv_nvfp4_f32 +static inline float ue4m3_to_fp32(uint8_t x) { + if (x == 0) { + return 0.0f; + } + int exp = (x >> 3) & 0xF; + int man = x & 0x7; + if (exp == 0) { + return ldexp((float)man, -9); + } + // Normal: construct FP32 directly. exp-7+127 = exp+120 + uint32_t bits = (uint32_t)(exp + 120) << 23 | (uint32_t)man << 20; + return as_type(bits); +} + static inline float dot(float x, float y) { return x*y; } @@ -557,6 +603,49 @@ void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F]; } +// NVFP4 dequantize: block_nvfp4 has 64 elements = 4 sub-blocks of 16 +// nl=4, il=0..3 selects sub-block. Each call produces 4x4 = 16 elements. +// CPU layout per sub-block: low nibbles of 8 bytes -> elements 0..7, +// high nibbles -> elements 8..15 +template +void dequantize_nvfp4(device const block_nvfp4 * xb, short il, thread type4x4 & reg) { + device const uint8_t * qs = xb->qs + il * (QK_NVFP4_SUB/2); // 8 bytes per sub-block + const float d = kvalues_ue4m3_f[xb->d[il]]; + + // rows 0-1: low nibbles (elements 0..7) + for (int i = 0; i < 2; ++i) { + reg[i][0] = d * kvalues_mxfp4_f[qs[4*i + 0] & 0x0F]; + reg[i][1] = d * kvalues_mxfp4_f[qs[4*i + 1] & 0x0F]; + reg[i][2] = d * kvalues_mxfp4_f[qs[4*i + 2] & 0x0F]; + reg[i][3] = d * kvalues_mxfp4_f[qs[4*i + 3] & 0x0F]; + } + // rows 2-3: high nibbles (elements 8..15) + for (int i = 0; i < 2; ++i) { + reg[i+2][0] = d * kvalues_mxfp4_f[qs[4*i + 0] >> 4]; + reg[i+2][1] = d * kvalues_mxfp4_f[qs[4*i + 1] >> 4]; + reg[i+2][2] = d * kvalues_mxfp4_f[qs[4*i + 2] >> 4]; + reg[i+2][3] = d * kvalues_mxfp4_f[qs[4*i + 3] >> 4]; + } +} + +// t4 variant: chpb=16, il=0..15, each call produces 4 elements. +// Within each sub-block (4 calls): lo-row0, lo-row1, hi-row0, hi-row1 +template +void dequantize_nvfp4_t4(device const block_nvfp4 * xb, short il, thread type4 & reg) { + const short sub = il / 4; // sub-block index 0..3 + const short rem = il % 4; // 0..3 within sub-block + const short row = rem % 2; // byte row: 0 = bytes 0..3, 1 = bytes 4..7 + const uint8_t shr = rem >= 2 ? 4 : 0; // 0,1 = low nibble; 2,3 = high nibble + + device const uint8_t * qs = xb->qs + sub * (QK_NVFP4_SUB/2) + row * 4; + const float d = kvalues_ue4m3_f[xb->d[sub]]; + + reg[0] = d * kvalues_mxfp4_f[(qs[0] >> shr) & 0x0F]; + reg[1] = d * kvalues_mxfp4_f[(qs[1] >> shr) & 0x0F]; + reg[2] = d * kvalues_mxfp4_f[(qs[2] >> shr) & 0x0F]; + reg[3] = d * kvalues_mxfp4_f[(qs[3] >> shr) & 0x0F]; +} + template void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { const float d = xb->d; @@ -3743,6 +3832,11 @@ template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_3")]] kernel mul_mv_ext_q4 template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_mxfp4, 32, dequantize_mxfp4_t4>; template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_mxfp4, 32, dequantize_mxfp4_t4>; +template [[host_name("kernel_mul_mv_ext_nvfp4_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_nvfp4, 64, dequantize_nvfp4_t4>; +template [[host_name("kernel_mul_mv_ext_nvfp4_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_nvfp4, 64, dequantize_nvfp4_t4>; +template [[host_name("kernel_mul_mv_ext_nvfp4_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_nvfp4, 64, dequantize_nvfp4_t4>; +template [[host_name("kernel_mul_mv_ext_nvfp4_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_nvfp4, 64, dequantize_nvfp4_t4>; + template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>; template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>; template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>; @@ -8807,6 +8901,106 @@ kernel void kernel_mul_mv_mxfp4_f32( kernel_mul_mv_mxfp4_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } +template +void kernel_mul_mv_nvfp4_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + + threadgroup float * shmem_f32 = (threadgroup float *) shmem; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * NSG + sgitg) * NR0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_nvfp4 * x = (device const block_nvfp4 *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + const int nb = args.ne00/QK_NVFP4; + const int ns01 = args.nb01/args.nb00; + + const short ix = tiisg/2; // 0...15 + const short it = tiisg%2; // 0 or 1 + + shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16]; + + // UE4M3 scale LUT: 128 entries, loaded cooperatively by all threads + threadgroup float * ue4m3_lut = shmem_f32 + 32; + ue4m3_lut[sgitg*32 + tiisg] = ue4m3_to_fp32(sgitg*32 + tiisg); + ue4m3_lut[sgitg*32 + tiisg + 64] = ue4m3_to_fp32(sgitg*32 + tiisg + 64); + threadgroup_barrier(mem_flags::mem_threadgroup); + + float sumf[NR0] = {0.f}; + + // each thread handles 4 bytes per sub-block (half of the 8 bytes) + // lo nibbles -> y at offset it*4, hi nibbles -> y at offset 8+it*4 + device const float * yb = y + ix*QK_NVFP4; + + float4 yl[8]; // pre-loaded y: 4 sub-blocks × {lo, hi} + + for (int ib = ix; ib < nb && ib < ns01; ib += 16) { + // pre-load y into registers before the row loop + FOR_UNROLL (short s = 0; s < 4; s++) { + yl[2*s + 0] = *(device const float4 *)(yb + s*QK_NVFP4_SUB + it*4); + yl[2*s + 1] = *(device const float4 *)(yb + s*QK_NVFP4_SUB + QK_NVFP4_SUB/2 + it*4); + } + + FOR_UNROLL (short row = 0; row < NR0; row++) { + device const block_nvfp4 & xb = x[row*ns01 + ib]; + + float sub_sum = 0.0f; + FOR_UNROLL (short s = 0; s < 4; s++) { + device const uint8_t * qs = (device const uint8_t *)(xb.qs + s*(QK_NVFP4_SUB/2) + it*4); + + float4 acc = yl[2*s]*float4(shmem_f32[qs[0] & 0x0F], shmem_f32[qs[1] & 0x0F], shmem_f32[qs[2] & 0x0F], shmem_f32[qs[3] & 0x0F]) + + yl[2*s+1]*float4(shmem_f32[qs[0] >> 4], shmem_f32[qs[1] >> 4], shmem_f32[qs[2] >> 4], shmem_f32[qs[3] >> 4]); + + sub_sum += ue4m3_lut[xb.d[s]] * ((acc[0] + acc[1]) + (acc[2] + acc[3])); + } + sumf[row] += sub_sum; + } + + yb += 16 * QK_NVFP4; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = sum_all; + } + } +} + +[[host_name("kernel_mul_mv_nvfp4_f32")]] +kernel void kernel_mul_mv_nvfp4_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_nvfp4_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + template kernel void kernel_get_rows_q( constant ggml_metal_kargs_get_rows & args, @@ -9655,6 +9849,7 @@ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_nvfp4")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q; @@ -9717,6 +9912,7 @@ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_nvfp4_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm; @@ -9740,6 +9936,7 @@ template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_nvfp4_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm; @@ -9772,6 +9969,7 @@ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_m template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_nvfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; @@ -9795,6 +9993,7 @@ template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_m template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_nvfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; @@ -9950,6 +10149,8 @@ template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_nvfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; + template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;