From 1d594c295c8863e9077feeb50305a5e19493e6ee Mon Sep 17 00:00:00 2001 From: Han Qingzhe <95479277+hNSBQZ@users.noreply.github.com> Date: Thu, 27 Nov 2025 04:44:07 +0800 Subject: [PATCH 01/33] clip: (minicpmv) fix resampler kq_scale (#17516) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * debug:"solve minicpmv precision problem" * “debug minicpmv” * Apply suggestion from @ngxson --------- Co-authored-by: Xuan-Son Nguyen --- tools/mtmd/clip.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index abdb778f7a..52ea542dec 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1175,10 +1175,11 @@ struct clip_graph { cb(K, "resampler_K", -1); cb(V, "resampler_V", -1); + float resampler_kq_scale = 1.0f/ sqrtf(float(d_head)); embeddings = build_attn( model.mm_model_attn_o_w, model.mm_model_attn_o_b, - Q, K, V, nullptr, kq_scale, -1); + Q, K, V, nullptr, resampler_kq_scale, -1); cb(embeddings, "resampler_attn_out", -1); } // layernorm From 5449367b2125d63069fad8d2ca13d0c5ebb2f003 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alberto=20Cabrera=20P=C3=A9rez?= <1478977+Alcpz@users.noreply.github.com> Date: Wed, 26 Nov 2025 21:14:54 +0000 Subject: [PATCH 02/33] Fix chunks being too small with small matrix sizes (#17526) --- ggml/src/ggml-cpu/repack.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index d132119135..875faedf2d 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1731,12 +1731,13 @@ template = min_chunk_size)) { nchunk0 = nth; + dr0 = (nr0 + nchunk0 - 1) / nchunk0; } - const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; - // Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size // This prevents creating too many tiny chunks that could overlap after alignment const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size; From 7cba58bbeac0c262bde7d45adb452133f53cb56f Mon Sep 17 00:00:00 2001 From: lhez Date: Wed, 26 Nov 2025 13:29:58 -0800 Subject: [PATCH 03/33] opencl: add sqr, sqrt, mean and ssm_conv (#17476) * opencl: add sqr * opencl: add sqrt * opencl: add mean * opencl: add ssm_conv * opencl: add missing cl_khr_fp16 * opencl: do sqrt in f32 then convert to f16 for better precision --- ggml/src/ggml-opencl/CMakeLists.txt | 4 + ggml/src/ggml-opencl/ggml-opencl.cpp | 331 +++++++++++++++++++++++ ggml/src/ggml-opencl/kernels/mean.cl | 39 +++ ggml/src/ggml-opencl/kernels/sqr.cl | 53 ++++ ggml/src/ggml-opencl/kernels/sqrt.cl | 53 ++++ ggml/src/ggml-opencl/kernels/ssm_conv.cl | 77 ++++++ 6 files changed, 557 insertions(+) create mode 100644 ggml/src/ggml-opencl/kernels/mean.cl create mode 100644 ggml/src/ggml-opencl/kernels/sqr.cl create mode 100644 ggml/src/ggml-opencl/kernels/sqrt.cl create mode 100644 ggml/src/ggml-opencl/kernels/ssm_conv.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 681c81b88a..2a4b79eb6a 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -70,6 +70,7 @@ set(GGML_OPENCL_KERNELS group_norm im2col_f32 im2col_f16 + mean mul_mat_Ab_Bi_8x4 mul_mv_f16_f16 mul_mv_f16_f32_1row @@ -109,6 +110,9 @@ set(GGML_OPENCL_KERNELS softmax_4_f16 softmax_f32 softmax_f16 + sqr + sqrt + ssm_conv sub sum_rows transpose diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 2319f7a9e2..e5302f4550 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -449,6 +449,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16; cl_kernel kernel_add_id; cl_kernel kernel_scale; + cl_kernel kernel_sqr_cont_f32, kernel_sqr_cont_f32_4, kernel_sqr_cont_f16, kernel_sqr_cont_f16_4; + cl_kernel kernel_sqrt_cont_f32, kernel_sqrt_cont_f32_4, kernel_sqrt_cont_f16, kernel_sqrt_cont_f16_4; + cl_kernel kernel_mean_f32; cl_kernel kernel_silu, kernel_silu_4; cl_kernel kernel_gelu, kernel_gelu_4; cl_kernel kernel_gelu_erf, kernel_gelu_erf_4; @@ -509,6 +512,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_conv_2d_f16; cl_kernel kernel_conv_2d_f32; cl_kernel kernel_conv_2d_f16_f32; + cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4; cl_kernel kernel_timestep_embedding; cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32; cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat; @@ -1552,6 +1556,66 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // sqr + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "sqr.cl.h" + }; +#else + const std::string kernel_src = read_file("sqr.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_sqr_cont_f32 = clCreateKernel(prog, "kernel_sqr_cont_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_sqr_cont_f32_4 = clCreateKernel(prog, "kernel_sqr_cont_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_sqr_cont_f16 = clCreateKernel(prog, "kernel_sqr_cont_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_sqr_cont_f16_4 = clCreateKernel(prog, "kernel_sqr_cont_f16_4", &err), err)); + + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // sqrt + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "sqrt.cl.h" + }; +#else + const std::string kernel_src = read_file("sqrt.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_sqrt_cont_f32 = clCreateKernel(prog, "kernel_sqrt_cont_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_sqrt_cont_f32_4 = clCreateKernel(prog, "kernel_sqrt_cont_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_sqrt_cont_f16 = clCreateKernel(prog, "kernel_sqrt_cont_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_sqrt_cont_f16_4 = clCreateKernel(prog, "kernel_sqrt_cont_f16_4", &err), err)); + + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mean + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mean.cl.h" + }; +#else + const std::string kernel_src = read_file("mean.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mean_f32 = clCreateKernel(prog, "kernel_mean_f32", &err), err)); + + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // sub { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1825,6 +1889,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve } } + // ssm_conv + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "ssm_conv.cl.h" + }; +#else + const std::string kernel_src = read_file("ssm_conv.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_ssm_conv_f32_f32 = clCreateKernel(prog, "kernel_ssm_conv_f32_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_ssm_conv_f32_f32_4 = clCreateKernel(prog, "kernel_ssm_conv_f32_f32_4", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mv_id_q4_0_f32_8x_flat { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2959,6 +3041,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); case GGML_OP_ADD_ID: return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SQR: + case GGML_OP_SQRT: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + ggml_is_contiguous(op->src[0]); case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_GELU: @@ -3007,6 +3093,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) || (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32); + case GGML_OP_SSM_CONV: + return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32); case GGML_OP_CONCAT: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_TIMESTEP_EMBEDDING: @@ -3075,6 +3163,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32; } case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); case GGML_OP_FLASH_ATTN_EXT: { @@ -5193,6 +5282,224 @@ static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const } } +static void ggml_cl_sqr(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + // Currently assumes src0 is contiguous + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_sqr_cont_f32_4; + } else { + kernel = backend_ctx->kernel_sqr_cont_f16_4; + } + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_sqr_cont_f32; + } else { + kernel = backend_ctx->kernel_sqr_cont_f16; + } + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); +} + +static void ggml_cl_sqrt(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + // Currently assumes src0 is contiguous + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_sqrt_cont_f32_4; + } else { + kernel = backend_ctx->kernel_sqrt_cont_f16_4; + } + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_sqrt_cont_f32; + } else { + kernel = backend_ctx->kernel_sqrt_cont_f16; + } + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); +} + +static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_UNUSED(src1); + + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + GGML_ASSERT(ggml_is_contiguous(src0)); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; + + cl_kernel kernel = backend_ctx->kernel_mean_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3)); + + size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)64, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + +static void ggml_cl_ssm_conv(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + int ne01 = src0->ne[1]; + cl_ulong nb00 = src0->nb[0]; + cl_ulong nb01 = src0->nb[1]; + cl_ulong nb02 = src0->nb[2]; + + int ne10 = src1->ne[0]; + cl_ulong nb11 = src1->nb[1]; + + int ne1 = dst->ne[1]; + int ne2 = dst->ne[2]; + cl_ulong nb0 = dst->nb[0]; + cl_ulong nb1 = dst->nb[1]; + cl_ulong nb2 = dst->nb[2]; + + cl_kernel kernel = backend_ctx->kernel_ssm_conv_f32_f32; + + if (ne10 % 4 == 0) { + kernel = backend_ctx->kernel_ssm_conv_f32_f32_4; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2)); + + size_t global_work_size[] = {(size_t)ne01, (size_t)ne1, (size_t)ne2}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (ne01 % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); +} + static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -9091,6 +9398,24 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_sub; break; + case GGML_OP_SQR: + if (!any_on_device) { + return false; + } + func = ggml_cl_sqr; + break; + case GGML_OP_SQRT: + if (!any_on_device) { + return false; + } + func = ggml_cl_sqrt; + break; + case GGML_OP_MEAN: + if (!any_on_device) { + return false; + } + func = ggml_cl_mean; + break; case GGML_OP_UNARY: switch (ggml_get_unary_op(tensor)) { case GGML_UNARY_OP_GELU: @@ -9192,6 +9517,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_conv_2d; break; + case GGML_OP_SSM_CONV: + if (!any_on_device) { + return false; + } + func = ggml_cl_ssm_conv; + break; case GGML_OP_CONCAT: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/mean.cl b/ggml/src/ggml-opencl/kernels/mean.cl new file mode 100644 index 0000000000..5c3e8bcd86 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mean.cl @@ -0,0 +1,39 @@ + +kernel void kernel_mean_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float *)((global char *)src0 + offset0); + dst = (global float *)((global char *)dst + offsetd); + + int i3 = get_global_id(2); + int i2 = get_global_id(1); + int i1 = get_global_id(0); + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + float row_sum = 0; + + for (int i0 = 0; i0 < ne00; i0++) { + row_sum += src_row[i0]; + } + + dst_row[0] = row_sum / ne00; +} diff --git a/ggml/src/ggml-opencl/kernels/sqr.cl b/ggml/src/ggml-opencl/kernels/sqr.cl new file mode 100644 index 0000000000..4310906f6e --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/sqr.cl @@ -0,0 +1,53 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_sqr_cont_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = src0[gid] * src0[gid]; +} + +kernel void kernel_sqr_cont_f32_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = src0[gid] * src0[gid]; +} + +kernel void kernel_sqr_cont_f16( + global half * src0, + ulong offset0, + global half * dst, + ulong offsetd +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = src0[gid] * src0[gid]; +} + +kernel void kernel_sqr_cont_f16_4( + global half4 * src0, + ulong offset0, + global half4 * dst, + ulong offsetd +) { + src0 = (global half4*)((global char*)src0 + offset0); + dst = (global half4*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = src0[gid] * src0[gid]; +} diff --git a/ggml/src/ggml-opencl/kernels/sqrt.cl b/ggml/src/ggml-opencl/kernels/sqrt.cl new file mode 100644 index 0000000000..c59fbe06a6 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/sqrt.cl @@ -0,0 +1,53 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_sqrt_cont_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = sqrt(src0[gid]); +} + +kernel void kernel_sqrt_cont_f32_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = sqrt(src0[gid]); +} + +kernel void kernel_sqrt_cont_f16( + global half * src0, + ulong offset0, + global half * dst, + ulong offsetd +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = convert_half(sqrt(convert_float(src0[gid]))); +} + +kernel void kernel_sqrt_cont_f16_4( + global half4 * src0, + ulong offset0, + global half4 * dst, + ulong offsetd +) { + src0 = (global half4*)((global char*)src0 + offset0); + dst = (global half4*)((global char*)dst + offsetd); + + uint gid = get_global_id(0); + dst[gid] = convert_half4(sqrt(convert_float4(src0[gid]))); +} diff --git a/ggml/src/ggml-opencl/kernels/ssm_conv.cl b/ggml/src/ggml-opencl/kernels/ssm_conv.cl new file mode 100644 index 0000000000..7ae21ac739 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/ssm_conv.cl @@ -0,0 +1,77 @@ +kernel void kernel_ssm_conv_f32_f32( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb11, + ulong nb0, + ulong nb1, + ulong nb2 +){ + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int ir = get_global_id(0); + int i2 = get_global_id(1); + int i3 = get_global_id(2); + + int nc = ne10; + + global float * s = (global float *) (src0 + ir*nb01 + i2*nb00 + i3*nb02); + global float * c = (global float *) (src1 + ir*nb11); + global float * d = (global float *) (dst + ir*nb0 + i2*nb1 + i3*nb2); + + float sumf = 0.0f; + + for (int i0 = 0; i0 < nc; ++i0) { + sumf += s[i0] * c[i0]; + } + + d[0] = sumf; +} + +kernel void kernel_ssm_conv_f32_f32_4( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + ulong nb00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb11, + ulong nb0, + ulong nb1, + ulong nb2 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int ir = get_global_id(0); + int i2 = get_global_id(1); + int i3 = get_global_id(2); + + int nc = ne10; + + global float4 * s = (global float4 *) (src0 + ir*nb01 + i2*nb00 + i3*nb02); + global float4 * c = (global float4 *) (src1 + ir*nb11); + global float * d = (global float *) (dst + ir*nb0 + i2*nb1 + i3*nb2); + + float sumf = 0.0f; + + for (int i0 = 0; i0 < nc/4; ++i0) { + sumf += dot(s[i0], c[i0]); + } + + d[0] = sumf; +} From e509411cf142807c947b53b340d2d5594ce38120 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Thu, 27 Nov 2025 01:02:50 +0100 Subject: [PATCH 04/33] server: enable jinja by default, update docs (#17524) * server: enable jinja by default, update docs * fix tests --- common/arg.cpp | 15 +++++++++++++- tools/server/README.md | 41 ++++++++++++++++++++++--------------- tools/server/tests/utils.py | 2 ++ 3 files changed, 41 insertions(+), 17 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index dd787290d2..9a874c6b1d 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -694,6 +694,12 @@ static bool is_autoy(const std::string & value) { } common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) { + // default values specific to example + // note: we place it here instead of inside server.cpp to allow llama-gen-docs to pick it up + if (ex == LLAMA_EXAMPLE_SERVER) { + params.use_jinja = true; + } + // load dynamic backends ggml_backend_load_all(); @@ -2488,11 +2494,18 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--jinja"}, - "use jinja template for chat (default: disabled)", + string_format("use jinja template for chat (default: %s)\n", params.use_jinja ? "enabled" : "disabled"), [](common_params & params) { params.use_jinja = true; } ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_JINJA")); + add_opt(common_arg( + {"--no-jinja"}, + string_format("disable jinja template for chat (default: %s)\n", params.use_jinja ? "enabled" : "disabled"), + [](common_params & params) { + params.use_jinja = false; + } + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_NO_JINJA")); add_opt(common_arg( {"--reasoning-format"}, "FORMAT", "controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n" diff --git a/tools/server/README.md b/tools/server/README.md index 8fd478eb32..b7fc565ec6 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -30,9 +30,10 @@ The project is under active development, and we are [looking for feedback and co | -------- | ----------- | | `-h, --help, --usage` | print usage and exit | | `--version` | show version and build info | +| `-cl, --cache-list` | show list of models in cache | | `--completion-bash` | print source-able bash completion script for llama.cpp | | `--verbose-prompt` | print a verbose prompt before generation (default: false) | -| `-t, --threads N` | number of threads to use during generation (default: -1)
(env: LLAMA_ARG_THREADS) | +| `-t, --threads N` | number of CPU threads to use during generation (default: -1)
(env: LLAMA_ARG_THREADS) | | `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) | | `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") | | `-Cr, --cpu-range lo-hi` | range of CPUs for affinity. Complements --cpu-mask | @@ -51,7 +52,7 @@ The project is under active development, and we are [looking for feedback and co | `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) | | `--swa-full` | use full-size SWA cache (default: false)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
(env: LLAMA_ARG_SWA_FULL) | | `--kv-unified, -kvu` | use single unified KV buffer for the KV cache of all sequences (default: false)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)
(env: LLAMA_ARG_KV_SPLIT) | -| `-fa, --flash-attn` | enable Flash Attention (default: disabled)
(env: LLAMA_ARG_FLASH_ATTN) | +| `-fa, --flash-attn [on\|off\|auto]` | set Flash Attention use ('on', 'off', or 'auto', default: 'auto')
(env: LLAMA_ARG_FLASH_ATTN) | | `--no-perf` | disable internal libllama performance timings (default: false)
(env: LLAMA_ARG_NO_PERF) | | `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) | | `--no-escape` | do not process escape sequences | @@ -61,11 +62,12 @@ The project is under active development, and we are [looking for feedback and co | `--rope-freq-scale N` | RoPE frequency scaling factor, expands context by a factor of 1/N
(env: LLAMA_ARG_ROPE_FREQ_SCALE) | | `--yarn-orig-ctx N` | YaRN: original context size of model (default: 0 = model training context size)
(env: LLAMA_ARG_YARN_ORIG_CTX) | | `--yarn-ext-factor N` | YaRN: extrapolation mix factor (default: -1.0, 0.0 = full interpolation)
(env: LLAMA_ARG_YARN_EXT_FACTOR) | -| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: 1.0)
(env: LLAMA_ARG_YARN_ATTN_FACTOR) | -| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: 1.0)
(env: LLAMA_ARG_YARN_BETA_SLOW) | -| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: 32.0)
(env: LLAMA_ARG_YARN_BETA_FAST) | +| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: -1.0)
(env: LLAMA_ARG_YARN_ATTN_FACTOR) | +| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: -1.0)
(env: LLAMA_ARG_YARN_BETA_SLOW) | +| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: -1.0)
(env: LLAMA_ARG_YARN_BETA_FAST) | | `-nkvo, --no-kv-offload` | disable KV offload
(env: LLAMA_ARG_NO_KV_OFFLOAD) | | `-nr, --no-repack` | disable weight repacking
(env: LLAMA_ARG_NO_REPACK) | +| `--no-host` | bypass host buffer allowing extra buffers to be used
(env: LLAMA_ARG_NO_HOST) | | `-ctk, --cache-type-k TYPE` | KV cache data type for K
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_K) | | `-ctv, --cache-type-v TYPE` | KV cache data type for V
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V) | | `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)
(env: LLAMA_ARG_DEFRAG_THOLD) | @@ -78,7 +80,7 @@ The project is under active development, and we are [looking for feedback and co | `--override-tensor, -ot =,...` | override tensor buffer type | | `--cpu-moe, -cmoe` | keep all Mixture of Experts (MoE) weights in the CPU
(env: LLAMA_ARG_CPU_MOE) | | `--n-cpu-moe, -ncmoe N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU
(env: LLAMA_ARG_N_CPU_MOE) | -| `-ngl, --gpu-layers, --n-gpu-layers N` | number of layers to store in VRAM
(env: LLAMA_ARG_N_GPU_LAYERS) | +| `-ngl, --gpu-layers, --n-gpu-layers N` | max. number of layers to store in VRAM (default: -1)
(env: LLAMA_ARG_N_GPU_LAYERS) | | `-sm, --split-mode {none,layer,row}` | how to split the model across multiple GPUs, one of:
- none: use one GPU only
- layer (default): split layers and KV across GPUs
- row: split rows across GPUs
(env: LLAMA_ARG_SPLIT_MODE) | | `-ts, --tensor-split N0,N1,N2,...` | fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1
(env: LLAMA_ARG_TENSOR_SPLIT) | | `-mg, --main-gpu INDEX` | the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: 0)
(env: LLAMA_ARG_MAIN_GPU) | @@ -92,6 +94,7 @@ The project is under active development, and we are [looking for feedback and co | `--control-vector-layer-range START END` | layer range to apply the control vector(s) to, start and end inclusive | | `-m, --model FNAME` | model path (default: `models/$filename` with filename from `--hf-file` or `--model-url` if set, otherwise models/7B/ggml-model-f16.gguf)
(env: LLAMA_ARG_MODEL) | | `-mu, --model-url MODEL_URL` | model download url (default: unused)
(env: LLAMA_ARG_MODEL_URL) | +| `-dr, --docker-repo [/][:quant]` | Docker Hub model repository. repo is optional, default to ai/. quant is optional, default to :latest.
example: gemma3
(default: unused)
(env: LLAMA_ARG_DOCKER_REPO) | | `-hf, -hfr, --hf-repo /[:quant]` | Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.
mmproj is also downloaded automatically if available. to disable, add --no-mmproj
example: unsloth/phi-4-GGUF:q4_k_m
(default: unused)
(env: LLAMA_ARG_HF_REPO) | | `-hfd, -hfrd, --hf-repo-draft /[:quant]` | Same as --hf-repo, but for the draft model (default: unused)
(env: LLAMA_ARG_HFD_REPO) | | `-hff, --hf-file FILE` | Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)
(env: LLAMA_ARG_HF_FILE) | @@ -100,7 +103,7 @@ The project is under active development, and we are [looking for feedback and co | `-hft, --hf-token TOKEN` | Hugging Face access token (default: value from HF_TOKEN environment variable)
(env: HF_TOKEN) | | `--log-disable` | Log disable | | `--log-file FNAME` | Log to file | -| `--log-colors` | Enable colored logging
(env: LLAMA_LOG_COLORS) | +| `--log-colors [on\|off\|auto]` | Set colored logging ('on', 'off', or 'auto', default: 'auto')
'auto' enables colors when output is to a terminal
(env: LLAMA_LOG_COLORS) | | `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) | | `--offline` | Offline mode: forces use of cache, prevents network access
(env: LLAMA_OFFLINE) | | `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored.
(env: LLAMA_LOG_VERBOSITY) | @@ -151,7 +154,8 @@ The project is under active development, and we are [looking for feedback and co | Argument | Explanation | | -------- | ----------- | -| `--swa-checkpoints N` | max number of SWA checkpoints per slot to create (default: 3)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)
(env: LLAMA_ARG_SWA_CHECKPOINTS) | +| `--ctx-checkpoints, --swa-checkpoints N` | max number of context checkpoints to create per slot (default: 8)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)
(env: LLAMA_ARG_CTX_CHECKPOINTS) | +| `--cache-ram, -cram N` | set the maximum cache size in MiB (default: 8192, -1 - no limit, 0 - disable)
[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)
(env: LLAMA_ARG_CACHE_RAM) | | `--no-context-shift` | disables context shift on infinite text generation (default: enabled)
(env: LLAMA_ARG_NO_CONTEXT_SHIFT) | | `--context-shift` | enables context shift on infinite text generation (default: disabled)
(env: LLAMA_ARG_CONTEXT_SHIFT) | | `-r, --reverse-prompt PROMPT` | halt generation at PROMPT, return control in interactive mode
| @@ -165,6 +169,8 @@ The project is under active development, and we are [looking for feedback and co | `--mmproj-url URL` | URL to a multimodal projector file. see tools/mtmd/README.md
(env: LLAMA_ARG_MMPROJ_URL) | | `--no-mmproj` | explicitly disable multimodal projector, useful when using -hf
(env: LLAMA_ARG_NO_MMPROJ) | | `--no-mmproj-offload` | do not offload multimodal projector to GPU
(env: LLAMA_ARG_NO_MMPROJ_OFFLOAD) | +| `--image-min-tokens N` | minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)
(env: LLAMA_ARG_IMAGE_MIN_TOKENS) | +| `--image-max-tokens N` | maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)
(env: LLAMA_ARG_IMAGE_MAX_TOKENS) | | `--override-tensor-draft, -otd =,...` | override tensor buffer type for draft model | | `--cpu-moe-draft, -cmoed` | keep all Mixture of Experts (MoE) weights in the CPU for the draft model
(env: LLAMA_ARG_CPU_MOE_DRAFT) | | `--n-cpu-moe-draft, -ncmoed N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model
(env: LLAMA_ARG_N_CPU_MOE_DRAFT) | @@ -189,13 +195,14 @@ The project is under active development, and we are [looking for feedback and co | `--slots` | enable slots monitoring endpoint (default: enabled)
(env: LLAMA_ARG_ENDPOINT_SLOTS) | | `--no-slots` | disables slots monitoring endpoint
(env: LLAMA_ARG_NO_ENDPOINT_SLOTS) | | `--slot-save-path PATH` | path to save slot kv cache (default: disabled) | -| `--jinja` | use jinja template for chat (default: disabled)
(env: LLAMA_ARG_JINJA) | -| `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:
- none: leaves thoughts unparsed in `message.content`
- deepseek: puts thoughts in `message.reasoning_content`
- deepseek-legacy: keeps `` tags in `message.content` while also populating `message.reasoning_content`
(default: deepseek)
(env: LLAMA_ARG_THINK) | +| `--jinja` | use jinja template for chat (default: enabled)

(env: LLAMA_ARG_JINJA) | +| `--no-jinja` | disable jinja template for chat (default: enabled)

(env: LLAMA_ARG_NO_JINJA) | +| `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:
- none: leaves thoughts unparsed in `message.content`
- deepseek: puts thoughts in `message.reasoning_content`
- deepseek-legacy: keeps `` tags in `message.content` while also populating `message.reasoning_content`
(default: auto)
(env: LLAMA_ARG_THINK) | | `--reasoning-budget N` | controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)
(env: LLAMA_ARG_THINK_BUDGET) | -| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE) | -| `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) | +| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE) | +| `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) | | `--no-prefill-assistant` | whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)
when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled

