|
|
|
|
@ -118,6 +118,56 @@ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
template <typename type4x4>
|
|
|
|
|
void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) {
|
|
|
|
|
device const uint8_t * qs = xb->qs;
|
|
|
|
|
const float d = xb->d;
|
|
|
|
|
const float neg_d = -d;
|
|
|
|
|
|
|
|
|
|
const int byte_offset = il * 2; // il*16 bits = il*2 bytes
|
|
|
|
|
const uint8_t b0 = qs[byte_offset];
|
|
|
|
|
const uint8_t b1 = qs[byte_offset + 1];
|
|
|
|
|
|
|
|
|
|
float4x4 reg_f;
|
|
|
|
|
|
|
|
|
|
reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01));
|
|
|
|
|
reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02));
|
|
|
|
|
reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04));
|
|
|
|
|
reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08));
|
|
|
|
|
reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10));
|
|
|
|
|
reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20));
|
|
|
|
|
reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40));
|
|
|
|
|
reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80));
|
|
|
|
|
|
|
|
|
|
reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01));
|
|
|
|
|
reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02));
|
|
|
|
|
reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04));
|
|
|
|
|
reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08));
|
|
|
|
|
reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10));
|
|
|
|
|
reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20));
|
|
|
|
|
reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40));
|
|
|
|
|
reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80));
|
|
|
|
|
|
|
|
|
|
reg = (type4x4) reg_f;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename type4>
|
|
|
|
|
void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) {
|
|
|
|
|
const float d = xb->d;
|
|
|
|
|
const float neg_d = -d;
|
|
|
|
|
const int base = il * 4;
|
|
|
|
|
const uint8_t byte = xb->qs[base / 8];
|
|
|
|
|
const int s = base % 8;
|
|
|
|
|
|
|
|
|
|
float4 reg_f;
|
|
|
|
|
reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1));
|
|
|
|
|
reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1));
|
|
|
|
|
reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1));
|
|
|
|
|
reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1));
|
|
|
|
|
|
|
|
|
|
reg = (type4) reg_f;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename type4x4>
|
|
|
|
|
void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
|
|
|
|
|
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
|
|
|
|
@ -152,6 +202,23 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void quantize_q1_0(device const float * src, device block_q1_0 & dst) {
|
|
|
|
|
float sum_abs = 0.0f;
|
|
|
|
|
for (int j = 0; j < QK1_0; j++) {
|
|
|
|
|
sum_abs += fabs(src[j]);
|
|
|
|
|
}
|
|
|
|
|
dst.d = sum_abs / QK1_0;
|
|
|
|
|
|
|
|
|
|
for (int j = 0; j < QK1_0 / 8; j++) {
|
|
|
|
|
dst.qs[j] = 0;
|
|
|
|
|
}
|
|
|
|
|
for (int j = 0; j < QK1_0; j++) {
|
|
|
|
|
if (src[j] >= 0.0f) {
|
|
|
|
|
dst.qs[j / 8] |= (1 << (j % 8));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
|
|
|
|
|
#pragma METAL fp math_mode(safe)
|
|
|
|
|
float amax = 0.0f; // absolute max
|
|
|
|
|
@ -3116,6 +3183,35 @@ kernel void kernel_group_norm_f32(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Q1_0 dot product: dot = d * (2 * Σ(yl[i] where bit=1) - sumy)
|
|
|
|
|
inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) {
|
|
|
|
|
device const uint8_t * qs = qb_curr->qs + il / 8;
|
|
|
|
|
const uint8_t b0 = qs[0];
|
|
|
|
|
const uint8_t b1 = qs[1];
|
|
|
|
|
|
|
|
|
|
float acc = 0.0f;
|
|
|
|
|
|
|
|
|
|
acc += select(0.0f, yl[ 0], bool(b0 & 0x01));
|
|
|
|
|
acc += select(0.0f, yl[ 1], bool(b0 & 0x02));
|
|
|
|
|
acc += select(0.0f, yl[ 2], bool(b0 & 0x04));
|
|
|
|
|
acc += select(0.0f, yl[ 3], bool(b0 & 0x08));
|
|
|
|
|
acc += select(0.0f, yl[ 4], bool(b0 & 0x10));
|
|
|
|
|
acc += select(0.0f, yl[ 5], bool(b0 & 0x20));
|
|
|
|
|
acc += select(0.0f, yl[ 6], bool(b0 & 0x40));
|
|
|
|
|
acc += select(0.0f, yl[ 7], bool(b0 & 0x80));
|
|
|
|
|
|
|
|
|
|
acc += select(0.0f, yl[ 8], bool(b1 & 0x01));
|
|
|
|
|
acc += select(0.0f, yl[ 9], bool(b1 & 0x02));
|
|
|
|
|
acc += select(0.0f, yl[10], bool(b1 & 0x04));
|
|
|
|
|
acc += select(0.0f, yl[11], bool(b1 & 0x08));
|
|
|
|
|
acc += select(0.0f, yl[12], bool(b1 & 0x10));
|
|
|
|
|
acc += select(0.0f, yl[13], bool(b1 & 0x20));
|
|
|
|
|
acc += select(0.0f, yl[14], bool(b1 & 0x40));
|
|
|
|
|
acc += select(0.0f, yl[15], bool(b1 & 0x80));
|
|
|
|
|
|
|
|
|
|
return qb_curr->d * (2.0f * acc - sumy);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
|
|
|
|
// il indicates where the q4 quants begin (0 or QK4_0/4)
|
|
|
|
|
// we assume that the yl's have been multiplied with the appropriate scale factor
|
|
|
|
|
@ -3337,6 +3433,85 @@ void mul_vec_q_n_f32_impl(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<int nr0, typename args_t>
|
|
|
|
|
void kernel_mul_mv_q1_0_f32_impl(
|
|
|
|
|
args_t args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
device const char * src1,
|
|
|
|
|
device char * dst,
|
|
|
|
|
threadgroup char * shmem,
|
|
|
|
|
uint3 tgpig,
|
|
|
|
|
ushort tiisg,
|
|
|
|
|
ushort sgitg) {
|
|
|
|
|
const short NSG = FC_mul_mv_nsg;
|
|
|
|
|
|
|
|
|
|
const int nb = args.ne00/QK1_0;
|
|
|
|
|
|
|
|
|
|
const int r0 = tgpig.x;
|
|
|
|
|
const int r1 = tgpig.y;
|
|
|
|
|
const int im = tgpig.z;
|
|
|
|
|
|
|
|
|
|
const int first_row = (r0 * NSG + sgitg) * nr0;
|
|
|
|
|
|
|
|
|
|
const uint i12 = im%args.ne12;
|
|
|
|
|
const uint i13 = im/args.ne12;
|
|
|
|
|
|
|
|
|
|
const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13;
|
|
|
|
|
|
|
|
|
|
device const float * y = (device const float *) (src1 + offset1);
|
|
|
|
|
|
|
|
|
|
device const block_q1_0 * ax[nr0];
|
|
|
|
|
for (int row = 0; row < nr0; ++row) {
|
|
|
|
|
const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
|
|
|
|
ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float yl[16];
|
|
|
|
|
float sumf[nr0] = {0.f};
|
|
|
|
|
|
|
|
|
|
const short ix = (tiisg/8);
|
|
|
|
|
const short il = (tiisg%8)*16;
|
|
|
|
|
|
|
|
|
|
device const float * yb = y + ix*QK1_0 + il;
|
|
|
|
|
|
|
|
|
|
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) {
|
|
|
|
|
float sumy = 0.f;
|
|
|
|
|
|
|
|
|
|
FOR_UNROLL (short i = 0; i < 16; i++) {
|
|
|
|
|
yl[i] = yb[i];
|
|
|
|
|
sumy += yb[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FOR_UNROLL (short row = 0; row < nr0; row++) {
|
|
|
|
|
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
yb += QK1_0 * (N_SIMDWIDTH/8);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
|
|
|
|
|
|
|
|
|
for (int row = 0; row < nr0; ++row) {
|
|
|
|
|
const float tot = simd_sum(sumf[row]);
|
|
|
|
|
|
|
|
|
|
if (tiisg == 0 && first_row + row < args.ne01) {
|
|
|
|
|
dst_f32[first_row + row] = tot;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
[[host_name("kernel_mul_mv_q1_0_f32")]]
|
|
|
|
|
kernel void kernel_mul_mv_q1_0_f32(
|
|
|
|
|
constant ggml_metal_kargs_mul_mv & args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
device const char * src1,
|
|
|
|
|
device char * dst,
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel void kernel_mul_mv_q4_0_f32(
|
|
|
|
|
constant ggml_metal_kargs_mul_mv & args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@ -3729,6 +3904,11 @@ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4
|
|
|
|
|
template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q1_0, 128, dequantize_q1_0_t4>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q1_0, 128, dequantize_q1_0_t4>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q1_0, 128, dequantize_q1_0_t4>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q1_0, 128, dequantize_q1_0_t4>;
|
|
|
|
|
|
|
|
|
|
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
|
|
|
|
|
@ -7133,6 +7313,7 @@ kernel void kernel_cpy_f32_q(
|
|
|
|
|
typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
|
|
|
|
|
|
|
|
|
|
template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
|
|
|
|
|
template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK1_0, block_q1_0, quantize_q1_0>;
|
|
|
|
|
template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
|
|
|
|
|
template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
|
|
|
|
|
template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
|
|
|
|
|
@ -7173,12 +7354,14 @@ kernel void kernel_cpy_q_f32(
|
|
|
|
|
|
|
|
|
|
typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
|
|
|
|
|
|
|
|
|
|
template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q1_0, 8, dequantize_q1_0>;
|
|
|
|
|
template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
|
|
|
|
|
template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
|
|
|
|
|
template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
|
|
|
|
|
template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
|
|
|
|
|
template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
|
|
|
|
|
|
|
|
|
|
template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q1_0, 8, dequantize_q1_0>;
|
|
|
|
|
template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
|
|
|
|
|
template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
|
|
|
|
|
template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
|
|
|
|
|
@ -9776,6 +9959,7 @@ template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_ro
|
|
|
|
|
|
|
|
|
|
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
|
|
|
|
|
|
|
|
|
|
template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q1_0, 8, dequantize_q1_0>;
|
|
|
|
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
|
|
|
|
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
|
|
|
|
|
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
|
|
|
|
|
@ -9838,6 +10022,7 @@ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_m
|
|
|
|
|
#if defined(GGML_METAL_HAS_BF16)
|
|
|
|
|
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
|
|
|
|
|
#endif
|
|
|
|
|
template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>;
|
|
|
|
|
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
|
|
|
|
|
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
|
|
|
|
|
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
|
|
|
|
|
@ -9861,6 +10046,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
|
|
|
|
|
|
|
|
|
|
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
|
|
|
|
|
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
|
|
|
|
|
template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>;
|
|
|
|
|
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
|
|
|
|
|
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
|
|
|
|
|
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
|
|
|
|
|
@ -10070,6 +10256,7 @@ template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4
|
|
|
|
|
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
|
|
|
|
|
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q1_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;
|
|
|
|
|
|