This commit is contained in:
Richard Davison 2026-03-15 23:55:07 +02:00 committed by GitHub
commit 61bee6b71f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 219 additions and 2 deletions

View File

@ -764,6 +764,12 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta
nr0 = N_R0_MXFP4;
smem = 32*sizeof(float);
} break;
case GGML_TYPE_NVFP4:
{
nsg = N_SG_NVFP4;
nr0 = N_R0_NVFP4;
smem = (32 + 128)*sizeof(float);
} break;
case GGML_TYPE_Q2_K:
{
nsg = N_SG_Q2_K;
@ -976,6 +982,12 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m
nr0 = N_R0_MXFP4;
smem = 32*sizeof(float);
} break;
case GGML_TYPE_NVFP4:
{
nsg = N_SG_NVFP4;
nr0 = N_R0_NVFP4;
smem = (32 + 128)*sizeof(float);
} break;
case GGML_TYPE_Q2_K:
{
nsg = N_SG_Q2_K;

View File

@ -1161,7 +1161,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_SOLVE_TRI:
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
return has_simdgroup_reduction && op->src[0]->type != GGML_TYPE_NVFP4;
return has_simdgroup_reduction;
case GGML_OP_SET:
case GGML_OP_CPY:
case GGML_OP_DUP:
@ -1219,7 +1219,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
};
}
case GGML_OP_GET_ROWS:
return op->src[0]->type != GGML_TYPE_NVFP4;
return true;
case GGML_OP_SET_ROWS:
{
if (op->src[0]->type != GGML_TYPE_F32) {

View File

@ -26,6 +26,9 @@
#define N_R0_MXFP4 2
#define N_SG_MXFP4 2
#define N_R0_NVFP4 2
#define N_SG_NVFP4 2
#define N_R0_Q2_K 4
#define N_SG_Q2_K 2

View File

@ -2049,6 +2049,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
op->src[0]->type == GGML_TYPE_Q5_1 ||
op->src[0]->type == GGML_TYPE_Q8_0 ||
op->src[0]->type == GGML_TYPE_MXFP4 ||
op->src[0]->type == GGML_TYPE_NVFP4 ||
op->src[0]->type == GGML_TYPE_IQ4_NL ||
false) && (ne11 >= 2 && ne11 <= 8)
) ||

View File

