From 211ab16a3a99be3e8a4492fa777fdabf398880ce Mon Sep 17 00:00:00 2001 From: richarddd Date: Thu, 12 Mar 2026 13:21:21 +0100 Subject: [PATCH] NVFP4 metal --- ggml/src/ggml-metal/ggml-metal-device.cpp | 12 ++ ggml/src/ggml-metal/ggml-metal-device.m | 4 +- ggml/src/ggml-metal/ggml-metal-impl.h | 3 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 1 + ggml/src/ggml-metal/ggml-metal.metal | 201 ++++++++++++++++++++++ 5 files changed, 219 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 169c63dd7a..5d7d79e85e 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -729,6 +729,12 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta nr0 = N_R0_MXFP4; smem = 32*sizeof(float); } break; + case GGML_TYPE_NVFP4: + { + nsg = N_SG_NVFP4; + nr0 = N_R0_NVFP4; + smem = (32 + 128)*sizeof(float); + } break; case GGML_TYPE_Q2_K: { nsg = N_SG_Q2_K; @@ -941,6 +947,12 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m nr0 = N_R0_MXFP4; smem = 32*sizeof(float); } break; + case GGML_TYPE_NVFP4: + { + nsg = N_SG_NVFP4; + nr0 = N_R0_NVFP4; + smem = (32 + 128)*sizeof(float); + } break; case GGML_TYPE_Q2_K: { nsg = N_SG_Q2_K; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index d42b8ab1eb..23bd2b2ab7 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1158,7 +1158,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_SOLVE_TRI: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: - return has_simdgroup_reduction && op->src[0]->type != GGML_TYPE_NVFP4; + return has_simdgroup_reduction; case GGML_OP_SET: case GGML_OP_CPY: case GGML_OP_DUP: @@ -1216,7 +1216,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te }; } case GGML_OP_GET_ROWS: - return op->src[0]->type != GGML_TYPE_NVFP4; + return true; case GGML_OP_SET_ROWS: { if (op->src[0]->type != GGML_TYPE_F32) { diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 99d64efc3b..183cade91b 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -26,6 +26,9 @@ #define N_R0_MXFP4 2 #define N_SG_MXFP4 2 +#define N_R0_NVFP4 2 +#define N_SG_NVFP4 2 + #define N_R0_Q2_K 4 #define N_SG_Q2_K 2 diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 267755d08c..a3ee1e6a01 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1970,6 +1970,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { op->src[0]->type == GGML_TYPE_Q5_1 || op->src[0]->type == GGML_TYPE_Q8_0 || op->src[0]->type == GGML_TYPE_MXFP4 || + op->src[0]->type == GGML_TYPE_NVFP4 || op->src[0]->type == GGML_TYPE_IQ4_NL || false) && (ne11 >= 2 && ne11 <= 8) ) || diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 29e4a245d5..7f15a2064c 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -50,6 +50,26 @@ constexpr constant static float kvalues_mxfp4_f[16] = { 0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f }; +// UE4M3 scale lookup table (128 entries, unsigned E4M3 with bias=7) +constexpr constant static float kvalues_ue4m3_f[128] = { + 0.f, 1.953125e-03f, 3.906250e-03f, 5.859375e-03f, 7.812500e-03f, 9.765625e-03f, 1.171875e-02f, 1.367188e-02f, + 1.562500e-02f, 1.757813e-02f, 1.953125e-02f, 2.148438e-02f, 2.343750e-02f, 2.539063e-02f, 2.734375e-02f, 2.929688e-02f, + 3.125000e-02f, 3.515625e-02f, 3.906250e-02f, 4.296875e-02f, 4.687500e-02f, 5.078125e-02f, 5.468750e-02f, 5.859375e-02f, + 6.250000e-02f, 7.031250e-02f, 7.812500e-02f, 8.593750e-02f, 9.375000e-02f, 1.015625e-01f, 1.093750e-01f, 1.171875e-01f, + 1.250000e-01f, 1.406250e-01f, 1.562500e-01f, 1.718750e-01f, 1.875000e-01f, 2.031250e-01f, 2.187500e-01f, 2.343750e-01f, + 2.500000e-01f, 2.812500e-01f, 3.125000e-01f, 3.437500e-01f, 3.750000e-01f, 4.062500e-01f, 4.375000e-01f, 4.687500e-01f, + 5.000000e-01f, 5.625000e-01f, 6.250000e-01f, 6.875000e-01f, 7.500000e-01f, 8.125000e-01f, 8.750000e-01f, 9.375000e-01f, + 1.000000e+00f, 1.125000e+00f, 1.250000e+00f, 1.375000e+00f, 1.500000e+00f, 1.625000e+00f, 1.750000e+00f, 1.875000e+00f, + 2.000000e+00f, 2.250000e+00f, 2.500000e+00f, 2.750000e+00f, 3.000000e+00f, 3.250000e+00f, 3.500000e+00f, 3.750000e+00f, + 4.000000e+00f, 4.500000e+00f, 5.000000e+00f, 5.500000e+00f, 6.000000e+00f, 6.500000e+00f, 7.000000e+00f, 7.500000e+00f, + 8.000000e+00f, 9.000000e+00f, 1.000000e+01f, 1.100000e+01f, 1.200000e+01f, 1.300000e+01f, 1.400000e+01f, 1.500000e+01f, + 1.600000e+01f, 1.800000e+01f, 2.000000e+01f, 2.200000e+01f, 2.400000e+01f, 2.600000e+01f, 2.800000e+01f, 3.000000e+01f, + 3.200000e+01f, 3.600000e+01f, 4.000000e+01f, 4.400000e+01f, 4.800000e+01f, 5.200000e+01f, 5.600000e+01f, 6.000000e+01f, + 6.400000e+01f, 7.200000e+01f, 8.000000e+01f, 8.800000e+01f, 9.600000e+01f, 1.040000e+02f, 1.120000e+02f, 1.200000e+02f, + 1.280000e+02f, 1.440000e+02f, 1.600000e+02f, 1.760000e+02f, 1.920000e+02f, 2.080000e+02f, 2.240000e+02f, 2.400000e+02f, + 2.560000e+02f, 2.880000e+02f, 3.200000e+02f, 3.520000e+02f, 3.840000e+02f, 4.160000e+02f, 4.480000e+02f, 0.f, +}; + static inline int best_index_int8(int n, constant float * val, float x) { if (x <= val[0]) return 0; if (x >= val[n-1]) return n-1; @@ -73,6 +93,32 @@ static inline float e8m0_to_fp32(uint8_t x) { return as_type(bits); } +// UE4M3: unsigned, 4 exp bits (bias=7), 3 mantissa bits +// Branchless using additive bias to avoid FP32 denormals on GPU +static inline float ue4m3_to_fp32(uint8_t x) { + // Add bias of 16 to exponent to keep all values in FP32 normal range + // UE4M3 bits: [6:3]=exp, [2:0]=man + // With biased exp: FP32 value = (1 + man/8) * 2^(exp+16-127) for normal + // We want (1 + man/8) * 2^(exp-7), so divide by 2^(16-127+7) = multiply by 2^(127-23) = 2^104 + // But for exp=0 (subnormal): value should be man * 2^(-9) + // With bias: we get (1 + man/8) * 2^(16-127) which is wrong for subnormals + // So this approach doesn't handle subnormals correctly. + // Since UE4M3 subnormals (exp=0, man=1..7) represent tiny values (max 7*2^-9 ≈ 0.0137), + // and these are scale factors, they're extremely rare in practice. + // Use select to handle the zero case, keep branches for subnormal. + if (x == 0) { + return 0.0f; + } + int exp = (x >> 3) & 0xF; + int man = x & 0x7; + if (exp == 0) { + return ldexp((float)man, -9); + } + // Normal: construct FP32 directly. exp-7+127 = exp+120 + uint32_t bits = (uint32_t)(exp + 120) << 23 | (uint32_t)man << 20; + return as_type(bits); +} + static inline float dot(float x, float y) { return x*y; } @@ -557,6 +603,49 @@ void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F]; } +// NVFP4 dequantize: block_nvfp4 has 64 elements = 4 sub-blocks of 16 +// nl=4, il=0..3 selects sub-block. Each call produces 4x4 = 16 elements. +// CPU layout per sub-block: low nibbles of 8 bytes -> elements 0..7, +// high nibbles -> elements 8..15 +template +void dequantize_nvfp4(device const block_nvfp4 * xb, short il, thread type4x4 & reg) { + device const uint8_t * qs = xb->qs + il * (QK_NVFP4_SUB/2); // 8 bytes per sub-block + const float d = kvalues_ue4m3_f[xb->d[il]]; + + // rows 0-1: low nibbles (elements 0..7) + for (int i = 0; i < 2; ++i) { + reg[i][0] = d * kvalues_mxfp4_f[qs[4*i + 0] & 0x0F]; + reg[i][1] = d * kvalues_mxfp4_f[qs[4*i + 1] & 0x0F]; + reg[i][2] = d * kvalues_mxfp4_f[qs[4*i + 2] & 0x0F]; + reg[i][3] = d * kvalues_mxfp4_f[qs[4*i + 3] & 0x0F]; + } + // rows 2-3: high nibbles (elements 8..15) + for (int i = 0; i < 2; ++i) { + reg[i+2][0] = d * kvalues_mxfp4_f[qs[4*i + 0] >> 4]; + reg[i+2][1] = d * kvalues_mxfp4_f[qs[4*i + 1] >> 4]; + reg[i+2][2] = d * kvalues_mxfp4_f[qs[4*i + 2] >> 4]; + reg[i+2][3] = d * kvalues_mxfp4_f[qs[4*i + 3] >> 4]; + } +} + +// t4 variant: chpb=16, il=0..15, each call produces 4 elements. +// Within each sub-block (4 calls): lo-row0, lo-row1, hi-row0, hi-row1 +template +void dequantize_nvfp4_t4(device const block_nvfp4 * xb, short il, thread type4 & reg) { + const short sub = il / 4; // sub-block index 0..3 + const short rem = il % 4; // 0..3 within sub-block + const short row = rem % 2; // byte row: 0 = bytes 0..3, 1 = bytes 4..7 + const uint8_t shr = rem >= 2 ? 4 : 0; // 0,1 = low nibble; 2,3 = high nibble + + device const uint8_t * qs = xb->qs + sub * (QK_NVFP4_SUB/2) + row * 4; + const float d = kvalues_ue4m3_f[xb->d[sub]]; + + reg[0] = d * kvalues_mxfp4_f[(qs[0] >> shr) & 0x0F]; + reg[1] = d * kvalues_mxfp4_f[(qs[1] >> shr) & 0x0F]; + reg[2] = d * kvalues_mxfp4_f[(qs[2] >> shr) & 0x0F]; + reg[3] = d * kvalues_mxfp4_f[(qs[3] >> shr) & 0x0F]; +} + template void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { const float d = xb->d; @@ -3518,6 +3607,11 @@ template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_3")]] kernel mul_mv_ext_q4 template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_mxfp4, 32, dequantize_mxfp4_t4>; template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_mxfp4, 32, dequantize_mxfp4_t4>; +template [[host_name("kernel_mul_mv_ext_nvfp4_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_nvfp4, 64, dequantize_nvfp4_t4>; +template [[host_name("kernel_mul_mv_ext_nvfp4_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_nvfp4, 64, dequantize_nvfp4_t4>; +template [[host_name("kernel_mul_mv_ext_nvfp4_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_nvfp4, 64, dequantize_nvfp4_t4>; +template [[host_name("kernel_mul_mv_ext_nvfp4_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_nvfp4, 64, dequantize_nvfp4_t4>; + template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>; template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>; template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>; @@ -8563,6 +8657,106 @@ kernel void kernel_mul_mv_mxfp4_f32( kernel_mul_mv_mxfp4_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } +template +void kernel_mul_mv_nvfp4_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + + threadgroup float * shmem_f32 = (threadgroup float *) shmem; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * NSG + sgitg) * NR0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_nvfp4 * x = (device const block_nvfp4 *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); + + const int nb = args.ne00/QK_NVFP4; + const int ns01 = args.nb01/args.nb00; + + const short ix = tiisg/2; // 0...15 + const short it = tiisg%2; // 0 or 1 + + shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16]; + + // UE4M3 scale LUT: 128 entries, loaded cooperatively by all threads + threadgroup float * ue4m3_lut = shmem_f32 + 32; + ue4m3_lut[sgitg*32 + tiisg] = ue4m3_to_fp32(sgitg*32 + tiisg); + ue4m3_lut[sgitg*32 + tiisg + 64] = ue4m3_to_fp32(sgitg*32 + tiisg + 64); + threadgroup_barrier(mem_flags::mem_threadgroup); + + float sumf[NR0] = {0.f}; + + // each thread handles 4 bytes per sub-block (half of the 8 bytes) + // lo nibbles -> y at offset it*4, hi nibbles -> y at offset 8+it*4 + device const float * yb = y + ix*QK_NVFP4; + + float4 yl[8]; // pre-loaded y: 4 sub-blocks × {lo, hi} + + for (int ib = ix; ib < nb && ib < ns01; ib += 16) { + // pre-load y into registers before the row loop + FOR_UNROLL (short s = 0; s < 4; s++) { + yl[2*s + 0] = *(device const float4 *)(yb + s*QK_NVFP4_SUB + it*4); + yl[2*s + 1] = *(device const float4 *)(yb + s*QK_NVFP4_SUB + QK_NVFP4_SUB/2 + it*4); + } + + FOR_UNROLL (short row = 0; row < NR0; row++) { + device const block_nvfp4 & xb = x[row*ns01 + ib]; + + float sub_sum = 0.0f; + FOR_UNROLL (short s = 0; s < 4; s++) { + device const uint8_t * qs = (device const uint8_t *)(xb.qs + s*(QK_NVFP4_SUB/2) + it*4); + + float4 acc = yl[2*s]*float4(shmem_f32[qs[0] & 0x0F], shmem_f32[qs[1] & 0x0F], shmem_f32[qs[2] & 0x0F], shmem_f32[qs[3] & 0x0F]) + + yl[2*s+1]*float4(shmem_f32[qs[0] >> 4], shmem_f32[qs[1] >> 4], shmem_f32[qs[2] >> 4], shmem_f32[qs[3] >> 4]); + + sub_sum += ue4m3_lut[xb.d[s]] * ((acc[0] + acc[1]) + (acc[2] + acc[3])); + } + sumf[row] += sub_sum; + } + + yb += 16 * QK_NVFP4; + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { + float sum_all = simd_sum(sumf[row]); + if (tiisg == 0) { + dst_f32[first_row + row] = sum_all; + } + } +} + +[[host_name("kernel_mul_mv_nvfp4_f32")]] +kernel void kernel_mul_mv_nvfp4_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + + kernel_mul_mv_nvfp4_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + template kernel void kernel_get_rows_q( constant ggml_metal_kargs_get_rows & args, @@ -9411,6 +9605,7 @@ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q; +template [[host_name("kernel_get_rows_nvfp4")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q; @@ -9473,6 +9668,7 @@ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_nvfp4_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm; @@ -9496,6 +9692,7 @@ template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_nvfp4_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm; @@ -9528,6 +9725,7 @@ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_m template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_nvfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id; @@ -9551,6 +9749,7 @@ template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_m template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_nvfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id; @@ -9706,6 +9905,8 @@ template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_nvfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; + template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>;