NVFP4 metal
This commit is contained in:
parent
f90bd1dd84
commit
211ab16a3a
|
|
@ -729,6 +729,12 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta
|
|||
nr0 = N_R0_MXFP4;
|
||||
smem = 32*sizeof(float);
|
||||
} break;
|
||||
case GGML_TYPE_NVFP4:
|
||||
{
|
||||
nsg = N_SG_NVFP4;
|
||||
nr0 = N_R0_NVFP4;
|
||||
smem = (32 + 128)*sizeof(float);
|
||||
} break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
{
|
||||
nsg = N_SG_Q2_K;
|
||||
|
|
@ -941,6 +947,12 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m
|
|||
nr0 = N_R0_MXFP4;
|
||||
smem = 32*sizeof(float);
|
||||
} break;
|
||||
case GGML_TYPE_NVFP4:
|
||||
{
|
||||
nsg = N_SG_NVFP4;
|
||||
nr0 = N_R0_NVFP4;
|
||||
smem = (32 + 128)*sizeof(float);
|
||||
} break;
|
||||
case GGML_TYPE_Q2_K:
|
||||
{
|
||||
nsg = N_SG_Q2_K;
|
||||
|
|
|
|||
|
|
@ -1158,7 +1158,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
case GGML_OP_SOLVE_TRI:
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return has_simdgroup_reduction && op->src[0]->type != GGML_TYPE_NVFP4;
|
||||
return has_simdgroup_reduction;
|
||||
case GGML_OP_SET:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_DUP:
|
||||
|
|
@ -1216,7 +1216,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
|||
};
|
||||
}
|
||||
case GGML_OP_GET_ROWS:
|
||||
return op->src[0]->type != GGML_TYPE_NVFP4;
|
||||
return true;
|
||||
case GGML_OP_SET_ROWS:
|
||||
{
|
||||
if (op->src[0]->type != GGML_TYPE_F32) {
|
||||
|
|
|
|||
|
|
@ -26,6 +26,9 @@
|
|||
#define N_R0_MXFP4 2
|
||||
#define N_SG_MXFP4 2
|
||||
|
||||
#define N_R0_NVFP4 2
|
||||
#define N_SG_NVFP4 2
|
||||
|
||||
#define N_R0_Q2_K 4
|
||||
#define N_SG_Q2_K 2
|
||||
|
||||
|
|
|
|||
|
|
@ -1970,6 +1970,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
|||
op->src[0]->type == GGML_TYPE_Q5_1 ||
|
||||
op->src[0]->type == GGML_TYPE_Q8_0 ||
|
||||
op->src[0]->type == GGML_TYPE_MXFP4 ||
|
||||
op->src[0]->type == GGML_TYPE_NVFP4 ||
|
||||
op->src[0]->type == GGML_TYPE_IQ4_NL ||
|
||||
false) && (ne11 >= 2 && ne11 <= 8)
|
||||
) ||
|
||||
|
|
|
|||
|
|
@ -50,6 +50,26 @@ constexpr constant static float kvalues_mxfp4_f[16] = {
|
|||
0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
|
||||
};
|
||||
|
||||
// UE4M3 scale lookup table (128 entries, unsigned E4M3 with bias=7)
|
||||
constexpr constant static float kvalues_ue4m3_f[128] = {
|
||||
0.f, 1.953125e-03f, 3.906250e-03f, 5.859375e-03f, 7.812500e-03f, 9.765625e-03f, 1.171875e-02f, 1.367188e-02f,
|
||||
1.562500e-02f, 1.757813e-02f, 1.953125e-02f, 2.148438e-02f, 2.343750e-02f, 2.539063e-02f, 2.734375e-02f, 2.929688e-02f,
|
||||
3.125000e-02f, 3.515625e-02f, 3.906250e-02f, 4.296875e-02f, 4.687500e-02f, 5.078125e-02f, 5.468750e-02f, 5.859375e-02f,
|
||||
6.250000e-02f, 7.031250e-02f, 7.812500e-02f, 8.593750e-02f, 9.375000e-02f, 1.015625e-01f, 1.093750e-01f, 1.171875e-01f,
|
||||
1.250000e-01f, 1.406250e-01f, 1.562500e-01f, 1.718750e-01f, 1.875000e-01f, 2.031250e-01f, 2.187500e-01f, 2.343750e-01f,
|
||||
2.500000e-01f, 2.812500e-01f, 3.125000e-01f, 3.437500e-01f, 3.750000e-01f, 4.062500e-01f, 4.375000e-01f, 4.687500e-01f,
|
||||
5.000000e-01f, 5.625000e-01f, 6.250000e-01f, 6.875000e-01f, 7.500000e-01f, 8.125000e-01f, 8.750000e-01f, 9.375000e-01f,
|
||||
1.000000e+00f, 1.125000e+00f, 1.250000e+00f, 1.375000e+00f, 1.500000e+00f, 1.625000e+00f, 1.750000e+00f, 1.875000e+00f,
|
||||
2.000000e+00f, 2.250000e+00f, 2.500000e+00f, 2.750000e+00f, 3.000000e+00f, 3.250000e+00f, 3.500000e+00f, 3.750000e+00f,
|
||||
4.000000e+00f, 4.500000e+00f, 5.000000e+00f, 5.500000e+00f, 6.000000e+00f, 6.500000e+00f, 7.000000e+00f, 7.500000e+00f,
|
||||
8.000000e+00f, 9.000000e+00f, 1.000000e+01f, 1.100000e+01f, 1.200000e+01f, 1.300000e+01f, 1.400000e+01f, 1.500000e+01f,
|
||||
1.600000e+01f, 1.800000e+01f, 2.000000e+01f, 2.200000e+01f, 2.400000e+01f, 2.600000e+01f, 2.800000e+01f, 3.000000e+01f,
|
||||
3.200000e+01f, 3.600000e+01f, 4.000000e+01f, 4.400000e+01f, 4.800000e+01f, 5.200000e+01f, 5.600000e+01f, 6.000000e+01f,
|
||||
6.400000e+01f, 7.200000e+01f, 8.000000e+01f, 8.800000e+01f, 9.600000e+01f, 1.040000e+02f, 1.120000e+02f, 1.200000e+02f,
|
||||
1.280000e+02f, 1.440000e+02f, 1.600000e+02f, 1.760000e+02f, 1.920000e+02f, 2.080000e+02f, 2.240000e+02f, 2.400000e+02f,
|
||||
2.560000e+02f, 2.880000e+02f, 3.200000e+02f, 3.520000e+02f, 3.840000e+02f, 4.160000e+02f, 4.480000e+02f, 0.f,
|
||||
};
|
||||
|
||||
static inline int best_index_int8(int n, constant float * val, float x) {
|
||||
if (x <= val[0]) return 0;
|
||||
if (x >= val[n-1]) return n-1;
|
||||
|
|
@ -73,6 +93,32 @@ static inline float e8m0_to_fp32(uint8_t x) {
|
|||
return as_type<float>(bits);
|
||||
}
|
||||
|
||||
// UE4M3: unsigned, 4 exp bits (bias=7), 3 mantissa bits
|
||||
// Branchless using additive bias to avoid FP32 denormals on GPU
|
||||
static inline float ue4m3_to_fp32(uint8_t x) {
|
||||
// Add bias of 16 to exponent to keep all values in FP32 normal range
|
||||
// UE4M3 bits: [6:3]=exp, [2:0]=man
|
||||
// With biased exp: FP32 value = (1 + man/8) * 2^(exp+16-127) for normal
|
||||
// We want (1 + man/8) * 2^(exp-7), so divide by 2^(16-127+7) = multiply by 2^(127-23) = 2^104
|
||||
// But for exp=0 (subnormal): value should be man * 2^(-9)
|
||||
// With bias: we get (1 + man/8) * 2^(16-127) which is wrong for subnormals
|
||||
// So this approach doesn't handle subnormals correctly.
|
||||
// Since UE4M3 subnormals (exp=0, man=1..7) represent tiny values (max 7*2^-9 ≈ 0.0137),
|
||||
// and these are scale factors, they're extremely rare in practice.
|
||||
// Use select to handle the zero case, keep branches for subnormal.
|
||||
if (x == 0) {
|
||||
return 0.0f;
|
||||
}
|
||||
int exp = (x >> 3) & 0xF;
|
||||
int man = x & 0x7;
|
||||
if (exp == 0) {
|
||||
return ldexp((float)man, -9);
|
||||
}
|
||||
// Normal: construct FP32 directly. exp-7+127 = exp+120
|
||||
uint32_t bits = (uint32_t)(exp + 120) << 23 | (uint32_t)man << 20;
|
||||
return as_type<float>(bits);
|
||||
}
|
||||
|
||||
static inline float dot(float x, float y) {
|
||||
return x*y;
|
||||
}
|
||||
|
|
@ -557,6 +603,49 @@ void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 &
|
|||
reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];
|
||||
}
|
||||
|
||||
// NVFP4 dequantize: block_nvfp4 has 64 elements = 4 sub-blocks of 16
|
||||
// nl=4, il=0..3 selects sub-block. Each call produces 4x4 = 16 elements.
|
||||
// CPU layout per sub-block: low nibbles of 8 bytes -> elements 0..7,
|
||||
// high nibbles -> elements 8..15
|
||||
template <typename type4x4>
|
||||
void dequantize_nvfp4(device const block_nvfp4 * xb, short il, thread type4x4 & reg) {
|
||||
device const uint8_t * qs = xb->qs + il * (QK_NVFP4_SUB/2); // 8 bytes per sub-block
|
||||
const float d = kvalues_ue4m3_f[xb->d[il]];
|
||||
|
||||
// rows 0-1: low nibbles (elements 0..7)
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
reg[i][0] = d * kvalues_mxfp4_f[qs[4*i + 0] & 0x0F];
|
||||
reg[i][1] = d * kvalues_mxfp4_f[qs[4*i + 1] & 0x0F];
|
||||
reg[i][2] = d * kvalues_mxfp4_f[qs[4*i + 2] & 0x0F];
|
||||
reg[i][3] = d * kvalues_mxfp4_f[qs[4*i + 3] & 0x0F];
|
||||
}
|
||||
// rows 2-3: high nibbles (elements 8..15)
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
reg[i+2][0] = d * kvalues_mxfp4_f[qs[4*i + 0] >> 4];
|
||||
reg[i+2][1] = d * kvalues_mxfp4_f[qs[4*i + 1] >> 4];
|
||||
reg[i+2][2] = d * kvalues_mxfp4_f[qs[4*i + 2] >> 4];
|
||||
reg[i+2][3] = d * kvalues_mxfp4_f[qs[4*i + 3] >> 4];
|
||||
}
|
||||
}
|
||||
|
||||
// t4 variant: chpb=16, il=0..15, each call produces 4 elements.
|
||||
// Within each sub-block (4 calls): lo-row0, lo-row1, hi-row0, hi-row1
|
||||
template <typename type4>
|
||||
void dequantize_nvfp4_t4(device const block_nvfp4 * xb, short il, thread type4 & reg) {
|
||||
const short sub = il / 4; // sub-block index 0..3
|
||||
const short rem = il % 4; // 0..3 within sub-block
|
||||
const short row = rem % 2; // byte row: 0 = bytes 0..3, 1 = bytes 4..7
|
||||
const uint8_t shr = rem >= 2 ? 4 : 0; // 0,1 = low nibble; 2,3 = high nibble
|
||||
|
||||
device const uint8_t * qs = xb->qs + sub * (QK_NVFP4_SUB/2) + row * 4;
|
||||
const float d = kvalues_ue4m3_f[xb->d[sub]];
|
||||
|
||||
reg[0] = d * kvalues_mxfp4_f[(qs[0] >> shr) & 0x0F];
|
||||
reg[1] = d * kvalues_mxfp4_f[(qs[1] >> shr) & 0x0F];
|
||||
reg[2] = d * kvalues_mxfp4_f[(qs[2] >> shr) & 0x0F];
|
||||
reg[3] = d * kvalues_mxfp4_f[(qs[3] >> shr) & 0x0F];
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
||||
const float d = xb->d;
|
||||
|
|
@ -3518,6 +3607,11 @@ template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_3")]] kernel mul_mv_ext_q4
|
|||
template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_mxfp4, 32, dequantize_mxfp4_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_mxfp4, 32, dequantize_mxfp4_t4>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_ext_nvfp4_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_nvfp4, 64, dequantize_nvfp4_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_nvfp4_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_nvfp4, 64, dequantize_nvfp4_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_nvfp4_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_nvfp4, 64, dequantize_nvfp4_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_nvfp4_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_nvfp4, 64, dequantize_nvfp4_t4>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
|
||||
|
|
@ -8563,6 +8657,106 @@ kernel void kernel_mul_mv_mxfp4_f32(
|
|||
kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<int NR0, typename args_t>
|
||||
void kernel_mul_mv_nvfp4_f32_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
const short NSG = FC_mul_mv_nsg;
|
||||
|
||||
threadgroup float * shmem_f32 = (threadgroup float *) shmem;
|
||||
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
|
||||
const int first_row = (r0 * NSG + sgitg) * NR0;
|
||||
|
||||
const uint i12 = im%args.ne12;
|
||||
const uint i13 = im/args.ne12;
|
||||
|
||||
const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
device const block_nvfp4 * x = (device const block_nvfp4 *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
const int nb = args.ne00/QK_NVFP4;
|
||||
const int ns01 = args.nb01/args.nb00;
|
||||
|
||||
const short ix = tiisg/2; // 0...15
|
||||
const short it = tiisg%2; // 0 or 1
|
||||
|
||||
shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16];
|
||||
|
||||
// UE4M3 scale LUT: 128 entries, loaded cooperatively by all threads
|
||||
threadgroup float * ue4m3_lut = shmem_f32 + 32;
|
||||
ue4m3_lut[sgitg*32 + tiisg] = ue4m3_to_fp32(sgitg*32 + tiisg);
|
||||
ue4m3_lut[sgitg*32 + tiisg + 64] = ue4m3_to_fp32(sgitg*32 + tiisg + 64);
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
float sumf[NR0] = {0.f};
|
||||
|
||||
// each thread handles 4 bytes per sub-block (half of the 8 bytes)
|
||||
// lo nibbles -> y at offset it*4, hi nibbles -> y at offset 8+it*4
|
||||
device const float * yb = y + ix*QK_NVFP4;
|
||||
|
||||
float4 yl[8]; // pre-loaded y: 4 sub-blocks × {lo, hi}
|
||||
|
||||
for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
|
||||
// pre-load y into registers before the row loop
|
||||
FOR_UNROLL (short s = 0; s < 4; s++) {
|
||||
yl[2*s + 0] = *(device const float4 *)(yb + s*QK_NVFP4_SUB + it*4);
|
||||
yl[2*s + 1] = *(device const float4 *)(yb + s*QK_NVFP4_SUB + QK_NVFP4_SUB/2 + it*4);
|
||||
}
|
||||
|
||||
FOR_UNROLL (short row = 0; row < NR0; row++) {
|
||||
device const block_nvfp4 & xb = x[row*ns01 + ib];
|
||||
|
||||
float sub_sum = 0.0f;
|
||||
FOR_UNROLL (short s = 0; s < 4; s++) {
|
||||
device const uint8_t * qs = (device const uint8_t *)(xb.qs + s*(QK_NVFP4_SUB/2) + it*4);
|
||||
|
||||
float4 acc = yl[2*s]*float4(shmem_f32[qs[0] & 0x0F], shmem_f32[qs[1] & 0x0F], shmem_f32[qs[2] & 0x0F], shmem_f32[qs[3] & 0x0F])
|
||||
+ yl[2*s+1]*float4(shmem_f32[qs[0] >> 4], shmem_f32[qs[1] >> 4], shmem_f32[qs[2] >> 4], shmem_f32[qs[3] >> 4]);
|
||||
|
||||
sub_sum += ue4m3_lut[xb.d[s]] * ((acc[0] + acc[1]) + (acc[2] + acc[3]));
|
||||
}
|
||||
sumf[row] += sub_sum;
|
||||
}
|
||||
|
||||
yb += 16 * QK_NVFP4;
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
|
||||
float sum_all = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[first_row + row] = sum_all;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_nvfp4_f32")]]
|
||||
kernel void kernel_mul_mv_nvfp4_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_nvfp4_f32_impl<N_R0_NVFP4, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
||||
kernel void kernel_get_rows_q(
|
||||
constant ggml_metal_kargs_get_rows & args,
|
||||
|
|
@ -9411,6 +9605,7 @@ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get
|
|||
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
|
||||
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_mxfp4, 2, dequantize_mxfp4>;
|
||||
template [[host_name("kernel_get_rows_nvfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_nvfp4, 4, dequantize_nvfp4>;
|
||||
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
|
|
@ -9473,6 +9668,7 @@ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_m
|
|||
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_nvfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_nvfp4, 4, dequantize_nvfp4, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
|
||||
|
|
@ -9496,6 +9692,7 @@ template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_m
|
|||
template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_nvfp4_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_nvfp4, 4, dequantize_nvfp4, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
|
||||
|
|
@ -9528,6 +9725,7 @@ template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_m
|
|||
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_nvfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_nvfp4, 4, dequantize_nvfp4, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
|
||||
|
|
@ -9551,6 +9749,7 @@ template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_m
|
|||
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_nvfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_nvfp4, 4, dequantize_nvfp4, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
|
||||
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
|
||||
|
|
@ -9706,6 +9905,8 @@ template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t
|
|||
|
||||
template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4>>>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_nvfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_nvfp4_f32_impl<N_R0_NVFP4>>>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K>>>;
|
||||
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K>>>;
|
||||
|
|
|
|||
Loading…
Reference in New Issue