@ -50,6 +50,36 @@ constexpr constant static float kvalues_mxfp4_f[16] = {
0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
};
// UE4M3 scale lookup table (128 entries, unsigned E4M3 with bias=7)
// Computed at compile time from the UE4M3 format definition
struct ue4m3_table {
float data[128];
constexpr ue4m3_table() : data{} {
for (int x = 0; x < 128; ++x) {
if (x == 0 || x == 0x7F) {
data[x] = 0.0f;
} else {
int exp = (x >> 3) & 0xF;
int man = x & 0x7;
if (exp == 0) {
// subnormal: man * 2^(-9)
data[x] = (float)man / 512.0f;
} else {
// normal: (1 + man/8) * 2^(exp-7)
float mantissa = 1.0f + (float)man / 8.0f;
float scale = 1.0f;
int e = exp - 7;
if (e > 0) { for (int i = 0; i < e; ++i) scale *= 2.0f; }
if (e < 0) { for (int i = 0; i < -e; ++i) scale *= 0.5f; }
data[x] = mantissa * scale;
}
}
}
}
};
constexpr constant static ue4m3_table kvalues_ue4m3_table = ue4m3_table();
#define kvalues_ue4m3_f kvalues_ue4m3_table.data
static inline int best_index_int8(int n, constant float * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
@ -73,6 +103,22 @@ static inline float e8m0_to_fp32(uint8_t x) {
return as_type<float>(bits);
}
// UE4M3: unsigned, 4 exp bits (bias=7), 3 mantissa bits
// Only used to populate the threadgroup LUT in kernel_mul_mv_nvfp4_f32
static inline float ue4m3_to_fp32(uint8_t x) {
if (x == 0) {
return 0.0f;
}
int exp = (x >> 3) & 0xF;
int man = x & 0x7;
if (exp == 0) {
return ldexp((float)man, -9);
}
// Normal: construct FP32 directly. exp-7+127 = exp+120
uint32_t bits = (uint32_t)(exp + 120) << 23 | (uint32_t)man << 20;
return as_type<float>(bits);
}
static inline float dot(float x, float y) {
return x*y;
}
@ -557,6 +603,49 @@ void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 &
reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];
}
// NVFP4 dequantize: block_nvfp4 has 64 elements = 4 sub-blocks of 16
// nl=4, il=0..3 selects sub-block. Each call produces 4x4 = 16 elements.
// CPU layout per sub-block: low nibbles of 8 bytes -> elements 0..7,
// high nibbles -> elements 8..15
template <typename type4x4>
void dequantize_nvfp4(device const block_nvfp4 * xb, short il, thread type4x4 & reg) {
device const uint8_t * qs = xb->qs + il * (QK_NVFP4_SUB/2); // 8 bytes per sub-block
const float d = kvalues_ue4m3_f[xb->d[il]];
// rows 0-1: low nibbles (elements 0..7)
for (int i = 0; i < 2; ++i) {
reg[i][0] = d * kvalues_mxfp4_f[qs[4*i + 0] & 0x0F];
reg[i][1] = d * kvalues_mxfp4_f[qs[4*i + 1] & 0x0F];
reg[i][2] = d * kvalues_mxfp4_f[qs[4*i + 2] & 0x0F];
reg[i][3] = d * kvalues_mxfp4_f[qs[4*i + 3] & 0x0F];
}
// rows 2-3: high nibbles (elements 8..15)
for (int i = 0; i < 2; ++i) {
reg[i+2][0] = d * kvalues_mxfp4_f[qs[4*i + 0] >> 4];
reg[i+2][1] = d * kvalues_mxfp4_f[qs[4*i + 1] >> 4];
reg[i+2][2] = d * kvalues_mxfp4_f[qs[4*i + 2] >> 4];
reg[i+2][3] = d * kvalues_mxfp4_f[qs[4*i + 3] >> 4];
}
}
// t4 variant: chpb=16, il=0..15, each call produces 4 elements.
// Within each sub-block (4 calls): lo-row0, lo-row1, hi-row0, hi-row1
template <typename type4>
void dequantize_nvfp4_t4(device const block_nvfp4 * xb, short il, thread type4 & reg) {
const short sub = il / 4; // sub-block index 0..3
const short rem = il % 4; // 0..3 within sub-block
const short row = rem % 2; // byte row: 0 = bytes 0..3, 1 = bytes 4..7
const uint8_t shr = rem >= 2 ? 4 : 0; // 0,1 = low nibble; 2,3 = high nibble
device const uint8_t * qs = xb->qs + sub * (QK_NVFP4_SUB/2) + row * 4;
const float d = kvalues_ue4m3_f[xb->d[sub]];
reg[0] = d * kvalues_mxfp4_f[(qs[0] >> shr) & 0x0F];
reg[1] = d * kvalues_mxfp4_f[(qs[1] >> shr) & 0x0F];
reg[2] = d * kvalues_mxfp4_f[(qs[2] >> shr) & 0x0F];
reg[3] = d * kvalues_mxfp4_f[(qs[3] >> shr) & 0x0F];
}
template <typename type4x4>
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
const float d = xb->d;
@ -3743,6 +3832,11 @@ template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_3")]] kernel mul_mv_ext_q4
template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_mxfp4, 32, dequantize_mxfp4_t4>;
template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_mxfp4, 32, dequantize_mxfp4_t4>;
template [[host_name("kernel_mul_mv_ext_nvfp4_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_nvfp4, 64, dequantize_nvfp4_t4>;
template [[host_name("kernel_mul_mv_ext_nvfp4_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_nvfp4, 64, dequantize_nvfp4_t4>;
template [[host_name("kernel_mul_mv_ext_nvfp4_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_nvfp4, 64, dequantize_nvfp4_t4>;
template [[host_name("kernel_mul_mv_ext_nvfp4_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_nvfp4, 64, dequantize_nvfp4_t4>;
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
@ -8807,6 +8901,106 @@ kernel void kernel_mul_mv_mxfp4_f32(
kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<int NR0, typename args_t>
void kernel_mul_mv_nvfp4_f32_impl(
args_t args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem,
uint3 tgpig,
ushort tiisg,
ushort sgitg) {
const short NSG = FC_mul_mv_nsg;
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * NSG + sgitg) * NR0;
const uint i12 = im%args.ne12;
const uint i13 = im/args.ne12;
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_nvfp4 * x = (device const block_nvfp4 *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
const int nb = args.ne00/QK_NVFP4;
const int ns01 = args.nb01/args.nb00;
const short ix = tiisg/2; // 0...15
const short it = tiisg%2; // 0 or 1
shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16];
// UE4M3 scale LUT: 128 entries, loaded cooperatively by all threads
threadgroup float * ue4m3_lut = shmem_f32 + 32;
ue4m3_lut[sgitg*32 + tiisg] = ue4m3_to_fp32(sgitg*32 + tiisg);
ue4m3_lut[sgitg*32 + tiisg + 64] = ue4m3_to_fp32(sgitg*32 + tiisg + 64);
threadgroup_barrier(mem_flags::mem_threadgroup);
float sumf[NR0] = {0.f};
// each thread handles 4 bytes per sub-block (half of the 8 bytes)
// lo nibbles -> y at offset it*4, hi nibbles -> y at offset 8+it*4
device const float * yb = y + ix*QK_NVFP4;
float4 yl[8]; // pre-loaded y: 4 sub-blocks × {lo, hi}
for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
// pre-load y into registers before the row loop
FOR_UNROLL (short s = 0; s < 4; s++) {
yl[2*s + 0] = *(device const float4 *)(yb + s*QK_NVFP4_SUB + it*4);
yl[2*s + 1] = *(device const float4 *)(yb + s*QK_NVFP4_SUB + QK_NVFP4_SUB/2 + it*4);
}
FOR_UNROLL (short row = 0; row < NR0; row++) {
device const block_nvfp4 & xb = x[row*ns01 + ib];
float sub_sum = 0.0f;
FOR_UNROLL (short s = 0; s < 4; s++) {
device const uint8_t * qs = (device const uint8_t *)(xb.qs + s*(QK_NVFP4_SUB/2) + it*4);
float4 acc = yl[2*s]*float4(shmem_f32[qs[0] & 0x0F], shmem_f32[qs[1] & 0x0F], shmem_f32[qs[2] & 0x0F], shmem_f32[qs[3] & 0x0F])
+ yl[2*s+1]*float4(shmem_f32[qs[0] >> 4], shmem_f32[qs[1] >> 4], shmem_f32[qs[2] >> 4], shmem_f32[qs[3] >> 4]);
sub_sum += ue4m3_lut[xb.d[s]] * ((acc[0] + acc[1]) + (acc[2] + acc[3]));
}
sumf[row] += sub_sum;
}
yb += 16 * QK_NVFP4;
}
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
float sum_all = simd_sum(sumf[row]);
if (tiisg == 0) {
dst_f32[first_row + row] = sum_all;
}
}
}
[[host_name("kernel_mul_mv_nvfp4_f32")]]
kernel void kernel_mul_mv_nvfp4_f32(
constant ggml_metal_kargs_mul_mv & args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
kernel_mul_mv_nvfp4_f32_impl<N_R0_NVFP4, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows_q(
constant ggml_metal_kargs_get_rows & args,
@ -9655,6 +9849,7 @@ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_mxfp4, 2, dequantize_mxfp4>;
template [[host_name("kernel_get_rows_nvfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_nvfp4, 4, dequantize_nvfp4>;
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
@ -9717,6 +9912,7 @@ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_nvfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_nvfp4, 4, dequantize_nvfp4, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
@ -9740,6 +9936,7 @@ template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_nvfp4_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_nvfp4, 4, dequantize_nvfp4, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
@ -9772,6 +9969,7 @@ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_m
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_nvfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_nvfp4, 4, dequantize_nvfp4, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
@ -9795,6 +9993,7 @@ template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_m
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_nvfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_nvfp4, 4, dequantize_nvfp4, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
@ -9950,6 +10149,8 @@ template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t
template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4>>>;
template [[host_name("kernel_mul_mv_id_nvfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_nvfp4_f32_impl<N_R0_NVFP4>>>;
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K>>>;
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K>>>;
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K>>>;