(env: LLAMA_ARG_NO_PREFILL_ASSISTANT) | -| `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)
| +| `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.10, 0.0 = disabled)
| | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) | | `-td, --threads-draft N` | number of threads to use during generation (default: same as --threads) | | `-tbd, --threads-batch-draft N` | number of threads to use during batch and prompt processing (default: same as --threads-draft) | @@ -209,15 +216,17 @@ The project is under active development, and we are [looking for feedback and co | `--spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible | | `-mv, --model-vocoder FNAME` | vocoder model for audio generation (default: unused) | | `--tts-use-guide-tokens` | Use guide tokens to improve TTS word recall | -| `--embd-bge-small-en-default` | use default bge-small-en-v1.5 model (note: can download weights from the internet) | -| `--embd-e5-small-en-default` | use default e5-small-v2 model (note: can download weights from the internet) | -| `--embd-gte-small-default` | use default gte-small model (note: can download weights from the internet) | +| `--embd-gemma-default` | use default EmbeddingGemma model (note: can download weights from the internet) | | `--fim-qwen-1.5b-default` | use default Qwen 2.5 Coder 1.5B (note: can download weights from the internet) | | `--fim-qwen-3b-default` | use default Qwen 2.5 Coder 3B (note: can download weights from the internet) | | `--fim-qwen-7b-default` | use default Qwen 2.5 Coder 7B (note: can download weights from the internet) | | `--fim-qwen-7b-spec` | use Qwen 2.5 Coder 7B + 0.5B draft for speculative decoding (note: can download weights from the internet) | | `--fim-qwen-14b-spec` | use Qwen 2.5 Coder 14B + 0.5B draft for speculative decoding (note: can download weights from the internet) | | `--fim-qwen-30b-default` | use default Qwen 3 Coder 30B A3B Instruct (note: can download weights from the internet) | +| `--gpt-oss-20b-default` | use gpt-oss-20b (note: can download weights from the internet) | +| `--gpt-oss-120b-default` | use gpt-oss-120b (note: can download weights from the internet) | +| `--vision-gemma-4b-default` | use Gemma 3 4B QAT (note: can download weights from the internet) | +| `--vision-gemma-12b-default` | use Gemma 3 12B QAT (note: can download weights from the internet) | Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var. diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index da703c4c51..a779283d69 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -205,6 +205,8 @@ class ServerProcess: server_args.append("--no-webui") if self.jinja: server_args.append("--jinja") + else: + server_args.append("--no-jinja") if self.reasoning_format is not None: server_args.extend(("--reasoning-format", self.reasoning_format)) if self.reasoning_budget is not None: From 142df17c9c296c846131041283c69edd2db754d8 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 26 Nov 2025 23:32:30 -0600 Subject: [PATCH 05/33] vulkan: use a fixed 1KB buffer for the add_rms_fusion opt (#17514) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 7f2cf795c9..7c7ce1d8e7 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5289,7 +5289,8 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { ctx->prealloc_size_x = 0; ctx->prealloc_size_y = 0; ctx->prealloc_size_split_k = 0; - ctx->prealloc_size_add_rms_partials = 0; + // Fixed size of 1KB, for deterministic behavior + ctx->prealloc_size_add_rms_partials = 1024; ctx->fence = ctx->device->device.createFence({}); ctx->almost_ready_fence = ctx->device->device.createFence({}); @@ -13095,7 +13096,6 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->fused_ops_write_mask = 0; } - ctx->prealloc_size_add_rms_partials = std::max(ctx->prealloc_size_add_rms_partials, ctx->prealloc_size_add_rms_partials_offset); ctx->last_total_mul_mat_bytes = total_mul_mat_bytes; if (vk_perf_logger_enabled) { From b78db3bd50b7d6b9578ced49243d787ee6622f89 Mon Sep 17 00:00:00 2001 From: Acly Date: Thu, 27 Nov 2025 06:54:19 +0100 Subject: [PATCH 06/33] vulkan : move contiguous checks to device_supports_op (#17490) * vulkan : remove op_supports_incontiguous and add missing constraints in device_supports_op * im2col: remove contraints on src0 (kernel input) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 85 ++++++++++++---------------- 1 file changed, 35 insertions(+), 50 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 7c7ce1d8e7..b4ab85292f 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -8687,41 +8687,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const GGML_UNUSED(src2); } -static bool ggml_vk_op_supports_incontiguous(ggml_op op) { - switch (op) { - case GGML_OP_CPY: - case GGML_OP_GET_ROWS: - case GGML_OP_ADD: - case GGML_OP_SUB: - case GGML_OP_MUL: - case GGML_OP_DIV: - case GGML_OP_ADD_ID: - case GGML_OP_CONCAT: - case GGML_OP_UPSCALE: - case GGML_OP_SQR: - case GGML_OP_SQRT: - case GGML_OP_SIN: - case GGML_OP_COS: - case GGML_OP_LOG: - case GGML_OP_CLAMP: - case GGML_OP_PAD: - case GGML_OP_REPEAT: - case GGML_OP_REPEAT_BACK: - case GGML_OP_ROPE: - case GGML_OP_RMS_NORM: - case GGML_OP_CONV_2D_DW: - case GGML_OP_IM2COL: - case GGML_OP_IM2COL_3D: - case GGML_OP_SET_ROWS: - case GGML_OP_SUM: - case GGML_OP_SUM_ROWS: - case GGML_OP_MEAN: - return true; - default: - return false; - } -} - template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); @@ -8806,7 +8771,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; std::cerr << "), " << ggml_op_name(op) << ")"); GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT - GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT GGML_ASSERT(dst->buffer != nullptr); const uint64_t ne00 = src0->ne[0]; const uint64_t ne01 = src0->ne[1]; @@ -8837,22 +8801,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); - const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op); - - vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, op_supports_incontiguous); - vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, op_supports_incontiguous) : vk_subbuffer{}; - vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, op_supports_incontiguous) : vk_subbuffer{}; - vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, op_supports_incontiguous) : vk_subbuffer{}; - vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, op_supports_incontiguous); + vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, true); + vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, true) : vk_subbuffer{}; + vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, true) : vk_subbuffer{}; + vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, true) : vk_subbuffer{}; + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true); // Compute misalignment offset for descriptors and store it in in push constants. init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst); std::array elements; - // Single call if dimension 2 is contiguous - GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))); - switch (op) { case GGML_OP_NORM: case GGML_OP_RMS_NORM_BACK: @@ -13876,15 +13835,17 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm op->type == GGML_TYPE_F32; case GGML_OP_SILU_BACK: case GGML_OP_RMS_NORM_BACK: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_CLAMP: + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_LEAKY_RELU: case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: - return op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_LOG: return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; case GGML_OP_ARGSORT: @@ -13919,17 +13880,29 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return true; case GGML_OP_UPSCALE: case GGML_OP_ACC: + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_CONCAT: + return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32); case GGML_OP_ADD1: + return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32) + || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32) + || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16); case GGML_OP_ARANGE: case GGML_OP_FILL: + return op->type == GGML_TYPE_F32; case GGML_OP_SCALE: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_PAD: case GGML_OP_ROLL: + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_DIAG_MASK_INF: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SOFT_MAX: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32 + && (!op->src[1] || (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)); case GGML_OP_SOFT_MAX_BACK: - return true; + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32 + && ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32; case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: @@ -13944,15 +13917,27 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } case GGML_OP_ARGMAX: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_COUNT_EQUAL: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_I32 + && ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_I32; case GGML_OP_IM2COL: + return ggml_is_contiguous(op->src[1]) + && op->src[1]->type == GGML_TYPE_F32 + && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); case GGML_OP_IM2COL_3D: + return op->src[1]->type == GGML_TYPE_F32 + && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); case GGML_OP_TIMESTEP_EMBEDDING: + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_CONV_2D_DW: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) + && op->src[1]->type == GGML_TYPE_F32; case GGML_OP_POOL_2D: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: - return true; + return true; // all inputs are contiguous, see ggml.c case GGML_OP_SSM_SCAN: { for (int i = 0; i < 6; i++) { @@ -13993,7 +13978,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return true; } case GGML_OP_SSM_CONV: - return true; + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_CONV_TRANSPOSE_1D: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; case GGML_OP_CONV_2D: From 4fcd87cf7cbb131b3e28e121b29cc588e460eb40 Mon Sep 17 00:00:00 2001 From: Aleksei Nikiforov <103434461+AlekseiNikiforovIBM@users.noreply.github.com> Date: Thu, 27 Nov 2025 11:35:38 +0100 Subject: [PATCH 07/33] gguf-py : skip endian-conversion of MXFP4 data (#17523) * gguf_convert_endian.py: skip MXFP4 data * Use gguf.constants.GGML_QUANT_SIZES to determine block sizes --- gguf-py/gguf/scripts/gguf_convert_endian.py | 30 +++++++++------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/gguf-py/gguf/scripts/gguf_convert_endian.py b/gguf-py/gguf/scripts/gguf_convert_endian.py index 0bda490a20..86bf87846c 100755 --- a/gguf-py/gguf/scripts/gguf_convert_endian.py +++ b/gguf-py/gguf/scripts/gguf_convert_endian.py @@ -19,6 +19,11 @@ import gguf logger = logging.getLogger("gguf-convert-endian") +def byteswap_noop(tensor, block_offs): + # this function is used when byteswapping is not needed + pass + + def byteswap_q4_0(tensor, block_offs): # Each block_q4_0 consists of an f16 delta (scaling factor) followed by 16 int8 quantizations. @@ -55,22 +60,11 @@ def byteswap_q6_k(tensor, block_offs): byteswap_tensors = { - gguf.GGMLQuantizationType.Q4_0: { - "block_size": 18, # 18 bytes = + 16 * - "byteswap_func": byteswap_q4_0, - }, - gguf.GGMLQuantizationType.Q8_0: { - "block_size": 34, # 34 bytes = + 32 * - "byteswap_func": byteswap_q8_0, - }, - gguf.GGMLQuantizationType.Q4_K: { - "block_size": 144, # 144 bytes = 2 * + 140 * - "byteswap_func": byteswap_q4_k, - }, - gguf.GGMLQuantizationType.Q6_K: { - "block_size": 210, # 210 bytes = + 208 * - "byteswap_func": byteswap_q6_k, - }, + gguf.GGMLQuantizationType.Q4_0: byteswap_q4_0, + gguf.GGMLQuantizationType.Q8_0: byteswap_q8_0, + gguf.GGMLQuantizationType.Q4_K: byteswap_q4_k, + gguf.GGMLQuantizationType.Q6_K: byteswap_q6_k, + gguf.GGMLQuantizationType.MXFP4: byteswap_noop, } @@ -135,8 +129,8 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None tensor.data.resize(newshape) - block_size = byteswap_tensors[tensor.tensor_type]["block_size"] - byteswap_func = byteswap_tensors[tensor.tensor_type]["byteswap_func"] + block_size = gguf.constants.GGML_QUANT_SIZES[tensor.tensor_type][1] + byteswap_func = byteswap_tensors[tensor.tensor_type] n_blocks = len(tensor.data) // block_size for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)): From d21a76ac38ca9d95be610f72b613c7455fa04103 Mon Sep 17 00:00:00 2001 From: Eric Curtin Date: Thu, 27 Nov 2025 10:35:47 +0000 Subject: [PATCH 08/33] devops: Add build-essential to Ubuntu 26.04 image (#17531) This is no longer passing the build, needs more packages. Signed-off-by: Eric Curtin --- .devops/vulkan.Dockerfile | 1 + 1 file changed, 1 insertion(+) diff --git a/.devops/vulkan.Dockerfile b/.devops/vulkan.Dockerfile index ebf23ba5cf..fd7195c5be 100644 --- a/.devops/vulkan.Dockerfile +++ b/.devops/vulkan.Dockerfile @@ -50,6 +50,7 @@ WORKDIR /app RUN apt-get update \ && apt-get install -y \ + build-essential \ git \ python3 \ python3-pip \ From cd8370b40890dd40ae91a1cc0206107cf78c9303 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alberto=20Cabrera=20P=C3=A9rez?= <1478977+Alcpz@users.noreply.github.com> Date: Thu, 27 Nov 2025 11:25:14 +0000 Subject: [PATCH 09/33] ggml-cpu: aarm64: q4_K repack gemm and gemv implementations (dotprod only) (#17494) * Enabled q4_K_4x8 path * Fixed generic Q4_K 8x4 implementation * wip: dotprod gemm * Working arm q4_K dotprod gemm Signed-off-by: Alberto Cabrera * Undo acc rename Signed-off-by: Alberto Cabrera * Q4_K arm dotprod gemm Signed-off-by: Alberto Cabrera * Fix: q4_qs reinterpret from uint to int Signed-off-by: Alberto Cabrera * Removed comments * Fixed macro guards * Fixed unused vars in generic implementation * Fixed unused vars in 8x4 repack * Fixed unused vars in generic implementation, unneeded comment * Missing arch fallback for x86 * minor : style --------- Signed-off-by: Alberto Cabrera Co-authored-by: Georgi Gerganov --- ggml/src/ggml-cpu/arch-fallback.h | 22 ++ ggml/src/ggml-cpu/arch/arm/repack.cpp | 339 +++++++++++++++++++++++++- ggml/src/ggml-cpu/repack.cpp | 235 +++++++++++++++++- ggml/src/ggml-cpu/repack.h | 6 + 4 files changed, 598 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index d27a969706..0775c87f98 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -33,10 +33,12 @@ // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 @@ -44,12 +46,14 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) // repack.cpp +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K @@ -58,11 +62,14 @@ #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 +#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 +#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #elif defined(__POWERPC__) || defined(__powerpc__) // ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679 @@ -74,10 +81,12 @@ // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 @@ -85,6 +94,7 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 @@ -99,10 +109,12 @@ // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 @@ -110,6 +122,7 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 @@ -132,15 +145,18 @@ // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 +#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 +#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 @@ -161,10 +177,12 @@ // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 @@ -172,6 +190,7 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 @@ -194,10 +213,12 @@ // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 @@ -205,6 +226,7 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index d2adfbea87..082bd2bf04 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -497,6 +497,140 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567 + const uint8x16_t m4b = vdupq_n_u8(0x0f); + + // 1x8 tile = 2 x 4 + float32x4_t acc_f32[col_groups]; + + const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); + + for (int i = 0; i < col_groups; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3 + float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7 + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d); + float32x4_t sb_scale_0123 = vmulq_f32(q4_d_0, q8_d); + float32x4_t sb_scale_4567 = vmulq_f32(q4_d_1, q8_d); + float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3 + float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7 + float32x4_t sb_min_0123 = vmulq_f32(q4_dmin_0, q8_d); + float32x4_t sb_min_4567 = vmulq_f32(q4_dmin_1, q8_d); + + // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567 + int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + int32x4_t acc_lo[col_groups]; + int32x4_t acc_hi[col_groups]; + + // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block + const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8)); + int16_t bsums_arr[8]; + vst1q_s16(bsums_arr, bsums); + for (int sb = 0; sb < QK_K / 64; sb++) { + for (int i = 0; i < col_groups; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int16x8_t q4sb_mins[2]; + int16x8_t q4sb_scales[2]; + for (int i = 0; i < 2; i++) { + int8_t aux_q4sb[8]; + const int offset = sb * 24 + i * 12; + decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); + } + + int8x16_t q8_qs[64 / 16]; + for (int i = 0; i < 64 / 16; i++) { + q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16); + } + + for (int c = 0; c < col_groups; c++) { + uint8x16_t q4_cols[8]; + for (int i = 0; i < 8; i++) { + q4_cols[i] = vld1q_u8(q4_ptr[b].qs + sb * QK_K + i * 32 + 16 * c); + } + + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[0], m4b)), q8_qs[0], 0); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[1], m4b)), q8_qs[0], 1); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[2], m4b)), q8_qs[0], 2); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[3], m4b)), q8_qs[0], 3); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[4], m4b)), q8_qs[1], 0); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[5], m4b)), q8_qs[1], 1); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[6], m4b)), q8_qs[1], 2); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[7], m4b)), q8_qs[1], 3); + + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[0], 4)), q8_qs[2], 0); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[1], 4)), q8_qs[2], 1); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[2], 4)), q8_qs[2], 2); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[3], 4)), q8_qs[2], 3); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[4], 4)), q8_qs[3], 0); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[5], 4)), q8_qs[3], 1); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[6], 4)), q8_qs[3], 2); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[7], 4)), q8_qs[3], 3); + } + + // Scales + // row c0123 blk0 and blk1 + const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]); + const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]); + const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]), + vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0]))); + acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123); + // row c4567 blk0 and blk1 + const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]); + const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]); + const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]), + vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1]))); + acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567); + + // Bias Correction + const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]); + const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]); + + bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0])); + bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1])); + bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0])); + bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1])); + } // for sb + + acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123); + acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567); + } // for b + + int base = x * ncols_interleaved; + vst1q_f32(s + base, acc_f32[0]); + vst1q_f32(s + base + 4, acc_f32[1]); + } // for x + return; +#endif // #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, @@ -518,7 +652,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined(__aarch64__) && defined(__ARM_NEON) +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) constexpr int col_pairs = ncols_interleaved / 2; const uint8x16_t m4b = vdupq_n_u8(0x0f); @@ -615,7 +749,6 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1; // 0123 or 4567 - // TODO: Single superblock mul at the end of the superblock float32x4_t sumf_0 = vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1]))); acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0); @@ -649,7 +782,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, vst1q_f32(s + base + 4, acc_f32[1]); } // for x return; -#endif // defined(__aarch64__) && defined(__ARM_NEON) +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } @@ -2069,6 +2202,206 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 4; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int q8_k_blocklen = 4; + constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs + const uint8x16_t m4b = vdupq_n_u8(0x0f); + + // 8 accumulators: 2 row pairs × 4 col pairs + float32x4_t acc_f32[acc_size]; + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); + + for (int i = 0; i < acc_size; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + // d4 0 1 2 3, 4 5 6 7 + float32x4_t q4_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); + float32x4_t q4_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); + // d8 0 1 2 3 + float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d); + // mins + float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); + float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); + + // Precomputation of scales and mins + float32x4_t sbd_scale_0123[q8_k_blocklen]; + float32x4_t sbd_scale_4567[q8_k_blocklen]; + float32x4_t sbd_min_0123[q8_k_blocklen]; + float32x4_t sbd_min_4567[q8_k_blocklen]; + + sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0); + sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0); + sbd_min_0123[0] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0); + sbd_min_4567[0] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0); + + sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1); + sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1); + sbd_min_0123[1] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1); + sbd_min_4567[1] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1); + + sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2); + sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2); + sbd_min_0123[2] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2); + sbd_min_4567[2] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2); + + sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3); + sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3); + sbd_min_0123[3] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3); + sbd_min_4567[3] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3); + + // Precomputation of bsums, each vpaddq calcs all the bsums for each row + const int16x8_t bsums[q8_k_blocklen] = { + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), + }; + int16_t bsums_arr[QK_K / 64][8]; + for (int q8_row = 0; q8_row < 4; q8_row++) { + vst1q_s16(bsums_arr[q8_row], bsums[q8_row]); + } + + // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 .. + int32x4_t bias_acc[acc_size]; + for (int i = 0; i < acc_size; i++) { + bias_acc[i] = vdupq_n_s32(0); + } + + for (int sb = 0; sb < QK_K / 64; sb++) { + // Int accumulators for qs vecdot (4 row x 2 col quartets) + int32x4_t acc_lo[acc_size]; + int32x4_t acc_hi[acc_size]; + for (int i = 0; i < acc_size; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int16x8_t q4sb_scales[2]; + int16x8_t q4sb_mins[2]; + for (int i = 0; i < 2; i++) { + int8_t aux_q4sb[8]; + const int offset = sb * 24 + i * 12; + decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); + } + + constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows + for (int k = 0; k < reads_per_sb; k++) { + const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k); + const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128); + + // 0..3 & 32..35 + const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k); + const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16); + + const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b)); + const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4)); + + acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0); // 0..3 r0 c0123 + acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1); // 0..3 r1 c0123 + acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2); // 0..3 r2 c0123 + acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3); // 0..3 r3 c0123 + + acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0); // 32..35 r0 c0123 + acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1); // 32..35 r1 c0123 + acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2); // 32..35 r2 c0123 + acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3); // 32..35 r3 c0123 + + const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b)); + const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4)); + + acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0); // 0..3 r0 c4567 + acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1); // 0..3 r1 c4567 + acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2); // 0..3 r2 c4567 + acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3); // 0..3 r3 c4567 + + acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0); // 32..35 r0 c4567 + acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1); // 32..35 r1 c4567 + acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2); // 32..35 r2 c4567 + acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3); // 32..35 r3 c4567 + } + + // Scale and bias application + // acc is stored interleaved to match output layout + const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]); + const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]); + const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]); + const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]); + for (int row = 0; row < q8_k_blocklen; row++) { + // Bias correction + // row c0123 blk0 and blk1 + const float32x4_t sumf_0123 = + vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]), + vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row]))); + acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123); + + // row c4567 blk0 and blk1 + const float32x4_t sumf_4567 = + vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]), + vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4]))); + acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567); + + // Bias + const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]); + const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]); + + // row c0123 blk0 and blk1 + bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0])); + bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1])); + + // row c4567 blk0 and blk1 + bias_acc[2 * row + 1] = + vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0])); + bias_acc[2 * row + 1] = + vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1])); + } + } // for sb + + for (int row = 0; row < q8_k_blocklen; row++) { + acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]); + acc_f32[2 * row + 1] = + vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]); + } + } // for b + + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + vst1q_f32(s + offset, acc_f32[2 * i + j]); + } + } + } // for x + } // for y + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 875faedf2d..9f0d449bd6 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -124,6 +124,58 @@ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GG } } + +void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK_K == 256); + assert(k % QK_K == 0); + const int nb = k / QK_K; + + block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy; + + // scalar + const int blck_size_interleave = 4; + float srcv[4][QK_K]; + float iscale[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + float max = 0; + + for (int j = 0; j < QK_K; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK_K + j]; + // Update the maximum value of the corresponding super block + if(amax < fabsf(srcv[row_iter][j])) { + amax = fabsf(srcv[row_iter][j]); + max = srcv[row_iter][j]; + } + } + + iscale[row_iter] = amax ? -127.f/max : 0; + + y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0; + } + + for (int j = 0; j < QK_K / 4; j++) { + y[i].bsums[j] = 0; + } + + // Quants values are interleaved in sequence of four bytes from corresponding super blocks + // Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving + // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on + for (int j = 0; j < QK_K * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + int index = (((j & 15) >> 2) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3); + + float x0 = srcv[src_id][src_offset] * iscale[src_id]; + y[i].qs[j] = nearest_int(x0); + y[i].bsums[index] += y[i].qs[j]; + } + } +} + void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK_K == 256); assert(k % QK_K == 0); @@ -192,6 +244,12 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTR ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row); } +template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { + assert(nrow == 4); + UNUSED(nrow); + ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row); +} + template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { assert(nrow == 4); UNUSED(nrow); @@ -333,6 +391,77 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 4; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + float sum_minf[8]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + sum_minf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32; + uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16; + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4); + sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i]); + sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i + 32]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } +} + void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -727,6 +856,89 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 4; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][8]; + float sum_minf[4][8]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32; + uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16; + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4); + sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i]); + sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i + 128]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for(int m = 0; m < 4; m++) { + const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); + for(int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } +} + void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -1228,9 +1440,10 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } + static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_K); - GGML_ASSERT(interleave_block == 8); + GGML_ASSERT(interleave_block == 8 || interleave_block == 4); constexpr int nrows_interleaved = 8; block_q4_Kx8 * dst = (block_q4_Kx8*)t->data; @@ -1468,6 +1681,10 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_K_to_q4_K_8_bl(t, 4, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size); } @@ -1501,6 +1718,10 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -1529,6 +1750,10 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -1931,6 +2156,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q4_0_4x4_q8_0; static const ggml::cpu::repack::tensor_traits q4_0_4x8_q8_0; static const ggml::cpu::repack::tensor_traits q4_0_8x8_q8_0; + + // instance for Q4_K + static const ggml::cpu::repack::tensor_traits q4_K_8x4_q8_K; static const ggml::cpu::repack::tensor_traits q4_K_8x8_q8_K; // instance for Q2 @@ -1967,6 +2195,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q4_K_8x8_q8_K; } } + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + if (cur->ne[1] % 8 == 0) { + return &q4_K_8x4_q8_K; + } + } } else if (cur->type == GGML_TYPE_Q2_K) { if (ggml_cpu_has_avx512()) { if (cur->ne[1] % 8 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index cb32b503d3..c4d928cd15 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -80,10 +80,12 @@ extern "C" { void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -91,6 +93,7 @@ void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -99,10 +102,12 @@ void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const // Native implementations void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -110,6 +115,7 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); From 909072abcfed4798f86b14c0a79df057a9e6ab47 Mon Sep 17 00:00:00 2001 From: matt23654 <193348153+matt23654@users.noreply.github.com> Date: Thu, 27 Nov 2025 11:35:35 +0000 Subject: [PATCH 10/33] cuda : fix UMA detection on discrete GPUs. (#17537) --- ggml/src/ggml-cuda/ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 0b29074f33..aa6570765a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3837,7 +3837,7 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * // Check if UMA is explicitly enabled via environment variable bool uma_env = getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr; - bool is_uma = prop.unifiedAddressing > 0 || uma_env; + bool is_uma = prop.integrated > 0 || uma_env; if (is_uma) { // For UMA systems (like DGX Spark), use system memory info From 6783b11fb0889d68d0046176b4cc92ceee1961b0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 27 Nov 2025 16:04:29 +0200 Subject: [PATCH 11/33] models : fix LFM2 tensors (#17548) --- src/llama-arch.cpp | 6 +++--- src/llama-model.cpp | 7 ++++--- src/models/lfm2.cpp | 8 +++++--- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 7ef87acf1b..6da9e0a371 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -2237,7 +2237,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" }, { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" }, { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "token_embd_norm" }, // note: wrong tensor name { LLM_TENSOR_OUTPUT, "output" }, } }, @@ -2259,7 +2259,7 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" }, { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" }, { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "token_embd_norm" }, // note: wrong tensor name { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, @@ -2490,8 +2490,8 @@ static const std::map> LLM_TENSOR_N static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, {LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, {LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}}, {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a042ea9632..cba875f114 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -6133,9 +6133,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_LFM2: case LLM_ARCH_LFM2MOE: { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); if (output == NULL) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); diff --git a/src/models/lfm2.cpp b/src/models/lfm2.cpp index ca06bacd7b..7f805d7879 100644 --- a/src/models/lfm2.cpp +++ b/src/models/lfm2.cpp @@ -9,6 +9,8 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params ggml_tensor * cur = build_inp_embd(model.tok_embd); cb(cur, "model.embed_tokens", -1); + ggml_build_forward_expand(gf, cur); + ggml_tensor * inp_pos = build_inp_pos(); auto * inp_hybrid = build_inp_mem_hybrid(); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -40,12 +42,12 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params cur = ggml_add(ctx0, cur, ffn_out); } - cur = build_norm(cur, model.tok_norm, NULL, LLM_NORM_RMS, -1); - cb(cur, "model.embedding_norm", -1); + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); res->t_embd = cur; cur = build_lora_mm(model.output, cur); - cb(cur, "lm_head", -1); + cb(cur, "result_output", -1); res->t_logits = cur; From c386114922dcc54a041ed99f855defda3fa7f225 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 27 Nov 2025 16:34:13 +0200 Subject: [PATCH 12/33] arch : add description about LLM_TENSOR_INFOS (#17550) --- src/llama-arch.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 6da9e0a371..f6e26245ef 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -2487,6 +2487,16 @@ static const std::map> LLM_TENSOR_N }, }; +// declare information about the model weight tensors: +// - the layer in which the tensor is going to be used. this is needed in order to assign the correct buffer type for the weight +// - the operator which is going to use the weight. this is needed to determine if the respective backend supports the operator +// +// for example, input layers are usually assigned to CPU/host buffer types +// +// a mismatch between the declared information and the actual layer/op in which the tensor is used can lead to sub-optimal +// assignment of the buffer types and extra overhead during computation +// example: https://github.com/ggml-org/llama.cpp/pull/17548 +// static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, {LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, From 4abef75f2cf2eee75eb5083b30a94cf981587394 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 27 Nov 2025 08:48:00 -0600 Subject: [PATCH 13/33] vulkan: Implement SOLVE_TRI (#17486) * vulkan: Implement SOLVE_TRI * load B matrix through shared memory * use FLOAT_TYPE --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 93 +++++++++++++++++++ .../ggml-vulkan/vulkan-shaders/solve_tri.comp | 72 ++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 + 3 files changed, 167 insertions(+) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index b4ab85292f..8b50ebd259 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -399,6 +399,18 @@ struct vk_conv2d_pipeline_state { } }; +struct vk_solve_tri_pipeline_state { + vk_solve_tri_pipeline_state(uint32_t N, uint32_t K) + : N(N), K(K) {} + + uint32_t N, K; + + bool operator<(const vk_solve_tri_pipeline_state &b) const { + return std::tie(N, K) < + std::tie(b.N, b.K); + } +}; + enum shader_reduction_mode { SHADER_REDUCTION_MODE_SHMEM, SHADER_REDUCTION_MODE_HYBRID, @@ -711,6 +723,7 @@ struct vk_device_struct { vk_pipeline pipeline_cumsum_f32; vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_count_equal_i32; + std::map pipeline_solve_tri_f32; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; @@ -4002,6 +4015,14 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); + for (auto &s : device->pipeline_solve_tri_f32) { + const vk_solve_tri_pipeline_state &state = s.first; + ggml_vk_create_pipeline( + device, s.second, "solve_tri_f32", + solve_tri_f32_len, solve_tri_f32_data, "main", 3, + sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K }, 1, true); + } + #define IM2COL(bda) \ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ @@ -8496,6 +8517,26 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_cumsum_f32; } return nullptr; + case GGML_OP_SOLVE_TRI: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + + vk_solve_tri_pipeline_state solve_tri_pipeline_state(src0->ne[0], src1->ne[0]); + + vk_pipeline pipeline = nullptr; + + { + std::lock_guard guard(ctx->device->mutex); + auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state); + if (it != ctx->device->pipeline_solve_tri_f32.end()) { + pipeline = it->second; + } else { + ctx->device->pipeline_solve_tri_f32[solve_tri_pipeline_state] = pipeline = std::make_shared(); + } + } + + return pipeline; + } + return nullptr; case GGML_OP_ARGMAX: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { return ctx->device->pipeline_argmax_f32; @@ -8832,6 +8873,18 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { nr, 1, 1 }; } } break; + case GGML_OP_SOLVE_TRI: + { + uint32_t nr = (uint32_t)(ne02 * ne03); + if (nr > 262144) { + elements = { 512, 512, CEIL_DIV(nr, 262144) }; + } else if (nr > 512) { + elements = { 512, CEIL_DIV(nr, 512), 1 }; + } else { + elements = { nr, 1, 1 }; + } + } + break; case GGML_OP_RMS_NORM: if (ctx->do_add_rms_partials) { // Run one element per thread, 128 threads per workgroup @@ -10260,6 +10313,21 @@ static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subct ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }); } +static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOLVE_TRI, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }); +} + static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const int32_t s0 = dst->op_params[0]; const int32_t s1 = dst->op_params[1]; @@ -11871,6 +11939,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_COUNT_EQUAL: ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node); + break; + case GGML_OP_SOLVE_TRI: + ggml_vk_solve_tri(ctx, compute_ctx, src0, src1, node); + break; case GGML_OP_IM2COL: ggml_vk_im2col(ctx, compute_ctx, src0, src1, node); @@ -13916,6 +13988,25 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } return false; } + case GGML_OP_SOLVE_TRI: + { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + + if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) { + return false; + } + const uint32_t N = op->src[0]->ne[0]; + const uint32_t K = op->src[1]->ne[0]; + // K dimension limited to workgroup size + if (K > 128) { + return false; + } + if (N * N * sizeof(float) + N * K * sizeof(float) > device->properties.limits.maxComputeSharedMemorySize) { + return false; + } + return true; + } case GGML_OP_ARGMAX: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_COUNT_EQUAL: @@ -14588,6 +14679,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_COUNT_EQUAL) { tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_SOLVE_TRI) { + tensor_clone = ggml_solve_tri(ggml_ctx, src_clone[0], src_clone[1], true, true, false); } else if (tensor->op == GGML_OP_IM2COL) { const int32_t s0 = tensor->op_params[0]; const int32_t s1 = tensor->op_params[1]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp b/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp new file mode 100644 index 0000000000..253a9e7efe --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp @@ -0,0 +1,72 @@ +#version 450 + +#include "types.glsl" +#include "generic_binary_head.glsl" + +layout (constant_id = 1) const uint N = 64; +layout (constant_id = 2) const uint K = 32; + +layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in; + +uint a_base, b_base, x_base; + +FLOAT_TYPE get_a(uint r, uint c) { + return FLOAT_TYPE(data_a[a_base + r * p.nb01 + c * p.nb00]); +} + +FLOAT_TYPE get_b(uint r, uint c) { + return FLOAT_TYPE(data_b[b_base + r * p.nb11 + c * p.nb10]); +} + +void store_x(uint r, uint c, FLOAT_TYPE v) { + data_d[x_base + r * p.nb21 + c * p.nb20] = D_TYPE(v); +} + +shared FLOAT_TYPE shA[N * N]; +shared FLOAT_TYPE shB[N * K]; + +void main() { + const uint batch = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + if (batch >= p.ne02 * p.ne03) { + return; + } + + const uint i3 = batch / p.ne22; + const uint i2 = batch % p.ne22; + a_base = get_aoffset() + i2 * p.nb02 + i3 * p.nb03; + b_base = get_boffset() + i2 * p.nb12 + i3 * p.nb13; + x_base = get_doffset() + i2 * p.nb22 + i3 * p.nb23; + + // Load the A matrix into shA + [[unroll]] for (uint i = 0; i < N * N; i += gl_WorkGroupSize.x) { + uint idx = i + tid; + if (((N * N) % gl_WorkGroupSize.x == 0) || idx < N * N) { + shA[idx] = get_a(idx / N, idx % N); + } + } + // Load the B matrix into shB + [[unroll]] for (uint i = 0; i < N * K; i += gl_WorkGroupSize.x) { + uint idx = i + tid; + if (((N * K) % gl_WorkGroupSize.x == 0) || idx < N * K) { + shB[idx] = get_b(idx / K, idx % K); + } + } + barrier(); + + FLOAT_TYPE X[N]; + // Each thread solves one column + if (tid < K) { + [[unroll]] for (int r = 0; r < N; ++r) { + FLOAT_TYPE b = shB[r * K + tid]; + // Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r] + [[unroll]] for (int c = 0; c < r; ++c) { + b -= shA[r * N + c] * X[c]; + } + FLOAT_TYPE x = b / shA[r * N + r]; + X[r] = x; + store_x(r, tid, x); + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 4a802ab1c2..d695baa4a7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -944,6 +944,8 @@ void process_shaders() { string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("solve_tri_f32", "solve_tri.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + for (auto transpose : {false, true}) { for (auto unroll : {false, true}) { for (auto a_f16 : {false, true}) { From efaaccdd69cd9db777584c2a062f70c0526a6fb5 Mon Sep 17 00:00:00 2001 From: Neo Zhang Jianyu Date: Fri, 28 Nov 2025 08:50:56 +0800 Subject: [PATCH 14/33] refactor pad_reflect_1d to make the UT case pass (#17204) Co-authored-by: Zhang Jianyu --- ggml/src/ggml-sycl/common.hpp | 26 ++++++ ggml/src/ggml-sycl/pad_reflect_1d.cpp | 130 ++++++++++++++++---------- ggml/src/ggml-sycl/pad_reflect_1d.hpp | 2 + 3 files changed, 107 insertions(+), 51 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 338fa08cda..637630c1d2 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -617,4 +617,30 @@ static __dpct_inline__ float get_alibi_slope(const float max_bias, return dpct::pow(base, exph); } +static const sycl::uint3 init_fastdiv_values(uint32_t d) { + GGML_ASSERT(d != 0); + + uint32_t L = 0; + while (L < 32 && (uint32_t{ 1 } << L) < d) { + L++; + } + + uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1); + return sycl::uint3(mp, L, d); +} + + +static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_values) { + const uint32_t hi = sycl::mul_hi(n, fastdiv_values.x()); + return (hi + n) >> fastdiv_values.y(); +} + + +static __dpct_inline__ sycl::uint2 fast_div_modulo(uint32_t n, const sycl::uint3 fastdiv_values) { + const uint32_t div_val = fastdiv(n, fastdiv_values); + const uint32_t mod_val = n - div_val * fastdiv_values.z(); + return sycl::uint2(div_val, mod_val); +} + + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/pad_reflect_1d.cpp b/ggml/src/ggml-sycl/pad_reflect_1d.cpp index e56655a98a..85e993628c 100644 --- a/ggml/src/ggml-sycl/pad_reflect_1d.cpp +++ b/ggml/src/ggml-sycl/pad_reflect_1d.cpp @@ -1,72 +1,100 @@ #include "pad_reflect_1d.hpp" -void pad_reflect_1d_f32(const float* src,float* dst, - const int64_t ne0, const int64_t ne02, const int p0, const int p1, - const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, - const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, - const sycl::nd_item<3> &item_ct1){ +static void pad_reflect_1d_kernel_f32( + const void *__restrict__ src0, void *__restrict__ dst, const int64_t ne0, + const int64_t ne00, const sycl::uint3 ne01, const int64_t ne02, + const int64_t ne03, const int64_t nb00, const int64_t nb01, + const int64_t nb02, const int64_t nb03, const int64_t nb0, + const int64_t nb1, const int64_t nb2, const int64_t nb3, const int p0, + const int p1, sycl::nd_item<3> item_ct1) { - const int i0 = item_ct1.get_group(0) * SYCL_CONCAT_BLOCK_SIZE + item_ct1.get_local_id(0); - const int i1 = item_ct1.get_group(1); - const int g2 = item_ct1.get_group(2); - const int i2 = g2 % ne02; - const int i3 = g2 / ne02; + const int64_t i3 = item_ct1.get_group(0); + const int64_t i2 = item_ct1.get_group(1); - if (i0 >= p0 + ne0 + p1) return; + const sycl::uint2 div_mod_packed = + fast_div_modulo(item_ct1.get_group(2), ne01); + const int64_t tile1 = div_mod_packed.y(); + const int64_t tile0 = div_mod_packed.x(); + const int64_t i1 = tile1; + const int64_t i0 = + item_ct1.get_local_id(2) + tile0 * item_ct1.get_local_range(2); - int t = i0 - p0; - int period = 2 * ne0 -2; - int m = t % period; - m += (m < 0) * period; - int center = ne0 -1; - int srci0 = center - abs(center - m); + if (i0 >= ne0 || i1 >= ne01.z() || i2 >= ne02 || i3 >= ne03) { + return; + } - int offest_src = i3*nb3 + i2*nb2 + i1*nb1 + srci0*nb0; - int offest_dst = i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00; - dst[offest_dst] = src[offest_src]; + const char *src0_ptr = + (const char *)src0 + i3 * nb03 + i2 * nb02 + i1 * nb01; + char *dst_ptr = (char *)dst + i3 * nb3 + i2 * nb2 + i1 * nb1; + const int64_t rel_i0 = i0 - p0; // relative i0 in src0 + int64_t src_idx; + + if (rel_i0 < 0) { + // Left padding - reflect + src_idx = -rel_i0; + } else if (rel_i0 < ne00) { + // Middle - copy + src_idx = rel_i0; + } else { + // Right padding - reflect + src_idx = 2 * ne00 - 2 - rel_i0; + } + const float value = *(const float *)(src0_ptr + src_idx * nb00); + *(float *)(dst_ptr + i0 * nb0) = value; + + GGML_UNUSED(p1); } -void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst){ +void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context &ctx, + ggml_tensor *dst) { - const ggml_tensor * src0 = dst->src[0]; - queue_ptr stream = ctx.stream(); + const ggml_tensor *src0 = dst->src[0]; + dpct::queue_ptr stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); - const int32_t * opts = (const int32_t *) dst->op_params; + const int32_t *opts = (const int32_t *)dst->op_params; const int p0 = opts[0]; const int p1 = opts[1]; - const int64_t ne0 = src0->ne[0]; + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const sycl::uint3 ne01_packed = init_fastdiv_values(ne01); + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; - const int64_t ne00 = dst->ne[0]; - const int64_t ne01 = dst->ne[1]; - const int64_t ne02 = dst->ne[2]; - const int64_t ne03 = dst->ne[3]; + const int64_t ne0 = dst->ne[0]; - const int64_t nb00 = dst->nb[0]; - const int64_t nb01 = dst->nb[1]; - const int64_t nb02 = dst->nb[2]; - const int64_t nb03 = dst->nb[3]; - const int64_t nb0 = src0->nb[0]; - const int64_t nb1 = src0->nb[1]; - const int64_t nb2 = src0->nb[2]; - const int64_t nb3 = src0->nb[3]; + GGML_ASSERT(ne0 == ne00 + p0 + p1); - int num_blocks = (ne00 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE; - sycl::range<3> global(num_blocks * SYCL_CONCAT_BLOCK_SIZE, ne01, ne02*ne03); - sycl::range<3> local(SYCL_CONCAT_BLOCK_SIZE, 1, 1); + constexpr int64_t bx = SYCL_PAD_REFLECT_1D_BLOCK_SIZE; + const int64_t tiles0 = (ne0 + bx - 1) / bx; + const dpct::dim3 grid_dims((unsigned)(ne01 * tiles0), (unsigned)ne02, + (unsigned)ne03); + const dpct::dim3 block_dims((unsigned)bx, 1, 1); - stream->parallel_for( - sycl::nd_range<3>(global, - local), - [=](sycl::nd_item<3> item_ct1) { pad_reflect_1d_f32( - (const float *) src0->data, (float *) dst->data, - ne0, ne02, p0, p1, - nb0, nb1, nb2, nb3, - nb00, nb01, nb02, nb03 - , item_ct1); - }); + stream->submit([&](sycl::handler &cgh) { + auto src0_data_ct0 = src0->data; + auto dst_data_ct1 = dst->data; + auto src0_nb_ct7 = src0->nb[0]; + auto src0_nb_ct8 = src0->nb[1]; + auto src0_nb_ct9 = src0->nb[2]; + auto src0_nb_ct10 = src0->nb[3]; + auto dst_nb_ct11 = dst->nb[0]; + auto dst_nb_ct12 = dst->nb[1]; + auto dst_nb_ct13 = dst->nb[2]; + auto dst_nb_ct14 = dst->nb[3]; + + cgh.parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + pad_reflect_1d_kernel_f32( + src0_data_ct0, dst_data_ct1, ne0, ne00, + ne01_packed, ne02, ne03, src0_nb_ct7, + src0_nb_ct8, src0_nb_ct9, src0_nb_ct10, + dst_nb_ct11, dst_nb_ct12, dst_nb_ct13, + dst_nb_ct14, p0, p1, item_ct1); + }); + }); } diff --git a/ggml/src/ggml-sycl/pad_reflect_1d.hpp b/ggml/src/ggml-sycl/pad_reflect_1d.hpp index a24509dea6..45aaf9a911 100644 --- a/ggml/src/ggml-sycl/pad_reflect_1d.hpp +++ b/ggml/src/ggml-sycl/pad_reflect_1d.hpp @@ -3,6 +3,8 @@ #include "common.hpp" +#define SYCL_PAD_REFLECT_1D_BLOCK_SIZE 256 + void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst); #endif // GGML_SYCL_PAD_REFLECT_1D_HPP From cd0e3a7a3b93b5c2c2aeceeb778b931125b1125d Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Fri, 28 Nov 2025 05:15:32 +0100 Subject: [PATCH 15/33] SOLVE_TRI CUDA kernel for small matrices (#17457) --- ggml/src/ggml-cuda/ggml-cuda.cu | 6 + ggml/src/ggml-cuda/solve_tri.cu | 203 +++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/solve_tri.cuh | 3 + tests/test-backend-ops.cpp | 3 + 4 files changed, 215 insertions(+) create mode 100644 ggml/src/ggml-cuda/solve_tri.cu create mode 100644 ggml/src/ggml-cuda/solve_tri.cuh diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index aa6570765a..6463921a6e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -53,6 +53,7 @@ #include "ggml-cuda/set.cuh" #include "ggml-cuda/set-rows.cuh" #include "ggml-cuda/pad_reflect_1d.cuh" +#include "ggml-cuda/solve_tri.cuh" #include "ggml.h" #include @@ -2717,6 +2718,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_OPT_STEP_SGD: ggml_cuda_opt_step_sgd(ctx, dst); break; + case GGML_OP_SOLVE_TRI: + ggml_cuda_op_solve_tri(ctx, dst); + break; default: return false; } @@ -4255,6 +4259,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: return true; + case GGML_OP_SOLVE_TRI: + return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32; default: return false; } diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu new file mode 100644 index 0000000000..2e2b39720f --- /dev/null +++ b/ggml/src/ggml-cuda/solve_tri.cu @@ -0,0 +1,203 @@ +#include "common.cuh" +#include "ggml.h" +#include "solve_tri.cuh" + +#define MAX_N_FAST 64 +#define MAX_K_FAST 32 + +// ====================== +// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction +// ====================== +// When ncols_template == 0 the bounds for the loops in this function are not +// known and can't be unrolled. As we want to keep pragma unroll for all other +// cases we supress the clang transformation warning here. +#ifdef __clang__ +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ +template +static __global__ void solve_tri_f32_fast(const float * __restrict__ A, + const float * __restrict__ B, + float * __restrict__ X, + const uint3 ne02, + const size_t nb02, + const size_t nb03, + const size_t nb12, + const size_t nb13, + const size_t nb2, + const size_t nb3, + const int n_arg, + const int k_arg) { + const int n = n_template == 0 ? n_arg : n_template; + const int k = k_template == 0 ? k_arg : k_template; + + const int batch_idx = blockIdx.x; + const int lane = threadIdx.x; + const int col_idx = threadIdx.y; + + if (col_idx >= k) { + return; + } + + const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02); + const int64_t i02 = i02_i03.y; + const int64_t i03 = i02_i03.x; + + const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03); + const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13); + float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); + + __shared__ float sA[MAX_N_FAST * MAX_N_FAST]; + __shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)]; + + const int offset = threadIdx.x + threadIdx.y * blockDim.x; + +#pragma unroll + for (int i = 0; i < n * n; i += k * WARP_SIZE) { + int i0 = i + offset; + if (i0 < n * n) { + sA[i0] = A_batch[i0]; + } + } + + const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE; + +#pragma unroll + for (int i = 0; i < rows_per_warp; i++) { + const int i0 = lane + i * WARP_SIZE; + if (i0 < n) { + sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx]; + } + } + + __syncthreads(); + +#pragma unroll + for (int row = 0; row < n; ++row) { + float sum = 0.0f; + + { + int j = lane; + if (j < row) { + sum += sA[row * n + j] * sXt[col_idx * n + j]; + } + } + if (row >= WARP_SIZE) { + int j = WARP_SIZE + lane; + if (j < row) { + sum += sA[row * n + j] * sXt[col_idx * n + j]; + } + } + + sum = warp_reduce_sum(sum); + + if (lane == 0) { + const float b_val = sXt[col_idx * n + row]; + const float a_diag = sA[row * n + row]; + // no safeguards for division by zero because that indicates corrupt + // data anyway + sXt[col_idx * n + row] = (b_val - sum) / a_diag; + } + } + + __syncthreads(); + +#pragma unroll + for (int i = 0; i < rows_per_warp; i++) { + const int i0 = lane + i * WARP_SIZE; + if (i0 < n) { + X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0]; + } + } +} +#ifdef __clang__ +# pragma clang diagnostic pop +#endif // __clang__ + +static void solve_tri_f32_cuda(const float * A, + const float * B, + float * X, + int n, + int k, + int64_t ne02, + int64_t ne03, + size_t nb02, + size_t nb03, + size_t nb12, + size_t nb13, + size_t nb2, + size_t nb3, + cudaStream_t stream) { + const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02); + dim3 threads(WARP_SIZE, k); + dim3 grid(ne02 * ne03); + if (n == 64) { + switch (k) { + case 32: + solve_tri_f32_fast<64, 32> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 16: + solve_tri_f32_fast<64, 16> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 14: + solve_tri_f32_fast<64, 14> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 12: + solve_tri_f32_fast<64, 12> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 10: + solve_tri_f32_fast<64, 10> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 8: + solve_tri_f32_fast<64, 8> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 6: + solve_tri_f32_fast<64, 6> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 4: + solve_tri_f32_fast<64, 4> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 2: + solve_tri_f32_fast<64, 2> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 1: + solve_tri_f32_fast<64, 1> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + default: + solve_tri_f32_fast<0, 0> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k); + } + } else { // run general case + solve_tri_f32_fast<0, 0> + <<>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k); + } +} + +void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // A (triangular n x x matrix) + const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns) + + ggml_is_contiguous(src0); + ggml_is_contiguous(src1); + + const int64_t n = src0->ne[0]; + const int64_t k = src1->ne[0]; + + GGML_ASSERT(n <= 64); + GGML_ASSERT(k <= 32); + + solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2], + src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float), + src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float), + dst->nb[3] / sizeof(float), ctx.stream()); +} diff --git a/ggml/src/ggml-cuda/solve_tri.cuh b/ggml/src/ggml-cuda/solve_tri.cuh new file mode 100644 index 0000000000..639992396a --- /dev/null +++ b/ggml/src/ggml-cuda/solve_tri.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d7ac2bc178..60bab47b9f 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7935,6 +7935,9 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416)); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 })); + test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 })); + for (int bs : {1, 2, 3, 4, 5, 8, 512}) { for (ggml_type type_a : all_types) { for (ggml_type type_b : {GGML_TYPE_F32}) { From 6bca76ff5ea5c7efe9c62e60852d88d350403d58 Mon Sep 17 00:00:00 2001 From: yulo <77381088+zhang-hui-yulo@users.noreply.github.com> Date: Fri, 28 Nov 2025 15:24:30 +0800 Subject: [PATCH 16/33] HIP: enable mul_mat_f for RDNA4 (#17437) * enable mmf for rdna4 * move some mmvf to mmf * revert lds128 for wmma loading * Revert "revert lds128 for wmma loading" This reverts commit db9ae8b6b4738a5def5b393caa1611d52133e9b5. * Revert "enable mmf for rdna4" This reverts commit 698c9f24187b990e35c3b73a8067e5387e6ddbd4. * Revert "move some mmvf to mmf" This reverts commit 99b92bd6653cc8593607f641e44606391691792f. * enable mul_mat for rdna4 --------- Co-authored-by: zhang hui --- ggml/src/ggml-cuda/mmf.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 5c51a22256..be2ad1c6b6 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const return false; } } else { - if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) { + if (src1_ncols > 16) { return false; } } From 15d2b46b4dec7747ca50c2f950f6f482ea0d6198 Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Fri, 28 Nov 2025 10:33:51 +0200 Subject: [PATCH 17/33] rpc : cache and reuse compute graphs (#15405) Store the last computed graph and reuse it when possible. Also do not return response from GRAPH_COMPUTE and assume it always completes successfully. If this this is not the case, the server closes the connection. This saves us a network round trip to the server. --- ggml/include/ggml-rpc.h | 2 +- ggml/src/ggml-rpc/ggml-rpc.cpp | 111 +++++++++++++++++++++++++++------ 2 files changed, 92 insertions(+), 21 deletions(-) diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index e6dca3f62b..832c26c61d 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -8,7 +8,7 @@ extern "C" { #endif #define RPC_PROTO_MAJOR_VERSION 3 -#define RPC_PROTO_MINOR_VERSION 0 +#define RPC_PROTO_MINOR_VERSION 5 #define RPC_PROTO_PATCH_VERSION 0 #define GGML_RPC_MAX_SERVERS 16 diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index a38df5a97e..48fd99a762 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -106,6 +106,7 @@ enum rpc_cmd { RPC_CMD_GET_ALLOC_SIZE, RPC_CMD_HELLO, RPC_CMD_DEVICE_COUNT, + RPC_CMD_GRAPH_RECOMPUTE, RPC_CMD_COUNT, }; @@ -205,10 +206,6 @@ struct rpc_msg_copy_tensor_rsp { uint8_t result; }; -struct rpc_msg_graph_compute_rsp { - uint8_t result; -}; - struct rpc_msg_get_device_memory_req { uint32_t device; }; @@ -217,6 +214,11 @@ struct rpc_msg_get_device_memory_rsp { uint64_t free_mem; uint64_t total_mem; }; + +struct rpc_msg_graph_recompute_req { + uint32_t device; +}; + #pragma pack(pop) // RPC data structures @@ -234,10 +236,35 @@ struct ggml_backend_rpc_buffer_type_context { size_t max_size; }; +struct graph_cache { + + bool is_cached(const ggml_cgraph * cgraph) { + if ((int)last_graph.size() != cgraph->n_nodes) { + return false; + } + for (int i = 0; i < cgraph->n_nodes; i++) { + if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) { + return false; + } + } + return true; + } + + void add(const ggml_cgraph * cgraph) { + last_graph.resize(cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; i++) { + memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)); + } + } + + std::vector last_graph; +}; + struct ggml_backend_rpc_context { std::string endpoint; uint32_t device; std::string name; + graph_cache gc; }; struct ggml_backend_rpc_buffer_context { @@ -815,13 +842,24 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; - std::vector input; - serialize_graph(rpc_ctx->device, cgraph, input); - rpc_msg_graph_compute_rsp response; - auto sock = get_socket(rpc_ctx->endpoint); - bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response)); - RPC_STATUS_ASSERT(status); - return (enum ggml_status)response.result; + + GGML_ASSERT(cgraph->n_nodes > 0); + bool reuse = rpc_ctx->gc.is_cached(cgraph); + if (reuse) { + rpc_msg_graph_recompute_req request; + request.device = rpc_ctx->device; + auto sock = get_socket(rpc_ctx->endpoint); + bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request)); + RPC_STATUS_ASSERT(status); + } else { + rpc_ctx->gc.add(cgraph); + std::vector input; + serialize_graph(rpc_ctx->device, cgraph, input); + auto sock = get_socket(rpc_ctx->endpoint); + bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size()); + RPC_STATUS_ASSERT(status); + } + return GGML_STATUS_SUCCESS; } static ggml_backend_i ggml_backend_rpc_interface = { @@ -880,7 +918,8 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) { ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { /* .endpoint = */ endpoint, /* .device = */ device, - /* .name = */ dev_name + /* .name = */ dev_name, + /* .gc = */ {}, }; auto reg = ggml_backend_rpc_add_server(endpoint); ggml_backend_t backend = new ggml_backend { @@ -920,8 +959,9 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, class rpc_server { public: - rpc_server(std::vector backends, const char * cache_dir) - : backends(std::move(backends)), cache_dir(cache_dir) { + rpc_server(std::vector all_backends, const char * cache_dir) + : backends(std::move(all_backends)), cache_dir(cache_dir) { + stored_graphs.resize(backends.size()); } ~rpc_server(); @@ -936,11 +976,17 @@ public: bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response); bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response); bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response); - bool graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response); + bool graph_compute(const std::vector & input); + bool graph_recompute(const rpc_msg_graph_recompute_req & request); bool init_tensor(const rpc_msg_init_tensor_req & request); bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response); bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response); + struct stored_graph { + ggml_context_ptr ctx_ptr; + ggml_cgraph * graph; + }; + private: bool get_cached_file(uint64_t hash, std::vector & data); ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor); @@ -953,6 +999,8 @@ private: std::vector backends; const char * cache_dir; std::unordered_set buffers; + // store the last computed graph for each backend + std::vector stored_graphs; }; void rpc_server::hello(rpc_msg_hello_rsp & response) { @@ -1394,7 +1442,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id, return result; } -bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response) { +bool rpc_server::graph_compute(const std::vector & input) { // serialization format: // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) | if (input.size() < 2*sizeof(uint32_t)) { @@ -1455,7 +1503,24 @@ bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph } } ggml_status status = ggml_backend_graph_compute(backends[device], graph); - response.result = status; + GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC"); + stored_graphs[device].ctx_ptr.swap(ctx_ptr); + stored_graphs[device].graph = graph; + return true; +} + +bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) { + uint32_t device = request.device; + if (device >= backends.size()) { + return false; + } + if (stored_graphs[device].graph == nullptr) { + return false; + } + ggml_cgraph * graph = stored_graphs[device].graph; + LOG_DBG("[%s] device: %u\n", __func__, device); + ggml_status status = ggml_backend_graph_compute(backends[device], graph); + GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC"); return true; } @@ -1690,11 +1755,17 @@ static void rpc_serve_client(const std::vector & backends, const if (!recv_msg(sockfd, input)) { return; } - rpc_msg_graph_compute_rsp response; - if (!server.graph_compute(input, response)) { + if (!server.graph_compute(input)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + break; + } + case RPC_CMD_GRAPH_RECOMPUTE: { + rpc_msg_graph_recompute_req request; + if (!recv_msg(sockfd, &request, sizeof(request))) { + return; + } + if (!server.graph_recompute(request)) { return; } break; From 35cf8887e119eb9b9f090349129e1e71a9eb608b Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 28 Nov 2025 03:07:29 -0600 Subject: [PATCH 18/33] vulkan: Implement GGML_OP_TRI (#17503) * vulkan: Implement GGML_OP_TRI * check types match --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 28 +++++++++++- ggml/src/ggml-vulkan/vulkan-shaders/tri.comp | 43 +++++++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 3 ++ 3 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/tri.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8b50ebd259..73562bc1be 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -649,6 +649,7 @@ struct vk_device_struct { vk_pipeline pipeline_sin_f32; vk_pipeline pipeline_cos_f32; vk_pipeline pipeline_log[2]; + vk_pipeline pipeline_tri[2]; vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; @@ -3876,6 +3877,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); } + ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1); @@ -8290,6 +8294,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_log[dst->type == GGML_TYPE_F16]; } return nullptr; + case GGML_OP_TRI: + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16]; + } + return nullptr; case GGML_OP_CLAMP: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_clamp_f32; @@ -8991,6 +9001,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_LOG: + case GGML_OP_TRI: case GGML_OP_CLAMP: case GGML_OP_PAD: case GGML_OP_ROLL: @@ -9671,6 +9682,13 @@ static void ggml_vk_log(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LOG, vk_op_unary_push_constants_init(src0, dst)); } +static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); + p.param1 = ggml_get_op_params_f32(dst, 0); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p)); +} + static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); p.param1 = ggml_get_op_params_f32(dst, 0); @@ -11794,6 +11812,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_LOG: ggml_vk_log(ctx, compute_ctx, src0, node); + break; + case GGML_OP_TRI: + ggml_vk_tri(ctx, compute_ctx, src0, node); + break; case GGML_OP_CLAMP: ggml_vk_clamp(ctx, compute_ctx, src0, node); @@ -13919,7 +13941,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_OPT_STEP_SGD: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_LOG: - return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; + case GGML_OP_TRI: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + op->type == op->src[0]->type; case GGML_OP_ARGSORT: { if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { @@ -14510,6 +14534,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_cos(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_LOG) { tensor_clone = ggml_log(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_TRI) { + tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0)); } else if (tensor->op == GGML_OP_CLAMP) { const float * params = (const float *)tensor->op_params; tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp b/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp new file mode 100644 index 0000000000..e18d0ffa30 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp @@ -0,0 +1,43 @@ +#version 450 + +#include "rte.glsl" +#include "types.glsl" +#include "generic_unary_head.glsl" + +#define GGML_TRI_TYPE_UPPER_DIAG 0 +#define GGML_TRI_TYPE_UPPER 1 +#define GGML_TRI_TYPE_LOWER_DIAG 2 +#define GGML_TRI_TYPE_LOWER 3 + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L); + const uint i02_offset = i02*p.ne01*p.ne00; + const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L); + const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; + + int param = floatBitsToInt(p.param1); + bool pass = false; + switch (param) { + case GGML_TRI_TYPE_UPPER_DIAG: pass = i00 >= i01; break; + case GGML_TRI_TYPE_UPPER: pass = i00 > i01; break; + case GGML_TRI_TYPE_LOWER_DIAG: pass = i00 <= i01; break; + case GGML_TRI_TYPE_LOWER: pass = i00 < i01; break; + } + + if (pass) { + const float val = float(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val); + } else { + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index d695baa4a7..214a743b97 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -846,6 +846,9 @@ void process_shaders() { string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); From 73955f7d2a3ce1f36d7ecc14495e08957b51d113 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 28 Nov 2025 10:29:09 +0100 Subject: [PATCH 19/33] CUDA: no FP16 arithmetic for vector FA kernel (#17558) --- ggml/src/ggml-cuda/common.cuh | 8 ++++++-- ggml/src/ggml-cuda/fattn-common.cuh | 4 ++-- ggml/src/ggml-cuda/fattn-vec.cuh | 32 ++++++++++++++--------------- 3 files changed, 24 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 99ec96869a..81fb312409 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -558,8 +558,12 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v acc += v.y*u.y; } -static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) { #if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA)) +#define V_DOT2_F32_F16_AVAILABLE +#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA)) + +static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) { +#ifdef V_DOT2_F32_F16_AVAILABLE asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u)); #else #ifdef FAST_FP16_AVAILABLE @@ -571,7 +575,7 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, acc += tmpv.x * tmpu.x; acc += tmpv.y * tmpu.y; #endif // FAST_FP16_AVAILABLE -#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA)) +#endif // V_DOT2_F32_F16_AVAILABLE } static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) { diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 218ccff14e..5cdd4bb211 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -55,11 +55,11 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( ggml_cuda_memcpy_1(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne); #pragma unroll for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); #else ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); -#endif // FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE } } diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index e1838fdded..6e63e860ac 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -86,11 +86,11 @@ static __global__ void flash_attn_ext_vec( constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ(); constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE constexpr dequantize_V_t dequantize_V = get_dequantize_V(); #else constexpr dequantize_V_t dequantize_V = get_dequantize_V(); -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. @@ -112,13 +112,13 @@ static __global__ void flash_attn_ext_vec( constexpr int ne_KQ = ncols*D; constexpr int ne_combine = nwarps*V_cols_per_iter*D; -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}}; __shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine]; #else float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}}; __shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine]; -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE float KQ_max[ncols]; float KQ_sum[ncols]; @@ -129,11 +129,11 @@ static __global__ void flash_attn_ext_vec( } // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely. #else float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized. -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; if constexpr (Q_q8_1) { @@ -191,7 +191,7 @@ static __global__ void flash_attn_ext_vec( __syncthreads(); } else { -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE const half2 scale_h2 = make_half2(scale, scale); #pragma unroll for (int j = 0; j < ncols; ++j) { @@ -233,7 +233,7 @@ static __global__ void flash_attn_ext_vec( Q_reg[j][k].y *= scale; } } -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE } const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; @@ -291,7 +291,7 @@ static __global__ void flash_attn_ext_vec( KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j]; KQ[j*nthreads + tid] = KQ_reg[j]; -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); #pragma unroll for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { @@ -303,7 +303,7 @@ static __global__ void flash_attn_ext_vec( VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale; VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale; } -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE } #ifndef GGML_USE_HIP @@ -314,7 +314,7 @@ static __global__ void flash_attn_ext_vec( for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) { const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V); -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE half2 KQ_k[ncols]; #pragma unroll for (int j = 0; j < ncols; ++j) { @@ -353,7 +353,7 @@ static __global__ void flash_attn_ext_vec( } } } -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE } } @@ -374,7 +374,7 @@ static __global__ void flash_attn_ext_vec( KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f); -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); #pragma unroll for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { @@ -386,7 +386,7 @@ static __global__ void flash_attn_ext_vec( VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale; VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale; } -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE } } @@ -421,7 +421,7 @@ static __global__ void flash_attn_ext_vec( const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new); KQ_max[j_VKQ] = kqmax_new; -#ifdef FAST_FP16_AVAILABLE +#ifdef V_DOT2_F32_F16_AVAILABLE half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2) + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2); @@ -452,7 +452,7 @@ static __global__ void flash_attn_ext_vec( ggml_cuda_memcpy_1(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]); ggml_cuda_memcpy_1(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]); } -#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE KQ_sum[j_VKQ] *= kqmax_scale; KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]); From ff55414c42522adbeaa1bd9c52c0e9db16942484 Mon Sep 17 00:00:00 2001 From: "Piotr Wilkin (ilintar)" Date: Fri, 28 Nov 2025 12:02:56 +0100 Subject: [PATCH 20/33] model : Qwen3 Next (#16095) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Qwen3 Next - cleaned up version * Whitespaces and stuff * Correct minor errors * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret * Misc. fixes. * Clean up code, add missing hybrid qualifier * Did someone transpose the SOLVE_TRI result matrix? Perhaps... * Whitespace * Proper tensors for cb calls * Use llama-graph.h vertical alignment * BROKEN: chunking * Set new tensors as inputs. * Proper chunk logic * It's the circle of life... * More shenanigans for n_seq > 1 * Nail in the coffin? * Fix Windows build * Eh, one fails on Windows, the other fails on Mac... just use general capture. * quant : cleanup * model : cleanup * qwen3 : cleanup * cont : cleanup * cont : cleanup * ggml : revert change * qwen3 : cleanup * cont : cleanup * Readd cmath * qwen3 : fix typo * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret * Usual suspects * fix my bad suggestion --------- Co-authored-by: Sigbjørn Skjæret Co-authored-by: Georgi Gerganov --- convert_hf_to_gguf.py | 30 + .../scripts/causal/run-converted-model.sh | 8 +- .../scripts/causal/run-org-model.py | 8 +- ggml/src/ggml-cpu/ops.cpp | 3 +- gguf-py/gguf/constants.py | 33 + gguf-py/gguf/tensor_mapping.py | 22 +- src/CMakeLists.txt | 1 + src/llama-arch.cpp | 35 + src/llama-arch.h | 2 + src/llama-context.cpp | 4 + src/llama-hparams.h | 2 +- src/llama-model.cpp | 100 +- src/llama-model.h | 4 + src/llama-quant.cpp | 18 +- src/models/models.h | 52 +- src/models/qwen3next.cpp | 1042 +++++++++++++++++ 16 files changed, 1345 insertions(+), 19 deletions(-) create mode 100644 src/models/qwen3next.cpp diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index daaf0bf497..866aa536f1 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4183,6 +4183,36 @@ class Qwen3MoeModel(Qwen2MoeModel): super().set_vocab() +@ModelBase.register("Qwen3NextForCausalLM") +class Qwen3NextModel(Qwen2MoeModel): + model_arch = gguf.MODEL_ARCH.QWEN3NEXT + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_ssm_conv_kernel(self.hparams["linear_conv_kernel_dim"]) + self.gguf_writer.add_ssm_state_size(self.hparams["linear_key_head_dim"]) + self.gguf_writer.add_ssm_group_count(self.hparams["linear_num_key_heads"]) + self.gguf_writer.add_ssm_time_step_rank(self.hparams["linear_num_value_heads"]) + self.gguf_writer.add_ssm_inner_size(self.hparams["linear_value_head_dim"] * self.hparams["linear_num_value_heads"]) + if (rope_dim := self.hparams.get("head_dim")) is None: + rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25))) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.startswith("mtp"): + return [] # ignore MTP layers for now + if name.endswith(".A_log"): + data_torch = -torch.exp(data_torch) + elif name.endswith(".dt_bias"): + name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias" + elif "conv1d" in name: + data_torch = data_torch.squeeze() + elif name.endswith("norm.weight") and not name.endswith("linear_attn.norm.weight"): + data_torch = data_torch + 1 + + yield from super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("RND1") class RND1Model(Qwen2MoeModel): model_arch = gguf.MODEL_ARCH.RND1 diff --git a/examples/model-conversion/scripts/causal/run-converted-model.sh b/examples/model-conversion/scripts/causal/run-converted-model.sh index f5f567d4ff..529e9987b0 100755 --- a/examples/model-conversion/scripts/causal/run-converted-model.sh +++ b/examples/model-conversion/scripts/causal/run-converted-model.sh @@ -4,6 +4,11 @@ set -e # First try command line argument, then environment variable, then file CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}" +MODEL_TESTING_PROMPT="${2:-"$MODEL_TESTING_PROMPT"}" + +if [ -z "$MODEL_TESTING_PROMPT"]; then + MODEL_TESTING_PROMPT="Hello, my name is" +fi # Final check if we have a model path if [ -z "$CONVERTED_MODEL" ]; then @@ -14,7 +19,8 @@ if [ -z "$CONVERTED_MODEL" ]; then fi echo $CONVERTED_MODEL +echo $MODEL_TESTING_PROMPT cmake --build ../../build --target llama-logits -j8 -../../build/bin/llama-logits -m "$CONVERTED_MODEL" "Hello, my name is" +../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT" diff --git a/examples/model-conversion/scripts/causal/run-org-model.py b/examples/model-conversion/scripts/causal/run-org-model.py index 85529c612f..7d2b80057c 100755 --- a/examples/model-conversion/scripts/causal/run-org-model.py +++ b/examples/model-conversion/scripts/causal/run-org-model.py @@ -184,8 +184,12 @@ model_name = os.path.basename(model_path) # of using AutoModelForCausalLM. print(f"Model class: {model.__class__.__name__}") -prompt = "Hello, my name is" -input_ids = tokenizer(prompt, return_tensors="pt").input_ids +device = next(model.parameters()).device +if os.getenv("MODEL_TESTING_PROMPT"): + prompt = os.getenv("MODEL_TESTING_PROMPT") +else: + prompt = "Hello, my name is" +input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) print(f"Input tokens: {input_ids}") print(f"Input text: {repr(prompt)}") diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index d405696539..2745fc54e1 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9766,7 +9766,8 @@ static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params } const float diag = A_batch[i00 * n + i00]; - GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix"); + assert(diag != 0.0f && "Zero diagonal in triangular matrix"); + X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag; } } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 6f5a742e04..266d19f9dd 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -366,6 +366,7 @@ class MODEL_ARCH(IntEnum): QWEN2VL = auto() QWEN3 = auto() QWEN3MOE = auto() + QWEN3NEXT = auto() QWEN3VL = auto() QWEN3VLMOE = auto() PHI2 = auto() @@ -531,6 +532,7 @@ class MODEL_TENSOR(IntEnum): SSM_D = auto() SSM_NORM = auto() SSM_OUT = auto() + SSM_BETA_ALPHA = auto() # qwen3next TIME_MIX_W0 = auto() TIME_MIX_W1 = auto() TIME_MIX_W2 = auto() @@ -736,6 +738,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.QWEN2VL: "qwen2vl", MODEL_ARCH.QWEN3: "qwen3", MODEL_ARCH.QWEN3MOE: "qwen3moe", + MODEL_ARCH.QWEN3NEXT: "qwen3next", MODEL_ARCH.QWEN3VL: "qwen3vl", MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe", MODEL_ARCH.PHI2: "phi2", @@ -900,6 +903,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", + MODEL_TENSOR.SSM_BETA_ALPHA: "blk.{bid}.ssm_ba", MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0", MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2", @@ -1569,6 +1573,35 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.QWEN3NEXT: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_GATE, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_INP_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_BETA_ALPHA, + MODEL_TENSOR.SSM_OUT + ], MODEL_ARCH.QWEN3VL: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 8c7ed10f2e..a7b0973979 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -672,10 +672,11 @@ class TensorNameMap: ), MODEL_TENSOR.SSM_IN: ( - "model.layers.{bid}.in_proj", # mamba-hf - "backbone.layers.{bid}.mixer.in_proj", # mamba - "model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid - "model.layers.layers.{bid}.mixer.in_proj", # plamo2 + "model.layers.{bid}.in_proj", # mamba-hf + "backbone.layers.{bid}.mixer.in_proj", # mamba + "model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid + "model.layers.layers.{bid}.mixer.in_proj", # plamo2 + "model.layers.{bid}.linear_attn.in_proj_qkvz", # qwen3next ), MODEL_TENSOR.SSM_CONV1D: ( @@ -683,6 +684,7 @@ class TensorNameMap: "backbone.layers.{bid}.mixer.conv1d", # mamba "model.layers.{bid}.mamba.conv1d", # jamba falcon-h1 granite-hybrid "model.layers.layers.{bid}.mixer.conv1d", # plamo2 + "model.layers.{bid}.linear_attn.conv1d", # qwen3next ), MODEL_TENSOR.SSM_X: ( @@ -697,6 +699,7 @@ class TensorNameMap: "backbone.layers.{bid}.mixer.dt_proj", # mamba "model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid "model.layers.layers.{bid}.mixer.dt_proj", # plamo2 + "model.layers.{bid}.linear_attn.dt_proj", # qwen3next ), MODEL_TENSOR.SSM_DT_NORM: ( @@ -709,6 +712,7 @@ class TensorNameMap: "backbone.layers.{bid}.mixer.A_log", # mamba "model.layers.{bid}.mamba.A_log", # jamba falcon-h1 granite-hybrid "model.layers.layers.{bid}.mixer.A_log", # plamo2 + "model.layers.{bid}.linear_attn.A_log", # qwen3next ), MODEL_TENSOR.SSM_B_NORM: ( @@ -731,17 +735,23 @@ class TensorNameMap: ), MODEL_TENSOR.SSM_NORM: ( - "model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid - "backbone.layers.{bid}.mixer.norm", # mamba2 + "model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid + "model.layers.{bid}.linear_attn.norm", # qwen3next + "backbone.layers.{bid}.mixer.norm", # mamba2 ), MODEL_TENSOR.SSM_OUT: ( "model.layers.{bid}.out_proj", # mamba-hf "backbone.layers.{bid}.mixer.out_proj", # mamba "model.layers.{bid}.mamba.out_proj", # jamba falcon-h1 granite-hybrid + "model.layers.{bid}.linear_attn.out_proj", # qwen3next "model.layers.layers.{bid}.mixer.out_proj", # plamo2 ), + MODEL_TENSOR.SSM_BETA_ALPHA: ( + "model.layers.{bid}.linear_attn.in_proj_ba", # qwen3next + ), + MODEL_TENSOR.TIME_MIX_W0: ( "model.layers.{bid}.attention.w0", # rwkv7 ), diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f7a8c9841e..67c7807e09 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -114,6 +114,7 @@ add_library(llama models/qwen3vl.cpp models/qwen3vl-moe.cpp models/qwen3moe.cpp + models/qwen3next.cpp models/refact.cpp models/rnd1.cpp models/rwkv6-base.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index f6e26245ef..8571a2e025 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -32,6 +32,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN2VL, "qwen2vl" }, { LLM_ARCH_QWEN3, "qwen3" }, { LLM_ARCH_QWEN3MOE, "qwen3moe" }, + { LLM_ARCH_QWEN3NEXT, "qwen3next" }, { LLM_ARCH_QWEN3VL, "qwen3vl" }, { LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" }, { LLM_ARCH_PHI2, "phi2" }, @@ -829,6 +830,38 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_QWEN3NEXT, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + }, + }, { LLM_ARCH_QWEN3VL, { @@ -2556,6 +2589,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_BETA_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, @@ -2754,6 +2788,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_LFM2: case LLM_ARCH_LFM2MOE: case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_QWEN3NEXT: return true; default: return false; diff --git a/src/llama-arch.h b/src/llama-arch.h index 9ad3157bf6..150646478a 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -36,6 +36,7 @@ enum llm_arch { LLM_ARCH_QWEN2VL, LLM_ARCH_QWEN3, LLM_ARCH_QWEN3MOE, + LLM_ARCH_QWEN3NEXT, LLM_ARCH_QWEN3VL, LLM_ARCH_QWEN3VLMOE, LLM_ARCH_PHI2, @@ -381,6 +382,7 @@ enum llm_tensor { LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, + LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next LLM_TENSOR_TIME_MIX_W0, LLM_TENSOR_TIME_MIX_W1, LLM_TENSOR_TIME_MIX_W2, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 2aa6d52a24..a58914429c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1,5 +1,6 @@ #include "llama-context.h" +#include "llama-arch.h" #include "llama-impl.h" #include "llama-batch.h" #include "llama-io.h" @@ -1386,6 +1387,9 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes() const { + if (model.arch == LLM_ARCH_QWEN3NEXT) { + return std::max(8192u, 32u*model.n_tensors()); + } return std::max(1024u, 8u*model.n_tensors()); } diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 9203af83b2..c3a53be793 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -6,7 +6,7 @@ // bump if necessary #define LLAMA_MAX_LAYERS 512 -#define LLAMA_MAX_EXPERTS 384 // Kimi-K2 +#define LLAMA_MAX_EXPERTS 512 // Qwen3 Next enum llama_expert_gating_func_type { LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index cba875f114..c2a545531a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2,7 +2,6 @@ #include "llama-impl.h" #include "llama-mmap.h" -#include "llama-batch.h" #include "llama-cparams.h" #include "llama-model-loader.h" @@ -2225,6 +2224,29 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_QWEN3NEXT: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0); // TODO: extract the magic 4 from "full_attention_interval" + } + + switch (hparams.n_layer) { + case 80: type = LLM_TYPE_80B_A3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -6415,6 +6437,74 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_QWEN3NEXT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + // Calculate projection sizes + const int64_t qkvz_dim = key_dim * 2 + value_dim * 2; + const int64_t ba_dim = n_v_heads * 2; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recurrent(i)) { + // Attention layers + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, 0); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + + // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -6685,6 +6775,7 @@ void llama_model::print_info() const { arch == LLM_ARCH_FALCON_H1 || arch == LLM_ARCH_PLAMO2 || arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_QWEN3NEXT || arch == LLM_ARCH_NEMOTRON_H) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); @@ -7426,7 +7517,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { case LLM_ARCH_PANGU_EMBED: { llm = std::make_unique(*this, params); - }break; + } break; + case LLM_ARCH_QWEN3NEXT: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -7653,6 +7748,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_COGVLM: case LLM_ARCH_PANGU_EMBED: case LLM_ARCH_AFMOE: + case LLM_ARCH_QWEN3NEXT: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/src/llama-model.h b/src/llama-model.h index f730c49540..f8342cf2cb 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -113,6 +113,7 @@ enum llm_type { LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, + LLM_TYPE_80B_A3B, // Qwen3 Next LLM_TYPE_100B_A6B, LLM_TYPE_106B_A12B, // GLM-4.5-Air LLM_TYPE_230B_A10B, // Minimax M2 @@ -309,6 +310,9 @@ struct llama_layer { struct ggml_tensor * ssm_conv1d_b = nullptr; struct ggml_tensor * ssm_dt_b = nullptr; + // qwen3next + struct ggml_tensor * ssm_beta_alpha = nullptr; + // rwkv struct ggml_tensor * time_mix_w1 = nullptr; struct ggml_tensor * time_mix_w2 = nullptr; diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index a56b2626ae..0b23eaef3a 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -681,7 +681,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str()); continue; - } else if (remapped_name != it.first) { + } + + if (remapped_name != it.first) { ggml_set_name(it.second.tensor, remapped_name.c_str()); LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor)); } @@ -726,13 +728,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: { const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin(); // attention layers have a non-zero number of kv heads - int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0); + int32_t n_layer_attn = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0); if (llama_model_has_encoder(&model)) { - // now n_attn_layer is the number of attention layers in the encoder + // now n_layer_attn is the number of attention layers in the encoder // for each decoder block, there are 2 attention layers - n_attn_layer += 2 * model.hparams.dec_n_layer; + n_layer_attn += 2 * model.hparams.dec_n_layer; } - GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected"); + + // note: for linear-attention models (such as Qwen3 Next) this is the number of linear layers + const int32_t n_layer_recr = std::count(model.hparams.recurrent_layer_arr.begin(), model.hparams.recurrent_layer_arr.end(), true); + + LLAMA_LOG_INFO("%s: n_layer_attn = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_attn, n_layer_recr, pruned_attention_w); + + GGML_ASSERT((qs.n_attention_wv == n_layer_attn - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected"); } size_t total_size_org = 0; diff --git a/src/models/models.h b/src/models/models.h index 5f019c59be..7ba225b478 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -2,8 +2,9 @@ #include "../llama-model.h" #include "../llama-graph.h" -#include "../llama-memory-recurrent.h" +// TODO: remove in follow-up PR - move to .cpp files +#include "../llama-memory-recurrent.h" #include struct llm_graph_context_mamba : public llm_graph_context { @@ -421,7 +422,56 @@ struct llm_build_qwen3vl : public llm_graph_context { struct llm_build_qwen3vlmoe : public llm_graph_context { llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_qwen3next : public llm_graph_context_mamba { + llm_build_qwen3next(const llama_model & model, const llm_graph_params & params); +private: + ggml_tensor * build_layer_attn( + llm_graph_input_attn_kv * inp_attn, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int il); + ggml_tensor * build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + ggml_tensor * causal_mask, + ggml_tensor * identity, + int il); + + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + int il); + + ggml_tensor * build_delta_net_recurrent( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + int il); + + ggml_tensor * build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + int il); + + ggml_tensor * build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer); + + const llama_model & model; +}; struct llm_build_qwen : public llm_graph_context { llm_build_qwen(const llama_model & model, const llm_graph_params & params); diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp new file mode 100644 index 0000000000..c8f1b5ec90 --- /dev/null +++ b/src/models/qwen3next.cpp @@ -0,0 +1,1042 @@ +#include "ggml.h" +#include "models.h" + +#define CHUNK_SIZE 64 + +llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) : + llm_graph_context_mamba(params), model(model) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + cb(inpL, "model.embed_tokens", -1); + + auto * inp = build_inp_mem_hybrid(); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + ggml_tensor * causal_mask = + ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens, ubatch.n_seq_tokens), 1.0f), + GGML_TRI_TYPE_LOWER); + + ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, ubatch.n_seq_tokens), 1.0f)); + + ggml_build_forward_expand(gf, causal_mask); + ggml_build_forward_expand(gf, identity); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // Determine layer type and build appropriate attention mechanism + if (hparams.is_recurrent(il)) { + // Linear attention layer (gated delta net) + cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, il); + } else { + // Full attention layer + cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Residual connection + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "attn_residual", il); + + // Save the tensor before post-attention norm for residual connection + ggml_tensor * ffn_residual = cur; + + // Post-attention norm + ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(attn_post_norm, "attn_post_norm", il); + + // FFN layer (MoE or dense) - without residual connection + cur = build_layer_ffn(attn_post_norm, il); + cb(cur, "ffn_out", il); + + // Residual connection for FFN - add to the tensor from before post_attention_layernorm + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "post_moe", il); + + // Input for next layer + inpL = cur; + } + cur = inpL; + + // Final norm + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // LM head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + int il) { + GGML_ASSERT(ggml_is_contiguous(q)); + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(beta)); + GGML_ASSERT(ggml_is_contiguous(state)); + + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(v->ne[2] == n_tokens); + GGML_ASSERT(k->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + + GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case + + // TODO: can this ever be false? + const bool use_qk_l2norm = true; + + if (use_qk_l2norm) { + const float eps_norm = hparams.f_norm_rms_eps; + + q = ggml_l2_norm(ctx0, q, eps_norm); + k = ggml_l2_norm(ctx0, k, eps_norm); + } + + const float scale = 1.0f / sqrtf(S_v); + + q = ggml_scale(ctx0, q, scale); + + beta = ggml_sigmoid(ctx0, beta); + + ggml_tensor * causal_diag_mask = ggml_add(ctx0, causal_mask, identity); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(beta, "beta_in", il); + cb(g, "g_in", il); + + q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); + + beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); + state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); + + cb(q, "q_perm", il); + cb(k, "k_perm", il); + cb(v, "v_perm", il); + cb(beta, "beta_perm", il); + cb(g, "g_perm", il); + cb(state, "state_in", il); + + GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); + GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); + + // Do padding + const int64_t chunk_size = CHUNK_SIZE; + + const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; + const int64_t n_chunks = (n_tokens + pad) / chunk_size; + + q = ggml_pad(ctx0, q, 0, pad, 0, 0); + k = ggml_pad(ctx0, k, 0, pad, 0, 0); + v = ggml_pad(ctx0, v, 0, pad, 0, 0); + g = ggml_pad(ctx0, g, pad, 0, 0, 0); + beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); + + cb(q, "q_pad", il); + cb(k, "k_pad", il); + cb(v, "v_pad", il); + cb(beta, "beta_pad", il); + cb(g, "g_pad", il); + + ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); + ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); + + cb(v_beta, "v_beta", il); + cb(k_beta, "k_beta", il); + + ggml_tensor * chunked_mask = + ggml_view_4d(ctx0, causal_mask, chunk_size, + chunk_size, causal_mask->ne[2], causal_mask->ne[3], + causal_mask->nb[1], causal_mask->nb[2], causal_mask->nb[3], 0); + + ggml_tensor * chunked_diag_mask = + ggml_view_4d(ctx0, causal_diag_mask, chunk_size, + chunk_size, causal_diag_mask->ne[2], causal_diag_mask->ne[3], + causal_diag_mask->nb[1], causal_diag_mask->nb[2], causal_diag_mask->nb[3], 0); + + ggml_tensor * chunked_identity = + ggml_view_4d(ctx0, identity, chunk_size, + chunk_size, identity->ne[2], identity->ne[3], + identity->nb[1], identity->nb[2], identity->nb[3], 0); + + q = ggml_cont_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); + k = ggml_cont_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); + k_beta = ggml_cont_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); + v = ggml_cont_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); + v_beta = ggml_cont_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); + + g = ggml_cont_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); + beta = ggml_cont_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); + + ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); + + cb(g_cumsum, "g_cumsum", il); + + ggml_tensor * gcs_i = ggml_cont_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); + ggml_tensor * gcs_j = ggml_cont_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); + + ggml_tensor * gcs_j_broadcast = + ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); + + ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); + + cb(decay_mask, "decay_mask", il); + + decay_mask = ggml_mul(ctx0, decay_mask, chunked_diag_mask); + decay_mask = ggml_exp(ctx0, decay_mask); + decay_mask = ggml_mul(ctx0, decay_mask, chunked_diag_mask); + + ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); + + ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); + ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, chunked_mask)); + + cb(attn, "attn_pre_solve", il); + + ggml_tensor * attn_lower = ggml_mul(ctx0, attn, chunked_mask); + ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, chunked_identity, attn_lower), attn_lower); + + ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); + attn = ggml_mul(ctx0, lin_solve, chunked_mask); + attn = ggml_add(ctx0, attn, chunked_identity); + + cb(attn, "attn_solved", il); + + v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); + + ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); + ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); + + ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); + + cb(kbeta_gexp, "kbeta_gexp", il); + + ggml_tensor * k_cumdecay = + ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); + + cb(k_cumdecay, "k_cumdecay", il); + + ggml_tensor * core_attn_out = nullptr; + ggml_tensor * new_state = ggml_dup(ctx0, state); + + cb(new_state, "new_state", il); + + for (int64_t chunk = 0; chunk < n_chunks; chunk++) { + auto chunkify = [=](ggml_tensor * t) { + return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); + }; + + auto chunkify_g = [=](ggml_tensor * t) { + return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, t->ne[1], 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); + }; + + ggml_tensor * k_chunk = chunkify(k); + ggml_tensor * q_chunk = chunkify(q); + ggml_tensor * v_chunk = chunkify(v); + + ggml_tensor * g_cs_chunk = chunkify_g(g_cumsum); + ggml_tensor * g_cs_chunk_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cs_chunk)); + + ggml_tensor * decay_mask_chunk = chunkify(decay_mask); + ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay); + + ggml_tensor * gexp_chunk = ggml_exp(ctx0, g_cs_chunk_t); + + // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + attn = ggml_mul_mat(ctx0, k_chunk, q_chunk); + attn = ggml_mul(ctx0, attn, decay_mask_chunk); + attn = ggml_mul(ctx0, attn, ggml_add(ctx0, chunked_identity, chunked_mask)); + + ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); + + // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); + + // v_new = v_i - v_prime + ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); + ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); + + // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); + + // core_attn_out[:, :, i] = attn_inter + attn @ v_new + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn); + + ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); + + core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1); + + // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) + // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() + // key_gdiff = key * g_diff.unsqueeze(-1) + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + + ggml_tensor * g_cum_last = + ggml_cont(ctx0, ggml_view_4d(ctx0, g_cs_chunk_t, g_cs_chunk_t->ne[0], 1, g_cs_chunk_t->ne[2], g_cs_chunk_t->ne[3], + g_cs_chunk_t->nb[1], g_cs_chunk_t->nb[2], g_cs_chunk_t->nb[3], + g_cs_chunk_t->nb[0] * (g_cs_chunk_t->ne[1] - 1))); + + ggml_tensor * gexp_last = + ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]); + + ggml_tensor * g_cum_last_3d = + ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]); + + ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cs_chunk, g_cs_chunk->ne[0], g_cs_chunk->ne[2], g_cs_chunk->ne[3]); + + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d)); + + ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); + + ggml_tensor * key_gdiff = ggml_mul(ctx0, k_chunk, + ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1], + g_diff_exp->ne[2] * g_diff_exp->ne[3])); + + ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff))); + + new_state = ggml_add(ctx0, + ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last, gexp_last->ne[0], gexp_last->ne[1], H_v, n_seqs)), + ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); + } + + core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs); + + ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, S_v, n_tokens, H_v, n_seqs, core_attn_out->nb[1], core_attn_out->nb[2], core_attn_out->nb[3], 0); + cb(output_tokens, "output_tokens", il); + + // flatten output + ggml_tensor * flat_output = + ggml_cont_1d(ctx0, ggml_permute(ctx0, output_tokens, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs); + + ggml_tensor * flat_state = ggml_cont_1d(ctx0, new_state, S_v * S_v * H_v * n_seqs); + + return ggml_concat(ctx0, flat_output, flat_state, 0); +} + +ggml_tensor * llm_build_qwen3next::build_delta_net_recurrent( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + int il) { + GGML_ASSERT(ggml_is_contiguous(q)); + GGML_ASSERT(ggml_is_contiguous(k)); + GGML_ASSERT(ggml_is_contiguous(v)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(beta)); + GGML_ASSERT(ggml_is_contiguous(state)); + + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(v->ne[2] == n_tokens); + GGML_ASSERT(k->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + + GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case + + // TODO: can this ever be false? + const bool use_qk_l2norm = true; + + if (use_qk_l2norm) { + const float eps_norm = hparams.f_norm_rms_eps; + + q = ggml_l2_norm(ctx0, q, eps_norm); + k = ggml_l2_norm(ctx0, k, eps_norm); + } + + const float scale = 1.0f / sqrtf(S_v); + + q = ggml_scale(ctx0, q, scale); + + beta = ggml_sigmoid(ctx0, beta); + + ggml_tensor * causal_diag_mask = ggml_add(ctx0, causal_mask, identity); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(beta, "beta_in", il); + cb(g, "g_in", il); + + q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); + + beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); + state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); + + cb(q, "q_perm", il); + cb(k, "k_perm", il); + cb(v, "v_perm", il); + cb(beta, "beta_perm", il); + cb(g, "g_perm", il); + cb(state, "state_in", il); + + GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); + GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); + + ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); + ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); + + ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); + + cb(k_beta, "k_beta", il); + cb(v_beta, "v_beta", il); + cb(g_cumsum, "g_cumsum", il); + + ggml_tensor * gcs_i = ggml_cont_4d(ctx0, g_cumsum, n_tokens, 1, H_v, n_seqs); // [chunk_size, 1, n_tokens, n_seqs] + ggml_tensor * gcs_j = ggml_cont_4d(ctx0, g_cumsum, 1, n_tokens, H_v, n_seqs); // [1, chunk_size, n_tokens, n_seqs] + + // Broadcast both tensors to [chunk_size, chunk_size, H_v, n_seqs] + // ggml_tensor * gcs_i_broadcast = + // ggml_repeat_4d(ctx0, gcs_i, GGML_DELTA_NET_CHUNK, GGML_DELTA_NET_CHUNK, num_chunks * H_v, + // n_seqs); // [chunk_size, 1, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs] + // Don't need this, this one will get auto-broadcast + ggml_tensor * gcs_j_broadcast = + ggml_repeat_4d(ctx0, gcs_j, n_tokens, n_tokens, H_v, n_seqs); // [1, chunk_size, H_v, n_seqs] -> [chunk_size, chunk_size, H_v, n_seqs] + + ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); + + // Apply lower triangular mask to ensure attention is causal (only past tokens influence current) + decay_mask = ggml_mul(ctx0, decay_mask, causal_diag_mask); + // Apply exponential to get the decay mask values + decay_mask = ggml_exp(ctx0, decay_mask); + // Apply lower triangular mask again to ensure only lower triangular values remain + decay_mask = ggml_mul(ctx0, decay_mask, causal_diag_mask); + + cb(decay_mask, "decay_mask", il); + + // attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); + + cb(kmulkbeta, "kmulkbeta", il); + + ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); + ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); + + cb(attn, "attn_pre_rec", il); + + // for i in range(1, chunk_size): + // row = attn[..., i, :i].clone() + // sub = attn[..., :i, :i].clone() + // attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + // attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + // + // We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A) + ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); + ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); + + ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); + attn = ggml_mul(ctx0, lin_solve, causal_mask); + attn = ggml_add(ctx0, attn, identity); + + // value = attn @ v_beta + v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); + + cb(v, "value_beta", il); + + // k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); + ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); + + cb(gexp, "g_cum_exp", il); + + ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); + + cb(kbeta_gexp, "kbeta_gexp", il); + + ggml_tensor * k_cumdecay = + ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); + + cb(k_cumdecay, "k_cumdecay", il); + + // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + attn = ggml_mul_mat(ctx0, k, q); + attn = ggml_mul(ctx0, attn, decay_mask); + attn = ggml_mul(ctx0, attn, ggml_add(ctx0, identity, causal_mask)); + + cb(attn, "attn_decay_key", il); + + ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state)); + + // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay); + + cb(v_prime, "v_prime", il); + + // v_new = v_i - v_prime + ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v, v_prime), v_prime); + + ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); + + cb(v_new, "v_new", il); + + // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + ggml_tensor * q_g_exp = ggml_mul(ctx0, q, gexp); + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); + + cb(attn_inter, "attn_inter", il); + + // core_attn_out[:, :, i] = attn_inter + attn @ v_new + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn); + + cb(v_attn, "v_attn", il); + + ggml_tensor * core_attn_out = ggml_add(ctx0, attn_inter, v_attn); + + cb(core_attn_out, "core_attn_out", il); + + // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) + // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() + // key_gdiff = key * g_diff.unsqueeze(-1) + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + + ggml_tensor * g_cum_last = + ggml_cont(ctx0, ggml_view_4d(ctx0, g_cumsum_t, g_cumsum_t->ne[0], 1, g_cumsum_t->ne[2], g_cumsum_t->ne[3], + g_cumsum_t->nb[1], g_cumsum_t->nb[2], g_cumsum_t->nb[3], + g_cumsum_t->nb[0] * (g_cumsum_t->ne[1] - 1))); + + cb(g_cum_last, "g_cum_last", il); + + ggml_tensor * gexp_last = + ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]); + + cb(gexp_last, "gexp_last", il); + + ggml_tensor * g_cum_last_3d = + ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]); + + cb(g_cum_last_3d, "g_cum_last_3d", il); + + ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cumsum, g_cumsum->ne[0], g_cumsum->ne[2], g_cumsum->ne[3]); + + cb(g_cumsum_3d, "g_cumsum_3d", il); + + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d)); + + cb(g_diff, "g_diff", il); + + ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); + + cb(g_diff_exp, "g_diff_exp", il); + + ggml_tensor * key_gdiff = ggml_mul(ctx0, k, + ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1], + g_diff_exp->ne[2] * g_diff_exp->ne[3])); + + cb(key_gdiff, "key_gdiff", il); + + ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff))); + + cb(kgdmulvnew, "kgdmulvnew", il); + + state = ggml_add(ctx0, ggml_mul(ctx0, state, gexp_last), kgdmulvnew); + + cb(state, "new_state", il); + + // flatten output + ggml_tensor * flat_output = + ggml_cont_1d(ctx0, ggml_permute(ctx0, core_attn_out, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs); + + ggml_tensor * flat_state = ggml_cont_1d(ctx0, state, S_v * S_v * H_v * n_seqs); + + return ggml_concat(ctx0, flat_output, flat_state, 0); +} + +ggml_tensor * llm_build_qwen3next::build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer) { + ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer); + ggml_tensor * gated_silu = ggml_silu(ctx0, gate); + + return ggml_mul(ctx0, normalized, gated_silu); +} + +ggml_tensor * llm_build_qwen3next::build_layer_attn( + llm_graph_input_attn_kv * inp, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int il) { + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention + + // Qwen3Next uses a single Q projection that outputs query + gate + ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur_full, "Qcur_full", il); + + Qcur_full = ggml_reshape_4d(ctx0, Qcur_full, n_embd_head * 2, n_head, n_tokens, 1); + + // Split Q projection into query and gate + // The split should be along dimension 0 (the feature dimension) + ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, + Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0); + ggml_tensor * gate = + ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, + Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full)); + cb(Qcur, "Qcur", il); + cb(gate, "gate", il); + + // Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention + Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + cb(Qcur, "Qcur_reshaped", il); + + // Apply Q normalization + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + // Apply K normalization + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads) + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "gate_reshaped", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply RoPE + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, + freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // Attention computation + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp, + nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_pregate", il); + + ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate); + cb(gate_sigmoid, "gate_sigmoid", il); + + cur = ggml_mul(ctx0, cur, gate_sigmoid); + cb(cur, "attn_gated", il); + + cur = build_lora_mm(model.layers[il].wo, cur); + cb(cur, "attn_output", il); + + return cur; +} + +ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + ggml_tensor * causal_mask, + ggml_tensor * identity, + int il) { + const auto * mctx_cur = inp->mctx; + + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t n_seqs = ubatch.n_seqs; + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t num_k_heads = hparams.ssm_n_group; + const int64_t num_v_heads = hparams.ssm_dt_rank; + const int64_t head_v_dim = d_inner / num_v_heads; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + const auto kv_head = mctx_cur->get_head(); + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs()); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + // Input projections + ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, cur); + cb(mixed_qkvz, "linear_attn_mixed_qkvz", il); + + ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur); + cb(mixed_ba, "linear_attn_mixed_ba", il); + + int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads); + ggml_tensor * mixed_qkvz_reshaped = ggml_cont_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs); + + // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads] + int64_t ba_new_dim = 2 * num_v_heads / num_k_heads; + ggml_tensor * mixed_ba_reshaped = ggml_cont_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs); + + // Split mixed_ba into b and a (beta and alpha parameters) + int64_t split_sizes_ba[2] = { + num_v_heads / num_k_heads, // beta size + num_v_heads / num_k_heads // alpha size + }; + + ggml_tensor * b = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[0], num_k_heads, n_seq_tokens, n_seqs, + mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], 0); + cb(b, "b", il); + + ggml_tensor * a = ggml_view_4d(ctx0, mixed_ba_reshaped, split_sizes_ba[1], num_k_heads, n_seq_tokens, n_seqs, + mixed_ba_reshaped->nb[1], mixed_ba_reshaped->nb[2], mixed_ba_reshaped->nb[3], + split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped)); + cb(a, "a", il); + + // Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads] + ggml_tensor * beta = ggml_cont_3d(ctx0, b, num_v_heads, n_seq_tokens, n_seqs); + ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs); + + GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba)); + + ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); + ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); + cb(alpha_softplus, "a_softplus", il); + ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus + cb(gate, "gate", il); + + // Split mixed_qkvz into query, key, value, z + int64_t split_sizes_qkvz[4] = { + head_k_dim, // query size + head_k_dim, // key size + head_v_dim * num_v_heads / num_k_heads, // value size + head_v_dim * num_v_heads / num_k_heads // z size + }; + + ggml_tensor * query = + ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0); + cb(query, "q", il); + + ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + split_sizes_qkvz[0] * sizeof(float)); + cb(key, "k", il); + + ggml_tensor * value = + ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float)); + cb(value, "v", il); + + ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float)); + cb(z, "z", il); + + GGML_ASSERT(ggml_nelements(query) + ggml_nelements(key) + ggml_nelements(value) + ggml_nelements(z) == + ggml_nelements(mixed_qkvz)); + + // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions + // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] + ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); + cb(query_flat, "query_flat", il); + + // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] + ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); + cb(key_flat, "key_flat", il); + + // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs] + ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); + cb(value_flat, "value_flat", il); + + // Get convolution states from cache + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); + + // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state(); + + // Build the convolution states tensor + ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); + cb(conv_states, "conv_states", il); + + // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs] + ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0); + qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0); + cb(qkv_mixed, "qkv_mixed", il); + + qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); + cb(qkv_mixed, "qkv_mixed_permuted", il); + + // Calculate the total conv dimension + int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads; + + // Calculate convolution kernel size + ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; + const int64_t conv_kernel_size = conv_kernel->ne[0]; + const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + cb(conv_states, "conv_states_reshaped", il); + + ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); + cb(conv_input, "conv_input", il); + + // Update convolution state cache + // Extract the last (conv_kernel_size - 1) states from conv_input + ggml_tensor * last_conv_states = + ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], + conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); + cb(last_conv_states, "last_conv_states", il); + + ggml_tensor * state_update_target = + ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs, + kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); + cb(state_update_target, "state_update_target", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + cb(conv_states_all, "conv_states_updated", il); + + // Apply SSM convolution + ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); + cb(conv_output_proper, "conv_output_raw", il); + + conv_output_proper = ggml_cont(ctx0, ggml_transpose(ctx0, conv_output_proper)); + cb(conv_output_proper, "conv_output_pre_silu", il); + + ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper); + cb(conv_output_silu, "conv_output_silu", il); + + ggml_tensor * conv_qkv_mix = + ggml_cont_2d(ctx0, ggml_transpose(ctx0, conv_output_silu), qkv_dim, n_seq_tokens * n_seqs); + cb(conv_qkv_mix, "conv_qkv_mix", il); + + // Extract the convolved Q, K, V from conv_output + ggml_tensor * q_conv = + ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], 0); + cb(q_conv, "q_conv", il); + ggml_tensor * k_conv = + ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], + head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + cb(k_conv, "k_conv", il); + ggml_tensor * v_conv = + ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], + 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + cb(v_conv, "v_conv", il); + + // Unsqueeze them + q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + + beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs); + + ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); + state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs); + cb(state, "state_predelta", il); + + // if head keys and value keys are different, repeat to force tensors into matching shapes + if (num_k_heads != num_v_heads) { + GGML_ASSERT(num_v_heads % num_k_heads == 0); + int64_t repeat_factor = num_v_heads / num_k_heads; + + // repeat interleave: reshape to (repeat part, 1, remaining part), do repeat, then reshape back + ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs); + ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs); + + // Repeat along the third dimension (the new dimension with size 1) + ggml_tensor * q_repeated = + ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1); + ggml_tensor * k_repeated = + ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1); + + // Reshape back to merge the head and repeat dimensions + // From [head_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs] + // Back to [head_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs] + q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs); + k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs); + } + + cb(q_conv, "q_conv_predelta", il); + cb(k_conv, "k_conv_predelta", il); + cb(v_conv, "v_conv_predelta", il); + + // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens + ggml_tensor * attn_out = n_seq_tokens > CHUNK_SIZE ? + build_delta_net_chunking (q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il) : + build_delta_net_recurrent(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, il); + cb(attn_out, "attn_out", il); + + // The tensors were concatenated 1d, so we need to extract them 1d as well + const int64_t output_flat_size = head_v_dim * num_v_heads * n_seq_tokens * n_seqs; + ggml_tensor * attn_out_1d = ggml_view_1d(ctx0, attn_out, output_flat_size, 0); + cb(attn_out_1d, "attn_out_1d", il); + + ggml_tensor * attn_out_final = ggml_cont_4d(ctx0, attn_out_1d, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + cb(attn_out_final, "attn_out_reshaped", il); + + // Extract the state part (second part of the concatenated tensor) + // State starts after n_tokens elements along dimension 1 + const int64_t state_flat_size = head_v_dim * head_v_dim * num_v_heads * n_seqs; + + ggml_tensor * state_1d = + ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out)); + cb(state_1d, "state_1d", il); + + // Update the recurrent states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, state_1d, + ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + + GGML_ASSERT(ggml_nelements(attn_out_1d) + ggml_nelements(state_1d) == ggml_nelements(attn_out)); + + // Reshape both attn_out_final and z to 2D tensors for normalization + // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] + ggml_tensor * attn_out_2d_final = + ggml_cont_2d(ctx0, attn_out_final, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + + // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] + ggml_tensor * z_2d = ggml_cont_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + + // Apply gated normalization: self.norm(core_attn_out, z) + ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il); + + // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim] + ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); + cb(final_output, "final_output", il); + + // Output projection + cur = build_lora_mm(model.layers[il].ssm_out, final_output); + cb(cur, "linear_attn_out", il); + + // Reshape back to original dimensions + cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + return cur; +} + +ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int il) { + // Check if this is an MoE layer + if (model.layers[il].ffn_gate_inp != nullptr) { + // MoE branch + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, LLM_FFN_SILU, + true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); + cb(moe_out, "ffn_moe_out", il); + + // Add shared experts if present - following Qwen3Next reference implementation + if (model.layers[il].ffn_up_shexp != nullptr) { + ggml_tensor * ffn_shexp = + build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + // Apply shared expert gating as in the reference implementation + // The shared expert has its own gate that is sigmoided + // Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token) + ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur); + cb(shared_gate, "shared_expert_gate", il); + + // Apply sigmoid to the gate + shared_gate = ggml_sigmoid(ctx0, shared_gate); + cb(shared_gate, "shared_expert_gate_sigmoid", il); + + // The gate needs to be broadcast to match the dimensions of ffn_shexp + // ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1] + // We need to repeat the gate along the feature dimension + shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp); + cb(shared_gate, "shared_expert_gate_broadcast", il); + + // Apply the gate to the shared expert output + ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); + cb(ffn_shexp, "ffn_shexp_gated", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; + } + } else { + // Dense FFN branch (not currently used I believe) + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + return cur; +} From ddf9f94389a614ce005347f1c3f60ce477df1be1 Mon Sep 17 00:00:00 2001 From: Fredrik Hultin Date: Fri, 28 Nov 2025 12:57:04 +0100 Subject: [PATCH 21/33] server : add Anthropic Messages API support (#17570) * server : add Anthropic Messages API support * remove -@pytest.mark.slow from tool calling/jinja tests * server : remove unused code and slow/skip on test_anthropic_vision_base64_with_multimodal_model in test_anthropic_api.py * server : removed redundant n field logic in anthropic_params_from_json * server : use single error object instead of error_array in streaming response handler for /v1/chat/completions and use unordered_set instead of set in to_json_anthropic_stream() * server : refactor Anthropic API to use OAI conversion * make sure basic test always go first * clean up * clean up api key check, add test --------- Co-authored-by: Xuan Son Nguyen --- tools/server/README.md | 72 ++ tools/server/server-common.cpp | 242 +++++- tools/server/server-common.h | 8 +- tools/server/server-http.cpp | 21 +- tools/server/server-task.cpp | 304 ++++++- tools/server/server-task.h | 47 +- tools/server/server.cpp | 97 ++- tools/server/tests/conftest.py | 6 + tools/server/tests/unit/test_basic.py | 6 - .../tests/unit/test_compat_anthropic.py | 807 ++++++++++++++++++ tools/server/tests/unit/test_security.py | 13 + 11 files changed, 1553 insertions(+), 70 deletions(-) create mode 100644 tools/server/tests/unit/test_compat_anthropic.py diff --git a/tools/server/README.md b/tools/server/README.md index b7fc565ec6..f42bc7921c 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -7,6 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp. **Features:** * LLM inference of F16 and quantized models on GPU and CPU * [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes + * [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) compatible chat completions * Reranking endpoint (https://github.com/ggml-org/llama.cpp/pull/9510) * Parallel decoding with multi-user support * Continuous batching @@ -1352,6 +1353,77 @@ See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-r }' ``` +### POST `/v1/messages`: Anthropic-compatible Messages API + +Given a list of `messages`, returns the assistant's response. Streaming is supported via Server-Sent Events. While no strong claims of compatibility with the Anthropic API spec are made, in our experience it suffices to support many apps. + +*Options:* + +See [Anthropic Messages API documentation](https://docs.anthropic.com/en/api/messages). Tool use requires `--jinja` flag. + +`model`: Model identifier (required) + +`messages`: Array of message objects with `role` and `content` (required) + +`max_tokens`: Maximum tokens to generate (default: 4096) + +`system`: System prompt as string or array of content blocks + +`temperature`: Sampling temperature 0-1 (default: 1.0) + +`top_p`: Nucleus sampling (default: 1.0) + +`top_k`: Top-k sampling + +`stop_sequences`: Array of stop sequences + +`stream`: Enable streaming (default: false) + +`tools`: Array of tool definitions (requires `--jinja`) + +`tool_choice`: Tool selection mode (`{"type": "auto"}`, `{"type": "any"}`, or `{"type": "tool", "name": "..."}`) + +*Examples:* + +```shell +curl http://localhost:8080/v1/messages \ + -H "Content-Type: application/json" \ + -H "x-api-key: your-api-key" \ + -d '{ + "model": "gpt-4", + "max_tokens": 1024, + "system": "You are a helpful assistant.", + "messages": [ + {"role": "user", "content": "Hello!"} + ] + }' +``` + +### POST `/v1/messages/count_tokens`: Token Counting + +Counts the number of tokens in a request without generating a response. + +Accepts the same parameters as `/v1/messages`. The `max_tokens` parameter is not required. + +*Example:* + +```shell +curl http://localhost:8080/v1/messages/count_tokens \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-4", + "messages": [ + {"role": "user", "content": "Hello!"} + ] + }' +``` + +*Response:* + +```json +{"input_tokens": 10} +``` + ## More examples ### Interactive mode diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 18328f3afb..0bbc4e858f 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -725,7 +725,6 @@ std::vector tokenize_input_prompts(const llama_vocab * vocab, mtm return result; } - // // OAI utils // @@ -1048,6 +1047,222 @@ json oaicompat_chat_params_parse( return llama_params; } +json convert_anthropic_to_oai(const json & body) { + json oai_body; + + // Convert system prompt + json oai_messages = json::array(); + auto system_param = json_value(body, "system", json()); + if (!system_param.is_null()) { + std::string system_content; + + if (system_param.is_string()) { + system_content = system_param.get(); + } else if (system_param.is_array()) { + for (const auto & block : system_param) { + if (json_value(block, "type", std::string()) == "text") { + system_content += json_value(block, "text", std::string()); + } + } + } + + oai_messages.push_back({ + {"role", "system"}, + {"content", system_content} + }); + } + + // Convert messages + if (!body.contains("messages")) { + throw std::runtime_error("'messages' is required"); + } + const json & messages = body.at("messages"); + if (messages.is_array()) { + for (const auto & msg : messages) { + std::string role = json_value(msg, "role", std::string()); + + if (!msg.contains("content")) { + if (role == "assistant") { + continue; + } + oai_messages.push_back(msg); + continue; + } + + const json & content = msg.at("content"); + + if (content.is_string()) { + oai_messages.push_back(msg); + continue; + } + + if (!content.is_array()) { + oai_messages.push_back(msg); + continue; + } + + json tool_calls = json::array(); + json converted_content = json::array(); + json tool_results = json::array(); + bool has_tool_calls = false; + + for (const auto & block : content) { + std::string type = json_value(block, "type", std::string()); + + if (type == "text") { + converted_content.push_back(block); + } else if (type == "image") { + json source = json_value(block, "source", json::object()); + std::string source_type = json_value(source, "type", std::string()); + + if (source_type == "base64") { + std::string media_type = json_value(source, "media_type", std::string("image/jpeg")); + std::string data = json_value(source, "data", std::string()); + std::ostringstream ss; + ss << "data:" << media_type << ";base64," << data; + + converted_content.push_back({ + {"type", "image_url"}, + {"image_url", { + {"url", ss.str()} + }} + }); + } else if (source_type == "url") { + std::string url = json_value(source, "url", std::string()); + converted_content.push_back({ + {"type", "image_url"}, + {"image_url", { + {"url", url} + }} + }); + } + } else if (type == "tool_use") { + tool_calls.push_back({ + {"id", json_value(block, "id", std::string())}, + {"type", "function"}, + {"function", { + {"name", json_value(block, "name", std::string())}, + {"arguments", json_value(block, "input", json::object()).dump()} + }} + }); + has_tool_calls = true; + } else if (type == "tool_result") { + std::string tool_use_id = json_value(block, "tool_use_id", std::string()); + + auto result_content = json_value(block, "content", json()); + std::string result_text; + if (result_content.is_string()) { + result_text = result_content.get(); + } else if (result_content.is_array()) { + for (const auto & c : result_content) { + if (json_value(c, "type", std::string()) == "text") { + result_text += json_value(c, "text", std::string()); + } + } + } + + tool_results.push_back({ + {"role", "tool"}, + {"tool_call_id", tool_use_id}, + {"content", result_text} + }); + } + } + + if (!converted_content.empty() || has_tool_calls) { + json new_msg = {{"role", role}}; + if (!converted_content.empty()) { + new_msg["content"] = converted_content; + } else if (has_tool_calls) { + new_msg["content"] = ""; + } + if (!tool_calls.empty()) { + new_msg["tool_calls"] = tool_calls; + } + oai_messages.push_back(new_msg); + } + + for (const auto & tool_msg : tool_results) { + oai_messages.push_back(tool_msg); + } + } + } + + oai_body["messages"] = oai_messages; + + // Convert tools + if (body.contains("tools")) { + const json & tools = body.at("tools"); + if (tools.is_array()) { + json oai_tools = json::array(); + for (const auto & tool : tools) { + oai_tools.push_back({ + {"type", "function"}, + {"function", { + {"name", json_value(tool, "name", std::string())}, + {"description", json_value(tool, "description", std::string())}, + {"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()} + }} + }); + } + oai_body["tools"] = oai_tools; + } + } + + // Convert tool_choice + if (body.contains("tool_choice")) { + const json & tc = body.at("tool_choice"); + if (tc.is_object()) { + std::string type = json_value(tc, "type", std::string()); + if (type == "auto") { + oai_body["tool_choice"] = "auto"; + } else if (type == "any" || type == "tool") { + oai_body["tool_choice"] = "required"; + } + } + } + + // Convert stop_sequences to stop + if (body.contains("stop_sequences")) { + oai_body["stop"] = body.at("stop_sequences"); + } + + // Handle max_tokens (required in Anthropic, but we're permissive) + if (body.contains("max_tokens")) { + oai_body["max_tokens"] = body.at("max_tokens"); + } else { + oai_body["max_tokens"] = 4096; + } + + // Pass through common params + for (const auto & key : {"temperature", "top_p", "top_k", "stream"}) { + if (body.contains(key)) { + oai_body[key] = body.at(key); + } + } + + // Handle Anthropic-specific thinking param + if (body.contains("thinking")) { + json thinking = json_value(body, "thinking", json::object()); + std::string thinking_type = json_value(thinking, "type", std::string()); + if (thinking_type == "enabled") { + int budget_tokens = json_value(thinking, "budget_tokens", 10000); + oai_body["thinking_budget_tokens"] = budget_tokens; + } + } + + // Handle Anthropic-specific metadata param + if (body.contains("metadata")) { + json metadata = json_value(body, "metadata", json::object()); + std::string user_id = json_value(metadata, "user_id", std::string()); + if (!user_id.empty()) { + oai_body["__metadata_user_id"] = user_id; + } + } + + return oai_body; +} + json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64) { json data = json::array(); int32_t n_tokens = 0; @@ -1211,7 +1426,7 @@ std::string tokens_to_output_formatted_string(const llama_context * ctx, const l // format server-sent event (SSE), return the formatted string to send // note: if data is a json array, it will be sent as multiple events, one per item -std::string format_sse(const json & data) { +std::string format_oai_sse(const json & data) { std::ostringstream ss; auto send_single = [&ss](const json & data) { ss << "data: " << @@ -1230,6 +1445,29 @@ std::string format_sse(const json & data) { return ss.str(); } +std::string format_anthropic_sse(const json & data) { + std::ostringstream ss; + + auto send_event = [&ss](const json & event_obj) { + if (event_obj.contains("event") && event_obj.contains("data")) { + ss << "event: " << event_obj.at("event").get() << "\n"; + ss << "data: " << safe_json_to_str(event_obj.at("data")) << "\n\n"; + } else { + ss << "data: " << safe_json_to_str(event_obj) << "\n\n"; + } + }; + + if (data.is_array()) { + for (const auto & event : data) { + send_event(event); + } + } else { + send_event(data); + } + + return ss.str(); +} + bool is_valid_utf8(const std::string & str) { const unsigned char* bytes = reinterpret_cast(str.data()); const unsigned char* end = bytes + str.length(); diff --git a/tools/server/server-common.h b/tools/server/server-common.h index 868c506103..ab8aabbad0 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -294,6 +294,9 @@ json oaicompat_chat_params_parse( const oaicompat_parser_options & opt, std::vector & out_files); +// convert Anthropic Messages API format to OpenAI Chat Completions API format +json convert_anthropic_to_oai(const json & body); + // TODO: move it to server-task.cpp json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false); @@ -320,7 +323,10 @@ std::string tokens_to_output_formatted_string(const llama_context * ctx, const l // format server-sent event (SSE), return the formatted string to send // note: if data is a json array, it will be sent as multiple events, one per item -std::string format_sse(const json & data); +std::string format_oai_sse(const json & data); + +// format Anthropic-style SSE with event types +std::string format_anthropic_sse(const json & data); bool is_valid_utf8(const std::string & str); diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp index a82aa86b19..622505714c 100644 --- a/tools/server/server-http.cpp +++ b/tools/server/server-http.cpp @@ -136,15 +136,22 @@ bool server_http_context::init(const common_params & params) { return true; } - // Check for API key in the header - auto auth_header = req.get_header_value("Authorization"); + // Check for API key in the Authorization header + std::string req_api_key = req.get_header_value("Authorization"); + if (req_api_key.empty()) { + // retry with anthropic header + req_api_key = req.get_header_value("X-Api-Key"); + } + // remove the "Bearer " prefix if needed std::string prefix = "Bearer "; - if (auth_header.substr(0, prefix.size()) == prefix) { - std::string received_api_key = auth_header.substr(prefix.size()); - if (std::find(api_keys.begin(), api_keys.end(), received_api_key) != api_keys.end()) { - return true; // API key is valid - } + if (req_api_key.substr(0, prefix.size()) == prefix) { + req_api_key = req_api_key.substr(prefix.size()); + } + + // validate the API key + if (std::find(api_keys.begin(), api_keys.end(), req_api_key) != api_keys.end()) { + return true; // API key is valid } // API key is invalid or not provided diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index bc4436ba65..b447a1ef6d 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -565,15 +565,17 @@ std::vector completion_token_output::str_to_bytes(const std::stri // server_task_result_cmpl_final // json server_task_result_cmpl_final::to_json() { - switch (oaicompat) { - case OAICOMPAT_TYPE_NONE: + switch (res_type) { + case TASK_RESPONSE_TYPE_NONE: return to_json_non_oaicompat(); - case OAICOMPAT_TYPE_COMPLETION: + case TASK_RESPONSE_TYPE_OAI_CMPL: return to_json_oaicompat(); - case OAICOMPAT_TYPE_CHAT: + case TASK_RESPONSE_TYPE_OAI_CHAT: return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); + case TASK_RESPONSE_TYPE_ANTHROPIC: + return stream ? to_json_anthropic_stream() : to_json_anthropic(); default: - GGML_ASSERT(false && "Invalid oaicompat_type"); + GGML_ASSERT(false && "Invalid task_response_type"); } } @@ -768,19 +770,203 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() { return deltas; } +json server_task_result_cmpl_final::to_json_anthropic() { + std::string stop_reason = "max_tokens"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use"; + } + + json content_blocks = json::array(); + + common_chat_msg msg; + if (!oaicompat_msg.empty()) { + msg = oaicompat_msg; + } else { + msg.role = "assistant"; + msg.content = content; + } + + if (!msg.content.empty()) { + content_blocks.push_back({ + {"type", "text"}, + {"text", msg.content} + }); + } + + for (const auto & tool_call : msg.tool_calls) { + json tool_use_block = { + {"type", "tool_use"}, + {"id", tool_call.id}, + {"name", tool_call.name} + }; + + try { + tool_use_block["input"] = json::parse(tool_call.arguments); + } catch (const std::exception &) { + tool_use_block["input"] = json::object(); + } + + content_blocks.push_back(tool_use_block); + } + + json res = { + {"id", oaicompat_cmpl_id}, + {"type", "message"}, + {"role", "assistant"}, + {"content", content_blocks}, + {"model", oaicompat_model}, + {"stop_reason", stop_reason}, + {"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)}, + {"usage", { + {"input_tokens", n_prompt_tokens}, + {"output_tokens", n_decoded} + }} + }; + + return res; +} + +json server_task_result_cmpl_final::to_json_anthropic_stream() { + json events = json::array(); + + std::string stop_reason = "max_tokens"; + if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { + stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use"; + } + + bool has_text = !oaicompat_msg.content.empty(); + size_t num_tool_calls = oaicompat_msg.tool_calls.size(); + + bool text_block_started = false; + std::unordered_set tool_calls_started; + + for (const auto & diff : oaicompat_msg_diffs) { + if (!diff.content_delta.empty()) { + if (!text_block_started) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", 0}, + {"content_block", { + {"type", "text"}, + {"text", ""} + }} + }} + }); + text_block_started = true; + } + + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", 0}, + {"delta", { + {"type", "text_delta"}, + {"text", diff.content_delta} + }} + }} + }); + } + + if (diff.tool_call_index != std::string::npos) { + size_t content_block_index = (has_text ? 1 : 0) + diff.tool_call_index; + + if (tool_calls_started.find(diff.tool_call_index) == tool_calls_started.end()) { + const auto & full_tool_call = oaicompat_msg.tool_calls[diff.tool_call_index]; + + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", content_block_index}, + {"content_block", { + {"type", "tool_use"}, + {"id", full_tool_call.id}, + {"name", full_tool_call.name} + }} + }} + }); + tool_calls_started.insert(diff.tool_call_index); + } + + if (!diff.tool_call_delta.arguments.empty()) { + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", content_block_index}, + {"delta", { + {"type", "input_json_delta"}, + {"partial_json", diff.tool_call_delta.arguments} + }} + }} + }); + } + } + } + + if (has_text) { + events.push_back({ + {"event", "content_block_stop"}, + {"data", { + {"type", "content_block_stop"}, + {"index", 0} + }} + }); + } + + for (size_t i = 0; i < num_tool_calls; i++) { + size_t content_block_index = (has_text ? 1 : 0) + i; + events.push_back({ + {"event", "content_block_stop"}, + {"data", { + {"type", "content_block_stop"}, + {"index", content_block_index} + }} + }); + } + + events.push_back({ + {"event", "message_delta"}, + {"data", { + {"type", "message_delta"}, + {"delta", { + {"stop_reason", stop_reason}, + {"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)} + }}, + {"usage", { + {"output_tokens", n_decoded} + }} + }} + }); + + events.push_back({ + {"event", "message_stop"}, + {"data", { + {"type", "message_stop"} + }} + }); + + return events; +} + // // server_task_result_cmpl_partial // json server_task_result_cmpl_partial::to_json() { - switch (oaicompat) { - case OAICOMPAT_TYPE_NONE: + switch (res_type) { + case TASK_RESPONSE_TYPE_NONE: return to_json_non_oaicompat(); - case OAICOMPAT_TYPE_COMPLETION: + case TASK_RESPONSE_TYPE_OAI_CMPL: return to_json_oaicompat(); - case OAICOMPAT_TYPE_CHAT: + case TASK_RESPONSE_TYPE_OAI_CHAT: return to_json_oaicompat_chat(); + case TASK_RESPONSE_TYPE_ANTHROPIC: + return to_json_anthropic(); default: - GGML_ASSERT(false && "Invalid oaicompat_type"); + GGML_ASSERT(false && "Invalid task_response_type"); } } @@ -905,7 +1091,7 @@ json server_task_result_cmpl_partial::to_json_oaicompat_chat() { // server_task_result_embd // json server_task_result_embd::to_json() { - return oaicompat == OAICOMPAT_TYPE_EMBEDDING + return res_type == TASK_RESPONSE_TYPE_OAI_EMBD ? to_json_oaicompat() : to_json_non_oaicompat(); } @@ -936,6 +1122,102 @@ json server_task_result_rerank::to_json() { }; } +json server_task_result_cmpl_partial::to_json_anthropic() { + json events = json::array(); + bool first = (n_decoded == 1); + static bool text_block_started = false; + + if (first) { + text_block_started = false; + + events.push_back({ + {"event", "message_start"}, + {"data", { + {"type", "message_start"}, + {"message", { + {"id", oaicompat_cmpl_id}, + {"type", "message"}, + {"role", "assistant"}, + {"content", json::array()}, + {"model", oaicompat_model}, + {"stop_reason", nullptr}, + {"stop_sequence", nullptr}, + {"usage", { + {"input_tokens", n_prompt_tokens}, + {"output_tokens", 0} + }} + }} + }} + }); + } + + for (const auto & diff : oaicompat_msg_diffs) { + if (!diff.content_delta.empty()) { + if (!text_block_started) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", 0}, + {"content_block", { + {"type", "text"}, + {"text", ""} + }} + }} + }); + text_block_started = true; + } + + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", 0}, + {"delta", { + {"type", "text_delta"}, + {"text", diff.content_delta} + }} + }} + }); + } + + if (diff.tool_call_index != std::string::npos) { + size_t content_block_index = (text_block_started ? 1 : 0) + diff.tool_call_index; + + if (!diff.tool_call_delta.name.empty()) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", content_block_index}, + {"content_block", { + {"type", "tool_use"}, + {"id", diff.tool_call_delta.id}, + {"name", diff.tool_call_delta.name} + }} + }} + }); + } + + if (!diff.tool_call_delta.arguments.empty()) { + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", content_block_index}, + {"delta", { + {"type", "input_json_delta"}, + {"partial_json", diff.tool_call_delta.arguments} + }} + }} + }); + } + } + } + + return events; +} + // // server_task_result_error // diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 0271caae11..a22d7cab11 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -27,11 +27,12 @@ enum server_task_type { }; // TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common -enum oaicompat_type { - OAICOMPAT_TYPE_NONE, - OAICOMPAT_TYPE_CHAT, - OAICOMPAT_TYPE_COMPLETION, - OAICOMPAT_TYPE_EMBEDDING, +enum task_response_type { + TASK_RESPONSE_TYPE_NONE, // llama.cpp native format + TASK_RESPONSE_TYPE_OAI_CHAT, + TASK_RESPONSE_TYPE_OAI_CMPL, + TASK_RESPONSE_TYPE_OAI_EMBD, + TASK_RESPONSE_TYPE_ANTHROPIC, }; enum stop_type { @@ -66,9 +67,9 @@ struct task_params { struct common_params_sampling sampling; struct common_params_speculative speculative; - // OAI-compat fields + // response formatting bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + task_response_type res_type = TASK_RESPONSE_TYPE_NONE; std::string oaicompat_model; std::string oaicompat_cmpl_id; common_chat_syntax oaicompat_chat_syntax; @@ -227,12 +228,12 @@ struct server_task_result_cmpl_final : server_task_result { task_params generation_params; - // OAI-compat fields - bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - common_chat_msg oaicompat_msg; + // response formatting + bool verbose = false; + task_response_type res_type = TASK_RESPONSE_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + common_chat_msg oaicompat_msg; std::vector oaicompat_msg_diffs; @@ -253,6 +254,10 @@ struct server_task_result_cmpl_final : server_task_result { json to_json_oaicompat_chat(); json to_json_oaicompat_chat_stream(); + + json to_json_anthropic(); + + json to_json_anthropic_stream(); }; struct server_task_result_cmpl_partial : server_task_result { @@ -270,11 +275,11 @@ struct server_task_result_cmpl_partial : server_task_result { result_timings timings; result_prompt_progress progress; - // OAI-compat fields - bool verbose = false; - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; + // response formatting + bool verbose = false; + task_response_type res_type = TASK_RESPONSE_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; std::vector oaicompat_msg_diffs; virtual int get_index() override { @@ -292,6 +297,8 @@ struct server_task_result_cmpl_partial : server_task_result { json to_json_oaicompat(); json to_json_oaicompat_chat(); + + json to_json_anthropic(); }; struct server_task_result_embd : server_task_result { @@ -300,8 +307,8 @@ struct server_task_result_embd : server_task_result { int32_t n_tokens; - // OAI-compat fields - oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + // response formatting + task_response_type res_type = TASK_RESPONSE_TYPE_NONE; virtual int get_index() override { return index; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 0f39def379..05bbe648c1 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1255,7 +1255,7 @@ struct server_context { res->post_sampling_probs = slot.task->params.post_sampling_probs; res->verbose = slot.task->params.verbose; - res->oaicompat = slot.task->params.oaicompat; + res->res_type = slot.task->params.res_type; res->oaicompat_model = slot.task->params.oaicompat_model; res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; @@ -1297,7 +1297,7 @@ struct server_context { res->verbose = slot.task->params.verbose; res->stream = slot.task->params.stream; res->include_usage = slot.task->params.include_usage; - res->oaicompat = slot.task->params.oaicompat; + res->res_type = slot.task->params.res_type; res->oaicompat_model = slot.task->params.oaicompat_model; res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs); @@ -1328,7 +1328,7 @@ struct server_context { res->id = slot.task->id; res->index = slot.task->index; res->n_tokens = slot.task->n_tokens(); - res->oaicompat = slot.task->params.oaicompat; + res->res_type = slot.task->params.res_type; const int n_embd = llama_model_n_embd(model); @@ -2951,7 +2951,7 @@ public: data, files, req.should_stop, - OAICOMPAT_TYPE_NONE); // infill is not OAI compatible + TASK_RESPONSE_TYPE_NONE); // infill is not OAI compatible }; server_http_context::handler_t post_completions = [this](const server_http_req & req) { @@ -2962,7 +2962,7 @@ public: body, files, req.should_stop, - OAICOMPAT_TYPE_NONE); + TASK_RESPONSE_TYPE_NONE); }; server_http_context::handler_t post_completions_oai = [this](const server_http_req & req) { @@ -2973,7 +2973,7 @@ public: body, files, req.should_stop, - OAICOMPAT_TYPE_COMPLETION); + TASK_RESPONSE_TYPE_OAI_CMPL); }; server_http_context::handler_t post_chat_completions = [this](const server_http_req & req) { @@ -2988,7 +2988,38 @@ public: body_parsed, files, req.should_stop, - OAICOMPAT_TYPE_CHAT); + TASK_RESPONSE_TYPE_OAI_CHAT); + }; + + server_http_context::handler_t post_anthropic_messages = [this](const server_http_req & req) { + std::vector files; + json body = convert_anthropic_to_oai(json::parse(req.body)); + json body_parsed = oaicompat_chat_params_parse( + body, + ctx_server.oai_parser_opt, + files); + return handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + body_parsed, + files, + req.should_stop, + TASK_RESPONSE_TYPE_ANTHROPIC); + }; + + server_http_context::handler_t post_anthropic_count_tokens = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + std::vector files; + json body = convert_anthropic_to_oai(json::parse(req.body)); + json body_parsed = oaicompat_chat_params_parse( + body, + ctx_server.oai_parser_opt, + files); + + json prompt = body_parsed.at("prompt"); + llama_tokens tokens = tokenize_mixed(ctx_server.vocab, prompt, true, true); + + res->ok({{"input_tokens", static_cast(tokens.size())}}); + return res; }; // same with handle_chat_completions, but without inference part @@ -3107,11 +3138,11 @@ public: }; server_http_context::handler_t post_embeddings = [this](const server_http_req & req) { - return handle_embeddings_impl(req, OAICOMPAT_TYPE_NONE); + return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_NONE); }; server_http_context::handler_t post_embeddings_oai = [this](const server_http_req & req) { - return handle_embeddings_impl(req, OAICOMPAT_TYPE_EMBEDDING); + return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_OAI_EMBD); }; server_http_context::handler_t post_rerank = [this](const server_http_req & req) { @@ -3262,7 +3293,7 @@ private: const json & data, const std::vector & files, const std::function & should_stop, - oaicompat_type oaicompat) { + task_response_type res_type) { GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); auto res = std::make_unique(ctx_server); @@ -3279,7 +3310,7 @@ private: // process prompt std::vector inputs; - if (oaicompat && ctx_server.mctx != nullptr) { + if (res_type != TASK_RESPONSE_TYPE_NONE && ctx_server.mctx != nullptr) { // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below. inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); } else { @@ -3301,8 +3332,8 @@ private: task.id_slot = json_value(data, "id_slot", -1); // OAI-compat - task.params.oaicompat = oaicompat; - task.params.oaicompat_cmpl_id = completion_id; + task.params.res_type = res_type; + task.params.oaicompat_cmpl_id = completion_id; // oaicompat_model is already populated by params_from_json_cmpl tasks.push_back(std::move(task)); @@ -3352,10 +3383,14 @@ private: } // next responses are streamed - res->data = format_sse(first_result->to_json()); // to be sent immediately + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + res->data = format_anthropic_sse(first_result->to_json()); + } else { + res->data = format_oai_sse(first_result->to_json()); // to be sent immediately + } res->status = 200; res->content_type = "text/event-stream"; - res->next = [res_this = res.get(), oaicompat, &should_stop](std::string & output) -> bool { + res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool { if (should_stop()) { SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); return false; // should_stop condition met @@ -3372,7 +3407,10 @@ private: // check if there is more data if (!rd.has_next()) { - if (oaicompat != OAICOMPAT_TYPE_NONE) { + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + // Anthropic doesn't send [DONE], message_stop was already sent + output = ""; + } else if (res_type != TASK_RESPONSE_TYPE_NONE) { output = "data: [DONE]\n\n"; } else { output = ""; @@ -3391,7 +3429,14 @@ private: // send the results json res_json = result->to_json(); if (result->is_error()) { - output = format_sse(json {{ "error", res_json }}); + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + output = format_anthropic_sse({ + {"event", "error"}, + {"data", res_json}, + }); + } else { + output = format_oai_sse(json {{ "error", res_json }}); + } SRV_DBG("%s", "error received during streaming, terminating stream\n"); return false; // terminate on error } else { @@ -3399,7 +3444,11 @@ private: dynamic_cast(result.get()) != nullptr || dynamic_cast(result.get()) != nullptr ); - output = format_sse(res_json); + if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { + output = format_anthropic_sse(res_json); + } else { + output = format_oai_sse(res_json); + } } // has next data, continue @@ -3507,14 +3556,14 @@ private: return res; } - std::unique_ptr handle_embeddings_impl(const server_http_req & req, oaicompat_type oaicompat) { + std::unique_ptr handle_embeddings_impl(const server_http_req & req, task_response_type res_type) { auto res = std::make_unique(ctx_server); if (!ctx_server.params_base.embedding) { res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); return res; } - if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + if (res_type != TASK_RESPONSE_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); return res; } @@ -3526,7 +3575,7 @@ private: if (body.count("input") != 0) { prompt = body.at("input"); } else if (body.contains("content")) { - oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible + res_type = TASK_RESPONSE_TYPE_NONE; // "content" field is not OAI compatible prompt = body.at("content"); } else { res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); @@ -3574,7 +3623,7 @@ private: task.tokens = std::move(tokenized_prompts[i]); // OAI-compat - task.params.oaicompat = oaicompat; + task.params.res_type = res_type; task.params.embd_normalize = embd_normalize; tasks.push_back(std::move(task)); @@ -3599,7 +3648,7 @@ private: } // write JSON response - json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING + json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD ? format_embeddings_response_oaicompat(body, responses, use_base64) : json(responses); res->ok(root); @@ -3712,6 +3761,8 @@ int main(int argc, char ** argv) { ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions)); ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions)); ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint + ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API + ctx_http.post("/v1/messages/count_tokens", ex_wrapper(routes.post_anthropic_count_tokens)); // anthropic token counting ctx_http.post("/infill", ex_wrapper(routes.post_infill)); ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings)); diff --git a/tools/server/tests/conftest.py b/tools/server/tests/conftest.py index 017d1bb841..c7ed775968 100644 --- a/tools/server/tests/conftest.py +++ b/tools/server/tests/conftest.py @@ -13,3 +13,9 @@ def stop_server_after_each_test(): ) # copy the set to prevent 'Set changed size during iteration' for server in instances: server.stop() + + +@pytest.fixture(scope="module", autouse=True) +def do_something(): + # this will be run once per test session, before any tests + ServerPreset.load_all() diff --git a/tools/server/tests/unit/test_basic.py b/tools/server/tests/unit/test_basic.py index 720b136b05..cadaa91849 100644 --- a/tools/server/tests/unit/test_basic.py +++ b/tools/server/tests/unit/test_basic.py @@ -5,12 +5,6 @@ from utils import * server = ServerPreset.tinyllama2() -@pytest.fixture(scope="session", autouse=True) -def do_something(): - # this will be run once per test session, before any tests - ServerPreset.load_all() - - @pytest.fixture(autouse=True) def create_server(): global server diff --git a/tools/server/tests/unit/test_compat_anthropic.py b/tools/server/tests/unit/test_compat_anthropic.py new file mode 100644 index 0000000000..d55dd1d945 --- /dev/null +++ b/tools/server/tests/unit/test_compat_anthropic.py @@ -0,0 +1,807 @@ +#!/usr/bin/env python3 +import pytest +import base64 +import requests + +from utils import * + +server: ServerProcess + + +def get_test_image_base64() -> str: + """Get a test image in base64 format""" + # Use the same test image as test_vision_api.py + IMG_URL = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png" + response = requests.get(IMG_URL) + response.raise_for_status() + return base64.b64encode(response.content).decode("utf-8") + +@pytest.fixture(autouse=True) +def create_server(): + global server + server = ServerPreset.tinyllama2() + server.model_alias = "tinyllama-2-anthropic" + server.server_port = 8082 + server.n_slots = 1 + server.n_ctx = 8192 + server.n_batch = 2048 + + +@pytest.fixture +def vision_server(): + """Separate fixture for vision tests that require multimodal support""" + global server + server = ServerPreset.tinygemma3() + server.offline = False # Allow downloading the model + server.model_alias = "tinygemma3-anthropic" + server.server_port = 8083 # Different port to avoid conflicts + server.n_slots = 1 + return server + + +# Basic message tests + +def test_anthropic_messages_basic(): + """Test basic Anthropic messages endpoint""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "messages": [ + {"role": "user", "content": "Say hello"} + ] + }) + + assert res.status_code == 200, f"Expected 200, got {res.status_code}" + assert res.body["type"] == "message", f"Expected type 'message', got {res.body.get('type')}" + assert res.body["role"] == "assistant", f"Expected role 'assistant', got {res.body.get('role')}" + assert "content" in res.body, "Missing 'content' field" + assert isinstance(res.body["content"], list), "Content should be an array" + assert len(res.body["content"]) > 0, "Content array should not be empty" + assert res.body["content"][0]["type"] == "text", "First content block should be text" + assert "text" in res.body["content"][0], "Text content block missing 'text' field" + assert res.body["stop_reason"] in ["end_turn", "max_tokens"], f"Invalid stop_reason: {res.body.get('stop_reason')}" + assert "usage" in res.body, "Missing 'usage' field" + assert "input_tokens" in res.body["usage"], "Missing usage.input_tokens" + assert "output_tokens" in res.body["usage"], "Missing usage.output_tokens" + assert isinstance(res.body["usage"]["input_tokens"], int), "input_tokens should be integer" + assert isinstance(res.body["usage"]["output_tokens"], int), "output_tokens should be integer" + assert res.body["usage"]["output_tokens"] > 0, "Should have generated some tokens" + # Anthropic API should NOT include timings + assert "timings" not in res.body, "Anthropic API should not include timings field" + + +def test_anthropic_messages_with_system(): + """Test messages with system prompt""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "system": "You are a helpful assistant.", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + assert len(res.body["content"]) > 0 + + +def test_anthropic_messages_multipart_content(): + """Test messages with multipart content blocks""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What is"}, + {"type": "text", "text": " the answer?"} + ] + } + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_messages_conversation(): + """Test multi-turn conversation""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +# Streaming tests + +def test_anthropic_messages_streaming(): + """Test streaming messages""" + server.start() + + res = server.make_stream_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 30, + "messages": [ + {"role": "user", "content": "Say hello"} + ], + "stream": True + }) + + events = [] + for data in res: + # Each event should have type and other fields + assert "type" in data, f"Missing 'type' in event: {data}" + events.append(data) + + # Verify event sequence + event_types = [e["type"] for e in events] + assert "message_start" in event_types, "Missing message_start event" + assert "content_block_start" in event_types, "Missing content_block_start event" + assert "content_block_delta" in event_types, "Missing content_block_delta event" + assert "content_block_stop" in event_types, "Missing content_block_stop event" + assert "message_delta" in event_types, "Missing message_delta event" + assert "message_stop" in event_types, "Missing message_stop event" + + # Check message_start structure + message_start = next(e for e in events if e["type"] == "message_start") + assert "message" in message_start, "message_start missing 'message' field" + assert message_start["message"]["type"] == "message" + assert message_start["message"]["role"] == "assistant" + assert message_start["message"]["content"] == [] + assert "usage" in message_start["message"] + assert message_start["message"]["usage"]["input_tokens"] > 0 + + # Check content_block_start + block_start = next(e for e in events if e["type"] == "content_block_start") + assert "index" in block_start, "content_block_start missing 'index'" + assert block_start["index"] == 0, "First content block should be at index 0" + assert "content_block" in block_start + assert block_start["content_block"]["type"] == "text" + + # Check content_block_delta + deltas = [e for e in events if e["type"] == "content_block_delta"] + assert len(deltas) > 0, "Should have at least one content_block_delta" + for delta in deltas: + assert "index" in delta + assert "delta" in delta + assert delta["delta"]["type"] == "text_delta" + assert "text" in delta["delta"] + + # Check content_block_stop + block_stop = next(e for e in events if e["type"] == "content_block_stop") + assert "index" in block_stop + assert block_stop["index"] == 0 + + # Check message_delta + message_delta = next(e for e in events if e["type"] == "message_delta") + assert "delta" in message_delta + assert "stop_reason" in message_delta["delta"] + assert message_delta["delta"]["stop_reason"] in ["end_turn", "max_tokens"] + assert "usage" in message_delta + assert message_delta["usage"]["output_tokens"] > 0 + + # Check message_stop + message_stop = next(e for e in events if e["type"] == "message_stop") + # message_stop should NOT have timings for Anthropic API + assert "timings" not in message_stop, "Anthropic streaming should not include timings" + + +# Token counting tests + +def test_anthropic_count_tokens(): + """Test token counting endpoint""" + server.start() + + res = server.make_request("POST", "/v1/messages/count_tokens", data={ + "model": "test", + "messages": [ + {"role": "user", "content": "Hello world"} + ] + }) + + assert res.status_code == 200 + assert "input_tokens" in res.body + assert isinstance(res.body["input_tokens"], int) + assert res.body["input_tokens"] > 0 + # Should only have input_tokens, no other fields + assert "output_tokens" not in res.body + + +def test_anthropic_count_tokens_with_system(): + """Test token counting with system prompt""" + server.start() + + res = server.make_request("POST", "/v1/messages/count_tokens", data={ + "model": "test", + "system": "You are a helpful assistant.", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["input_tokens"] > 0 + + +def test_anthropic_count_tokens_no_max_tokens(): + """Test that count_tokens doesn't require max_tokens""" + server.start() + + # max_tokens is NOT required for count_tokens + res = server.make_request("POST", "/v1/messages/count_tokens", data={ + "model": "test", + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert "input_tokens" in res.body + + +# Tool use tests + +def test_anthropic_tool_use_basic(): + """Test basic tool use""" + server.jinja = True + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 200, + "tools": [{ + "name": "get_weather", + "description": "Get the current weather in a location", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City name" + } + }, + "required": ["location"] + } + }], + "messages": [ + {"role": "user", "content": "What's the weather in Paris?"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + assert len(res.body["content"]) > 0 + + # Check if model used the tool (it might not always, depending on the model) + content_types = [block.get("type") for block in res.body["content"]] + + if "tool_use" in content_types: + # Model used the tool + assert res.body["stop_reason"] == "tool_use" + + # Find the tool_use block + tool_block = next(b for b in res.body["content"] if b.get("type") == "tool_use") + assert "id" in tool_block + assert "name" in tool_block + assert tool_block["name"] == "get_weather" + assert "input" in tool_block + assert isinstance(tool_block["input"], dict) + + +def test_anthropic_tool_result(): + """Test sending tool results back + + This test verifies that tool_result blocks are properly converted to + role="tool" messages internally. Without proper conversion, this would + fail with a 500 error: "unsupported content[].type" because tool_result + blocks would remain in the user message content array. + """ + server.jinja = True + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 100, + "messages": [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "test123", + "name": "get_weather", + "input": {"location": "Paris"} + } + ] + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "test123", + "content": "The weather is sunny, 25°C" + } + ] + } + ] + }) + + # This would be 500 with the old bug where tool_result blocks weren't converted + assert res.status_code == 200 + assert res.body["type"] == "message" + # Model should respond to the tool result + assert len(res.body["content"]) > 0 + assert res.body["content"][0]["type"] == "text" + + +def test_anthropic_tool_result_with_text(): + """Test tool result mixed with text content + + This tests the edge case where a user message contains both text and + tool_result blocks. The server must properly split these into separate + messages: a user message with text, followed by tool messages. + Without proper handling, this would fail with 500: "unsupported content[].type" + """ + server.jinja = True + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 100, + "messages": [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "tool_1", + "name": "get_weather", + "input": {"location": "Paris"} + } + ] + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "Here are the results:"}, + { + "type": "tool_result", + "tool_use_id": "tool_1", + "content": "Sunny, 25°C" + } + ] + } + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + assert len(res.body["content"]) > 0 + + +def test_anthropic_tool_result_error(): + """Test tool result with error flag""" + server.jinja = True + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 100, + "messages": [ + {"role": "user", "content": "Get the weather"}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "test123", + "name": "get_weather", + "input": {"location": "InvalidCity"} + } + ] + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "test123", + "is_error": True, + "content": "City not found" + } + ] + } + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_tool_streaming(): + """Test streaming with tool use""" + server.jinja = True + server.start() + + res = server.make_stream_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 200, + "stream": True, + "tools": [{ + "name": "calculator", + "description": "Calculate math", + "input_schema": { + "type": "object", + "properties": { + "expression": {"type": "string"} + }, + "required": ["expression"] + } + }], + "messages": [ + {"role": "user", "content": "Calculate 2+2"} + ] + }) + + events = [] + for data in res: + events.append(data) + + event_types = [e["type"] for e in events] + + # Should have basic events + assert "message_start" in event_types + assert "message_stop" in event_types + + # If tool was used, check for proper tool streaming + if any(e.get("type") == "content_block_start" and + e.get("content_block", {}).get("type") == "tool_use" + for e in events): + # Find tool use block start + tool_starts = [e for e in events if + e.get("type") == "content_block_start" and + e.get("content_block", {}).get("type") == "tool_use"] + + assert len(tool_starts) > 0, "Should have tool_use content_block_start" + + # Check index is correct (should be 0 if no text, 1 if there's text) + tool_start = tool_starts[0] + assert "index" in tool_start + assert tool_start["content_block"]["type"] == "tool_use" + assert "name" in tool_start["content_block"] + + +# Vision/multimodal tests + +def test_anthropic_vision_format_accepted(): + """Test that Anthropic vision format is accepted (format validation only)""" + server.start() + + # Small 1x1 red PNG image in base64 + red_pixel_png = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==" + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 10, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": red_pixel_png + } + }, + { + "type": "text", + "text": "What is this?" + } + ] + } + ] + }) + + # Server accepts the format but tinyllama doesn't support images + # So it should return 500 with clear error message about missing mmproj + assert res.status_code == 500 + assert "image input is not supported" in res.body.get("error", {}).get("message", "").lower() + + +def test_anthropic_vision_base64_with_multimodal_model(vision_server): + """Test vision with base64 image using Anthropic format with multimodal model""" + global server + server = vision_server + server.start() + + # Get test image in base64 format + image_base64 = get_test_image_base64() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 10, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": image_base64 + } + }, + { + "type": "text", + "text": "What is this:\n" + } + ] + } + ] + }) + + assert res.status_code == 200, f"Expected 200, got {res.status_code}: {res.body}" + assert res.body["type"] == "message" + assert len(res.body["content"]) > 0 + assert res.body["content"][0]["type"] == "text" + # The model should generate some response about the image + assert len(res.body["content"][0]["text"]) > 0 + + +# Parameter tests + +def test_anthropic_stop_sequences(): + """Test stop_sequences parameter""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 100, + "stop_sequences": ["\n", "END"], + "messages": [ + {"role": "user", "content": "Count to 10"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_temperature(): + """Test temperature parameter""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "temperature": 0.5, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_top_p(): + """Test top_p parameter""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "top_p": 0.9, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_top_k(): + """Test top_k parameter (llama.cpp specific)""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "top_k": 40, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +# Error handling tests + +def test_anthropic_missing_messages(): + """Test error when messages are missing""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50 + # missing "messages" field + }) + + # Should return an error (400 or 500) + assert res.status_code >= 400 + + +def test_anthropic_empty_messages(): + """Test permissive handling of empty messages array""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "messages": [] + }) + + # Server is permissive and accepts empty messages (provides defaults) + # This matches the permissive validation design choice + assert res.status_code == 200 + assert res.body["type"] == "message" + + +# Content block index tests + +def test_anthropic_streaming_content_block_indices(): + """Test that content block indices are correct in streaming""" + server.jinja = True + server.start() + + # Request that might produce both text and tool use + res = server.make_stream_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 200, + "stream": True, + "tools": [{ + "name": "test_tool", + "description": "A test tool", + "input_schema": { + "type": "object", + "properties": { + "param": {"type": "string"} + }, + "required": ["param"] + } + }], + "messages": [ + {"role": "user", "content": "Use the test tool"} + ] + }) + + events = [] + for data in res: + events.append(data) + + # Check content_block_start events have sequential indices + block_starts = [e for e in events if e.get("type") == "content_block_start"] + if len(block_starts) > 1: + # If there are multiple blocks, indices should be sequential + indices = [e["index"] for e in block_starts] + expected_indices = list(range(len(block_starts))) + assert indices == expected_indices, f"Expected indices {expected_indices}, got {indices}" + + # Check content_block_stop events match the starts + block_stops = [e for e in events if e.get("type") == "content_block_stop"] + start_indices = set(e["index"] for e in block_starts) + stop_indices = set(e["index"] for e in block_stops) + assert start_indices == stop_indices, "content_block_stop indices should match content_block_start indices" + + +# Extended features tests + +def test_anthropic_thinking(): + """Test extended thinking parameter""" + server.jinja = True + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 100, + "thinking": { + "type": "enabled", + "budget_tokens": 50 + }, + "messages": [ + {"role": "user", "content": "What is 2+2?"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +def test_anthropic_metadata(): + """Test metadata parameter""" + server.start() + + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "metadata": { + "user_id": "test_user_123" + }, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + +# Compatibility tests + +def test_anthropic_vs_openai_different_response_format(): + """Verify Anthropic format is different from OpenAI format""" + server.start() + + # Make OpenAI request + openai_res = server.make_request("POST", "/v1/chat/completions", data={ + "model": "test", + "max_tokens": 50, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + # Make Anthropic request + anthropic_res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 50, + "messages": [ + {"role": "user", "content": "Hello"} + ] + }) + + assert openai_res.status_code == 200 + assert anthropic_res.status_code == 200 + + # OpenAI has "object", Anthropic has "type" + assert "object" in openai_res.body + assert "type" in anthropic_res.body + assert openai_res.body["object"] == "chat.completion" + assert anthropic_res.body["type"] == "message" + + # OpenAI has "choices", Anthropic has "content" + assert "choices" in openai_res.body + assert "content" in anthropic_res.body + + # Different usage field names + assert "prompt_tokens" in openai_res.body["usage"] + assert "input_tokens" in anthropic_res.body["usage"] + assert "completion_tokens" in openai_res.body["usage"] + assert "output_tokens" in anthropic_res.body["usage"] diff --git a/tools/server/tests/unit/test_security.py b/tools/server/tests/unit/test_security.py index 0e11580553..e160a8e6d3 100644 --- a/tools/server/tests/unit/test_security.py +++ b/tools/server/tests/unit/test_security.py @@ -49,6 +49,19 @@ def test_correct_api_key(): assert "content" in res.body +def test_correct_api_key_anthropic_header(): + global server + server.start() + res = server.make_request("POST", "/completions", data={ + "prompt": "I believe the meaning of life is", + }, headers={ + "X-Api-Key": TEST_API_KEY, + }) + assert res.status_code == 200 + assert "error" not in res.body + assert "content" in res.body + + def test_openai_library_correct_api_key(): global server server.start() From 2e7ef98f18090b382611c135efc417200b23780b Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 28 Nov 2025 20:34:51 +0800 Subject: [PATCH 22/33] ggml-cuda: add stricter checking for fusion (#17568) * ggml-cuda: make conditions for fusion more explicit * ggml-cuda: remove size check as std::equal already does it --- ggml/src/ggml-cuda/ggml-cuda.cu | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 6463921a6e..a844a3d99a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3050,7 +3050,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list topk_moe_ops_delayed_softmax = ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true); - if (ops.size() == topk_moe_ops_with_norm.size() && + const auto is_equal = [](const std::initializer_list & list1, + const std::initializer_list & list2) { + return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end()); + }; + + if (is_equal(topk_moe_ops_with_norm, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * weights = cgraph->nodes[node_idx + 9]; @@ -3060,8 +3065,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } - if (ops.size() == topk_moe_ops.size() && - ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) { + if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) { ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * weights = cgraph->nodes[node_idx + 4]; if (ggml_cuda_should_use_topk_moe(softmax, weights)) { @@ -3069,7 +3073,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } - if (ops.size() == topk_moe_ops_delayed_softmax.size() && + if (is_equal(topk_moe_ops_delayed_softmax, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) { ggml_tensor * softmax = cgraph->nodes[node_idx + 4]; ggml_tensor * weights = cgraph->nodes[node_idx + 5]; @@ -3085,9 +3089,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU }; std::initializer_list mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU }; - if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) || - ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) { - + if ((is_equal(mul_mat_bias_glu_ops, ops) || is_equal(mul_mat_id_bias_glu_ops, ops)) && + ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 4 })) { const ggml_tensor * ffn_gate = cgraph->nodes[node_idx]; const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1]; const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2]; @@ -3099,9 +3102,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } - if (ops.size() == 3 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}) || - ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}))) { - + if ((is_equal(mul_mat_id_glu_ops, ops) || is_equal(mul_mat_glu_ops, ops)) && + ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) { const ggml_tensor * ffn_gate = cgraph->nodes[node_idx]; const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1]; const ggml_tensor * glu = cgraph->nodes[node_idx + 2]; @@ -3111,7 +3113,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } - if (ops.size() == 3 && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) { + std::initializer_list rope_set_rows_ops = { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }; + + if (is_equal(rope_set_rows_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) { const ggml_tensor * rope = cgraph->nodes[node_idx]; const ggml_tensor * view = cgraph->nodes[node_idx + 1]; const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2]; From c6f7a423c8c87748ef563a99d81c3b1b05cecff0 Mon Sep 17 00:00:00 2001 From: R0CKSTAR Date: Fri, 28 Nov 2025 21:08:29 +0800 Subject: [PATCH 23/33] [MUSA] enable fp16/fast_fp16/bf16_mma on PH1 (#17551) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [MUSA] enable fp16/fast_fp16/bf16_mma on PH1 Signed-off-by: Xiaodong Ye * Update ggml/src/ggml-cuda/fattn-vec.cuh Co-authored-by: Johannes Gäßler * Update ggml/src/ggml-cuda/fattn-vec.cuh Co-authored-by: Johannes Gäßler * Update ggml/src/ggml-cuda/fattn-tile.cuh Co-authored-by: Johannes Gäßler * Address review comments Signed-off-by: Xiaodong Ye --------- Signed-off-by: Xiaodong Ye Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/common.cuh | 20 ++++++++++++-------- ggml/src/ggml-cuda/cpy.cu | 5 ++++- ggml/src/ggml-cuda/fattn-tile.cuh | 2 +- ggml/src/ggml-cuda/fattn-vec.cuh | 4 ++-- ggml/src/ggml-cuda/mma.cuh | 4 ++-- 5 files changed, 21 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 81fb312409..0b10e5f6ae 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -84,12 +84,12 @@ #define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000 #define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000 -#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD +#define GGML_CUDA_CC_PH1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // MTT S5000 #define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD) #define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2) -#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG) -#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG) +#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_PH1) +#define GGML_CUDA_CC_IS_PH1(cc) (cc >= GGML_CUDA_CC_PH1) #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 # define GGML_CUDA_USE_CUB @@ -212,9 +212,9 @@ static const char * cu_get_error_str(CUresult err) { #define GGML_USE_VMM #endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) -#if defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL +#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL #define FP16_AVAILABLE -#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL +#endif // defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL #if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 #define FAST_FP16_AVAILABLE @@ -250,12 +250,14 @@ static const char * cu_get_error_str(CUresult err) { #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220) static bool fp16_available(const int cc) { - return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL; + return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL || + (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1); } static bool fast_fp16_available(const int cc) { return GGML_CUDA_CC_IS_AMD(cc) || - (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610); + (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610) || + (GGML_CUDA_CC_IS_MTHREADS(cc) && fp16_available(cc)); } // To be used for feature selection of external libraries, e.g. cuBLAS. @@ -272,7 +274,9 @@ static bool fp16_mma_hardware_available(const int cc) { } static bool bf16_mma_hardware_available(const int cc) { - return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3; + return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || + GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3 || + (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1); } static bool fp32_mma_hardware_available(const int cc) { diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 82367ad3fb..c4ceb4fc57 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -86,6 +86,9 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const } } } + + GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, + nb12, nb13); } static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { @@ -202,7 +205,7 @@ static void ggml_cpy_scalar_cuda( ne00n = ne00; ne01n = ne01; ne02n = ne02; - } else if (nb00 > nb02) { + } else { ne00n = ne00; ne01n = ne01*ne02; ne02n = 1; diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index c358aa1e87..3e58d64ff9 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -609,7 +609,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( float KQ_sum_add = 0.0f; #pragma unroll for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { - const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ? + const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < static_cast(k_VKQ_sup) ? expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f; KQ_sum_add += val; tmp[i0/(np*warp_size)][jc1] = val; diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 6e63e860ac..67aa67ecb9 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -155,7 +155,7 @@ static __global__ void flash_attn_ext_vec( for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - if (i0 + WARP_SIZE <= D/sizeof(int) || i < D/sizeof(int)) { + if (i0 + WARP_SIZE <= int(D/sizeof(int)) || i < int(D/sizeof(int))) { tmp_q_i32[i] = 0; } } @@ -272,7 +272,7 @@ static __global__ void flash_attn_ext_vec( KQ_max_new[j] = fmaxf(KQ_max_new[j], sum); - if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == i_KQ_0) { + if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) { KQ_reg[j] = sum; } } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index c0a9c2c08a..0ed42e87d3 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -889,8 +889,8 @@ namespace ggml_cuda_mma { : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) : "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7])); #else - tile<16, 8, float> * D16 = (tile<16, 8, float> *) &D; - tile<16, 8, half2> * A16 = (tile<16, 8, half2> *) &A; + tile <16, 8, float> * D16 = reinterpret_cast *>(&D); + const tile<16, 8, half2> * A16 = reinterpret_cast *>(&A); mma(D16[0], A16[0], B); mma(D16[1], A16[1], B); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE From e072b2052e9250395e4a28a28d37806342ac5db1 Mon Sep 17 00:00:00 2001 From: Diego Devesa Date: Fri, 28 Nov 2025 07:33:23 -0800 Subject: [PATCH 24/33] ggml : add GGML_SCHED_NO_REALLOC option to disable reallocations in ggml_backend_sched (#17276) * ggml : add GGML_SCHED_NO_REALLOC option to disable reallocations in ggml_backend_sched Enabled in ggml-ci for testing. * llama : update worst-case graph for unified cache * ci : disable op offload in some tests * fix spelling --------- Co-authored-by: Georgi Gerganov --- ci/run.sh | 16 ++++++++-------- examples/embedding/embedding.cpp | 7 ++++--- ggml/CMakeLists.txt | 1 + ggml/src/CMakeLists.txt | 4 ++++ ggml/src/ggml-alloc.c | 11 ++++++++--- ggml/src/ggml-backend.cpp | 12 +++++++++--- src/llama-context.cpp | 4 ++-- tests/CMakeLists.txt | 2 +- 8 files changed, 37 insertions(+), 20 deletions(-) diff --git a/ci/run.sh b/ci/run.sh index 3fec8e9110..1dd65adeaa 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -45,7 +45,7 @@ sd=`dirname $0` cd $sd/../ SRC=`pwd` -CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON" +CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON -DGGML_SCHED_NO_REALLOC=ON" if [ ! -z ${GG_BUILD_METAL} ]; then CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON" @@ -428,10 +428,10 @@ function gg_run_qwen3_0_6b { (time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -ngl 99 -c 1024 -b 512 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log - (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log - (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log function check_ppl { qnt="$1" @@ -523,8 +523,8 @@ function gg_run_embd_bge_small { ./bin/llama-quantize ${model_f16} ${model_q8_0} q8_0 - (time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log - (time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log + (time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-f16.log + (time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log set +e } @@ -564,7 +564,7 @@ function gg_run_rerank_tiny { model_f16="${path_models}/ggml-model-f16.gguf" # for this model, the SEP token is "" - (time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log + (time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --no-op-offload --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log # sample output # rerank score 0: 0.029 diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 9e3ab5905b..fe91b308cd 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -104,12 +104,16 @@ int main(int argc, char ** argv) { params.embedding = true; + // get max number of sequences per batch + const int n_seq_max = llama_max_parallel_sequences(); + // if the number of prompts that would be encoded is known in advance, it's more efficient to specify the // --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache // in order to support any number of prompts if (params.n_parallel == 1) { LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__); params.kv_unified = true; + params.n_parallel = n_seq_max; } // utilize the full context @@ -123,9 +127,6 @@ int main(int argc, char ** argv) { params.n_ubatch = params.n_batch; } - // get max number of sequences per batch - const int n_seq_max = llama_max_parallel_sequences(); - llama_backend_init(); llama_numa_init(params.numa); diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 0211255a76..9b10df00da 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -183,6 +183,7 @@ endif() # ggml core set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism") option(GGML_CPU "ggml: enable CPU backend" ON) +option(GGML_SCHED_NO_REALLOC "ggml: disallow reallocations in ggml-alloc (for debugging)" OFF) # 3rd party libs / backends option(GGML_ACCELERATE "ggml: enable Accelerate framework" ON) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index a4499509ec..a36f5b6647 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -221,6 +221,10 @@ if (GGML_BACKEND_DL) target_compile_definitions(ggml-base PUBLIC GGML_BACKEND_DL) endif() +if (GGML_SCHED_NO_REALLOC) + target_compile_definitions(ggml-base PUBLIC GGML_SCHED_NO_REALLOC) +endif() + add_library(ggml ggml-backend-reg.cpp) add_library(ggml::ggml ALIAS ggml) diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 91aff205f1..218222ece8 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -921,10 +921,15 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c } if (realloc) { #ifndef NDEBUG - size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0; - GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); + { + size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0; + if (cur_size > 0) { + GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", + __func__, ggml_backend_buft_name(galloc->bufts[i]), + cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); + } + } #endif - ggml_vbuffer_free(galloc->buffers[i]); galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); if (galloc->buffers[i] == NULL) { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index eeaf35c169..4cf377e7f3 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1395,14 +1395,20 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { // allocate graph if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { +#ifdef GGML_SCHED_NO_REALLOC + GGML_ABORT("%s: failed to allocate graph, but graph re-allocation is disabled by GGML_SCHED_NO_REALLOC\n", __func__); +#endif + +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed); +#endif + // the re-allocation may cause the split inputs to be moved to a different address // synchronize without ggml_backend_sched_synchronize to avoid changing cur_copy for (int i = 0; i < sched->n_backends; i++) { ggml_backend_synchronize(sched->backends[i]); } -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed); -#endif + ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids); if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { GGML_LOG_ERROR("%s: failed to allocate graph\n", __func__); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a58914429c..e04f0fc4f9 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -300,7 +300,7 @@ llama_context::llama_context( cross.v_embd.clear(); - const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; + const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); // avoid reserving graphs with zero outputs - assume one output per sequence @@ -543,7 +543,7 @@ bool llama_context::memory_update(bool optimize) { throw std::runtime_error("failed to initialize memory context"); } - const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; + const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d9cc5e933f..9361a113a1 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -196,7 +196,7 @@ if (NOT WIN32) llama_build_and_test(test-arg-parser.cpp) endif() -if (NOT LLAMA_SANITIZE_ADDRESS) +if (NOT LLAMA_SANITIZE_ADDRESS AND NOT GGML_SCHED_NO_REALLOC) # TODO: repair known memory leaks llama_build_and_test(test-opt.cpp) endif() From 3ce7a65c2f2529a8fc566b4aead53b088f7faec2 Mon Sep 17 00:00:00 2001 From: o7si <32285332+o7si@users.noreply.github.com> Date: Sat, 29 Nov 2025 02:14:00 +0800 Subject: [PATCH 25/33] server: fix: /metrics endpoint returning JSON-escaped Prometheus format (#17386) * fix: /metrics endpoint returning JSON-escaped Prometheus format * mod: remove string overload from ok() method --- tools/server/server.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 05bbe648c1..96b2df27f7 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2713,7 +2713,8 @@ public: res->headers["Process-Start-Time-Unix"] = std::to_string(res_task->t_start); res->content_type = "text/plain; version=0.0.4"; - res->ok(prometheus.str()); + res->status = 200; + res->data = prometheus.str(); return res; }; From 03914c7ef826caf0b6371a6d1de270cda102b542 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?DAN=E2=84=A2?= Date: Fri, 28 Nov 2025 13:29:36 -0500 Subject: [PATCH 26/33] common : move all common_chat_parse_* to chat-parser.cpp. (#17481) --- common/chat-parser.cpp | 968 +++++++++++++++++++++++++++++++++++++++++ common/chat.cpp | 952 ---------------------------------------- 2 files changed, 968 insertions(+), 952 deletions(-) diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index ff83102788..301f439a6f 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -13,6 +13,120 @@ using json = nlohmann::ordered_json; +static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, + const common_regex & prefix, + size_t rstrip_prefix = 0) { + static const std::vector> args_paths = { { "arguments" } }; + if (auto res = builder.try_find_regex(prefix)) { + builder.move_back(rstrip_prefix); + auto tool_calls = builder.consume_json_with_dumped_args(args_paths); + if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call array"); + } + } else { + builder.add_content(builder.consume_rest()); + } +} + +static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) { + std::string arguments; + if (builder.is_partial()) { + arguments = (json{ + { "code", code + builder.healing_marker() } + }) + .dump(); + auto idx = arguments.find(builder.healing_marker()); + if (idx != std::string::npos) { + arguments.resize(idx); + } + } else { + arguments = (json{ + { "code", code } + }) + .dump(); + } + return arguments; +} + +/** + * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. + * Aggregates the prefix, suffix and in-between text into the content. + */ +static void parse_json_tool_calls( + common_chat_msg_parser & builder, + const std::optional & block_open, + const std::optional & function_regex_start_only, + const std::optional & function_regex, + const common_regex & close_regex, + const std::optional & block_close, + bool allow_raw_python = false, + const std::function & get_function_name = + nullptr) { + auto parse_tool_calls = [&]() { + size_t from = std::string::npos; + auto first = true; + while (true) { + auto start_pos = builder.pos(); + auto res = function_regex_start_only && first ? builder.try_consume_regex(*function_regex_start_only) : + function_regex ? builder.try_find_regex(*function_regex, from) : + std::nullopt; + + if (res) { + std::string name; + if (get_function_name) { + name = get_function_name(*res); + } else { + GGML_ASSERT(res->groups.size() == 2); + name = builder.str(res->groups[1]); + } + first = false; + if (name.empty()) { + // get_function_name signalled us that we should skip this match and treat it as content. + from = res->groups[0].begin + 1; + continue; + } + from = std::string::npos; + + auto maybe_raw_python = name == "python" && allow_raw_python; + if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) { + if (auto arguments = builder.try_consume_json_with_dumped_args({ {} })) { + if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_regex(close_regex); + } + continue; + } + if (maybe_raw_python) { + auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); + if (!builder.add_tool_call(name, "", arguments)) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + return; + } + throw common_chat_msg_partial_exception("incomplete tool call"); + } else { + builder.move_to(start_pos); + } + break; + } + if (block_close) { + builder.consume_regex(*block_close); + } + builder.consume_spaces(); + builder.add_content(builder.consume_rest()); + }; + if (block_open) { + if (auto res = builder.try_find_regex(*block_open)) { + parse_tool_calls(); + } else { + builder.add_content(builder.consume_rest()); + } + } else { + parse_tool_calls(); + } +} + common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax) : input_(input), is_partial_(is_partial), syntax_(syntax) { @@ -532,3 +646,857 @@ std::optional common_chat_msg_parse void common_chat_msg_parser::clear_tools() { result_.tool_calls.clear(); } + +/** + * All common_chat_parse_* moved from chat.cpp to chat-parser.cpp below + * to reduce incremental compile time for parser changes. + */ +static void common_chat_parse_generic(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + static const std::vector> content_paths = { + {"response"}, + }; + static const std::vector> args_paths = { + {"tool_call", "arguments"}, + {"tool_calls", "arguments"}, + }; + auto data = builder.consume_json_with_dumped_args(args_paths, content_paths); + if (data.value.contains("tool_calls")) { + if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool calls"); + } + } else if (data.value.contains("tool_call")) { + if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } else if (data.value.contains("response")) { + const auto & response = data.value.at("response"); + builder.add_content(response.is_string() ? response.template get() : response.dump(2)); + if (data.is_partial) { + throw common_chat_msg_partial_exception("incomplete response"); + } + } else { + throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON"); + } +} + +static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex prefix(regex_escape("[TOOL_CALLS]")); + parse_prefixed_json_tool_call_array(builder, prefix); +} + +static void common_chat_parse_magistral(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("[THINK]", "[/THINK]"); + + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex prefix(regex_escape("[TOOL_CALLS]")); + parse_prefixed_json_tool_call_array(builder, prefix); +} + +static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>"); + + static const common_regex start_action_regex("<\\|START_ACTION\\|>"); + static const common_regex end_action_regex("<\\|END_ACTION\\|>"); + static const common_regex start_response_regex("<\\|START_RESPONSE\\|>"); + static const common_regex end_response_regex("<\\|END_RESPONSE\\|>"); + + if (auto res = builder.try_find_regex(start_action_regex)) { + // If we didn't extract thoughts, prelude includes them. + auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}}); + for (const auto & tool_call : tool_calls.value) { + std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; + std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : ""; + std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : ""; + if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } + if (tool_calls.is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_regex(end_action_regex); + } else if (auto res = builder.try_find_regex(start_response_regex)) { + if (!builder.try_find_regex(end_response_regex)) { + builder.add_content(builder.consume_rest()); + throw common_chat_msg_partial_exception(end_response_regex.str()); + } + } else { + builder.add_content(builder.consume_rest()); + } +} + +static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { + builder.try_parse_reasoning("", ""); + + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex function_regex( + "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: "); + static const common_regex close_regex("\\}\\s*"); + + static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\("); + static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*"); + + if (with_builtin_tools) { + static const common_regex builtin_call_regex("<\\|python_tag\\|>"); + if (auto res = builder.try_find_regex(builtin_call_regex)) { + auto fun_res = builder.consume_regex(function_name_regex); + auto function_name = builder.str(fun_res.groups[1]); + + common_healing_marker healing_marker; + json args = json::object(); + while (true) { + if (auto arg_res = builder.try_consume_regex(arg_name_regex)) { + auto arg_name = builder.str(arg_res->groups[1]); + auto partial = builder.consume_json(); + args[arg_name] = partial.json; + healing_marker.marker = partial.healing_marker.marker; + healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker; + builder.consume_spaces(); + if (!builder.try_consume_literal(",")) { + break; + } + } else { + break; + } + } + builder.consume_literal(")"); + builder.consume_spaces(); + + auto arguments = args.dump(); + if (!builder.add_tool_call(function_name, "", arguments)) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + return; + } + } + parse_json_tool_calls( + builder, + /* block_open= */ std::nullopt, + /* function_regex_start_only= */ function_regex, + /* function_regex= */ std::nullopt, + close_regex, + std::nullopt); + +} + +static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); + static const common_regex tool_calls_end("<|tool▁calls▁end|>"); + static const common_regex function_regex("(?:<|tool▁call▁begin|>)?function<|tool▁sep|>([^\n]+)\n```json\n"); + static const common_regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); + + parse_json_tool_calls( + builder, + /* block_open= */ tool_calls_begin, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + tool_calls_end); +} + +static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) { + static const common_regex function_regex("(?:<|tool▁call▁begin|>)?([^\\n<]+)(?:<|tool▁sep|>)"); + + static const common_regex close_regex("(?:[\\s]*)?<|tool▁call▁end|>"); + static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); + static const common_regex tool_calls_end("<|tool▁calls▁end|>"); + + if (!builder.syntax().parse_tool_calls) { + LOG_DBG("%s: not parse_tool_calls\n", __func__); + builder.add_content(builder.consume_rest()); + return; + } + + LOG_DBG("%s: parse_tool_calls\n", __func__); + + parse_json_tool_calls( + builder, + /* block_open= */ tool_calls_begin, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + tool_calls_end); +} + +static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) { + // DeepSeek V3.1 outputs reasoning content between "" and "" tags, followed by regular content + // First try to parse using the standard reasoning parsing method + LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str()); + + auto start_pos = builder.pos(); + auto found_end_think = builder.try_find_literal("
"); + builder.move_to(start_pos); + + if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) { + LOG_DBG("%s: no end_think, not partial, adding content\n", __func__); + common_chat_parse_deepseek_v3_1_content(builder); + } else if (builder.try_parse_reasoning("", "")) { + // If reasoning was parsed successfully, the remaining content is regular content + LOG_DBG("%s: parsed reasoning, adding content\n", __func__); + //
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>NAME\n```json\nJSON\n```<|tool▁call▁end|><|tool▁calls▁end|> + common_chat_parse_deepseek_v3_1_content(builder); + } else { + if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) { + LOG_DBG("%s: reasoning_format none, adding content\n", __func__); + common_chat_parse_deepseek_v3_1_content(builder); + return; + } + // If no reasoning tags found, check if we should treat everything as reasoning + if (builder.syntax().thinking_forced_open) { + // If thinking is forced open but no tags found, treat everything as reasoning + LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__); + builder.add_reasoning_content(builder.consume_rest()); + } else { + LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__); + // <|tool▁call▁begin|>NAME<|tool▁sep|>JSON<|tool▁call▁end|> + common_chat_parse_deepseek_v3_1_content(builder); + } + } +} + +static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) { + static const xml_tool_call_format form { + /* form.scope_start = */ "", + /* form.tool_start = */ "", + /* form.key_start = */ "", + /* form.val_end = */ "", + /* form.tool_end = */ "", + /* form.scope_end = */ "", + }; + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) { + static const xml_tool_call_format form = ([]() { + xml_tool_call_format form {}; + form.scope_start = ""; + form.tool_start = "", ""); +} + +static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) { + static const xml_tool_call_format form = ([]() { + xml_tool_call_format form {}; + form.scope_start = "["; + form.tool_start = "{\"name\": \""; + form.tool_sep = "\", \"arguments\": {"; + form.key_start = "\""; + form.key_val_sep = "\": "; + form.val_end = ", "; + form.tool_end = "}, "; + form.scope_end = "]"; + form.raw_argval = false; + form.last_val_end = ""; + form.last_tool_end = "}"; + return form; + })(); + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) { + static const xml_tool_call_format form = ([]() { + xml_tool_call_format form {}; + form.scope_start = ""; + form.tool_start = "\n{\"name\": \""; + form.tool_sep = "\", \"arguments\": {"; + form.key_start = "\""; + form.key_val_sep = "\": "; + form.val_end = ", "; + form.tool_end = "}\n"; + form.scope_end = ""; + form.raw_argval = false; + form.last_val_end = ""; + return form; + })(); + builder.consume_reasoning_with_xml_tool_calls(form); +} + +static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) { + static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))"; + static const std::string recipient("(?: to=functions\\.([^<\\s]+))"); + + static const common_regex start_regex("<\\|start\\|>assistant"); + static const common_regex analysis_regex("<\\|channel\\|>analysis"); + static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?"); + static const common_regex preamble_regex("<\\|channel\\|>commentary"); + static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?"); + static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?"); + + auto consume_end = [&](bool include_end = false) { + if (auto res = builder.try_find_literal("<|end|>")) { + return res->prelude + (include_end ? builder.str(res->groups[0]) : ""); + } + return builder.consume_rest(); + }; + + auto handle_tool_call = [&](const std::string & name) { + if (auto args = builder.try_consume_json_with_dumped_args({{}})) { + if (builder.syntax().parse_tool_calls) { + if (!builder.add_tool_call(name, "", args->value) || args->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } else if (args->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } + }; + + auto regex_match = [](const common_regex & regex, const std::string & input) -> std::optional { + auto match = regex.search(input, 0, true); + if (match.type == COMMON_REGEX_MATCH_TYPE_FULL) { + return match; + } + return std::nullopt; + }; + + do { + auto header_start_pos = builder.pos(); + auto content_start = builder.try_find_literal("<|message|>"); + if (!content_start) { + throw common_chat_msg_partial_exception("incomplete header"); + } + + auto header = content_start->prelude; + + if (auto match = regex_match(tool_call1_regex, header)) { + auto group = match->groups[1]; + auto name = header.substr(group.begin, group.end - group.begin); + handle_tool_call(name); + continue; + } + + if (auto match = regex_match(tool_call2_regex, header)) { + auto group = match->groups[2]; + auto name = header.substr(group.begin, group.end - group.begin); + handle_tool_call(name); + continue; + } + + if (regex_match(analysis_regex, header)) { + builder.move_to(header_start_pos); + if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) { + builder.add_content(consume_end(true)); + } else { + builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>"); + } + continue; + } + + if(regex_match(final_regex, header) || regex_match(preamble_regex, header)) { + builder.add_content(consume_end()); + continue; + } + + // Possibly a malformed message, attempt to recover by rolling + // back to pick up the next <|start|> + LOG_DBG("%s: unknown header from message: %s\n", __func__, header.c_str()); + builder.move_to(header_start_pos); + } while (builder.try_find_regex(start_regex, std::string::npos, false)); + + auto remaining = builder.consume_rest(); + if (!remaining.empty()) { + LOG_DBG("%s: content after last message: %s\n", __func__, remaining.c_str()); + } +} + +static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) { + static const xml_tool_call_format form { + /* form.scope_start = */ "", + /* form.tool_start = */ "", + /* form.tool_sep = */ "", + /* form.key_start = */ "", + /* form.key_val_sep = */ "", + /* form.val_end = */ "", + /* form.tool_end = */ "", + /* form.scope_end = */ "", + /* form.key_val_sep2 = */ "", + }; + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + static const common_regex prefix(regex_escape(" functools[")); + parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1); +} + +static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) { + static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))"); + static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))"); + static const common_regex close_regex(R"(\s*)"); + + parse_json_tool_calls( + builder, + std::nullopt, + function_regex_start_only, + function_regex, + close_regex, + std::nullopt, + /* allow_raw_python= */ true, + /* get_function_name= */ [&](const auto & res) -> std::string { + auto at_start = res.groups[0].begin == 0; + auto name = builder.str(res.groups[1]); + if (!name.empty() && name.back() == '{') { + // Unconsume the opening brace '{' to ensure the JSON parsing goes well. + builder.move_back(1); + } + auto idx = name.find_last_not_of("\n{"); + name = name.substr(0, idx + 1); + if (at_start && name == "all") { + return ""; + } + return name; + }); +} + +static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + // This version of Functionary still supports the llama 3.1 tool call format for the python tool. + static const common_regex python_tag_regex(regex_escape("<|python_tag|>")); + + static const common_regex function_regex(R"()"); + static const common_regex close_regex(R"()"); + + parse_json_tool_calls( + builder, + /* block_open= */ std::nullopt, + /* function_regex_start_only= */ std::nullopt, + function_regex, + close_regex, + std::nullopt); + + if (auto res = builder.try_find_regex(python_tag_regex)) { + auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); + builder.add_tool_call("python", "", arguments); + return; + } +} + +static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + static const common_regex open_regex( + "(?:" + "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start) + "(" // match 2 (open_tag) + "" + "|" + "|" + "|" + "|" + "|" + "|" + "|" + ")?" + "(\\s*\\{\\s*\"name\")" // match 3 (named tool call) + ")" + "|]+)>" // match 4 (function name) + "|" // match 5 (function name again) + ); + + while (auto res = builder.try_find_regex(open_regex)) { + const auto & block_start = res->groups[1]; + std::string block_end = block_start.empty() ? "" : "```"; + + const auto & open_tag = res->groups[2]; + std::string close_tag; + + if (!res->groups[3].empty()) { + builder.move_to(res->groups[3].begin); + close_tag = open_tag.empty() ? "" : "value) || tool_call->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_spaces(); + builder.consume_literal(close_tag); + builder.consume_spaces(); + if (!block_end.empty()) { + builder.consume_literal(block_end); + builder.consume_spaces(); + } + } else { + throw common_chat_msg_partial_exception("failed to parse tool call"); + } + } else { + auto function_name = builder.str(res->groups[4]); + if (function_name.empty()) { + function_name = builder.str(res->groups[5]); + } + GGML_ASSERT(!function_name.empty()); + + close_tag = ""; + + if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { + if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + builder.consume_spaces(); + builder.consume_literal(close_tag); + builder.consume_spaces(); + if (!block_end.empty()) { + builder.consume_literal(block_end); + builder.consume_spaces(); + } + } + } + } + + builder.add_content(builder.consume_rest()); +} + +static void common_chat_parse_granite(common_chat_msg_parser & builder) { + // Parse thinking tags + static const common_regex start_think_regex(regex_escape("")); + static const common_regex end_think_regex(regex_escape("")); + // Granite models output partial tokens such as "<" and "groups[0].begin); + builder.try_find_regex(end_think_regex, std::string::npos, false); + // Restore position for try_parse_reasoning() + builder.move_to(res->groups[0].begin); + } + builder.try_parse_reasoning("", ""); + + // Parse response tags + static const common_regex start_response_regex(regex_escape("")); + static const common_regex end_response_regex(regex_escape("")); + // Granite models output partial tokens such as "<" and "")); + if (auto res = builder.try_find_regex(tool_call_regex)) { + builder.move_to(res->groups[0].end); + + // Expect JSON array of tool calls + if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) { + if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) { + throw common_chat_msg_partial_exception("incomplete tool call"); + } + } + } else { + builder.add_content(builder.consume_rest()); + } +} + +static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) { + // Parse thinking tags + builder.try_parse_reasoning("", ""); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // Look for tool calls + static const common_regex tool_call_regex(regex_escape("")); + if (auto res = builder.try_find_regex(tool_call_regex)) { + builder.move_to(res->groups[0].end); + + // Expect JSON array of tool calls + auto tool_calls_data = builder.consume_json(); + if (tool_calls_data.json.is_array()) { + if (!builder.try_consume_literal("")) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + builder.add_tool_calls(tool_calls_data.json); + } else { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } + builder.add_content(builder.consume_rest()); +} + +static void common_chat_parse_apertus(common_chat_msg_parser & builder) { + // Parse thinking tags + builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>"); + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // Look for tool calls + static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>")); + if (auto res = builder.try_find_regex(tool_call_regex)) { + builder.move_to(res->groups[0].end); + + auto tool_calls_data = builder.consume_json(); + if (tool_calls_data.json.is_array()) { + builder.consume_spaces(); + if (!builder.try_consume_literal("<|tools_suffix|>")) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + for (const auto & value : tool_calls_data.json) { + if (value.is_object()) { + builder.add_tool_call_short_form(value); + } + } + } else { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } + builder.add_content(builder.consume_rest()); +} + + +static void common_chat_parse_lfm2(common_chat_msg_parser & builder) { + if (!builder.syntax().parse_tool_calls) { + builder.add_content(builder.consume_rest()); + return; + } + + // LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|> + static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>")); + static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>")); + + // Loop through all tool calls + while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) { + builder.move_to(res->groups[0].end); + + // Parse JSON array format: [{"name": "...", "arguments": {...}}] + auto tool_calls_data = builder.consume_json(); + + // Consume end marker + builder.consume_spaces(); + if (!builder.try_consume_regex(tool_call_end_regex)) { + throw common_chat_msg_partial_exception("Expected <|tool_call_end|>"); + } + + // Process each tool call in the array + if (tool_calls_data.json.is_array()) { + for (const auto & tool_call : tool_calls_data.json) { + if (!tool_call.is_object()) { + throw common_chat_msg_partial_exception("Tool call must be an object"); + } + + if (!tool_call.contains("name")) { + throw common_chat_msg_partial_exception("Tool call missing 'name' field"); + } + + std::string function_name = tool_call.at("name"); + std::string arguments = "{}"; + + if (tool_call.contains("arguments")) { + if (tool_call.at("arguments").is_object()) { + arguments = tool_call.at("arguments").dump(); + } else if (tool_call.at("arguments").is_string()) { + arguments = tool_call.at("arguments"); + } + } + + if (!builder.add_tool_call(function_name, "", arguments)) { + throw common_chat_msg_partial_exception("Incomplete tool call"); + } + } + } else { + throw common_chat_msg_partial_exception("Expected JSON array for tool calls"); + } + + // Consume any trailing whitespace after this tool call + builder.consume_spaces(); + } + + // Consume any remaining content after all tool calls + auto remaining = builder.consume_rest(); + if (!string_strip(remaining).empty()) { + builder.add_content(remaining); + } +} + +static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { + static const xml_tool_call_format form { + /* form.scope_start = */ "", + /* form.tool_start = */ "", + /* form.key_start = */ "", + /* form.val_end = */ "", + /* form.tool_end = */ "", + /* form.scope_end = */ "", + }; + builder.consume_reasoning_with_xml_tool_calls(form, "", ""); +} + +static void common_chat_parse_content_only(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("", ""); + builder.add_content(builder.consume_rest()); +} + +static void common_chat_parse(common_chat_msg_parser & builder) { + LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str()); + + switch (builder.syntax().format) { + case COMMON_CHAT_FORMAT_CONTENT_ONLY: + common_chat_parse_content_only(builder); + break; + case COMMON_CHAT_FORMAT_GENERIC: + common_chat_parse_generic(builder); + break; + case COMMON_CHAT_FORMAT_MISTRAL_NEMO: + common_chat_parse_mistral_nemo(builder); + break; + case COMMON_CHAT_FORMAT_MAGISTRAL: + common_chat_parse_magistral(builder); + break; + case COMMON_CHAT_FORMAT_LLAMA_3_X: + common_chat_parse_llama_3_1(builder); + break; + case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: + common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true); + break; + case COMMON_CHAT_FORMAT_DEEPSEEK_R1: + common_chat_parse_deepseek_r1(builder); + break; + case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1: + common_chat_parse_deepseek_v3_1(builder); + break; + case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: + common_chat_parse_functionary_v3_2(builder); + break; + case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: + common_chat_parse_functionary_v3_1_llama_3_1(builder); + break; + case COMMON_CHAT_FORMAT_HERMES_2_PRO: + common_chat_parse_hermes_2_pro(builder); + break; + case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: + common_chat_parse_firefunction_v2(builder); + break; + case COMMON_CHAT_FORMAT_COMMAND_R7B: + common_chat_parse_command_r7b(builder); + break; + case COMMON_CHAT_FORMAT_GRANITE: + common_chat_parse_granite(builder); + break; + case COMMON_CHAT_FORMAT_GPT_OSS: + common_chat_parse_gpt_oss(builder); + break; + case COMMON_CHAT_FORMAT_SEED_OSS: + common_chat_parse_seed_oss(builder); + break; + case COMMON_CHAT_FORMAT_NEMOTRON_V2: + common_chat_parse_nemotron_v2(builder); + break; + case COMMON_CHAT_FORMAT_APERTUS: + common_chat_parse_apertus(builder); + break; + case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: + common_chat_parse_lfm2(builder); + break; + case COMMON_CHAT_FORMAT_MINIMAX_M2: + common_chat_parse_minimax_m2(builder); + break; + case COMMON_CHAT_FORMAT_GLM_4_5: + common_chat_parse_glm_4_5(builder); + break; + case COMMON_CHAT_FORMAT_KIMI_K2: + common_chat_parse_kimi_k2(builder); + break; + case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: + common_chat_parse_qwen3_coder_xml(builder); + break; + case COMMON_CHAT_FORMAT_APRIEL_1_5: + common_chat_parse_apriel_1_5(builder); + break; + case COMMON_CHAT_FORMAT_XIAOMI_MIMO: + common_chat_parse_xiaomi_mimo(builder); + break; + default: + throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); + } + builder.finish(); +} + +common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) { + common_chat_msg_parser builder(input, is_partial, syntax); + try { + common_chat_parse(builder); + } catch (const common_chat_msg_partial_exception & ex) { + LOG_DBG("Partial parse: %s\n", ex.what()); + if (!is_partial) { + builder.clear_tools(); + builder.move_to(0); + common_chat_parse_content_only(builder); + } + } + auto msg = builder.result(); + if (!is_partial) { + LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); + } + return msg; +} diff --git a/common/chat.cpp b/common/chat.cpp index 6fa05a6041..b4a0f985e2 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -678,114 +678,6 @@ common_reasoning_format common_reasoning_format_from_name(const std::string & fo throw std::runtime_error("Unknown reasoning format: " + format); } -static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) { - std::string arguments; - if (builder.is_partial()) { - arguments = (json {{"code", code + builder.healing_marker()}}).dump(); - auto idx = arguments.find(builder.healing_marker()); - if (idx != std::string::npos) { - arguments.resize(idx); - } - } else { - arguments = (json {{"code", code}}).dump(); - } - return arguments; -} - -/** - * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. - * Aggregates the prefix, suffix and in-between text into the content. - */ -static void parse_json_tool_calls( - common_chat_msg_parser & builder, - const std::optional & block_open, - const std::optional & function_regex_start_only, - const std::optional & function_regex, - const common_regex & close_regex, - const std::optional & block_close, - bool allow_raw_python = false, - const std::function & get_function_name = nullptr) { - - auto parse_tool_calls = [&]() { - size_t from = std::string::npos; - auto first = true; - while (true) { - auto start_pos = builder.pos(); - auto res = function_regex_start_only && first - ? builder.try_consume_regex(*function_regex_start_only) - : function_regex - ? builder.try_find_regex(*function_regex, from) - : std::nullopt; - - if (res) { - std::string name; - if (get_function_name) { - name = get_function_name(*res); - } else { - GGML_ASSERT(res->groups.size() == 2); - name = builder.str(res->groups[1]); - } - first = false; - if (name.empty()) { - // get_function_name signalled us that we should skip this match and treat it as content. - from = res->groups[0].begin + 1; - continue; - } - from = std::string::npos; - - auto maybe_raw_python = name == "python" && allow_raw_python; - if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) { - if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { - if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_regex(close_regex); - } - continue; - } - if (maybe_raw_python) { - auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); - if (!builder.add_tool_call(name, "", arguments)) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - return; - } - throw common_chat_msg_partial_exception("incomplete tool call"); - } else { - builder.move_to(start_pos); - } - break; - } - if (block_close) { - builder.consume_regex(*block_close); - } - builder.consume_spaces(); - builder.add_content(builder.consume_rest()); - }; - if (block_open) { - if (auto res = builder.try_find_regex(*block_open)) { - parse_tool_calls(); - } else { - builder.add_content(builder.consume_rest()); - } - } else { - parse_tool_calls(); - } -} - -static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, const common_regex & prefix, size_t rstrip_prefix = 0) { - static const std::vector> args_paths = {{"arguments"}}; - if (auto res = builder.try_find_regex(prefix)) { - builder.move_back(rstrip_prefix); - auto tool_calls = builder.consume_json_with_dumped_args(args_paths); - if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call array"); - } - } else { - builder.add_content(builder.consume_rest()); - } -} - static void foreach_function(const json & tools, const std::function & fn) { for (const auto & tool : tools) { if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) { @@ -918,37 +810,6 @@ static common_chat_params common_chat_params_init_generic(const common_chat_temp data.format = COMMON_CHAT_FORMAT_GENERIC; return data; } -static void common_chat_parse_generic(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - static const std::vector> content_paths = { - {"response"}, - }; - static const std::vector> args_paths = { - {"tool_call", "arguments"}, - {"tool_calls", "arguments"}, - }; - auto data = builder.consume_json_with_dumped_args(args_paths, content_paths); - if (data.value.contains("tool_calls")) { - if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool calls"); - } - } else if (data.value.contains("tool_call")) { - if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } else if (data.value.contains("response")) { - const auto & response = data.value.at("response"); - builder.add_content(response.is_string() ? response.template get() : response.dump(2)); - if (data.is_partial) { - throw common_chat_msg_partial_exception("incomplete response"); - } - } else { - throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON"); - } -} static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -1173,28 +1034,6 @@ static common_chat_params common_chat_params_init_magistral(const common_chat_te return data; } -static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex prefix(regex_escape("[TOOL_CALLS]")); - parse_prefixed_json_tool_call_array(builder, prefix); -} - -static void common_chat_parse_magistral(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("[THINK]", "[/THINK]"); - - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex prefix(regex_escape("[TOOL_CALLS]")); - parse_prefixed_json_tool_call_array(builder, prefix); -} - static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -1275,39 +1114,6 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ return data; } -static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>"); - - static const common_regex start_action_regex("<\\|START_ACTION\\|>"); - static const common_regex end_action_regex("<\\|END_ACTION\\|>"); - static const common_regex start_response_regex("<\\|START_RESPONSE\\|>"); - static const common_regex end_response_regex("<\\|END_RESPONSE\\|>"); - - if (auto res = builder.try_find_regex(start_action_regex)) { - // If we didn't extract thoughts, prelude includes them. - auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}}); - for (const auto & tool_call : tool_calls.value) { - std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; - std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : ""; - std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : ""; - if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - if (tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_regex(end_action_regex); - } else if (auto res = builder.try_find_regex(start_response_regex)) { - if (!builder.try_find_regex(end_response_regex)) { - builder.add_content(builder.consume_rest()); - throw common_chat_msg_partial_exception(end_response_regex.str()); - } - } else { - builder.add_content(builder.consume_rest()); - } -} - static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector & expected_properties) { if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) { throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties"); @@ -1536,63 +1342,6 @@ static common_chat_params common_chat_params_init_apertus(const common_chat_temp } return data; } -static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { - builder.try_parse_reasoning("", ""); - - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex function_regex( - "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: "); - static const common_regex close_regex("\\}\\s*"); - - static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\("); - static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*"); - - if (with_builtin_tools) { - static const common_regex builtin_call_regex("<\\|python_tag\\|>"); - if (auto res = builder.try_find_regex(builtin_call_regex)) { - auto fun_res = builder.consume_regex(function_name_regex); - auto function_name = builder.str(fun_res.groups[1]); - - common_healing_marker healing_marker; - json args = json::object(); - while (true) { - if (auto arg_res = builder.try_consume_regex(arg_name_regex)) { - auto arg_name = builder.str(arg_res->groups[1]); - auto partial = builder.consume_json(); - args[arg_name] = partial.json; - healing_marker.marker = partial.healing_marker.marker; - healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker; - builder.consume_spaces(); - if (!builder.try_consume_literal(",")) { - break; - } - } else { - break; - } - } - builder.consume_literal(")"); - builder.consume_spaces(); - - auto arguments = args.dump(); - if (!builder.add_tool_call(function_name, "", arguments)) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - return; - } - } - parse_json_tool_calls( - builder, - /* block_open= */ std::nullopt, - /* function_regex_start_only= */ function_regex, - /* function_regex= */ std::nullopt, - close_regex, - std::nullopt); - -} static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -1732,88 +1481,6 @@ static common_chat_params common_chat_params_init_deepseek_v3_1(const common_cha return data; } -static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("", ""); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); - static const common_regex tool_calls_end("<|tool▁calls▁end|>"); - static const common_regex function_regex("(?:<|tool▁call▁begin|>)?function<|tool▁sep|>([^\n]+)\n```json\n"); - static const common_regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); - - parse_json_tool_calls( - builder, - /* block_open= */ tool_calls_begin, - /* function_regex_start_only= */ std::nullopt, - function_regex, - close_regex, - tool_calls_end); -} - -static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) { - static const common_regex function_regex("(?:<|tool▁call▁begin|>)?([^\\n<]+)(?:<|tool▁sep|>)"); - - static const common_regex close_regex("(?:[\\s]*)?<|tool▁call▁end|>"); - static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); - static const common_regex tool_calls_end("<|tool▁calls▁end|>"); - - if (!builder.syntax().parse_tool_calls) { - LOG_DBG("%s: not parse_tool_calls\n", __func__); - builder.add_content(builder.consume_rest()); - return; - } - - LOG_DBG("%s: parse_tool_calls\n", __func__); - - parse_json_tool_calls( - builder, - /* block_open= */ tool_calls_begin, - /* function_regex_start_only= */ std::nullopt, - function_regex, - close_regex, - tool_calls_end); -} - -static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) { - // DeepSeek V3.1 outputs reasoning content between "" and "" tags, followed by regular content - // First try to parse using the standard reasoning parsing method - LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str()); - - auto start_pos = builder.pos(); - auto found_end_think = builder.try_find_literal(""); - builder.move_to(start_pos); - - if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) { - LOG_DBG("%s: no end_think, not partial, adding content\n", __func__); - common_chat_parse_deepseek_v3_1_content(builder); - } else if (builder.try_parse_reasoning("", "")) { - // If reasoning was parsed successfully, the remaining content is regular content - LOG_DBG("%s: parsed reasoning, adding content\n", __func__); - // <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>NAME\n```json\nJSON\n```<|tool▁call▁end|><|tool▁calls▁end|> - common_chat_parse_deepseek_v3_1_content(builder); - } else { - if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) { - LOG_DBG("%s: reasoning_format none, adding content\n", __func__); - common_chat_parse_deepseek_v3_1_content(builder); - return; - } - // If no reasoning tags found, check if we should treat everything as reasoning - if (builder.syntax().thinking_forced_open) { - // If thinking is forced open but no tags found, treat everything as reasoning - LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__); - builder.add_reasoning_content(builder.consume_rest()); - } else { - LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__); - // <|tool▁call▁begin|>NAME<|tool▁sep|>JSON<|tool▁call▁end|> - common_chat_parse_deepseek_v3_1_content(builder); - } - } -} - - static common_chat_params common_chat_params_init_minimax_m2(const common_chat_template & tmpl, const struct templates_params & params) { common_chat_params data; data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -1856,20 +1523,6 @@ static common_chat_params common_chat_params_init_minimax_m2(const common_chat_t return data; } -static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) { - static const xml_tool_call_format form { - /* form.scope_start = */ "", - /* form.tool_start = */ "", - /* form.key_start = */ "", - /* form.val_end = */ "", - /* form.tool_end = */ "", - /* form.scope_end = */ "", - }; - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_chat_template & tmpl, const struct templates_params & params) { common_chat_params data; data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -1902,23 +1555,6 @@ static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_c return data; } -static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = ""; - form.tool_start = "", ""); -} - static common_chat_params common_chat_params_init_apriel_1_5(const common_chat_template & tmpl, const struct templates_params & params) { common_chat_params data; data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -2016,25 +1634,6 @@ static common_chat_params common_chat_params_init_apriel_1_5(const common_chat_t return data; } -static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = "["; - form.tool_start = "{\"name\": \""; - form.tool_sep = "\", \"arguments\": {"; - form.key_start = "\""; - form.key_val_sep = "\": "; - form.val_end = ", "; - form.tool_end = "}, "; - form.scope_end = "]"; - form.raw_argval = false; - form.last_val_end = ""; - form.last_tool_end = "}"; - return form; - })(); - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - static common_chat_params common_chat_params_init_xiaomi_mimo(const common_chat_template & tmpl, const struct templates_params & params) { common_chat_params data; data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -2067,24 +1666,6 @@ static common_chat_params common_chat_params_init_xiaomi_mimo(const common_chat_ return data; } -static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = ""; - form.tool_start = "\n{\"name\": \""; - form.tool_sep = "\", \"arguments\": {"; - form.key_start = "\""; - form.key_val_sep = "\": "; - form.val_end = ", "; - form.tool_end = "}\n"; - form.scope_end = ""; - form.raw_argval = false; - form.last_val_end = ""; - return form; - })(); - builder.consume_reasoning_with_xml_tool_calls(form); -} - static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -2231,93 +1812,6 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp return data; } -static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) { - static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))"; - static const std::string recipient("(?: to=functions\\.([^<\\s]+))"); - - static const common_regex start_regex("<\\|start\\|>assistant"); - static const common_regex analysis_regex("<\\|channel\\|>analysis"); - static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?"); - static const common_regex preamble_regex("<\\|channel\\|>commentary"); - static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?"); - static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?"); - - auto consume_end = [&](bool include_end = false) { - if (auto res = builder.try_find_literal("<|end|>")) { - return res->prelude + (include_end ? builder.str(res->groups[0]) : ""); - } - return builder.consume_rest(); - }; - - auto handle_tool_call = [&](const std::string & name) { - if (auto args = builder.try_consume_json_with_dumped_args({{}})) { - if (builder.syntax().parse_tool_calls) { - if (!builder.add_tool_call(name, "", args->value) || args->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } else if (args->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - }; - - auto regex_match = [](const common_regex & regex, const std::string & input) -> std::optional { - auto match = regex.search(input, 0, true); - if (match.type == COMMON_REGEX_MATCH_TYPE_FULL) { - return match; - } - return std::nullopt; - }; - - do { - auto header_start_pos = builder.pos(); - auto content_start = builder.try_find_literal("<|message|>"); - if (!content_start) { - throw common_chat_msg_partial_exception("incomplete header"); - } - - auto header = content_start->prelude; - - if (auto match = regex_match(tool_call1_regex, header)) { - auto group = match->groups[1]; - auto name = header.substr(group.begin, group.end - group.begin); - handle_tool_call(name); - continue; - } - - if (auto match = regex_match(tool_call2_regex, header)) { - auto group = match->groups[2]; - auto name = header.substr(group.begin, group.end - group.begin); - handle_tool_call(name); - continue; - } - - if (regex_match(analysis_regex, header)) { - builder.move_to(header_start_pos); - if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) { - builder.add_content(consume_end(true)); - } else { - builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>"); - } - continue; - } - - if(regex_match(final_regex, header) || regex_match(preamble_regex, header)) { - builder.add_content(consume_end()); - continue; - } - - // Possibly a malformed message, attempt to recover by rolling - // back to pick up the next <|start|> - LOG_DBG("%s: unknown header from message: %s\n", __func__, header.c_str()); - builder.move_to(header_start_pos); - } while (builder.try_find_regex(start_regex, std::string::npos, false)); - - auto remaining = builder.consume_rest(); - if (!remaining.empty()) { - LOG_DBG("%s: content after last message: %s\n", __func__, remaining.c_str()); - } -} static common_chat_params common_chat_params_init_glm_4_5(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -2398,21 +1892,6 @@ static common_chat_params common_chat_params_init_glm_4_5(const common_chat_temp return data; } -static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) { - static const xml_tool_call_format form { - /* form.scope_start = */ "", - /* form.tool_start = */ "", - /* form.tool_sep = */ "", - /* form.key_start = */ "", - /* form.key_val_sep = */ "", - /* form.val_end = */ "", - /* form.tool_end = */ "", - /* form.scope_end = */ "", - /* form.key_val_sep2 = */ "", - }; - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { LOG_DBG("%s\n", __func__); common_chat_params data; @@ -2460,14 +1939,6 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c } return data; } -static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - static const common_regex prefix(regex_escape(" functools[")); - parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1); -} static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) { // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... @@ -2518,34 +1989,6 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ } return data; } -static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) { - static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))"); - static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))"); - static const common_regex close_regex(R"(\s*)"); - - parse_json_tool_calls( - builder, - std::nullopt, - function_regex_start_only, - function_regex, - close_regex, - std::nullopt, - /* allow_raw_python= */ true, - /* get_function_name= */ [&](const auto & res) -> std::string { - auto at_start = res.groups[0].begin == 0; - auto name = builder.str(res.groups[1]); - if (!name.empty() && name.back() == '{') { - // Unconsume the opening brace '{' to ensure the JSON parsing goes well. - builder.move_back(1); - } - auto idx = name.find_last_not_of("\n{"); - name = name.substr(0, idx + 1); - if (at_start && name == "all") { - return ""; - } - return name; - }); -} static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) { // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt @@ -2605,31 +2048,6 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con // TODO: if (has_raw_python) return data; } -static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - // This version of Functionary still supports the llama 3.1 tool call format for the python tool. - static const common_regex python_tag_regex(regex_escape("<|python_tag|>")); - - static const common_regex function_regex(R"()"); - static const common_regex close_regex(R"()"); - - parse_json_tool_calls( - builder, - /* block_open= */ std::nullopt, - /* function_regex_start_only= */ std::nullopt, - function_regex, - close_regex, - std::nullopt); - - if (auto res = builder.try_find_regex(python_tag_regex)) { - auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); - builder.add_tool_call("python", "", arguments); - return; - } -} static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -2746,83 +2164,6 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat return data; } -static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("", ""); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex open_regex( - "(?:" - "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start) - "(" // match 2 (open_tag) - "" - "|" - "|" - "|" - "|" - "|" - "|" - "|" - ")?" - "(\\s*\\{\\s*\"name\")" // match 3 (named tool call) - ")" - "|]+)>" // match 4 (function name) - "|" // match 5 (function name again) - ); - - while (auto res = builder.try_find_regex(open_regex)) { - const auto & block_start = res->groups[1]; - std::string block_end = block_start.empty() ? "" : "```"; - - const auto & open_tag = res->groups[2]; - std::string close_tag; - - if (!res->groups[3].empty()) { - builder.move_to(res->groups[3].begin); - close_tag = open_tag.empty() ? "" : "value) || tool_call->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_spaces(); - builder.consume_literal(close_tag); - builder.consume_spaces(); - if (!block_end.empty()) { - builder.consume_literal(block_end); - builder.consume_spaces(); - } - } else { - throw common_chat_msg_partial_exception("failed to parse tool call"); - } - } else { - auto function_name = builder.str(res->groups[4]); - if (function_name.empty()) { - function_name = builder.str(res->groups[5]); - } - GGML_ASSERT(!function_name.empty()); - - close_tag = ""; - - if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { - if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_spaces(); - builder.consume_literal(close_tag); - builder.consume_spaces(); - if (!block_end.empty()) { - builder.consume_literal(block_end); - builder.consume_spaces(); - } - } - } - } - - builder.add_content(builder.consume_rest()); -} static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; @@ -2905,190 +2246,6 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp return data; } -static void common_chat_parse_granite(common_chat_msg_parser & builder) { - // Parse thinking tags - static const common_regex start_think_regex(regex_escape("")); - static const common_regex end_think_regex(regex_escape("")); - // Granite models output partial tokens such as "<" and "groups[0].begin); - builder.try_find_regex(end_think_regex, std::string::npos, false); - // Restore position for try_parse_reasoning() - builder.move_to(res->groups[0].begin); - } - builder.try_parse_reasoning("", ""); - - // Parse response tags - static const common_regex start_response_regex(regex_escape("")); - static const common_regex end_response_regex(regex_escape("")); - // Granite models output partial tokens such as "<" and "")); - if (auto res = builder.try_find_regex(tool_call_regex)) { - builder.move_to(res->groups[0].end); - - // Expect JSON array of tool calls - if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) { - if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - } else { - builder.add_content(builder.consume_rest()); - } -} - -static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) { - // Parse thinking tags - builder.try_parse_reasoning("", ""); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - // Look for tool calls - static const common_regex tool_call_regex(regex_escape("")); - if (auto res = builder.try_find_regex(tool_call_regex)) { - builder.move_to(res->groups[0].end); - - // Expect JSON array of tool calls - auto tool_calls_data = builder.consume_json(); - if (tool_calls_data.json.is_array()) { - if (!builder.try_consume_literal("")) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - builder.add_tool_calls(tool_calls_data.json); - } else { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - } - builder.add_content(builder.consume_rest()); -} - -static void common_chat_parse_apertus(common_chat_msg_parser & builder) { - // Parse thinking tags - builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>"); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - // Look for tool calls - static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>")); - if (auto res = builder.try_find_regex(tool_call_regex)) { - builder.move_to(res->groups[0].end); - - auto tool_calls_data = builder.consume_json(); - if (tool_calls_data.json.is_array()) { - builder.consume_spaces(); - if (!builder.try_consume_literal("<|tools_suffix|>")) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - for (const auto & value : tool_calls_data.json) { - if (value.is_object()) { - builder.add_tool_call_short_form(value); - } - } - } else { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - } - builder.add_content(builder.consume_rest()); -} - - -static void common_chat_parse_lfm2(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - // LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|> - static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>")); - static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>")); - - // Loop through all tool calls - while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) { - builder.move_to(res->groups[0].end); - - // Parse JSON array format: [{"name": "...", "arguments": {...}}] - auto tool_calls_data = builder.consume_json(); - - // Consume end marker - builder.consume_spaces(); - if (!builder.try_consume_regex(tool_call_end_regex)) { - throw common_chat_msg_partial_exception("Expected <|tool_call_end|>"); - } - - // Process each tool call in the array - if (tool_calls_data.json.is_array()) { - for (const auto & tool_call : tool_calls_data.json) { - if (!tool_call.is_object()) { - throw common_chat_msg_partial_exception("Tool call must be an object"); - } - - if (!tool_call.contains("name")) { - throw common_chat_msg_partial_exception("Tool call missing 'name' field"); - } - - std::string function_name = tool_call.at("name"); - std::string arguments = "{}"; - - if (tool_call.contains("arguments")) { - if (tool_call.at("arguments").is_object()) { - arguments = tool_call.at("arguments").dump(); - } else if (tool_call.at("arguments").is_string()) { - arguments = tool_call.at("arguments"); - } - } - - if (!builder.add_tool_call(function_name, "", arguments)) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - } - } else { - throw common_chat_msg_partial_exception("Expected JSON array for tool calls"); - } - - // Consume any trailing whitespace after this tool call - builder.consume_spaces(); - } - - // Consume any remaining content after all tool calls - auto remaining = builder.consume_rest(); - if (!string_strip(remaining).empty()) { - builder.add_content(remaining); - } -} - -static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { - static const xml_tool_call_format form { - /* form.scope_start = */ "", - /* form.tool_start = */ "", - /* form.key_start = */ "", - /* form.val_end = */ "", - /* form.tool_end = */ "", - /* form.scope_end = */ "", - }; - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; data.prompt = apply(tmpl, inputs); @@ -3428,112 +2585,3 @@ common_chat_params common_chat_templates_apply( ? common_chat_templates_apply_jinja(tmpls, inputs) : common_chat_templates_apply_legacy(tmpls, inputs); } - -static void common_chat_parse_content_only(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("", ""); - builder.add_content(builder.consume_rest()); -} - -static void common_chat_parse(common_chat_msg_parser & builder) { - LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str()); - - switch (builder.syntax().format) { - case COMMON_CHAT_FORMAT_CONTENT_ONLY: - common_chat_parse_content_only(builder); - break; - case COMMON_CHAT_FORMAT_GENERIC: - common_chat_parse_generic(builder); - break; - case COMMON_CHAT_FORMAT_MISTRAL_NEMO: - common_chat_parse_mistral_nemo(builder); - break; - case COMMON_CHAT_FORMAT_MAGISTRAL: - common_chat_parse_magistral(builder); - break; - case COMMON_CHAT_FORMAT_LLAMA_3_X: - common_chat_parse_llama_3_1(builder); - break; - case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: - common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true); - break; - case COMMON_CHAT_FORMAT_DEEPSEEK_R1: - common_chat_parse_deepseek_r1(builder); - break; - case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1: - common_chat_parse_deepseek_v3_1(builder); - break; - case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: - common_chat_parse_functionary_v3_2(builder); - break; - case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: - common_chat_parse_functionary_v3_1_llama_3_1(builder); - break; - case COMMON_CHAT_FORMAT_HERMES_2_PRO: - common_chat_parse_hermes_2_pro(builder); - break; - case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: - common_chat_parse_firefunction_v2(builder); - break; - case COMMON_CHAT_FORMAT_COMMAND_R7B: - common_chat_parse_command_r7b(builder); - break; - case COMMON_CHAT_FORMAT_GRANITE: - common_chat_parse_granite(builder); - break; - case COMMON_CHAT_FORMAT_GPT_OSS: - common_chat_parse_gpt_oss(builder); - break; - case COMMON_CHAT_FORMAT_SEED_OSS: - common_chat_parse_seed_oss(builder); - break; - case COMMON_CHAT_FORMAT_NEMOTRON_V2: - common_chat_parse_nemotron_v2(builder); - break; - case COMMON_CHAT_FORMAT_APERTUS: - common_chat_parse_apertus(builder); - break; - case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: - common_chat_parse_lfm2(builder); - break; - case COMMON_CHAT_FORMAT_MINIMAX_M2: - common_chat_parse_minimax_m2(builder); - break; - case COMMON_CHAT_FORMAT_GLM_4_5: - common_chat_parse_glm_4_5(builder); - break; - case COMMON_CHAT_FORMAT_KIMI_K2: - common_chat_parse_kimi_k2(builder); - break; - case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: - common_chat_parse_qwen3_coder_xml(builder); - break; - case COMMON_CHAT_FORMAT_APRIEL_1_5: - common_chat_parse_apriel_1_5(builder); - break; - case COMMON_CHAT_FORMAT_XIAOMI_MIMO: - common_chat_parse_xiaomi_mimo(builder); - break; - default: - throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); - } - builder.finish(); -} - -common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) { - common_chat_msg_parser builder(input, is_partial, syntax); - try { - common_chat_parse(builder); - } catch (const common_chat_msg_partial_exception & ex) { - LOG_DBG("Partial parse: %s\n", ex.what()); - if (!is_partial) { - builder.clear_tools(); - builder.move_to(0); - common_chat_parse_content_only(builder); - } - } - auto msg = builder.result(); - if (!is_partial) { - LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); - } - return msg; -} From d82b7a7c1d73c0674698d9601b1bbb0200933f29 Mon Sep 17 00:00:00 2001 From: Aleksei Nikiforov <103434461+AlekseiNikiforovIBM@users.noreply.github.com> Date: Fri, 28 Nov 2025 20:53:01 +0100 Subject: [PATCH 27/33] gguf-py : fix passing non-native endian tensors (editor-gui and new-metadata) (#17553) gguf_new_metadata.py reads data from reader. Reader doesn't byteswap tensors to native endianness. But writer does expect tensors in native endianness to convert them into requested endianness. There are two ways to fix this: update reader and do conversion to native endianness and back, or skip converting endianness in writer in this particular USE-case. gguf_editor_gui.py doesn't allow editing or viewing tensor data. Let's go with skipping excessive byteswapping. If eventually capability to view or edit tensor data is added, tensor data should be instead byteswapped when reading it. --- gguf-py/gguf/gguf_writer.py | 18 ++++++++++++------ gguf-py/gguf/scripts/gguf_editor_gui.py | 2 +- gguf-py/gguf/scripts/gguf_new_metadata.py | 2 +- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 57ca2035fe..8ddd895cb7 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -371,10 +371,13 @@ class GGUFWriter: def add_tensor( self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, - raw_dtype: GGMLQuantizationType | None = None, + raw_dtype: GGMLQuantizationType | None = None, tensor_endianess: GGUFEndian | None = None ) -> None: - if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \ - (self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'): + # if tensor endianness is not passed, assume it's native to system + if tensor_endianess is None: + tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE + + if tensor_endianess != self.endianess: # Don't byteswap inplace since lazy copies cannot handle it tensor = tensor.byteswap(inplace=False) if self.use_temp_file and self.temp_file is None: @@ -397,13 +400,16 @@ class GGUFWriter: if pad != 0: fp.write(bytes([0] * pad)) - def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: + def write_tensor_data(self, tensor: np.ndarray[Any, Any], tensor_endianess: GGUFEndian | None = None) -> None: if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS: raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}') assert self.fout is not None - if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \ - (self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'): + # if tensor endianness is not passed, assume it's native to system + if tensor_endianess is None: + tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE + + if tensor_endianess != self.endianess: # Don't byteswap inplace since lazy copies cannot handle it tensor = tensor.byteswap(inplace=False) diff --git a/gguf-py/gguf/scripts/gguf_editor_gui.py b/gguf-py/gguf/scripts/gguf_editor_gui.py index 05f4db0f8c..293316afed 100755 --- a/gguf-py/gguf/scripts/gguf_editor_gui.py +++ b/gguf-py/gguf/scripts/gguf_editor_gui.py @@ -1552,7 +1552,7 @@ class GGUFEditorWindow(QMainWindow): # Add tensors (including data) for tensor in self.reader.tensors: - writer.add_tensor(tensor.name, tensor.data, raw_shape=tensor.data.shape, raw_dtype=tensor.tensor_type) + writer.add_tensor(tensor.name, tensor.data, raw_shape=tensor.data.shape, raw_dtype=tensor.tensor_type, tensor_endianess=self.reader.endianess) # Write header and metadata writer.open_output_file(Path(file_path)) diff --git a/gguf-py/gguf/scripts/gguf_new_metadata.py b/gguf-py/gguf/scripts/gguf_new_metadata.py index 2fa5800cf7..c67436bad4 100755 --- a/gguf-py/gguf/scripts/gguf_new_metadata.py +++ b/gguf-py/gguf/scripts/gguf_new_metadata.py @@ -94,7 +94,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new writer.write_ti_data_to_file() for tensor in reader.tensors: - writer.write_tensor_data(tensor.data) + writer.write_tensor_data(tensor.data, tensor_endianess=reader.endianess) bar.update(tensor.n_bytes) writer.close() From 59d8d4e96341eb54f362ac3d583ef522566e2a39 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 29 Nov 2025 01:39:57 -0600 Subject: [PATCH 28/33] vulkan: improve topk perf for large k, fix overflow in unit tests (#17582) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 +++- tests/test-backend-ops.cpp | 8 +++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 73562bc1be..f3aba8165b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -10239,7 +10239,9 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons // Prefer going as small as num_topk_pipelines - 3 for perf reasons. // But if K is larger, then we need a larger workgroup - uint32_t max_pipeline = num_topk_pipelines - 3; + uint32_t max_pipeline = num_topk_pipelines - 1; + uint32_t preferred_pipeline = std::max(num_topk_pipelines - 3, (uint32_t)log2f(float(k)) + 2); + max_pipeline = std::min(preferred_pipeline, max_pipeline); uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1; // require full subgroup min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 60bab47b9f..87a61aa122 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1446,14 +1446,14 @@ struct test_case { const uint64_t target_flops_cpu = 8ULL * GFLOP; const uint64_t target_flops_gpu = 100ULL * GFLOP; uint64_t target_flops = is_cpu ? target_flops_cpu : target_flops_gpu; - n_runs = std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1; + n_runs = (int)std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1; } else { // based on memory size const size_t GB = 1ULL << 30; const size_t target_size_cpu = 8 * GB; const size_t target_size_gpu = 32 * GB; size_t target_size = is_cpu ? target_size_cpu : target_size_gpu; - n_runs = std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1; + n_runs = (int)std::min(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1; } // duplicate the op @@ -8043,7 +8043,9 @@ static std::vector> make_test_cases_perf() { } test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1})); - for (auto k : {1, 10, 40}) { + + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2, 1, 1, 1}, 1)); + for (auto k : {1, 10, 40, 400}) { for (auto nrows : {1, 16}) { for (auto cols : {k, 1000, 65000, 200000}) { test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {cols, nrows, 1, 1}, k)); From 47a268ea5000fc0f05fc1c5cd0062efebfe84b92 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sat, 29 Nov 2025 09:37:22 +0100 Subject: [PATCH 29/33] Vulkan: MMVQ Integer Dot K-Quant and MUL_MAT_ID support (#16900) * vulkan: split mul_mmq_funcs for mul_mat_vecq use * add mxfp4 mmvq * add q2_k mmvq * add q3_k mmvq * add q4_k and q5_k mmvq * add q6_k mmvq * handle 4x4 quants per mmvq thread * enable MUL_MAT_ID mmvq support * enable subgroup optimizations for mul_mat_vec_id shaders * device tuning * request prealloc_y sync after quantization * fix indentation * fix llvmpipe test failures * fix mul_mat_id mmvq condition * fix unused variable warning --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 244 ++++++++--- .../vulkan-shaders/dequant_funcs.glsl | 7 - .../vulkan-shaders/generic_binary_head.glsl | 7 + .../vulkan-shaders/generic_unary_head.glsl | 7 + .../vulkan-shaders/mul_mat_vec.comp | 1 + .../vulkan-shaders/mul_mat_vec_base.glsl | 2 - .../vulkan-shaders/mul_mat_vec_iface.glsl | 8 +- .../vulkan-shaders/mul_mat_vecq.comp | 62 ++- .../vulkan-shaders/mul_mat_vecq_funcs.glsl | 379 ++++++++++++++++++ .../ggml-vulkan/vulkan-shaders/mul_mmq.comp | 2 - .../vulkan-shaders/mul_mmq_funcs.glsl | 229 +++-------- .../vulkan-shaders/vulkan-shaders-gen.cpp | 22 +- 12 files changed, 682 insertions(+), 288 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index f3aba8165b..66dd0bfabd 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -613,9 +613,10 @@ struct vk_device_struct { vk_pipeline pipeline_dequant[GGML_TYPE_COUNT]; vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; - vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT]; vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + vk_pipeline pipeline_dequant_mul_mat_vec_id_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT]; vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio]; vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; @@ -1611,7 +1612,7 @@ class vk_perf_logger { } if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { const uint64_t m = node->src[0]->ne[1]; - const uint64_t n = node->ne[1]; + const uint64_t n = (node->op == GGML_OP_MUL_MAT) ? node->ne[1] : node->ne[2]; const uint64_t k = node->src[1]->ne[0]; const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3]; std::string name = ggml_op_name(node->op); @@ -3525,13 +3526,18 @@ static void ggml_vk_load_shaders(vk_device& device) { // the number of rows computed per shader depends on GPU model and quant uint32_t rm_stdq = 1; uint32_t rm_kq = 2; + uint32_t rm_stdq_int = 1; + uint32_t rm_kq_int = 1; if (device->vendor_id == VK_VENDOR_ID_AMD) { if (device->architecture == AMD_GCN) { rm_stdq = 2; rm_kq = 4; + rm_stdq_int = 4; } - } else if (device->vendor_id == VK_VENDOR_ID_INTEL) + } else if (device->vendor_id == VK_VENDOR_ID_INTEL) { rm_stdq = 2; + rm_stdq_int = 2; + } uint32_t rm_iq = 2 * rm_kq; const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN; @@ -3612,39 +3618,73 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size; const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_q8_1_f32", arr_dmmv_mxfp4_q8_1_f32_len[reduc], arr_dmmv_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_q8_1_f32", arr_dmmv_q2_k_q8_1_f32_len[reduc], arr_dmmv_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_q8_1_f32", arr_dmmv_q3_k_q8_1_f32_len[reduc], arr_dmmv_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); } #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT } + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", arr_dmmv_id_q4_1_f32_f32_len[reduc], arr_dmmv_id_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", arr_dmmv_id_q5_0_f32_f32_len[reduc], arr_dmmv_id_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", arr_dmmv_id_q5_1_f32_f32_len[reduc], arr_dmmv_id_q5_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", arr_dmmv_id_q8_0_f32_f32_len[reduc], arr_dmmv_id_q8_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", arr_dmmv_id_q2_k_f32_f32_len[reduc16], arr_dmmv_id_q2_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", arr_dmmv_id_q3_k_f32_f32_len[reduc16], arr_dmmv_id_q3_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", arr_dmmv_id_q4_k_f32_f32_len[reduc16], arr_dmmv_id_q4_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", arr_dmmv_id_q5_k_f32_f32_len[reduc16], arr_dmmv_id_q5_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", arr_dmmv_id_q6_k_f32_f32_len[reduc16], arr_dmmv_id_q6_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", arr_dmmv_id_iq1_s_f32_f32_len[reduc16], arr_dmmv_id_iq1_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", arr_dmmv_id_iq1_m_f32_f32_len[reduc16], arr_dmmv_id_iq1_m_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", arr_dmmv_id_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", arr_dmmv_id_iq2_xs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", arr_dmmv_id_iq2_s_f32_f32_len[reduc16], arr_dmmv_id_iq2_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", arr_dmmv_id_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq3_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", arr_dmmv_id_iq3_s_f32_f32_len[reduc16], arr_dmmv_id_iq3_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", arr_dmmv_id_iq4_xs_f32_f32_len[reduc16], arr_dmmv_id_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", arr_dmmv_id_iq4_nl_f32_f32_len[reduc16], arr_dmmv_id_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", arr_dmmv_id_mxfp4_f32_f32_len[reduc16], arr_dmmv_id_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size; + const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + } +#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT } - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); +#if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + GGML_UNUSED(rm_stdq_int); + GGML_UNUSED(rm_kq_int); +#endif // dequant shaders ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); @@ -5453,6 +5493,12 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_MXFP4: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: break; default: return nullptr; @@ -5592,9 +5638,28 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co } } -static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) { +static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t m, uint32_t k) { VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()"); - GGML_ASSERT(b_type == GGML_TYPE_F32); + GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_Q8_1); + + if (b_type == GGML_TYPE_Q8_1) { + switch (a_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_MXFP4: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + break; + default: + return nullptr; + } + } switch (a_type) { case GGML_TYPE_F32: @@ -5625,7 +5690,31 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context return nullptr; } - return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type]; + // heuristic to choose workgroup size + uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP; + if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) { + // Prefer larger workgroups when M is small, to spread the work out more + // and keep more SMs busy. + // q6_k seems to prefer small workgroup size even for "medium" values of M. + if (a_type == GGML_TYPE_Q6_K) { + if (m < 4096 && k >= 1024) { + dmmv_wg = DMMV_WG_SIZE_LARGE; + } + } else { + if (m <= 8192 && k >= 1024) { + dmmv_wg = DMMV_WG_SIZE_LARGE; + } + } + } + + if (b_type == GGML_TYPE_Q8_1) { + if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) { + dmmv_wg = DMMV_WG_SIZE_SUBGROUP; + } + return ctx->device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[dmmv_wg][a_type]; + } + + return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[dmmv_wg][a_type]; } static void * ggml_vk_host_malloc(vk_device& device, size_t size) { @@ -6817,20 +6906,35 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ return false; } + // General performance issue with q3_k and q6_k due to 2-byte alignment + if (src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) { + return false; + } + // MMVQ is generally good for batches if (n > 1) { return true; } + // Quantization overhead is not worth it for small k switch (device->vendor_id) { case VK_VENDOR_ID_NVIDIA: + if (k <= 4096) { + return false; + } + switch (src0_type) { + case GGML_TYPE_MXFP4: case GGML_TYPE_Q8_0: return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING; default: return true; } case VK_VENDOR_ID_AMD: + if (k < 2048) { + return false; + } + switch (src0_type) { case GGML_TYPE_Q8_0: return device->architecture == vk_device_architecture::AMD_GCN; @@ -6838,6 +6942,10 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ return true; } case VK_VENDOR_ID_INTEL: + if (k < 2048) { + return false; + } + switch (src0_type) { // From tests on A770 Linux, may need more tuning case GGML_TYPE_Q4_0: @@ -6851,7 +6959,6 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ } GGML_UNUSED(m); - GGML_UNUSED(k); } static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { @@ -7574,7 +7681,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& if (x_non_contig || qx_needs_dequant) { ctx->prealloc_x_need_sync = true; } - if (y_non_contig) { + if (y_non_contig || quantize_y) { ctx->prealloc_y_need_sync = true; } } @@ -7600,7 +7707,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte const uint64_t ne10 = src1->ne[0]; const uint64_t ne11 = src1->ne[1]; - // const uint64_t ne12 = src1->ne[2]; + const uint64_t ne12 = src1->ne[2]; // const uint64_t ne13 = src1->ne[3]; const uint64_t nei0 = ids->ne[0]; @@ -7617,19 +7724,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; - - const bool qx_needs_dequant = x_non_contig; - const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig; - - // Not implemented - GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - - const uint64_t x_ne = ggml_nelements(src0); - const uint64_t y_ne = ggml_nelements(src1); - - const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); - const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; - const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne12, ne10, src0->type); vk_pipeline to_fp16_vk_0 = nullptr; vk_pipeline to_fp16_vk_1 = nullptr; @@ -7641,11 +7736,38 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } - vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type); + + // Check for mmq first + vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, GGML_TYPE_Q8_1, ne20, ne00) : nullptr; + vk_pipeline to_q8_1 = nullptr; + + if (dmmv == nullptr) { + // Fall back to f16 dequant mul mat + dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type, ne20, ne00); + quantize_y = false; + } + + if (quantize_y) { + to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); + } + + const bool qx_needs_dequant = x_non_contig; + const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig); + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT GGML_ASSERT(dmmv != nullptr); + const uint64_t x_ne = ggml_nelements(src0); + const uint64_t y_ne = ggml_nelements(src1); + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; + const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : + (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); + { if ( (qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) || @@ -7656,7 +7778,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte ctx->prealloc_size_x = x_sz; ggml_vk_preallocate_buffers(ctx, subctx); } - if (qy_needs_dequant && ctx->prealloc_size_y < y_sz) { + if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) { ctx->prealloc_size_y = y_sz; ggml_vk_preallocate_buffers(ctx, subctx); } @@ -7668,6 +7790,9 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte if (qy_needs_dequant) { ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); } + if (quantize_y) { + ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); + } ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); } @@ -7683,7 +7808,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte } else { d_X = d_Qx; } - if (qy_needs_dequant) { + if (qy_needs_dequant || quantize_y) { d_Y = { ctx->prealloc_y, 0, ctx->prealloc_y->size }; } else { d_Y = d_Qy; @@ -7711,6 +7836,17 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte ctx->prealloc_y_last_tensor_used = src1; } } + if (quantize_y) { + if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne); + ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } + } uint32_t stride_batch_y = ne10*ne11; @@ -7772,7 +7908,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte if (x_non_contig) { ctx->prealloc_x_need_sync = true; } - if (y_non_contig) { + if (y_non_contig || quantize_y) { ctx->prealloc_y_need_sync = true; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 09676a623b..70ee542d96 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -4,13 +4,6 @@ #include "types.glsl" -#if defined(A_TYPE_PACKED16) -layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; -#endif -#if defined(A_TYPE_PACKED32) -layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; -#endif - #if defined(DATA_A_F32) vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl index c1ad517256..ba7909c4d3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl @@ -22,6 +22,13 @@ layout (push_constant) uniform parameter #if !RMS_NORM_ROPE_FUSION layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif + layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl index 8dc9d360d5..cc181fda87 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl @@ -18,6 +18,13 @@ layout (push_constant) uniform parameter } p; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif + layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; uint get_idx() { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index 9a03925cfd..b3c96576de 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -3,6 +3,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #include "mul_mat_vec_base.glsl" +#include "dequant_funcs.glsl" layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl index e4651a683b..cfc8b0c7f4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl @@ -13,8 +13,6 @@ #include "mul_mat_vec_iface.glsl" -#include "dequant_funcs.glsl" - layout (push_constant) uniform parameter { uint ncols; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl index 14ab1fd74c..337dbd796a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl @@ -5,13 +5,15 @@ #define MAT_VEC_FUSION_FLAGS_SCALE0 0x4 #define MAT_VEC_FUSION_FLAGS_SCALE1 0x8 -#ifndef MMQ layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; #if defined(A_TYPE_VEC4) layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; #endif -#else -layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; #endif layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp index 64293f6eca..15f005be3e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp @@ -10,60 +10,56 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4) #define K_PER_ITER 8 - -#include "mul_mmq_funcs.glsl" +#elif defined(DATA_A_QUANT_K) +#define K_PER_ITER 16 +#else +#error unimplemented +#endif uint a_offset, b_offset, d_offset; -int32_t cache_b_qs[2]; +int32_t cache_b_qs[K_PER_ITER / 4]; vec2 cache_b_ds; +#include "mul_mat_vecq_funcs.glsl" + void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { const uint col = i*BLOCK_SIZE + tid*K_PER_ITER; // Preload data_b block const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset; - const uint b_qs_idx = tid % 4; + const uint b_qs_idx = tid % (32 / K_PER_ITER); const uint b_block_idx_outer = b_block_idx / 4; const uint b_block_idx_inner = b_block_idx % 4; cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]); #if QUANT_R == 2 + // Assumes K_PER_ITER == 8 cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx]; cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4]; #else +#if K_PER_ITER == 8 cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2]; cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1]; +#elif K_PER_ITER == 16 + cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 ]; + cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1]; + cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2]; + cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3]; +#else +#error unimplemented +#endif #endif uint ibi = first_row*p.ncols; [[unroll]] for (uint n = 0; n < num_rows; ++n) { - const uint a_block_idx = (ibi + col)/QUANT_K + a_offset; + const uint a_block_idx = (ibi + col)/QUANT_K_Q8_1 + a_offset; ibi += p.ncols; - int32_t q_sum = 0; -#if QUANT_R == 2 - const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx); - q_sum += dotPacked4x8EXT(data_a_qs.x, - cache_b_qs[0]); - q_sum += dotPacked4x8EXT(data_a_qs.y, - cache_b_qs[1]); -#else - int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2); - q_sum += dotPacked4x8EXT(data_a_qs, - cache_b_qs[0]); - data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1); - q_sum += dotPacked4x8EXT(data_a_qs, - cache_b_qs[1]); -#endif - -#if QUANT_AUXF == 1 - temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4); -#else - temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4); -#endif + temp[j][n] += mmvq_dot_product(a_block_idx, b_qs_idx); } } } @@ -72,7 +68,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint tid = gl_LocalInvocationID.x; get_offsets(a_offset, b_offset, d_offset); - a_offset /= QUANT_K; + a_offset /= QUANT_K_Q8_1; b_offset /= QUANT_K_Q8_1; FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; @@ -102,14 +98,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { unroll_count = 2; unrolled_iters = num_iters & ~(unroll_count - 1); -#if K_PER_ITER == 2 - if ((p.ncols & 1) != 0 && - unrolled_iters == num_iters && - unrolled_iters > 0) { - unrolled_iters -= unroll_count; - } -#endif - while (i < unrolled_iters) { // Manually partially unroll the loop [[unroll]] for (uint k = 0; k < unroll_count; ++k) { @@ -128,6 +116,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { void main() { const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + // do NUM_ROWS at a time, unless there aren't enough remaining rows if (first_row + NUM_ROWS <= p.stride_d) { compute_outputs(first_row, NUM_ROWS); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl new file mode 100644 index 0000000000..2389ea0b1e --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl @@ -0,0 +1,379 @@ +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require + +#include "types.glsl" + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +FLOAT_TYPE get_dm(uint ib) { + return FLOAT_TYPE(data_a[ib].d); +} +#endif + +#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) +FLOAT_TYPE_VEC2 get_dm(uint ib) { + return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); +} +#endif + +#if defined(DATA_A_MXFP4) +FLOAT_TYPE get_dm(uint ib) { + return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e)); +} +#endif + +#if defined(DATA_A_Q2_K) +FLOAT_TYPE_VEC2 get_dm(uint ib) { + const uint ib_k = ib / 8; + return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); +} +#endif + +// Each iqs value maps to a 32-bit integer +#if defined(DATA_A_Q4_0) +// 2-byte loads for Q4_0 blocks (18 bytes) +i32vec2 repack(uint ib, uint iqs) { + const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y)); +} +#endif + +#if defined(DATA_A_Q4_1) +// 4-byte loads for Q4_1 blocks (20 bytes) +i32vec2 repack(uint ib, uint iqs) { + const uint32_t vui = data_a_packed32[ib].qs[iqs]; + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); +} +#endif + +#if defined(DATA_A_Q5_0) +// 2-byte loads for Q5_0 blocks (22 bytes) +i32vec2 repack(uint ib, uint iqs) { + const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs)); + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + + const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + return i32vec2(v0, v1); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y)); +} +#endif + +#if defined(DATA_A_Q5_1) +// 4-byte loads for Q5_1 blocks (24 bytes) +i32vec2 repack(uint ib, uint iqs) { + const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + + const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + return i32vec2(v0, v1); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); +} +#endif + +#if defined(DATA_A_Q8_0) +// 2-byte loads for Q8_0 blocks (34 bytes) +int32_t repack(uint ib, uint iqs) { + return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ], + data_a_packed16[ib].qs[iqs * 2 + 1])); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(float(q_sum) * da * dsb.x); +} +#endif + +#if defined(DATA_A_MXFP4) +// 1-byte loads for mxfp4 blocks (17 bytes) +i32vec2 repack(uint ib, uint iqs) { + const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], + data_a[ib].qs[iqs * 4 + 1], + data_a[ib].qs[iqs * 4 + 2], + data_a[ib].qs[iqs * 4 + 3])); + + const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F); + const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F); + + return i32vec2(pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w])), + pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w]))); +} + +FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return FLOAT_TYPE(da * dsb.x * float(q_sum) * 0.5); +} +#endif + +#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4) +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t q_sum = 0; +#if QUANT_R == 2 + const i32vec2 data_a_qs = repack(ib_a, iqs); + q_sum += dotPacked4x8EXT(data_a_qs.x, + cache_b_qs[0]); + q_sum += dotPacked4x8EXT(data_a_qs.y, + cache_b_qs[1]); +#else + int32_t data_a_qs = repack(ib_a, iqs * 2); + q_sum += dotPacked4x8EXT(data_a_qs, + cache_b_qs[0]); + data_a_qs = repack(ib_a, iqs * 2 + 1); + q_sum += dotPacked4x8EXT(data_a_qs, + cache_b_qs[1]); +#endif + + // 2 quants per call => divide sums by 8/2 = 4 + return mul_q8_1(q_sum, get_dm(ib_a), cache_b_ds, 4); +} +#endif + +#if defined(DATA_A_Q2_K) +// 4-byte loads for Q2_K blocks (84 bytes) +i32vec4 repack4(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 32) / 8) * 2; + + return i32vec4((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303, + (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303, + (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303, + (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303); +} + +uint8_t get_scale(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + return data_a[ib_k].scales[iqs_k / 4]; +} + +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t sum_d = 0; + int32_t sum_m = 0; + + const i32vec4 qs_a = repack4(ib_a, iqs * 4); + const uint8_t scale = get_scale(ib_a, iqs * 4); + const vec2 dm = vec2(get_dm(ib_a)); + const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits. + + sum_d += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]) * (scale & 0xF); + sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[0]); + + sum_d += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]) * (scale & 0xF); + sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[1]); + + sum_d += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]) * (scale & 0xF); + sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[2]); + + sum_d += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]) * (scale & 0xF); + sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[3]); + + return FLOAT_TYPE(float(cache_b_ds.x) * (float(dm.x) * float(sum_d) - float(dm.y) * float(sum_m))); +} +#endif + +#if defined(DATA_A_Q3_K) +// 2-byte loads for Q3_K blocks (110 bytes) +i32vec4 repack4(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 32) / 8) * 2; + const uint hm_shift = iqs_k / 8; + + // bitwise OR to add 4 if hmask is set, subtract later + const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals20 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 4] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 4] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals21 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 5] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 5] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals30 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 6] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 6] >> hm_shift) & uint16_t(0x0101)) << 2)); + const i8vec2 vals31 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 7] >> qs_shift) & uint16_t(0x0303))) | + unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 7] >> hm_shift) & uint16_t(0x0101)) << 2)); + + return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y) - int8_t(4)), + pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y) - int8_t(4)), + pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y) - int8_t(4)), + pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y) - int8_t(4))); +} + +float get_d_scale(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + const uint is = iqs_k / 4; + + const int8_t scale = int8_t(((data_a[ib_k].scales[is % 8 ] >> (4 * (is / 8))) & 0x0F0F) | + (((data_a[ib_k].scales[8 + (is % 4)] >> (2 * (is / 4))) & 0x0303) << 4)); + return float(data_a[ib_k].d) * float(scale - 32); +} + +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t q_sum = 0; + + const i32vec4 qs_a = repack4(ib_a, iqs * 4); + const float d_scale = get_d_scale(ib_a, iqs * 4); + + q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]); + q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]); + q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]); + q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]); + + return FLOAT_TYPE(float(cache_b_ds.x) * d_scale * float(q_sum)); +} +#endif + +#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) +// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes) +i32vec4 repack4(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8); + const uint qs_shift = ((iqs_k % 16) / 8) * 4; + +#if defined(DATA_A_Q4_K) + const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F; + const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F; + const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F; + const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F; + + return i32vec4(vals0, vals1, vals2, vals3); +#else // defined(DATA_A_Q5_K) + const uint qh_idx = iqs; + const uint qh_shift = iqs_k / 8; + + return i32vec4(((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F) | + (((data_a_packed32[ib_k].qh[qh_idx ] >> qh_shift) & 0x01010101) << 4), + ((data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F) | + (((data_a_packed32[ib_k].qh[qh_idx + 1] >> qh_shift) & 0x01010101) << 4), + ((data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F) | + (((data_a_packed32[ib_k].qh[qh_idx + 2] >> qh_shift) & 0x01010101) << 4), + ((data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F) | + (((data_a_packed32[ib_k].qh[qh_idx + 3] >> qh_shift) & 0x01010101) << 4)); +#endif +} + +vec2 get_dm_scale(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + const uint is = iqs_k / 8; + u8vec2 scale_dm; + if (is < 4) { + scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F); + } else { + scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2), + (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2)); + } + + return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm); +} + +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t q_sum = 0; + + const i32vec4 qs_a = repack4(ib_a, iqs * 4); + const vec2 dm_scale = get_dm_scale(ib_a, iqs * 4); + + q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]); + q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]); + q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]); + q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]); + + return FLOAT_TYPE(float(cache_b_ds.x) * float(dm_scale.x) * float(q_sum) - float(dm_scale.y) * float(cache_b_ds.y / 2)); +} +#endif + +#if defined(DATA_A_Q6_K) +// 2-byte loads for Q6_K blocks (210 bytes) +i32vec4 repack4(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + + const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16; + const uint ql_shift = ((iqs_k % 32) / 16) * 4; + + const uint qh_idx = (iqs_k / 32) * 8 + iqs; + const uint qh_shift = ((iqs_k % 32) / 8) * 2; + + const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals10 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 2] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 2] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals11 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 3] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 3] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals20 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 4] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 4] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals21 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 5] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 5] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals30 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 6] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 6] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + const i8vec2 vals31 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 7] >> ql_shift) & uint16_t(0x0F0F))) | + unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 7] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); + + return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y)), + pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y)), + pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y)), + pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y))); +} + +float get_d_scale(uint ib, uint iqs) { + const uint ib_k = ib / 8; + const uint iqs_k = (ib % 8) * 8 + iqs; + return float(data_a[ib_k].d) * float(data_a[ib_k].scales[iqs_k / 4]); +} + +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t q_sum = 0; + + const i32vec4 qs_a = repack4(ib_a, iqs * 4); + const float d_scale = get_d_scale(ib_a, iqs * 4); + + q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]); + q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]); + q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]); + q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]); + + return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum)); +} +#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index 5266e523b9..dc8b3df47b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -78,8 +78,6 @@ layout (constant_id = 10) const uint WARP = 32; #define BK 32 -#define MMQ_SHMEM - #include "mul_mmq_shmem_types.glsl" #ifdef MUL_MAT_ID diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl index 4e3a561142..7f32dadf17 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -9,31 +9,6 @@ #if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) // 2-byte loads for Q4_0 blocks (18 bytes) // 4-byte loads for Q4_1 blocks (20 bytes) -i32vec2 repack(uint ib, uint iqs) { -#ifdef DATA_A_Q4_0 - const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], - data_a_packed16[ib].qs[iqs * 2 + 1]); - const uint32_t vui = pack32(quants); - return i32vec2( vui & 0x0F0F0F0F, - (vui >> 4) & 0x0F0F0F0F); -#else // DATA_A_Q4_1 - const uint32_t vui = data_a_packed32[ib].qs[iqs]; - return i32vec2( vui & 0x0F0F0F0F, - (vui >> 4) & 0x0F0F0F0F); -#endif -} - -#ifdef DATA_A_Q4_0 -ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y)); -} -#else // DATA_A_Q4_1 -ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); -} -#endif - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { #ifdef DATA_A_Q4_0 buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2], @@ -73,42 +48,17 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { q_sum += dotPacked4x8EXT(qs_a.y, qs_b1); } - return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); +#ifdef DATA_A_Q4_0 + return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 8.0 * float(cache_b.ds.y))); +#else // DATA_A_Q4_1 + return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y)); +#endif } -#endif // MMQ_SHMEM +#endif -#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) +#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) // 2-byte loads for Q5_0 blocks (22 bytes) // 4-byte loads for Q5_1 blocks (24 bytes) -i32vec2 repack(uint ib, uint iqs) { - const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ], - data_a_packed16[ib].qs[iqs * 2 + 1]); - const uint32_t vui = pack32(quants); -#ifdef DATA_A_Q5_0 - const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs)); -#else // DATA_A_Q5_1 - const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); -#endif - const int32_t v0 = int32_t(vui & 0x0F0F0F0F) - | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) - - const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) - | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) - - return i32vec2(v0, v1); -} - -#ifdef DATA_A_Q5_0 -ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y)); -} -#else // DATA_A_Q5_1 -ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); -} -#endif - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { #ifdef DATA_A_Q5_0 buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2], @@ -154,23 +104,16 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { q_sum += dotPacked4x8EXT(qs_a1, qs_b1); } - return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); +#ifdef DATA_A_Q5_0 + return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 16.0 * float(cache_b.ds.y))); +#else // DATA_A_Q5_1 + return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y)); +#endif } -#endif // MMQ_SHMEM #endif #if defined(DATA_A_Q8_0) // 2-byte loads for Q8_0 blocks (34 bytes) -int32_t repack(uint ib, uint iqs) { - return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ], - data_a_packed16[ib].qs[iqs * 2 + 1])); -} - -ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(float(q_sum) * da * dsb.x); -} - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2], data_a_packed16[ib].qs[iqs * 2 + 1])); @@ -197,28 +140,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { q_sum += dotPacked4x8EXT(qs_a, qs_b); } - return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); + return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm) * float(cache_b.ds.x)); } -#endif // MMQ_SHMEM #endif #if defined(DATA_A_MXFP4) // 1-byte loads for mxfp4 blocks (17 bytes) -i32vec2 repack(uint ib, uint iqs) { - const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], - data_a[ib].qs[iqs * 4 + 1], - data_a[ib].qs[iqs * 4 + 2], - data_a[ib].qs[iqs * 4 + 3])); - - return i32vec2( quants & 0x0F0F0F0F, - (quants >> 4) & 0x0F0F0F0F); -} - -ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(da * dsb.x * float(q_sum)); -} - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], data_a[ib].qs[iqs * 4 + 1], @@ -252,37 +179,14 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); } - return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1); + return ACC_TYPE(float(cache_a[ib_a].d) * float(cache_b.ds.x) * float(q_sum)); } -#endif // MMQ_SHMEM #endif // For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide // iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants #if defined(DATA_A_Q2_K) // 4-byte loads for Q2_K blocks (84 bytes) -int32_t repack(uint ib, uint iqs) { - const uint ib_k = ib / 8; - const uint iqs_k = (ib % 8) * 8 + iqs; - - const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); - const uint qs_shift = ((iqs_k % 32) / 8) * 2; - - return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303); -} - -uint8_t get_scale(uint ib, uint iqs) { - const uint ib_k = ib / 8; - const uint iqs_k = (ib % 8) * 8 + iqs; - - return data_a[ib_k].scales[iqs_k / 4]; -} - -ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m))); -} - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint ib_k = ib / 8; const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ; @@ -326,14 +230,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]); } - return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1); + return ACC_TYPE(float(cache_b.ds.x) * (float(cache_a[ib_a].dm.x) * float(sum_d) - float(cache_a[ib_a].dm.y) * float(sum_m))); } -#endif // MMQ_SHMEM #endif #if defined(DATA_A_Q3_K) // 2-byte loads for Q3_K blocks (110 bytes) -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint ib_k = ib / 8; const uint hm_idx = iqs * QUANT_R_MMQ; @@ -394,18 +296,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { } result += float(cache_a[ib_a].d_scales[1]) * float(q_sum); - return ACC_TYPE(cache_b.ds.x * result); + return ACC_TYPE(float(cache_b.ds.x) * result); } -#endif // MMQ_SHMEM #endif #if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) // 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes) -ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { - return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y); -} - -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint ib_k = ib / 8; const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ; @@ -427,7 +323,6 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { (((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4)); #endif - if (iqs == 0) { // Scale index const uint is = iqs_k / 8; @@ -464,49 +359,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); } - return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); -} -#endif // MMQ_SHMEM -#endif - -#ifdef MMQ_SHMEM -void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) { - if (is_in_bounds) { - const uint ib_outer = ib / 4; - const uint ib_inner = ib % 4; - - if (iqs == 0) { - buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); - } - - const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; - buf_b[buf_ib].qs[iqs * 4 ] = values.x; - buf_b[buf_ib].qs[iqs * 4 + 1] = values.y; - buf_b[buf_ib].qs[iqs * 4 + 2] = values.z; - buf_b[buf_ib].qs[iqs * 4 + 3] = values.w; - } else { - if (iqs == 0) { - buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f); - } - - buf_b[buf_ib].qs[iqs * 4 ] = 0; - buf_b[buf_ib].qs[iqs * 4 + 1] = 0; - buf_b[buf_ib].qs[iqs * 4 + 2] = 0; - buf_b[buf_ib].qs[iqs * 4 + 3] = 0; - } -} - -void block_b_to_registers(const uint ib) { - cache_b.ds = buf_b[ib].ds; - [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) { - cache_b.qs[iqs] = buf_b[ib].qs[iqs]; - } + return ACC_TYPE(float(cache_b.ds.x) * float(cache_a[ib_a].dm.x) * float(q_sum) - float(cache_a[ib_a].dm.y) * float(cache_b.ds.y)); } #endif #if defined(DATA_A_Q6_K) // 2-byte loads for Q6_K blocks (210 bytes) -#ifdef MMQ_SHMEM void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint ib_k = ib / 8; const uint iqs_k = (ib % 8) * 8 + iqs; @@ -558,32 +416,39 @@ ACC_TYPE mmq_dot_product(const uint ib_a) { } result += float(cache_a[ib_a].d_scales[1]) * float(q_sum); - return ACC_TYPE(cache_b.ds.x * result); -} -#endif // MMQ_SHMEM -#endif - -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) -FLOAT_TYPE get_d(uint ib) { - return FLOAT_TYPE(data_a[ib].d); + return ACC_TYPE(float(cache_b.ds.x) * result); } #endif -#if defined(DATA_A_MXFP4) -FLOAT_TYPE get_d(uint ib) { - return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e)); -} -#endif +void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) { + if (is_in_bounds) { + const uint ib_outer = ib / 4; + const uint ib_inner = ib % 4; -#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) -FLOAT_TYPE_VEC2 get_dm(uint ib) { - return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); -} -#endif + if (iqs == 0) { + buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); + } -#if defined(DATA_A_Q2_K) -FLOAT_TYPE_VEC2 get_dm(uint ib) { - const uint ib_k = ib / 8; - return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); + const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; + buf_b[buf_ib].qs[iqs * 4 ] = values.x; + buf_b[buf_ib].qs[iqs * 4 + 1] = values.y; + buf_b[buf_ib].qs[iqs * 4 + 2] = values.z; + buf_b[buf_ib].qs[iqs * 4 + 3] = values.w; + } else { + if (iqs == 0) { + buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f); + } + + buf_b[buf_ib].qs[iqs * 4 ] = 0; + buf_b[buf_ib].qs[iqs * 4 + 1] = 0; + buf_b[buf_ib].qs[iqs * 4 + 2] = 0; + buf_b[buf_ib].qs[iqs * 4 + 3] = 0; + } +} + +void block_b_to_registers(const uint ib) { + cache_b.ds = buf_b[ib].ds; + [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) { + cache_b.qs[iqs] = buf_b[ib].qs[iqs]; + } } -#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 214a743b97..92bae088b2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -679,14 +679,20 @@ void process_shaders() { string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); // mul mat vec with integer dot product #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (is_legacy_quant(tname)) { + if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname)) { string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); } #endif @@ -1100,7 +1106,7 @@ void write_output_files() { for (const std::string& btype : btypes) { for (const auto& tname : type_names) { - if (btype == "q8_1" && !is_legacy_quant(tname)) { + if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname)) { continue; } hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n"; @@ -1109,6 +1115,16 @@ void write_output_files() { src << "const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n"; src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n"; } + + if (btype == "f16") { + continue; + } + hdr << "extern const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3];\n"; + hdr << "extern const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3];\n"; + if (basename(input_filepath) == "mul_mat_vec.comp") { + src << "const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n"; + src << "const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_len, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n"; + } } } From f698a79c6396771a33e903930c5b5348b87a8715 Mon Sep 17 00:00:00 2001 From: ixgbe <1113177880@qq.com> Date: Sat, 29 Nov 2025 20:56:31 +0800 Subject: [PATCH 30/33] ggml: replace hwcap with riscv_hwprobe for RVV detection (#17567) Signed-off-by: Wang Yang --- ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp b/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp index b181898818..43c757bd01 100644 --- a/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp @@ -1,20 +1,23 @@ #include "ggml-backend-impl.h" #if defined(__riscv) && __riscv_xlen == 64 -#include - -//https://github.com/torvalds/linux/blob/master/arch/riscv/include/uapi/asm/hwcap.h#L24 -#ifndef COMPAT_HWCAP_ISA_V -#define COMPAT_HWCAP_ISA_V (1 << ('V' - 'A')) -#endif +#include +#include +#include struct riscv64_features { bool has_rvv = false; riscv64_features() { - uint32_t hwcap = getauxval(AT_HWCAP); + struct riscv_hwprobe probe; + probe.key = RISCV_HWPROBE_KEY_IMA_EXT_0; + probe.value = 0; - has_rvv = !!(hwcap & COMPAT_HWCAP_ISA_V); + int ret = syscall(__NR_riscv_hwprobe, &probe, 1, 0, NULL, 0); + + if (0 == ret) { + has_rvv = !!(probe.value & RISCV_HWPROBE_IMA_V); + } } }; From 7d2add51d8e3759020d70f2ff3a76b5795ff67bc Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Sat, 29 Nov 2025 20:59:44 +0800 Subject: [PATCH 31/33] sycl : support to malloc memory on device more than 4GB, update the doc and script (#17566) Co-authored-by: Neo Zhang Jianyu --- docs/backend/SYCL.md | 13 +++++++++++++ examples/sycl/run-llama2.sh | 3 +++ examples/sycl/run-llama3.sh | 9 ++++++--- examples/sycl/win-run-llama2.bat | 2 ++ examples/sycl/win-run-llama3.bat | 4 +++- ggml/src/ggml-sycl/CMakeLists.txt | 8 ++++++-- ggml/src/ggml-sycl/cpy.cpp | 3 --- 7 files changed, 33 insertions(+), 9 deletions(-) diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index 92ab27066b..02a72a9d51 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -42,6 +42,9 @@ The following releases are verified and recommended: ## News +- 2025.11 + - Support malloc memory on device more than 4GB. + - 2025.2 - Optimize MUL_MAT Q4_0 on Intel GPU for all dGPUs and built-in GPUs since MTL. Increase the performance of LLM (llama-2-7b.Q4_0.gguf) 21%-87% on Intel GPUs (MTL, ARL-H, Arc, Flex, PVC). |GPU|Base tokens/s|Increased tokens/s|Percent| @@ -789,6 +792,8 @@ use 1 SYCL GPUs: [0] with Max compute units:512 | GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. | | GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. | | ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.
Recommended to use when --split-mode = layer | +| UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Support malloc device memory more than 4GB.| + ## Known Issues @@ -835,6 +840,14 @@ use 1 SYCL GPUs: [0] with Max compute units:512 | The default context is too big. It leads to excessive memory usage.|Set `-c 8192` or a smaller value.| | The model is too big and requires more memory than what is available.|Choose a smaller model or change to a smaller quantization, like Q5 -> Q4;
Alternatively, use more than one device to load model.| +- `ggml_backend_sycl_buffer_type_alloc_buffer: can't allocate 5000000000 Bytes of memory on device` + + You need to enable to support 4GB memory malloc by: + ``` + export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 + set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 + ``` + ### **GitHub contribution**: Please add the `SYCL :` prefix/tag in issues/PRs titles to help the SYCL contributors to check/address them without delay. diff --git a/examples/sycl/run-llama2.sh b/examples/sycl/run-llama2.sh index 37195008de..a018e45197 100755 --- a/examples/sycl/run-llama2.sh +++ b/examples/sycl/run-llama2.sh @@ -15,6 +15,9 @@ MODEL_FILE=models/llama-2-7b.Q4_0.gguf NGL=99 CONTEXT=4096 +#support malloc device memory more than 4GB. +export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 + if [ $# -gt 0 ]; then GGML_SYCL_DEVICE=$1 echo "use $GGML_SYCL_DEVICE as main GPU" diff --git a/examples/sycl/run-llama3.sh b/examples/sycl/run-llama3.sh index 8e21b017f4..4770255703 100755 --- a/examples/sycl/run-llama3.sh +++ b/examples/sycl/run-llama3.sh @@ -6,7 +6,7 @@ # If you want more control, DPC++ Allows selecting a specific device through the # following environment variable -#export ONEAPI_DEVICE_SELECTOR="level_zero:0" +export ONEAPI_DEVICE_SELECTOR="level_zero:0" source /opt/intel/oneapi/setvars.sh #export GGML_SYCL_DEBUG=1 @@ -18,11 +18,14 @@ MODEL_FILE=models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf NGL=99 # Layers offloaded to the GPU. If the device runs out of memory, reduce this value according to the model you are using. CONTEXT=4096 +#support malloc device memory more than 4GB. +export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 + if [ $# -gt 0 ]; then GGML_SYCL_DEVICE=$1 echo "Using $GGML_SYCL_DEVICE as the main GPU" - ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none + ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none else #use multiple GPUs with same max compute units - ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT} + ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} fi diff --git a/examples/sycl/win-run-llama2.bat b/examples/sycl/win-run-llama2.bat index d7564f4161..b654f88f62 100644 --- a/examples/sycl/win-run-llama2.bat +++ b/examples/sycl/win-run-llama2.bat @@ -5,5 +5,7 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:" @call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force +:: support malloc device memory more than 4GB. +set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 .\build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 99 -s 0 diff --git a/examples/sycl/win-run-llama3.bat b/examples/sycl/win-run-llama3.bat index 4b61aebee5..608b834f60 100644 --- a/examples/sycl/win-run-llama3.bat +++ b/examples/sycl/win-run-llama3.bat @@ -5,5 +5,7 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:" @call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force +:: support malloc device memory more than 4GB. +set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 -.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -e -ngl 99 +.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -s 0 -e -ngl 99 diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index efd78b912c..88f29221bb 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -91,7 +91,10 @@ if (GGML_SYCL_F16) add_compile_definitions(GGML_SYCL_F16) endif() -if (GGML_SYCL_TARGET STREQUAL "NVIDIA") +if (GGML_SYCL_TARGET STREQUAL "INTEL") + add_compile_definitions(GGML_SYCL_WARP_SIZE=16) + target_link_options(ggml-sycl PRIVATE -Xs -ze-intel-greater-than-4GB-buffer-required) +elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA") add_compile_definitions(GGML_SYCL_WARP_SIZE=32) elseif (GGML_SYCL_TARGET STREQUAL "AMD") # INFO: Allowed Sub_group_sizes are not consistent through all @@ -100,7 +103,8 @@ elseif (GGML_SYCL_TARGET STREQUAL "AMD") # Target archs tested working: gfx1030, gfx1031, (Only tested sub_group_size = 32) add_compile_definitions(GGML_SYCL_WARP_SIZE=32) else() - add_compile_definitions(GGML_SYCL_WARP_SIZE=16) + # default for other target + add_compile_definitions(GGML_SYCL_WARP_SIZE=32) endif() if (GGML_SYCL_GRAPH) diff --git a/ggml/src/ggml-sycl/cpy.cpp b/ggml/src/ggml-sycl/cpy.cpp index 1ec99b0a5d..96709554cf 100644 --- a/ggml/src/ggml-sycl/cpy.cpp +++ b/ggml/src/ggml-sycl/cpy.cpp @@ -515,9 +515,6 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); - GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); - GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); - GGML_TENSOR_BINARY_OP_LOCALS01; SYCL_CHECK(ggml_sycl_set_device(ctx.device)); From 0874693b449a847de2d052f2afb5d0cbe9409f92 Mon Sep 17 00:00:00 2001 From: Igor Smirnov Date: Sat, 29 Nov 2025 21:06:32 +0500 Subject: [PATCH 32/33] common : fix json schema with '\' in literals (#17307) * Fix json schema with '\' in literals * Add "literal string with escapes" test --- common/json-schema-to-grammar.cpp | 4 +-- examples/json_schema_to_grammar.py | 4 +-- tests/test-json-schema-to-grammar.cpp | 26 +++++++++++++++++++ .../public_legacy/json-schema-to-grammar.mjs | 4 +-- 4 files changed, 32 insertions(+), 6 deletions(-) diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index e64dc059f3..c8421e1e82 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -268,10 +268,10 @@ static bool is_reserved_name(const std::string & name) { } std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+"); -std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]"); +std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]"); std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]"); std::unordered_map GRAMMAR_LITERAL_ESCAPES = { - {'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"} + {'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}, {'\\', "\\\\"} }; std::unordered_set NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'}; diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index 26989157fe..886dd3d81e 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -231,9 +231,9 @@ DOT = '[^\\x0A\\x0D]' RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]) INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+') -GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') +GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\\]') GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]') -GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'} +GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\'} NON_LITERAL_SET = set('|.()[]{}*+?') ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?') diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index 8a55bc54ae..1e568219d2 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -1339,6 +1339,32 @@ static void test_all(const std::string & lang, std::function Date: Sun, 30 Nov 2025 01:43:29 +0800 Subject: [PATCH 33/33] server: explicitly set the function name in lambda (#17538) As [1] explained, the real debug message will be like: "res operator(): operator() : queue result stop" Set the name explicitly, the message is easy for debugging: "res operator(): recv : queue result stop" The left "operator()" is generated by 'RES_DBG() ... __func__' [1]: https://clang.llvm.org/extra/clang-tidy/checks/bugprone/lambda-function-name.html Signed-off-by: Haiyue Wang --- tools/server/server-queue.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/server/server-queue.cpp b/tools/server/server-queue.cpp index 5a74fd76ac..65c8a0a9ae 100644 --- a/tools/server/server-queue.cpp +++ b/tools/server/server-queue.cpp @@ -199,7 +199,7 @@ server_task_result_ptr server_response::recv(const std::unordered_set & id_ std::unique_lock lock(mutex_results); condition_results.wait(lock, [&]{ if (!running) { - RES_DBG("%s : queue result stop\n", __func__); + RES_DBG("%s : queue result stop\n", "recv"); std::terminate(); // we cannot return here since the caller is HTTP code } return !queue_results.empty();