From 32286453463422a937a6d3e461c9a760d2b39f3e Mon Sep 17 00:00:00 2001 From: Li He Date: Wed, 21 Jan 2026 22:47:34 -0800 Subject: [PATCH] opencl: add general q6_k mm --- ggml/src/ggml-opencl/CMakeLists.txt | 1 + ggml/src/ggml-opencl/ggml-opencl.cpp | 58 +++++++ .../kernels/mul_mm_q6_k_f32_l4_lm.cl | 158 ++++++++++++++++++ 3 files changed, 217 insertions(+) create mode 100644 ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index fa5fadd112..f08d96943e 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -101,6 +101,7 @@ set(GGML_OPENCL_KERNELS mul_mm_f32_f32_l4_lm mul_mm_f16_f32_l4_lm mul_mm_q8_0_f32_l4_lm + mul_mm_q6_k_f32_l4_lm mul_mm_q8_0_f32_8x4 gemv_noshuffle_general_q8_0_f32 mul diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 508b2b8f03..af20604d29 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -564,6 +564,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mm_f32_f32_l4_lm; cl_kernel kernel_mul_mm_f16_f32_l4_lm; cl_kernel kernel_mul_mm_q8_0_f32_l4_lm; + cl_kernel kernel_mul_mm_q6_k_f32_l4_lm; std::vector profiling_info; @@ -1358,6 +1359,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_q6_k_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q6_k_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q6_k_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q6_k_f32_l4_lm", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mm_f16_f32_kq_kqv { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -8927,6 +8945,46 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } + case GGML_TYPE_Q6_K: { + //if (ne11 < 32) { + // break; + //} + kernel = backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q6_K->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } default: break; } diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl new file mode 100644 index 0000000000..3602c92fef --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl @@ -0,0 +1,158 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 2 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q6_k_f32_l4_lm( + global uchar * src0_ql, + global uchar * src0_qh, + global char * src0_s, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + + int ib = idx / 128; // 2 values per idx + int iqs = idx % 128; // 0..127 + + int n = iqs / 64; // 0,1 + int b = (iqs % 64) / 32; // 0,1 + int is_b = (iqs % 16) / 8; // 0,1 + int qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + int is = 8 * n + qhshift + is_b; // 0..15 + int qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 + int qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + + float dscale = (float)src0_d[ib] * (float)src0_s[ib*16 + is]; + + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 0] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 0] >> qhshift) & 3) << 4)) - 32); + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 1] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 1] >> qhshift) & 3) << 4)) - 32); + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +}