diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 89539bd761..e8548b053e 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -736,6 +736,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta suffix = ne00 % 4 == 0 ? "_4" : ""; } } break; + case GGML_TYPE_Q1_0: + { + nsg = N_SG_Q1_0; + nr0 = N_R0_Q1_0; + } break; case GGML_TYPE_Q4_0: { nsg = N_SG_Q4_0; @@ -948,6 +953,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m smem = 32*sizeof(float)*nr0; suffix = ne00 % 4 == 0 ? "_4" : ""; } break; + case GGML_TYPE_Q1_0: + { + nsg = N_SG_Q1_0; + nr0 = N_R0_Q1_0; + } break; case GGML_TYPE_Q4_0: { nsg = N_SG_Q4_0; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 17d51b11b6..40cacb4652 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1184,6 +1184,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_TYPE_F16: case GGML_TYPE_BF16: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -1210,6 +1211,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te default: return false; } + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index eb2253e029..62b028f4a4 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -8,6 +8,9 @@ // // TODO: for optimal performance, become function of the device and work size +#define N_R0_Q1_0 8 +#define N_SG_Q1_0 2 + #define N_R0_Q4_0 4 #define N_SG_Q4_0 2 diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 3cda21be43..846225d907 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2047,6 +2047,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16 || + op->src[0]->type == GGML_TYPE_Q1_0 || op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || op->src[0]->type == GGML_TYPE_Q5_0 || diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 2074211594..f28bfa0b95 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -118,6 +118,56 @@ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg } #endif +template +void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) { + device const uint8_t * qs = xb->qs; + const float d = xb->d; + const float neg_d = -d; + + const int byte_offset = il * 2; // il*16 bits = il*2 bytes + const uint8_t b0 = qs[byte_offset]; + const uint8_t b1 = qs[byte_offset + 1]; + + float4x4 reg_f; + + reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01)); + reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02)); + reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04)); + reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08)); + reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10)); + reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20)); + reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40)); + reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80)); + + reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01)); + reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02)); + reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04)); + reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08)); + reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10)); + reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20)); + reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40)); + reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80)); + + reg = (type4x4) reg_f; +} + +template +void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) { + const float d = xb->d; + const float neg_d = -d; + const int base = il * 4; + const uint8_t byte = xb->qs[base / 8]; + const int s = base % 8; + + float4 reg_f; + reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1)); + reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1)); + reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1)); + reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1)); + + reg = (type4) reg_f; +} + template void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 1); @@ -152,6 +202,23 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r } } +void quantize_q1_0(device const float * src, device block_q1_0 & dst) { + float sum_abs = 0.0f; + for (int j = 0; j < QK1_0; j++) { + sum_abs += fabs(src[j]); + } + dst.d = sum_abs / QK1_0; + + for (int j = 0; j < QK1_0 / 8; j++) { + dst.qs[j] = 0; + } + for (int j = 0; j < QK1_0; j++) { + if (src[j] >= 0.0f) { + dst.qs[j / 8] |= (1 << (j % 8)); + } + } +} + void quantize_q4_0(device const float * src, device block_q4_0 & dst) { #pragma METAL fp math_mode(safe) float amax = 0.0f; // absolute max @@ -3116,6 +3183,35 @@ kernel void kernel_group_norm_f32( } } +// Q1_0 dot product: dot = d * (2 * Σ(yl[i] where bit=1) - sumy) +inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) { + device const uint8_t * qs = qb_curr->qs + il / 8; + const uint8_t b0 = qs[0]; + const uint8_t b1 = qs[1]; + + float acc = 0.0f; + + acc += select(0.0f, yl[ 0], bool(b0 & 0x01)); + acc += select(0.0f, yl[ 1], bool(b0 & 0x02)); + acc += select(0.0f, yl[ 2], bool(b0 & 0x04)); + acc += select(0.0f, yl[ 3], bool(b0 & 0x08)); + acc += select(0.0f, yl[ 4], bool(b0 & 0x10)); + acc += select(0.0f, yl[ 5], bool(b0 & 0x20)); + acc += select(0.0f, yl[ 6], bool(b0 & 0x40)); + acc += select(0.0f, yl[ 7], bool(b0 & 0x80)); + + acc += select(0.0f, yl[ 8], bool(b1 & 0x01)); + acc += select(0.0f, yl[ 9], bool(b1 & 0x02)); + acc += select(0.0f, yl[10], bool(b1 & 0x04)); + acc += select(0.0f, yl[11], bool(b1 & 0x08)); + acc += select(0.0f, yl[12], bool(b1 & 0x10)); + acc += select(0.0f, yl[13], bool(b1 & 0x20)); + acc += select(0.0f, yl[14], bool(b1 & 0x40)); + acc += select(0.0f, yl[15], bool(b1 & 0x80)); + + return qb_curr->d * (2.0f * acc - sumy); +} + // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) // il indicates where the q4 quants begin (0 or QK4_0/4) // we assume that the yl's have been multiplied with the appropriate scale factor @@ -3337,6 +3433,85 @@ void mul_vec_q_n_f32_impl( } } +template +void kernel_mul_mv_q1_0_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + + const int nb = args.ne00/QK1_0; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * NSG + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13; + + device const float * y = (device const float *) (src1 + offset1); + + device const block_q1_0 * ax[nr0]; + for (int row = 0; row < nr0; ++row) { + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0); + } + + float yl[16]; + float sumf[nr0] = {0.f}; + + const short ix = (tiisg/8); + const short il = (tiisg%8)*16; + + device const float * yb = y + ix*QK1_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) { + float sumy = 0.f; + + FOR_UNROLL (short i = 0; i < 16; i++) { + yl[i] = yb[i]; + sumy += yb[i]; + } + + FOR_UNROLL (short row = 0; row < nr0; row++) { + sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il); + } + + yb += QK1_0 * (N_SIMDWIDTH/8); + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr0; ++row) { + const float tot = simd_sum(sumf[row]); + + if (tiisg == 0 && first_row + row < args.ne01) { + dst_f32[first_row + row] = tot; + } + } +} + +[[host_name("kernel_mul_mv_q1_0_f32")]] +kernel void kernel_mul_mv_q1_0_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_q1_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + kernel void kernel_mul_mv_q4_0_f32( constant ggml_metal_kargs_mul_mv & args, device const char * src0, @@ -3729,6 +3904,11 @@ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4 template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>; #endif +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q1_0, 128, dequantize_q1_0_t4>; +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q1_0, 128, dequantize_q1_0_t4>; +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q1_0, 128, dequantize_q1_0_t4>; +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q1_0, 128, dequantize_q1_0_t4>; + template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>; template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>; template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>; @@ -7133,6 +7313,7 @@ kernel void kernel_cpy_f32_q( typedef decltype(kernel_cpy_f32_q) cpy_f_q_t; template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; @@ -7173,12 +7354,14 @@ kernel void kernel_cpy_q_f32( typedef decltype(kernel_cpy_q_f32) cpy_q_f_t; +template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; @@ -9776,6 +9959,7 @@ template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_ro typedef decltype(kernel_get_rows_q) get_rows_q_t; +template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q; @@ -9838,6 +10022,7 @@ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_m #if defined(GGML_METAL_HAS_BF16) template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm; #endif +template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm; @@ -9861,6 +10046,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm; @@ -10070,6 +10256,7 @@ template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4 template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q1_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 093f17e23a..c7129d47a7 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7251,6 +7251,7 @@ static const ggml_type all_types[] = { GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0, + GGML_TYPE_Q1_0, GGML_TYPE_MXFP4, GGML_TYPE_NVFP4, GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, @@ -7275,6 +7276,7 @@ static const ggml_type other_types[] = { GGML_TYPE_Q4_1, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0, + GGML_TYPE_Q1_0, GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K,