From 80f19b41869728eeb6a26569957b92a773a2b2c6 Mon Sep 17 00:00:00 2001 From: lhez Date: Tue, 15 Apr 2025 12:26:00 -0700 Subject: [PATCH 01/39] opencl: split `ggml-opencl.cl` into multiple files and cleanup (#12886) * opencl: refactor - split the kernel files --------- Co-authored-by: Shangqing Gu * opencl: split more kernels into separate files * opencl: specify subgroup size instead of querying it * opencl: refine Adreno cl compiler version parsing * opencl: skip some kernels not used by Adreno on old compilers * opencl: refine logic for selecting Adreno kernels * opencl: refine Adreno cl compiler version * opencl: cleanup preprocessor for kernels * opencl: consider Adreno CL compiler on Windows * opencl: add final newline for `mul_mv_f16_f16.cl` --------- Co-authored-by: Shangqing Gu --- ggml/src/ggml-opencl/CMakeLists.txt | 45 +- ggml/src/ggml-opencl/ggml-opencl.cpp | 1053 ++++-- ggml/src/ggml-opencl/kernels/add.cl | 83 + ggml/src/ggml-opencl/kernels/clamp.cl | 20 + ggml/src/ggml-opencl/kernels/cpy.cl | 184 + .../kernels/{ggml-opencl_cvt.cl => cvt.cl} | 66 +- ggml/src/ggml-opencl/kernels/diag_mask_inf.cl | 58 + ggml/src/ggml-opencl/kernels/gelu.cl | 62 + ...cl_gemv_noshuffle.cl => gemv_noshuffle.cl} | 0 ...e_general.cl => gemv_noshuffle_general.cl} | 0 ggml/src/ggml-opencl/kernels/get_rows.cl | 163 + ggml/src/ggml-opencl/kernels/ggml-opencl.cl | 3231 ----------------- .../ggml-opencl/kernels/ggml-opencl_im2col.cl | 146 - .../src/ggml-opencl/kernels/ggml-opencl_mm.cl | 1225 ------- .../kernels/ggml-opencl_transpose_16.cl | 26 - .../kernels/ggml-opencl_transpose_32.cl | 25 - .../kernels/ggml-opencl_transpose_32_16.cl | 35 - ggml/src/ggml-opencl/kernels/im2col_f16.cl | 57 + ggml/src/ggml-opencl/kernels/im2col_f32.cl | 57 + ggml/src/ggml-opencl/kernels/mul.cl | 79 + ..._mat_Ab_Bi_8x4.cl => mul_mat_Ab_Bi_8x4.cl} | 0 .../src/ggml-opencl/kernels/mul_mv_f16_f16.cl | 118 + .../src/ggml-opencl/kernels/mul_mv_f16_f32.cl | 118 + .../kernels/mul_mv_f16_f32_1row.cl | 94 + .../ggml-opencl/kernels/mul_mv_f16_f32_l4.cl | 84 + .../src/ggml-opencl/kernels/mul_mv_f32_f32.cl | 118 + .../ggml-opencl/kernels/mul_mv_q4_0_f32.cl | 192 + .../kernels/mul_mv_q4_0_f32_1d_16x_flat.cl | 307 ++ .../kernels/mul_mv_q4_0_f32_1d_8x_flat.cl | 265 ++ .../kernels/mul_mv_q4_0_f32_8x_flat.cl | 272 ++ .../ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl | 254 ++ ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl | 190 + ggml/src/ggml-opencl/kernels/norm.cl | 81 + ggml/src/ggml-opencl/kernels/relu.cl | 16 + ggml/src/ggml-opencl/kernels/rms_norm.cl | 96 + ggml/src/ggml-opencl/kernels/rope.cl | 721 ++++ ggml/src/ggml-opencl/kernels/scale.cl | 16 + ggml/src/ggml-opencl/kernels/silu.cl | 30 + ggml/src/ggml-opencl/kernels/softmax_4_f16.cl | 87 + ggml/src/ggml-opencl/kernels/softmax_4_f32.cl | 87 + ggml/src/ggml-opencl/kernels/softmax_f16.cl | 86 + ggml/src/ggml-opencl/kernels/softmax_f32.cl | 86 + ggml/src/ggml-opencl/kernels/transpose.cl | 84 + 43 files changed, 5021 insertions(+), 4996 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/add.cl create mode 100644 ggml/src/ggml-opencl/kernels/clamp.cl create mode 100644 ggml/src/ggml-opencl/kernels/cpy.cl rename ggml/src/ggml-opencl/kernels/{ggml-opencl_cvt.cl => cvt.cl} (69%) create mode 100644 ggml/src/ggml-opencl/kernels/diag_mask_inf.cl create mode 100644 ggml/src/ggml-opencl/kernels/gelu.cl rename ggml/src/ggml-opencl/kernels/{ggml-opencl_gemv_noshuffle.cl => gemv_noshuffle.cl} (100%) rename ggml/src/ggml-opencl/kernels/{ggml-opencl_gemv_noshuffle_general.cl => gemv_noshuffle_general.cl} (100%) create mode 100644 ggml/src/ggml-opencl/kernels/get_rows.cl delete mode 100644 ggml/src/ggml-opencl/kernels/ggml-opencl.cl delete mode 100644 ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl delete mode 100644 ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl delete mode 100644 ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl delete mode 100644 ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl delete mode 100644 ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl create mode 100644 ggml/src/ggml-opencl/kernels/im2col_f16.cl create mode 100644 ggml/src/ggml-opencl/kernels/im2col_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul.cl rename ggml/src/ggml-opencl/kernels/{ggml-opencl_mul_mat_Ab_Bi_8x4.cl => mul_mat_Ab_Bi_8x4.cl} (100%) create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl create mode 100644 ggml/src/ggml-opencl/kernels/norm.cl create mode 100644 ggml/src/ggml-opencl/kernels/relu.cl create mode 100644 ggml/src/ggml-opencl/kernels/rms_norm.cl create mode 100644 ggml/src/ggml-opencl/kernels/rope.cl create mode 100644 ggml/src/ggml-opencl/kernels/scale.cl create mode 100644 ggml/src/ggml-opencl/kernels/silu.cl create mode 100644 ggml/src/ggml-opencl/kernels/softmax_4_f16.cl create mode 100644 ggml/src/ggml-opencl/kernels/softmax_4_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/softmax_f16.cl create mode 100644 ggml/src/ggml-opencl/kernels/softmax_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/transpose.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 624cb1b9d0..352deb321e 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -54,16 +54,41 @@ function(ggml_opencl_add_kernel KNAME) endfunction() set(GGML_OPENCL_KERNELS - ggml-opencl - ggml-opencl_mm - ggml-opencl_cvt - ggml-opencl_gemv_noshuffle - ggml-opencl_gemv_noshuffle_general - ggml-opencl_mul_mat_Ab_Bi_8x4 - ggml-opencl_transpose_16 - ggml-opencl_transpose_32 - ggml-opencl_transpose_32_16 - ggml-opencl_im2col + add + clamp + cpy + cvt + diag_mask_inf + gelu + gemv_noshuffle_general + gemv_noshuffle + get_rows + im2col_f32 + im2col_f16 + mul_mat_Ab_Bi_8x4 + mul_mv_f16_f16 + mul_mv_f16_f32_1row + mul_mv_f16_f32_l4 + mul_mv_f16_f32 + mul_mv_f32_f32 + mul_mv_q4_0_f32 + mul_mv_q4_0_f32_v + mul_mv_q4_0_f32_8x_flat + mul_mv_q4_0_f32_1d_8x_flat + mul_mv_q4_0_f32_1d_16x_flat + mul_mv_q6_k + mul + norm + relu + rms_norm + rope + scale + silu + softmax_4_f32 + softmax_4_f16 + softmax_f32 + softmax_f16 + transpose ) foreach (K ${GGML_OPENCL_KERNELS}) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index b8b5cbd3e2..fa40abc33e 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -64,11 +64,33 @@ enum ADRENO_GPU_GEN { X1E, }; +enum ADRENO_CL_COMPILER_TYPE { + E031, + DX, +}; + struct ggml_cl_version { cl_uint major = 0; cl_uint minor = 0; }; +struct ggml_cl_compiler_version { + ADRENO_CL_COMPILER_TYPE type; + int major = -1; + int minor = -1; + int patch = -1; + + bool same(ADRENO_CL_COMPILER_TYPE t, int x, int y, int z) const { + return major == x && minor == y && patch == z && type == t; + } + bool newer_than(ADRENO_CL_COMPILER_TYPE t, int x, int y, int z) const { + return major*10000 + minor*100 + patch > x*10000 + y*100 + z && type == t; + } + bool newer_than_or_same(ADRENO_CL_COMPILER_TYPE t, int x, int y, int z) const { + return same(t, x, y, z) || newer_than(t, x, y, z); + } +}; + // Parses a version string of form "XX.YY ". On an error returns ggml_cl_version with all zeroes. static ggml_cl_version parse_cl_version(std::string_view str) { size_t major_str_begin = 0; @@ -173,24 +195,30 @@ static ADRENO_GPU_GEN get_adreno_gpu_gen(const char *device_name) { return ADRENO_GPU_GEN::ADRENO_UNKNOWN; } -static int get_adreno_cl_compiler_version(const char *driver_version) { +static ggml_cl_compiler_version get_adreno_cl_compiler_version(const char *driver_version) { std::string driver_ver_str(driver_version); + ADRENO_CL_COMPILER_TYPE type = ADRENO_CL_COMPILER_TYPE::E031; size_t compiler_ver_pos = driver_ver_str.find("E031"); size_t compiler_ver_len = 13; - size_t compiler_ver_offset = 5; + size_t compiler_major_offset = 5; + size_t compiler_minor_offset = 8; + size_t compiler_patch_offset = 11; if (compiler_ver_pos == std::string::npos) { compiler_ver_pos = driver_ver_str.find("DX"); if (compiler_ver_pos == std::string::npos) { - return -1; + return {}; } + type = ADRENO_CL_COMPILER_TYPE::DX; compiler_ver_len = 11; - compiler_ver_offset = 3; + compiler_major_offset = 3; } std::string compiler_ver_str = driver_ver_str.substr(compiler_ver_pos, compiler_ver_len); - std::string major_ver_str = compiler_ver_str.substr(compiler_ver_offset, 2); - return std::atoi(major_ver_str.c_str()); + int major = std::atoi(compiler_ver_str.substr(compiler_major_offset, 2).c_str()); + int minor = std::atoi(compiler_ver_str.substr(compiler_minor_offset, 2).c_str()); + int patch = std::atoi(compiler_ver_str.substr(compiler_patch_offset, 2).c_str()); + return { type, major, minor, patch }; } // backend device context @@ -215,16 +243,48 @@ struct ggml_backend_opencl_context { cl_int alignment; size_t max_alloc_size; bool fp16_support; + bool has_vector_subgroup_broadcast; + ggml_cl_compiler_version adreno_cl_compiler_version; int adreno_wave_size; cl_context context; cl_command_queue queue; - cl_program program; - cl_program program_1; - cl_program program_2; - cl_program program_im2col; + cl_program program_add; + cl_program program_clamp; + cl_program program_cpy; + cl_program program_cvt; + cl_program program_diag_mask_inf; + cl_program program_gelu; + cl_program program_gemv_noshuffle_general; + cl_program program_gemv_noshuffle; + cl_program program_get_rows; + cl_program program_im2col_f16; + cl_program program_im2col_f32; + cl_program program_mul_mat_Ab_Bi_8x4; + cl_program program_mul_mv_q4_0_f32; + cl_program program_mul_mv_q4_0_f32_v; + cl_program program_mul_mv_q4_0_f32_8x_flat; + cl_program program_mul_mv_q4_0_f32_1d_8x_flat; + cl_program program_mul_mv_q4_0_f32_1d_16x_flat; + cl_program program_mul_mv_q6_K; + cl_program program_mul_mv_f16_f16; + cl_program program_mul_mv_f16_f32_1row; + cl_program program_mul_mv_f16_f32_l4; + cl_program program_mul_mv_f16_f32; + cl_program program_mul_mv_f32_f32; + cl_program program_mul; + cl_program program_norm; + cl_program program_relu; + cl_program program_rms_norm; + cl_program program_rope; + cl_program program_scale; + cl_program program_silu; + cl_program program_softmax_f32; + cl_program program_softmax_f16; + cl_program program_softmax_4_f32; + cl_program program_softmax_4_f16; cl_kernel kernel_add, kernel_add_row; cl_kernel kernel_mul, kernel_mul_row; @@ -249,19 +309,17 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mat_f16_f32; cl_kernel kernel_mul_mat_f16_f32_l4; cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v; - cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0, kernel_mul_mat_q4_0_f32_flat; + cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; - cl_kernel kernel_convert_block_q4_0_noshuffle, kernel_mul_mat_q4_0_f32_flat_v0, - kernel_mul_mat_q4_0_f32_flat_img_v0; + cl_kernel kernel_convert_block_q4_0_noshuffle; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q6_K_f32; cl_kernel kernel_im2col_f32, kernel_im2col_f16; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // Transpose kernels - cl_program program_transpose_32; - cl_program program_transpose_32_16; - cl_program program_transpose_16; + cl_program program_transpose; + cl_kernel kernel_transpose_32; cl_kernel kernel_transpose_32_16; cl_kernel kernel_transpose_16; @@ -374,6 +432,681 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co return p; } +static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_version opencl_c_version) { + cl_int err; + + // compiler options for general kernels + auto opencl_c_std = + std::string("CL") + std::to_string(opencl_c_version.major) + "." + std::to_string(opencl_c_version.minor); + std::string compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable -cl-unsafe-math-optimizations" + " -cl-finite-math-only -cl-fast-relaxed-math"; + + GGML_LOG_INFO("ggml_opencl: loading OpenCL kernels"); + + // add + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "add.cl.h" + }; +#else + const std::string kernel_src = read_file("add.cl"); +#endif + backend_ctx->program_add = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_add = clCreateKernel(backend_ctx->program_add, "kernel_add", &err), err)); + CL_CHECK((backend_ctx->kernel_add_row = clCreateKernel(backend_ctx->program_add, "kernel_add_row", &err), err)); + GGML_LOG_CONT("."); + } + + // clamp + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "clamp.cl.h" + }; +#else + const std::string kernel_src = read_file("clamp.cl"); +#endif + backend_ctx->program_clamp = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_clamp = clCreateKernel(backend_ctx->program_clamp, "kernel_clamp", &err), err)); + GGML_LOG_CONT("."); + } + + // cpy + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "cpy.cl.h" + }; +#else + const std::string kernel_src = read_file("cpy.cl"); +#endif + backend_ctx->program_cpy = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f16_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f16_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f32_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f32_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // cvt + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "cvt.cl.h" + }; +#else + const std::string kernel_src = read_file("cvt.cl"); +#endif + backend_ctx->program_cvt = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err)); + GGML_LOG_CONT("."); + } + + // diag_mask_inf + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "diag_mask_inf.cl.h" + }; +#else + const std::string kernel_src = read_file("diag_mask_inf.cl"); +#endif + backend_ctx->program_diag_mask_inf = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_diag_mask_inf_8 = clCreateKernel(backend_ctx->program_diag_mask_inf, "kernel_diag_mask_inf_8", &err), err)); + CL_CHECK((backend_ctx->kernel_diag_mask_inf = clCreateKernel(backend_ctx->program_diag_mask_inf, "kernel_diag_mask_inf", &err), err)); + GGML_LOG_CONT("."); + } + + // gelu + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gelu.cl.h" + }; +#else + const std::string kernel_src = read_file("gelu.cl"); +#endif + backend_ctx->program_gelu = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_4", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_quick = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick", &err), err)); + CL_CHECK((backend_ctx->kernel_gelu_quick_4 = clCreateKernel(backend_ctx->program_gelu, "kernel_gelu_quick_4", &err), err)); + GGML_LOG_CONT("."); + } + + // get_rows + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "get_rows.cl.h" + }; +#else + const std::string kernel_src = read_file("get_rows.cl"); +#endif + backend_ctx->program_get_rows = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_get_rows_f32 = clCreateKernel(backend_ctx->program_get_rows, "kernel_get_rows_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_get_rows_f16 = clCreateKernel(backend_ctx->program_get_rows, "kernel_get_rows_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_get_rows_q4_0 = clCreateKernel(backend_ctx->program_get_rows, "kernel_get_rows_q4_0", &err), err)); + GGML_LOG_CONT("."); + } + + // im2col_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "im2col_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("im2col_f32.cl"); +#endif + backend_ctx->program_im2col_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_im2col_f32 = clCreateKernel(backend_ctx->program_im2col_f32, "kernel_im2col_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // im2col_f16 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "im2col_f16.cl.h" + }; +#else + const std::string kernel_src = read_file("im2col_f16.cl"); +#endif + backend_ctx->program_im2col_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_im2col_f16 = clCreateKernel(backend_ctx->program_im2col_f16, "kernel_im2col_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32 = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32, "kernel_mul_mat_q4_0_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32_v + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32_v.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32_v.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32_v = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_v = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_v, "kernel_mul_mat_q4_0_f32_v", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32_8x_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32_8x_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32_8x_flat.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32_8x_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_8x_flat, "kernel_mul_mat_q4_0_f32_8x_flat", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32_1d_8x_flat + // This kernel does not compiler on Adreno cl compiler 38.01. Skip it for + // those compiler versions since it is anyway not used for Adreno. + if (backend_ctx->gpu_family != ADRENO || + backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) || + backend_ctx->adreno_cl_compiler_version.type == DX) { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32_1d_8x_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32_1d_8x_flat.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32_1d_8x_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_8x_flat = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_1d_8x_flat, "kernel_mul_mat_q4_0_f32_1d_8x_flat", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_0_f32_1d_16x_flat + // This kernel does not compiler on Adreno cl compiler 38.01. Skip it for + // those compiler versions since it is anyway not used for Adreno. + if (backend_ctx->gpu_family != ADRENO || + backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) || + backend_ctx->adreno_cl_compiler_version.type == DX) { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_0_f32_1d_16x_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_0_f32_1d_16x_flat.cl"); +#endif + backend_ctx->program_mul_mv_q4_0_f32_1d_16x_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_16x_flat = clCreateKernel(backend_ctx->program_mul_mv_q4_0_f32_1d_16x_flat, "kernel_mul_mat_q4_0_f32_1d_16x_flat", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_q6_k + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q6_k.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q6_k.cl"); +#endif + backend_ctx->program_mul_mv_q6_K = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32 = clCreateKernel(backend_ctx->program_mul_mv_q6_K, "kernel_mul_mv_q6_K_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f16_f16 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f16_f16.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f16_f16.cl"); +#endif + backend_ctx->program_mul_mv_f16_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f16 = clCreateKernel(backend_ctx->program_mul_mv_f16_f16, "kernel_mul_mat_f16_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f16_f32_1row + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f16_f32_1row.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f16_f32_1row.cl"); +#endif + backend_ctx->program_mul_mv_f16_f32_1row = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_1row = clCreateKernel(backend_ctx->program_mul_mv_f16_f32_1row, "kernel_mul_mat_f16_f32_1row", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f16_f32_l4 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f16_f32_l4.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f16_f32_l4.cl"); +#endif + backend_ctx->program_mul_mv_f16_f32_l4 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_l4 = clCreateKernel(backend_ctx->program_mul_mv_f16_f32_l4, "kernel_mul_mat_f16_f32_l4", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f16_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f16_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f16_f32.cl"); +#endif + backend_ctx->program_mul_mv_f16_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32 = clCreateKernel(backend_ctx->program_mul_mv_f16_f32, "kernel_mul_mat_f16_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mv_f32_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_f32_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_f32_f32.cl"); +#endif + backend_ctx->program_mul_mv_f32_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mat_f32_f32 = clCreateKernel(backend_ctx->program_mul_mv_f32_f32, "kernel_mul_mat_f32_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // mul + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul.cl.h" + }; +#else + const std::string kernel_src = read_file("mul.cl"); +#endif + backend_ctx->program_mul = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul = clCreateKernel(backend_ctx->program_mul, "kernel_mul", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_row = clCreateKernel(backend_ctx->program_mul, "kernel_mul_row", &err), err)); + GGML_LOG_CONT("."); + } + + // norm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "norm.cl.h" + }; +#else + const std::string kernel_src = read_file("norm.cl"); +#endif + backend_ctx->program_norm = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err)); + GGML_LOG_CONT("."); + } + + // relu + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "relu.cl.h" + }; +#else + const std::string kernel_src = read_file("relu.cl"); +#endif + backend_ctx->program_relu = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_relu = clCreateKernel(backend_ctx->program_relu, "kernel_relu", &err), err)); + GGML_LOG_CONT("."); + } + + // rms_norm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "rms_norm.cl.h" + }; +#else + const std::string kernel_src = read_file("rms_norm.cl"); +#endif + backend_ctx->program_rms_norm = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err)); + GGML_LOG_CONT("."); + } + + // rope + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "rope.cl.h" + }; +#else + const std::string kernel_src = read_file("rope.cl"); +#endif + backend_ctx->program_rope = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_rope_norm_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_norm_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_norm_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_neox_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_neox_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_neox_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_multi_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_multi_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_multi_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_multi_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_vision_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_vision_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_vision_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_vision_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // scale + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "scale.cl.h" + }; +#else + const std::string kernel_src = read_file("scale.cl"); +#endif + backend_ctx->program_scale = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_scale = clCreateKernel(backend_ctx->program_scale, "kernel_scale", &err), err)); + GGML_LOG_CONT("."); + } + + // silu + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "silu.cl.h" + }; +#else + const std::string kernel_src = read_file("silu.cl"); +#endif + backend_ctx->program_silu = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_silu = clCreateKernel(backend_ctx->program_silu, "kernel_silu", &err), err)); + CL_CHECK((backend_ctx->kernel_silu_4 = clCreateKernel(backend_ctx->program_silu, "kernel_silu_4", &err), err)); + GGML_LOG_CONT("."); + } + + // softmax_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "softmax_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("softmax_f32.cl"); +#endif + backend_ctx->program_softmax_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_soft_max = clCreateKernel(backend_ctx->program_softmax_f32, "kernel_soft_max", &err), err)); + GGML_LOG_CONT("."); + } + + // softmax_f16 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "softmax_f16.cl.h" + }; +#else + const std::string kernel_src = read_file("softmax_f16.cl"); +#endif + backend_ctx->program_softmax_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_soft_max_f16 = clCreateKernel(backend_ctx->program_softmax_f16, "kernel_soft_max_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // softmax_4_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "softmax_4_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("softmax_4_f32.cl"); +#endif + backend_ctx->program_softmax_4_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_soft_max_4 = clCreateKernel(backend_ctx->program_softmax_4_f32, "kernel_soft_max_4", &err), err)); + GGML_LOG_CONT("."); + } + + // softmax_4_f16 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "softmax_4_f16.cl.h" + }; +#else + const std::string kernel_src = read_file("softmax_4_f16.cl"); +#endif + backend_ctx->program_softmax_4_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_soft_max_4_f16 = clCreateKernel(backend_ctx->program_softmax_4_f16, "kernel_soft_max_4_f16", &err), err)); + GGML_LOG_CONT("."); + } + + // Adreno kernels +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // transpose + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "transpose.cl.h" + }; +#else + const std::string kernel_src = read_file("transpose.cl"); +#endif + backend_ctx->program_transpose = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32_16", &err), err)); + CL_CHECK((backend_ctx->kernel_transpose_32 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32", &err), err)); + CL_CHECK((backend_ctx->kernel_transpose_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16", &err), err)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_general + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemv_general { + #include "gemv_noshuffle_general.cl.h" + }; +#else + const std::string kernel_src_CL_gemv_general = read_file("gemv_noshuffle_general.cl"); +#endif + + backend_ctx->program_CL_gemv_general = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general = clCreateKernel(backend_ctx->program_CL_gemv_general, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle + { + // Gemv 2048, 16384 + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=2048 " + " -DBLOCK_STRIDE_A=16384 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemv { + #include "gemv_noshuffle.cl.h" + }; +#else + const std::string kernel_src_CL_gemv = read_file("gemv_noshuffle.cl"); +#endif + + backend_ctx->program_CL_gemv_4096_1_4096 = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_4096, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + + // Gemv 2048, 16384 + CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=2048 " + " -DBLOCK_STRIDE_A=16384 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + backend_ctx->program_CL_gemv_4096_1_11008 = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_11008, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + + // Gemv 5504, 44032 + CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=5504 " + " -DBLOCK_STRIDE_A=44032 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + backend_ctx->program_CL_gemv_11008_1_4096 = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_11008_1_4096, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + + // Gemv 16000, 128000 + CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=16000 " + " -DBLOCK_STRIDE_A=128000 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + backend_ctx->program_CL_gemv_32000_1_4096 = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_32000_1_4096, "kernel_gemv_noshuffle", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mat_Ab_Bi_8x4 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemm { + #include "mul_mat_Ab_Bi_8x4.cl.h" + }; +#else + const std::string kernel_src_CL_gemm = read_file("mul_mat_Ab_Bi_8x4.cl"); +#endif + backend_ctx->program_CL_gemm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_CL_gemm.c_str(), compile_opts); + CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err)); + GGML_LOG_CONT("."); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + GGML_LOG_CONT("\n"); +} + static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { static bool initialized = false; static ggml_backend_opencl_context *backend_ctx = nullptr; @@ -612,11 +1345,12 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { GGML_LOG_INFO("ggml_opencl: OpenCL driver: %s\n", driver_version); backend_ctx->driver_version = driver_version; - int adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version); - bool has_vector_subgroup_broadcast = - adreno_cl_compiler_version >= 47 || adreno_cl_compiler_version == 17; + backend_ctx->adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version); + backend_ctx->has_vector_subgroup_broadcast = + backend_ctx->adreno_cl_compiler_version.major >= 47 || + backend_ctx->adreno_cl_compiler_version.major == 17; GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n", - has_vector_subgroup_broadcast ? "true" : "false"); + backend_ctx->has_vector_subgroup_broadcast ? "true" : "false"); size_t ext_str_size; clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, 0, NULL, &ext_str_size); @@ -691,247 +1425,10 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { #endif CL_CHECK((backend_ctx->queue = clCreateCommandQueue(context, device, command_queue_props, &err), err)); -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src { - #include "ggml-opencl.cl.h" - }; -#else - const std::string kernel_src = read_file("ggml-opencl.cl"); -#endif + // Load kernels + load_cl_kernels(backend_ctx, opencl_c_version); - auto opencl_c_std = - std::string("CL") + std::to_string(opencl_c_version.major) + "." + std::to_string(opencl_c_version.minor); - - std::string compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable -cl-unsafe-math-optimizations" - " -cl-finite-math-only -cl-fast-relaxed-math"; - backend_ctx->program = build_program_from_source(context, device, kernel_src.c_str(), compile_opts); - - // Non matmul kernels. - CL_CHECK((backend_ctx->kernel_get_rows_f32 = clCreateKernel(backend_ctx->program, "kernel_get_rows_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_get_rows_f16 = clCreateKernel(backend_ctx->program, "kernel_get_rows_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_get_rows_q4_0 = clCreateKernel(backend_ctx->program, "kernel_get_rows_q4_0", &err), err)); - CL_CHECK((backend_ctx->kernel_add = clCreateKernel(backend_ctx->program, "kernel_add", &err), err)); - CL_CHECK((backend_ctx->kernel_add_row = clCreateKernel(backend_ctx->program, "kernel_add_row", &err), err)); - CL_CHECK((backend_ctx->kernel_mul = clCreateKernel(backend_ctx->program, "kernel_mul", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_row = clCreateKernel(backend_ctx->program, "kernel_mul_row", &err), err)); - CL_CHECK((backend_ctx->kernel_scale = clCreateKernel(backend_ctx->program, "kernel_scale", &err), err)); - CL_CHECK((backend_ctx->kernel_silu = clCreateKernel(backend_ctx->program, "kernel_silu", &err), err)); - CL_CHECK((backend_ctx->kernel_silu_4 = clCreateKernel(backend_ctx->program, "kernel_silu_4", &err), err)); - CL_CHECK((backend_ctx->kernel_gelu = clCreateKernel(backend_ctx->program, "kernel_gelu", &err), err)); - CL_CHECK((backend_ctx->kernel_gelu_4 = clCreateKernel(backend_ctx->program, "kernel_gelu_4", &err), err)); - CL_CHECK((backend_ctx->kernel_gelu_quick = clCreateKernel(backend_ctx->program, "kernel_gelu_quick", &err), err)); - CL_CHECK((backend_ctx->kernel_gelu_quick_4 = clCreateKernel(backend_ctx->program, "kernel_gelu_quick_4", &err), err)); - CL_CHECK((backend_ctx->kernel_relu = clCreateKernel(backend_ctx->program, "kernel_relu", &err), err)); - CL_CHECK((backend_ctx->kernel_clamp = clCreateKernel(backend_ctx->program, "kernel_clamp", &err), err)); - CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program, "kernel_norm", &err), err)); - CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program, "kernel_rms_norm", &err), err)); - CL_CHECK((backend_ctx->kernel_diag_mask_inf = clCreateKernel(backend_ctx->program, "kernel_diag_mask_inf", &err), err)); - CL_CHECK((backend_ctx->kernel_diag_mask_inf_8 = clCreateKernel(backend_ctx->program, "kernel_diag_mask_inf_8", &err), err)); - CL_CHECK((backend_ctx->kernel_soft_max = clCreateKernel(backend_ctx->program, "kernel_soft_max", &err), err)); - CL_CHECK((backend_ctx->kernel_soft_max_4 = clCreateKernel(backend_ctx->program, "kernel_soft_max_4", &err), err)); - CL_CHECK((backend_ctx->kernel_soft_max_f16 = clCreateKernel(backend_ctx->program, "kernel_soft_max_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_soft_max_4_f16 = clCreateKernel(backend_ctx->program, "kernel_soft_max_4_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_norm_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_norm_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_neox_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_neox_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_multi_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_multi_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_multi_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_multi_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_vision_f32 = clCreateKernel(backend_ctx->program, "kernel_rope_vision_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_vision_f16 = clCreateKernel(backend_ctx->program, "kernel_rope_vision_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(backend_ctx->program, "kernel_cpy_f16_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(backend_ctx->program, "kernel_cpy_f32_f32", &err), err)); - - // Matmul kernels. - CL_CHECK((backend_ctx->kernel_mul_mat_f32_f32 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f32_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_f16_f16 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_1row = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32_1row", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_l4 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_f16_f32_l4", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32 = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_v = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_v", &err), err)); - - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_flat", &err), err)); - CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program, "kernel_convert_block_q4_0", &err), err)); - CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program, "kernel_restore_block_q4_0", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat = clCreateKernel(backend_ctx->program, "kernel_mul_mat_q4_0_f32_8x_flat", &err), err)); - - // Load additional mulmat kernels. -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_1 { - #include "ggml-opencl_mm.cl.h" - }; -#else - const std::string kernel_src_1 = read_file("ggml-opencl_mm.cl"); -#endif - backend_ctx->program_1 = build_program_from_source(context, device, kernel_src_1.c_str(), compile_opts); - - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_8x_flat = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_1d_8x_flat", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_1d_16x_flat = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_1d_16x_flat", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32 = clCreateKernel(backend_ctx->program_1, "kernel_mul_mv_q6_K_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat_v0 = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_flat_v0", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mat_q4_0_f32_flat_img_v0 = clCreateKernel(backend_ctx->program_1, "kernel_mul_mat_q4_0_f32_flat_img_v0", &err), err)); - - // Load additional data conversion kernels. -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_2 { - #include "ggml-opencl_cvt.cl.h" - }; -#else - const std::string kernel_src_2 = read_file("ggml-opencl_cvt.cl"); -#endif - backend_ctx->program_2 = build_program_from_source(context, device, kernel_src_2.c_str(), compile_opts); - - CL_CHECK((backend_ctx->kernel_convert_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_2, "kernel_convert_block_q4_0_noshuffle", &err), err)); - - // im2col kernels -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_im2col { - #include "ggml-opencl_im2col.cl.h" - }; -#else - const std::string kernel_src_im2col = read_file("ggml-opencl_im2col.cl"); -#endif - backend_ctx->program_im2col = build_program_from_source(context, device, kernel_src_im2col.c_str(), compile_opts); - - CL_CHECK((backend_ctx->kernel_im2col_f32 = clCreateKernel(backend_ctx->program_im2col, "kernel_im2col_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_im2col_f16 = clCreateKernel(backend_ctx->program_im2col, "kernel_im2col_f16", &err), err)); - - // Kernels for Adreno #ifdef GGML_OPENCL_USE_ADRENO_KERNELS -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string transpose_32_src { - #include "ggml-opencl_transpose_32.cl.h" - }; -#else - const std::string transpose_32_src = read_file("ggml-opencl_transpose_32.cl"); -#endif - backend_ctx->program_transpose_32 = build_program_from_source(context, device, transpose_32_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_transpose_32 = clCreateKernel(backend_ctx->program_transpose_32, "kernel_transpose_32", &err), err)); - -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string transpose_32_16_src { - #include "ggml-opencl_transpose_32_16.cl.h" - }; -#else - const std::string transpose_32_16_src = read_file("ggml-opencl_transpose_32_16.cl"); -#endif - backend_ctx->program_transpose_32_16 = build_program_from_source(context, device, transpose_32_16_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose_32_16, "kernel_transpose_32_16", &err), err)); - -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string transpose_16_src { - #include "ggml-opencl_transpose_16.cl.h" - }; -#else - const std::string transpose_16_src = read_file("ggml-opencl_transpose_16.cl"); -#endif - backend_ctx->program_transpose_16 = build_program_from_source(context, device, transpose_16_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_transpose_16 = clCreateKernel(backend_ctx->program_transpose_16, "kernel_transpose_16", &err), err)); - - // Gemv general - std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_CL_gemv_general { - #include "ggml-opencl_gemv_noshuffle_general.cl.h" - }; -#else - const std::string kernel_src_CL_gemv_general = read_file("ggml-opencl_gemv_noshuffle_general.cl"); -#endif - - backend_ctx->program_CL_gemv_general = build_program_from_source( - context, device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general = clCreateKernel(backend_ctx->program_CL_gemv_general, "kernel_gemv_noshuffle", &err), err)); - - // Gemv 2048, 16384 - CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DLINE_STRIDE_A=2048 " - " -DBLOCK_STRIDE_A=16384 " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_CL_gemv { - #include "ggml-opencl_gemv_noshuffle.cl.h" - }; -#else - const std::string kernel_src_CL_gemv = read_file("ggml-opencl_gemv_noshuffle.cl"); -#endif - - backend_ctx->program_CL_gemv_4096_1_4096 = build_program_from_source( - context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_4096, "kernel_gemv_noshuffle", &err), err)); - - // Gemv 2048, 16384 - CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DLINE_STRIDE_A=2048 " - " -DBLOCK_STRIDE_A=16384 " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } - - backend_ctx->program_CL_gemv_4096_1_11008 = build_program_from_source( - context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_11008, "kernel_gemv_noshuffle", &err), err)); - - // Gemv 5504, 44032 - CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DLINE_STRIDE_A=5504 " - " -DBLOCK_STRIDE_A=44032 " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } - - backend_ctx->program_CL_gemv_11008_1_4096 = build_program_from_source( - context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_11008_1_4096, "kernel_gemv_noshuffle", &err), err)); - - // Gemv 16000, 128000 - CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DLINE_STRIDE_A=16000 " - " -DBLOCK_STRIDE_A=128000 " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } - - backend_ctx->program_CL_gemv_32000_1_4096 = build_program_from_source(context, device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_32000_1_4096, "kernel_gemv_noshuffle", &err), err)); - - // Gemm -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_CL_gemm { - #include "ggml-opencl_mul_mat_Ab_Bi_8x4.cl.h" - }; -#else - const std::string kernel_src_CL_gemm = read_file("ggml-opencl_mul_mat_Ab_Bi_8x4.cl"); -#endif - backend_ctx->program_CL_gemm = build_program_from_source(context, device, kernel_src_CL_gemm.c_str(), compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err)); - - // TODO: fixme: these sizes are hardcoded for now. - // they should be allocated based on the model's size - // and the device's max alloc size // Allocate intermediate buffers and images size_t required_A_q_d_bytes = 311164928; size_t required_A_s_d_bytes = 38895616; @@ -1495,8 +1992,15 @@ static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buff // The optimized gemm and gemv kernels are used for large matrices without batch. // tensor is the quantized weights matrix. -inline bool use_adreno_kernels(const ggml_tensor *tensor) { - return tensor->ne[0] >= 512 && tensor->ne[1] >= 512 && +inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + int64_t threshold_ne0 = 512; + int64_t threshold_ne1 = 512; + if (!backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) && + backend_ctx->adreno_cl_compiler_version.type != DX) { + threshold_ne0 = 128; + threshold_ne1 = 128; + } + return tensor->ne[0] >= threshold_ne0 && tensor->ne[1] >= threshold_ne1 && tensor->ne[2] == 1 && tensor->ne[3] == 1; } @@ -1574,7 +2078,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; // The optimized kernels need weights in natural order, so unshuffle. - if (use_adreno_kernels(tensor)) { + if (use_adreno_kernels(backend_ctx, tensor)) { kernel = backend_ctx->kernel_convert_block_q4_0_noshuffle; } #else @@ -1598,7 +2102,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // Only do transpose for large, non batched matrix // TODO: use preallocated images instead of sub-buffer then image - if (use_adreno_kernels(tensor)) { + if (use_adreno_kernels(backend_ctx, tensor)) { // <----------------------------------------------------------------------------------> // // start transpose // <----------------------------------------------------------------------------------> // @@ -2899,8 +3403,8 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; cl_command_queue queue = backend_ctx->queue; - ggml_backend_opencl_device_context * dev_ctx = - (ggml_backend_opencl_device_context *)backend->device->context; + //ggml_backend_opencl_device_context * dev_ctx = + // (ggml_backend_opencl_device_context *)backend->device->context; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; @@ -2931,13 +3435,20 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c // Note, this kernel declares local memory in kernel args and the size // depends on subgroup size. - // Retrieve subgroup size. // Note, this requires OpenCL 2.1 and above + // For now we use fixed subgroup size to simplify support for OpenCL 2.0. size_t sgs; - CL_CHECK(clGetKernelSubGroupInfo(kernel, dev_ctx->device, - CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE, - sizeof(local_work_size), local_work_size, - sizeof(size_t), &sgs, NULL)); + //CL_CHECK(clGetKernelSubGroupInfo(kernel, dev_ctx->device, + // CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE, + // sizeof(local_work_size), local_work_size, + // sizeof(size_t), &sgs, NULL)); + if (backend_ctx->gpu_family == ADRENO) { + sgs = 64; + } else if (backend_ctx->gpu_family == INTEL) { + sgs = 32; + } else { + GGML_ASSERT(false && "Unsupported GPU"); + } CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); @@ -3030,7 +3541,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co #ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_context context = backend_ctx->context; - if (ne01 && ne1 && use_adreno_kernels(src0)) { + if (ne01 && ne1 && use_adreno_kernels(backend_ctx, src0)) { // init CL objects // <--------------------------------------------> // diff --git a/ggml/src/ggml-opencl/kernels/add.cl b/ggml/src/ggml-opencl/kernels/add.cl new file mode 100644 index 0000000000..f73f3c0134 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/add.cl @@ -0,0 +1,83 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// add +//------------------------------------------------------------------------------ + +// general-purpose kernel for addition of two tensors +// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 +// cons: not very efficient +kernel void kernel_add( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i10 = i0 % ne10; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) + *((global float *)(src1_ptr + i10*nb10)); + } +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_add_row( + global float4 * src0, + ulong offset0, + global float4 * src1, + ulong offset1, + global float4 * dst, + ulong offsetd, + int ne +) { + src0 = (global float4*)((global char*)src0 + offset0); + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float4*)((global char*)dst + offsetd); + + // This performs better than using %. + uint gid = get_global_id(0); + uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne + dst[gid] = src0[gid] + src1[idx1]; +} diff --git a/ggml/src/ggml-opencl/kernels/clamp.cl b/ggml/src/ggml-opencl/kernels/clamp.cl new file mode 100644 index 0000000000..ae6032444e --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/clamp.cl @@ -0,0 +1,20 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// clamp +//------------------------------------------------------------------------------ +kernel void kernel_clamp( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + float min, + float max +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = src0[get_global_id(0)] < min ? + min : + (src0[get_global_id(0)] > max ? max : src0[get_global_id(0)]); +} diff --git a/ggml/src/ggml-opencl/kernels/cpy.cl b/ggml/src/ggml-opencl/kernels/cpy.cl new file mode 100644 index 0000000000..9369351a60 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/cpy.cl @@ -0,0 +1,184 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// cpy +//------------------------------------------------------------------------------ + +kernel void kernel_cpy_f16_f16( + global half * src0, + ulong offset0, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f16_f32( + global half * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + + src0 = (global half*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f16( + global float * src0, + ulong offset0, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_f32_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl similarity index 69% rename from ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl rename to ggml/src/ggml-opencl/kernels/cvt.cl index e2024332f8..fe7975e3db 100644 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -1,39 +1,20 @@ //------------------------------------------------------------------------------ -// This file is contains additional kernels for data conversion. +// This file is contains kernels for data conversion. // These kernels are used when loading the model, so its performance is less // important. //------------------------------------------------------------------------------ -#ifdef cl_khr_fp16 #pragma OPENCL EXTENSION cl_khr_fp16 : enable -#elif defined(cl_amd_fp16) -#pragma OPENCL EXTENSION cl_amd_fp16 : enable -#else -#error "Half precision floating point not supportedby OpenCL implementation on your device." -#endif - -#ifdef cl_khr_subgroups -#pragma OPENCL EXTENSION cl_khr_subgroups : enable -#elif defined(cl_intel_subgroups) -#pragma OPENCL EXTENSION cl_intel_subgroups : enable -#else -#error "Subgroup not supported on your device." -#endif #ifdef cl_intel_required_subgroup_size -// Always use subgroup size of 32 on Intel. #pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable #define INTEL_GPU 1 #define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) #define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) #elif defined(cl_qcom_reqd_sub_group_size) -// Always use subgroups size of 64 on Adreno. #pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable #define ADRENO_GPU 1 #define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) #define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) -#else -// TODO: do not know how to choose subgroup size on other GPUs. -#error "Selecting subgroup size is not supported on your device." #endif #define QK4_0 32 @@ -66,13 +47,44 @@ struct block_q4_0 }; //------------------------------------------------------------------------------ -// mul_vec_q_n_f32_flat_noshuffle -// -// This variation uses flat arrays (struct of arrays, SOA) representation for -// quant tensors. It also uses non shuffled bit order for weights. -// -// The shuffled version is kept in the original file because moving it here -// seems to result in worse performance for adreno. +// kernel_convert_block_q4_0 +// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q4_0( + global struct block_q4_0 * src0, + global uchar * dst_q, + global half * dst_d +) { + global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + + for (int i = 0; i < QK4_0/2; ++i) { + q[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_q4_0( + global uchar * src_q, + global half * src_d, + global struct block_q4_0 * dst +) { + global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + for (int i = 0; i < QK4_0/2; ++i) { + b->qs[i] = q[i]; + } +} + +//------------------------------------------------------------------------------ +// kernel_convert_block_q4_0_noshuffle +// Flatten q4_0 weights and unshuffle the bits //------------------------------------------------------------------------------ kernel void kernel_convert_block_q4_0_noshuffle( diff --git a/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl b/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl new file mode 100644 index 0000000000..36eff0439f --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl @@ -0,0 +1,58 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// diag_mask_inf kernels +//------------------------------------------------------------------------------ +kernel void kernel_diag_mask_inf( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int n_past +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i02 = get_global_id(2); + int i01 = get_global_id(1); + int i00 = get_global_id(0); + + if (i00 > n_past + i01) { + dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; + } else { + dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; + } +} + +kernel void kernel_diag_mask_inf_8( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd, + int ne00, + int ne01, + int n_past +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + int i = 2*get_global_id(0); + + dst[i+0] = src0[i+0]; + dst[i+1] = src0[i+1]; + int i4 = 4*i; + int i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; + int i01 = i4/(ne00); i4 -= i01*ne00; + int i00 = i4; + for (int k = 3; k >= 0; --k) { + if (i00 + 4 + k <= n_past + i01) { + break; + } + (&dst[i+1])[k] = -INFINITY; + if (i00 + k > n_past + i01) { + (&dst[i])[k] = -INFINITY; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/gelu.cl b/ggml/src/ggml-opencl/kernels/gelu.cl new file mode 100644 index 0000000000..71c310cc9f --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gelu.cl @@ -0,0 +1,62 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// gelu +//------------------------------------------------------------------------------ +#define GELU_COEF_A 0.044715f +#define GELU_QUICK_COEF -1.702f +#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f + +kernel void kernel_gelu( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + + dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + + dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); +} + +kernel void kernel_gelu_quick( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} + +kernel void kernel_gelu_quick_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl similarity index 100% rename from ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle.cl rename to ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl similarity index 100% rename from ggml/src/ggml-opencl/kernels/ggml-opencl_gemv_noshuffle_general.cl rename to ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl diff --git a/ggml/src/ggml-opencl/kernels/get_rows.cl b/ggml/src/ggml-opencl/kernels/get_rows.cl new file mode 100644 index 0000000000..b3fea2923d --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/get_rows.cl @@ -0,0 +1,163 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +#define QK4_0 32 + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + + +//------------------------------------------------------------------------------ +// dequantize_q4_0_f32, dequantize_q4_0_f16 +//------------------------------------------------------------------------------ +void dequantize_q4_0_f32(global struct block_q4_0 * xb, short il, float16 * reg) { + global ushort * qs = ((global ushort *)xb + 1); + float d1 = il ? (xb->d / 16.h) : xb->d; + float d2 = d1 / 256.f; + float md = -8.h * xb->d; + ushort mask0 = il ? 0x00F0 : 0x000F; + ushort mask1 = mask0 << 8; + + reg->s0 = d1 * (qs[0] & mask0) + md; + reg->s1 = d2 * (qs[0] & mask1) + md; + + reg->s2 = d1 * (qs[1] & mask0) + md; + reg->s3 = d2 * (qs[1] & mask1) + md; + + reg->s4 = d1 * (qs[2] & mask0) + md; + reg->s5 = d2 * (qs[2] & mask1) + md; + + reg->s6 = d1 * (qs[3] & mask0) + md; + reg->s7 = d2 * (qs[3] & mask1) + md; + + reg->s8 = d1 * (qs[4] & mask0) + md; + reg->s9 = d2 * (qs[4] & mask1) + md; + + reg->sa = d1 * (qs[5] & mask0) + md; + reg->sb = d2 * (qs[5] & mask1) + md; + + reg->sc = d1 * (qs[6] & mask0) + md; + reg->sd = d2 * (qs[6] & mask1) + md; + + reg->se = d1 * (qs[7] & mask0) + md; + reg->sf = d2 * (qs[7] & mask1) + md; +} + + +//------------------------------------------------------------------------------ +// get_rows +//------------------------------------------------------------------------------ +kernel void kernel_get_rows_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb10, + ulong nb11, + ulong nb1, + ulong nb2 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i10 = get_group_id(0); + int i11 = get_group_id(1); + + int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + + int i02 = i11; + + for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { + ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = + ((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +kernel void kernel_get_rows_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb10, + ulong nb11, + ulong nb1, + ulong nb2 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i10 = get_group_id(0); + int i11 = get_group_id(1); + + int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + + int i02 = i11; + + for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { + ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = + ((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; + } +} + +kernel void kernel_get_rows_q4_0( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + ulong nb01, + ulong nb02, + int ne10, + ulong nb10, + ulong nb11, + ulong nb1, + ulong nb2 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + const int NL = 2; + + int i10 = get_group_id(0); + int i11 = get_group_id(1); + + int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + + int i02 = i11; + + for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) { + float16 temp; + dequantize_q4_0_f32( + ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp); + *(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; + } +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl.cl deleted file mode 100644 index b887928879..0000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl.cl +++ /dev/null @@ -1,3231 +0,0 @@ -#ifdef cl_khr_fp16 -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -#elif defined(cl_amd_fp16) -#pragma OPENCL EXTENSION cl_amd_fp16 : enable -#else -#error "Half precision floating point not supportedby OpenCL implementation on your device." -#endif - -#ifdef cl_khr_subgroups -#pragma OPENCL EXTENSION cl_khr_subgroups : enable -#elif defined(cl_intel_subgroups) -#pragma OPENCL EXTENSION cl_intel_subgroups : enable -#else -#error "Subgroup not supported on your device." -#endif - -#ifdef cl_intel_required_subgroup_size -// Always use subgroup size of 32 on Intel. -#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable -#define INTEL_GPU 1 -#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) -#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) -#elif defined(cl_qcom_reqd_sub_group_size) -// Always use subgroups size of 64 on Adreno. -#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable -#define ADRENO_GPU 1 -#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) -#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) -#else -// TODO: do not know how to choose subgroup size on other GPUs. -#error "Selecting subgroup size is not supported on your device." -#endif - -#define QK4_0 32 -#define QR4_0 2 -#define QK4_1 32 -#define QR4_1 2 -#define QK5_0 32 -#define QR5_0 2 -#define QK5_1 32 -#define QR5_1 2 -#define QK8_0 32 -#define QR8_0 1 -#define QK_K 256 -#define K_QUANTS_PER_ITERATION 2 - -typedef char int8_t; -typedef uchar uint8_t; -typedef short int16_t; -typedef ushort uint16_t; -typedef int int32_t; -typedef uint uint32_t; - -//------------------------------------------------------------------------------ -// block_q4_0 -//------------------------------------------------------------------------------ -struct block_q4_0 -{ - half d; - uint8_t qs[QK4_0 / 2]; -}; - -//------------------------------------------------------------------------------ -// block_q4_1 -//------------------------------------------------------------------------------ -struct block_q4_1 -{ - half d; - half m; - uint8_t qs[QK4_1 / 2]; -}; - -//------------------------------------------------------------------------------ -// block_q5_0 -//------------------------------------------------------------------------------ -struct block_q5_0 -{ - half d; - uint32_t qh; - uint8_t qs[QK5_0 / 2]; -}; - -//------------------------------------------------------------------------------ -// block_q5_1 -//------------------------------------------------------------------------------ -struct block_q5_1 -{ - half d; - half m; - uint32_t qh; - uint8_t qs[QK5_1 / 2]; -}; - -//------------------------------------------------------------------------------ -// block_q8_0 -//------------------------------------------------------------------------------ -struct block_q8_0 -{ - half d; - int8_t qs[QK8_0]; -}; - -//------------------------------------------------------------------------------ -// block_q2_K -//------------------------------------------------------------------------------ -struct block_q2_K -{ - uint8_t scales[16]; - uint8_t qs[64]; - half d; - half dmin; -}; - -//------------------------------------------------------------------------------ -// block_q3_K -//------------------------------------------------------------------------------ -struct block_q3_K -{ - uint8_t hmask[32]; - uint8_t qs[64]; - uint8_t scales[12]; - half d; -}; - -//------------------------------------------------------------------------------ -// block_q4_K -//------------------------------------------------------------------------------ -struct block_q4_K -{ - half d; - half dmin; - uint8_t scales[12]; - uint8_t qs[128]; -}; - -//------------------------------------------------------------------------------ -// block_q5_K -//------------------------------------------------------------------------------ -struct block_q5_K -{ - half d; - half dmin; - uint8_t scales[12]; - uint8_t qh[32]; - uint8_t qs[128]; -}; - -//------------------------------------------------------------------------------ -// block_q6_K -//------------------------------------------------------------------------------ -struct block_q6_K -{ - uint8_t ql[128]; - uint8_t qh[64]; - int8_t scales[16]; - half d; -}; - -//------------------------------------------------------------------------------ -// dequantize_q4_0_f32, dequantize_q4_0_f16 -//------------------------------------------------------------------------------ -void dequantize_q4_0_f32(global struct block_q4_0 * xb, short il, float16 * reg) { - global ushort * qs = ((global ushort *)xb + 1); - float d1 = il ? (xb->d / 16.h) : xb->d; - float d2 = d1 / 256.f; - float md = -8.h * xb->d; - ushort mask0 = il ? 0x00F0 : 0x000F; - ushort mask1 = mask0 << 8; - - reg->s0 = d1 * (qs[0] & mask0) + md; - reg->s1 = d2 * (qs[0] & mask1) + md; - - reg->s2 = d1 * (qs[1] & mask0) + md; - reg->s3 = d2 * (qs[1] & mask1) + md; - - reg->s4 = d1 * (qs[2] & mask0) + md; - reg->s5 = d2 * (qs[2] & mask1) + md; - - reg->s6 = d1 * (qs[3] & mask0) + md; - reg->s7 = d2 * (qs[3] & mask1) + md; - - reg->s8 = d1 * (qs[4] & mask0) + md; - reg->s9 = d2 * (qs[4] & mask1) + md; - - reg->sa = d1 * (qs[5] & mask0) + md; - reg->sb = d2 * (qs[5] & mask1) + md; - - reg->sc = d1 * (qs[6] & mask0) + md; - reg->sd = d2 * (qs[6] & mask1) + md; - - reg->se = d1 * (qs[7] & mask0) + md; - reg->sf = d2 * (qs[7] & mask1) + md; -} - -void dequantize_q4_0_f16(global struct block_q4_0 * xb, short il, half16 * reg) { - global ushort * qs = ((global ushort *)xb + 1); - half d1 = il ? (xb->d / 16.h) : xb->d; - half d2 = d1 / 256.h; - half md = -8.h * xb->d; - ushort mask0 = il ? 0x00F0 : 0x000F; - ushort mask1 = mask0 << 8; - - reg->s0 = d1 * (qs[0] & mask0) + md; - reg->s1 = d2 * (qs[0] & mask1) + md; - - reg->s2 = d1 * (qs[1] & mask0) + md; - reg->s3 = d2 * (qs[1] & mask1) + md; - - reg->s4 = d1 * (qs[2] & mask0) + md; - reg->s5 = d2 * (qs[2] & mask1) + md; - - reg->s6 = d1 * (qs[3] & mask0) + md; - reg->s7 = d2 * (qs[3] & mask1) + md; - - reg->s8 = d1 * (qs[4] & mask0) + md; - reg->s9 = d2 * (qs[4] & mask1) + md; - - reg->sa = d1 * (qs[5] & mask0) + md; - reg->sb = d2 * (qs[5] & mask1) + md; - - reg->sc = d1 * (qs[6] & mask0) + md; - reg->sd = d2 * (qs[6] & mask1) + md; - - reg->se = d1 * (qs[7] & mask0) + md; - reg->sf = d2 * (qs[7] & mask1) + md; -} - -//------------------------------------------------------------------------------ -// add -//------------------------------------------------------------------------------ - -// general-purpose kernel for addition of two tensors -// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 -// cons: not very efficient -kernel void kernel_add( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global char * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - int ne13, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - src0 = src0 + offset0; - src1 = src1 + offset1; - dst = dst + offsetd; - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int i13 = i03 % ne13; - int i12 = i02 % ne12; - int i11 = i01 % ne11; - - global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; - - for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { - const int i10 = i0 % ne10; - *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) + *((global float *)(src1_ptr + i10*nb10)); - } -} - -// assumption: src1 is a row -// broadcast src1 into src0 -kernel void kernel_add_row( - global float4 * src0, - ulong offset0, - global float4 * src1, - ulong offset1, - global float4 * dst, - ulong offsetd, - int ne -) { - src0 = (global float4*)((global char*)src0 + offset0); - src1 = (global float4*)((global char*)src1 + offset1); - dst = (global float4*)((global char*)dst + offsetd); - - // This performs better than using %. - uint gid = get_global_id(0); - uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne - dst[gid] = src0[gid] + src1[idx1]; -} - -//------------------------------------------------------------------------------ -// mul -//------------------------------------------------------------------------------ -kernel void kernel_mul( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global char * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - int ne13, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - src0 = src0 + offset0; - src1 = src1 + offset1; - dst = dst + offsetd; - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int i13 = i03 % ne13; - int i12 = i02 % ne12; - int i11 = i01 % ne11; - - global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; - - for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { - const int i10 = i0 % ne10; - *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) * *((global float *)(src1_ptr + i10*nb10)); - } -} - -// assumption: src1 is a row -// broadcast src1 into src0 -kernel void kernel_mul_row( - global float4 * src0, - ulong offset0, - global float4 * src1, - ulong offset1, - global float4 * dst, - ulong offsetd, - int ne -) { - src0 = (global float4*)((global char*)src0 + offset0); - src1 = (global float4*)((global char*)src1 + offset1); - dst = (global float4*)((global char*)dst + offsetd); - - // This performs better than using %. - uint gid = get_global_id(0); - uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne - dst[gid] = src0[gid] * src1[idx1]; -} - -//------------------------------------------------------------------------------ -// scale -//------------------------------------------------------------------------------ -kernel void kernel_scale( - global float4 * src0, - ulong offset0, - global float4 * dst, - ulong offsetd, - float scale -) { - src0 = (global float4*)((global char*)src0 + offset0); - dst = (global float4*)((global char*)dst + offsetd); - dst[get_global_id(0)] = src0[get_global_id(0)] * scale; -} - -//------------------------------------------------------------------------------ -// gelu -//------------------------------------------------------------------------------ -#define GELU_COEF_A 0.044715f -#define GELU_QUICK_COEF -1.702f -#define SQRT_2_OVER_PI 0.79788456080286535587989211986876f - -kernel void kernel_gelu( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - float x = src0[get_global_id(0)]; - - dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_gelu_4( - global float4 * src0, - ulong offset0, - global float4 * dst, - ulong offsetd -) { - src0 = (global float4*)((global char*)src0 + offset0); - dst = (global float4*)((global char*)dst + offsetd); - - float4 x = src0[get_global_id(0)]; - - dst[get_global_id(0)] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_gelu_quick( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - float x = src0[get_global_id(0)]; - dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} - -kernel void kernel_gelu_quick_4( - global float4 * src0, - ulong offset0, - global float4 * dst, - ulong offsetd -) { - src0 = (global float4*)((global char*)src0 + offset0); - dst = (global float4*)((global char*)dst + offsetd); - - float4 x = src0[get_global_id(0)]; - dst[get_global_id(0)] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} - -//------------------------------------------------------------------------------ -// silu -//------------------------------------------------------------------------------ -kernel void kernel_silu( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - float x = src0[get_global_id(0)]; - dst[get_global_id(0)] = x / (1.0f + exp(-x)); -} - -kernel void kernel_silu_4( - global float4 * src0, - ulong offset0, - global float4 * dst, - ulong offsetd -) { - src0 = (global float4*)((global char*)src0 + offset0); - dst = (global float4*)((global char*)dst + offsetd); - - float4 x = src0[get_global_id(0)]; - dst[get_global_id(0)] = x / (1.0f + exp(-x)); -} - -//------------------------------------------------------------------------------ -// relu -//------------------------------------------------------------------------------ -kernel void kernel_relu( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - dst[get_global_id(0)] = fmax(0.0f, src0[get_global_id(0)]); -} - -//------------------------------------------------------------------------------ -// clamp -//------------------------------------------------------------------------------ -kernel void kernel_clamp( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd, - float min, - float max -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - dst[get_global_id(0)] = src0[get_global_id(0)] < min ? - min : - (src0[get_global_id(0)] > max ? max : src0[get_global_id(0)]); -} - -//------------------------------------------------------------------------------ -// norm -//------------------------------------------------------------------------------ -kernel void kernel_norm( - global void * src0, - ulong offset0, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb01, - ulong nb02, - ulong nb03, - float eps, - local float * sum -) { - src0 = (global void*)((global char*)src0 + offset0); - dst = (global void*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); - - // MEAN - // parallel sum - sum[get_local_id(0)] = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - sum[get_local_id(0)] += x[i00]; - } - // reduce - barrier(CLK_LOCAL_MEM_FENCE); - for (uint i = get_local_size(0)/2; i > 0; i /= 2) { - if (get_local_id(0) < i) { - sum[get_local_id(0)] += sum[get_local_id(0) + i]; - } - barrier(CLK_LOCAL_MEM_FENCE); - } - float mean = sum[0] / ne00; - - // recenter and VARIANCE - barrier(CLK_LOCAL_MEM_FENCE); - global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - sum[get_local_id(0)] = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - y[i00] = x[i00] - mean; - sum[get_local_id(0)] += y[i00] * y[i00]; - } - - // reduce - barrier(CLK_LOCAL_MEM_FENCE); - for (uint i = get_local_size(0)/2; i > 0; i /= 2) { - if (get_local_id(0) < i) { - sum[get_local_id(0)] += sum[get_local_id(0) + i]; - } - barrier(CLK_LOCAL_MEM_FENCE); - } - float variance = sum[0] / ne00; - - float scale = 1.0f/sqrt(variance + eps); - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - y[i00] = y[i00] * scale; - } -} - -//------------------------------------------------------------------------------ -// rms_norm -//------------------------------------------------------------------------------ -// This kernel depends on subgroup size. -kernel void kernel_rms_norm( - global void * src0, - ulong offset0, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb01, - ulong nb02, - ulong nb03, - float eps, - local float * sum // Note, the size depends on number of subgroups -) { - src0 = (global void*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); - global float * x_scalar = (global float *) x; - float4 sumf = 0; - float all_sum = 0; - - // parallel sum - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - sumf += x[i00] * x[i00]; - } - all_sum = sumf.s0 + sumf.s1 + sumf.s2 + sumf.s3; - all_sum = sub_group_reduce_add(all_sum); - if (get_sub_group_local_id() == 0) { - sum[get_sub_group_id()] = all_sum; - } - - barrier(CLK_LOCAL_MEM_FENCE); - // broadcast - for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) { - if (get_local_id(0) < i) { - sum[get_local_id(0)] += sum[get_local_id(0) + i]; - } - } - if (get_local_id(0) == 0) { - for (int i = 4 * (ne00 / 4); i < ne00; i++) { - sum[0] += x_scalar[i]; - } - sum[0] /= ne00; - } - - barrier(CLK_LOCAL_MEM_FENCE); - - const float mean = sum[0]; - const float scale = 1.0f/sqrt(mean + eps); - - global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - global float * y_scalar = (global float *) y; - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - y[i00] = x[i00] * scale; - } - if (get_local_id(0) == 0) { - for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) { - y_scalar[i00] = x_scalar[i00] * scale; - } - } -} - -//------------------------------------------------------------------------------ -// diag_mask_inf kernels -//------------------------------------------------------------------------------ -kernel void kernel_diag_mask_inf( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int n_past -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - int i02 = get_global_id(2); - int i01 = get_global_id(1); - int i00 = get_global_id(0); - - if (i00 > n_past + i01) { - dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; - } else { - dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; - } -} - -kernel void kernel_diag_mask_inf_8( - global float4 * src0, - ulong offset0, - global float4 * dst, - ulong offsetd, - int ne00, - int ne01, - int n_past -) { - src0 = (global float4*)((global char*)src0 + offset0); - dst = (global float4*)((global char*)dst + offsetd); - - int i = 2*get_global_id(0); - - dst[i+0] = src0[i+0]; - dst[i+1] = src0[i+1]; - int i4 = 4*i; - int i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; - int i01 = i4/(ne00); i4 -= i01*ne00; - int i00 = i4; - for (int k = 3; k >= 0; --k) { - if (i00 + 4 + k <= n_past + i01) { - break; - } - (&dst[i+1])[k] = -INFINITY; - if (i00 + k > n_past + i01) { - (&dst[i])[k] = -INFINITY; - } - } -} - -//------------------------------------------------------------------------------ -// softmax -//------------------------------------------------------------------------------ -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_soft_max( - global float * src0, - ulong offset0, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - float scale, - float max_bias, - float m0, - float m1, - int n_head_log2 -) { - src0 = (global float*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0; - global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - int h = i02; - - float base = h < n_head_log2 ? m0 : m1; - int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - // parallel max - float lmax = -INFINITY; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); - } - float max = sub_group_reduce_max(lmax); - - // parallel sum - float lsum = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); - lsum += exp_psrc0; - // Remember the result of exp here. exp is expensive, so we really do not - // wish to compute it twice. - pdst[i00] = exp_psrc0; - } - - const float sum = sub_group_reduce_add(lsum); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - pdst[i00] /= sum; - } -} - -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_soft_max_4( - global float * src0, - ulong offset0, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - float scale, - float max_bias, - float m0, - float m1, - int n_head_log2 -) { - src0 = (global float*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0; - global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - int h = i02; - - float base = h < n_head_log2 ? m0 : m1; - int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - // parallel max - float4 lmax4 = -INFINITY; - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); - } - float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3)); - - const float max = sub_group_reduce_max(lmax); - - // parallel sum - float4 lsum4 = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); - lsum4 += exp_psrc4; - pdst4[i00] = exp_psrc4; - } - float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; - - const float sum = sub_group_reduce_add(lsum); - - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - pdst4[i00] /= sum; - } -} - -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_soft_max_f16( - global float * src0, - ulong offset0, - global half * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - float scale, - float max_bias, - float m0, - float m1, - int n_head_log2 -) { - src0 = (global float *)((global char *)src0 + offset0); - src1 = (global half *)((global char *)src1 + offset1); - dst = (global float *)((global char *)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0; - global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - int h = i02; - - float base = h < n_head_log2 ? m0 : m1; - int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - // parallel max - float lmax = -INFINITY; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); - } - float max = sub_group_reduce_max(lmax); - - // parallel sum - float lsum = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); - lsum += exp_psrc0; - // Remember the result of exp here. exp is expensive, so we really do not - // wish to compute it twice. - pdst[i00] = exp_psrc0; - } - - const float sum = sub_group_reduce_add(lsum); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - pdst[i00] /= sum; - } -} - -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_soft_max_4_f16( - global float * src0, - ulong offset0, - global half * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - float scale, - float max_bias, - float m0, - float m1, - int n_head_log2 -) { - src0 = (global float *)((global char *)src0 + offset0); - src1 = (global half *)((global char *)src1 + offset1); - dst = (global float *)((global char *)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0; - global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - - float slope = 1.0f; - - // ALiBi - if (max_bias > 0.0f) { - int h = i02; - - float base = h < n_head_log2 ? m0 : m1; - int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; - - slope = pow(base, exp); - } - - // parallel max - float4 lmax4 = -INFINITY; - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)); - } - float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3)); - - const float max = sub_group_reduce_max(lmax); - - // parallel sum - float4 lsum4 = 0.0f; - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)) - max); - lsum4 += exp_psrc4; - pdst4[i00] = exp_psrc4; - } - float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; - - const float sum = sub_group_reduce_add(lsum); - - for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { - pdst4[i00] /= sum; - } -} - -//------------------------------------------------------------------------------ -// kernel_rope -//------------------------------------------------------------------------------ -float rope_yarn_ramp(float low, float high, int i0) { - const float y = (i0 / 2 - low) / max(0.001f, high - low); - return 1.0f - min(1.0f, max(0.0f, y)); -} - -// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn -// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. -float2 rope_yarn( - float theta_extrap, float freq_scale, float2 corr_dims, int i0, float ext_factor, float mscale -) { - // Get n-d rotational scaling corrected for extrapolation - float theta_interp = freq_scale * theta_extrap; - float theta = theta_interp; - if (ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(corr_dims.s0, corr_dims.s1, i0) * ext_factor; - theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - - // Get n-d magnitude scaling corrected for interpolation - mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); - } - return (float2)(cos(theta) * mscale, sin(theta) * mscale); -} - -// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get -// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` -float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { - return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base)); -} - -float2 rope_yarn_corr_dims( - int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow -) { - // start and end correction dims - return (float2)( - max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))), - min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))) - ); -} - -kernel void kernel_rope_norm_f32( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - float theta_base = (float) pos[i2]; - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - float theta = theta_base * pow(freq_base, inv_ndims*i0); - - float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - float x0 = src[0]; - float x1 = src[1]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_norm_f16( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - float theta_base = (float) pos[i2]; - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - float theta = theta_base * pow(freq_base, inv_ndims*i0); - - float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - float x0 = src[0]; - float x1 = src[1]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_neox_f32( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - float theta_base = (float) pos[i2]; - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - const float theta = theta_base * pow(freq_base, inv_ndims*i0); - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_neox_f16( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - float theta_base = (float) pos[i2]; - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - const float theta = theta_base * pow(freq_base, inv_ndims*i0); - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_multi_f32( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow, - int4 sections -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; - const int sec_w = sections.s1 + sections.s0; - - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - const int sector = (i0 / 2) % sect_dims; - float theta_base = 0.0f; - - if (sector < sections.s0) { - theta_base = pos[i2]; - } - else if (sector >= sections.s0 && sector < sec_w) { - theta_base = pos[i2 + ne2 * 1]; - } - else if (sector >= sec_w && sector < sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 2]; - } - else if (sector >= sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 3]; - } - - const float theta = theta_base * pow(freq_base, inv_ndims*i0); - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_multi_f16( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global half * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow, - int4 sections -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; - const int sec_w = sections.s1 + sections.s0; - - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - if (i0 < n_dims) { - int ic = i0/2; - - const int sector = (i0 / 2) % sect_dims; - float theta_base = 0.0f; - - if (sector < sections.s0) { - theta_base = pos[i2]; - } - else if (sector >= sections.s0 && sector < sec_w) { - theta_base = pos[i2 + ne2 * 1]; - } - else if (sector >= sec_w && sector < sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 2]; - } - else if (sector >= sec_w + sections.s2) { - theta_base = pos[i2 + ne2 * 3]; - } - - const float theta = theta_base * pow(freq_base, inv_ndims*i0); - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims/2]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } else { - global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - dst_data[0] = src[0]; - dst_data[1] = src[1]; - } - } -} - -kernel void kernel_rope_vision_f32( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow, - int4 sections -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - const int sect_dims = sections.s0 + sections.s1; - const int sec_w = sections.s1 + sections.s0; - - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - int ic = i0/2; - - const int sector = (i0/2) % sect_dims; - float theta_base = 0.0f; - - if (sector < sections.s0) { - const int p = sector; - theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); - } else if (sector >= sections.s0 && sector < sec_w) { - const int p = sector - sections.s0; - theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); - } - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } -} - -kernel void kernel_rope_vision_f16( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * src2, - ulong offset2, - global half * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3, - int n_past, - int n_dims, - int n_ctx_orig, - float freq_base, - float freq_scale, - float ext_factor, - float attn_factor, - float beta_fast, - float beta_slow, - int4 sections -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - src2 = (global float*)((global char*)src2 + offset2); - dst = (global float*)((global char*)dst + offsetd); - - int i3 = get_group_id(2); - int i2 = get_group_id(1); - int i1 = get_group_id(0); - - float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); - - global int * pos = src1; - - const int sect_dims = sections.s0 + sections.s1; - const int sec_w = sections.s1 + sections.s0; - - float inv_ndims = -1.f/n_dims; - - for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { - int ic = i0/2; - - const int sector = (i0/2) % sect_dims; - float theta_base = 0.0f; - - if (sector < sections.s0) { - const int p = sector; - theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); - } else if (sector >= sections.s0 && sector < sec_w) { - const int p = sector - sections.s0; - theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); - } - - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; - - float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); - - global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); - - const float x0 = src[0]; - const float x1 = src[n_dims]; - - dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; - dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; - } -} - -//------------------------------------------------------------------------------ -// cpy -//------------------------------------------------------------------------------ - -kernel void kernel_cpy_f16_f16( - global half * src0, - ulong offset0, - global half * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - src0 = (global half*)((global char*)src0 + offset0); - dst = (global half*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - int i3 = n / (ne2*ne1*ne0); - int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - global const half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f16_f32( - global half * src0, - ulong offset0, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - - src0 = (global half*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - int i3 = n / (ne2*ne1*ne0); - int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - global half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_f16( - global float * src0, - ulong offset0, - global half * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global half*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - int i3 = n / (ne2*ne1*ne0); - int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - dst_data[i00] = src[0]; - } -} - -kernel void kernel_cpy_f32_f32( - global float * src0, - ulong offset0, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne03, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne0, - int ne1, - int ne2, - int ne3, - ulong nb0, - ulong nb1, - ulong nb2, - ulong nb3 -) { - src0 = (global float*)((global char*)src0 + offset0); - dst = (global float*)((global char*)dst + offsetd); - - int i03 = get_group_id(2); - int i02 = get_group_id(1); - int i01 = get_group_id(0); - - int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - - int i3 = n / (ne2*ne1*ne0); - int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); - - global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - - for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - - dst_data[i00] = src[0]; - } -} - -//------------------------------------------------------------------------------ -// get_rows -//------------------------------------------------------------------------------ -kernel void kernel_get_rows_f32( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - ulong nb01, - ulong nb02, - int ne10, - ulong nb10, - ulong nb11, - ulong nb1, - ulong nb2 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int i10 = get_group_id(0); - int i11 = get_group_id(1); - - int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; - - int i02 = i11; - - for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { - ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = - ((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; - } -} - -kernel void kernel_get_rows_f16( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - ulong nb01, - ulong nb02, - int ne10, - ulong nb10, - ulong nb11, - ulong nb1, - ulong nb2 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int i10 = get_group_id(0); - int i11 = get_group_id(1); - - int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; - - int i02 = i11; - - for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { - ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = - ((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; - } -} - -kernel void kernel_get_rows_q4_0( - global void * src0, - ulong offset0, - global int * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - ulong nb01, - ulong nb02, - int ne10, - ulong nb10, - ulong nb11, - ulong nb1, - ulong nb2 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global int*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - const int NL = 2; - - int i10 = get_group_id(0); - int i11 = get_group_id(1); - - int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; - - int i02 = i11; - - for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) { - float16 temp; - dequantize_q4_0_f32( - ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp); - *(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; - } -} - -//------------------------------------------------------------------------------ -// mul_mat_f32_f32 -//------------------------------------------------------------------------------ -#define N_F32_F32 4 - -kernel void kernel_mul_mat_f32_f32( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global char*)((global char*)src0 + offset0); - src1 = (global char*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int r0 = get_group_id(0); - int rb = get_group_id(1)*N_F32_F32; - int im = get_group_id(2); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - - global float * x = (global float *) (src0 + offset_src0); - - if (ne00 < 128) { - for (int row = 0; row < N_F32_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global float * y = (global float *) (src1 + offset_src1); - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { - sumf += (float) x[i] * (float) y[i]; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - global float4 * x4 = (global float4 *)x; - for (int row = 0; row < N_F32_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global float * y = (global float *) (src1 + offset_src1); - global float4 * y4 = (global float4 *) y; - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { - sumf += (float) x4[i].s0 * y4[i].s0; - sumf += (float) x4[i].s1 * y4[i].s1; - sumf += (float) x4[i].s2 * y4[i].s2; - sumf += (float) x4[i].s3 * y4[i].s3; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) { - all_sum += (float) x[i] * y[i]; - } - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -//------------------------------------------------------------------------------ -// mul_mat_f16_f16 -//------------------------------------------------------------------------------ -#define N_F16_F16 4 - -kernel void kernel_mul_mat_f16_f16( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int r2, - int r3) -{ - src0 = (global char*)((global char*)src0 + offset0); - src1 = (global char*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int r0 = get_group_id(0); - int rb = get_group_id(1)*N_F16_F16; - int im = get_group_id(2); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - - global half * x = (global half *) (src0 + offset_src0); - - if (ne00 < 128) { - for (int row = 0; row < N_F16_F16; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global half * y = (global half *) (src1 + offset_src1); - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { - sumf += (half) x[i] * (half) y[i]; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - global half4 * x4 = (global half4 *)x; - for (int row = 0; row < N_F16_F16; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global half * y = (global half *) (src1 + offset_src1); - global half4 * y4 = (global half4 *) y; - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { - sumf += (half) x4[i].s0 * y4[i].s0; - sumf += (half) x4[i].s1 * y4[i].s1; - sumf += (half) x4[i].s2 * y4[i].s2; - sumf += (half) x4[i].s3 * y4[i].s3; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) { - all_sum += (half) x[i] * y[i]; - } - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -//------------------------------------------------------------------------------ -// mul_mat_f16_f32_1row -//------------------------------------------------------------------------------ -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_f16_f32_1row( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global char*)((global char*)src0 + offset0); - src1 = (global char*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global half * x = (global half *) (src0 + offset_src0); - global float * y = (global float *) (src1 + offset_src1); - - float sumf = 0; - if (ne00 < 128) { - for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { - sumf += (float) x[i] * (float) y[i]; - } - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } else { - global half4 * x4 = (global half4 *) x; - global float4 * y4 = (global float4 *) y; - for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { - sumf += (float) x4[i].s0 * y4[i].s0; - sumf += (float) x4[i].s1 * y4[i].s1; - sumf += (float) x4[i].s2 * y4[i].s2; - sumf += (float) x4[i].s3 * y4[i].s3; - } - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) { - all_sum += (float) x[i] * y[i]; - } - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - -} - -//------------------------------------------------------------------------------ -// mul_mat_f16_f32 -//------------------------------------------------------------------------------ -#define N_F16_F32 4 - -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_f16_f32( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global char*)((global char*)src0 + offset0); - src1 = (global char*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int r0 = get_group_id(0); - int rb = get_group_id(1)*N_F16_F32; - int im = get_group_id(2); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - - global half * x = (global half *) (src0 + offset_src0); - - if (ne00 < 128) { - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global float * y = (global float *) (src1 + offset_src1); - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { - sumf += convert_float(x[i]) * y[i]; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } else { - global half4 * x4 = (global half4 *)x; - for (int row = 0; row < N_F16_F32; ++row) { - int r1 = rb + row; - if (r1 >= ne11) { - break; - } - - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global float * y = (global float *) (src1 + offset_src1); - global float4 * y4 = (global float4 *) y; - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { - sumf += convert_float(x4[i].s0) * y4[i].s0; - sumf += convert_float(x4[i].s1) * y4[i].s1; - sumf += convert_float(x4[i].s2) * y4[i].s2; - sumf += convert_float(x4[i].s3) * y4[i].s3; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) { - all_sum += (float) x[i] * y[i]; - } - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } - } -} - -//------------------------------------------------------------------------------ -// mul_mat_f16_f32_l4 -//------------------------------------------------------------------------------ -// Assumes row size (ne00) is a multiple of 4 -#ifdef ADRENO_GPU -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_f16_f32_l4( - global char * src0, - ulong offset0, - global char * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - ulong nb00, - ulong nb01, - ulong nb02, - ulong nb03, - int ne10, - int ne11, - int ne12, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global char*)((global char*)src0 + offset0); - src1 = (global char*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - int nrows = ne11; - int r0 = get_group_id(0); - int im = get_group_id(2); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - - global half4 * x4 = (global half4 *) (src0 + offset_src0); - - for (int r1 = 0; r1 < nrows; ++r1) { - ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; - - global float4 * y4 = (global float4 *) (src1 + offset_src1); - - float sumf = 0; - for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { - sumf += convert_float(x4[i].s0) * y4[i].s0; - sumf += convert_float(x4[i].s1) * y4[i].s1; - sumf += convert_float(x4[i].s2) * y4[i].s2; - sumf += convert_float(x4[i].s3) * y4[i].s3; - } - - float all_sum = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; - } - } -} - -//------------------------------------------------------------------------------ -// mul_vec_q_n_f32 -//------------------------------------------------------------------------------ -// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) -// il indicates where the q4 quants begin (0 or QK4_0/4) -// we assume that the yl's have been multiplied with the appropriate scale factor -// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) -inline float block_q_4_0_dot_y( - global struct block_q4_0 * qb_curr, - float sumy, - private float * yl, - int il -) { - float d = qb_curr->d; - float2 acc = 0.f; - global ushort * qs = ((global ushort *)qb_curr + 1 + il/2); - for (int i = 0; i < 8; i+=2) { - acc.s0 += yl[i + 0] * (qs[i / 2] & 0x000F) - + yl[i + 1] * (qs[i / 2] & 0x0F00); - acc.s1 += yl[i + 8] * (qs[i / 2] & 0x00F0) - + yl[i + 9] * (qs[i / 2] & 0xF000); - } - return d * (sumy * -8.f + acc.s0 + acc.s1); -} - -#ifdef INTEL_GPU -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 -#elif defined (ADRENO_GPU) -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -inline void mul_vec_q_n_f32( - global void * src0, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - - const ulong nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global - // id of a SIMD group in the grid. - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float yl[16]; // src1 vector cache - float sumf[N_DST]={0.f}; - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix * QK4_0 + il; - - // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0; - for (int i = 0; i < 8; i += 2) { - sumy += yb[i] + yb[i+1]; - yl[i+0] = yb[i+ 0]; - yl[i+1] = yb[i+ 1]/256.f; - sumy += yb[i+16] + yb[i+17]; - yl[i+8] = yb[i+16]/16.f; - yl[i+9] = yb[i+17]/4096.f; - } - - for (int row = 0; row < N_DST; row++) { - sumf[row] += block_q_4_0_dot_y(x+ib+row*nb, sumy, yl, il); - } - - // One thread in a SIMD group (i.e., subgroup) handles a half block, - // hence then entire SIMD group handles SIMDWIDTH/2 blocks. - // y points to the activation matrix (of type float). Therefore for - // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because - // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of - // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size. - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - // The above does not work for Adreno - it produces incorrect results for - // row = 1, 2, 3 and only row = 0 gives the correct result. - // If N_DST is changed, the below array must be initialized accordingly. - // This also seems to perform better on Intel. - float tot[N_DST] = { - sub_group_reduce_add(sumf[0]), sub_group_reduce_add(sumf[1]), - sub_group_reduce_add(sumf[2]), sub_group_reduce_add(sumf[3])}; - for (int row = 0; row < N_DST; ++row) { - if (get_sub_group_local_id() == 0 && first_row + row < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot[row]; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32( - global void * src0, - ulong offset0, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} - -// -// This variant unrolls the loops and uses vector types instead of pointers. -// It improves performance on Adreno but not so much on Intel. -// -inline float block_q_4_0_dot_y_v( - global struct block_q4_0 * qb_curr, - float sumy, - float16 yl, - int il -) { - float d = qb_curr->d; - float acc = 0.f; - global ushort * qs = ((global ushort *)qb_curr + 1 + il/2); - - acc += yl.s0 * (qs[0] & 0x000F); - acc += yl.s1 * (qs[0] & 0x0F00); - acc += yl.s8 * (qs[0] & 0x00F0); - acc += yl.s9 * (qs[0] & 0xF000); - - acc += yl.s2 * (qs[1] & 0x000F); - acc += yl.s3 * (qs[1] & 0x0F00); - acc += yl.sa * (qs[1] & 0x00F0); - acc += yl.sb * (qs[1] & 0xF000); - - acc += yl.s4 * (qs[2] & 0x000F); - acc += yl.s5 * (qs[2] & 0x0F00); - acc += yl.sc * (qs[2] & 0x00F0); - acc += yl.sd * (qs[2] & 0xF000); - - acc += yl.s6 * (qs[3] & 0x000F); - acc += yl.s7 * (qs[3] & 0x0F00); - acc += yl.se * (qs[3] & 0x00F0); - acc += yl.sf * (qs[3] & 0xF000); - - return d * (sumy * -8.f + acc); -} - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 -#elif defined (ADRENO_GPU) -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -inline void mul_vec_q_n_f32_v( - global void * src0, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - const ulong nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global - // id of a SIMD group in the grid. - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float16 yl; // src1 vector cache - float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix * QK4_0 + il; - - // each thread in a SIMD group deals with half a block. - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0; - - sumy += yb[0]; - sumy += yb[1]; - sumy += yb[2]; - sumy += yb[3]; - sumy += yb[4]; - sumy += yb[5]; - sumy += yb[6]; - sumy += yb[7]; - - sumy += yb[16]; - sumy += yb[17]; - sumy += yb[18]; - sumy += yb[19]; - sumy += yb[20]; - sumy += yb[21]; - sumy += yb[22]; - sumy += yb[23]; - - - yl.s0 = yb[0]; - yl.s1 = yb[1]/256.f; - - yl.s2 = yb[2]; - yl.s3 = yb[3]/256.f; - - yl.s4 = yb[4]; - yl.s5 = yb[5]/256.f; - - yl.s6 = yb[6]; - yl.s7 = yb[7]/256.f; - - yl.s8 = yb[16]/16.f; - yl.s9 = yb[17]/4096.f; - - yl.sa = yb[18]/16.f; - yl.sb = yb[19]/4096.f; - - yl.sc = yb[20]/16.f; - yl.sd = yb[21]/4096.f; - - yl.se = yb[22]/16.f; - yl.sf = yb[23]/4096.f; - - sumf.s0 += block_q_4_0_dot_y_v(x+ib+0*nb, sumy, yl, il); - sumf.s1 += block_q_4_0_dot_y_v(x+ib+1*nb, sumy, yl, il); - sumf.s2 += block_q_4_0_dot_y_v(x+ib+2*nb, sumy, yl, il); - sumf.s3 += block_q_4_0_dot_y_v(x+ib+3*nb, sumy, yl, il); - - // One thread in a SIMD group (i.e., subgroup) handles a half block, - // hence then entire SIMD group handles SIMDWIDTH/2 blocks. - // y points to the activation matrix (of type float). Therefore for - // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because - // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of - // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size. - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - // The above does not work for Adreno - it produces incorrect results for - // row = 1, 2, 3 and only row = 0 gives the correct result. - // If N_DST is changed, the below array must be initialized accordingly. - // This also seems to perform better on Intel. - float4 tot = (float4)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_v( - global void * src0, - ulong offset0, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_vec_q_n_f32_v(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} - -//------------------------------------------------------------------------------ -// kernel_convert_block_q4_0 -// Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA). -// This kernel does not deshuffle the bits. -//------------------------------------------------------------------------------ -kernel void kernel_convert_block_q4_0( - global struct block_q4_0 * src0, - global uchar * dst_q, - global half * dst_d -) { - global struct block_q4_0 * b = (global struct block_q4_0 *) src0 + get_global_id(0); - global uchar * q = (global uchar *) dst_q + QK4_0/2*get_global_id(0); - global half * d = (global half *) dst_d + get_global_id(0); - - *d = b->d; - - for (int i = 0; i < QK4_0/2; ++i) { - q[i] = b->qs[i]; - } -} - -kernel void kernel_restore_block_q4_0( - global uchar * src_q, - global half * src_d, - global struct block_q4_0 * dst -) { - global struct block_q4_0 * b = (global struct block_q4_0 *) dst + get_global_id(0); - global uchar * q = (global uchar *) src_q + QK4_0/2*get_global_id(0); - global half * d = (global half *) src_d + get_global_id(0); - - b->d = *d; - for (int i = 0; i < QK4_0/2; ++i) { - b->qs[i] = q[i]; - } -} - -//------------------------------------------------------------------------------ -// mul_vec_q_n_f32_flat -// -// This variation uses flat arrays (struct of arrays, SOA) representation for -// quant tensors. -//------------------------------------------------------------------------------ - -// This function requires the original shuffled weights. -// As a reminder, the original weights are shuffled so that (q[0], q[16]) are -// packed together in a byte, so are (q[1], q[17]) and so on. -inline float block_q_4_0_dot_y_flat( - global uchar * x, - global half * dh, - float sumy, - float16 yl, - int il -) { - float d = *dh; - global ushort * qs = ((global ushort *)x + il/2); - float acc = 0.f; - - acc += yl.s0 * (qs[0] & 0x000F); - acc += yl.s1 * (qs[0] & 0x0F00); - acc += yl.s8 * (qs[0] & 0x00F0); - acc += yl.s9 * (qs[0] & 0xF000); - - acc += yl.s2 * (qs[1] & 0x000F); - acc += yl.s3 * (qs[1] & 0x0F00); - acc += yl.sa * (qs[1] & 0x00F0); - acc += yl.sb * (qs[1] & 0xF000); - - acc += yl.s4 * (qs[2] & 0x000F); - acc += yl.s5 * (qs[2] & 0x0F00); - acc += yl.sc * (qs[2] & 0x00F0); - acc += yl.sd * (qs[2] & 0xF000); - - acc += yl.s6 * (qs[3] & 0x000F); - acc += yl.s7 * (qs[3] & 0x0F00); - acc += yl.se * (qs[3] & 0x00F0); - acc += yl.sf * (qs[3] & 0xF000); - - return d * (sumy * -8.f + acc); -} - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 32 -#elif defined (ADRENO_GPU) -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -inline void mul_vec_q_n_f32_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - const ulong nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of - // a SIMD group in the grid. Each SIMD group produces N_DST values in the - // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. - // Currently with llama2 7B, im is always 0. - // TODO: how to handle im/gqa*(nb*ne0)? - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; - - global uchar * x = (global uchar *) src0_q + offset0_q; - global half * d = (global half *) src0_d + offset0_d; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float16 yl; - float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix*QK4_0 + il; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0.f; - - sumy += yb[0]; - sumy += yb[1]; - sumy += yb[2]; - sumy += yb[3]; - sumy += yb[4]; - sumy += yb[5]; - sumy += yb[6]; - sumy += yb[7]; - - sumy += yb[16]; - sumy += yb[17]; - sumy += yb[18]; - sumy += yb[19]; - sumy += yb[20]; - sumy += yb[21]; - sumy += yb[22]; - sumy += yb[23]; - - yl.s0 = yb[0]; - yl.s1 = yb[1]/256.f; - - yl.s2 = yb[2]; - yl.s3 = yb[3]/256.f; - - yl.s4 = yb[4]; - yl.s5 = yb[5]/256.f; - - yl.s6 = yb[6]; - yl.s7 = yb[7]/256.f; - - yl.s8 = yb[16]/16.f; - yl.s9 = yb[17]/4096.f; - - yl.sa = yb[18]/16.f; - yl.sb = yb[19]/4096.f; - - yl.sc = yb[20]/16.f; - yl.sd = yb[21]/4096.f; - - yl.se = yb[22]/16.f; - yl.sf = yb[23]/4096.f; - - sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); - sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); - sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); - sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); - - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - float4 tot = (float4)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_vec_q_n_f32_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} - -// -// This variant outputs 8 values. -// -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 8 // each SIMD group works on 8 rows -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 32 -#elif defined (ADRENO_GPU) -#define N_DST 8 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -inline void mul_vec_q_n_f32_8x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - const ulong nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of - // a SIMD group in the grid. Each SIMD group produces N_DST values in the - // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. - // Currently with llama2 7B, im is always 0. - // TODO: how to handle im/gqa*(nb*ne0)? - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; - - global uchar * x = (global uchar *) src0_q + offset0_q; - global half * d = (global half *) src0_d + offset0_d; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float16 yl; - float8 sumf = 0.f; - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix*QK4_0 + il; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0.f; - - sumy += yb[0]; - sumy += yb[1]; - sumy += yb[2]; - sumy += yb[3]; - sumy += yb[4]; - sumy += yb[5]; - sumy += yb[6]; - sumy += yb[7]; - - sumy += yb[16]; - sumy += yb[17]; - sumy += yb[18]; - sumy += yb[19]; - sumy += yb[20]; - sumy += yb[21]; - sumy += yb[22]; - sumy += yb[23]; - - yl.s0 = yb[0]; - yl.s1 = yb[1]/256.f; - - yl.s2 = yb[2]; - yl.s3 = yb[3]/256.f; - - yl.s4 = yb[4]; - yl.s5 = yb[5]/256.f; - - yl.s6 = yb[6]; - yl.s7 = yb[7]/256.f; - - yl.s8 = yb[16]/16.f; - yl.s9 = yb[17]/4096.f; - - yl.sa = yb[18]/16.f; - yl.sb = yb[19]/4096.f; - - yl.sc = yb[20]/16.f; - yl.sd = yb[21]/4096.f; - - yl.se = yb[22]/16.f; - yl.sf = yb[23]/4096.f; - - sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); - sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); - sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); - sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); - - sumf.s4 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); - sumf.s5 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); - sumf.s6 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); - sumf.s7 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); - - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - float8 tot = (float8)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), - sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), - sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } - - if (first_row + 4 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; - } - if (first_row + 5 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; - } - if (first_row + 6 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; - } - if (first_row + 7 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_8x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_vec_q_n_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl deleted file mode 100644 index 9b41dfb255..0000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_im2col.cl +++ /dev/null @@ -1,146 +0,0 @@ -#ifdef cl_khr_fp16 -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -#elif defined(cl_amd_fp16) -#pragma OPENCL EXTENSION cl_amd_fp16 : enable -#else -#error "Half precision floating point not supportedby OpenCL implementation on your device." -#endif - -#ifdef cl_khr_subgroups -#pragma OPENCL EXTENSION cl_khr_subgroups : enable -#elif defined(cl_intel_subgroups) -#pragma OPENCL EXTENSION cl_intel_subgroups : enable -#else -#error "Subgroup not supported on your device." -#endif - -#ifdef cl_intel_required_subgroup_size -// Always use subgroup size of 32 on Intel. -#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable -#define INTEL_GPU 1 -#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) -#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) -#elif defined(cl_qcom_reqd_sub_group_size) -// Always use subgroups size of 64 on Adreno. -#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable -#define ADRENO_GPU 1 -#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) -#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) -#else -// TODO: do not know how to choose subgroup size on other GPUs. -#error "Selecting subgroup size is not supported on your device." -#endif - -kernel void kernel_im2col_f32( - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - ulong batch_offset, - ulong delta_offset, - long IW, - long IH, - long IC, - long OW, - long OH, - long KW, - long KH, - long pelements, - long CHW, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1 -) { - // threadIdx.x + blockIdx.x * blockDim.x - long i = get_global_id(0); - if (i >= pelements) { - return; - } - - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - long ksize = OW * (KH > 1 ? KW : 1); - long kx = i / ksize; - long kd = kx * ksize; - long ky = (i - kd) / OW; - long ix = i % OW; - - long oh = get_group_id(1); - long batch = get_group_id(2) / IC; - long ic = get_group_id(2) % IC; - - long iiw = ix * s0 + kx * d0 - p0; - long iih = oh * s1 + ky * d1 - p1; - - long offset_dst = - ((batch * OH + oh) * OW + ix) * CHW + - (ic * (KW * KH) + ky * KW + kx); - - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = 0.0f; - } else { - long offset_src = ic * delta_offset + batch * batch_offset; - dst[offset_dst] = src1[offset_src + iih * IW + iiw]; - } -} - -kernel void kernel_im2col_f16( - global float * src1, - ulong offset1, - global half * dst, - ulong offsetd, - ulong batch_offset, - ulong delta_offset, - long IW, - long IH, - long IC, - long OW, - long OH, - long KW, - long KH, - long pelements, - long CHW, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1 -) { - long i = get_global_id(0); - - if (i >= pelements) { - return; - } - - src1 = (global float*)((global char*)src1 + offset1); - dst = (global half*)((global char*)dst + offsetd); - - long ksize = OW * (KH > 1 ? KW : 1); - long kx = i / ksize; - long kd = kx * ksize; - long ky = (i - kd) / OW; - long ix = i % OW; - - long oh = get_group_id(1); - long batch = get_group_id(2) / IC; - long ic = get_group_id(2) % IC; - - long iiw = ix * s0 + kx * d0 - p0; - long iih = oh * s1 + ky * d1 - p1; - - long offset_dst = - ((batch * OH + oh) * OW + ix) * CHW + - (ic * (KW * KH) + ky * KW + kx); - - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = 0.0f; - } else { - long offset_src = ic * delta_offset + batch * batch_offset; - dst[offset_dst] = src1[offset_src + iih * IW + iiw]; - } -} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl deleted file mode 100644 index e19e9a2f43..0000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_mm.cl +++ /dev/null @@ -1,1225 +0,0 @@ -//------------------------------------------------------------------------------ -// This file is contains additional mulmat kernels -// (and potentially other kernels). -//------------------------------------------------------------------------------ -#ifdef cl_khr_fp16 -#pragma OPENCL EXTENSION cl_khr_fp16 : enable -#elif defined(cl_amd_fp16) -#pragma OPENCL EXTENSION cl_amd_fp16 : enable -#else -#error "Half precision floating point not supportedby OpenCL implementation on your device." -#endif - -#ifdef cl_khr_subgroups -#pragma OPENCL EXTENSION cl_khr_subgroups : enable -#elif defined(cl_intel_subgroups) -#pragma OPENCL EXTENSION cl_intel_subgroups : enable -#else -#error "Subgroup not supported on your device." -#endif - -#ifdef cl_intel_required_subgroup_size -// Always use subgroup size of 32 on Intel. -#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable -#define INTEL_GPU 1 -#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) -#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) -#elif defined(cl_qcom_reqd_sub_group_size) -// Always use subgroups size of 64 on Adreno. -#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable -#define ADRENO_GPU 1 -#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) -#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) -#else -// TODO: do not know how to choose subgroup size on other GPUs. -#error "Selecting subgroup size is not supported on your device." -#endif - -#define QK4_0 32 -#define QR4_0 2 -#define QK4_1 32 -#define QR4_1 2 -#define QK5_0 32 -#define QR5_0 2 -#define QK5_1 32 -#define QR5_1 2 -#define QK8_0 32 -#define QR8_0 1 -#define QK_K 256 -#define K_QUANTS_PER_ITERATION 2 - -typedef char int8_t; -typedef uchar uint8_t; -typedef short int16_t; -typedef ushort uint16_t; -typedef int int32_t; -typedef uint uint32_t; - -//------------------------------------------------------------------------------ -// block_q4_0 -//------------------------------------------------------------------------------ -struct block_q4_0 -{ - half d; - uint8_t qs[QK4_0 / 2]; -}; - -//------------------------------------------------------------------------------ -// block_q6_K -//------------------------------------------------------------------------------ -// 6-bit quantization -// weight is represented as x = a * q -// 16 blocks of 16 elements each -// Effectively 6.5625 bits per weight -typedef struct { - uint8_t ql[QK_K/2]; // quants, lower 4 bits - uint8_t qh[QK_K/4]; // quants, upper 2 bits - int8_t scales[QK_K/16]; // scales, quantized with 8 bits - half d; // super-block scale -} block_q6_K; - -//------------------------------------------------------------------------------ -// These are the variant for matmatmul, based on the matvecmul kernel with -// flattened block_q4_0. -//------------------------------------------------------------------------------ - -// Common dot prod. -inline float mm_block_q_4_0_dot_y_flat( - global uchar * x, - global half * dh, - float sumy, - float16 yl, - int il -) { - float d = *dh; - global ushort * qs = ((global ushort *)x + il/2); - float acc = 0.f; - - acc += yl.s0 * (qs[0] & 0x000F); - acc += yl.s1 * (qs[0] & 0x0F00); - acc += yl.s8 * (qs[0] & 0x00F0); - acc += yl.s9 * (qs[0] & 0xF000); - - acc += yl.s2 * (qs[1] & 0x000F); - acc += yl.s3 * (qs[1] & 0x0F00); - acc += yl.sa * (qs[1] & 0x00F0); - acc += yl.sb * (qs[1] & 0xF000); - - acc += yl.s4 * (qs[2] & 0x000F); - acc += yl.s5 * (qs[2] & 0x0F00); - acc += yl.sc * (qs[2] & 0x00F0); - acc += yl.sd * (qs[2] & 0xF000); - - acc += yl.s6 * (qs[3] & 0x000F); - acc += yl.s7 * (qs[3] & 0x0F00); - acc += yl.se * (qs[3] & 0x00F0); - acc += yl.sf * (qs[3] & 0xF000); - - return d * (sumy * -8.f + acc); -} - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 8 // each SIMD group works on 8 rows (in weights matrix) -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 -#elif defined (ADRENO_GPU) -#define N_DST 8 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif -// -// This variant performs 1d blocking with 8x output. -// Eeach simdgroup outputs 8 values on `n0` dim (row in the output matrix). -// -inline void mul_mat_q_n_f32_1d_8x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - const int nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of - // a SIMD group in the grid. Each SIMD group produces N_DST values in the - // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. - // Currently with llama2 7B, im is always 0. - // TODO: how to handle im/gqa*(nb*ne0)? - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; - - global uchar * x = (global uchar *) src0_q + offset0_q; - global half * d = (global half *) src0_d + offset0_d; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float16 yl; - float8 sumf = (float8)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f); - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix*QK4_0 + il; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0.f; - - sumy += yb[0]; - sumy += yb[1]; - sumy += yb[2]; - sumy += yb[3]; - sumy += yb[4]; - sumy += yb[5]; - sumy += yb[6]; - sumy += yb[7]; - - sumy += yb[16]; - sumy += yb[17]; - sumy += yb[18]; - sumy += yb[19]; - sumy += yb[20]; - sumy += yb[21]; - sumy += yb[22]; - sumy += yb[23]; - - yl.s0 = yb[0]; - yl.s1 = yb[1]/256.f; - - yl.s2 = yb[2]; - yl.s3 = yb[3]/256.f; - - yl.s4 = yb[4]; - yl.s5 = yb[5]/256.f; - - yl.s6 = yb[6]; - yl.s7 = yb[7]/256.f; - - yl.s8 = yb[16]/16.f; - yl.s9 = yb[17]/4096.f; - - yl.sa = yb[18]/16.f; - yl.sb = yb[19]/4096.f; - - yl.sc = yb[20]/16.f; - yl.sd = yb[21]/4096.f; - - yl.se = yb[22]/16.f; - yl.sf = yb[23]/4096.f; - - sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); - sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); - sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); - sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); - - sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); - sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); - sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); - sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); - - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - float8 tot = (float8)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), - sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), - sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } - - if (first_row + 4 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; - } - if (first_row + 5 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; - } - if (first_row + 6 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; - } - if (first_row + 7 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_1d_8x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_mat_q_n_f32_1d_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 16 // each SIMD group works on 8 rows (in weights matrix) -#define N_SIMDGROUP 1 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 -#elif defined (ADRENO_GPU) -#define N_DST 16 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif -// -// This variant performs 1d blocking with 16x output. -// Eeach simdgroup outputs 16 values on `n0` dim (row in the output matrix). -// -inline void mul_mat_q_n_f32_1d_16x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - global float * dst, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - const int nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of - // a SIMD group in the grid. Each SIMD group produces N_DST values in the - // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. - // Currently with llama2 7B, im is always 0. - // TODO: how to handle im/gqa*(nb*ne0)? - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; - - global uchar * x = (global uchar *) src0_q + offset0_q; - global half * d = (global half *) src0_d + offset0_d; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float16 yl; - float16 sumf = (float16)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, - 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f); - - int ix = get_sub_group_local_id()/2; - int il = 8*(get_sub_group_local_id()%2); - - global float * yb = y + ix*QK4_0 + il; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { - float sumy = 0.f; - - sumy += yb[0]; - sumy += yb[1]; - sumy += yb[2]; - sumy += yb[3]; - sumy += yb[4]; - sumy += yb[5]; - sumy += yb[6]; - sumy += yb[7]; - - sumy += yb[16]; - sumy += yb[17]; - sumy += yb[18]; - sumy += yb[19]; - sumy += yb[20]; - sumy += yb[21]; - sumy += yb[22]; - sumy += yb[23]; - - yl.s0 = yb[0]; - yl.s1 = yb[1]/256.f; - - yl.s2 = yb[2]; - yl.s3 = yb[3]/256.f; - - yl.s4 = yb[4]; - yl.s5 = yb[5]/256.f; - - yl.s6 = yb[6]; - yl.s7 = yb[7]/256.f; - - yl.s8 = yb[16]/16.f; - yl.s9 = yb[17]/4096.f; - - yl.sa = yb[18]/16.f; - yl.sb = yb[19]/4096.f; - - yl.sc = yb[20]/16.f; - yl.sd = yb[21]/4096.f; - - yl.se = yb[22]/16.f; - yl.sf = yb[23]/4096.f; - - sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); - sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); - sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); - sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); - - sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); - sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); - sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); - sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); - - sumf.s8 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 8*nb*QK4_0/2, d + ib + 8*nb, sumy, yl, il); - sumf.s9 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 9*nb*QK4_0/2, d + ib + 9*nb, sumy, yl, il); - sumf.sa += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 10*nb*QK4_0/2, d + ib + 10*nb, sumy, yl, il); - sumf.sb += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 11*nb*QK4_0/2, d + ib + 11*nb, sumy, yl, il); - - sumf.sc += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 12*nb*QK4_0/2, d + ib + 12*nb, sumy, yl, il); - sumf.sd += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 13*nb*QK4_0/2, d + ib + 13*nb, sumy, yl, il); - sumf.se += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 14*nb*QK4_0/2, d + ib + 14*nb, sumy, yl, il); - sumf.sf += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 15*nb*QK4_0/2, d + ib + 15*nb, sumy, yl, il); - - yb += QK4_0 * (N_SIMDWIDTH/2); - } - - float16 tot = (float16)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), - sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), - sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7), - - sub_group_reduce_add(sumf.s8), sub_group_reduce_add(sumf.s9), - sub_group_reduce_add(sumf.sa), sub_group_reduce_add(sumf.sb), - sub_group_reduce_add(sumf.sc), sub_group_reduce_add(sumf.sd), - sub_group_reduce_add(sumf.se), sub_group_reduce_add(sumf.sf) - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } - - if (first_row + 4 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; - } - if (first_row + 5 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; - } - if (first_row + 6 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; - } - if (first_row + 7 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; - } - - if (first_row + 8 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 8] = tot.s8; - } - if (first_row + 9 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 9] = tot.s9; - } - if (first_row + 10 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 10] = tot.sa; - } - if (first_row + 11 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 11] = tot.sb; - } - - if (first_row + 12 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 12] = tot.sc; - } - if (first_row + 13 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 13] = tot.sd; - } - if (first_row + 14 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 14] = tot.se; - } - if (first_row + 15 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 15] = tot.sf; - } - } -} - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_1d_16x_flat( - global uchar * src0_q, - global half * src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - mul_mat_q_n_f32_1d_16x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); -} - -//------------------------------------------------------------------------------ -// kernel_mul_mat_q4_0_f32_flat_v0 -//------------------------------------------------------------------------------ -inline float block_q_4_0_dot_y_flat_v2( - half x, - half d, - float sumy, - float4 yl -) { - uchar2 q = as_uchar2(x); - float acc = 0.0f; - - acc += (q.s0 & 0x0F) * yl.s0; - acc += (q.s1 & 0x0F) * yl.s1; - - acc += (q.s0 & 0xF0) * yl.s2; - acc += (q.s1 & 0xF0) * yl.s3; - - return d * (sumy * -8.f + acc);; -} - -inline float block_q_4_0_dot_y_flat_v4( - float x, - half d, - float sumy, - float8 yl -) { - uchar4 q = as_uchar4(x); - float acc = 0.0f; - - acc += (q.s0 & 0x0F) * yl.s0; - acc += (q.s1 & 0x0F) * yl.s1; - acc += (q.s2 & 0x0F) * yl.s2; - acc += (q.s3 & 0x0F) * yl.s3; - - acc += (q.s0 & 0xF0) * yl.s4; - acc += (q.s1 & 0xF0) * yl.s5; - acc += (q.s2 & 0xF0) * yl.s6; - acc += (q.s3 & 0xF0) * yl.s7; - - return d * (sumy * -8.f + acc);; -} - -inline float block_q_4_0_dot_y_flat_v8( - float2 x, - half d, - float sumy, - float16 yl -) { - uchar8 q = as_uchar8(x); - float acc = 0.0f; - - acc += (q.s0 & 0x0F) * yl.s0; - acc += (q.s1 & 0x0F) * yl.s1; - acc += (q.s2 & 0x0F) * yl.s2; - acc += (q.s3 & 0x0F) * yl.s3; - acc += (q.s4 & 0x0F) * yl.s4; - acc += (q.s5 & 0x0F) * yl.s5; - acc += (q.s6 & 0x0F) * yl.s6; - acc += (q.s7 & 0x0F) * yl.s7; - - acc += (q.s0 & 0xF0) * yl.s8; - acc += (q.s1 & 0xF0) * yl.s9; - acc += (q.s2 & 0xF0) * yl.sa; - acc += (q.s3 & 0xF0) * yl.sb; - acc += (q.s4 & 0xF0) * yl.sc; - acc += (q.s5 & 0xF0) * yl.sd; - acc += (q.s6 & 0xF0) * yl.se; - acc += (q.s7 & 0xF0) * yl.sf; - - return d * (sumy * -8.f + acc);; -} - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define THREADS_PER_BLK 4 // Number of threads per block, or each thread process 1/THREADS_PER_BLK of a block -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 16 -#elif defined (ADRENO_GPU) -#define THREADS_PER_BLK 4 -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block -# define ACT_TY float16 -# define Q_BLK_LD_TY float2 -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v8 -#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block -# define ACT_TY float8 -# define Q_BLK_LD_TY float -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v4 -#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block -# define ACT_TY float4 -# define Q_BLK_LD_TY half -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v2 -#endif - -#define BTYES_PER_THREAD_IN_BLK (QK4_0/2/THREADS_PER_BLK) - -#if N_DST == 2 -# define SUM_TY float2 -#elif N_DST == 4 -# define SUM_TY float4 -#elif N_DST == 8 -# define SUM_TY float8 -#elif N_DST == 16 -# define SUM_TY float16 -#endif - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_flat_v0( - global uchar * src0_q, - global half * src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - const int nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; - - global uchar * x = (global uchar *) src0_q + offset0_q; - global half * d = (global half *) src0_d + offset0_d; - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - int ix = get_sub_group_local_id()/THREADS_PER_BLK; - int il = get_sub_group_local_id()%THREADS_PER_BLK; - - global float * yb = y + ix*QK4_0 + BTYES_PER_THREAD_IN_BLK*il; - - // Registers for caching activation - ACT_TY yl = 0.f; - - // Registers for caching quants - Q_BLK_LD_TY q_blk_0 = 0, q_blk_1 = 0; -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - Q_BLK_LD_TY q_blk_2 = 0, q_blk_3 = 0; -#endif -#if N_DST == 8 || N_DST == 16 - Q_BLK_LD_TY q_blk_4 = 0, q_blk_5 = 0, q_blk_6 = 0, q_blk_7 = 0; -#endif - - // Partial sum - SUM_TY sumf = 0.f; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/THREADS_PER_BLK) { - float sumy = 0.f; - - q_blk_0 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 0*nb*QK4_0/2); - q_blk_1 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 1*nb*QK4_0/2); -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - q_blk_2 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 2*nb*QK4_0/2); - q_blk_3 = *(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 3*nb*QK4_0/2); -#endif -#if N_DST == 8 || N_DST == 16 - q_blk_4 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 4*nb*QK4_0/2)); - q_blk_5 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 5*nb*QK4_0/2)); - q_blk_6 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 6*nb*QK4_0/2)); - q_blk_7 = (*(global Q_BLK_LD_TY*)(x + ib*QK4_0/2 + BTYES_PER_THREAD_IN_BLK*il + 7*nb*QK4_0/2)); -#endif - - // Load activation -#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block - yl.s01234567 = *(global float8 *)(yb); - yl.s89abcdef = *(global float8 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; - sumy += yl.s3; - sumy += yl.s4; - sumy += yl.s5; - sumy += yl.s6; - sumy += yl.s7; - sumy += yl.s8; yl.s8 /= 16.f; - sumy += yl.s9; yl.s9 /= 16.f; - sumy += yl.sa; yl.sa /= 16.f; - sumy += yl.sb; yl.sb /= 16.f; - sumy += yl.sc; yl.sc /= 16.f; - sumy += yl.sd; yl.sd /= 16.f; - sumy += yl.se; yl.se /= 16.f; - sumy += yl.sf; yl.sf /= 16.f; -#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block - yl.s0123 = *(global float4 *)(yb); - yl.s4567 = *(global float4 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; - sumy += yl.s3; - sumy += yl.s4; yl.s4 /= 16.f; - sumy += yl.s5; yl.s5 /= 16.f; - sumy += yl.s6; yl.s6 /= 16.f; - sumy += yl.s7; yl.s7 /= 16.f; -#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block - yl.s01 = *(global float2 *)(yb); - yl.s23 = *(global float2 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; yl.s2 /= 16.f; - sumy += yl.s3; yl.s3 /= 16.f; -#endif - - sumf.s0 += block_q_4_0_dot_y_flat(q_blk_0, *(d + ib + 0*nb), sumy, yl); - sumf.s1 += block_q_4_0_dot_y_flat(q_blk_1, *(d + ib + 1*nb), sumy, yl); -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - sumf.s2 += block_q_4_0_dot_y_flat(q_blk_2, *(d + ib + 2*nb), sumy, yl); - sumf.s3 += block_q_4_0_dot_y_flat(q_blk_3, *(d + ib + 3*nb), sumy, yl); -#endif -#if N_DST == 8 || N_DST == 16 - sumf.s4 += block_q_4_0_dot_y_flat(q_blk_4, *(d + ib + 4*nb), sumy, yl); - sumf.s5 += block_q_4_0_dot_y_flat(q_blk_5, *(d + ib + 5*nb), sumy, yl); - sumf.s6 += block_q_4_0_dot_y_flat(q_blk_6, *(d + ib + 6*nb), sumy, yl); - sumf.s7 += block_q_4_0_dot_y_flat(q_blk_7, *(d + ib + 7*nb), sumy, yl); -#endif - - yb += QK4_0 * (N_SIMDWIDTH/THREADS_PER_BLK); - } - - SUM_TY tot = (SUM_TY)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1) -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - , sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) -#endif -#if N_DST == 8 || N_DST == 16 - , sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5) - , sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) -#endif - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } -#endif -#if N_DST == 8 || N_DST == 16 - if (first_row + 4 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; - } - if (first_row + 5 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; - } - if (first_row + 6 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; - } - if (first_row + 7 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; - } -#endif - } -} - -//------------------------------------------------------------------------------ -// Using image1d_buffer_t - -#if defined(cl_qcom_subgroup_shuffle) -#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable -float qcom_sub_group_reduce_add(float sum) { - sum += qcom_sub_group_shuffle_down(sum, 32, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - sum += qcom_sub_group_shuffle_down(sum, 16, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - sum += qcom_sub_group_shuffle_down(sum, 8, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - sum += qcom_sub_group_shuffle_down(sum, 4, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - sum += qcom_sub_group_shuffle_down(sum, 2, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - sum += qcom_sub_group_shuffle_down(sum, 1, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, 0.f); - return sum; -} -#define sub_group_reduce_add qcom_sub_group_reduce_add -#else -#define sub_group_reduce_add sub_group_reduce_add -#endif - -#undef THREADS_PER_BLK -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define THREADS_PER_BLK 4 // Number of threads per block, or each thread process 1/THREADS_PER_BLK of a block -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 16 -#elif defined (ADRENO_GPU) -#define THREADS_PER_BLK 4 -#define N_DST 4 -#define N_SIMDGROUP 1 -#define N_SIMDWIDTH 64 -#endif - -#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block -# define ACT_TY float16 -# define Q_BLK_LD_TY float2 -# define EXTRACT_BLK_DATA(tmp, part) *((float2*)&tmp + part) -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v8 -#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block -# define ACT_TY float8 -# define Q_BLK_LD_TY float -# define EXTRACT_BLK_DATA(tmp, part) *((float*)&tmp + part) -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v4 -#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block -# define ACT_TY float4 -# define Q_BLK_LD_TY half -# define EXTRACT_BLK_DATA(tmp, part) *((half*)&tmp + part) -# define block_q_4_0_dot_y_flat block_q_4_0_dot_y_flat_v2 -#endif - -#define BTYES_PER_THREAD_IN_BLK (QK4_0/2/THREADS_PER_BLK) - -#if N_DST == 2 -# define SUM_TY float2 -#elif N_DST == 4 -# define SUM_TY float4 -#elif N_DST == 8 -# define SUM_TY float8 -#elif N_DST == 16 -# define SUM_TY float16 -#endif - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mat_q4_0_f32_flat_img_v0( - read_only image1d_buffer_t src0_q, - read_only image1d_buffer_t src0_d, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - const int nb = ne00/QK4_0; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; - - int i12 = im%ne12; - int i13 = im/ne12; - - // The number of scales is the same as the number of blocks. - ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. - ulong offset0_q = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - int ix = get_sub_group_local_id()/THREADS_PER_BLK; - int il = get_sub_group_local_id()%THREADS_PER_BLK; - - global float * yb = y + ix*QK4_0 + BTYES_PER_THREAD_IN_BLK*il; - - // Registers for caching activation - ACT_TY yl = 0.f; - - // Registers for caching quants - Q_BLK_LD_TY q_blk_0 = 0, q_blk_1 = 0; -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - Q_BLK_LD_TY q_blk_2 = 0, q_blk_3 = 0; -#endif -#if N_DST == 8 || N_DST == 16 - Q_BLK_LD_TY q_blk_4 = 0, q_blk_5 = 0, q_blk_6 = 0, q_blk_7 = 0; -#endif - - // Partial sum - SUM_TY sumf = 0.f; - - for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/THREADS_PER_BLK) { - float sumy = 0.f;; - - float4 tmp; - tmp = read_imagef(src0_q, offset0_q + ib + 0*nb); - q_blk_0 = EXTRACT_BLK_DATA(tmp, il); - tmp = read_imagef(src0_q, offset0_q + ib + 1*nb); - q_blk_1 = EXTRACT_BLK_DATA(tmp, il); -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - tmp = read_imagef(src0_q, offset0_q + ib + 2*nb); - q_blk_2 = EXTRACT_BLK_DATA(tmp, il); - tmp = read_imagef(src0_q, offset0_q + ib + 3*nb); - q_blk_3 = EXTRACT_BLK_DATA(tmp, il); -#endif -#if N_DST == 8 || N_DST == 16 - tmp = read_imagef(src0_q, offset0_q + ib + 4*nb); - q_blk_4 = EXTRACT_BLK_DATA(tmp, il); - tmp = read_imagef(src0_q, offset0_q + ib + 5*nb); - q_blk_5 = EXTRACT_BLK_DATA(tmp, il); - tmp = read_imagef(src0_q, offset0_q + ib + 6*nb); - q_blk_6 = EXTRACT_BLK_DATA(tmp, il); - tmp = read_imagef(src0_q, offset0_q + ib + 7*nb); - q_blk_7 = EXTRACT_BLK_DATA(tmp, il); -#endif - - // Load activation -#if THREADS_PER_BLK == 2 // Each thread processes 1/2 block - yl.s01234567 = *(global float8 *)(yb); - yl.s89abcdef = *(global float8 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; - sumy += yl.s3; - sumy += yl.s4; - sumy += yl.s5; - sumy += yl.s6; - sumy += yl.s7; - sumy += yl.s8; yl.s8 /= 16.f; - sumy += yl.s9; yl.s9 /= 16.f; - sumy += yl.sa; yl.sa /= 16.f; - sumy += yl.sb; yl.sb /= 16.f; - sumy += yl.sc; yl.sc /= 16.f; - sumy += yl.sd; yl.sd /= 16.f; - sumy += yl.se; yl.se /= 16.f; - sumy += yl.sf; yl.sf /= 16.f; -#elif THREADS_PER_BLK == 4 // Each thread processes 1/4 block - yl.s0123 = *(global float4 *)(yb); - yl.s4567 = *(global float4 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; - sumy += yl.s3; - sumy += yl.s4; yl.s4 /= 16.f; - sumy += yl.s5; yl.s5 /= 16.f; - sumy += yl.s6; yl.s6 /= 16.f; - sumy += yl.s7; yl.s7 /= 16.f; -#elif THREADS_PER_BLK == 8 // Each thread processes 1/8 block - yl.s01 = *(global float2 *)(yb); - yl.s23 = *(global float2 *)(yb + 16); - - sumy += yl.s0; - sumy += yl.s1; - sumy += yl.s2; yl.s2 /= 16.f; - sumy += yl.s3; yl.s3 /= 16.f; -#endif - - sumf.s0 += block_q_4_0_dot_y_flat(q_blk_0, read_imageh(src0_d, offset0_d + ib + 0*nb).s0, sumy, yl); - sumf.s1 += block_q_4_0_dot_y_flat(q_blk_1, read_imageh(src0_d, offset0_d + ib + 1*nb).s0, sumy, yl); -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - sumf.s2 += block_q_4_0_dot_y_flat(q_blk_2, read_imageh(src0_d, offset0_d + ib + 2*nb).s0, sumy, yl); - sumf.s3 += block_q_4_0_dot_y_flat(q_blk_3, read_imageh(src0_d, offset0_d + ib + 3*nb).s0, sumy, yl); -#endif -#if N_DST == 8 || N_DST == 16 - sumf.s4 += block_q_4_0_dot_y_flat(q_blk_4, read_imageh(src0_d, offset0_d + ib + 4*nb).s0, sumy, yl); - sumf.s5 += block_q_4_0_dot_y_flat(q_blk_5, read_imageh(src0_d, offset0_d + ib + 5*nb).s0, sumy, yl); - sumf.s6 += block_q_4_0_dot_y_flat(q_blk_6, read_imageh(src0_d, offset0_d + ib + 6*nb).s0, sumy, yl); - sumf.s7 += block_q_4_0_dot_y_flat(q_blk_7, read_imageh(src0_d, offset0_d + ib + 7*nb).s0, sumy, yl); -#endif - - yb += QK4_0 * (N_SIMDWIDTH/THREADS_PER_BLK); - } - - SUM_TY tot = (SUM_TY)( - sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1) -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - , sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) -#endif -#if N_DST == 8 || N_DST == 16 - , sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5) - , sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) -#endif - ); - - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } -#if N_DST == 4 || N_DST == 8 || N_DST == 16 - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; - } -#endif -#if N_DST == 8 || N_DST == 16 - if (first_row + 4 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; - } - if (first_row + 5 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; - } - if (first_row + 6 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; - } - if (first_row + 7 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; - } -#endif - } -} - -//------------------------------------------------------------------------------ -// kernel_mul_mv_q6_K_f32 -//------------------------------------------------------------------------------ - -#undef N_DST -#undef N_SIMDGROUP -#undef N_SIMDWIDTH - -#ifdef INTEL_GPU -#define N_DST 1 // number of rows each SIMD group works on -#define N_SIMDGROUP 2 // number of SIMD groups in a thread group -#define N_SIMDWIDTH 16 // SIMD group size -#elif defined (ADRENO_GPU) -#define N_DST 1 -#define N_SIMDGROUP 2 -#define N_SIMDWIDTH 64 -#endif - -#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes - -#ifdef INTEL_GPU -REQD_SUBGROUP_SIZE_16 -#elif defined (ADRENO_GPU) -REQD_SUBGROUP_SIZE_64 -#endif -kernel void kernel_mul_mv_q6_K_f32( - global void * src0, - ulong offset0, - global float * src1, - ulong offset1, - global float * dst, - ulong offsetd, - int ne00, - int ne01, - int ne02, - int ne10, - int ne12, - int ne0, - int ne1, - int r2, - int r3 -) { - src0 = (global void*)((global char*)src0 + offset0); - src1 = (global float*)((global char*)src1 + offset1); - dst = (global float*)((global char*)dst + offsetd); - - uchar kmask1 = 0x03; - uchar kmask2 = 0x0C; - uchar kmask3 = 0x30; - uchar kmask4 = 0xC0; - - int nb = ne00/QK_K; - - int r0 = get_group_id(0); - int r1 = get_group_id(1); - int im = get_group_id(2); - - int row = N_SIMDGROUP * r0 + get_sub_group_id(); - - int i12 = im%ne12; - int i13 = im/ne12; - - ulong offset_src0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); - - global block_q6_K * x = (global block_q6_K *) src0 + row*nb + offset_src0; - global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1; - - float sumf = 0; - - // For Q6_K quantization, 16 values forms a subblock, 16 subblock forms a - // block. Values in a subblock shares a scale that is quantized with 8 bits; - // the entire block shares a single floating point scale. - // For work distribution, each thread processes a subblock (16 weights), hence - // 16 threads process a (super) block -- a subgroup thus handles SIMDWIDTH/16 - // (super) blocks -- this is the block stride. - // The 16 threads that process a (super) block are split into 2 portions, each has - // 8 threads; each portion works on 8 subblocks. - // For subgroup of 16 threads, the entire subgroup works on a single (super) block - // before moving to the next (super) block. Thread0 - thread7 work on the - // first 8 subblocks; thread8 - thread15 works on the last 8 subblocks. - // Thread0 - thread3 work on subblocks 0, 2, 4, 6; thread4 - thread7 work on - // subblocks 1, 3, 5, 7. Each thread does not work on an entire subblock, but - // works on a total of 16 weight values. - int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0 - int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1 - int ip = tid/8; // first or second half of (super) block (0 or 1) - int il = tid%8; // each half has 8 parts, one per scale - int n = 4; // 4 scales at a time (and 4 sums) - int l0 = n*il; // offset into half-block, 0..28 - int is = 8*ip + l0/16; // 0, 1, 8, 9 - - int y_offset = 128*ip + l0; - int q_offset_l = 64*ip + l0; - int q_offset_h = 32*ip + l0; - - for (int i = ix; i < nb; i += BLOCK_STRIDE) { - - global uint8_t * q1 = x[i].ql + q_offset_l; - global uint8_t * q2 = q1 + QK_K/8; - global uint8_t * qh = x[i].qh + q_offset_h; - global int8_t * sc = x[i].scales + is; - - global float * y = yy + i * QK_K + y_offset; - - float dall = x[i].d; - - float4 sums = {0.f, 0.f, 0.f, 0.f}; - - sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & kmask1) << 4)) - 32.f); - sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & kmask2) << 2)) - 32.f); - sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & kmask3) << 0)) - 32.f); - sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & kmask4) >> 2)) - 32.f); - - sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & kmask1) << 4)) - 32.f); - sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & kmask2) << 2)) - 32.f); - sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & kmask3) << 0)) - 32.f); - sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & kmask4) >> 2)) - 32.f); - - sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & kmask1) << 4)) - 32.f); - sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & kmask2) << 2)) - 32.f); - sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & kmask3) << 0)) - 32.f); - sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & kmask4) >> 2)) - 32.f); - - sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & kmask1) << 4)) - 32.f); - sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & kmask2) << 2)) - 32.f); - sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & kmask3) << 0)) - 32.f); - sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & kmask4) >> 2)) - 32.f); - - sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]); - } - - float tot = sub_group_reduce_add(sumf); - if (get_sub_group_local_id() == 0) { - dst[r1*ne0 + im*ne0*ne1 + row] = tot; - } -} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl deleted file mode 100644 index cd4e0afbad..0000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_16.cl +++ /dev/null @@ -1,26 +0,0 @@ -// 16-bit transpose, loading/storing a 4x4 tile of elements - -#pragma OPENCL EXTENSION cl_khr_fp16 : enable - -kernel void kernel_transpose_16( - __read_only image1d_buffer_t input, - __write_only image1d_buffer_t output, - const uint rows, - const uint cols -) { - - const int i = get_global_id(0); - const int j = get_global_id(1); - const int i_2 = i<<2; - const int j_2 = j<<2; - - half4 temp0 = read_imageh(input, (j_2+0)*cols+i); - half4 temp1 = read_imageh(input, (j_2+1)*cols+i); - half4 temp2 = read_imageh(input, (j_2+2)*cols+i); - half4 temp3 = read_imageh(input, (j_2+3)*cols+i); - - write_imageh(output, (i_2+0)*rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); - write_imageh(output, (i_2+1)*rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); - write_imageh(output, (i_2+2)*rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); - write_imageh(output, (i_2+3)*rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); -} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl deleted file mode 100644 index 914ec0193e..0000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32.cl +++ /dev/null @@ -1,25 +0,0 @@ -// 32-bit transpose, loading/storing a 4x4 tile of elements - -kernel void kernel_transpose_32( - __read_only image1d_buffer_t input, - __write_only image1d_buffer_t output, - const uint rows, - const uint cols -) { - - const int i = get_global_id(0); - const int j = get_global_id(1); - const int i_2 = i<<2; - const int j_2 = j<<2; - - float4 temp0 = read_imagef(input, (j_2+0)*cols+i); - float4 temp1 = read_imagef(input, (j_2+1)*cols+i); - float4 temp2 = read_imagef(input, (j_2+2)*cols+i); - float4 temp3 = read_imagef(input, (j_2+3)*cols+i); - - write_imagef(output, (i_2+0)*rows+j, (float4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); - write_imagef(output, (i_2+1)*rows+j, (float4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); - write_imagef(output, (i_2+2)*rows+j, (float4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); - write_imagef(output, (i_2+3)*rows+j, (float4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); - -} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl b/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl deleted file mode 100644 index d3bd1fabb7..0000000000 --- a/ggml/src/ggml-opencl/kernels/ggml-opencl_transpose_32_16.cl +++ /dev/null @@ -1,35 +0,0 @@ -// 32-bit transpose, loading/storing a 4x4 tile of elements -// Only used for activations -// converts to FP16 -// also adds zero padding for non multiple of 8 prompt lengths -#pragma OPENCL EXTENSION cl_khr_fp16 : enable - -kernel void kernel_transpose_32_16(__read_only image1d_buffer_t input, __write_only image1d_buffer_t output, const uint rows, const uint cols, const uint padded_rows) { - - const int i = get_global_id(0); - const int j = get_global_id(1); - const int i_2 = i<<2; - const int j_2 = j<<2; - half4 temp0 = {0,0,0,0}; // initialize outputs to 0 - half4 temp1 = {0,0,0,0}; - half4 temp2 = {0,0,0,0}; - half4 temp3 = {0,0,0,0}; - - if((j_2+0)*cols+i*4+3 < rows*cols*16){ // only load from a valid location. Otherwise keep register data as 0 - temp0 = read_imageh(input, (j_2+0)*cols+i); - } - if((j_2+1)*cols+i*4+3 < rows*cols*16){ - temp1 = read_imageh(input, (j_2+1)*cols+i); - } - if((j_2+2)*cols+i*4+3 < rows*cols*16){ - temp2 = read_imageh(input, (j_2+2)*cols+i); - } - if((j_2+3)*cols+i*4+3 < rows*cols*16){ - temp3 = read_imageh(input, (j_2+3)*cols+i); - } - - write_imageh(output, (i_2+0)*padded_rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); // no conditionals for output, includes zero padding - write_imageh(output, (i_2+1)*padded_rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); - write_imageh(output, (i_2+2)*padded_rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); - write_imageh(output, (i_2+3)*padded_rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); -} diff --git a/ggml/src/ggml-opencl/kernels/im2col_f16.cl b/ggml/src/ggml-opencl/kernels/im2col_f16.cl new file mode 100644 index 0000000000..b84c898465 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/im2col_f16.cl @@ -0,0 +1,57 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_im2col_f16( + global float * src1, + ulong offset1, + global half * dst, + ulong offsetd, + ulong batch_offset, + ulong delta_offset, + long IW, + long IH, + long IC, + long OW, + long OH, + long KW, + long KH, + long pelements, + long CHW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1 +) { + long i = get_global_id(0); + if (i >= pelements) { + return; + } + + src1 = (global float*)((global char*)src1 + offset1); + dst = (global half*)((global char*)dst + offsetd); + + long ksize = OW * (KH > 1 ? KW : 1); + long kx = i / ksize; + long kd = kx * ksize; + long ky = (i - kd) / OW; + long ix = i % OW; + + long oh = get_group_id(1); + long batch = get_group_id(2) / IC; + long ic = get_group_id(2) % IC; + + long iiw = ix * s0 + kx * d0 - p0; + long iih = oh * s1 + ky * d1 - p1; + + long offset_dst = + ((batch * OH + oh) * OW + ix) * CHW + + (ic * (KW * KH) + ky * KW + kx); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + long offset_src = ic * delta_offset + batch * batch_offset; + dst[offset_dst] = src1[offset_src + iih * IW + iiw]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/im2col_f32.cl b/ggml/src/ggml-opencl/kernels/im2col_f32.cl new file mode 100644 index 0000000000..4bf65e4eaa --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/im2col_f32.cl @@ -0,0 +1,57 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_im2col_f32( + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + ulong batch_offset, + ulong delta_offset, + long IW, + long IH, + long IC, + long OW, + long OH, + long KW, + long KH, + long pelements, + long CHW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1 +) { + long i = get_global_id(0); + if (i >= pelements) { + return; + } + + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + long ksize = OW * (KH > 1 ? KW : 1); + long kx = i / ksize; + long kd = kx * ksize; + long ky = (i - kd) / OW; + long ix = i % OW; + + long oh = get_group_id(1); + long batch = get_group_id(2) / IC; + long ic = get_group_id(2) % IC; + + long iiw = ix * s0 + kx * d0 - p0; + long iih = oh * s1 + ky * d1 - p1; + + long offset_dst = + ((batch * OH + oh) * OW + ix) * CHW + + (ic * (KW * KH) + ky * KW + kx); + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + long offset_src = ic * delta_offset + batch * batch_offset; + dst[offset_dst] = src1[offset_src + iih * IW + iiw]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul.cl b/ggml/src/ggml-opencl/kernels/mul.cl new file mode 100644 index 0000000000..2a2b4eb70a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul.cl @@ -0,0 +1,79 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// mul +//------------------------------------------------------------------------------ +kernel void kernel_mul( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int i13 = i03 % ne13; + int i12 = i02 % ne12; + int i11 = i01 % ne11; + + global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + global char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i10 = i0 % ne10; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) * *((global float *)(src1_ptr + i10*nb10)); + } +} + +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_mul_row( + global float4 * src0, + ulong offset0, + global float4 * src1, + ulong offset1, + global float4 * dst, + ulong offsetd, + int ne +) { + src0 = (global float4*)((global char*)src0 + offset0); + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float4*)((global char*)dst + offsetd); + + // This performs better than using %. + uint gid = get_global_id(0); + uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne + dst[gid] = src0[gid] * src1[idx1]; +} diff --git a/ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl b/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl similarity index 100% rename from ggml/src/ggml-opencl/kernels/ggml-opencl_mul_mat_Ab_Bi_8x4.cl rename to ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl new file mode 100644 index 0000000000..9393b54941 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f16.cl @@ -0,0 +1,118 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define N_F16_F16 4 + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f16( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3) +{ + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int rb = get_group_id(1)*N_F16_F16; + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global half * x = (global half *) (src0 + offset_src0); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global half * y = (global half *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += (half) x[i] * (half) y[i]; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + global half4 * x4 = (global half4 *)x; + for (int row = 0; row < N_F16_F16; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global half * y = (global half *) (src1 + offset_src1); + global half4 * y4 = (global half4 *) y; + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += (half) x4[i].s0 * y4[i].s0; + sumf += (half) x4[i].s1 * y4[i].s1; + sumf += (half) x4[i].s2 * y4[i].s2; + sumf += (half) x4[i].s3 * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (half) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl new file mode 100644 index 0000000000..e52d3c6d47 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32.cl @@ -0,0 +1,118 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define N_F16_F32 4 + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f32( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int rb = get_group_id(1)*N_F16_F32; + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global half * x = (global half *) (src0 + offset_src0); + + if (ne00 < 128) { + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += convert_float(x[i]) * y[i]; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + global half4 * x4 = (global half4 *)x; + for (int row = 0; row < N_F16_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + global float4 * y4 = (global float4 *) y; + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += convert_float(x4[i].s0) * y4[i].s0; + sumf += convert_float(x4[i].s1) * y4[i].s1; + sumf += convert_float(x4[i].s2) * y4[i].s2; + sumf += convert_float(x4[i].s3) * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (float) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl new file mode 100644 index 0000000000..28d30212cd --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_1row.cl @@ -0,0 +1,94 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f32_1row( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global half * x = (global half *) (src0 + offset_src0); + global float * y = (global float *) (src1 + offset_src1); + + float sumf = 0; + if (ne00 < 128) { + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += (float) x[i] * (float) y[i]; + } + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } else { + global half4 * x4 = (global half4 *) x; + global float4 * y4 = (global float4 *) y; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += (float) x4[i].s0 * y4[i].s0; + sumf += (float) x4[i].s1 * y4[i].s1; + sumf += (float) x4[i].s2 * y4[i].s2; + sumf += (float) x4[i].s3 * y4[i].s3; + } + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (float) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl new file mode 100644 index 0000000000..cdf8197c47 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f16_f32_l4.cl @@ -0,0 +1,84 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +// Assumes row size (ne00) is a multiple of 4 +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f16_f32_l4( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int nrows = ne11; + int r0 = get_group_id(0); + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global half4 * x4 = (global half4 *) (src0 + offset_src0); + + for (int r1 = 0; r1 < nrows; ++r1) { + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float4 * y4 = (global float4 *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += convert_float(x4[i].s0) * y4[i].s0; + sumf += convert_float(x4[i].s1) * y4[i].s1; + sumf += convert_float(x4[i].s2) * y4[i].s2; + sumf += convert_float(x4[i].s3) * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl new file mode 100644 index 0000000000..ec71b87565 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_f32_f32.cl @@ -0,0 +1,118 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define N_F32_F32 4 + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_f32_f32( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global char*)((global char*)src0 + offset0); + src1 = (global char*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int r0 = get_group_id(0); + int rb = get_group_id(1)*N_F32_F32; + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + + global float * x = (global float *) (src0 + offset_src0); + + if (ne00 < 128) { + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + global float4 * x4 = (global float4 *)x; + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global float * y = (global float *) (src1 + offset_src1); + global float4 * y4 = (global float4 *) y; + + float sumf = 0; + for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) { + sumf += (float) x4[i].s0 * y4[i].s0; + sumf += (float) x4[i].s1 * y4[i].s1; + sumf += (float) x4[i].s2 * y4[i].s2; + sumf += (float) x4[i].s3 * y4[i].s3; + } + + float all_sum = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) { + all_sum += (float) x[i] * y[i]; + } + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl new file mode 100644 index 0000000000..52141e0ed5 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32.cl @@ -0,0 +1,192 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +//------------------------------------------------------------------------------ +// mul_vec_q_n_f32 +//------------------------------------------------------------------------------ +// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q4 quants begin (0 or QK4_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_4_0_dot_y( + global struct block_q4_0 * qb_curr, + float sumy, + private float * yl, + int il +) { + float d = qb_curr->d; + float2 acc = 0.f; + global ushort * qs = ((global ushort *)qb_curr + 1 + il/2); + for (int i = 0; i < 8; i+=2) { + acc.s0 += yl[i + 0] * (qs[i / 2] & 0x000F) + + yl[i + 1] * (qs[i / 2] & 0x0F00); + acc.s1 += yl[i + 8] * (qs[i / 2] & 0x00F0) + + yl[i + 9] * (qs[i / 2] & 0xF000); + } + return d * (sumy * -8.f + acc.s0 + acc.s1); +} + +#ifdef INTEL_GPU +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global + // id of a SIMD group in the grid. + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; // src1 vector cache + float sumf[N_DST]={0.f}; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + for (int i = 0; i < 8; i += 2) { + sumy += yb[i] + yb[i+1]; + yl[i+0] = yb[i+ 0]; + yl[i+1] = yb[i+ 1]/256.f; + sumy += yb[i+16] + yb[i+17]; + yl[i+8] = yb[i+16]/16.f; + yl[i+9] = yb[i+17]/4096.f; + } + + for (int row = 0; row < N_DST; row++) { + sumf[row] += block_q_4_0_dot_y(x+ib+row*nb, sumy, yl, il); + } + + // One thread in a SIMD group (i.e., subgroup) handles a half block, + // hence then entire SIMD group handles SIMDWIDTH/2 blocks. + // y points to the activation matrix (of type float). Therefore for + // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because + // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of + // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size. + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + // The above does not work for Adreno - it produces incorrect results for + // row = 1, 2, 3 and only row = 0 gives the correct result. + // If N_DST is changed, the below array must be initialized accordingly. + // This also seems to perform better on Intel. + float tot[N_DST] = { + sub_group_reduce_add(sumf[0]), sub_group_reduce_add(sumf[1]), + sub_group_reduce_add(sumf[2]), sub_group_reduce_add(sumf[3])}; + for (int row = 0; row < N_DST; ++row) { + if (get_sub_group_local_id() == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot[row]; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl new file mode 100644 index 0000000000..3eebab8f0f --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_16x_flat.cl @@ -0,0 +1,307 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +inline float mm_block_q_4_0_dot_y_flat( + global uchar * x, + global half * dh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + global ushort * qs = ((global ushort *)x + il/2); + float acc = 0.f; + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +#ifdef INTEL_GPU +#define N_DST 16 // each SIMD group works on 8 rows (in weights matrix) +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 16 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif +// +// This variant performs 1d blocking with 16x output. +// Eeach simdgroup outputs 16 values on `n0` dim (row in the output matrix). +// +inline void mul_mat_q_n_f32_1d_16x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const int nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float16 sumf = (float16)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + sumf.s8 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 8*nb*QK4_0/2, d + ib + 8*nb, sumy, yl, il); + sumf.s9 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 9*nb*QK4_0/2, d + ib + 9*nb, sumy, yl, il); + sumf.sa += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 10*nb*QK4_0/2, d + ib + 10*nb, sumy, yl, il); + sumf.sb += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 11*nb*QK4_0/2, d + ib + 11*nb, sumy, yl, il); + + sumf.sc += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 12*nb*QK4_0/2, d + ib + 12*nb, sumy, yl, il); + sumf.sd += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 13*nb*QK4_0/2, d + ib + 13*nb, sumy, yl, il); + sumf.se += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 14*nb*QK4_0/2, d + ib + 14*nb, sumy, yl, il); + sumf.sf += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 15*nb*QK4_0/2, d + ib + 15*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float16 tot = (float16)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7), + + sub_group_reduce_add(sumf.s8), sub_group_reduce_add(sumf.s9), + sub_group_reduce_add(sumf.sa), sub_group_reduce_add(sumf.sb), + sub_group_reduce_add(sumf.sc), sub_group_reduce_add(sumf.sd), + sub_group_reduce_add(sumf.se), sub_group_reduce_add(sumf.sf) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + + if (first_row + 8 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 8] = tot.s8; + } + if (first_row + 9 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 9] = tot.s9; + } + if (first_row + 10 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 10] = tot.sa; + } + if (first_row + 11 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 11] = tot.sb; + } + + if (first_row + 12 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 12] = tot.sc; + } + if (first_row + 13 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 13] = tot.sd; + } + if (first_row + 14 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 14] = tot.se; + } + if (first_row + 15 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 15] = tot.sf; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_1d_16x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_mat_q_n_f32_1d_16x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl new file mode 100644 index 0000000000..38024d00ad --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_1d_8x_flat.cl @@ -0,0 +1,265 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +inline float mm_block_q_4_0_dot_y_flat( + global uchar * x, + global half * dh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + global ushort * qs = ((global ushort *)x + il/2); + float acc = 0.f; + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +#ifdef INTEL_GPU +#define N_DST 8 // each SIMD group works on 8 rows (in weights matrix) +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 8 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif +// +// This variant performs 1d blocking with 8x output. +// Eeach simdgroup outputs 8 values on `n0` dim (row in the output matrix). +// +inline void mul_mat_q_n_f32_1d_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const int nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float8 sumf = (float8)(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += mm_block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float8 tot = (float8)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_1d_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_mat_q_n_f32_1d_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl new file mode 100644 index 0000000000..aed1ce7b26 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_8x_flat.cl @@ -0,0 +1,272 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +// This function requires the original shuffled weights. +// As a reminder, the original weights are shuffled so that (q[0], q[16]) are +// packed together in a byte, so are (q[1], q[17]) and so on. +inline float block_q_4_0_dot_y_flat( + global uchar * x, + global half * dh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + global ushort * qs = ((global ushort *)x + il/2); + float acc = 0.f; + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +// +// This variant outputs 8 values. +// +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 8 // each SIMD group works on 8 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 32 +#elif defined (ADRENO_GPU) +#define N_DST 8 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is the linear global id of + // a SIMD group in the grid. Each SIMD group produces N_DST values in the + // result, hence uses nb blocks, i.e., the offset becomes first_row*nb. + // Currently with llama2 7B, im is always 0. + // TODO: how to handle im/gqa*(nb*ne0)? + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_0/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_0/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float8 sumf = 0.f; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0.f; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 0*nb*QK4_0/2, d + ib + 0*nb, sumy, yl, il); + sumf.s1 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 1*nb*QK4_0/2, d + ib + 1*nb, sumy, yl, il); + sumf.s2 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 2*nb*QK4_0/2, d + ib + 2*nb, sumy, yl, il); + sumf.s3 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 3*nb*QK4_0/2, d + ib + 3*nb, sumy, yl, il); + + sumf.s4 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 4*nb*QK4_0/2, d + ib + 4*nb, sumy, yl, il); + sumf.s5 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 5*nb*QK4_0/2, d + ib + 5*nb, sumy, yl, il); + sumf.s6 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 6*nb*QK4_0/2, d + ib + 6*nb, sumy, yl, il); + sumf.s7 += block_q_4_0_dot_y_flat(x + ib*QK4_0/2 + 7*nb*QK4_0/2, d + ib + 7*nb, sumy, yl, il); + + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + float8 tot = (float8)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl new file mode 100644 index 0000000000..9295521797 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_0_f32_v.cl @@ -0,0 +1,254 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q4_0 +//------------------------------------------------------------------------------ +struct block_q4_0 +{ + half d; + uint8_t qs[QK4_0 / 2]; +}; + +// +// This variant unrolls the loops and uses vector types instead of pointers. +// It improves performance on Adreno but not so much on Intel. +// +inline float block_q_4_0_dot_y_v( + global struct block_q4_0 * qb_curr, + float sumy, + float16 yl, + int il +) { + float d = qb_curr->d; + float acc = 0.f; + global ushort * qs = ((global ushort *)qb_curr + 1 + il/2); + + acc += yl.s0 * (qs[0] & 0x000F); + acc += yl.s1 * (qs[0] & 0x0F00); + acc += yl.s8 * (qs[0] & 0x00F0); + acc += yl.s9 * (qs[0] & 0xF000); + + acc += yl.s2 * (qs[1] & 0x000F); + acc += yl.s3 * (qs[1] & 0x0F00); + acc += yl.sa * (qs[1] & 0x00F0); + acc += yl.sb * (qs[1] & 0xF000); + + acc += yl.s4 * (qs[2] & 0x000F); + acc += yl.s5 * (qs[2] & 0x0F00); + acc += yl.sc * (qs[2] & 0x00F0); + acc += yl.sd * (qs[2] & 0xF000); + + acc += yl.s6 * (qs[3] & 0x000F); + acc += yl.s7 * (qs[3] & 0x0F00); + acc += yl.se * (qs[3] & 0x00F0); + acc += yl.sf * (qs[3] & 0xF000); + + return d * (sumy * -8.f + acc); +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // assuming SIMD group size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_v( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + // (r0 * N_SIMDGROUP + get_sub_group_id()) is essenatially the linear global + // id of a SIMD group in the grid. + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q4_0 * x = (global struct block_q4_0 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; // src1 vector cache + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q_4_0_dot_y_v(x+ib+0*nb, sumy, yl, il); + sumf.s1 += block_q_4_0_dot_y_v(x+ib+1*nb, sumy, yl, il); + sumf.s2 += block_q_4_0_dot_y_v(x+ib+2*nb, sumy, yl, il); + sumf.s3 += block_q_4_0_dot_y_v(x+ib+3*nb, sumy, yl, il); + + // One thread in a SIMD group (i.e., subgroup) handles a half block, + // hence then entire SIMD group handles SIMDWIDTH/2 blocks. + // y points to the activation matrix (of type float). Therefore for + // one thread, the # of blocks y should advance is SIMDWIDTH/2 (because + // SIMDWIDTH/2 blocks are processed by a SIMD group) - in terms of + // floats, it is QK4_0 * (SIMDWIDTH/2), where QK4_0 is the block size. + yb += QK4_0 * (N_SIMDWIDTH/2); + } + + // The above does not work for Adreno - it produces incorrect results for + // row = 1, 2, 3 and only row = 0 gives the correct result. + // If N_DST is changed, the below array must be initialized accordingly. + // This also seems to perform better on Intel. + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mat_q4_0_f32_v( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_v(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl new file mode 100644 index 0000000000..8a17b9aae6 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl @@ -0,0 +1,190 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_0 32 +#define QR4_0 2 +#define QK4_1 32 +#define QR4_1 2 +#define QK5_0 32 +#define QR5_0 2 +#define QK5_1 32 +#define QR5_1 2 +#define QK8_0 32 +#define QR8_0 1 +#define QK_K 256 +#define K_QUANTS_PER_ITERATION 2 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +//------------------------------------------------------------------------------ +// block_q6_K +//------------------------------------------------------------------------------ +// 6-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elements each +// Effectively 6.5625 bits per weight +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + half d; // super-block scale +} block_q6_K; + +//------------------------------------------------------------------------------ +// kernel_mul_mv_q6_K_f32 +//------------------------------------------------------------------------------ + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 1 // number of rows each SIMD group works on +#define N_SIMDGROUP 2 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // SIMD group size +#elif defined (ADRENO_GPU) +#define N_DST 1 +#define N_SIMDGROUP 2 +#define N_SIMDWIDTH 64 +#endif + +#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q6_K_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + uchar kmask1 = 0x03; + uchar kmask2 = 0x0C; + uchar kmask3 = 0x30; + uchar kmask4 = 0xC0; + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int row = N_SIMDGROUP * r0 + get_sub_group_id(); + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset_src0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global block_q6_K * x = (global block_q6_K *) src0 + row*nb + offset_src0; + global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float sumf = 0; + + // For Q6_K quantization, 16 values forms a subblock, 16 subblock forms a + // block. Values in a subblock shares a scale that is quantized with 8 bits; + // the entire block shares a single floating point scale. + // For work distribution, each thread processes a subblock (16 weights), hence + // 16 threads process a (super) block -- a subgroup thus handles SIMDWIDTH/16 + // (super) blocks -- this is the block stride. + // The 16 threads that process a (super) block are split into 2 portions, each has + // 8 threads; each portion works on 8 subblocks. + // For subgroup of 16 threads, the entire subgroup works on a single (super) block + // before moving to the next (super) block. Thread0 - thread7 work on the + // first 8 subblocks; thread8 - thread15 works on the last 8 subblocks. + // Thread0 - thread3 work on subblocks 0, 2, 4, 6; thread4 - thread7 work on + // subblocks 1, 3, 5, 7. Each thread does not work on an entire subblock, but + // works on a total of 16 weight values. + int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0 + int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1 + int ip = tid/8; // first or second half of (super) block (0 or 1) + int il = tid%8; // each half has 8 parts, one per scale + int n = 4; // 4 scales at a time (and 4 sums) + int l0 = n*il; // offset into half-block, 0..28 + int is = 8*ip + l0/16; // 0, 1, 8, 9 + + int y_offset = 128*ip + l0; + int q_offset_l = 64*ip + l0; + int q_offset_h = 32*ip + l0; + + for (int i = ix; i < nb; i += BLOCK_STRIDE) { + + global uint8_t * q1 = x[i].ql + q_offset_l; + global uint8_t * q2 = q1 + QK_K/8; + global uint8_t * qh = x[i].qh + q_offset_h; + global int8_t * sc = x[i].scales + is; + + global float * y = yy + i * QK_K + y_offset; + + float dall = x[i].d; + + float4 sums = {0.f, 0.f, 0.f, 0.f}; + + sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & kmask1) << 4)) - 32.f); + sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & kmask2) << 2)) - 32.f); + sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & kmask3) << 0)) - 32.f); + sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & kmask4) >> 2)) - 32.f); + + sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & kmask1) << 4)) - 32.f); + sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & kmask2) << 2)) - 32.f); + sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & kmask3) << 0)) - 32.f); + sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & kmask4) >> 2)) - 32.f); + + sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & kmask1) << 4)) - 32.f); + sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & kmask2) << 2)) - 32.f); + sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & kmask3) << 0)) - 32.f); + sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & kmask4) >> 2)) - 32.f); + + sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & kmask1) << 4)) - 32.f); + sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & kmask2) << 2)) - 32.f); + sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & kmask3) << 0)) - 32.f); + sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & kmask4) >> 2)) - 32.f); + + sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]); + } + + float tot = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + dst[r1*ne0 + im*ne0*ne1 + row] = tot; + } +} diff --git a/ggml/src/ggml-opencl/kernels/norm.cl b/ggml/src/ggml-opencl/kernels/norm.cl new file mode 100644 index 0000000000..43167ba4d2 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/norm.cl @@ -0,0 +1,81 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// norm +//------------------------------------------------------------------------------ +kernel void kernel_norm( + global void * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + float eps, + local float * sum +) { + src0 = (global void*)((global char*)src0 + offset0); + dst = (global void*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); + + // MEAN + // parallel sum + sum[get_local_id(0)] = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + sum[get_local_id(0)] += x[i00]; + } + // reduce + barrier(CLK_LOCAL_MEM_FENCE); + for (uint i = get_local_size(0)/2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + barrier(CLK_LOCAL_MEM_FENCE); + } + float mean = sum[0] / ne00; + + // recenter and VARIANCE + barrier(CLK_LOCAL_MEM_FENCE); + global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + sum[get_local_id(0)] = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + y[i00] = x[i00] - mean; + sum[get_local_id(0)] += y[i00] * y[i00]; + } + + // reduce + barrier(CLK_LOCAL_MEM_FENCE); + for (uint i = get_local_size(0)/2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + barrier(CLK_LOCAL_MEM_FENCE); + } + float variance = sum[0] / ne00; + + float scale = 1.0f/sqrt(variance + eps); + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + y[i00] = y[i00] * scale; + } +} diff --git a/ggml/src/ggml-opencl/kernels/relu.cl b/ggml/src/ggml-opencl/kernels/relu.cl new file mode 100644 index 0000000000..60ff28a61a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/relu.cl @@ -0,0 +1,16 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// relu +//------------------------------------------------------------------------------ +kernel void kernel_relu( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = fmax(0.0f, src0[get_global_id(0)]); +} diff --git a/ggml/src/ggml-opencl/kernels/rms_norm.cl b/ggml/src/ggml-opencl/kernels/rms_norm.cl new file mode 100644 index 0000000000..9d21f3398e --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/rms_norm.cl @@ -0,0 +1,96 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// rms_norm +//------------------------------------------------------------------------------ +// This kernel depends on subgroup size. +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_rms_norm( + global void * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + float eps, + local float * sum // Note, the size depends on number of subgroups +) { + src0 = (global void*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); + global float * x_scalar = (global float *) x; + float4 sumf = 0; + float all_sum = 0; + + // parallel sum + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + sumf += x[i00] * x[i00]; + } + all_sum = sumf.s0 + sumf.s1 + sumf.s2 + sumf.s3; + all_sum = sub_group_reduce_add(all_sum); + if (get_sub_group_local_id() == 0) { + sum[get_sub_group_id()] = all_sum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + // broadcast + for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + } + if (get_local_id(0) == 0) { + for (int i = 4 * (ne00 / 4); i < ne00; i++) { + sum[0] += x_scalar[i]; + } + sum[0] /= ne00; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + const float mean = sum[0]; + const float scale = 1.0f/sqrt(mean + eps); + + global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + global float * y_scalar = (global float *) y; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + y[i00] = x[i00] * scale; + } + if (get_local_id(0) == 0) { + for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) { + y_scalar[i00] = x_scalar[i00] * scale; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/rope.cl b/ggml/src/ggml-opencl/kernels/rope.cl new file mode 100644 index 0000000000..0247730c03 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/rope.cl @@ -0,0 +1,721 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// kernel_rope +//------------------------------------------------------------------------------ +float rope_yarn_ramp(float low, float high, int i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +float2 rope_yarn( + float theta_extrap, float freq_scale, float2 corr_dims, int i0, float ext_factor, float mscale +) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims.s0, corr_dims.s1, i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); + } + return (float2)(cos(theta) * mscale, sin(theta) * mscale); +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { + return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base)); +} + +float2 rope_yarn_corr_dims( + int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow +) { + // start and end correction dims + return (float2)( + max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))), + min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))) + ); +} + +kernel void kernel_rope_norm_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + float theta = theta_base * pow(freq_base, inv_ndims*i0); + + float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + float x0 = src[0]; + float x1 = src[1]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_norm_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + float theta = theta_base * pow(freq_base, inv_ndims*i0); + + float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + float x0 = src[0]; + float x1 = src[1]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_neox_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_neox_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + float theta_base = (float) pos[i2]; + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_multi_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const int sector = (i0 / 2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + theta_base = pos[i2]; + } + else if (sector >= sections.s0 && sector < sec_w) { + theta_base = pos[i2 + ne2 * 1]; + } + else if (sector >= sec_w && sector < sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 2]; + } + else if (sector >= sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 3]; + } + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_multi_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + if (i0 < n_dims) { + int ic = i0/2; + + const int sector = (i0 / 2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + theta_base = pos[i2]; + } + else if (sector >= sections.s0 && sector < sec_w) { + theta_base = pos[i2 + ne2 * 1]; + } + else if (sector >= sec_w && sector < sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 2]; + } + else if (sector >= sec_w + sections.s2) { + theta_base = pos[i2 + ne2 * 3]; + } + + const float theta = theta_base * pow(freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } else { + global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +kernel void kernel_rope_vision_f32( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + int ic = i0/2; + + const int sector = (i0/2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + const int p = sector; + theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); + } else if (sector >= sections.s0 && sector < sec_w) { + const int p = sector - sections.s0; + theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); + } + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } +} + +kernel void kernel_rope_vision_f16( + global void * src0, + ulong offset0, + global int * src1, + ulong offset1, + global float * src2, + ulong offset2, + global half * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int n_past, + int n_dims, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + int4 sections +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global int*)((global char*)src1 + offset1); + src2 = (global float*)((global char*)src2 + offset2); + dst = (global float*)((global char*)dst + offsetd); + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow); + + global int * pos = src1; + + const int sect_dims = sections.s0 + sections.s1; + const int sec_w = sections.s1 + sections.s0; + + float inv_ndims = -1.f/n_dims; + + for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) { + int ic = i0/2; + + const int sector = (i0/2) % sect_dims; + float theta_base = 0.0f; + + if (sector < sections.s0) { + const int p = sector; + theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p); + } else if (sector >= sections.s0 && sector < sec_w) { + const int p = sector - sections.s0; + theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p); + } + + const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + + float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor); + + global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims]; + + dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1; + dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0; + } +} diff --git a/ggml/src/ggml-opencl/kernels/scale.cl b/ggml/src/ggml-opencl/kernels/scale.cl new file mode 100644 index 0000000000..8cfd518fa5 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/scale.cl @@ -0,0 +1,16 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// scale +//------------------------------------------------------------------------------ +kernel void kernel_scale( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd, + float scale +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + dst[get_global_id(0)] = src0[get_global_id(0)] * scale; +} diff --git a/ggml/src/ggml-opencl/kernels/silu.cl b/ggml/src/ggml-opencl/kernels/silu.cl new file mode 100644 index 0000000000..1d95e1b50f --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/silu.cl @@ -0,0 +1,30 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// silu +//------------------------------------------------------------------------------ +kernel void kernel_silu( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + float x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x / (1.0f + exp(-x)); +} + +kernel void kernel_silu_4( + global float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + float4 x = src0[get_global_id(0)]; + dst[get_global_id(0)] = x / (1.0f + exp(-x)); +} diff --git a/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl b/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl new file mode 100644 index 0000000000..62c05369a8 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/softmax_4_f16.cl @@ -0,0 +1,87 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max_4_f16( + global float * src0, + ulong offset0, + global half * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float *)((global char *)src0 + offset0); + src1 = (global half *)((global char *)src1 + offset1); + dst = (global float *)((global char *)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0; + global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float4 lmax4 = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + lmax4 = fmax(lmax4, psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)); + } + float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3)); + + const float max = sub_group_reduce_max(lmax); + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + const float4 exp_psrc4 = exp((psrc4[i00]*scale + slope*(pmask ? convert_float4(pmask[i00]) : 0.0f)) - max); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + pdst4[i00] /= sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl b/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl new file mode 100644 index 0000000000..d562774eab --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/softmax_4_f32.cl @@ -0,0 +1,87 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max_4( + global float * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0; + global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float4 lmax4 = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + float lmax = fmax(fmax(lmax4.s0, lmax4.s1), fmax(lmax4.s2, lmax4.s3)); + + const float max = sub_group_reduce_max(lmax); + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + float lsum = lsum4.s0 + lsum4.s1 + lsum4.s2 + lsum4.s3; + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + pdst4[i00] /= sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/softmax_f16.cl b/ggml/src/ggml-opencl/kernels/softmax_f16.cl new file mode 100644 index 0000000000..d38d099671 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/softmax_f16.cl @@ -0,0 +1,86 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max_f16( + global float * src0, + ulong offset0, + global half * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float *)((global char *)src0 + offset0); + src1 = (global half *)((global char *)src1 + offset1); + dst = (global float *)((global char *)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0; + global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float lmax = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + float max = sub_group_reduce_max(lmax); + + // parallel sum + float lsum = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); + lsum += exp_psrc0; + // Remember the result of exp here. exp is expensive, so we really do not + // wish to compute it twice. + pdst[i00] = exp_psrc0; + } + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + pdst[i00] /= sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/softmax_f32.cl b/ggml/src/ggml-opencl/kernels/softmax_f32.cl new file mode 100644 index 0000000000..001b587abe --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/softmax_f32.cl @@ -0,0 +1,86 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_soft_max( + global float * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + float scale, + float max_bias, + float m0, + float m1, + int n_head_log2 +) { + src0 = (global float*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0; + global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + float slope = 1.0f; + + // ALiBi + if (max_bias > 0.0f) { + int h = i02; + + float base = h < n_head_log2 ? m0 : m1; + int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + slope = pow(base, exp); + } + + // parallel max + float lmax = -INFINITY; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + lmax = fmax(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)); + } + float max = sub_group_reduce_max(lmax); + + // parallel sum + float lsum = 0.0f; + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max); + lsum += exp_psrc0; + // Remember the result of exp here. exp is expensive, so we really do not + // wish to compute it twice. + pdst[i00] = exp_psrc0; + } + + const float sum = sub_group_reduce_add(lsum); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + pdst[i00] /= sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/transpose.cl b/ggml/src/ggml-opencl/kernels/transpose.cl new file mode 100644 index 0000000000..a11490b304 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/transpose.cl @@ -0,0 +1,84 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +// 16-bit transpose, loading/storing a 4x4 tile of elements +kernel void kernel_transpose_16( + __read_only image1d_buffer_t input, + __write_only image1d_buffer_t output, + const uint rows, + const uint cols +) { + + const int i = get_global_id(0); + const int j = get_global_id(1); + const int i_2 = i<<2; + const int j_2 = j<<2; + + half4 temp0 = read_imageh(input, (j_2+0)*cols+i); + half4 temp1 = read_imageh(input, (j_2+1)*cols+i); + half4 temp2 = read_imageh(input, (j_2+2)*cols+i); + half4 temp3 = read_imageh(input, (j_2+3)*cols+i); + + write_imageh(output, (i_2+0)*rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); + write_imageh(output, (i_2+1)*rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); + write_imageh(output, (i_2+2)*rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); + write_imageh(output, (i_2+3)*rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); +} + +// 32-bit transpose, loading/storing a 4x4 tile of elements +kernel void kernel_transpose_32( + __read_only image1d_buffer_t input, + __write_only image1d_buffer_t output, + const uint rows, + const uint cols +) { + + const int i = get_global_id(0); + const int j = get_global_id(1); + const int i_2 = i<<2; + const int j_2 = j<<2; + + float4 temp0 = read_imagef(input, (j_2+0)*cols+i); + float4 temp1 = read_imagef(input, (j_2+1)*cols+i); + float4 temp2 = read_imagef(input, (j_2+2)*cols+i); + float4 temp3 = read_imagef(input, (j_2+3)*cols+i); + + write_imagef(output, (i_2+0)*rows+j, (float4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); + write_imagef(output, (i_2+1)*rows+j, (float4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); + write_imagef(output, (i_2+2)*rows+j, (float4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); + write_imagef(output, (i_2+3)*rows+j, (float4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); + +} + +// 32-bit transpose, loading/storing a 4x4 tile of elements +// Only used for activations +// converts to FP16 +// also adds zero padding for non multiple of 8 prompt lengths +kernel void kernel_transpose_32_16(__read_only image1d_buffer_t input, __write_only image1d_buffer_t output, const uint rows, const uint cols, const uint padded_rows) { + + const int i = get_global_id(0); + const int j = get_global_id(1); + const int i_2 = i<<2; + const int j_2 = j<<2; + half4 temp0 = {0,0,0,0}; // initialize outputs to 0 + half4 temp1 = {0,0,0,0}; + half4 temp2 = {0,0,0,0}; + half4 temp3 = {0,0,0,0}; + + if((j_2+0)*cols+i*4+3 < rows*cols*16){ // only load from a valid location. Otherwise keep register data as 0 + temp0 = read_imageh(input, (j_2+0)*cols+i); + } + if((j_2+1)*cols+i*4+3 < rows*cols*16){ + temp1 = read_imageh(input, (j_2+1)*cols+i); + } + if((j_2+2)*cols+i*4+3 < rows*cols*16){ + temp2 = read_imageh(input, (j_2+2)*cols+i); + } + if((j_2+3)*cols+i*4+3 < rows*cols*16){ + temp3 = read_imageh(input, (j_2+3)*cols+i); + } + + write_imageh(output, (i_2+0)*padded_rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); // no conditionals for output, includes zero padding + write_imageh(output, (i_2+1)*padded_rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1)); + write_imageh(output, (i_2+2)*padded_rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2)); + write_imageh(output, (i_2+3)*padded_rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3)); +} From b43d89e311c5e7fbf62e5ec3c0401eb536677267 Mon Sep 17 00:00:00 2001 From: Chenguang Li <757486878@qq.com> Date: Wed, 16 Apr 2025 16:21:05 +0800 Subject: [PATCH 02/39] CANN: Add 310P operator support check (#12962) --- ggml/src/ggml-cann/aclnn_ops.cpp | 8 ++++++++ ggml/src/ggml-cann/ggml-cann.cpp | 10 ++++++++++ 2 files changed, 18 insertions(+) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 2c5cdcae32..2c6737ea8c 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -625,6 +625,10 @@ static void ggml_cann_avg_pool2d(ggml_backend_cann_context& ctx, bool count_include_pad = true; int64_t divisor_override = 0; int8_t cube_math_type = 0; +#ifdef ASCEND_310P + cube_math_type = 1; +#endif + GGML_CANN_CALL_ACLNN_OP(AvgPool2d, acl_src, kernel_size, strides, paddings_avg, ceil_mode, count_include_pad, divisor_override, cube_math_type, acl_dst); @@ -2590,6 +2594,10 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds int64_t groups = 1; int8_t cubeMathType = 0; +#ifdef ASCEND_310P + cubeMathType = 1; +#endif + GGML_CANN_CALL_ACLNN_OP(Convolution, acl_input, acl_weight, nullptr, stride, padding, dilation, transposed, padding, groups, acl_dst, cubeMathType); diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 08b9ca301c..ca41e02607 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2022,6 +2022,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, return true; case GGML_TYPE_Q8_0: case GGML_TYPE_Q4_0: +#ifdef ASCEND_310P + // Q4 && Q8 per group is not suppor on 310p device + return false; +#endif // only support contiguous for quantized types. return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); @@ -2107,6 +2111,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, } case GGML_OP_POOL_2D: { const int32_t * opts = (const int32_t *) op->op_params; +#ifdef ASCEND_310P + enum ggml_op_pool opt = static_cast(opts[0]); + if(opt == GGML_OP_POOL_MAX){ + return false; + } +#endif const int k0 = opts[1]; const int k1 = opts[2]; const int p0 = opts[5]; From 015022bb53387baa8b23817ac03743705c7d472b Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 16 Apr 2025 13:37:25 -0500 Subject: [PATCH 03/39] vulkan: enable coopmat2 FA gqa and split_k optimizations more often (#12931) The grouped query attention optmization doesn't require a power of two ratio, the only thing relying on it was the modulo operation written as bitwise &. split_k need not depend on gqa_ratio - enable it any time there's only one workgroup in the X dimension. The shader gets the split index from the x coord, and multiple workgroups in the X dimension (pre-split) indicates a larger FA operation that wouldn't need splitting. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 6 +++--- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp | 2 +- tests/test-backend-ops.cpp | 4 +++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 783a0ff86c..0e9b2e8135 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5531,7 +5531,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_z = (uint32_t)neq3; - if (N == 1 && qk_ratio > 1 && is_pow2(qk_ratio) && gqa_ratio <= flash_attention_num_small_rows && + if (N == 1 && qk_ratio > 1 && gqa_ratio <= flash_attention_num_small_rows && qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) { // grouped query attention - make the N dimension equal to gqa_ratio, reduce // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 @@ -5544,8 +5544,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t split_kv = KV; uint32_t split_k = 1; - if (gqa_ratio > 1 && ctx->device->shader_core_count > 0) { - GGML_ASSERT(workgroups_x == 1); + // Try to use split_k when KV is large enough to be worth the overhead + if (workgroups_x == 1 && ctx->device->shader_core_count > 0 && KV >= 512) { // Try to run two workgroups per SM. split_k = ctx->device->shader_core_count * 2 / workgroups_y; if (split_k > 1) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index e1baa85f9e..b926a578ad 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -131,7 +131,7 @@ ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in A // Load the slope matrix, indexed by Q's dimension 2. ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) { - const uint32_t h = iq2 + (r & (p.gqa_ratio - 1)); + const uint32_t h = iq2 + (r % p.gqa_ratio); const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1); const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 3a5741c8d9..1ee7428946 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -4532,7 +4532,9 @@ static std::vector> make_test_cases_perf() { for (int kv : { 4096, 8192, 16384, }) { for (int hs : { 64, 128, }) { - test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, 4, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); + for (int nr : { 1, 4, }) { + test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, nr, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16)); + } } } From 12b17501e6015ffe568ac54fdf08e6580833bf1b Mon Sep 17 00:00:00 2001 From: kimminsu <80271594+kimminsu38oo@users.noreply.github.com> Date: Thu, 17 Apr 2025 06:25:57 +0900 Subject: [PATCH 04/39] opencl: fix incorrect local_size index in profiling log (#12868) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index fa40abc33e..05a2f4e630 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -1521,7 +1521,7 @@ static void ggml_cl2_free(void) { info.cmd_complete_duration_ns/1.e6f, info.cmd_total_duration_ns/1.e6f, info.global_size[0], info.global_size[1], info.global_size[2], - info.local_size[0], info.local_size[2], info.local_size[2], + info.local_size[0], info.local_size[1], info.local_size[2], info.output_size[0], info.output_size[1], info.output_size[2], info.output_size[3]); } fclose(fperf); From 971f245b3b5f3f55991bb779cb541b00f82eea1d Mon Sep 17 00:00:00 2001 From: Mikko Juola Date: Thu, 17 Apr 2025 01:37:05 -0700 Subject: [PATCH 05/39] llama : recognize IBM Granite 3.3 FIM tokens (#12988) The Granite's FIM tokens are very similar to Qwen's; it's just that they use underscore instead of a dash. So for example instead of . Opening up tokenizer_config.json in ibm-granite/granite-3.3-8b-base shows: ``` "", "", "", "", ... "", ``` --- src/llama-vocab.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 464ff01e06..480605173d 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1841,6 +1841,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { if (false || t.first == "<|fim_prefix|>" // Qwen || t.first == "" + || t.first == "" // Granite || t.first == "<|fim▁begin|>" // DeepSeek || t.first == "
"
                         || t.first == "▁
"          // CodeLlama
@@ -1859,6 +1860,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                 if (false
                         || t.first == "<|fim_suffix|>" // Qwen
                         || t.first == ""
+                        || t.first == ""   // Granite
                         || t.first == "<|fim▁hole|>" // DeepSeek
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
@@ -1877,6 +1879,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                 if (false
                         || t.first == "<|fim_middle|>" // Qwen
                         || t.first == ""
+                        || t.first == ""   // Granite
                         || t.first == "<|fim▁end|>"  // DeepSeek
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
@@ -1895,6 +1898,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                 if (false
                         || t.first == "<|fim_pad|>" // Qwen
                         || t.first == ""
+                        || t.first == ""   // Granite
                         || t.first == ""
                         ) {
                     special_fim_pad_id = t.second;
@@ -1913,6 +1917,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == "<|repo_name|>"
                         || t.first == ""
                         || t.first == ""
+                        || t.first == ""    // Granite
                         ) {
                     special_fim_rep_id = t.second;
                     if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {

From 7a395f67a7a02bb361d944b816d6e933889e28e1 Mon Sep 17 00:00:00 2001
From: hipudding 
Date: Thu, 17 Apr 2025 20:34:16 +0800
Subject: [PATCH 06/39] CANN: Add support for async operator submission
 (#12864)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Submit operators using asynchronous threads to improve performance.

Use the environment variable GGML_CANN_ASYNC_MODE to control whether
asynchronous submission is enabled. It is disabled by default.

Testing shows a 10%–20% performance improvement in scenarios with
small parameter sizes, especially in quantized models.
---
 ggml/src/ggml-cann/aclnn_ops.cpp | 440 ++++++++++++-------------------
 ggml/src/ggml-cann/aclnn_ops.h   | 324 +++++++++++++++++++----
 ggml/src/ggml-cann/common.h      | 136 +++++++++-
 ggml/src/ggml-cann/ggml-cann.cpp |  60 ++---
 4 files changed, 604 insertions(+), 356 deletions(-)

diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp
index 2c6737ea8c..67c0223c01 100644
--- a/ggml/src/ggml-cann/aclnn_ops.cpp
+++ b/ggml/src/ggml-cann/aclnn_ops.cpp
@@ -103,9 +103,7 @@ void ggml_cann_unary_op(
     aclTensor* acl_dst = ggml_cann_create_tensor(dst);
 
     unary_op(ctx, acl_src, acl_dst);
-
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst);
 }
 
 /**
@@ -123,8 +121,8 @@ static void aclnn_repeat(ggml_backend_cann_context& ctx, aclTensor* acl_src,
     // repeat tensor along each dim with repeat_array
     aclIntArray* repeats = aclCreateIntArray(repeat_array, GGML_MAX_DIMS);
 
-    GGML_CANN_CALL_ACLNN_OP(Repeat, acl_src, repeats, acl_dst);
-    ACL_CHECK(aclDestroyIntArray(repeats));
+    GGML_CANN_CALL_ACLNN_OP(ctx, Repeat, acl_src, repeats, acl_dst);
+    ggml_cann_release_resources(ctx, repeats);
 }
 
 /**
@@ -142,7 +140,7 @@ static void aclnn_repeat(ggml_backend_cann_context& ctx, aclTensor* acl_src,
  */
 static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src,
     aclTensor* acl_dst, aclDataType cast_data_type) {
-    GGML_CANN_CALL_ACLNN_OP(Cast, acl_src, cast_data_type, acl_dst);
+    GGML_CANN_CALL_ACLNN_OP(ctx, Cast, acl_src, cast_data_type, acl_dst);
 }
 
 void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -156,8 +154,7 @@ void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
                               dst->ne[1] / src->ne[1], dst->ne[0] / src->ne[0]};
 
     aclnn_repeat(ctx, acl_src, acl_dst, repeatsArray);
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst);
 }
 
 void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
@@ -165,10 +162,10 @@ void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
     float alphaValue = 1.0f;
     aclScalar* alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
     if (acl_dst != nullptr)
-        GGML_CANN_CALL_ACLNN_OP(Add, acl_src0, acl_src1, alpha, acl_dst);
+        GGML_CANN_CALL_ACLNN_OP(ctx, Add, acl_src0, acl_src1, alpha, acl_dst);
     else
-        GGML_CANN_CALL_ACLNN_OP(InplaceAdd, acl_src0, acl_src1, alpha);
-    ACL_CHECK(aclDestroyScalar(alpha));
+        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_src0, acl_src1, alpha);
+    ggml_cann_release_resources(ctx, alpha);
 }
 
 void aclnn_sub(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
@@ -176,26 +173,26 @@ void aclnn_sub(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
     float alphaValue = 1.0f;
     aclScalar* alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
     if (acl_dst != nullptr)
-        GGML_CANN_CALL_ACLNN_OP(Sub, acl_src0, acl_src1, alpha, acl_dst);
+        GGML_CANN_CALL_ACLNN_OP(ctx, Sub, acl_src0, acl_src1, alpha, acl_dst);
     else
-        GGML_CANN_CALL_ACLNN_OP(InplaceSub, acl_src0, acl_src1, alpha);
-    ACL_CHECK(aclDestroyScalar(alpha));
+        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSub, acl_src0, acl_src1, alpha);
+    ggml_cann_release_resources(ctx, alpha);
 }
 
 void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src,
     aclTensor* acl_other, aclTensor* acl_dst) {
     if (acl_dst != nullptr)
-        GGML_CANN_CALL_ACLNN_OP(Mul, acl_src, acl_other, acl_dst);
+        GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_src, acl_other, acl_dst);
     else
-        GGML_CANN_CALL_ACLNN_OP(InplaceMul, acl_src, acl_other);
+        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_src, acl_other);
 }
 
 void aclnn_div(ggml_backend_cann_context& ctx, aclTensor* acl_src,
     aclTensor* acl_other, aclTensor* acl_dst) {
     if (acl_dst != nullptr)
-        GGML_CANN_CALL_ACLNN_OP(Div, acl_src, acl_other, acl_dst);
+        GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src, acl_other, acl_dst);
     else
-        GGML_CANN_CALL_ACLNN_OP(InplaceDiv, acl_src, acl_other);
+        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDiv, acl_src, acl_other);
 }
 
 /**
@@ -224,11 +221,11 @@ static void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src,
     float scale, aclTensor* acl_dst, bool inplace) {
     aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT);
     if (inplace) {
-        GGML_CANN_CALL_ACLNN_OP(InplaceMuls, acl_src, acl_scale);
+        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_src, acl_scale);
     } else {
-        GGML_CANN_CALL_ACLNN_OP(Muls, acl_src, acl_scale, acl_dst);
+        GGML_CANN_CALL_ACLNN_OP(ctx, Muls, acl_src, acl_scale, acl_dst);
     }
-    ACL_CHECK(aclDestroyScalar(acl_scale));
+    ggml_cann_release_resources(ctx, acl_scale);
 }
 
 void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -245,11 +242,8 @@ void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     aclScalar* acl_negative_slope =
         aclCreateScalar(&negative_slope, aclDataType::ACL_FLOAT);
 
-    GGML_CANN_CALL_ACLNN_OP(LeakyRelu, acl_src, acl_negative_slope, acl_dst);
-
-    ACL_CHECK(aclDestroyScalar(acl_negative_slope));
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    GGML_CANN_CALL_ACLNN_OP(ctx, LeakyRelu, acl_src, acl_negative_slope, acl_dst);
+    ggml_cann_release_resources(ctx, acl_negative_slope, acl_src, acl_dst);
 }
 
 /**
@@ -265,7 +259,7 @@ void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
 static void aclnn_concat(ggml_backend_cann_context& ctx,
                          aclTensorList* tensorList, aclTensor* acl_dst,
                          int64_t concat_dim) {
-    GGML_CANN_CALL_ACLNN_OP(Cat, tensorList, concat_dim, acl_dst);
+    GGML_CANN_CALL_ACLNN_OP(ctx, Cat, tensorList, concat_dim, acl_dst);
 }
 
 void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -281,11 +275,10 @@ void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     int32_t acl_dim = 3 - dim;
 
     aclTensor* tensors[] = {acl_src0, acl_src1};
-    aclTensorList* tensorList = aclCreateTensorList(tensors, 2);
-    aclnn_concat(ctx, tensorList, acl_dst, acl_dim);
+    aclTensorList* tensor_list = aclCreateTensorList(tensors, 2);
+    aclnn_concat(ctx, tensor_list, acl_dst, acl_dim);
 
-    ACL_CHECK(aclDestroyTensorList(tensorList));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, tensor_list, acl_dst);
 }
 
 /**
@@ -315,10 +308,8 @@ static void aclnn_arange(ggml_backend_cann_context& ctx, aclTensor* acl_dst,
     aclScalar* acl_end = aclCreateScalar(&stop, aclDataType::ACL_FLOAT);
     aclScalar* acl_step = aclCreateScalar(&step, aclDataType::ACL_FLOAT);
 
-    GGML_CANN_CALL_ACLNN_OP(Arange, acl_start, acl_end, acl_step, acl_dst);
-    ACL_CHECK(aclDestroyScalar(acl_start));
-    ACL_CHECK(aclDestroyScalar(acl_end));
-    ACL_CHECK(aclDestroyScalar(acl_step));
+    GGML_CANN_CALL_ACLNN_OP(ctx, Arange, acl_start, acl_end, acl_step, acl_dst);
+    ggml_cann_release_resources(ctx, acl_start, acl_end, acl_step);
 }
 
 void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -335,7 +326,7 @@ void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     memcpy(&step, (float*)dst->op_params + 2, sizeof(float));
 
     aclnn_arange(ctx, acl_dst, start, stop, step, n_elements);
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_dst);
 }
 
 void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -352,11 +343,8 @@ void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     aclScalar* acl_min = aclCreateScalar(&min, aclDataType::ACL_FLOAT);
     aclScalar* acl_max = aclCreateScalar(&max, aclDataType::ACL_FLOAT);
 
-    GGML_CANN_CALL_ACLNN_OP(Clamp, acl_src, acl_min, acl_max, acl_dst);
-    ACL_CHECK(aclDestroyScalar(acl_min));
-    ACL_CHECK(aclDestroyScalar(acl_max));
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    GGML_CANN_CALL_ACLNN_OP(ctx, Clamp, acl_src, acl_min, acl_max, acl_dst);
+    ggml_cann_release_resources(ctx, acl_min, acl_max, acl_src, acl_dst);
 }
 
 void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -370,10 +358,8 @@ void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     aclTensor* acl_src = ggml_cann_create_tensor(src);
     aclTensor* acl_dst = ggml_cann_create_tensor(dst);
 
-    GGML_CANN_CALL_ACLNN_OP(Muls, acl_src, scale, acl_dst);
-    ACL_CHECK(aclDestroyScalar(scale));
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    GGML_CANN_CALL_ACLNN_OP(ctx, Muls, acl_src, scale, acl_dst);
+    ggml_cann_release_resources(ctx, scale, acl_src, acl_dst);
 }
 
 void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -388,12 +374,10 @@ void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     aclTensor* tmp_tensor =
         ggml_cann_create_tensor(buffer, ACL_INT64, ggml_type_size(dst->type),
                                 dst->ne, dst->nb, GGML_MAX_DIMS);
-    GGML_CANN_CALL_ACLNN_OP(Argsort, acl_src, -1, (order == GGML_SORT_ORDER_DESC ? true : false),
+    GGML_CANN_CALL_ACLNN_OP(ctx, Argsort, acl_src, -1, (order == GGML_SORT_ORDER_DESC ? true : false),
                       tmp_tensor);
-    GGML_CANN_CALL_ACLNN_OP(Cast, tmp_tensor, ggml_cann_type_mapping(dst->type), acl_dst);
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(tmp_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    GGML_CANN_CALL_ACLNN_OP(ctx, Cast, tmp_tensor, ggml_cann_type_mapping(dst->type), acl_dst);
+    ggml_cann_release_resources(ctx, acl_src, tmp_tensor, acl_dst);
 }
 
 void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -407,11 +391,9 @@ void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
 
     std::vector normData = {dst->ne[0]};
     aclIntArray* norm = aclCreateIntArray(normData.data(), normData.size());
-    GGML_CANN_CALL_ACLNN_OP(LayerNorm, acl_src, norm, nullptr, nullptr,
+    GGML_CANN_CALL_ACLNN_OP(ctx, LayerNorm, acl_src, norm, nullptr, nullptr,
                     eps, acl_dst, nullptr, nullptr);
-    ACL_CHECK(aclDestroyIntArray(norm));
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, norm, acl_src, acl_dst);
 }
 
 void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -441,12 +423,9 @@ void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     aclTensor* acl_rstd_out = ggml_cann_create_tensor(
         (char*)buffer + n_bytes, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND);
 
-    GGML_CANN_CALL_ACLNN_OP(GroupNorm, acl_src, nullptr, nullptr, N, C, HxW, n_groups, eps,
+    GGML_CANN_CALL_ACLNN_OP(ctx, GroupNorm, acl_src, nullptr, nullptr, N, C, HxW, n_groups, eps,
         acl_dst, acl_mean_out, acl_rstd_out);
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
-    ACL_CHECK(aclDestroyTensor(acl_mean_out));
-    ACL_CHECK(aclDestroyTensor(acl_rstd_out));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst, acl_mean_out, acl_rstd_out);
 }
 
 void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -471,19 +450,17 @@ void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
 
     if (!inplace) {
         size_t cpy_size = ggml_nbytes(dst);
-        ACL_CHECK(aclrtMemcpyAsync(dst->data, cpy_size, src0->data, cpy_size,
-                                   ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
+        ggml_cann_async_memcpy(ctx, dst->data, src0->data, cpy_size,
+            ACL_MEMCPY_DEVICE_TO_DEVICE);
         aclTensor* acl_src0 = ggml_cann_create_tensor(
             src0, src1->ne, src0->nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset);
 
-        GGML_CANN_CALL_ACLNN_OP(Add, acl_src0, acl_src1, alpha, acl_dst);
-        ACL_CHECK(aclDestroyTensor(acl_src0));
+        GGML_CANN_CALL_ACLNN_OP(ctx, Add, acl_src0, acl_src1, alpha, acl_dst);
+        ggml_cann_release_resources(ctx, acl_src0);
     } else {
-        GGML_CANN_CALL_ACLNN_OP(InplaceAdd, acl_dst, acl_src1, alpha);
+        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst, acl_src1, alpha);
     }
-
-    ACL_CHECK(aclDestroyTensor(acl_src1));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_src1, acl_dst);
 }
 
 /**
@@ -496,7 +473,6 @@ void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  * @param dim An array of dimension indices.
  * @param dim_size The number of dimensions.
  */
-
 static void aclnn_reduce_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst,
                              int64_t* dim, size_t dim_size) {
     GGML_ASSERT(dst->ne[0] == 1);
@@ -505,11 +481,9 @@ static void aclnn_reduce_sum(ggml_backend_cann_context& ctx, ggml_tensor* dst,
     aclTensor* acl_dst = ggml_cann_create_tensor(dst);
     aclIntArray* reduce_dims = aclCreateIntArray(dim, dim_size);
 
-    GGML_CANN_CALL_ACLNN_OP(ReduceSum, acl_src, reduce_dims, true,
+    GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_src, reduce_dims, true,
                       ggml_cann_type_mapping(dst->type), acl_dst);
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
-    ACL_CHECK(aclDestroyIntArray(reduce_dims));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst, reduce_dims);
 }
 
 void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -533,10 +507,8 @@ void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx,
     std::vector output_size{dst->ne[1], dst->ne[0]};
     auto output_size_array = aclCreateIntArray(output_size.data(), 2);
 
-    GGML_CANN_CALL_ACLNN_OP(UpsampleNearest2d, acl_src, output_size_array, acl_dst);
-    ACL_CHECK(aclDestroyIntArray(output_size_array));
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    GGML_CANN_CALL_ACLNN_OP(ctx, UpsampleNearest2d, acl_src, output_size_array, acl_dst);
+    ggml_cann_release_resources(ctx, acl_src, acl_dst, output_size_array);
 }
 
 /**
@@ -559,9 +531,8 @@ static void aclnn_pad(ggml_backend_cann_context& ctx, aclTensor* acl_src,
     aclIntArray* acl_pad = aclCreateIntArray(paddings, GGML_MAX_DIMS * 2);
     aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT);
 
-    GGML_CANN_CALL_ACLNN_OP(ConstantPadNd, acl_src, acl_pad, acl_value, acl_dst);
-    ACL_CHECK(aclDestroyIntArray(acl_pad));
-    ACL_CHECK(aclDestroyScalar(acl_value));
+    GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_src, acl_pad, acl_value, acl_dst);
+    ggml_cann_release_resources(ctx, acl_pad, acl_value);
 }
 
 void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -577,9 +548,7 @@ void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
         0, dst->ne[0] - src->ne[0], 0, dst->ne[1] - src->ne[1],
         0, dst->ne[2] - src->ne[2], 0, dst->ne[3] - src->ne[3]};
     aclnn_pad(ctx, acl_src, acl_dst, paddings);
-
-    ACL_CHECK(aclDestroyTensor(acl_dst));
-    ACL_CHECK(aclDestroyTensor(acl_src));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst);
 }
 
 /**
@@ -629,14 +598,11 @@ static void ggml_cann_avg_pool2d(ggml_backend_cann_context& ctx,
     cube_math_type = 1;
 #endif
 
-    GGML_CANN_CALL_ACLNN_OP(AvgPool2d, acl_src, kernel_size, strides, paddings_avg,
+    GGML_CANN_CALL_ACLNN_OP(ctx, AvgPool2d, acl_src, kernel_size, strides, paddings_avg,
                     ceil_mode, count_include_pad, divisor_override,
                     cube_math_type, acl_dst);
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
-    ACL_CHECK(aclDestroyIntArray(kernel_size));
-    ACL_CHECK(aclDestroyIntArray(strides));
-    ACL_CHECK(aclDestroyIntArray(paddings_avg));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst, kernel_size, strides,
+                                paddings_avg);
 }
 
 /**
@@ -704,15 +670,10 @@ static void ggml_cann_max_pool2d(ggml_backend_cann_context& ctx,
 
     bool ceil_mode = false;
     int64_t auto_pads = 0;
-    GGML_CANN_CALL_ACLNN_OP(MaxPool, tmp_tensor, kernel_size, strides, auto_pads,
+    GGML_CANN_CALL_ACLNN_OP(ctx, MaxPool, tmp_tensor, kernel_size, strides, auto_pads,
                     paddings_max, dilations, ceil_mode, acl_dst);
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
-    ACL_CHECK(aclDestroyTensor(tmp_tensor));
-    ACL_CHECK(aclDestroyIntArray(kernel_size));
-    ACL_CHECK(aclDestroyIntArray(strides));
-    ACL_CHECK(aclDestroyIntArray(paddings_max));
-    ACL_CHECK(aclDestroyIntArray(dilations));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst, tmp_tensor, kernel_size,
+                                strides, paddings_max, dilations);
 }
 
 void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -743,7 +704,7 @@ void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  */
 static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src,
                       aclTensor* acl_dst) {
-    GGML_CANN_CALL_ACLNN_OP(InplaceCopy, acl_dst, acl_src);
+    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst, acl_src);
 }
 
 void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -761,9 +722,8 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
         if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
             if (dst->type == src0->type) {
                 size_t cpy_size = ggml_nbytes(dst);
-                ACL_CHECK(aclrtMemcpyAsync(
-                    dst->data, cpy_size, src0->data, cpy_size,
-                    ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
+                ggml_cann_async_memcpy(ctx, dst->data, src0->data, cpy_size,
+                    ACL_MEMCPY_DEVICE_TO_DEVICE);
                 return;
             } else {
                 ggml_cann_pool_alloc src_buffer_allocator(
@@ -782,10 +742,9 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
 
                 aclnn_cast(ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping(dst->type));
                 size_t cpy_size = ggml_nbytes(dst);
-                ACL_CHECK(aclrtMemcpyAsync(
-                    dst->data, cpy_size, src_trans_buffer, cpy_size,
-                    ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
-                ACL_CHECK(aclDestroyTensor(src_trans_tensor));
+                ggml_cann_async_memcpy(ctx, dst->data, src_trans_buffer, cpy_size,
+                    ACL_MEMCPY_DEVICE_TO_DEVICE);
+                ggml_cann_release_resources(ctx, src_trans_tensor);
                 return;
             }
         } else if (ggml_is_contiguous(dst)) {
@@ -805,18 +764,15 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
             aclnn_cast(ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping(dst->type));
 
             size_t cpy_size = ggml_nbytes(dst);
-            ACL_CHECK(aclrtMemcpyAsync(dst->data, cpy_size, src_trans_buffer,
-                                       cpy_size, ACL_MEMCPY_DEVICE_TO_DEVICE,
-                                       ctx.stream()));
-            ACL_CHECK(aclDestroyTensor(src_trans_tensor));
+            ggml_cann_async_memcpy(ctx, dst->data, src_trans_buffer, cpy_size,
+                ACL_MEMCPY_DEVICE_TO_DEVICE);
+            ggml_cann_release_resources(ctx, src_trans_tensor);
             return;
         } else {
             GGML_ABORT("Unsupport dst is not tontiguous.");
         }
     }
-
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst);
 }
 
 /**
@@ -844,7 +800,7 @@ static aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, void* buffer,
         nb[i] = nb[i - 1] * ne[i - 1];
     }
 
-    ACL_CHECK(aclrtMemsetAsync(buffer, n_bytes, 0, n_bytes, ctx.stream()));
+    ggml_cann_async_memset(ctx, buffer, n_bytes, 0);
     aclTensor* zero =
         ggml_cann_create_tensor(buffer, type, type_size, ne, nb, dims);
     return zero;
@@ -877,7 +833,7 @@ static aclTensor* aclnn_values(ggml_backend_cann_context& ctx, void* buffer,
     float alpha_host = 1.0f;
     aclScalar* alpha = aclCreateScalar(&alpha_host, aclDataType::ACL_FLOAT);
     aclScalar* other = aclCreateScalar(&value, aclDataType::ACL_FLOAT);
-    GGML_CANN_CALL_ACLNN_OP(InplaceAdds, acl_tensor, other, alpha);
+    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_tensor, other, alpha);
     return acl_tensor;
 }
 
@@ -903,11 +859,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
         aclnn_zero(ctx, zero_tensor_allocator.get(), zero_tensor_n_bytes,
                    src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type),
                    ggml_element_size(src));
-    GGML_CANN_CALL_ACLNN_OP(RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd);
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
-    ACL_CHECK(aclDestroyTensor(acl_gamma));
-    ACL_CHECK(aclDestroyTensor(acl_rstd));
+    GGML_CANN_CALL_ACLNN_OP(ctx, RmsNorm, acl_src, acl_gamma, eps, acl_dst, acl_rstd);
+    ggml_cann_release_resources(ctx, acl_src, acl_dst, acl_gamma, acl_rstd);
 }
 
 // TODO: performace is low.
@@ -933,13 +886,10 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst,
     float alphaValue = 1.0f;
     alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
 
-    GGML_CANN_CALL_ACLNN_OP(InplaceTriu, mask_tensor, n_past + 1);
-    GGML_CANN_CALL_ACLNN_OP(Tril, acl_src, n_past + 1, acl_dst);
-    GGML_CANN_CALL_ACLNN_OP(InplaceAdd, acl_dst, mask_tensor, alpha);
-    ACL_CHECK(aclDestroyScalar(alpha));
-    ACL_CHECK(aclDestroyTensor(mask_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceTriu, mask_tensor, n_past + 1);
+    GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src, n_past + 1, acl_dst);
+    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst, mask_tensor, alpha);
+    ggml_cann_release_resources(ctx, alpha, acl_src, acl_dst, mask_tensor);
 }
 
 /**
@@ -960,7 +910,8 @@ void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst,
 static void aclnn_permute(ggml_backend_cann_context& ctx, aclTensor* acl_src,
                           aclTensor* acl_dst, int64_t* new_dim, uint64_t dims) {
     aclIntArray* acl_dims = aclCreateIntArray(new_dim, dims);
-    GGML_CANN_CALL_ACLNN_OP(Permute, acl_src, acl_dims, acl_dst);
+    GGML_CANN_CALL_ACLNN_OP(ctx, Permute, acl_src, acl_dims, acl_dst);
+    ggml_cann_release_resources(ctx, acl_dims);
 }
 
 static void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx,
@@ -981,8 +932,7 @@ static void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx,
         aclnn_permute(ctx, tmp_im2col_tensor, acl_dst, permute_dim, 3);
     }
 
-    // release
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_dst);
 }
 
 static void ggml_cann_im2col_1d_post_process(
@@ -1004,7 +954,6 @@ static void ggml_cann_im2col_1d_post_process(
 
     // Permute: [N, IC * KH * KW, OW * OH] ->
     // [N, OW * OH * n_bytes_factor, IC * KH * KW]
-    aclTensor* tmp_permute_tensor = nullptr;
     ggml_cann_pool_alloc tmp_permute_allocator(ctx.pool());
     tmp_permute_allocator.alloc(ggml_nbytes(dst) * n_bytes_factor);
     void* tmp_permute_buffer = tmp_permute_allocator.get();
@@ -1016,7 +965,7 @@ static void ggml_cann_im2col_1d_post_process(
         tmp_permute_nb[i] = tmp_permute_nb[i - 1] * tmp_permute_ne[i - 1];
     }
 
-    tmp_permute_tensor = ggml_cann_create_tensor(
+    aclTensor* tmp_permute_tensor = ggml_cann_create_tensor(
         tmp_permute_buffer, ggml_cann_type_mapping(dst->type),
         ggml_type_size(dst->type), tmp_permute_ne, tmp_permute_nb,
         GGML_MAX_DIMS - 1, ACL_FORMAT_ND);
@@ -1046,9 +995,8 @@ static void ggml_cann_im2col_1d_post_process(
                              c * KH * KW * n_step_w * ggml_type_size(dst->type);
 
             for (int i = 0; i < n_step_w; i++) {
-                ACL_CHECK(aclrtMemcpyAsync(
-                    cur_dst_buffer, size_cpy, cur_permute_buffer, size_cpy,
-                    ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
+                ggml_cann_async_memcpy(ctx, cur_dst_buffer, cur_permute_buffer, size_cpy,
+                    ACL_MEMCPY_DEVICE_TO_DEVICE);
                 cur_dst_buffer =
                     (char*)cur_dst_buffer + KH * KW * ggml_type_size(dst->type);
                 cur_permute_buffer = (char*)cur_permute_buffer +
@@ -1058,13 +1006,11 @@ static void ggml_cann_im2col_1d_post_process(
     } else {
         offset = KH * KW * n_step_w *
                  ggml_type_size(dst->type);  // equal to ggml_nbytes(dst)
-        ACL_CHECK(aclrtMemcpyAsync(dst->data, offset,
-                                   (char*)tmp_permute_buffer + offset, offset,
-                                   ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
+        ggml_cann_async_memcpy(ctx, dst->data, (char*)tmp_permute_buffer + offset, offset,
+            ACL_MEMCPY_DEVICE_TO_DEVICE);
     }
 
-    // release
-    ACL_CHECK(aclDestroyTensor(tmp_permute_tensor));
+    ggml_cann_release_resources(ctx, tmp_permute_tensor);
 }
 
 void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -1126,7 +1072,7 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     auto* dilations = aclCreateIntArray(dilation_size.data(), 2);
     auto* paddings = aclCreateIntArray(padding_dims.data(), 2);
     auto* strides = aclCreateIntArray(stride_dims.data(), 2);
-    GGML_CANN_CALL_ACLNN_OP(Im2col, acl_src1, kernel_size, dilations,
+    GGML_CANN_CALL_ACLNN_OP(ctx, Im2col, acl_src1, kernel_size, dilations,
                     paddings, strides, tmp_im2col_tensor);
 
     // Cast if dst is f16.
@@ -1160,14 +1106,8 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
                                          tmp_im2col_tensor, im2col_op_params);
     }
 
-    // release
-    ACL_CHECK(aclDestroyTensor(acl_src1));
-    ACL_CHECK(aclDestroyTensor(tmp_im2col_tensor));
-    ACL_CHECK(aclDestroyTensor(tmp_cast_tensor));
-    ACL_CHECK(aclDestroyIntArray(kernel_size));
-    ACL_CHECK(aclDestroyIntArray(dilations));
-    ACL_CHECK(aclDestroyIntArray(paddings));
-    ACL_CHECK(aclDestroyIntArray(strides));
+    ggml_cann_release_resources(ctx, acl_src1, tmp_im2col_tensor, tmp_cast_tensor,
+        kernel_size, dilations, paddings, strides);
 }
 
 /**
@@ -1184,17 +1124,17 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  * @param acl_src The tensor on which the exponential function will be applied.
  */
 static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) {
-    GGML_CANN_CALL_ACLNN_OP(InplaceExp, acl_src);
+    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceExp, acl_src);
 }
 
 void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
                       aclTensor* acl_dst) {
-    GGML_CANN_CALL_ACLNN_OP(Cos, acl_src, acl_dst);
+    GGML_CANN_CALL_ACLNN_OP(ctx, Cos, acl_src, acl_dst);
 }
 
 void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
                       aclTensor* acl_dst) {
-    GGML_CANN_CALL_ACLNN_OP(Sin, acl_src, acl_dst);
+    GGML_CANN_CALL_ACLNN_OP(ctx, Sin, acl_src, acl_dst);
 }
 
 void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
@@ -1243,13 +1183,13 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
 
     ggml_cann_pool_alloc permute_allocator(ctx.pool(), ggml_nbytes(src));
     void* tmp_permute_buffer = permute_allocator.get();
-    aclTensor* tmp_permute_tenosr = ggml_cann_create_tensor(
+    aclTensor* tmp_permute_tensor = ggml_cann_create_tensor(
         tmp_permute_buffer, ggml_cann_type_mapping(src->type),
         ggml_type_size(src->type), tmp_permute_ne, tmp_permute_nb,
         GGML_MAX_DIMS, ACL_FORMAT_ND);
     int64_t permute_dim[] = {0, 1, 3, 2};
     int64_t num_dims = 4;
-    aclnn_permute(ctx, acl_src, tmp_permute_tenosr, permute_dim, num_dims);
+    aclnn_permute(ctx, acl_src, tmp_permute_tensor, permute_dim, num_dims);
 
     // timestep * freq
     int64_t tmp_mul_ne[] = {src->ne[1] * half, src->ne[0], src->ne[2],
@@ -1270,7 +1210,7 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
         tmp_mul_buffer, ggml_cann_type_mapping(src->type),
         ggml_type_size(src->type), tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS,
         ACL_FORMAT_ND);
-    aclnn_mul(ctx, tmp_permute_tenosr, tmp_arange_tensor, tmp_mul_tensor);
+    aclnn_mul(ctx, tmp_permute_tensor, tmp_arange_tensor, tmp_mul_tensor);
 
     // cos
     ggml_cann_pool_alloc cos_allocator(
@@ -1298,17 +1238,13 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
     int64_t concat_dim = 3;
     aclTensor* acl_dst = ggml_cann_create_tensor(dst);
     aclTensor* tensors[] = {tmp_cos_tensor, tmp_sin_tensor};
-    aclTensorList* tensorList = aclCreateTensorList(tensors, 2);
-    aclnn_concat(ctx, tensorList, acl_dst, concat_dim);
+    aclTensorList* tensor_list = aclCreateTensorList(tensors, 2);
+    aclnn_concat(ctx, tensor_list, acl_dst, concat_dim);
 
     // release
     // segmentation fault when delete both tensorList and his elements.
-    ACL_CHECK(aclDestroyTensorList(tensorList));
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(tmp_arange_tensor));
-    ACL_CHECK(aclDestroyTensor(tmp_permute_tenosr));
-    ACL_CHECK(aclDestroyTensor(tmp_mul_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, tensor_list, acl_src, tmp_arange_tensor,
+        tmp_permute_tensor, tmp_mul_tensor, acl_dst);
 }
 
 /**
@@ -1324,8 +1260,8 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
 static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
                               aclTensor* acl_dst) {
     auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT);
-    GGML_CANN_CALL_ACLNN_OP(InplaceFillScalar, acl_dst, acl_scalar);
-    ACL_CHECK(aclDestroyScalar(acl_scalar));
+    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst, acl_scalar);
+    ggml_cann_release_resources(ctx, acl_scalar);
 }
 
 /**
@@ -1346,7 +1282,7 @@ static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
  */
 static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
                                     aclTensor* acl_dst, aclTensor* acl_exp) {
-    GGML_CANN_CALL_ACLNN_OP(InplacePowTensorTensor, acl_dst, acl_exp);
+    GGML_CANN_CALL_ACLNN_OP(ctx, InplacePowTensorTensor, acl_dst, acl_exp);
 }
 
 /**
@@ -1498,15 +1434,9 @@ static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src,
 
     // add
     aclnn_add(ctx, tmp_output_tensor, acl_src, acl_dst);
-
-    ACL_CHECK(aclDestroyTensor(tmp_arange1_tensor));
-    ACL_CHECK(aclDestroyTensor(tmp_arange2_tensor));
-    ACL_CHECK(aclDestroyTensor(tmp_mk_base1_tensor));
-    ACL_CHECK(aclDestroyTensor(tmp_mk_base2_tensor));
-    ACL_CHECK(aclDestroyTensor(tmp_mk_base_tensor));
-    ACL_CHECK(aclDestroyTensor(tmp_arange_tensor));
-    ACL_CHECK(aclDestroyTensor(tmp_mk_tensor));
-    ACL_CHECK(aclDestroyTensor(tmp_output_tensor));
+    ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
+        tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
+        tmp_arange_tensor, tmp_mk_tensor, tmp_output_tensor);
 }
 
 void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -1529,7 +1459,7 @@ void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  */
 static void aclnn_softmax(ggml_backend_cann_context& ctx, aclTensor* acl_src,
                           int64_t dim, aclTensor* acl_dst) {
-    GGML_CANN_CALL_ACLNN_OP(Softmax, acl_src, dim, acl_dst);
+    GGML_CANN_CALL_ACLNN_OP(ctx, Softmax, acl_src, dim, acl_dst);
 }
 
 void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
@@ -1579,8 +1509,7 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
                 src1_fp32_nb, GGML_MAX_DIMS);
             aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
             aclnn_cast(ctx, acl_src1, acl_src1_fp32_tensor, ACL_FLOAT);
-
-            ACL_CHECK(aclDestroyTensor(acl_src1));
+            ggml_cann_release_resources(ctx, acl_src1);
         } else {
             acl_src1_fp32_tensor = ggml_cann_create_tensor(src1);
         }
@@ -1633,17 +1562,13 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
 
         // softmax
         aclnn_softmax(ctx, alibi_output_tensor, 3, acl_dst);
-        ACL_CHECK(aclDestroyTensor(alibi_output_tensor));
+        ggml_cann_release_resources(ctx, alibi_output_tensor);
     } else {
         aclnn_softmax(ctx, acl_input_mul_scale_tensor, 3, acl_dst);
     }
 
-    ACL_CHECK(aclDestroyTensor(acl_src0));
-    ACL_CHECK(aclDestroyTensor(acl_src1_fp32_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
-    ACL_CHECK(aclDestroyScalar(acl_scale));
-    ACL_CHECK(aclDestroyTensor(acl_input_mul_scale_tensor));
-    ACL_CHECK(aclDestroyTensor(tmp_mask_tensor));
+    ggml_cann_release_resources(ctx, acl_src0, acl_src1_fp32_tensor, acl_dst,
+        acl_scale, acl_input_mul_scale_tensor, tmp_mask_tensor);
 }
 
 /**
@@ -1690,10 +1615,8 @@ static void aclnn_embedding_4d(ggml_backend_cann_context& ctx, void* src_buffer,
                 (char*)dst->data + i * dst->nb[3] + j * dst->nb[2],
                 ggml_cann_type_mapping(dst->type), ggml_element_size(dst),
                 acl_out_ne, acl_out_nb, 2);
-            GGML_CANN_CALL_ACLNN_OP(Embedding, acl_src_tensor, acl_index, acl_out);
-            ACL_CHECK(aclDestroyTensor(acl_src_tensor));
-            ACL_CHECK(aclDestroyTensor(acl_index));
-            ACL_CHECK(aclDestroyTensor(acl_out));
+            GGML_CANN_CALL_ACLNN_OP(ctx, Embedding, acl_src_tensor, acl_index, acl_out);
+            ggml_cann_release_resources(ctx, acl_src_tensor, acl_index, acl_out);
         }
     }
 }
@@ -1724,8 +1647,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
             aclnn_cast(ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping(dst->type));
             aclnn_embedding_4d(ctx, src_trans_buffer, src0->ne,
                                    src_trans_nb, src1, dst);
-            ACL_CHECK(aclDestroyTensor(acl_src0));
-            ACL_CHECK(aclDestroyTensor(src_trans_tensor));
+            ggml_cann_release_resources(ctx, acl_src0, src_trans_tensor);
             break;
         }
         case GGML_TYPE_Q8_0: {
@@ -1787,7 +1709,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
             aclnn_embedding_4d(ctx, dequant_buffer_allocator.get(),
                                    dequant_ne, dequant_nb, src1, dst);
 
-            ACL_CHECK(aclDestroyTensor(dequant_tensor));
+            ggml_cann_release_resources(ctx, dequant_tensor);
             break;
         }
         default:
@@ -1815,7 +1737,7 @@ static void aclnn_repeat_interleave(ggml_backend_cann_context& ctx,
                                     aclTensor* acl_src, aclTensor* acl_dst,
                                     int64_t dim, int64_t repeats,
                                     int64_t output_size) {
-    GGML_CANN_CALL_ACLNN_OP(RepeatInterleaveIntWithDim, acl_src, repeats, dim,
+    GGML_CANN_CALL_ACLNN_OP(ctx, RepeatInterleaveIntWithDim, acl_src, repeats, dim,
                   output_size, acl_dst);
 }
 
@@ -1864,21 +1786,19 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
 
     switch (n_dims) {
         case 2:
-            GGML_CANN_CALL_ACLNN_OP(Mm, acl_input_tensor, acl_weight_tensor, acl_dst, 2);
+            GGML_CANN_CALL_ACLNN_OP(ctx, Mm, acl_input_tensor, acl_weight_tensor, acl_dst, 2);
             break;
         case 3:
-            GGML_CANN_CALL_ACLNN_OP(BatchMatMul, acl_input_tensor, acl_weight_tensor, acl_dst, 2);
+            GGML_CANN_CALL_ACLNN_OP(ctx, BatchMatMul, acl_input_tensor, acl_weight_tensor, acl_dst, 2);
             break;
         default:
             // ALLOW_FP32_DOWN_PRECISION, when input is
             // fp32, atlas a2 will transpose it to HFLOAT32.
-            GGML_CANN_CALL_ACLNN_OP(Matmul, acl_input_tensor, acl_weight_tensor, acl_dst, 1);
+            GGML_CANN_CALL_ACLNN_OP(ctx, Matmul, acl_input_tensor, acl_weight_tensor, acl_dst, 1);
             break;
     }
 
-    ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_input_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_weight_tensor, acl_input_tensor, acl_dst);
 }
 
 /**
@@ -1948,9 +1868,7 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
             input_buffer, ACL_FLOAT16, input_elem_size, input_cast_ne,
             input_cast_nb, GGML_MAX_DIMS);
         aclnn_cast(ctx, acl_src1_tensor, acl_input_tensor, ACL_FLOAT16);
-
-        ACL_CHECK(aclDestroyTensor(acl_input_tensor));
-        ACL_CHECK(aclDestroyTensor(acl_src1_tensor));
+        ggml_cann_release_resources(ctx, acl_input_tensor, acl_src1_tensor);
     }
 
     // output
@@ -2003,13 +1921,11 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
             if (src0->ne[0] > QK8_0) {
                 antiquantGroupSize = QK8_0;
             }
-            GGML_CANN_CALL_ACLNN_OP(WeightQuantBatchMatmulV2, acl_input_tensor,
+            GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, acl_input_tensor,
                            acl_weight_tensor, acl_scale_tensor, nullptr,
                            nullptr, nullptr, nullptr, antiquantGroupSize,
                            acl_output_tensor);
-            ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
-            ACL_CHECK(aclDestroyTensor(acl_scale_tensor));
-            ACL_CHECK(aclDestroyTensor(acl_output_tensor));
+            ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, acl_output_tensor);
 
             // other splits
             for (int64_t split = 1; split < split_size; split++) {
@@ -2036,16 +1952,14 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
                     (char*)output_buffer + batch1 * output_stride, ACL_FLOAT16,
                     output_elem_size, output_ne, output_nb, 2, ACL_FORMAT_ND,
                     output_ne_offset);
-                GGML_CANN_CALL_ACLNN_OP(WeightQuantBatchMatmulV2, acl_input_tensor,
+                GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, acl_input_tensor,
                                    acl_weight_tensor, acl_scale_tensor, nullptr,
                                    nullptr, nullptr, nullptr, antiquantGroupSize,
                                    acl_output_tensor);
-                ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
-                ACL_CHECK(aclDestroyTensor(acl_scale_tensor));
-                ACL_CHECK(aclDestroyTensor(acl_output_tensor));
+                ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, acl_output_tensor);
             }
 
-            ACL_CHECK(aclDestroyTensor(acl_input_tensor));
+            ggml_cann_release_resources(ctx, acl_input_tensor);
         }
     }
 
@@ -2064,8 +1978,7 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
         aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
         aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
 
-        ACL_CHECK(aclDestroyTensor(acl_output_tensor));
-        ACL_CHECK(aclDestroyTensor(acl_dst_tensor));
+        ggml_cann_release_resources(ctx, acl_output_tensor, acl_dst_tensor);
     }
 }
 
@@ -2106,9 +2019,8 @@ static void aclnn_roll(ggml_backend_cann_context& ctx, aclTensor* acl_src,
                        aclTensor* acl_dst, int64_t* shifts, int64_t* dims) {
     aclIntArray* acl_shifts = aclCreateIntArray(shifts, 1);
     aclIntArray* acl_dims = aclCreateIntArray(dims, 1);
-    GGML_CANN_CALL_ACLNN_OP(Roll, acl_src, acl_shifts, acl_dims, acl_dst);
-    ACL_CHECK(aclDestroyIntArray(acl_shifts));
-    ACL_CHECK(aclDestroyIntArray(acl_dims));
+    GGML_CANN_CALL_ACLNN_OP(ctx, Roll, acl_src, acl_shifts, acl_dims, acl_dst);
+    ggml_cann_release_resources(ctx, acl_shifts, acl_dims);
 }
 
 /**
@@ -2130,9 +2042,8 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
                                     float value) {
     aclIntArray* acl_index = aclCreateIntArray(index, index_num);
     aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT);
-    GGML_CANN_CALL_ACLNN_OP(InplaceIndexFillTensor, acl_src, dim, acl_index, acl_value);
-    ACL_CHECK(aclDestroyIntArray(acl_index));
-    ACL_CHECK(aclDestroyScalar(acl_value));
+    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexFillTensor, acl_src, dim, acl_index, acl_value);
+    ggml_cann_release_resources(ctx, acl_index, acl_value);
 }
 
 static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
@@ -2169,7 +2080,8 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
 
     // power
     aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT);
-    GGML_CANN_CALL_ACLNN_OP(PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor, acl_theta_scale_tensor);
+    GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor,
+                            acl_theta_scale_tensor);
 
     // freq_scale
     if (freq_scale != 1) {
@@ -2182,7 +2094,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
             src2->data, ggml_cann_type_mapping(src2->type),
             ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
         aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
-        ACL_CHECK(aclDestroyTensor(acl_freq_factors_tensor));
+        ggml_cann_release_resources(ctx, acl_freq_factors_tensor);
     }
 
     // position
@@ -2251,12 +2163,8 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
     }
 
     // release
-    ACL_CHECK(aclDestroyTensor(acl_theta_scale_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_position_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_theta_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_sin_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_cos_tensor));
-    ACL_CHECK(aclDestroyScalar(acl_theta_scale));
+    ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
+        acl_theta_tensor, acl_sin_tensor, acl_cos_tensor, acl_theta_scale);
 }
 
 #ifdef __cplusplus
@@ -2368,8 +2276,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
         int64_t shifts[] = {1};
         int64_t dims[] = {3};
         aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims);
-        ACL_CHECK(aclDestroyTensor(acl_input_roll_tensor));
-        ACL_CHECK(aclDestroyTensor(acl_input_tensor));
+        ggml_cann_release_resources(ctx, acl_input_roll_tensor, acl_input_tensor);
 
         // init [-1, 1, -1, 1, ...]
         minus_one_scale_buffer = minus_one_scale_allocator.get();
@@ -2405,8 +2312,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
         int64_t dims[] = {3};
         aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims);
 
-        ACL_CHECK(aclDestroyTensor(acl_input_roll_tensor));
-        ACL_CHECK(aclDestroyTensor(acl_input_tensor));
+        ggml_cann_release_resources(ctx, acl_input_roll_tensor, acl_input_tensor);
         // init [-1, -1, -1, 1, 1,1,...]
         minus_one_scale_buffer = minus_one_scale_allocator.get();
         int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1};
@@ -2431,7 +2337,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
         bool inplace = true;
         float scale = -1;
         aclnn_muls(ctx, acl_first_half_tensor, scale, nullptr, inplace);
-        ACL_CHECK(aclDestroyTensor(acl_first_half_tensor));
+        ggml_cann_release_resources(ctx, acl_first_half_tensor);
     }
 
     // TODO: n_dims < ne0
@@ -2496,14 +2402,10 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
                   output_fp32_tensor);
         aclnn_cast(ctx, output_fp32_tensor, acl_dst, ACL_FLOAT16);
 
-        ACL_CHECK(aclDestroyTensor(input_fp32_tensor1));
-        ACL_CHECK(aclDestroyTensor(input_fp32_tensor2));
-        ACL_CHECK(aclDestroyTensor(output_fp32_tensor));
-        ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor));
-        ACL_CHECK(aclDestroyTensor(acl_minus_one_tensor));
-        ACL_CHECK(aclDestroyTensor(acl_input_roll_mul_scale_tensor));
-        ACL_CHECK(aclDestroyTensor(acl_input_roll_reshape_tensor));
-        ACL_CHECK(aclDestroyTensor(acl_src));
+        ggml_cann_release_resources(ctx, input_fp32_tensor1, input_fp32_tensor2,
+            output_fp32_tensor, acl_sin_reshape_tensor,
+            acl_minus_one_tensor, acl_input_roll_mul_scale_tensor,
+            acl_input_roll_reshape_tensor, acl_src);
     }
     return;
 #endif
@@ -2513,8 +2415,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
 
     switch (src0->type) {
         case GGML_TYPE_F32: {
-            GGML_CANN_CALL_ACLNN_OP(RotaryPositionEmbedding, acl_src, acl_cos_reshape_tensor,
-                acl_sin_reshape_tensor, acl_mode, acl_dst);
+            GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src,
+                acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode, acl_dst);
             break;
         }
         case GGML_TYPE_F16: {
@@ -2540,23 +2442,22 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
 
             aclnn_cast(ctx, acl_src, acl_src_trans_tensor, ACL_FLOAT);
 
-            GGML_CANN_CALL_ACLNN_OP(RotaryPositionEmbedding, acl_src_trans_tensor, acl_cos_reshape_tensor,
-                acl_sin_reshape_tensor, acl_mode, acl_dst_trans_tensor);
+            GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor,
+                acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode,
+                acl_dst_trans_tensor);
 
             aclnn_cast(ctx, acl_dst_trans_tensor, acl_dst, ACL_FLOAT16);
 
-            ACL_CHECK(aclDestroyTensor(acl_src_trans_tensor));
-            ACL_CHECK(aclDestroyTensor(acl_dst_trans_tensor));
+            ggml_cann_release_resources(ctx, acl_src_trans_tensor,
+                acl_dst_trans_tensor);
             break;
         }
         default:
             GGML_ABORT("Unsupported tensor type for GGML_OP_ROPE");
             break;
     }
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_cos_reshape_tensor,
+        acl_sin_reshape_tensor, acl_src, acl_dst);
 }
 
 
@@ -2566,10 +2467,9 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     aclTensor* acl_src = ggml_cann_create_tensor(src0);
     aclTensor* acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3);
 
-    GGML_CANN_CALL_ACLNN_OP(ArgMax, acl_src, 3, false, acl_dst);
+    GGML_CANN_CALL_ACLNN_OP(ctx, ArgMax, acl_src, 3, false, acl_dst);
 
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst);
 }
 
 void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){
@@ -2598,14 +2498,10 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
     cubeMathType = 1;
 #endif
 
-    GGML_CANN_CALL_ACLNN_OP(Convolution, acl_input, acl_weight, nullptr, stride,
+    GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input, acl_weight, nullptr, stride,
         padding, dilation, transposed, padding, groups, acl_dst, cubeMathType);
 
-    ACL_CHECK(aclDestroyTensor(acl_weight));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
-    ACL_CHECK(aclDestroyIntArray(stride));
-    ACL_CHECK(aclDestroyIntArray(padding));
-    ACL_CHECK(aclDestroyIntArray(dilation));
+    ggml_cann_release_resources(ctx, acl_weight, acl_dst, stride, padding, dilation);
 }
 
 void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst){
@@ -2618,12 +2514,10 @@ void ggml_cann_elu(ggml_backend_cann_context& ctx, ggml_tensor* dst){
     aclScalar* alpha = nullptr;
     alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
 
-    GGML_CANN_CALL_ACLNN_OP(Elu, acl_input, alpha, alpha, alpha,
+    GGML_CANN_CALL_ACLNN_OP(ctx, Elu, acl_input, alpha, alpha, alpha,
         acl_dst);
 
-    ACL_CHECK(aclDestroyTensor(acl_input));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
-    ACL_CHECK(aclDestroyScalar(alpha));
+    ggml_cann_release_resources(ctx, acl_input, acl_dst, alpha);
 }
 
 void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst){
@@ -2636,11 +2530,9 @@ void ggml_cann_mean(ggml_backend_cann_context& ctx, ggml_tensor* dst){
     aclIntArray* reduceDim = aclCreateIntArray(reduceDimValue, 1);
     bool keepDim = true;
 
-    GGML_CANN_CALL_ACLNN_OP(Mean, acl_src, reduceDim, keepDim, ACL_FLOAT, acl_dst);
+    GGML_CANN_CALL_ACLNN_OP(ctx, Mean, acl_src, reduceDim, keepDim, ACL_FLOAT, acl_dst);
 
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
-    ACL_CHECK(aclDestroyIntArray(reduceDim));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst, reduceDim);
 }
 
 void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){
@@ -2660,12 +2552,11 @@ void ggml_cann_pad_reflect_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){
             ggml_cann_type_mapping(dst->type), ggml_element_size(dst),
             dst->ne, dst->nb, 3);
 
-            GGML_CANN_CALL_ACLNN_OP(ReflectionPad1d, acl_src, paddings, acl_dst);
+            GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src, paddings, acl_dst);
 
-            ACL_CHECK(aclDestroyTensor(acl_src));
-            ACL_CHECK(aclDestroyTensor(acl_dst));
+            ggml_cann_release_resources(ctx, acl_src, acl_dst);
     }
-    ACL_CHECK(aclDestroyIntArray(paddings));
+    ggml_cann_release_resources(ctx, paddings);
 }
 
 void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst){
@@ -2675,12 +2566,11 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst){
     aclTensor* acl_self = ggml_cann_create_tensor(src0);
     aclTensor* acl_other = ggml_cann_create_tensor(src1);
 
-    GGML_CANN_CALL_ACLNN_OP(InplaceEqTensor, acl_self, acl_other);
+    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceEqTensor, acl_self, acl_other);
 
     ggml_cann_sum(ctx, dst);
 
-    ACL_CHECK(aclDestroyTensor(acl_self));
-    ACL_CHECK(aclDestroyTensor(acl_other));
+    ggml_cann_release_resources(ctx, acl_self, acl_other);
 }
 
 void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){
@@ -2693,9 +2583,7 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){
     aclScalar* alpha = nullptr;
     alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
 
-    GGML_CANN_CALL_ACLNN_OP(GtScalar, acl_src, alpha, acl_dst);
+    GGML_CANN_CALL_ACLNN_OP(ctx, GtScalar, acl_src, alpha, acl_dst);
 
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
-    ACL_CHECK(aclDestroyScalar(alpha));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst, alpha);
 }
diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h
index b2d1b3c36d..462351542e 100644
--- a/ggml/src/ggml-cann/aclnn_ops.h
+++ b/ggml/src/ggml-cann/aclnn_ops.h
@@ -23,6 +23,7 @@
 #ifndef CANN_ACLNN_OPS
 #define CANN_ACLNN_OPS
 
+#include 
 #include 
 #include 
 #include 
@@ -713,6 +714,270 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst);
  */
 void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
 
+/*
+ * @brief A generic wrapper for ACL resources with custom deleter support.
+ */
+using any_acl_resource = std::unique_ptr>;
+
+/**
+ * @brief Trait structure used to define how to destroy a given ACL resource type.
+ *
+ * @tparam T ACL resource type.
+ */
+template
+struct acl_resource_traits;
+
+/**
+ * @brief Specialization for aclTensor, defines how to destroy an aclTensor resource.
+ */
+template<>
+struct acl_resource_traits {
+    static void destroy(void* p) {
+        ACL_CHECK(aclDestroyTensor(static_cast(p)));
+    }
+};
+
+/**
+ * @brief Specialization for aclIntArray, defines how to destroy an aclIntArray resource.
+ */
+template<>
+struct acl_resource_traits {
+    static void destroy(void* p) {
+        ACL_CHECK(aclDestroyIntArray(static_cast(p)));
+    }
+};
+
+/**
+ * @brief Specialization for aclScalar, defines how to destroy an aclScalar resource.
+ */
+template<>
+struct acl_resource_traits {
+    static void destroy(void* p) {
+        ACL_CHECK(aclDestroyScalar(static_cast(p)));
+    }
+};
+
+/**
+ * @brief Specialization for aclTensorList, defines how to destroy an aclTensorList resource.
+ */
+template<>
+struct acl_resource_traits {
+    static void destroy(void* p) {
+        ACL_CHECK(aclDestroyTensorList(static_cast(p)));
+    }
+};
+
+/**
+ * @brief Creates a generic ACL resource wrapper with proper destruction logic.
+ *
+ * @tparam T ACL resource type.
+ * @param ptr Raw pointer to ACL resource.
+ * @return any_acl_resource Smart pointer that handles destruction.
+ */
+template
+any_acl_resource make_acl_resource(T* ptr) {
+    return any_acl_resource(
+        static_cast(ptr),
+        [](void* p) {
+            acl_resource_traits::destroy(p);
+        }
+    );
+}
+
+/**
+ * @brief Registers multiple ACL resources into a vector for lifetime management.
+ *
+ * @tparam Args Variadic list of ACL resource types.
+ * @param vec Target vector to hold ACL resources.
+ * @param args Raw pointers to ACL resources.
+ */
+template
+void register_acl_resources(std::vector& vec, Args*... args) {
+    (vec.emplace_back(make_acl_resource(args)), ...);
+}
+
+/**
+ * @brief Task class that wraps the execution of an aclnn function call.
+ */
+class aclnn_task : public cann_task {
+    public:
+        aclnn_task(aclnn_func_t aclnn_func, void * workspace_addr,
+                   uint64_t workspace_size, aclOpExecutor * executor,
+                   aclrtStream stream) :
+            aclnn_func_(aclnn_func),
+            workspace_addr_(workspace_addr),
+            workspace_size_(workspace_size),
+            executor_(executor),
+            stream_(stream) {}
+        virtual void run_task() override {
+            ACL_CHECK(aclnn_func_(workspace_addr_, workspace_size_, executor_, stream_));
+        }
+    private:
+        aclnn_func_t aclnn_func_;
+        void *          workspace_addr_;
+        uint64_t        workspace_size_;
+        aclOpExecutor * executor_;
+        aclrtStream     stream_;
+};
+
+/**
+ * @brief Task class that releases ACL resources after usage.
+ */
+class release_resource_task : public cann_task {
+public:
+    release_resource_task(std::vector&& resources){
+        resource_ = std::move(resources);
+    }
+
+    virtual void run_task() override {
+        resource_.clear();
+    }
+private:
+    std::vector resource_;
+};
+
+/**
+ * @brief Task class for performing asynchronous memory copy operations.
+ */
+class async_memcpy_task : public cann_task {
+public:
+    async_memcpy_task(void* dst, const void* src, size_t size,
+                      aclrtMemcpyKind kind, aclrtStream stream)
+        : dst_(dst), src_(src), size_(size), kind_(kind), stream_(stream) {}
+
+    virtual void run_task() override {
+        ACL_CHECK(aclrtMemcpyAsync(dst_, size_, src_, size_, kind_, stream_));
+    }
+private:
+    void* dst_;
+    const void* src_;
+    size_t size_;
+    aclrtMemcpyKind kind_;
+    aclrtStream stream_;
+};
+
+/**
+ * @brief Task class for performing asynchronous memory set operations.
+ */
+class async_memset_task : public cann_task {
+    public:
+    async_memset_task(void* buffer, size_t size, int32_t value, aclrtStream stream)
+            : buffer_(buffer), size_(size), value_(value), stream_(stream) {}
+
+        virtual void run_task() override {
+            ACL_CHECK(aclrtMemsetAsync(buffer_, size_, value_, size_, stream_));
+        }
+    private:
+        void* buffer_;
+        size_t size_;
+        int32_t value_;
+        aclrtStream stream_;
+};
+
+/**
+ * @brief Launches an asynchronous task using the memory allocator.
+ *
+ * This macro submit an asynchronous task on the specified stream.
+ * The task uses memory allocated by the allocator. It is guaranteed
+ * that the memory will not be accessed by other tasks until this task
+ * completes, due to the sequential execution order within the same stream.
+ *
+ * @param OP_NAME aclnn operator name.
+ * @param args Additional arguments required by the task.
+ *
+ * @note
+ * Memory from the allocator will be "freed" immediately and can be
+ * reallocated to other pointers. However, it won't be accessed by any
+ * other task before this asynchronous task ends, because all tasks in the
+ * same stream are executed in queue order.
+ */
+
+#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...)                                          \
+    do {                                                                                    \
+        uint64_t        workspaceSize = 0;                                                  \
+        aclOpExecutor * executor;                                                           \
+        void *          workspaceAddr = nullptr;                                            \
+        ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor));\
+        /* workspace should alloced in main thread to keep malloc order when using vmm. */  \
+        if (workspaceSize > 0) {                                                            \
+            ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize);            \
+            workspaceAddr = workspace_allocator.get();                                      \
+        }                                                                                   \
+        if (CTX.async_mode) {                                                               \
+            auto task =                                                                     \
+                std::make_unique(aclnn##OP_NAME, workspaceAddr, workspaceSize,  \
+                    executor, CTX.stream()); \
+            CTX.task_queue.submit_task(std::move(task));                                    \
+        } else {                                                                            \
+            ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream()));\
+        }                                                                                   \
+    } while (0)
+
+/**
+ * @brief Registers and releases multiple ACL resources, optionally deferring the release
+ *        using a task.
+ *
+ * @tparam Args Types of the ACL resources.
+ * @param ctx Backend context which manages task submission and async mode.
+ * @param args Pointers to ACL resources to be released.
+ */
+template 
+void ggml_cann_release_resources(ggml_backend_cann_context & ctx, Args &&... args) {
+    std::vector resources;
+    register_acl_resources(resources, std::forward(args)...);
+    if(ctx.async_mode) {
+        auto task = std::make_unique(std::move(resources));
+        ctx.task_queue.submit_task(std::move(task));
+    }
+}
+
+/**
+ * @brief Performs an asynchronous memory copy operation, optionally deferred via task submission.
+ *
+ * @param ctx Backend context containing stream and async configuration.
+ * @param dst Destination memory address.
+ * @param src Source memory address.
+ * @param len Size of memory to copy (in bytes).
+ * @param kind Type of memory copy (host-to-device, device-to-host, etc).
+ */
+inline void ggml_cann_async_memcpy(ggml_backend_cann_context & ctx, void * dst,
+                                   const void * src, size_t len, aclrtMemcpyKind kind) {
+    if (ctx.async_mode) {
+        auto task = std::make_unique(dst, const_cast(src), len, kind, ctx.stream());
+        ctx.task_queue.submit_task(std::move(task));
+    } else {
+        ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx.stream()));
+    }
+}
+
+inline void ggml_cann_async_memcpy(ggml_backend_cann_context * ctx, void * dst,
+                                   const void * src, size_t len, aclrtMemcpyKind kind) {
+    if (ctx->async_mode) {
+        auto task = std::make_unique(dst, const_cast(src), len, kind, ctx->stream());
+        ctx->task_queue.submit_task(std::move(task));
+    } else {
+        ACL_CHECK(aclrtMemcpyAsync(dst, len, src, len, kind, ctx->stream()));
+    }
+}
+
+/**
+ * @brief Performs an asynchronous memory set operation, optionally deferred via task submission.
+ *
+ * @param ctx Backend context containing stream and async configuration.
+ * @param buffer Memory buffer to be set.
+ * @param size Size of the memory buffer (in bytes).
+ * @param value Value to set in the buffer.
+ */
+inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffer,
+                                   size_t size, int value) {
+    if (ctx.async_mode) {
+        auto task = std::make_unique(buffer, size, value, ctx.stream());
+        ctx.task_queue.submit_task(std::move(task));
+    } else {
+        ACL_CHECK(aclrtMemsetAsync(buffer, size, value, size, ctx.stream()));
+    }
+}
+
 /**
  * @brief Applies a element-wise operation to two input tensors using the CANN
  * backend.
@@ -742,42 +1007,9 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     bcast_shape(src0, src1, dst, &acl_src0, &acl_src1, &acl_dst);
     binary_op(ctx, acl_src0, acl_src1, acl_dst);
 
-    ACL_CHECK(aclDestroyTensor(acl_src0));
-    ACL_CHECK(aclDestroyTensor(acl_src1));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst);
 }
 
-/**
- * @brief Launches an asynchronous task using the memory allocator.
- *
- * This macro submit an asynchronous task on the specified stream.
- * The task uses memory allocated by the allocator. It is guaranteed
- * that the memory will not be accessed by other tasks until this task
- * completes, due to the sequential execution order within the same stream.
- *
- * @param OP_NAME aclnn operator name.
- * @param args Additional arguments required by the task.
- *
- * @note
- * Memory from the allocator will be "freed" immediately and can be
- * reallocated to other pointers. However, it won't be accessed by any
- * other task before this asynchronous task ends, because all tasks in the
- * same stream are executed in queue order.
- */
-#define GGML_CANN_CALL_ACLNN_OP(OP_NAME, ...)                                                \
-    do {                                                                                     \
-        uint64_t        workspaceSize = 0;                                                   \
-        aclOpExecutor * executor;                                                            \
-        void *          workspaceAddr = nullptr;                                             \
-                                                                                             \
-        ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
-                                                                                             \
-        if (workspaceSize > 0) {                                                             \
-            ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);             \
-            workspaceAddr = workspace_allocator.get();                                       \
-        }                                                                                    \
-        ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, ctx.stream()));     \
-    } while (0)
 
 /**
  * @brief Applies a unary operation to an input tensor using the CANN backend.
@@ -799,9 +1031,7 @@ template 
     aclTensor* acl_dst = ggml_cann_create_tensor(dst);
 
     unary_op(ctx, acl_src, acl_dst);
-
-    ACL_CHECK(aclDestroyTensor(acl_src));
-    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ggml_cann_release_resources(ctx, acl_src, acl_dst);
 }
 
 /**
@@ -832,7 +1062,7 @@ void ggml_cann_unary_op(
  *
  * Internally, the lambda will call:
  * @code
- * GGML_CANN_CALL_ACLNN_OP(OP_NAME, acl_src, acl_dst);
+ * GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
  * @endcode
  *
  * @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP.
@@ -840,14 +1070,14 @@ void ggml_cann_unary_op(
  * @see ggml_cann_unary_op
  * @see GGML_CANN_CALL_ACLNN_OP
  */
-#define GGML_CANN_CALL_UNARY_OP(OP_NAME)                         \
-    do {                                                         \
-        auto lambda = [](ggml_backend_cann_context& ctx,         \
-            aclTensor* acl_src,                                  \
-            aclTensor* acl_dst) {                                \
-            GGML_CANN_CALL_ACLNN_OP(OP_NAME, acl_src, acl_dst);  \
-        };                                                       \
-        ggml_cann_unary_op(lambda, ctx, dst);                    \
-    }                                                            \
+#define GGML_CANN_CALL_UNARY_OP(OP_NAME)                              \
+    do {                                                              \
+        auto lambda = [](ggml_backend_cann_context& ctx,              \
+            aclTensor* acl_src,                                       \
+            aclTensor* acl_dst) {                                     \
+            GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);  \
+        };                                                            \
+        ggml_cann_unary_op(lambda, ctx, dst);                         \
+    }                                                                 \
     while (0)
 #endif  // CANN_ACLNN_OPS
diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h
index 5164cb74ec..7ef80a4793 100644
--- a/ggml/src/ggml-cann/common.h
+++ b/ggml/src/ggml-cann/common.h
@@ -31,9 +31,16 @@
 #include 
 #include 
 #include 
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
 
 #include "../include/ggml-cann.h"
 #include "../include/ggml.h"
+#include "../ggml-impl.h"
 
 #define MATRIX_ROW_PADDING 512
 #define GGML_CANN_MAX_STREAMS 8
@@ -205,6 +212,127 @@ struct ggml_cann_pool_alloc {
     ggml_cann_pool_alloc& operator=(ggml_cann_pool_alloc&&) = delete;
 };
 
+/**
+ * @brief Function pointer type for ACLNN operator calls.
+ */
+using aclnn_func_t = aclnnStatus (*)(void*, uint64_t, aclOpExecutor*, aclrtStream);
+
+/**
+ * @brief Base class for all CANN tasks to be submitted to the task queue.
+ *
+ * Users should override the run_task() method with actual task logic.
+ */
+class cann_task {
+public:
+    virtual void run_task() {}
+};
+
+/**
+ * @brief A lock-free ring-buffer based task queue for asynchronously executing cann_task instances.
+ */
+class cann_task_queue {
+public:
+    /**
+     * @brief Constructs a task queue with a fixed power-of-two capacity for a specific device.
+     *
+     * @param capacity Queue capacity. Must be a power of 2.
+     * @param device Target device ID (used for context setting).
+     */
+    explicit cann_task_queue(size_t capacity, int32_t device)
+        : buffer_(capacity), capacity_(capacity), head_(0), tail_(0),
+          running_(false), device_(device) {
+        GGML_ASSERT((capacity & (capacity - 1)) == 0 && "capacity must be power of 2");
+        mask_ = capacity_ - 1;
+    }
+
+    /**
+     * @brief Attempts to enqueue a task into the queue.
+     *
+     * @param item Unique pointer to the task.
+     * @return true if the task was successfully enqueued, false if the queue was full.
+     */
+    bool enqueue(std::unique_ptr&& item) {
+        size_t next_tail = (tail_ + 1) & mask_;
+
+        if (next_tail == head_) {
+            return false;
+        }
+
+        buffer_[tail_] = std::move(item);
+        std::atomic_thread_fence(std::memory_order_release);
+        tail_ = next_tail;
+
+        return true;
+    }
+
+    /**
+     * @brief Submits a task to the queue, and starts the worker thread if not already running.
+     *
+     * @param task Task to be submitted.
+     */
+    void submit_task(std::unique_ptr&& task) {
+        while(!enqueue(std::move(task))) {
+            std::this_thread::yield();
+            continue;
+        }
+
+        if (!running_) {
+            running_ = true;
+            thread_ = std::thread(&cann_task_queue::execute, this);
+        }
+
+    }
+
+    /**
+     * @brief Waits until the queue is completely empty and no tasks are being processed.
+     */
+    void wait() {
+        while (running_ && head_ != tail_) {
+            std::this_thread::yield();
+            continue;
+        }
+    }
+
+    /**
+     * @brief Stops the task queue and joins the worker thread.
+     */
+    void stop() {
+        running_ = false;
+        if (thread_.joinable()) {
+            thread_.join();
+        }
+    }
+
+private:
+    /**
+     * @brief Worker thread function that continuously dequeues and executes tasks.
+     */
+    void execute() {
+        ggml_cann_set_device(device_);
+
+        while (running_) {
+            if(head_ == tail_) {
+                std::this_thread::yield();
+                continue;
+            }
+
+            std::atomic_thread_fence(std::memory_order_acquire);
+            buffer_[head_]->run_task();
+            buffer_[head_].reset();
+            head_ = (head_ + 1) & mask_;
+        }
+    }
+
+    std::vector> buffer_;
+    const size_t capacity_;
+    size_t mask_;
+    size_t head_;
+    size_t tail_;
+    bool running_;
+    std::thread thread_;
+    int32_t device_;
+};
+
 /**
  * @brief Context for managing CANN backend operations.
  */
@@ -213,6 +341,8 @@ struct ggml_backend_cann_context {
     std::string name;                /**< Name of the device. */
     std::string description;         /**< Description of the device. */
     aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
+    cann_task_queue task_queue;
+    bool async_mode;
 
     aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
 
@@ -221,9 +351,12 @@ struct ggml_backend_cann_context {
      * @param device Device ID.
      */
     explicit ggml_backend_cann_context(int device)
-        : device(device), name("CANN" + std::to_string(device)) {
+        : device(device), name("CANN" + std::to_string(device)), task_queue(1024, device) {
         ggml_cann_set_device(device);
         description = aclrtGetSocName();
+        async_mode = (getenv("GGML_CANN_ASYNC_MODE") != nullptr);
+        GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
+            device, async_mode ? "ON" : "OFF");
     }
 
     /**
@@ -231,6 +364,7 @@ struct ggml_backend_cann_context {
      */
     ~ggml_backend_cann_context() {
         ggml_cann_set_device(device);
+        task_queue.stop();
         if (copy_event != nullptr) {
             ACL_CHECK(aclrtDestroyEvent(copy_event));
         }
diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp
index ca41e02607..e2617b06e9 100644
--- a/ggml/src/ggml-cann/ggml-cann.cpp
+++ b/ggml/src/ggml-cann/ggml-cann.cpp
@@ -1606,7 +1606,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
                     auto lambda = [](ggml_backend_cann_context& ctx,
                         aclTensor* acl_src,
                         aclTensor* acl_dst) {
-                        GGML_CANN_CALL_ACLNN_OP(GeluV2, acl_src, 0, acl_dst);
+                        GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
                     };
                     ggml_cann_unary_op(lambda, ctx, dst);
                 } break;
@@ -1789,12 +1789,11 @@ static void ggml_backend_cann_free(ggml_backend_t backend) {
     delete backend;
 }
 
+
 /**
  * @brief Sets tensor data asynchronously in the CANN backend.
  *
- * This function asynchronously sets tensor data in the CANN backend. Depending
- * on the tensor type, it may perform data transformations before copying data
- * to the device.
+ * This function asynchronously sets tensor data in the CANN backend.
  *
  * @param backend Pointer to the CANN backend structure.
  * @param tensor Pointer to the tensor structure to set data for.
@@ -1809,23 +1808,28 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
                                                size_t size) {
     ggml_backend_cann_context *cann_ctx =
         (ggml_backend_cann_context *)backend->context;
+    ggml_backend_buffer_t buf =
+        tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
 
-    if (!need_transform(tensor->type)) {
-        ACL_CHECK(aclrtMemcpyAsync((char *)tensor->data + offset, size, data,
-                                   size, ACL_MEMCPY_HOST_TO_DEVICE,
-                                   cann_ctx->stream()));
-    } else {
-        void *transform_buffer = malloc(size);
-        ggml_backend_cann_transform(tensor, data, transform_buffer);
+    GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
+        "unsupported buffer type");
+    GGML_ASSERT(!ggml_is_quantized(tensor->type));
 
-        ACL_CHECK(aclrtMemcpyAsync(
-            (char *)tensor->data + offset, size, transform_buffer, size,
-            ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
-        ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
-        free(transform_buffer);
-    }
+    ggml_cann_async_memcpy(cann_ctx, (char *)tensor->data + offset, data, size,
+        ACL_MEMCPY_HOST_TO_DEVICE);
 }
 
+/**
+ * @brief Gets tensor data asynchronously in the CANN backend.
+ *
+ * This function asynchronously gets tensor data in the CANN backend.
+ *
+ * @param backend Pointer to the CANN backend structure.
+ * @param tensor Pointer to the tensor structure to get data from.
+ * @param data Pointer to the host data to copy from the tensor.
+ * @param offset Offset in bytes within the host data.
+ * @param size Size of the data to copy in bytes.
+ */
 static void ggml_backend_cann_get_tensor_async(
     ggml_backend_t backend, const ggml_tensor *tensor, void *data,
     size_t offset, size_t size) {
@@ -1836,20 +1840,11 @@ static void ggml_backend_cann_get_tensor_async(
 
     GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
                 "unsupported buffer type");
+    GGML_ASSERT(!ggml_is_quantized(tensor->type));
+
+    ggml_cann_async_memcpy(cann_ctx, data, (char *)tensor->data + offset, size,
+        ACL_MEMCPY_DEVICE_TO_HOST);
 
-    if (!need_transform(tensor->type)) {
-        ACL_CHECK(aclrtMemcpyAsync(data, size, (char *)tensor->data + offset,
-                                   size, ACL_MEMCPY_DEVICE_TO_HOST,
-                                   cann_ctx->stream()));
-    } else {
-        void *transform_buffer = malloc(size);
-        ACL_CHECK(aclrtMemcpyAsync(
-            transform_buffer, size, (char *)tensor->data + offset, size,
-            ACL_MEMCPY_DEVICE_TO_HOST, cann_ctx->stream()));
-        ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
-        ggml_backend_cann_transform_back(tensor, transform_buffer, data);
-        free(transform_buffer);
-    }
 }
 
 /**
@@ -1909,6 +1904,8 @@ static bool ggml_backend_cann_cpy_tensor_async(
         ggml_cann_set_device(cann_ctx_src->device);
         ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
 
+        // wait for task_queue empty to keep task order.
+        cann_ctx_src->task_queue.wait();
         ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
                                    ACL_MEMCPY_DEVICE_TO_DEVICE,
                                    cann_ctx_src->stream()));
@@ -1936,9 +1933,8 @@ static bool ggml_backend_cann_cpy_tensor_async(
 static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
     ggml_backend_cann_context* cann_ctx =
         (ggml_backend_cann_context*)backend->context;
-
+    cann_ctx->task_queue.wait();
     ggml_cann_set_device(cann_ctx->device);
-
     ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
 }
 

From 207c22ec2d6d793fc70830138617d1e016c5151c Mon Sep 17 00:00:00 2001
From: Alan Gray 
Date: Thu, 17 Apr 2025 14:19:42 +0100
Subject: [PATCH 07/39] ggml: Re-enable CUDA graphs in presence of CONT and DUP
 nodes (#12970)

---
 ggml/src/ggml-cuda/cpy.cu       | 9 +++++----
 ggml/src/ggml-cuda/cpy.cuh      | 2 +-
 ggml/src/ggml-cuda/ggml-cuda.cu | 2 +-
 3 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
index 4f4faa3e63..ed25646e8e 100644
--- a/ggml/src/ggml-cuda/cpy.cu
+++ b/ggml/src/ggml-cuda/cpy.cu
@@ -551,7 +551,7 @@ static void ggml_cpy_f16_f16_cuda(
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
 }
 
-void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
+void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) {
     const int64_t ne = ggml_nelements(src0);
     GGML_ASSERT(ne == ggml_nelements(src1));
 
@@ -588,7 +588,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
     char ** dest_ptrs_d = nullptr;
     int graph_cpynode_index = -1;
 #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
-    if(ctx.cuda_graph->use_cpy_indirection) {
+    if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
         dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d;
         graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index;
     }
@@ -636,7 +636,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
                 ggml_type_name(src0->type), ggml_type_name(src1->type));
     }
 #if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)
-    if(ctx.cuda_graph->use_cpy_indirection) {
+    if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) {
         ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index;
     }
 #endif
@@ -645,7 +645,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
 
 void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
-    ggml_cuda_cpy(ctx, src0, dst);
+    bool disable_indirection = true;
+    ggml_cuda_cpy(ctx, src0, dst, disable_indirection);
 }
 
 void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
diff --git a/ggml/src/ggml-cuda/cpy.cuh b/ggml/src/ggml-cuda/cpy.cuh
index 6bed0564df..0bd3c0c6f8 100644
--- a/ggml/src/ggml-cuda/cpy.cuh
+++ b/ggml/src/ggml-cuda/cpy.cuh
@@ -2,7 +2,7 @@
 
 #define CUDA_CPY_BLOCK_SIZE 64
 
-void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
+void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1,  bool disable_indirection = false);
 
 void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 9ced466512..bab85809ac 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -2489,7 +2489,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
 #endif
         }
 
-        if (node->op == GGML_OP_MUL_MAT_ID || node->op == GGML_OP_CONT || node->op == GGML_OP_DUP) {
+        if (node->op == GGML_OP_MUL_MAT_ID) {
             use_cuda_graph = false; // This node type is not supported by CUDA graph capture
 #ifndef NDEBUG
             GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);

From 2f74c354c0f752ed9aabf7d3a350e6edebd7e744 Mon Sep 17 00:00:00 2001
From: Georgi Gerganov 
Date: Thu, 17 Apr 2025 18:16:36 +0300
Subject: [PATCH 08/39] graph : make FA compatible with MLA + add initial Metal
 kernels (#12953)

* graph : make mla compatible with FA

* metal : add exp FA kernels for DeepSeek models

ggml-ci

* llama : minor naming updates

ggml-ci

* ggml : disable FA for DS head sizes

* tests : add FA tests for MLA shapes

ggml-ci
---
 ggml/src/ggml-cuda/ggml-cuda.cu      |  4 ++
 ggml/src/ggml-metal/ggml-metal.m     | 83 ++++++++++++++++++++++++++--
 ggml/src/ggml-metal/ggml-metal.metal | 17 ++++++
 ggml/src/ggml-vulkan/ggml-vulkan.cpp |  1 +
 src/llama-context.cpp                | 11 +---
 src/llama-graph.cpp                  | 14 +++--
 src/llama-model.cpp                  |  6 +-
 tests/test-backend-ops.cpp           |  7 ++-
 8 files changed, 117 insertions(+), 26 deletions(-)

diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index bab85809ac..a7febef723 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -3237,6 +3237,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             if (op->src[0]->ne[0] == 192) {
                 return false;
             }
+            if (op->src[0]->ne[0] == 576) {
+                // DeepSeek MLA
+                return false;
+            }
             if (op->src[0]->ne[3] != 1) {
                 return false;
             }
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index 16ca08383c..85f3ae7bfd 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -354,6 +354,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
@@ -362,6 +363,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
@@ -370,6 +372,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,
@@ -378,6 +381,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,
@@ -386,6 +390,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,
@@ -394,6 +399,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,
@@ -402,6 +408,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,
@@ -437,6 +444,13 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
+    GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
     GGML_METAL_KERNEL_TYPE_SET_I32,
     GGML_METAL_KERNEL_TYPE_SET_F32,
     GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
@@ -1018,6 +1032,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H192,         flash_attn_ext_f16_h192,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128,  flash_attn_ext_f16_hk192_hv128,  has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,         flash_attn_ext_f16_h256,         has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512,  flash_attn_ext_f16_hk576_hv512,  has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,         flash_attn_ext_bf16_h64,         has_simdgroup_mm && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,         flash_attn_ext_bf16_h80,         has_simdgroup_mm && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,         flash_attn_ext_bf16_h96,         has_simdgroup_mm && use_bfloat);
@@ -1026,6 +1041,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H192,        flash_attn_ext_bf16_h192,        has_simdgroup_mm && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128, flash_attn_ext_bf16_hk192_hv128, has_simdgroup_mm && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,        flash_attn_ext_bf16_h256,        has_simdgroup_mm && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512, flash_attn_ext_bf16_hk576_hv512, has_simdgroup_mm && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,         flash_attn_ext_q4_0_h64,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,         flash_attn_ext_q4_0_h80,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,         flash_attn_ext_q4_0_h96,         has_simdgroup_mm);
@@ -1034,6 +1050,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H192,        flash_attn_ext_q4_0_h192,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128, flash_attn_ext_q4_0_hk192_hv128, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,        flash_attn_ext_q4_0_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512, flash_attn_ext_q4_0_hk576_hv512, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,         flash_attn_ext_q4_1_h64,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,         flash_attn_ext_q4_1_h80,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,         flash_attn_ext_q4_1_h96,         has_simdgroup_mm);
@@ -1042,6 +1059,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H192,        flash_attn_ext_q4_1_h192,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128, flash_attn_ext_q4_1_hk192_hv128, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,        flash_attn_ext_q4_1_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512, flash_attn_ext_q4_1_hk576_hv512, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,         flash_attn_ext_q5_0_h64,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,         flash_attn_ext_q5_0_h80,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,         flash_attn_ext_q5_0_h96,         has_simdgroup_mm);
@@ -1050,6 +1068,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H192,        flash_attn_ext_q5_0_h192,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128, flash_attn_ext_q5_0_hk192_hv128, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,        flash_attn_ext_q5_0_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512, flash_attn_ext_q5_0_hk576_hv512, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,         flash_attn_ext_q5_1_h64,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,         flash_attn_ext_q5_1_h80,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,         flash_attn_ext_q5_1_h96,         has_simdgroup_mm);
@@ -1058,6 +1077,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H192,        flash_attn_ext_q5_1_h192,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128, flash_attn_ext_q5_1_hk192_hv128, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,        flash_attn_ext_q5_1_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512, flash_attn_ext_q5_1_hk576_hv512, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,         flash_attn_ext_q8_0_h64,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,         flash_attn_ext_q8_0_h80,         has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,         flash_attn_ext_q8_0_h96,         has_simdgroup_mm);
@@ -1066,6 +1086,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H192,        flash_attn_ext_q8_0_h192,        has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,        flash_attn_ext_q8_0_h256,        has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H96,      flash_attn_ext_vec_f16_h96,      has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H96,     flash_attn_ext_vec_bf16_h96,     has_simdgroup_reduction && use_bfloat);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H96,     flash_attn_ext_vec_q4_0_h96,     has_simdgroup_reduction);
@@ -1101,6 +1122,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,    flash_attn_ext_vec_q5_0_h256,    has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,    flash_attn_ext_vec_q5_1_h256,    has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,    flash_attn_ext_vec_q8_0_h256,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512,     flash_attn_ext_vec_f16_hk576_hv512,     has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512,    flash_attn_ext_vec_bf16_hk576_hv512,    has_simdgroup_reduction && use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512,    flash_attn_ext_vec_q4_0_hk576_hv512,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512,    flash_attn_ext_vec_q4_1_hk576_hv512,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,    flash_attn_ext_vec_q5_0_hk576_hv512,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,    flash_attn_ext_vec_q5_1_hk576_hv512,    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,    flash_attn_ext_vec_q8_0_hk576_hv512,    has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32,                         set_f32,                         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32,                         set_i32,                         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                     cpy_f32_f32,                     true);
@@ -1365,6 +1393,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                 // TODO: not sure if it is worth adding kernels for this size
                 return false;
             }
+            if (op->src[0]->ne[0] == 576) {
+                // DeepSeek sizes
+                // TODO: disabled for now, until optmized
+                return false;
+            }
             if (op->src[1]->type != op->src[2]->type) {
                 return false;
             }
@@ -3857,12 +3890,14 @@ static void ggml_metal_encode_node(
                 // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
                 //       for now avoiding mainly to keep the number of templates/kernels a bit lower
                 //       these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
-                if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192)) {
+                if (ne01 >= 4 || (ne00%128 != 0 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
                     switch (src1->type) {
                         case GGML_TYPE_F16:
                             {
                                 if (ne00 == 192 && ne20 == 128) {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK192_HV128].pipeline;
+                                } else if (ne00 == 576 && ne20 == 512) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_HK576_HV512].pipeline;
                                 } else {
                                     switch (ne00) {
                                         case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
@@ -3885,6 +3920,8 @@ static void ggml_metal_encode_node(
                             {
                                 if (ne00 == 192 && ne20 == 128) {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK192_HV128].pipeline;
+                                } else if (ne00 == 576 && ne20 == 512) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_HK576_HV512].pipeline;
                                 } else {
                                     switch (ne00) {
                                         case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
@@ -3907,6 +3944,8 @@ static void ggml_metal_encode_node(
                             {
                                 if (ne00 == 192 && ne20 == 128) {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK192_HV128].pipeline;
+                                } else if (ne00 == 576 && ne20 == 512) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_HK576_HV512].pipeline;
                                 } else {
                                     switch (ne00) {
                                         case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64 ].pipeline; break;
@@ -3929,6 +3968,8 @@ static void ggml_metal_encode_node(
                             {
                                 if (ne00 == 192 && ne20 == 128) {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK192_HV128].pipeline;
+                                } else if (ne00 == 576 && ne20 == 512) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_HK576_HV512].pipeline;
                                 } else {
                                     switch (ne00) {
                                         case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64 ].pipeline; break;
@@ -3951,6 +3992,8 @@ static void ggml_metal_encode_node(
                             {
                                 if (ne00 == 192 && ne20 == 128) {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK192_HV128].pipeline;
+                                } else if (ne00 == 576 && ne20 == 512) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_HK576_HV512].pipeline;
                                 } else {
                                     switch (ne00) {
                                         case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64 ].pipeline; break;
@@ -3973,6 +4016,8 @@ static void ggml_metal_encode_node(
                             {
                                 if (ne00 == 192 && ne20 == 128) {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK192_HV128].pipeline;
+                                } else if (ne00 == 576 && ne20 == 512) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_HK576_HV512].pipeline;
                                 } else {
                                     switch (ne00) {
                                         case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64 ].pipeline; break;
@@ -3995,6 +4040,8 @@ static void ggml_metal_encode_node(
                             {
                                 if (ne00 == 192 && ne20 == 128) {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128].pipeline;
+                                } else if (ne00 == 576 && ne20 == 512) {
+                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512].pipeline;
                                 } else {
                                     switch (ne00) {
                                         case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64 ].pipeline; break;
@@ -4114,12 +4161,36 @@ static void ggml_metal_encode_node(
                                         }
                                 }
                             } break;
+                        case 576:
+                            {
+                                if (ne20 == 512) {
+                                    switch (src1->type) {
+                                        case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_HK576_HV512].pipeline; break;
+                                        case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_HK576_HV512].pipeline; break;
+                                        case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_HK576_HV512].pipeline; break;
+                                        case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_HK576_HV512].pipeline; break;
+                                        case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512].pipeline; break;
+                                        case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512].pipeline; break;
+                                        case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512].pipeline; break;
+                                        default:
+                                            {
+                                                GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
+                                                GGML_LOG_ERROR("add template specialization for this type\n");
+                                                GGML_ABORT("add template specialization for this type");
+                                            }
+                                    }
+                                } else {
+                                    GGML_LOG_ERROR("unsupported size: %lld\n", ne20);
+                                    GGML_LOG_ERROR("add template specialization for this size\n");
+                                    GGML_ABORT("add template specialization for this size");
+                                }
+                            } break;
                         default:
-                                  {
-                                      GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
-                                      GGML_LOG_ERROR("add template specialization for this size\n");
-                                      GGML_ABORT("add template specialization for this size");
-                                  }
+                            {
+                                GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
+                                GGML_LOG_ERROR("add template specialization for this size\n");
+                                GGML_ABORT("add template specialization for this size");
+                            }
                     }
                 }
 
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 0fb28238a8..dc7eab03ee 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -3546,6 +3546,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h128")]]         kernel flash_at
 template [[host_name("kernel_flash_attn_ext_f16_h192")]]         kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_f16_hk192_hv128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_f16_h256")]]         kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]]  kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 #if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_flash_attn_ext_bf16_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -3556,6 +3557,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_bf16_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_bf16_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 #endif
 
 template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -3566,6 +3568,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q4_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -3575,6 +3578,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q4_1_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q4_1_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q4_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -3584,6 +3588,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q5_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -3593,6 +3598,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q5_1_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_1_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q5_1_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q5_1_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
@@ -3602,6 +3608,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]]        kernel flash_at
 template [[host_name("kernel_flash_attn_ext_q8_0_h192")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q8_0_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 template [[host_name("kernel_flash_attn_ext_q8_0_h256")]]        kernel flash_attn_ext_t kernel_flash_attn_ext;
+template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext;
 
 #undef FA_TYPES
 
@@ -4009,6 +4016,16 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_
 template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
 
+template [[host_name("kernel_flash_attn_ext_vec_f16_hk576_hv512")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_flash_attn_ext_vec_bf16_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+#endif
+template [[host_name("kernel_flash_attn_ext_vec_q4_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q4_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q5_1_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec;
+
 #undef FA_TYPES
 
 template
diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 0e9b2e8135..2f6d03c939 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -9261,6 +9261,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 case 112:
                 case 128:
                 case 256:
+                case 575: // DeepSeek MLA
                     break;
                 default:
                     return false;
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index d3ef1cbdeb..983385f86d 100644
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
@@ -484,7 +484,7 @@ ggml_tensor * llama_context::build_rope_shift(
 
     // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly.
     // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
-    const float yarn_attn_factor_scaled = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
+    const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor;
 
     ggml_tensor * tmp;
 
@@ -504,14 +504,14 @@ ggml_tensor * llama_context::build_rope_shift(
 
         tmp = ggml_rope_ext_inplace(ctx0, tmp,
                 shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow);
+                yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
 
         tmp = ggml_cpy(ctx0, tmp, cur);
     } else {
         // we rotate only the first n_rot dimensions
         tmp = ggml_rope_ext_inplace(ctx0, cur,
                 shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                yarn_ext_factor, yarn_attn_factor_scaled, yarn_beta_fast, yarn_beta_slow);
+                yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
     }
 
     return tmp;
@@ -2278,11 +2278,6 @@ llama_context * llama_init_from_model(
         params.flash_attn = false;
     }
 
-    if (params.flash_attn && model->arch == LLM_ARCH_DEEPSEEK2) {
-        LLAMA_LOG_WARN("%s: flash_attn is not compatible with Deepseek2 - forcing off\n", __func__);
-        params.flash_attn = false;
-    }
-
     if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
         LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
         return nullptr;
diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp
index 5d0222b981..a85e97288e 100644
--- a/src/llama-graph.cpp
+++ b/src/llama-graph.cpp
@@ -1200,9 +1200,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
   //const auto & n_embd_head_k = hparams.n_embd_head_k;
   //const auto & n_embd_head_v = hparams.n_embd_head_v;
 
-    // note: for MLA with the absorption optimization, the final embedding size will be changed via v_mla
-    const auto n_embd_head_v = v_mla == nullptr ? v_trans ? v->ne[1] : v->ne[0] : v_mla->ne[1];
-
     const auto n_tokens = q->ne[1];
     const auto n_head   = q->ne[2];
     const auto n_kv     = k->ne[1];
@@ -1231,7 +1228,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
 
         ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
 
-        cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
+        if (v_mla) {
+            cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
+            cur = ggml_mul_mat(ctx0, v_mla, cur);
+        }
+
+        cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
     } else {
         ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
 
@@ -1274,9 +1276,9 @@ ggml_tensor * llm_graph_context::build_attn_mha(
             kqv = ggml_mul_mat(ctx0, v_mla, kqv);
         }
 
-        ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+        cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
 
-        cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
+        cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
 
         if (!cparams.offload_kqv) {
             // all nodes between the KV store and the attention output are run on the CPU
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 248c61748e..6b7bfecf3a 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -10050,7 +10050,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
         // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
         const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
         const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(n_embd_head_k));
-        const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
+        const float attn_factor = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
 
         ggml_tensor * cur;
         ggml_tensor * inpL;
@@ -10127,13 +10127,13 @@ struct llm_build_deepseek2 : public llm_graph_context {
 
                 q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor_scaled, beta_fast, beta_slow
+                        ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(q_pe, "q_pe", il);
 
                 k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr,
                         n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
-                        ext_factor, attn_factor_scaled, beta_fast, beta_slow
+                        ext_factor, attn_factor, beta_fast, beta_slow
                 );
                 cb(k_pe, "k_pe", il);
 
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 1ee7428946..5f6f87d1a3 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -4428,10 +4428,11 @@ static std::vector> make_test_cases_eval() {
     test_cases.emplace_back(new test_timestep_embedding());
     test_cases.emplace_back(new test_leaky_relu());
 
-    for (int hsk : { 64, 80, 128, 192, 256, }) {
-        for (int hsv : { 64, 80, 128, 192, 256, }) {
-            if (hsk != 192 && hsk != hsv) continue;
+    for (int hsk : { 64, 80, 128, 192, 256, 576 }) {
+        for (int hsv : { 64, 80, 128, 192, 256, 512 }) {
+            if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
             if (hsk == 192 && (hsv != 128 && hsv != 192)) continue;
+            if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
 
             for (bool mask : { true, false } ) {
                 for (float max_bias : { 0.0f, 8.0f }) {

From 2db9ba1464f3de0aceb2b5289963e69fc369cb66 Mon Sep 17 00:00:00 2001
From: Radoslav Gerganov 
Date: Fri, 18 Apr 2025 10:13:42 +0300
Subject: [PATCH 09/39] rpc : add RPC_CMD_HELLO (#12955)

Add RPC_CMD_HELLO for getting the version of the protocol implemend by
the server. Follow the semantic versioning rules at https://semver.org

Hopefully this bring better user experience when we make breaking
changes at the protocol level and avoid issues like #12465
---
 examples/rpc/rpc-server.cpp    |  5 +++-
 ggml/include/ggml-rpc.h        |  3 ++
 ggml/src/ggml-rpc/ggml-rpc.cpp | 54 +++++++++++++++++++++++++++++++++-
 3 files changed, 60 insertions(+), 2 deletions(-)

diff --git a/examples/rpc/rpc-server.cpp b/examples/rpc/rpc-server.cpp
index f8f2cb90ea..9db5542570 100644
--- a/examples/rpc/rpc-server.cpp
+++ b/examples/rpc/rpc-server.cpp
@@ -297,7 +297,10 @@ int main(int argc, char * argv[]) {
         }
         cache_dir = cache_dir_str.c_str();
     }
-    printf("Starting RPC server\n");
+    printf("Starting RPC server v%d.%d.%d\n",
+           RPC_PROTO_MAJOR_VERSION,
+           RPC_PROTO_MINOR_VERSION,
+           RPC_PROTO_PATCH_VERSION);
     printf("  endpoint       : %s\n", endpoint.c_str());
     printf("  local cache    : %s\n", cache_dir ? cache_dir : "n/a");
     printf("  backend memory : %zu MB\n", free_mem / (1024 * 1024));
diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h
index 4e0d210f8e..c8b6097f7e 100644
--- a/ggml/include/ggml-rpc.h
+++ b/ggml/include/ggml-rpc.h
@@ -7,6 +7,9 @@
 extern "C" {
 #endif
 
+#define RPC_PROTO_MAJOR_VERSION    1
+#define RPC_PROTO_MINOR_VERSION    0
+#define RPC_PROTO_PATCH_VERSION    0
 #define GGML_RPC_MAX_SERVERS       16
 
 // backend API
diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp
index 3189ae85d5..a0667b7d70 100644
--- a/ggml/src/ggml-rpc/ggml-rpc.cpp
+++ b/ggml/src/ggml-rpc/ggml-rpc.cpp
@@ -92,12 +92,19 @@ enum rpc_cmd {
     RPC_CMD_GET_DEVICE_MEMORY,
     RPC_CMD_INIT_TENSOR,
     RPC_CMD_GET_ALLOC_SIZE,
+    RPC_CMD_HELLO,
     RPC_CMD_COUNT,
 };
 
 // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
 const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
 
+struct rpc_msg_hello_rsp {
+    uint8_t major;
+    uint8_t minor;
+    uint8_t patch;
+};
+
 struct rpc_msg_get_alloc_size_req {
     rpc_tensor tensor;
 };
@@ -400,6 +407,20 @@ static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cm
 
 // RPC client-side implementation
 
+static bool check_server_version(const std::shared_ptr & sock) {
+    rpc_msg_hello_rsp response;
+    bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
+    GGML_ASSERT(status);
+    if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
+        fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
+        return false;
+    }
+    if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
+        fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
+    }
+    return true;
+}
+
 static std::shared_ptr get_socket(const std::string & endpoint) {
     static std::mutex mutex;
     std::lock_guard lock(mutex);
@@ -433,6 +454,9 @@ static std::shared_ptr get_socket(const std::string & endpoint) {
     if (sock == nullptr) {
         return nullptr;
     }
+    if (!check_server_version(sock)) {
+        return nullptr;
+    }
     GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
     sockets[endpoint] = sock;
     return sock;
@@ -818,6 +842,7 @@ public:
     }
     ~rpc_server();
 
+    void hello(rpc_msg_hello_rsp & response);
     void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
     void get_alignment(rpc_msg_get_alignment_rsp & response);
     void get_max_size(rpc_msg_get_max_size_rsp & response);
@@ -846,6 +871,13 @@ private:
     std::unordered_set buffers;
 };
 
+void rpc_server::hello(rpc_msg_hello_rsp & response) {
+    response.major = RPC_PROTO_MAJOR_VERSION;
+    response.minor = RPC_PROTO_MINOR_VERSION;
+    response.patch = RPC_PROTO_PATCH_VERSION;
+    GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
+}
+
 bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
     ggml_backend_buffer_type_t buft;
     struct ggml_init_params params {
@@ -1271,8 +1303,24 @@ rpc_server::~rpc_server() {
 static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
                              sockfd_t sockfd, size_t free_mem, size_t total_mem) {
     rpc_server server(backend, cache_dir);
+    uint8_t cmd;
+    if (!recv_data(sockfd, &cmd, 1)) {
+        return;
+    }
+    // the first command sent by the client must be HELLO
+    if (cmd != RPC_CMD_HELLO) {
+        fprintf(stderr, "Expected HELLO command, update client\n");
+        return;
+    }
+    if (!recv_msg(sockfd, nullptr, 0)) {
+        return;
+    }
+    rpc_msg_hello_rsp response;
+    server.hello(response);
+    if (!send_msg(sockfd, &response, sizeof(response))) {
+        return;
+    }
     while (true) {
-        uint8_t cmd;
         if (!recv_data(sockfd, &cmd, 1)) {
             break;
         }
@@ -1282,6 +1330,10 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
             break;
         }
         switch (cmd) {
+            case RPC_CMD_HELLO: {
+                // HELLO command is handled above
+                return;
+            }
             case RPC_CMD_ALLOC_BUFFER: {
                 rpc_msg_alloc_buffer_req request;
                 if (!recv_msg(sockfd, &request, sizeof(request))) {

From b9154ecff93ff54dc554411eb844a2a654be49f2 Mon Sep 17 00:00:00 2001
From: Xuan-Son Nguyen 
Date: Fri, 18 Apr 2025 10:04:51 +0200
Subject: [PATCH 10/39] mtmd : add methods to access `mtmd_image_tokens`
 (#12906)

* mtmd : add more api around mtmd_image_tokens

* mtmd : ability to calc image hash

* shared_ptr for mtmd_image_tokens

* move hash to user-define ID (fixed)

* fix prompt_modified

* rm redundant data member
---
 examples/llava/gemma3-cli.cpp | 11 +++--
 examples/llava/mtmd.cpp       | 88 ++++++++++++++++++++++++-----------
 examples/llava/mtmd.h         | 37 ++++++++++-----
 3 files changed, 92 insertions(+), 44 deletions(-)

diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp
index 693b738a80..3d56647506 100644
--- a/examples/llava/gemma3-cli.cpp
+++ b/examples/llava/gemma3-cli.cpp
@@ -184,18 +184,19 @@ static int eval_message(gemma3_context & ctx, common_chat_msg & msg, std::vector
     text.text          = formatted_chat.prompt;
     text.add_special   = add_bos;
     text.parse_special = true;
-    mtmd_input_chunks_ptr chunks(mtmd_tokenize(ctx.ctx_vision.get(), text, bitmaps));
-    if (chunks == nullptr) {
-        LOG_ERR("Unable to tokenize prompt\n");
+    mtmd_input_chunks chunks;
+    int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps);
+    if (res != 0) {
+        LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
         return 1;
     }
 
-    if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks.get(), ctx.n_past, 0, ctx.n_batch)) {
+    if (mtmd_helper_eval(ctx.ctx_vision.get(), ctx.lctx, chunks, ctx.n_past, 0, ctx.n_batch)) {
         LOG_ERR("Unable to eval prompt\n");
         return 1;
     }
 
-    ctx.n_past += mtmd_helper_get_n_tokens(chunks.get());
+    ctx.n_past += mtmd_helper_get_n_tokens(chunks);
 
     return 0;
 }
diff --git a/examples/llava/mtmd.cpp b/examples/llava/mtmd.cpp
index 114c274bc1..3fd5bebc6a 100644
--- a/examples/llava/mtmd.cpp
+++ b/examples/llava/mtmd.cpp
@@ -16,6 +16,7 @@ struct mtmd_context {
     struct clip_ctx * ctx_clip;
     const struct llama_model * text_model;
     std::vector image_embd_v; // image embedding vector
+
     bool print_timings;
     int n_threads;
     std::string image_marker;
@@ -24,7 +25,11 @@ struct mtmd_context {
 
     mtmd_context(const char * mmproj_fname,
                    const llama_model * text_model,
-                   const mtmd_context_params & ctx_params) : print_timings(ctx_params.print_timings), n_threads(ctx_params.n_threads), image_marker(ctx_params.image_marker) {
+                   const mtmd_context_params & ctx_params) :
+        print_timings(ctx_params.print_timings),
+        n_threads    (ctx_params.n_threads),
+        image_marker (ctx_params.image_marker)
+    {
         clip_context_params ctx_clip_params;
         ctx_clip_params.use_gpu   = ctx_params.use_gpu;
         ctx_clip_params.verbosity = ctx_params.verbosity;
@@ -49,6 +54,7 @@ struct mtmd_image_tokens {
     uint32_t ny; // number of tokens in y direction
     uint32_t n_tokens() const { return nx * ny; }
     clip_image_f32_batch batch_f32; // preprocessed image patches
+    std::string id; // optional user-defined ID, useful for KV cache tracking
 };
 
 mtmd_context * mtmd_init_from_file(const char * mmproj_fname,
@@ -88,10 +94,10 @@ static std::vector mtmd_tokenize_text_internal(
     return result;
 }
 
-mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
-                                const mtmd_input_text & text,
-                                const std::vector & bitmaps) {
-    mtmd_input_chunks * output = new mtmd_input_chunks;
+int32_t mtmd_tokenize(mtmd_context * ctx,
+                        std::vector & output,
+                        const mtmd_input_text & text,
+                        const std::vector & bitmaps) {
     auto vocab = llama_model_get_vocab(ctx->text_model);
 
     std::string prompt_modified(text.text);
@@ -105,9 +111,9 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
         string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
     }
 
-    std::vector parts = string_split_str(text.text, ctx->image_marker);
-    output->clear();
-    output->reserve(parts.size());
+    std::vector parts = string_split_str(prompt_modified, ctx->image_marker);
+    output.clear();
+    output.reserve(parts.size());
 
     size_t i_img = 0;
 
@@ -123,14 +129,14 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
             std::move(tokens),
             {},
         };
-        output->emplace_back(std::move(chunk));
+        output.emplace_back(std::move(chunk));
 
         if (&parts.back() != &part) {
             // add image token to middle of 2 parts
 
             if (i_img >= bitmaps.size()) {
                 LOG_ERR("%s: error: not enough images for %d parts\n", __func__, (int)parts.size());
-                return nullptr;
+                return 1;
             }
 
             // shim layer
@@ -145,34 +151,48 @@ mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
             bool ok = clip_image_preprocess(ctx->ctx_clip, img_u8.get(), &batch_f32);
             if (!ok) {
                 LOG_ERR("Unable to preprocess image\n");
-                return nullptr;
+                return 2;
             }
 
-            mtmd_image_tokens * image_tokens = new mtmd_image_tokens;
+            mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
             image_tokens->nx = clip_n_patches(ctx->ctx_clip); // TODO @ngxson : use clip_n_patches_by_image
             image_tokens->ny = 1; // TODO
             image_tokens->batch_f32 = std::move(batch_f32);
+            image_tokens->id = bitmaps[i_img].id; // optional
 
             mtmd_input_chunk chunk{
                 MTMD_INPUT_CHUNK_TYPE_IMAGE,
                 {},
-                image_tokens,
+                std::move(image_tokens),
             };
-            output->emplace_back(std::move(chunk));
+            output.emplace_back(std::move(chunk));
             i_img++;
         }
     }
 
-    return output;
+    return 0;
 }
 
-void mtmd_input_chunks_free(mtmd_input_chunks * chunks) {
-    for (auto & chunk : *chunks) {
-        if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE && chunk.tokens_image) {
-            delete chunk.tokens_image;
-        }
+void mtmd_image_tokens_free(mtmd_image_tokens * image_tokens) {
+    if (image_tokens) {
+        delete image_tokens;
     }
-    delete chunks;
+}
+
+size_t mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens) {
+    return image_tokens->n_tokens();
+}
+
+size_t mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens) {
+    return image_tokens->nx;
+}
+
+size_t mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens) {
+    return image_tokens->ny;
+}
+
+std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
+    return image_tokens->id;
 }
 
 int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
@@ -190,9 +210,9 @@ float * mtmd_get_output_embd(mtmd_context * ctx) {
     return ctx->image_embd_v.data();
 }
 
-size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks) {
+size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks) {
     size_t n_tokens = 0;
-    for (auto & chunk : *chunks) {
+    for (auto & chunk : chunks) {
         if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
             n_tokens += chunk.tokens_text.size();
         } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
@@ -241,7 +261,7 @@ struct decode_embd_batch {
 
 int32_t mtmd_helper_eval(mtmd_context * ctx,
         llama_context * lctx,
-        mtmd_input_chunks * chunks,
+        mtmd_input_chunks & chunks,
         llama_pos pos0,
         llama_seq_id seq_id,
         int32_t n_batch) {
@@ -249,8 +269,8 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
     llama_pos n_past = pos0;
     llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
 
-    for (auto & chunk : *chunks) {
-        bool is_last = &chunk == &chunks->back();
+    for (auto & chunk : chunks) {
+        bool is_last = &chunk == &chunks.back();
         if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
             // TODO @ngxson : may need to split into smaller batches
             text_batch.n_tokens = chunk.tokens_text.size();
@@ -279,7 +299,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
             if (ctx->print_timings) {
                 LOG_INF("encoding image...\n");
             }
-            ret = mtmd_encode(ctx, chunk.tokens_image);
+            ret = mtmd_encode(ctx, chunk.tokens_image.get());
             if (ret != 0) {
                 LOG_ERR("failed to encode image\n");
                 llama_batch_free(text_batch);
@@ -289,7 +309,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
                 LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
             }
 
-            int32_t n_tokens = chunk.tokens_image->n_tokens();
+            int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get());
             float * embd = mtmd_get_output_embd(ctx);
             decode_embd_batch batch_img(embd, n_tokens, n_past, 0);
             int64_t t1 = ggml_time_ms();
@@ -339,3 +359,15 @@ int32_t mtmd_helper_bitmap_init_from_file(const char * fname, mtmd_bitmap & outp
     std::memcpy(output.data.data(), data, output.nx * output.ny * 3);
     return 0;
 }
+
+bool mtmd_decode_use_non_causal(mtmd_context * ctx) {
+    projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
+    if (proj_type == PROJECTOR_TYPE_GEMMA3) {
+        return true;
+    }
+    return false;
+}
+
+void mtmd_image_tokens_deleter::operator()(mtmd_image_tokens * val) {
+    mtmd_image_tokens_free(val);
+}
diff --git a/examples/llava/mtmd.h b/examples/llava/mtmd.h
index 598f6947bb..78be192dd6 100644
--- a/examples/llava/mtmd.h
+++ b/examples/llava/mtmd.h
@@ -39,12 +39,18 @@ struct mtmd_bitmap {
     uint32_t nx;
     uint32_t ny;
     std::vector data;
+    std::string id; // optional user-defined id, for ex: can be set to image hash, useful for KV cache tracking
 };
 
+struct mtmd_image_tokens_deleter {
+    void operator()(mtmd_image_tokens * val); // forward declaration
+};
+using mtmd_image_tokens_ptr = std::unique_ptr;
+
 struct mtmd_input_chunk {
     mtmd_input_chunk_type type;
     std::vector tokens_text;
-    mtmd_image_tokens * tokens_image = nullptr;
+    mtmd_image_tokens_ptr tokens_image;
 };
 
 using mtmd_input_chunks = std::vector;
@@ -82,12 +88,21 @@ MTMD_API void mtmd_free(mtmd_context * ctx);
 //   3. "\ndescribe it in detail."
 // number of bitmaps must be equal to the number of image markers in the prompt
 // this function is thread-safe (shared ctx)
-MTMD_API mtmd_input_chunks * mtmd_tokenize(mtmd_context * ctx,
+// return values:
+//   0 on success
+//   1 on number of images not matching the number of markers
+//   2 on image preprocessing error
+MTMD_API int32_t mtmd_tokenize(mtmd_context * ctx,
+                                std::vector & output,
                                 const mtmd_input_text & text,
                                 const std::vector & bitmaps);
 
-// free image chunk data
-MTMD_API void mtmd_input_chunks_free(mtmd_input_chunks * chunks);
+// access mtmd_image_tokens
+MTMD_API size_t      mtmd_image_tokens_get_n_tokens(const mtmd_image_tokens * image_tokens);
+MTMD_API size_t      mtmd_image_tokens_get_nx(const mtmd_image_tokens * image_tokens);
+MTMD_API size_t      mtmd_image_tokens_get_ny(const mtmd_image_tokens * image_tokens);
+MTMD_API std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens);
+MTMD_API void        mtmd_image_tokens_free(mtmd_image_tokens * image_tokens);
 
 // returns 0 on success
 MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
@@ -96,12 +111,17 @@ MTMD_API int32_t mtmd_encode(mtmd_context * ctx,
 // get output embeddings from the last encode pass
 MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
 
+// whether we need to set non-causal mask before llama_decode
+MTMD_API bool mtmd_decode_use_non_causal(mtmd_context * ctx);
+
+
+
 //
 // helper functions (can be implemented based on other functions)
 //
 
 // helper to count the total number of tokens from a list of chunks, useful to keep track of n_past
-MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks);
+MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks & chunks);
 
 // helper function that automatically:
 // 1. run llama_decode() on text chunks
@@ -110,7 +130,7 @@ MTMD_API size_t mtmd_helper_get_n_tokens(mtmd_input_chunks * chunks);
 // otherwise, returns 0 on success
 MTMD_API int32_t mtmd_helper_eval(mtmd_context * ctx,
                                 llama_context * lctx,
-                                mtmd_input_chunks * chunks,
+                                mtmd_input_chunks & chunks,
                                 llama_pos pos0,
                                 llama_seq_id seq_id,
                                 int32_t n_batch);
@@ -132,11 +152,6 @@ struct mtmd_context_deleter {
 };
 using mtmd_context_ptr = std::unique_ptr;
 
-struct mtmd_input_chunks_deleter {
-    void operator()(mtmd_input_chunks * val) { mtmd_input_chunks_free(val); }
-};
-using mtmd_input_chunks_ptr = std::unique_ptr;
-
 #else
 
 static_assert(false && "C header is not yet supported by this library");

From 8d6600576318dfc6b091fca744b0fd36a5e5255f Mon Sep 17 00:00:00 2001
From: Akarshan Biswas 
Date: Fri, 18 Apr 2025 19:27:56 +0530
Subject: [PATCH 11/39] SYCL: Refactor and enable FP16 in binary broadcast OPs
 (#12975)

* SYCL: refactor move to a separate file

* Fix binbcast

* Remove duplicates

* fix include formatting

* fix typo
---
 ggml/src/ggml-sycl/backend.hpp      |   1 +
 ggml/src/ggml-sycl/binbcast.cpp     | 350 ++++++++++++++++++++++++++++
 ggml/src/ggml-sycl/binbcast.hpp     |  39 ++++
 ggml/src/ggml-sycl/common.hpp       | 281 ----------------------
 ggml/src/ggml-sycl/element_wise.cpp |  47 ----
 ggml/src/ggml-sycl/element_wise.hpp |  32 +--
 ggml/src/ggml-sycl/ggml-sycl.cpp    |  15 +-
 7 files changed, 393 insertions(+), 372 deletions(-)
 create mode 100644 ggml/src/ggml-sycl/binbcast.cpp
 create mode 100644 ggml/src/ggml-sycl/binbcast.hpp

diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp
index 73d807cab0..de814ef91a 100644
--- a/ggml/src/ggml-sycl/backend.hpp
+++ b/ggml/src/ggml-sycl/backend.hpp
@@ -13,6 +13,7 @@
 #ifndef GGML_SYCL_BACKEND_HPP
 #define GGML_SYCL_BACKEND_HPP
 
+#include "binbcast.hpp"
 #include "concat.hpp"
 #include "common.hpp"
 #include "conv.hpp"
diff --git a/ggml/src/ggml-sycl/binbcast.cpp b/ggml/src/ggml-sycl/binbcast.cpp
new file mode 100644
index 0000000000..0a9d3a927c
--- /dev/null
+++ b/ggml/src/ggml-sycl/binbcast.cpp
@@ -0,0 +1,350 @@
+#include "binbcast.hpp"
+
+#include 
+#include 
+#include 
+
+#include "ggml.h"
+
+template
+static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
+        int ne0, int ne1, int ne2, int ne3,
+        int ne10, int ne11, int ne12, int ne13,
+        /*int s0, */ int s1,  int s2,  int s3,
+        /*int s00,*/ int s01, int s02, int s03,
+        /*int s10,*/ int s11, int s12, int s13,
+        const sycl::nd_item<3> &item_ct1) {
+    const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                    item_ct1.get_local_id(2);
+    const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                    item_ct1.get_local_id(1));
+    const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
+                    item_ct1.get_local_id(0)) /
+                   ne3;
+    const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
+                    item_ct1.get_local_id(0)) %
+                   ne3;
+
+    if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
+        return;
+    }
+
+    const int i11 = i1 % ne11;
+    const int i12 = i2 % ne12;
+    const int i13 = i3 % ne13;
+
+    const size_t i_src0 =  i3*s03 +  i2*s02 +  i1*s01;
+    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
+    const size_t i_dst  =  i3*s3  +  i2*s2  +  i1*s1;
+
+    const src0_t * src0_row = src0 + i_src0;
+    const src1_t * src1_row = src1 + i_src1;
+    dst_t * dst_row = dst + i_dst;
+
+    for (int i0 = i0s; i0 < ne0;
+         i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
+        const int i10 = i0 % ne10;
+        dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+    }
+}
+
+template
+static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
+        int ne0, int ne1, int ne2, int ne3,
+        int ne10, int ne11, int ne12, int ne13,
+        /*int s0, */ int s1,  int s2,  int s3,
+        /*int s00,*/ int s01, int s02, int s03,
+        /*int s10,*/ int s11, int s12, int s13,
+        const sycl::nd_item<3> &item_ct1) {
+
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+
+    const int i3 = i/(ne2*ne1*ne0);
+    const int i2 = (i/(ne1*ne0)) % ne2;
+    const int i1 = (i/ne0) % ne1;
+    const int i0 = i % ne0;
+
+    if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
+        return;
+    }
+
+    const int i11 = i1 % ne11;
+    const int i12 = i2 % ne12;
+    const int i13 = i3 % ne13;
+
+    const size_t i_src0 =  i3*s03 +  i2*s02 +  i1*s01;
+    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
+    const size_t i_dst  =  i3*s3  +  i2*s2  +  i1*s1;
+
+    const src0_t * src0_row = src0 + i_src0;
+    const src1_t * src1_row = src1 + i_src1;
+    dst_t * dst_row = dst + i_dst;
+
+    const int i10 = i0 % ne10;
+    dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+}
+
+
+template
+struct bin_bcast_sycl {
+    template 
+    void operator()(const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const int64_t ne00,
+                    const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne10, const int64_t ne11,
+                    const int64_t ne12, const int64_t ne13, const int64_t ne0, const int64_t ne1, const int64_t ne2,
+                    const int64_t ne3, const size_t nb00, const size_t nb01, const size_t nb02, const size_t nb03,
+                    const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0,
+                    const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous,
+                    const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) {
+        int nr0 = ne10 / ne0;
+        int nr1 = ne11/ne1;
+        int nr2 = ne12/ne2;
+        int nr3 = ne13/ne3;
+
+        int nr[4] = { nr0, nr1, nr2, nr3 };
+
+        // collapse dimensions until first broadcast dimension
+        int64_t cne[] = {ne0, ne1, ne2, ne3};
+        int64_t cne0[] = {ne00, ne01, ne02, ne03};
+        int64_t cne1[] = {ne10, ne11, ne12, ne13};
+        size_t cnb[] = {nb0, nb1, nb2, nb3};
+        size_t cnb0[] = {nb00, nb01, nb02, nb03};
+        size_t cnb1[] = {nb10, nb11, nb12, nb13};
+        auto collapse = [](int64_t cne[]) {
+            cne[0] *= cne[1];
+            cne[1] = cne[2];
+            cne[2] = cne[3];
+            cne[3] = 1;
+        };
+
+        auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
+            cnb[1] *= cne[1];
+            cnb[2] *= cne[2];
+            cnb[3] *= cne[3];
+        };
+
+        if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) {
+            for (int i = 0; i < 4; i++) {
+                if (nr[i] != 1) {
+                    break;
+                }
+                if (i > 0) {
+                    collapse_nb(cnb, cne);
+                    collapse_nb(cnb0, cne0);
+                    collapse_nb(cnb1, cne1);
+                    collapse(cne);
+                    collapse(cne0);
+                    collapse(cne1);
+                }
+            }
+        }
+        {
+            int64_t ne0 = cne[0];
+            int64_t ne1 = cne[1];
+            int64_t ne2 = cne[2];
+            int64_t ne3 = cne[3];
+
+            int64_t ne10 = cne1[0];
+            int64_t ne11 = cne1[1];
+            int64_t ne12 = cne1[2];
+            int64_t ne13 = cne1[3];
+
+            size_t nb0 = cnb[0];
+            size_t nb1 = cnb[1];
+            size_t nb2 = cnb[2];
+            size_t nb3 = cnb[3];
+
+            size_t nb00 = cnb0[0];
+            size_t nb01 = cnb0[1];
+            size_t nb02 = cnb0[2];
+            size_t nb03 = cnb0[3];
+
+            size_t nb10 = cnb1[0];
+            size_t nb11 = cnb1[1];
+            size_t nb12 = cnb1[2];
+            size_t nb13 = cnb1[3];
+
+            size_t s0 = nb0 / sizeof(dst_t);
+            size_t s1 = nb1 / sizeof(dst_t);
+            size_t s2 = nb2 / sizeof(dst_t);
+            size_t s3 = nb3 / sizeof(dst_t);
+
+            size_t s10 = nb10 / sizeof(src1_t);
+            size_t s11 = nb11 / sizeof(src1_t);
+            size_t s12 = nb12 / sizeof(src1_t);
+            size_t s13 = nb13 / sizeof(src1_t);
+
+            size_t s00 = nb00 / sizeof(src0_t);
+            size_t s01 = nb01 / sizeof(src0_t);
+            size_t s02 = nb02 / sizeof(src0_t);
+            size_t s03 = nb03 / sizeof(src0_t);
+
+            GGML_UNUSED(s00);
+
+            GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
+            GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
+            GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
+            GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
+
+            GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
+            GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
+            GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
+            GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
+
+            GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
+            GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
+            GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
+            GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
+
+            GGML_ASSERT(s0 == 1);
+            GGML_ASSERT(s10 == 1);
+
+            const int block_size = 128;
+
+            int64_t hne0 = std::max(ne0/2LL, 1LL);
+
+            sycl::range<3> block_dims(1, 1, 1);
+            block_dims[2] = std::min(hne0, block_size);
+            block_dims[1] = std::min(
+                ne1, block_size / (unsigned int)block_dims[2]);
+            block_dims[0] = std::min(
+                std::min(
+                    ne2 * ne3, block_size / (unsigned int)block_dims[2] /
+                                   (unsigned int)block_dims[1]),
+                64U);
+
+            sycl::range<3> block_nums(
+                (ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
+                (ne1 + block_dims[1] - 1) / block_dims[1],
+                (hne0 + block_dims[2] - 1) / block_dims[2]);
+
+            if (block_nums[0] > 65535) {
+                // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
+                int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
+                {
+                    dpct::has_capability_or_fail(stream->get_device(),
+                                                 {sycl::aspect::fp16});
+
+                    stream->parallel_for(
+                        sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
+                                              sycl::range<3>(1, 1, block_size),
+                                          sycl::range<3>(1, 1, block_size)),
+                        [=](sycl::nd_item<3> item_ct1) {
+                            k_bin_bcast_unravel(
+                                src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
+                                ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
+                                s03, s11, s12, s13, item_ct1);
+                        });
+                }
+            } else {
+                /*
+                DPCT1049:16: The work-group size passed to the SYCL kernel may
+                exceed the limit. To get the device limit, query
+                info::device::max_work_group_size. Adjust the work-group size if
+                needed.
+                */
+                dpct::has_capability_or_fail(stream->get_device(),
+                                             {sycl::aspect::fp16});
+
+                stream->parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1,
+                                            ne2, ne3, ne10, ne11, ne12, ne13,
+                                            s1, s2, s3, s01, s02, s03, s11, s12, s13,
+                                            item_ct1);
+                    });
+            }
+        }
+    }
+};
+
+template 
+inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1,
+                                   ggml_tensor * dst) {
+    dpct::queue_ptr main_stream = ctx.stream();
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+        op()((const float *) src0->data, (const float *) src1->data, (float *) dst->data, ne00, ne01, ne02, ne03, ne10,
+             ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3,
+             ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+        op()((const sycl::half *) src0->data, (const sycl::half *) src1->data, (sycl::half *) dst->data, ne00, ne01,
+             ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13,
+             nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst),
+             main_stream);
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
+        op()((const sycl::half *) src0->data, (const float *) src1->data, (sycl::half *) dst->data, ne00, ne01, ne02,
+             ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1,
+             nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
+    } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
+        op()((const int32_t *) src0->data, (const int32_t *) src1->data, (int32_t *) dst->data, ne00, ne01, ne02, ne03,
+             ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
+             nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
+    } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
+        op()((const int16_t *) src0->data, (const int16_t *) src1->data, (int16_t *) dst->data, ne00, ne01, ne02, ne03,
+             ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2,
+             nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream);
+    } else {
+        fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type),
+                ggml_type_name(src0->type), ggml_type_name(src1->type));
+        GGML_ABORT("fatal error");
+    }
+}
+
+inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
+
+    ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst);
+}
+
+inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
+
+    ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst);
+}
+
+inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
+
+    ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst);
+}
+
+inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
+
+    ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst);
+}
+
+inline void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
+    ggml_sycl_op_bin_bcast>(ctx, dst, dst->src[0], dst);
+}
+
+
+void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_add(ctx, dst);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_sub(ctx, dst);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_mul(ctx, dst);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_div(ctx, dst);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
+void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_repeat(ctx, dst);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
diff --git a/ggml/src/ggml-sycl/binbcast.hpp b/ggml/src/ggml-sycl/binbcast.hpp
new file mode 100644
index 0000000000..9cce0f053a
--- /dev/null
+++ b/ggml/src/ggml-sycl/binbcast.hpp
@@ -0,0 +1,39 @@
+#ifndef GGML_SYCL_BINBCAST_HPP
+#define GGML_SYCL_BINBCAST_HPP
+#include "common.hpp"
+
+
+static __dpct_inline__ float op_repeat(const float a, const float b) {
+    return b;
+    GGML_UNUSED(a);
+}
+
+static __dpct_inline__ float op_add(const float a, const float b) {
+    return a + b;
+}
+
+static __dpct_inline__ float op_sub(const float a, const float b) {
+    return a - b;
+}
+
+static __dpct_inline__ float op_mul(const float a, const float b) {
+    return a * b;
+}
+
+static __dpct_inline__ float op_div(const float a, const float b) {
+    return a / b;
+}
+
+void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
+void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
+void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
+void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
+void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
+
+#endif //GGML_SYCL_BINBCAST_HPP
+
diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp
index 3e1ceeaa49..96becabc85 100644
--- a/ggml/src/ggml-sycl/common.hpp
+++ b/ggml/src/ggml-sycl/common.hpp
@@ -494,286 +494,5 @@ static __dpct_inline__ Tp* get_pointer(sycl::local_accessor acc) {
 
 int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size);
 
-template
-static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
-        int ne0, int ne1, int ne2, int ne3,
-        int ne10, int ne11, int ne12, int ne13,
-        /*int s0, */ int s1,  int s2,  int s3,
-        /*int s00,*/ int s01, int s02, int s03,
-        /*int s10,*/ int s11, int s12, int s13,
-        const sycl::nd_item<3> &item_ct1) {
-    const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                    item_ct1.get_local_id(2);
-    const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                    item_ct1.get_local_id(1));
-    const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
-                    item_ct1.get_local_id(0)) /
-                   ne3;
-    const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
-                    item_ct1.get_local_id(0)) %
-                   ne3;
-
-    if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
-        return;
-    }
-
-    const int i11 = i1 % ne11;
-    const int i12 = i2 % ne12;
-    const int i13 = i3 % ne13;
-
-    const size_t i_src0 =  i3*s03 +  i2*s02 +  i1*s01;
-    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
-    const size_t i_dst  =  i3*s3  +  i2*s2  +  i1*s1;
-
-    const src0_t * src0_row = src0 + i_src0;
-    const src1_t * src1_row = src1 + i_src1;
-    dst_t * dst_row = dst + i_dst;
-
-    for (int i0 = i0s; i0 < ne0;
-         i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
-        const int i10 = i0 % ne10;
-        dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
-    }
-}
-
-template
-static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
-        int ne0, int ne1, int ne2, int ne3,
-        int ne10, int ne11, int ne12, int ne13,
-        /*int s0, */ int s1,  int s2,  int s3,
-        /*int s00,*/ int s01, int s02, int s03,
-        /*int s10,*/ int s11, int s12, int s13,
-        const sycl::nd_item<3> &item_ct1) {
-
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    const int i3 = i/(ne2*ne1*ne0);
-    const int i2 = (i/(ne1*ne0)) % ne2;
-    const int i1 = (i/ne0) % ne1;
-    const int i0 = i % ne0;
-
-    if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
-        return;
-    }
-
-    const int i11 = i1 % ne11;
-    const int i12 = i2 % ne12;
-    const int i13 = i3 % ne13;
-
-    const size_t i_src0 =  i3*s03 +  i2*s02 +  i1*s01;
-    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
-    const size_t i_dst  =  i3*s3  +  i2*s2  +  i1*s1;
-
-    const src0_t * src0_row = src0 + i_src0;
-    const src1_t * src1_row = src1 + i_src1;
-    dst_t * dst_row = dst + i_dst;
-
-    const int i10 = i0 % ne10;
-    dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
-}
-
-
-template
-struct bin_bcast_sycl {
-    template 
-    void operator()(ggml_backend_sycl_context & ctx,
-                    const struct ggml_tensor *src0,
-                    const struct ggml_tensor *src1, struct ggml_tensor *dst,
-                    const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
-                    queue_ptr stream) {
-
-        GGML_TENSOR_BINARY_OP_LOCALS
-
-        int nr0 = ne10/ne0;
-        int nr1 = ne11/ne1;
-        int nr2 = ne12/ne2;
-        int nr3 = ne13/ne3;
-
-        int nr[4] = { nr0, nr1, nr2, nr3 };
-
-        // collapse dimensions until first broadcast dimension
-        int64_t cne[] = {ne0, ne1, ne2, ne3};
-        int64_t cne0[] = {ne00, ne01, ne02, ne03};
-        int64_t cne1[] = {ne10, ne11, ne12, ne13};
-        size_t cnb[] = {nb0, nb1, nb2, nb3};
-        size_t cnb0[] = {nb00, nb01, nb02, nb03};
-        size_t cnb1[] = {nb10, nb11, nb12, nb13};
-        auto collapse = [](int64_t cne[]) {
-            cne[0] *= cne[1];
-            cne[1] = cne[2];
-            cne[2] = cne[3];
-            cne[3] = 1;
-        };
-
-        auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
-            cnb[1] *= cne[1];
-            cnb[2] *= cne[2];
-            cnb[3] *= cne[3];
-        };
-
-        if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
-            for (int i = 0; i < 4; i++) {
-                if (nr[i] != 1) {
-                    break;
-                }
-                if (i > 0) {
-                    collapse_nb(cnb, cne);
-                    collapse_nb(cnb0, cne0);
-                    collapse_nb(cnb1, cne1);
-                    collapse(cne);
-                    collapse(cne0);
-                    collapse(cne1);
-                }
-            }
-        }
-        {
-            int64_t ne0 = cne[0];
-            int64_t ne1 = cne[1];
-            int64_t ne2 = cne[2];
-            int64_t ne3 = cne[3];
-
-            int64_t ne10 = cne1[0];
-            int64_t ne11 = cne1[1];
-            int64_t ne12 = cne1[2];
-            int64_t ne13 = cne1[3];
-
-            size_t nb0 = cnb[0];
-            size_t nb1 = cnb[1];
-            size_t nb2 = cnb[2];
-            size_t nb3 = cnb[3];
-
-            size_t nb00 = cnb0[0];
-            size_t nb01 = cnb0[1];
-            size_t nb02 = cnb0[2];
-            size_t nb03 = cnb0[3];
-
-            size_t nb10 = cnb1[0];
-            size_t nb11 = cnb1[1];
-            size_t nb12 = cnb1[2];
-            size_t nb13 = cnb1[3];
-
-            size_t s0 = nb0 / sizeof(dst_t);
-            size_t s1 = nb1 / sizeof(dst_t);
-            size_t s2 = nb2 / sizeof(dst_t);
-            size_t s3 = nb3 / sizeof(dst_t);
-
-            size_t s10 = nb10 / sizeof(src1_t);
-            size_t s11 = nb11 / sizeof(src1_t);
-            size_t s12 = nb12 / sizeof(src1_t);
-            size_t s13 = nb13 / sizeof(src1_t);
-
-            size_t s00 = nb00 / sizeof(src0_t);
-            size_t s01 = nb01 / sizeof(src0_t);
-            size_t s02 = nb02 / sizeof(src0_t);
-            size_t s03 = nb03 / sizeof(src0_t);
-
-            GGML_UNUSED(s00);
-
-            GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
-            GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
-            GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
-            GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
-
-            GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
-            GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
-            GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
-            GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
-
-            GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
-            GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
-            GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
-            GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
-
-            GGML_ASSERT(s0 == 1);
-            GGML_ASSERT(s10 == 1);
-
-            const int block_size = 128;
-
-            int64_t hne0 = std::max(ne0/2LL, 1LL);
-
-            sycl::range<3> block_dims(1, 1, 1);
-            block_dims[2] = std::min(hne0, block_size);
-            block_dims[1] = std::min(
-                ne1, block_size / (unsigned int)block_dims[2]);
-            block_dims[0] = std::min(
-                std::min(
-                    ne2 * ne3, block_size / (unsigned int)block_dims[2] /
-                                   (unsigned int)block_dims[1]),
-                64U);
-
-            sycl::range<3> block_nums(
-                (ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
-                (ne1 + block_dims[1] - 1) / block_dims[1],
-                (hne0 + block_dims[2] - 1) / block_dims[2]);
-
-            if (block_nums[0] > 65535) {
-                // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
-                int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
-                {
-                    dpct::has_capability_or_fail(stream->get_device(),
-                                                 {sycl::aspect::fp16});
-
-                    stream->parallel_for(
-                        sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
-                                              sycl::range<3>(1, 1, block_size),
-                                          sycl::range<3>(1, 1, block_size)),
-                        [=](sycl::nd_item<3> item_ct1) {
-                            k_bin_bcast_unravel(
-                                src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
-                                ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02,
-                                s03, s11, s12, s13, item_ct1);
-                        });
-                }
-            } else {
-                /*
-                DPCT1049:16: The work-group size passed to the SYCL kernel may
-                exceed the limit. To get the device limit, query
-                info::device::max_work_group_size. Adjust the work-group size if
-                needed.
-                */
-                dpct::has_capability_or_fail(stream->get_device(),
-                                             {sycl::aspect::fp16});
-
-                stream->parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1,
-                                            ne2, ne3, ne10, ne11, ne12, ne13,
-                                            s1, s2, s3, s01, s02, s03, s11, s12, s13,
-                                            item_ct1);
-                    });
-            }
-        }
-        GGML_UNUSED(ctx);
-    }
-};
-
-template 
-inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                   const ggml_tensor *src1, ggml_tensor *dst) {
-    dpct::queue_ptr main_stream = ctx.stream();
-
-    if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-        op()(ctx, src0, src1, dst, (const float *)src0->data, (const float *)src1->data, (float *)dst->data, main_stream);
-    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
-        op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data,
-             (sycl::half *)dst->data, main_stream);
-    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
-        op()(ctx, src0, src1, dst, (const sycl::half *)src0->data, (const float *)src1->data, (float *)dst->data,
-             main_stream);
-    } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
-        op()(ctx, src0, src1, dst, (const int32_t *)src0->data, (const int32_t *)src1->data, (int32_t *)dst->data,
-             main_stream);
-    } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
-        op()(ctx, src0, src1, dst, (const int16_t *)src0->data, (const int16_t *)src1->data, (int16_t *)dst->data,
-             main_stream);
-    } else {
-        fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
-            ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
-        GGML_ABORT("fatal error");
-    }
-}
-
 bool gpu_has_xmx(sycl::device &dev);
 #endif // GGML_SYCL_COMMON_HPP
diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp
index b36aa7a9d2..fc25d98ddf 100644
--- a/ggml/src/ggml-sycl/element_wise.cpp
+++ b/ggml/src/ggml-sycl/element_wise.cpp
@@ -1261,27 +1261,6 @@ inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst)
 }
 
 
-inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
-
-    ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst);
-}
-
-inline void ggml_sycl_op_sub(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
-
-    ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst);
-}
-
-inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
-
-    ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst);
-}
-
-inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
-
-    ggml_sycl_op_bin_bcast>(ctx, dst->src[0], dst->src[1], dst);
-}
-
-
 void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type));
     ggml_sycl_op_sqrt(ctx, dst);
@@ -1409,29 +1388,3 @@ void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-
-
-
-void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_add(ctx, dst);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_sub(ctx, dst);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_mul(ctx, dst);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_div(ctx, dst);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp
index 76d91ba108..e623cb56f7 100644
--- a/ggml/src/ggml-sycl/element_wise.hpp
+++ b/ggml/src/ggml-sycl/element_wise.hpp
@@ -10,27 +10,6 @@ T neg_infinity() {
     return -std::numeric_limits::infinity();
 }
 
-static __dpct_inline__ float op_repeat(const float a, const float b) {
-    return b;
-    GGML_UNUSED(a);
-}
-
-static __dpct_inline__ float op_add(const float a, const float b) {
-    return a + b;
-}
-
-static __dpct_inline__ float op_sub(const float a, const float b) {
-    return a - b;
-}
-
-static __dpct_inline__ float op_mul(const float a, const float b) {
-    return a * b;
-}
-
-static __dpct_inline__ float op_div(const float a, const float b) {
-    return a / b;
-}
-
 template
 struct typed_data {
     const T * src;
@@ -87,14 +66,5 @@ void ggml_sycl_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
 void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
 
-// ---------
-
-void ggml_sycl_add(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
-
-void ggml_sycl_sub(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
-
-void ggml_sycl_mul(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
-
-void ggml_sycl_div(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
-
 #endif // GGML_SYCL_ELEMENTWISE_HPP
+
diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp
index 4d2fda0bfa..1de34c9629 100644
--- a/ggml/src/ggml-sycl/ggml-sycl.cpp
+++ b/ggml/src/ggml-sycl/ggml-sycl.cpp
@@ -1967,11 +1967,6 @@ catch (sycl::exception const &exc) {
   std::exit(1);
 }
 
-static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
-    ggml_sycl_op_bin_bcast>(ctx, dst, dst->src[0], dst);
-}
-
-
 inline void ggml_sycl_op_mul_mat_sycl(
     ggml_backend_sycl_context & ctx,
     const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
@@ -2600,12 +2595,6 @@ catch (sycl::exception const &exc) {
 }
 
 
-static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_repeat(ctx, dst);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
 static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     GGML_SYCL_DEBUG("call %s\n", __func__);
     ggml_sycl_op_get_rows(ctx, dst);
@@ -3972,7 +3961,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_ARGMAX:
         case GGML_OP_NONE:
         case GGML_OP_RESHAPE:
-        case GGML_OP_REPEAT:
         case GGML_OP_VIEW:
         case GGML_OP_PERMUTE:
         case GGML_OP_TRANSPOSE:
@@ -3982,7 +3970,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_SUB:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
-            return (op->src[0]->type == GGML_TYPE_F32);
+        case GGML_OP_REPEAT:
+            return true;
         case GGML_OP_SQR:
         case GGML_OP_SQRT:
         case GGML_OP_SIN:

From 35370ba94593c3abc9550328377d461fc4f822b7 Mon Sep 17 00:00:00 2001
From: Xuan-Son Nguyen 
Date: Fri, 18 Apr 2025 19:58:12 +0200
Subject: [PATCH 12/39] server : use std::move whenever possible (#12936)

* server : use std::move whenever possible

* use r-value ref

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov 

* make task creation scoped

* restore std::move

* fix task_id not set correctly

* apply changes from suggestion

Co-authored-by: ggerganov 

---------

Co-authored-by: Georgi Gerganov 
---
 examples/server/server.cpp | 245 ++++++++++++++++++++-----------------
 1 file changed, 132 insertions(+), 113 deletions(-)

diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index d87bda1a0d..c580ec1232 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -1552,29 +1552,30 @@ struct server_queue {
     std::condition_variable condition_tasks;
 
     // callback functions
-    std::function callback_new_task;
-    std::function        callback_update_slots;
+    std::function callback_new_task;
+    std::function           callback_update_slots;
 
     // Add a new task to the end of the queue
-    int post(server_task task, bool front = false) {
+    int post(server_task && task, bool front = false) {
         std::unique_lock lock(mutex_tasks);
         GGML_ASSERT(task.id != -1);
         // if this is cancel task make sure to clean up pending tasks
         if (task.type == SERVER_TASK_TYPE_CANCEL) {
             cleanup_pending_task(task.id_target);
         }
-        QUE_DBG("new task, id = %d, front = %d\n", task.id, front);
+        const int task_id = task.id;
+        QUE_DBG("new task, id = %d, front = %d\n", task_id, front);
         if (front) {
             queue_tasks.push_front(std::move(task));
         } else {
             queue_tasks.push_back(std::move(task));
         }
         condition_tasks.notify_one();
-        return task.id;
+        return task_id;
     }
 
     // multi-task version of post()
-    int post(std::vector & tasks, bool front = false) {
+    int post(std::vector && tasks, bool front = false) {
         std::unique_lock lock(mutex_tasks);
         for (auto & task : tasks) {
             if (task.id == -1) {
@@ -1596,7 +1597,7 @@ struct server_queue {
     }
 
     // Add a new task, but defer until one slot is available
-    void defer(server_task task) {
+    void defer(server_task && task) {
         std::unique_lock lock(mutex_tasks);
         QUE_DBG("defer task, id = %d\n", task.id);
         queue_tasks_deferred.push_back(std::move(task));
@@ -1611,7 +1612,7 @@ struct server_queue {
     }
 
     // Register function to process a new task
-    void on_new_task(std::function callback) {
+    void on_new_task(std::function callback) {
         callback_new_task = std::move(callback);
     }
 
@@ -1660,7 +1661,7 @@ struct server_queue {
                     lock.unlock();
                     break;
                 }
-                server_task task = queue_tasks.front();
+                server_task task = std::move(queue_tasks.front());
                 queue_tasks.pop_front();
                 lock.unlock();
 
@@ -2004,7 +2005,7 @@ struct server_context {
 
             slot.reset();
 
-            slots.push_back(slot);
+            slots.push_back(std::move(slot));
         }
 
         default_generation_settings_for_props = slots[0].to_json();
@@ -2105,7 +2106,7 @@ struct server_context {
         return true;
     }
 
-    bool launch_slot_with_task(server_slot & slot, const server_task & task) {
+    bool launch_slot_with_task(server_slot & slot, server_task && task) {
         slot.reset();
         slot.id_task       = task.id;
         slot.index         = task.index;
@@ -2113,10 +2114,10 @@ struct server_context {
         slot.params        = std::move(task.params);
         slot.prompt_tokens = std::move(task.prompt_tokens);
 
-        if (!are_lora_equal(task.params.lora, slot.lora)) {
+        if (!are_lora_equal(slot.params.lora, slot.lora)) {
             // if lora is changed, we cannot reuse cached tokens
             slot.cache_tokens.clear();
-            slot.lora = task.params.lora;
+            slot.lora = slot.params.lora;
         }
 
         bool can_detokenize = can_be_detokenized(ctx, slot.prompt_tokens);
@@ -2547,10 +2548,10 @@ struct server_context {
             server_task task(SERVER_TASK_TYPE_CANCEL);
             task.id_target = id_task;
             queue_results.remove_waiting_task_id(id_task);
-            cancel_tasks.push_back(task);
+            cancel_tasks.push_back(std::move(task));
         }
         // push to beginning of the queue, so it has highest priority
-        queue_tasks.post(cancel_tasks, true);
+        queue_tasks.post(std::move(cancel_tasks), true);
     }
 
     // receive the results from task(s)
@@ -2637,7 +2638,7 @@ struct server_context {
     // Functions to process the task
     //
 
-    void process_single_task(server_task task) {
+    void process_single_task(server_task && task) {
         switch (task.type) {
             case SERVER_TASK_TYPE_COMPLETION:
             case SERVER_TASK_TYPE_INFILL:
@@ -2651,17 +2652,17 @@ struct server_context {
                     if (slot == nullptr) {
                         // if no slot is available, we defer this task for processing later
                         SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id);
-                        queue_tasks.defer(task);
+                        queue_tasks.defer(std::move(task));
                         break;
                     }
                     if (slot->is_processing()) {
                         // if requested slot is unavailable, we defer this task for processing later
                         SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
-                        queue_tasks.defer(task);
+                        queue_tasks.defer(std::move(task));
                         break;
                     }
 
-                    if (!launch_slot_with_task(*slot, task)) {
+                    if (!launch_slot_with_task(*slot, std::move(task))) {
                         SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
                         break;
                     }
@@ -2740,7 +2741,7 @@ struct server_context {
                     if (slot->is_processing()) {
                         // if requested slot is unavailable, we defer this task for processing later
                         SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
-                        queue_tasks.defer(task);
+                        queue_tasks.defer(std::move(task));
                         break;
                     }
 
@@ -2776,7 +2777,7 @@ struct server_context {
                     if (slot->is_processing()) {
                         // if requested slot is unavailable, we defer this task for processing later
                         SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
-                        queue_tasks.defer(task);
+                        queue_tasks.defer(std::move(task));
                         break;
                     }
 
@@ -2819,7 +2820,7 @@ struct server_context {
                     if (slot->is_processing()) {
                         // if requested slot is unavailable, we defer this task for processing later
                         SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
-                        queue_tasks.defer(task);
+                        queue_tasks.defer(std::move(task));
                         break;
                     }
 
@@ -2871,7 +2872,7 @@ struct server_context {
 
             server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
             task.id = queue_tasks.get_new_id();
-            queue_tasks.post(task);
+            queue_tasks.post(std::move(task));
         }
 
         // apply context-shift if needed
@@ -3633,14 +3634,17 @@ int main(int argc, char ** argv) {
         }
 
         // request slots data using task queue
-        server_task task(SERVER_TASK_TYPE_METRICS);
-        task.id = ctx_server.queue_tasks.get_new_id();
-        ctx_server.queue_results.add_waiting_task_id(task.id);
-        ctx_server.queue_tasks.post(task, true); // high-priority task
+        int task_id = ctx_server.queue_tasks.get_new_id();
+        {
+            server_task task(SERVER_TASK_TYPE_METRICS);
+            task.id = task_id;
+            ctx_server.queue_results.add_waiting_task_id(task_id);
+            ctx_server.queue_tasks.post(std::move(task), true); // high-priority task
+        }
 
         // get the result
-        server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
-        ctx_server.queue_results.remove_waiting_task_id(task.id);
+        server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
+        ctx_server.queue_results.remove_waiting_task_id(task_id);
 
         if (result->is_error()) {
             res_error(res, result->to_json());
@@ -3669,16 +3673,17 @@ int main(int argc, char ** argv) {
         }
 
         // request slots data using task queue
-        server_task task(SERVER_TASK_TYPE_METRICS);
-        task.id = ctx_server.queue_tasks.get_new_id();
-        task.metrics_reset_bucket = true;
-
-        ctx_server.queue_results.add_waiting_task_id(task.id);
-        ctx_server.queue_tasks.post(task, true); // high-priority task
+        int task_id = ctx_server.queue_tasks.get_new_id();
+        {
+            server_task task(SERVER_TASK_TYPE_METRICS);
+            task.id = task_id;
+            ctx_server.queue_results.add_waiting_task_id(task_id);
+            ctx_server.queue_tasks.post(std::move(task), true); // high-priority task
+        }
 
         // get the result
-        server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
-        ctx_server.queue_results.remove_waiting_task_id(task.id);
+        server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
+        ctx_server.queue_results.remove_waiting_task_id(task_id);
 
         if (result->is_error()) {
             res_error(res, result->to_json());
@@ -3775,17 +3780,20 @@ int main(int argc, char ** argv) {
         }
         std::string filepath = params.slot_save_path + filename;
 
-        server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
-        task.id = ctx_server.queue_tasks.get_new_id();
-        task.slot_action.slot_id  = id_slot;
-        task.slot_action.filename = filename;
-        task.slot_action.filepath = filepath;
+        int task_id = ctx_server.queue_tasks.get_new_id();
+        {
+            server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
+            task.id = task_id;
+            task.slot_action.slot_id  = id_slot;
+            task.slot_action.filename = filename;
+            task.slot_action.filepath = filepath;
 
-        ctx_server.queue_results.add_waiting_task_id(task.id);
-        ctx_server.queue_tasks.post(task);
+            ctx_server.queue_results.add_waiting_task_id(task_id);
+            ctx_server.queue_tasks.post(std::move(task));
+        }
 
-        server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
-        ctx_server.queue_results.remove_waiting_task_id(task.id);
+        server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
+        ctx_server.queue_results.remove_waiting_task_id(task_id);
 
         if (result->is_error()) {
             res_error(res, result->to_json());
@@ -3804,17 +3812,20 @@ int main(int argc, char ** argv) {
         }
         std::string filepath = params.slot_save_path + filename;
 
-        server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
-        task.id = ctx_server.queue_tasks.get_new_id();
-        task.slot_action.slot_id  = id_slot;
-        task.slot_action.filename = filename;
-        task.slot_action.filepath = filepath;
+        int task_id = ctx_server.queue_tasks.get_new_id();
+        {
+            server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
+            task.id = task_id;
+            task.slot_action.slot_id  = id_slot;
+            task.slot_action.filename = filename;
+            task.slot_action.filepath = filepath;
 
-        ctx_server.queue_results.add_waiting_task_id(task.id);
-        ctx_server.queue_tasks.post(task);
+            ctx_server.queue_results.add_waiting_task_id(task_id);
+            ctx_server.queue_tasks.post(std::move(task));
+        }
 
-        server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
-        ctx_server.queue_results.remove_waiting_task_id(task.id);
+        server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
+        ctx_server.queue_results.remove_waiting_task_id(task_id);
 
         if (result->is_error()) {
             res_error(res, result->to_json());
@@ -3826,15 +3837,18 @@ int main(int argc, char ** argv) {
     };
 
     const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
-        server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
-        task.id = ctx_server.queue_tasks.get_new_id();
-        task.slot_action.slot_id = id_slot;
+        int task_id = ctx_server.queue_tasks.get_new_id();
+        {
+            server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
+            task.id = task_id;
+            task.slot_action.slot_id = id_slot;
 
-        ctx_server.queue_results.add_waiting_task_id(task.id);
-        ctx_server.queue_tasks.post(task);
+            ctx_server.queue_results.add_waiting_task_id(task_id);
+            ctx_server.queue_tasks.post(std::move(task));
+        }
 
-        server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
-        ctx_server.queue_results.remove_waiting_task_id(task.id);
+        server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
+        ctx_server.queue_results.remove_waiting_task_id(task_id);
 
         if (result->is_error()) {
             res_error(res, result->to_json());
@@ -3938,9 +3952,10 @@ int main(int argc, char ** argv) {
         }
 
         auto completion_id = gen_chatcmplid();
-        std::vector tasks;
-
+        std::unordered_set task_ids;
         try {
+            std::vector tasks;
+
             const auto & prompt = data.at("prompt");
             // TODO: this log can become very long, put it behind a flag or think about a more compact format
             //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str());
@@ -3955,9 +3970,9 @@ int main(int argc, char ** argv) {
 
                 task.prompt_tokens    = std::move(tokenized_prompts[i]);
                 task.params           = server_task::params_from_json_cmpl(
-                                            ctx_server.ctx,
-                                            ctx_server.params_base,
-                                            data);
+                        ctx_server.ctx,
+                        ctx_server.params_base,
+                        data);
                 task.id_selected_slot = json_value(data, "id_slot", -1);
 
                 // OAI-compat
@@ -3965,18 +3980,18 @@ int main(int argc, char ** argv) {
                 task.params.oaicompat_cmpl_id         = completion_id;
                 // oaicompat_model is already populated by params_from_json_cmpl
 
-                tasks.push_back(task);
+                tasks.push_back(std::move(task));
             }
+
+            task_ids = server_task::get_list_id(tasks);
+            ctx_server.queue_results.add_waiting_tasks(tasks);
+            ctx_server.queue_tasks.post(std::move(tasks));
         } catch (const std::exception & e) {
             res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
             return;
         }
 
-        ctx_server.queue_results.add_waiting_tasks(tasks);
-        ctx_server.queue_tasks.post(tasks);
-
         bool stream = json_value(data, "stream", false);
-        const auto task_ids = server_task::get_list_id(tasks);
 
         if (!stream) {
             ctx_server.receive_multi_results(task_ids, [&](std::vector & results) {
@@ -4268,6 +4283,7 @@ int main(int argc, char ** argv) {
         // create and queue the task
         json responses = json::array();
         bool error = false;
+        std::unordered_set task_ids;
         {
             std::vector tasks;
             for (size_t i = 0; i < tokenized_prompts.size(); i++) {
@@ -4280,28 +4296,27 @@ int main(int argc, char ** argv) {
                 // OAI-compat
                 task.params.oaicompat = oaicompat;
 
-                tasks.push_back(task);
+                tasks.push_back(std::move(task));
             }
 
+            task_ids = server_task::get_list_id(tasks);
             ctx_server.queue_results.add_waiting_tasks(tasks);
-            ctx_server.queue_tasks.post(tasks);
-
-            // get the result
-            std::unordered_set task_ids = server_task::get_list_id(tasks);
-
-            ctx_server.receive_multi_results(task_ids, [&](std::vector & results) {
-                for (auto & res : results) {
-                    GGML_ASSERT(dynamic_cast(res.get()) != nullptr);
-                    responses.push_back(res->to_json());
-                }
-            }, [&](const json & error_data) {
-                res_error(res, error_data);
-                error = true;
-            }, req.is_connection_closed);
-
-            ctx_server.queue_results.remove_waiting_task_ids(task_ids);
+            ctx_server.queue_tasks.post(std::move(tasks));
         }
 
+        // get the result
+        ctx_server.receive_multi_results(task_ids, [&](std::vector & results) {
+            for (auto & res : results) {
+                GGML_ASSERT(dynamic_cast(res.get()) != nullptr);
+                responses.push_back(res->to_json());
+            }
+        }, [&](const json & error_data) {
+            res_error(res, error_data);
+            error = true;
+        }, req.is_connection_closed);
+
+        ctx_server.queue_results.remove_waiting_task_ids(task_ids);
+
         if (error) {
             return;
         }
@@ -4367,6 +4382,7 @@ int main(int argc, char ** argv) {
         // create and queue the task
         json responses = json::array();
         bool error = false;
+        std::unordered_set task_ids;
         {
             std::vector tasks;
             std::vector tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
@@ -4376,26 +4392,24 @@ int main(int argc, char ** argv) {
                 task.id            = ctx_server.queue_tasks.get_new_id();
                 task.index         = i;
                 task.prompt_tokens = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
-                tasks.push_back(task);
+                tasks.push_back(std::move(task));
             }
 
+            task_ids = server_task::get_list_id(tasks);
             ctx_server.queue_results.add_waiting_tasks(tasks);
-            ctx_server.queue_tasks.post(tasks);
-
-            // get the result
-            std::unordered_set task_ids = server_task::get_list_id(tasks);
-
-            ctx_server.receive_multi_results(task_ids, [&](std::vector & results) {
-                for (auto & res : results) {
-                    GGML_ASSERT(dynamic_cast(res.get()) != nullptr);
-                    responses.push_back(res->to_json());
-                }
-            }, [&](const json & error_data) {
-                res_error(res, error_data);
-                error = true;
-            }, req.is_connection_closed);
+            ctx_server.queue_tasks.post(std::move(tasks));
         }
 
+        ctx_server.receive_multi_results(task_ids, [&](std::vector & results) {
+            for (auto & res : results) {
+                GGML_ASSERT(dynamic_cast(res.get()) != nullptr);
+                responses.push_back(res->to_json());
+            }
+        }, [&](const json & error_data) {
+            res_error(res, error_data);
+            error = true;
+        }, req.is_connection_closed);
+
         if (error) {
             return;
         }
@@ -4431,14 +4445,19 @@ int main(int argc, char ** argv) {
             res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
             return;
         }
-        server_task task(SERVER_TASK_TYPE_SET_LORA);
-        task.id = ctx_server.queue_tasks.get_new_id();
-        task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body);
-        ctx_server.queue_results.add_waiting_task_id(task.id);
-        ctx_server.queue_tasks.post(task);
 
-        server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
-        ctx_server.queue_results.remove_waiting_task_id(task.id);
+        int task_id = ctx_server.queue_tasks.get_new_id();
+        {
+            server_task task(SERVER_TASK_TYPE_SET_LORA);
+            task.id = task_id;
+            task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body);
+            ctx_server.queue_results.add_waiting_task_id(task_id);
+            ctx_server.queue_tasks.post(std::move(task));
+        }
+
+        // get the result
+        server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
+        ctx_server.queue_results.remove_waiting_task_id(task_id);
 
         if (result->is_error()) {
             res_error(res, result->to_json());
@@ -4582,8 +4601,8 @@ int main(int argc, char ** argv) {
         common_chat_templates_source(ctx_server.chat_templates.get()),
         common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str());
 
-    ctx_server.queue_tasks.on_new_task([&ctx_server](const server_task & task) {
-        ctx_server.process_single_task(task);
+    ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) {
+        ctx_server.process_single_task(std::move(task));
     });
 
     ctx_server.queue_tasks.on_update_slots([&ctx_server]() {

From aff9d107b066c9d464f7ab1324280acfdcebf569 Mon Sep 17 00:00:00 2001
From: Chris Thompson 
Date: Fri, 18 Apr 2025 12:30:41 -0600
Subject: [PATCH 13/39] gguf-py : GGUF Editor GUI - Python + Qt6 (#12930)

---
 gguf-py/README.md                             |    7 +
 gguf-py/gguf/scripts/__init__.py              |    1 +
 gguf-py/gguf/scripts/gguf_editor_gui.py       | 1610 +++++++++++++++++
 gguf-py/pyproject.toml                        |    7 +-
 requirements/requirements-all.txt             |    2 +
 requirements/requirements-gguf_editor_gui.txt |    3 +
 6 files changed, 1629 insertions(+), 1 deletion(-)
 create mode 100755 gguf-py/gguf/scripts/gguf_editor_gui.py
 create mode 100644 requirements/requirements-gguf_editor_gui.txt

diff --git a/gguf-py/README.md b/gguf-py/README.md
index dd4ab7bde7..ca7e09c681 100644
--- a/gguf-py/README.md
+++ b/gguf-py/README.md
@@ -11,6 +11,11 @@ as an example for its usage.
 pip install gguf
 ```
 
+Optionally, you can install gguf with the extra 'gui' to enable the visual GGUF editor.
+```sh
+pip install gguf[gui]
+```
+
 ## API Examples/Simple Tools
 
 [examples/writer.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/examples/writer.py) — Generates `example.gguf` in the current directory to demonstrate generating a GGUF file. Note that this file cannot be used as a model.
@@ -25,6 +30,8 @@ pip install gguf
 
 [gguf/scripts/gguf_new_metadata.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_new_metadata.py) — Copies a GGUF file with added/modified/removed metadata values.
 
+[gguf/scripts/gguf_editor_gui.py](https://github.com/ggml-org/llama.cpp/blob/master/gguf-py/gguf/scripts/gguf_editor_gui.py) — Allows for viewing, editing, adding, or removing metadata values within a GGUF file as well as viewing its tensors with a Qt interface.
+
 ## Development
 Maintainers who participate in development of this package are advised to install it in editable mode:
 
diff --git a/gguf-py/gguf/scripts/__init__.py b/gguf-py/gguf/scripts/__init__.py
index e77f2e9c97..72cc73e700 100644
--- a/gguf-py/gguf/scripts/__init__.py
+++ b/gguf-py/gguf/scripts/__init__.py
@@ -4,3 +4,4 @@ from .gguf_convert_endian import main as gguf_convert_endian_entrypoint
 from .gguf_dump import main as gguf_dump_entrypoint
 from .gguf_set_metadata import main as gguf_set_metadata_entrypoint
 from .gguf_new_metadata import main as gguf_new_metadata_entrypoint
+from .gguf_editor_gui import main as gguf_editor_gui_entrypoint
diff --git a/gguf-py/gguf/scripts/gguf_editor_gui.py b/gguf-py/gguf/scripts/gguf_editor_gui.py
new file mode 100755
index 0000000000..9dab6ca276
--- /dev/null
+++ b/gguf-py/gguf/scripts/gguf_editor_gui.py
@@ -0,0 +1,1610 @@
+#!/usr/bin/env python3
+from __future__ import annotations
+
+import logging
+import argparse
+import os
+import sys
+import numpy
+import enum
+from pathlib import Path
+from typing import Any, Optional, Tuple, Type
+import warnings
+
+import numpy as np
+from PySide6.QtWidgets import (
+    QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout,
+    QPushButton, QLabel, QLineEdit, QFileDialog, QTableWidget,
+    QTableWidgetItem, QComboBox, QMessageBox, QTabWidget,
+    QTextEdit, QFormLayout,
+    QHeaderView, QDialog, QDialogButtonBox
+)
+from PySide6.QtCore import Qt
+
+# Necessary to load the local gguf package
+if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists():
+    sys.path.insert(0, str(Path(__file__).parent.parent.parent))
+
+import gguf
+from gguf import GGUFReader, GGUFWriter, GGUFValueType, ReaderField
+from gguf.constants import TokenType, RopeScalingType, PoolingType, GGMLQuantizationType
+
+logger = logging.getLogger("gguf-editor-gui")
+
+# Map of key names to enum types for automatic enum interpretation
+KEY_TO_ENUM_TYPE = {
+    gguf.Keys.Tokenizer.TOKEN_TYPE: TokenType,
+    gguf.Keys.Rope.SCALING_TYPE: RopeScalingType,
+    gguf.Keys.LLM.POOLING_TYPE: PoolingType,
+    gguf.Keys.General.FILE_TYPE: GGMLQuantizationType,
+}
+
+# Define the tokenizer keys that should be edited together
+TOKENIZER_LINKED_KEYS = [
+    gguf.Keys.Tokenizer.LIST,
+    gguf.Keys.Tokenizer.TOKEN_TYPE,
+    gguf.Keys.Tokenizer.SCORES
+]
+
+
+class TokenizerEditorDialog(QDialog):
+    def __init__(self, tokens, token_types, scores, parent=None):
+        super().__init__(parent)
+        self.setWindowTitle("Edit Tokenizer Data")
+        self.resize(900, 600)
+
+        self.tokens = tokens.copy() if tokens else []
+        self.token_types = token_types.copy() if token_types else []
+        self.scores = scores.copy() if scores else []
+
+        # Ensure all arrays have the same length
+        max_len = max(len(self.tokens), len(self.token_types), len(self.scores))
+        if len(self.tokens) < max_len:
+            self.tokens.extend([""] * (max_len - len(self.tokens)))
+        if len(self.token_types) < max_len:
+            self.token_types.extend([0] * (max_len - len(self.token_types)))
+        if len(self.scores) < max_len:
+            self.scores.extend([0.0] * (max_len - len(self.scores)))
+
+        layout = QVBoxLayout(self)
+
+        # Add filter controls
+        filter_layout = QHBoxLayout()
+        filter_layout.addWidget(QLabel("Filter:"))
+        self.filter_edit = QLineEdit()
+        self.filter_edit.setPlaceholderText("Type to filter tokens...")
+        self.filter_edit.textChanged.connect(self.apply_filter)
+        filter_layout.addWidget(self.filter_edit)
+
+        # Add page controls
+        self.page_size = 100  # Show 100 items per page
+        self.current_page = 0
+        self.total_pages = max(1, (len(self.tokens) + self.page_size - 1) // self.page_size)
+
+        self.page_label = QLabel(f"Page 1 of {self.total_pages}")
+        filter_layout.addWidget(self.page_label)
+
+        prev_page = QPushButton("Previous")
+        prev_page.clicked.connect(self.previous_page)
+        filter_layout.addWidget(prev_page)
+
+        next_page = QPushButton("Next")
+        next_page.clicked.connect(self.next_page)
+        filter_layout.addWidget(next_page)
+
+        layout.addLayout(filter_layout)
+
+        # Tokenizer data table
+        self.tokens_table = QTableWidget()
+        self.tokens_table.setColumnCount(4)
+        self.tokens_table.setHorizontalHeaderLabels(["Index", "Token", "Type", "Score"])
+        self.tokens_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents)
+        self.tokens_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch)
+        self.tokens_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents)
+        self.tokens_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents)
+
+        layout.addWidget(self.tokens_table)
+
+        # Controls
+        controls_layout = QHBoxLayout()
+
+        add_button = QPushButton("Add Token")
+        add_button.clicked.connect(self.add_token)
+        controls_layout.addWidget(add_button)
+
+        remove_button = QPushButton("Remove Selected")
+        remove_button.clicked.connect(self.remove_selected)
+        controls_layout.addWidget(remove_button)
+
+        controls_layout.addStretch()
+
+        layout.addLayout(controls_layout)
+
+        # Buttons
+        buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
+        buttons.accepted.connect(self.accept)
+        buttons.rejected.connect(self.reject)
+        layout.addWidget(buttons)
+
+        # Initialize the filtered values
+        self.filtered_indices = list(range(len(self.tokens)))
+
+        # Load data for the first page
+        self.load_page()
+
+    def apply_filter(self):
+        """Filter the tokens based on the search text."""
+        filter_text = self.filter_edit.text().lower()
+
+        if not filter_text:
+            # No filter, show all values
+            self.filtered_indices = list(range(len(self.tokens)))
+        else:
+            # Apply filter
+            self.filtered_indices = []
+            for i, token in enumerate(self.tokens):
+                if filter_text in str(token).lower():
+                    self.filtered_indices.append(i)
+
+        # Reset to first page and reload
+        self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size)
+        self.current_page = 0
+        self.page_label.setText(f"Page 1 of {self.total_pages}")
+        self.load_page()
+
+    def previous_page(self):
+        """Go to the previous page of results."""
+        if self.current_page > 0:
+            self.current_page -= 1
+            self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
+            self.load_page()
+
+    def next_page(self):
+        """Go to the next page of results."""
+        if self.current_page < self.total_pages - 1:
+            self.current_page += 1
+            self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
+            self.load_page()
+
+    def load_page(self):
+        """Load the current page of tokenizer data."""
+        self.tokens_table.setRowCount(0)  # Clear the table
+
+        # Calculate start and end indices for the current page
+        start_idx = self.current_page * self.page_size
+        end_idx = min(start_idx + self.page_size, len(self.filtered_indices))
+
+        # Pre-allocate rows for better performance
+        self.tokens_table.setRowCount(end_idx - start_idx)
+
+        for row, i in enumerate(range(start_idx, end_idx)):
+            orig_idx = self.filtered_indices[i]
+
+            # Index
+            index_item = QTableWidgetItem(str(orig_idx))
+            index_item.setData(Qt.ItemDataRole.UserRole, orig_idx)  # Store original index
+            index_item.setFlags(index_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+            self.tokens_table.setItem(row, 0, index_item)
+
+            # Token
+            token_item = QTableWidgetItem(str(self.tokens[orig_idx]))
+            self.tokens_table.setItem(row, 1, token_item)
+
+            # Token Type
+            token_type = self.token_types[orig_idx] if orig_idx < len(self.token_types) else 0
+            try:
+                enum_val = TokenType(token_type)
+                display_text = f"{enum_val.name} ({token_type})"
+            except (ValueError, KeyError):
+                display_text = f"Unknown ({token_type})"
+
+            type_item = QTableWidgetItem(display_text)
+            type_item.setData(Qt.ItemDataRole.UserRole, token_type)
+
+            # Make type cell editable with a double-click handler
+            type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+            self.tokens_table.setItem(row, 2, type_item)
+
+            # Score
+            score = self.scores[orig_idx] if orig_idx < len(self.scores) else 0.0
+            score_item = QTableWidgetItem(str(score))
+            self.tokens_table.setItem(row, 3, score_item)
+
+        # Connect double-click handler for token type cells
+        self.tokens_table.cellDoubleClicked.connect(self.handle_cell_double_click)
+
+    def handle_cell_double_click(self, row, column):
+        """Handle double-click on a cell, specifically for token type editing."""
+        if column == 2:  # Token Type column
+            orig_item = self.tokens_table.item(row, 0)
+            if orig_item:
+                orig_idx = orig_item.data(Qt.ItemDataRole.UserRole)
+                self.edit_token_type(row, orig_idx)
+
+    def edit_token_type(self, row, orig_idx):
+        """Edit a token type using a dialog with a dropdown of all enum options."""
+        current_value = self.token_types[orig_idx] if orig_idx < len(self.token_types) else 0
+
+        # Create a dialog with enum options
+        dialog = QDialog(self)
+        dialog.setWindowTitle("Select Token Type")
+        layout = QVBoxLayout(dialog)
+
+        combo = QComboBox()
+        for enum_val in TokenType:
+            combo.addItem(f"{enum_val.name} ({enum_val.value})", enum_val.value)
+
+        # Set current value
+        try:
+            if isinstance(current_value, int):
+                enum_val = TokenType(current_value)
+                combo.setCurrentText(f"{enum_val.name} ({current_value})")
+        except (ValueError, KeyError):
+            pass
+
+        layout.addWidget(combo)
+
+        buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
+        buttons.accepted.connect(dialog.accept)
+        buttons.rejected.connect(dialog.reject)
+        layout.addWidget(buttons)
+
+        if dialog.exec() == QDialog.DialogCode.Accepted:
+            # Get the selected value
+            new_value = combo.currentData()
+            enum_val = TokenType(new_value)
+            display_text = f"{enum_val.name} ({new_value})"
+
+            # Update the display
+            type_item = self.tokens_table.item(row, 2)
+            if type_item:
+                type_item.setText(display_text)
+                type_item.setData(Qt.ItemDataRole.UserRole, new_value)
+
+            # Update the actual value
+            self.token_types[orig_idx] = new_value
+
+    def add_token(self):
+        """Add a new token to the end of the list."""
+        # Add to the end of the arrays
+        self.tokens.append("")
+        self.token_types.append(0)  # Default to normal token
+        self.scores.append(0.0)
+
+        orig_idx = len(self.tokens) - 1
+
+        # Add to filtered indices if it matches the current filter
+        filter_text = self.filter_edit.text().lower()
+        if not filter_text or filter_text in "":
+            self.filtered_indices.append(orig_idx)
+
+        # Update pagination
+        self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size)
+
+        # Go to the last page to show the new item
+        self.current_page = self.total_pages - 1
+        self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
+
+        # Reload the page
+        self.load_page()
+
+    def remove_selected(self):
+        """Remove selected tokens from all arrays."""
+        selected_rows = []
+        for item in self.tokens_table.selectedItems():
+            row = item.row()
+            if row not in selected_rows:
+                selected_rows.append(row)
+
+        if not selected_rows:
+            return
+
+        # Get original indices in descending order to avoid index shifting
+        orig_indices = []
+        for row in selected_rows:
+            orig_item = self.tokens_table.item(row, 0)
+            if orig_item:
+                orig_indices.append(orig_item.data(Qt.ItemDataRole.UserRole))
+        orig_indices.sort(reverse=True)
+
+        # Remove from all arrays
+        for idx in orig_indices:
+            if idx < len(self.tokens):
+                del self.tokens[idx]
+            if idx < len(self.token_types):
+                del self.token_types[idx]
+            if idx < len(self.scores):
+                del self.scores[idx]
+
+        # Rebuild filtered_indices
+        self.filtered_indices = []
+        filter_text = self.filter_edit.text().lower()
+
+        for i, token in enumerate(self.tokens):
+            if not filter_text or filter_text in str(token).lower():
+                self.filtered_indices.append(i)
+
+        # Update pagination
+        self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size)
+        self.current_page = min(self.current_page, self.total_pages - 1)
+        self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
+
+        # Reload the page
+        self.load_page()
+
+    def get_data(self):
+        """Return the edited tokenizer data."""
+        return self.tokens, self.token_types, self.scores
+
+
+class ArrayEditorDialog(QDialog):
+    def __init__(self, array_values, element_type, key=None, parent=None):
+        super().__init__(parent)
+        self.setWindowTitle("Edit Array Values")
+        self.resize(700, 500)
+
+        self.array_values = array_values
+        self.element_type = element_type
+        self.key = key
+
+        # Get enum type for this array if applicable
+        self.enum_type = None
+        if key in KEY_TO_ENUM_TYPE and element_type == GGUFValueType.INT32:
+            self.enum_type = KEY_TO_ENUM_TYPE[key]
+
+        layout = QVBoxLayout(self)
+
+        # Add enum type information if applicable
+        if self.enum_type is not None:
+            enum_info_layout = QHBoxLayout()
+            enum_label = QLabel(f"Editing {self.enum_type.__name__} values:")
+            enum_info_layout.addWidget(enum_label)
+
+            # Add a legend for the enum values
+            enum_values = ", ".join([f"{e.name}={e.value}" for e in self.enum_type])
+            enum_values_label = QLabel(f"Available values: {enum_values}")
+            enum_values_label.setWordWrap(True)
+            enum_info_layout.addWidget(enum_values_label, 1)
+
+            layout.addLayout(enum_info_layout)
+
+        # Add search/filter controls
+        filter_layout = QHBoxLayout()
+        filter_layout.addWidget(QLabel("Filter:"))
+        self.filter_edit = QLineEdit()
+        self.filter_edit.setPlaceholderText("Type to filter values...")
+        self.filter_edit.textChanged.connect(self.apply_filter)
+        filter_layout.addWidget(self.filter_edit)
+
+        # Add page controls for large arrays
+        self.page_size = 100  # Show 100 items per page
+        self.current_page = 0
+        self.total_pages = max(1, (len(array_values) + self.page_size - 1) // self.page_size)
+
+        self.page_label = QLabel(f"Page 1 of {self.total_pages}")
+        filter_layout.addWidget(self.page_label)
+
+        prev_page = QPushButton("Previous")
+        prev_page.clicked.connect(self.previous_page)
+        filter_layout.addWidget(prev_page)
+
+        next_page = QPushButton("Next")
+        next_page.clicked.connect(self.next_page)
+        filter_layout.addWidget(next_page)
+
+        layout.addLayout(filter_layout)
+
+        # Array items table
+        self.items_table = QTableWidget()
+
+        # Set up columns based on whether we have an enum type
+        if self.enum_type is not None:
+            self.items_table.setColumnCount(3)
+            self.items_table.setHorizontalHeaderLabels(["Index", "Value", "Actions"])
+            self.items_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents)
+            self.items_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch)
+            self.items_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents)
+        else:
+            self.items_table.setColumnCount(2)
+            self.items_table.setHorizontalHeaderLabels(["Index", "Value"])
+            self.items_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents)
+            self.items_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch)
+
+        layout.addWidget(self.items_table)
+
+        # Controls
+        controls_layout = QHBoxLayout()
+
+        add_button = QPushButton("Add Item")
+        add_button.clicked.connect(self.add_item)
+        controls_layout.addWidget(add_button)
+
+        remove_button = QPushButton("Remove Selected")
+        remove_button.clicked.connect(self.remove_selected)
+        controls_layout.addWidget(remove_button)
+
+        # Add bulk edit button for enum arrays
+        if self.enum_type is not None:
+            bulk_edit_button = QPushButton("Bulk Edit Selected")
+            bulk_edit_button.clicked.connect(self.bulk_edit_selected)
+            controls_layout.addWidget(bulk_edit_button)
+
+        controls_layout.addStretch()
+
+        layout.addLayout(controls_layout)
+
+        # Buttons
+        buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
+        buttons.accepted.connect(self.accept)
+        buttons.rejected.connect(self.reject)
+        layout.addWidget(buttons)
+
+        # Initialize the filtered values
+        self.filtered_indices = list(range(len(self.array_values)))
+
+        # Load array values for the first page
+        self.load_page()
+
+    def apply_filter(self):
+        """Filter the array values based on the search text."""
+        filter_text = self.filter_edit.text().lower()
+
+        if not filter_text:
+            # No filter, show all values
+            self.filtered_indices = list(range(len(self.array_values)))
+        else:
+            # Apply filter
+            self.filtered_indices = []
+            for i, value in enumerate(self.array_values):
+                # For enum values, search in both name and value
+                if self.enum_type is not None and isinstance(value, int):
+                    try:
+                        enum_val = self.enum_type(value)
+                        display_text = f"{enum_val.name} ({value})".lower()
+                        if filter_text in display_text:
+                            self.filtered_indices.append(i)
+                    except (ValueError, KeyError):
+                        # If not a valid enum value, just check the raw value
+                        if filter_text in str(value).lower():
+                            self.filtered_indices.append(i)
+                else:
+                    # For non-enum values, just check the string representation
+                    if filter_text in str(value).lower():
+                        self.filtered_indices.append(i)
+
+        # Reset to first page and reload
+        self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size)
+        self.current_page = 0
+        self.page_label.setText(f"Page 1 of {self.total_pages}")
+        self.load_page()
+
+    def previous_page(self):
+        """Go to the previous page of results."""
+        if self.current_page > 0:
+            self.current_page -= 1
+            self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
+            self.load_page()
+
+    def next_page(self):
+        """Go to the next page of results."""
+        if self.current_page < self.total_pages - 1:
+            self.current_page += 1
+            self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
+            self.load_page()
+
+    def load_page(self):
+        """Load the current page of array values."""
+        self.items_table.setRowCount(0)  # Clear the table
+
+        # Calculate start and end indices for the current page
+        start_idx = self.current_page * self.page_size
+        end_idx = min(start_idx + self.page_size, len(self.filtered_indices))
+
+        # Pre-allocate rows for better performance
+        self.items_table.setRowCount(end_idx - start_idx)
+
+        for row, i in enumerate(range(start_idx, end_idx)):
+            orig_idx = self.filtered_indices[i]
+            value = self.array_values[orig_idx]
+
+            # Index
+            index_item = QTableWidgetItem(str(orig_idx))
+            index_item.setData(Qt.ItemDataRole.UserRole, orig_idx)  # Store original index
+            index_item.setFlags(index_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+            self.items_table.setItem(row, 0, index_item)
+
+            # Value
+            if self.enum_type is not None:
+                # Display enum value and name
+                try:
+                    if isinstance(value, (int, numpy.signedinteger)):
+                        enum_val = self.enum_type(value)
+                        display_text = f"{enum_val.name} ({value})"
+                    else:
+                        display_text = str(value)
+                except (ValueError, KeyError):
+                    display_text = f"Unknown ({value})"
+
+                # Store the enum value in the item
+                value_item = QTableWidgetItem(display_text)
+                value_item.setData(Qt.ItemDataRole.UserRole, value)
+                value_item.setFlags(value_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+                self.items_table.setItem(row, 1, value_item)
+
+                # Add an edit button in a separate column
+                edit_button = QPushButton("Edit")
+                edit_button.setProperty("row", row)
+                edit_button.clicked.connect(self.edit_array_enum_value)
+
+                # Create a widget to hold the button
+                button_widget = QWidget()
+                button_layout = QHBoxLayout(button_widget)
+                button_layout.setContentsMargins(2, 2, 2, 2)
+                button_layout.addWidget(edit_button)
+                button_layout.addStretch()
+
+                self.items_table.setCellWidget(row, 2, button_widget)
+            else:
+                value_item = QTableWidgetItem(str(value))
+                self.items_table.setItem(row, 1, value_item)
+
+    def edit_array_enum_value(self):
+        """Handle editing an enum value in the array editor."""
+        button = self.sender()
+        row = button.property("row")
+
+        # Get the original index from the table item
+        orig_item = self.items_table.item(row, 0)
+        new_item = self.items_table.item(row, 1)
+        if orig_item and new_item and self.enum_type and self.edit_enum_value(row, self.enum_type):
+            orig_idx = orig_item.data(Qt.ItemDataRole.UserRole)
+            new_value = new_item.data(Qt.ItemDataRole.UserRole)
+            # Update the stored value in the array
+            if isinstance(new_value, (int, float, str, bool)):
+                self.array_values[orig_idx] = new_value
+
+    def bulk_edit_selected(self):
+        """Edit multiple enum values at once."""
+        if not self.enum_type:
+            return
+
+        selected_rows = set()
+        for item in self.items_table.selectedItems():
+            selected_rows.add(item.row())
+
+        if not selected_rows:
+            QMessageBox.information(self, "No Selection", "Please select at least one row to edit.")
+            return
+
+        # Create a dialog with enum options
+        dialog = QDialog(self)
+        dialog.setWindowTitle(f"Bulk Edit {self.enum_type.__name__} Values")
+        layout = QVBoxLayout(dialog)
+
+        layout.addWidget(QLabel(f"Set {len(selected_rows)} selected items to:"))
+
+        combo = QComboBox()
+        for enum_val in self.enum_type:
+            combo.addItem(f"{enum_val.name} ({enum_val.value})", enum_val.value)
+
+        layout.addWidget(combo)
+
+        buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
+        buttons.accepted.connect(dialog.accept)
+        buttons.rejected.connect(dialog.reject)
+        layout.addWidget(buttons)
+
+        if dialog.exec() == QDialog.DialogCode.Accepted:
+            # Get the selected value
+            new_value = combo.currentData()
+            enum_val = self.enum_type(new_value)
+            display_text = f"{enum_val.name} ({new_value})"
+
+            # Update all selected rows
+            for row in selected_rows:
+                orig_item = self.items_table.item(row, 0)
+                new_item = self.items_table.item(row, 1)
+                if orig_item and new_item:
+                    orig_idx = orig_item.data(Qt.ItemDataRole.UserRole)
+                    self.array_values[orig_idx] = new_value
+
+                    # Update the display
+                    new_item.setText(display_text)
+                    new_item.setData(Qt.ItemDataRole.UserRole, new_value)
+
+    def add_item(self):
+        # Add to the end of the array
+        orig_idx = len(self.array_values)
+
+        # Add default value based on type
+        if self.enum_type is not None:
+            # Default to first enum value
+            default_value = list(self.enum_type)[0].value
+            self.array_values.append(default_value)
+        else:
+            if self.element_type == GGUFValueType.STRING:
+                self.array_values.append("")
+            else:
+                self.array_values.append(0)
+
+        # Add to filtered indices if it matches the current filter
+        self.filtered_indices.append(orig_idx)
+
+        # Update pagination
+        self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size)
+
+        # Go to the last page to show the new item
+        self.current_page = self.total_pages - 1
+        self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
+
+        # Reload the page
+        self.load_page()
+
+    def remove_selected(self):
+        selected_rows = []
+        for item in self.items_table.selectedItems():
+            row = item.row()
+            if row not in selected_rows:
+                selected_rows.append(row)
+
+        if not selected_rows:
+            return
+
+        # Get original indices in descending order to avoid index shifting
+        orig_indices = list()
+        for row in selected_rows:
+            orig_item = self.items_table.item(row, 0)
+            if orig_item:
+                orig_indices.append(orig_item.data(Qt.ItemDataRole.UserRole))
+        orig_indices.sort(reverse=True)
+
+        # Remove from array_values
+        for idx in orig_indices:
+            del self.array_values[idx]
+
+        # Rebuild filtered_indices
+        self.filtered_indices = []
+        filter_text = self.filter_edit.text().lower()
+
+        for i, value in enumerate(self.array_values):
+            if not filter_text:
+                self.filtered_indices.append(i)
+            else:
+                # Apply filter
+                if self.enum_type is not None and isinstance(value, int):
+                    try:
+                        enum_val = self.enum_type(value)
+                        display_text = f"{enum_val.name} ({value})".lower()
+                        if filter_text in display_text:
+                            self.filtered_indices.append(i)
+                    except (ValueError, KeyError):
+                        if filter_text in str(value).lower():
+                            self.filtered_indices.append(i)
+                else:
+                    if filter_text in str(value).lower():
+                        self.filtered_indices.append(i)
+
+        # Update pagination
+        self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size)
+        self.current_page = min(self.current_page, self.total_pages - 1)
+        self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}")
+
+        # Reload the page
+        self.load_page()
+
+    def edit_enum_value(self, row: int, enum_type: Type[enum.Enum]):
+        """Edit an enum value using a dialog with a dropdown of all enum options."""
+        # Get the original index from the table item
+        orig_item = self.items_table.item(row, 0)
+        if orig_item:
+            orig_idx = orig_item.data(Qt.ItemDataRole.UserRole)
+        else:
+            return
+        current_value = self.array_values[orig_idx]
+
+        # Create a dialog with enum options
+        dialog = QDialog(self)
+        dialog.setWindowTitle(f"Select {enum_type.__name__} Value")
+        layout = QVBoxLayout(dialog)
+
+        # Add description
+        description = QLabel(f"Select a {enum_type.__name__} value:")
+        layout.addWidget(description)
+
+        # Use a combo box for quick selection
+        combo = QComboBox()
+        for enum_val in enum_type:
+            combo.addItem(f"{enum_val.name} ({enum_val.value})", enum_val.value)
+
+        # Set current value
+        try:
+            if isinstance(current_value, int):
+                enum_val = enum_type(current_value)
+                combo.setCurrentText(f"{enum_val.name} ({current_value})")
+        except (ValueError, KeyError):
+            pass
+
+        layout.addWidget(combo)
+
+        buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
+        buttons.accepted.connect(dialog.accept)
+        buttons.rejected.connect(dialog.reject)
+        layout.addWidget(buttons)
+
+        if dialog.exec() == QDialog.DialogCode.Accepted:
+            # Update the value display and stored data
+            new_value = combo.currentData()
+            enum_val = enum_type(new_value)
+            display_text = f"{enum_val.name} ({new_value})"
+
+            new_item = self.items_table.item(row, 1)
+            if new_item:
+                new_item.setText(display_text)
+                new_item.setData(Qt.ItemDataRole.UserRole, new_value)
+
+            # Update the actual array value
+            self.array_values[orig_idx] = new_value
+            return True
+        return False
+
+    def get_array_values(self):
+        # The array_values list is kept up-to-date as edits are made
+        return self.array_values
+
+
+class AddMetadataDialog(QDialog):
+    def __init__(self, parent=None):
+        super().__init__(parent)
+        self.setWindowTitle("Add Metadata")
+        self.resize(400, 200)
+
+        layout = QVBoxLayout(self)
+
+        form_layout = QFormLayout()
+
+        self.key_edit = QLineEdit()
+        form_layout.addRow("Key:", self.key_edit)
+
+        self.type_combo = QComboBox()
+        for value_type in GGUFValueType:
+            if value_type != GGUFValueType.ARRAY:  # Skip array type for simplicity
+                self.type_combo.addItem(value_type.name, value_type)
+        form_layout.addRow("Type:", self.type_combo)
+
+        self.value_edit = QTextEdit()
+        form_layout.addRow("Value:", self.value_edit)
+
+        layout.addLayout(form_layout)
+
+        buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
+        buttons.accepted.connect(self.accept)
+        buttons.rejected.connect(self.reject)
+        layout.addWidget(buttons)
+
+    def get_data(self) -> Tuple[str, GGUFValueType, Any]:
+        key = self.key_edit.text()
+        value_type = self.type_combo.currentData()
+        value_text = self.value_edit.toPlainText()
+
+        # Convert value based on type
+        if value_type == GGUFValueType.UINT8:
+            value = np.uint8(int(value_text))
+        elif value_type == GGUFValueType.INT8:
+            value = np.int8(int(value_text))
+        elif value_type == GGUFValueType.UINT16:
+            value = np.uint16(int(value_text))
+        elif value_type == GGUFValueType.INT16:
+            value = np.int16(int(value_text))
+        elif value_type == GGUFValueType.UINT32:
+            value = np.uint32(int(value_text))
+        elif value_type == GGUFValueType.INT32:
+            value = np.int32(int(value_text))
+        elif value_type == GGUFValueType.FLOAT32:
+            value = np.float32(float(value_text))
+        elif value_type == GGUFValueType.BOOL:
+            value = value_text.lower() in ('true', 'yes', '1')
+        elif value_type == GGUFValueType.STRING:
+            value = value_text
+        else:
+            value = value_text
+
+        return key, value_type, value
+
+
+class GGUFEditorWindow(QMainWindow):
+    def __init__(self):
+        super().__init__()
+
+        self.setWindowTitle("GGUF Editor")
+        self.resize(1000, 800)
+
+        self.current_file = None
+        self.reader = None
+        self.modified = False
+        self.metadata_changes = {}  # Store changes to apply when saving
+        self.metadata_to_remove = set()  # Store keys to remove when saving
+
+        self.setup_ui()
+
+    def setup_ui(self):
+        central_widget = QWidget()
+        self.setCentralWidget(central_widget)
+
+        main_layout = QVBoxLayout(central_widget)
+
+        # File controls
+        file_layout = QHBoxLayout()
+
+        self.file_path_edit = QLineEdit()
+        self.file_path_edit.setReadOnly(True)
+        file_layout.addWidget(self.file_path_edit)
+
+        open_button = QPushButton("Open GGUF")
+        open_button.clicked.connect(self.open_file)
+        file_layout.addWidget(open_button)
+
+        save_button = QPushButton("Save As...")
+        save_button.clicked.connect(self.save_file)
+        file_layout.addWidget(save_button)
+
+        main_layout.addLayout(file_layout)
+
+        # Tabs for different views
+        self.tabs = QTabWidget()
+
+        # Metadata tab
+        self.metadata_tab = QWidget()
+        metadata_layout = QVBoxLayout(self.metadata_tab)
+
+        # Metadata table
+        self.metadata_table = QTableWidget()
+        self.metadata_table.setColumnCount(4)
+        self.metadata_table.setHorizontalHeaderLabels(["Key", "Type", "Value", "Actions"])
+        self.metadata_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch)
+        self.metadata_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents)
+        self.metadata_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.Stretch)
+        self.metadata_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents)
+        metadata_layout.addWidget(self.metadata_table)
+
+        # Metadata controls
+        metadata_controls = QHBoxLayout()
+
+        add_metadata_button = QPushButton("Add Metadata")
+        add_metadata_button.clicked.connect(self.add_metadata)
+        metadata_controls.addWidget(add_metadata_button)
+
+        metadata_controls.addStretch()
+
+        metadata_layout.addLayout(metadata_controls)
+
+        # Tensors tab
+        self.tensors_tab = QWidget()
+        tensors_layout = QVBoxLayout(self.tensors_tab)
+
+        self.tensors_table = QTableWidget()
+        self.tensors_table.setColumnCount(5)
+        self.tensors_table.setHorizontalHeaderLabels(["Name", "Type", "Shape", "Elements", "Size (bytes)"])
+        self.tensors_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch)
+        self.tensors_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents)
+        self.tensors_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents)
+        self.tensors_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents)
+        self.tensors_table.horizontalHeader().setSectionResizeMode(4, QHeaderView.ResizeMode.ResizeToContents)
+        tensors_layout.addWidget(self.tensors_table)
+
+        # Add tabs to tab widget
+        self.tabs.addTab(self.metadata_tab, "Metadata")
+        self.tabs.addTab(self.tensors_tab, "Tensors")
+
+        main_layout.addWidget(self.tabs)
+
+        # Status bar
+        self.statusBar().showMessage("Ready")
+
+    def load_file(self, file_path):
+        """Load a GGUF file by path"""
+        try:
+            self.statusBar().showMessage(f"Loading {file_path}...")
+            QApplication.processEvents()
+
+            self.reader = GGUFReader(file_path, 'r')
+            self.current_file = file_path
+            self.file_path_edit.setText(file_path)
+
+            self.load_metadata()
+            self.load_tensors()
+
+            self.metadata_changes = {}
+            self.metadata_to_remove = set()
+            self.modified = False
+
+            self.statusBar().showMessage(f"Loaded {file_path}")
+            return True
+        except Exception as e:
+            QMessageBox.critical(self, "Error", f"Failed to open file: {str(e)}")
+            self.statusBar().showMessage("Error loading file")
+            return False
+
+    def open_file(self):
+        file_path, _ = QFileDialog.getOpenFileName(
+            self, "Open GGUF File", "", "GGUF Files (*.gguf);;All Files (*)"
+        )
+
+        if not file_path:
+            return
+
+        self.load_file(file_path)
+
+    def load_metadata(self):
+        self.metadata_table.setRowCount(0)
+
+        if not self.reader:
+            return
+
+        # Disconnect to prevent triggering during loading
+        with warnings.catch_warnings():
+            warnings.filterwarnings('ignore')
+            self.metadata_table.itemChanged.disconnect(self.on_metadata_changed)
+
+        for i, (key, field) in enumerate(self.reader.fields.items()):
+            self.metadata_table.insertRow(i)
+
+            # Key
+            key_item = QTableWidgetItem(key)
+            key_item.setFlags(key_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+            self.metadata_table.setItem(i, 0, key_item)
+
+            # Type
+            if not field.types:
+                type_str = "N/A"
+            elif field.types[0] == GGUFValueType.ARRAY:
+                nest_count = len(field.types) - 1
+                element_type = field.types[-1].name
+                # Check if this is an enum array
+                enum_type = self.get_enum_for_key(key)
+                if enum_type is not None and field.types[-1] == GGUFValueType.INT32:
+                    element_type = enum_type.__name__
+                type_str = '[' * nest_count + element_type + ']' * nest_count
+            else:
+                type_str = str(field.types[0].name)
+                # Check if this is an enum field
+                enum_type = self.get_enum_for_key(key)
+                if enum_type is not None and field.types[0] == GGUFValueType.INT32:
+                    type_str = enum_type.__name__
+
+            type_item = QTableWidgetItem(type_str)
+            type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+            self.metadata_table.setItem(i, 1, type_item)
+
+            # Value
+            value_str = self.format_field_value(field)
+            value_item = QTableWidgetItem(value_str)
+
+            # Make only simple values editable
+            if len(field.types) == 1 and field.types[0] != GGUFValueType.ARRAY:
+                value_item.setFlags(value_item.flags() | Qt.ItemFlag.ItemIsEditable)
+            else:
+                value_item.setFlags(value_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+
+            self.metadata_table.setItem(i, 2, value_item)
+
+            # Actions
+            actions_widget = QWidget()
+            actions_layout = QHBoxLayout(actions_widget)
+            actions_layout.setContentsMargins(2, 2, 2, 2)
+
+            # Add Edit button for arrays and enum fields
+            if field.types and field.types[0] == GGUFValueType.ARRAY:
+                edit_button = QPushButton("Edit")
+                edit_button.setProperty("row", i)
+                edit_button.setProperty("key", key)
+                edit_button.clicked.connect(self.edit_array_metadata)
+                actions_layout.addWidget(edit_button)
+
+                # Add special label for tokenizer linked fields
+                if key in TOKENIZER_LINKED_KEYS:
+                    edit_button.setText("Edit Tokenizer")
+                    edit_button.setToolTip("Edit all tokenizer data together")
+            elif len(field.types) == 1 and self.get_enum_for_key(key) is not None:
+                edit_button = QPushButton("Edit")
+                edit_button.setProperty("row", i)
+                edit_button.setProperty("key", key)
+                edit_button.clicked.connect(self.edit_metadata_enum)
+                actions_layout.addWidget(edit_button)
+
+            remove_button = QPushButton("Remove")
+            remove_button.setProperty("row", i)
+            remove_button.setProperty("key", key)
+            remove_button.clicked.connect(self.remove_metadata)
+            actions_layout.addWidget(remove_button)
+
+            self.metadata_table.setCellWidget(i, 3, actions_widget)
+
+        # Reconnect after loading
+        self.metadata_table.itemChanged.connect(self.on_metadata_changed)
+
+    def extract_array_values(self, field: ReaderField) -> list:
+        """Extract all values from an array field."""
+        if not field.types or field.types[0] != GGUFValueType.ARRAY:
+            return []
+
+        curr_type = field.types[1]
+        array_values = []
+        total_elements = len(field.data)
+
+        if curr_type == GGUFValueType.STRING:
+            for element_pos in range(total_elements):
+                value_string = str(bytes(field.parts[-1 - (total_elements - element_pos - 1) * 2]), encoding='utf-8')
+                array_values.append(value_string)
+        elif self.reader and curr_type in self.reader.gguf_scalar_to_np:
+            for element_pos in range(total_elements):
+                array_values.append(field.parts[-1 - (total_elements - element_pos - 1)][0])
+
+        return array_values
+
+    def get_enum_for_key(self, key: str) -> Optional[Type[enum.Enum]]:
+        """Get the enum type for a given key if it exists."""
+        return KEY_TO_ENUM_TYPE.get(key)
+
+    def format_enum_value(self, value: Any, enum_type: Type[enum.Enum]) -> str:
+        """Format a value as an enum if possible."""
+        try:
+            if isinstance(value, (int, str)):
+                enum_value = enum_type(value)
+                return f"{enum_value.name} ({value})"
+        except (ValueError, KeyError):
+            pass
+        return str(value)
+
+    def format_field_value(self, field: ReaderField) -> str:
+        if not field.types:
+            return "N/A"
+
+        if len(field.types) == 1:
+            curr_type = field.types[0]
+            if curr_type == GGUFValueType.STRING:
+                return str(bytes(field.parts[-1]), encoding='utf-8')
+            elif self.reader and curr_type in self.reader.gguf_scalar_to_np:
+                value = field.parts[-1][0]
+                # Check if this field has an enum type
+                enum_type = self.get_enum_for_key(field.name)
+                if enum_type is not None:
+                    return self.format_enum_value(value, enum_type)
+                return str(value)
+
+        if field.types[0] == GGUFValueType.ARRAY:
+            array_values = self.extract_array_values(field)
+            render_element = min(5, len(array_values))
+
+            # Get enum type for this array if applicable
+            enum_type = self.get_enum_for_key(field.name)
+
+            if enum_type is not None:
+                array_elements = []
+                for i in range(render_element):
+                    array_elements.append(self.format_enum_value(array_values[i], enum_type))
+            else:
+                array_elements = [str(array_values[i]) for i in range(render_element)]
+
+            return f"[ {', '.join(array_elements).strip()}{', ...' if len(array_values) > len(array_elements) else ''} ]"
+
+        return "Complex value"
+
+    def load_tensors(self):
+        self.tensors_table.setRowCount(0)
+
+        if not self.reader:
+            return
+
+        for i, tensor in enumerate(self.reader.tensors):
+            self.tensors_table.insertRow(i)
+
+            # Name
+            name_item = QTableWidgetItem(tensor.name)
+            name_item.setFlags(name_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+            self.tensors_table.setItem(i, 0, name_item)
+
+            # Type
+            type_item = QTableWidgetItem(tensor.tensor_type.name)
+            type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+            self.tensors_table.setItem(i, 1, type_item)
+
+            # Shape
+            shape_str = " × ".join(str(d) for d in tensor.shape)
+            shape_item = QTableWidgetItem(shape_str)
+            shape_item.setFlags(shape_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+            self.tensors_table.setItem(i, 2, shape_item)
+
+            # Elements
+            elements_item = QTableWidgetItem(str(tensor.n_elements))
+            elements_item.setFlags(elements_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+            self.tensors_table.setItem(i, 3, elements_item)
+
+            # Size
+            size_item = QTableWidgetItem(f"{tensor.n_bytes:,}")
+            size_item.setFlags(size_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+            self.tensors_table.setItem(i, 4, size_item)
+
+    def on_metadata_changed(self, item):
+        if item.column() != 2:  # Only handle value column changes
+            return
+
+        row = item.row()
+        orig_item = self.metadata_table.item(row, 0)
+        key = None
+        if orig_item:
+            key = orig_item.text()
+        new_value = item.text()
+
+        field = None
+        if self.reader and key:
+            field = self.reader.get_field(key)
+        if not field or not field.types or not key:
+            return
+
+        value_type = field.types[0]
+
+        # Check if this is an enum field
+        enum_type = self.get_enum_for_key(key)
+        if enum_type is not None and value_type == GGUFValueType.INT32:
+            # Try to parse the enum value from the text
+            try:
+                # Check if it's a name
+                try:
+                    enum_val = enum_type[new_value]
+                    converted_value = enum_val.value
+                except (KeyError, AttributeError):
+                    # Check if it's a number or "NAME (value)" format
+                    if '(' in new_value and ')' in new_value:
+                        # Extract the value from "NAME (value)" format
+                        value_part = new_value.split('(')[1].split(')')[0].strip()
+                        converted_value = int(value_part)
+                    else:
+                        # Try to convert directly to int
+                        converted_value = int(new_value)
+
+                # Validate that it's a valid enum value
+                enum_type(converted_value)
+
+                # Store the change
+                self.metadata_changes[key] = (value_type, converted_value)
+                self.modified = True
+
+                # Update display with formatted enum value
+                formatted_value = self.format_enum_value(converted_value, enum_type)
+                item.setText(formatted_value)
+
+                self.statusBar().showMessage(f"Changed {key} to {formatted_value}")
+                return
+            except (ValueError, KeyError) as e:
+                QMessageBox.warning(
+                    self,
+                    f"Invalid Enum Value ({e})",
+                    f"'{new_value}' is not a valid {enum_type.__name__} value.\n"
+                    f"Valid values are: {', '.join(v.name for v in enum_type)}")
+
+                # Revert to original value
+                original_value = self.format_field_value(field)
+                item.setText(original_value)
+                return
+
+        try:
+            # Convert the string value to the appropriate type
+            if value_type == GGUFValueType.UINT8:
+                converted_value = np.uint8(int(new_value))
+            elif value_type == GGUFValueType.INT8:
+                converted_value = np.int8(int(new_value))
+            elif value_type == GGUFValueType.UINT16:
+                converted_value = np.uint16(int(new_value))
+            elif value_type == GGUFValueType.INT16:
+                converted_value = np.int16(int(new_value))
+            elif value_type == GGUFValueType.UINT32:
+                converted_value = np.uint32(int(new_value))
+            elif value_type == GGUFValueType.INT32:
+                converted_value = np.int32(int(new_value))
+            elif value_type == GGUFValueType.FLOAT32:
+                converted_value = np.float32(float(new_value))
+            elif value_type == GGUFValueType.BOOL:
+                converted_value = new_value.lower() in ('true', 'yes', '1')
+            elif value_type == GGUFValueType.STRING:
+                converted_value = new_value
+            else:
+                # Unsupported type for editing
+                return
+
+            # Store the change
+            self.metadata_changes[key] = (value_type, converted_value)
+            self.modified = True
+
+            self.statusBar().showMessage(f"Changed {key} to {new_value}")
+        except ValueError:
+            QMessageBox.warning(self, "Invalid Value", f"The value '{new_value}' is not valid for type {value_type.name}")
+
+            # Revert to original value
+            original_value = self.format_field_value(field)
+            item.setText(original_value)
+
+    def remove_metadata(self):
+        button = self.sender()
+        key = button.property("key")
+        row = button.property("row")
+
+        reply = QMessageBox.question(
+            self, "Confirm Removal",
+            f"Are you sure you want to remove the metadata key '{key}'?",
+            QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, QMessageBox.StandardButton.No
+        )
+
+        if reply == QMessageBox.StandardButton.Yes:
+            self.metadata_table.removeRow(row)
+            self.metadata_to_remove.add(key)
+
+            # If we previously had changes for this key, remove them
+            if key in self.metadata_changes:
+                del self.metadata_changes[key]
+
+            self.modified = True
+            self.statusBar().showMessage(f"Marked {key} for removal")
+
+    def edit_metadata_enum(self):
+        """Edit an enum metadata field."""
+        button = self.sender()
+        key = button.property("key")
+        row = button.property("row")
+
+        field = None
+        if self.reader:
+            field = self.reader.get_field(key)
+        if not field or not field.types:
+            return
+
+        enum_type = self.get_enum_for_key(key)
+        if enum_type is None:
+            return
+
+        # Get current value
+        current_value = field.contents()
+
+        # Create a dialog with enum options
+        dialog = QDialog(self)
+        dialog.setWindowTitle(f"Select {enum_type.__name__} Value")
+        layout = QVBoxLayout(dialog)
+
+        combo = QComboBox()
+        for enum_val in enum_type:
+            combo.addItem(f"{enum_val.name} ({enum_val.value})", enum_val.value)
+
+        # Set current value
+        try:
+            if isinstance(current_value, (int, str)):
+                enum_val = enum_type(current_value)
+                combo.setCurrentText(f"{enum_val.name} ({current_value})")
+        except (ValueError, KeyError):
+            pass
+
+        layout.addWidget(combo)
+
+        buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
+        buttons.accepted.connect(dialog.accept)
+        buttons.rejected.connect(dialog.reject)
+        layout.addWidget(buttons)
+
+        if dialog.exec() == QDialog.DialogCode.Accepted:
+            # Get the selected value
+            new_value = combo.currentData()
+            enum_val = enum_type(new_value)
+
+            # Store the change
+            self.metadata_changes[key] = (field.types[0], new_value)
+            self.modified = True
+
+            # Update display
+            display_text = f"{enum_val.name} ({new_value})"
+            target_item = self.metadata_table.item(row, 2)
+            if target_item:
+                target_item.setText(display_text)
+
+            self.statusBar().showMessage(f"Changed {key} to {display_text}")
+
+    def edit_array_metadata(self):
+        button = self.sender()
+        key = button.property("key")
+        row = button.property("row")
+
+        # Check if this is one of the linked tokenizer keys
+        if key in TOKENIZER_LINKED_KEYS:
+            self.edit_tokenizer_metadata(key)
+            return
+
+        field = None
+        if self.reader:
+            field = self.reader.get_field(key)
+        if not field or not field.types or field.types[0] != GGUFValueType.ARRAY:
+            return
+
+        # Get array element type
+        element_type = field.types[1]
+
+        # Extract array values
+        array_values = self.extract_array_values(field)
+
+        # Open array editor dialog
+        dialog = ArrayEditorDialog(array_values, element_type, key, self)
+        if dialog.exec() == QDialog.DialogCode.Accepted:
+            new_values = dialog.get_array_values()
+
+            # Store the change
+            self.metadata_changes[key] = (GGUFValueType.ARRAY, (element_type, new_values))
+            self.modified = True
+
+            # Update display
+            enum_type = self.get_enum_for_key(key)
+            if enum_type is not None and element_type == GGUFValueType.INT32:
+                value_str = f"[ {', '.join(self.format_enum_value(v, enum_type) for v in new_values[:5])}{', ...' if len(new_values) > 5 else ''} ]"
+            else:
+                value_str = f"[ {', '.join(str(v) for v in new_values[:5])}{', ...' if len(new_values) > 5 else ''} ]"
+            target_item = self.metadata_table.item(row, 2)
+            if target_item:
+                target_item.setText(value_str)
+
+            self.statusBar().showMessage(f"Updated array values for {key}")
+
+    def edit_tokenizer_metadata(self, trigger_key):
+        """Edit the linked tokenizer metadata arrays together."""
+        if not self.reader:
+            return
+
+        # Get all three fields
+        tokens_field = self.reader.get_field(gguf.Keys.Tokenizer.LIST)
+        token_types_field = self.reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE)
+        scores_field = self.reader.get_field(gguf.Keys.Tokenizer.SCORES)
+
+        # Extract values from each field
+        tokens = self.extract_array_values(tokens_field) if tokens_field else []
+        token_types = self.extract_array_values(token_types_field) if token_types_field else []
+        scores = self.extract_array_values(scores_field) if scores_field else []
+
+        # Apply any pending changes
+        if gguf.Keys.Tokenizer.LIST in self.metadata_changes:
+            _, (_, tokens) = self.metadata_changes[gguf.Keys.Tokenizer.LIST]
+        if gguf.Keys.Tokenizer.TOKEN_TYPE in self.metadata_changes:
+            _, (_, token_types) = self.metadata_changes[gguf.Keys.Tokenizer.TOKEN_TYPE]
+        if gguf.Keys.Tokenizer.SCORES in self.metadata_changes:
+            _, (_, scores) = self.metadata_changes[gguf.Keys.Tokenizer.SCORES]
+
+        # Open the tokenizer editor dialog
+        dialog = TokenizerEditorDialog(tokens, token_types, scores, self)
+        if dialog.exec() == QDialog.DialogCode.Accepted:
+            new_tokens, new_token_types, new_scores = dialog.get_data()
+
+            # Store changes for all three arrays
+            if tokens_field:
+                self.metadata_changes[gguf.Keys.Tokenizer.LIST] = (
+                    GGUFValueType.ARRAY,
+                    (tokens_field.types[1], new_tokens)
+                )
+
+            if token_types_field:
+                self.metadata_changes[gguf.Keys.Tokenizer.TOKEN_TYPE] = (
+                    GGUFValueType.ARRAY,
+                    (token_types_field.types[1], new_token_types)
+                )
+
+            if scores_field:
+                self.metadata_changes[gguf.Keys.Tokenizer.SCORES] = (
+                    GGUFValueType.ARRAY,
+                    (scores_field.types[1], new_scores)
+                )
+
+            self.modified = True
+
+            # Update display for all three fields
+            self.update_tokenizer_display(gguf.Keys.Tokenizer.LIST, new_tokens)
+            self.update_tokenizer_display(gguf.Keys.Tokenizer.TOKEN_TYPE, new_token_types)
+            self.update_tokenizer_display(gguf.Keys.Tokenizer.SCORES, new_scores)
+
+            self.statusBar().showMessage("Updated tokenizer data")
+
+    def update_tokenizer_display(self, key, values):
+        """Update the display of a tokenizer field in the metadata table."""
+        for row in range(self.metadata_table.rowCount()):
+            key_item = self.metadata_table.item(row, 0)
+            if key_item and key_item.text() == key:
+                value_str = f"[ {', '.join(str(v) for v in values[:5])}{', ...' if len(values) > 5 else ''} ]"
+                value_item = self.metadata_table.item(row, 2)
+                if value_item:
+                    value_item.setText(value_str)
+                break
+
+    def add_metadata(self):
+        dialog = AddMetadataDialog(self)
+        if dialog.exec() == QDialog.DialogCode.Accepted:
+            key, value_type, value = dialog.get_data()
+
+            if not key:
+                QMessageBox.warning(self, "Invalid Key", "Key cannot be empty")
+                return
+
+            # Check if key already exists
+            for row in range(self.metadata_table.rowCount()):
+                orig_item = self.metadata_table.item(row, 0)
+                if orig_item and orig_item.text() == key:
+                    QMessageBox.warning(self, "Duplicate Key", f"Key '{key}' already exists")
+                    return
+
+            # Add to table
+            row = self.metadata_table.rowCount()
+            self.metadata_table.insertRow(row)
+
+            # Key
+            key_item = QTableWidgetItem(key)
+            key_item.setFlags(key_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+            self.metadata_table.setItem(row, 0, key_item)
+
+            # Type
+            type_item = QTableWidgetItem(value_type.name)
+            type_item.setFlags(type_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
+            self.metadata_table.setItem(row, 1, type_item)
+
+            # Value
+            value_item = QTableWidgetItem(str(value))
+            value_item.setFlags(value_item.flags() | Qt.ItemFlag.ItemIsEditable)
+            self.metadata_table.setItem(row, 2, value_item)
+
+            # Actions
+            actions_widget = QWidget()
+            actions_layout = QHBoxLayout(actions_widget)
+            actions_layout.setContentsMargins(2, 2, 2, 2)
+
+            remove_button = QPushButton("Remove")
+            remove_button.setProperty("row", row)
+            remove_button.setProperty("key", key)
+            remove_button.clicked.connect(self.remove_metadata)
+            actions_layout.addWidget(remove_button)
+
+            self.metadata_table.setCellWidget(row, 3, actions_widget)
+
+            # Store the change
+            self.metadata_changes[key] = (value_type, value)
+            self.modified = True
+
+            self.statusBar().showMessage(f"Added new metadata key {key}")
+
+    def save_file(self):
+        if not self.reader:
+            QMessageBox.warning(self, "No File Open", "Please open a GGUF file first")
+            return
+
+        if not self.modified and not self.metadata_changes and not self.metadata_to_remove:
+            QMessageBox.information(self, "No Changes", "No changes to save")
+            return
+
+        file_path, _ = QFileDialog.getSaveFileName(
+            self, "Save GGUF File As", "", "GGUF Files (*.gguf);;All Files (*)"
+        )
+
+        if not file_path:
+            return
+
+        try:
+            self.statusBar().showMessage(f"Saving to {file_path}...")
+            QApplication.processEvents()
+
+            # Get architecture and endianness from the original file
+            arch = 'unknown'
+            field = self.reader.get_field(gguf.Keys.General.ARCHITECTURE)
+            if field:
+                arch = field.contents()
+
+            # Create writer
+            writer = GGUFWriter(file_path, arch=arch, endianess=self.reader.endianess)
+
+            # Get alignment if present
+            alignment = None
+            field = self.reader.get_field(gguf.Keys.General.ALIGNMENT)
+            if field:
+                alignment = field.contents()
+                if alignment is not None:
+                    writer.data_alignment = alignment
+
+            # Copy metadata with changes
+            for field in self.reader.fields.values():
+                # Skip virtual fields and fields written by GGUFWriter
+                if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
+                    continue
+
+                # Skip fields marked for removal
+                if field.name in self.metadata_to_remove:
+                    continue
+
+                # Apply changes if any
+                if field.name in self.metadata_changes:
+                    value_type, value = self.metadata_changes[field.name]
+                    if value_type == GGUFValueType.ARRAY:
+                        # Handle array values
+                        element_type, array_values = value
+                        writer.add_array(field.name, array_values)
+                    else:
+                        writer.add_key_value(field.name, value, value_type)
+                else:
+                    # Copy original value
+                    value = field.contents()
+                    if value is not None and field.types:
+                        writer.add_key_value(field.name, value, field.types[0])
+
+            # Add new metadata
+            for key, (value_type, value) in self.metadata_changes.items():
+                # Skip if the key already existed (we handled it above)
+                if self.reader.get_field(key) is not None:
+                    continue
+
+                writer.add_key_value(key, value, value_type)
+
+            # Add tensors (including data)
+            for tensor in self.reader.tensors:
+                writer.add_tensor(tensor.name, tensor.data, raw_shape=tensor.data.shape, raw_dtype=tensor.tensor_type)
+
+            # Write header and metadata
+            writer.open_output_file(Path(file_path))
+            writer.write_header_to_file()
+            writer.write_kv_data_to_file()
+
+            # Write tensor data using the optimized method
+            writer.write_tensors_to_file(progress=False)
+
+            writer.close()
+
+            self.statusBar().showMessage(f"Saved to {file_path}")
+
+            # Ask if user wants to open the new file
+            reply = QMessageBox.question(
+                self, "Open Saved File",
+                "Would you like to open the newly saved file?",
+                QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, QMessageBox.StandardButton.Yes
+            )
+
+            if reply == QMessageBox.StandardButton.Yes:
+                self.reader = GGUFReader(file_path, 'r')
+                self.current_file = file_path
+                self.file_path_edit.setText(file_path)
+
+                self.load_metadata()
+                self.load_tensors()
+
+                self.metadata_changes = {}
+                self.metadata_to_remove = set()
+                self.modified = False
+
+        except Exception as e:
+            QMessageBox.critical(self, "Error", f"Failed to save file: {str(e)}")
+            self.statusBar().showMessage("Error saving file")
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(description="GUI GGUF Editor")
+    parser.add_argument("model_path", nargs="?", help="path to GGUF model file to load at startup")
+    parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
+
+    args = parser.parse_args()
+
+    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
+
+    app = QApplication(sys.argv)
+    window = GGUFEditorWindow()
+    window.show()
+
+    # Load model if specified
+    if args.model_path:
+        if os.path.isfile(args.model_path) and args.model_path.endswith('.gguf'):
+            window.load_file(args.model_path)
+        else:
+            logger.error(f"Invalid model path: {args.model_path}")
+            QMessageBox.warning(
+                window,
+                "Invalid Model Path",
+                f"The specified file does not exist or is not a GGUF file: {args.model_path}")
+
+    sys.exit(app.exec())
+
+
+if __name__ == '__main__':
+    main()
diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml
index d214e87201..9745d3ea84 100644
--- a/gguf-py/pyproject.toml
+++ b/gguf-py/pyproject.toml
@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "gguf"
-version = "0.16.0"
+version = "0.16.1"
 description = "Read and write ML models in GGUF for GGML"
 authors = ["GGML "]
 packages = [
@@ -23,10 +23,14 @@ numpy = ">=1.17"
 tqdm = ">=4.27"
 pyyaml = ">=5.1"
 sentencepiece = ">=0.1.98,<=0.2.0"
+PySide6 = { version = "^6.9", optional = true }
 
 [tool.poetry.dev-dependencies]
 pytest = "^5.2"
 
+[tool.poetry.extras]
+gui = ["PySide6"]
+
 [build-system]
 requires = ["poetry-core>=1.0.0"]
 build-backend = "poetry.core.masonry.api"
@@ -36,3 +40,4 @@ gguf-convert-endian = "gguf.scripts:gguf_convert_endian_entrypoint"
 gguf-dump = "gguf.scripts:gguf_dump_entrypoint"
 gguf-set-metadata = "gguf.scripts:gguf_set_metadata_entrypoint"
 gguf-new-metadata = "gguf.scripts:gguf_new_metadata_entrypoint"
+gguf-editor-gui = "gguf.scripts:gguf_editor_gui_entrypoint"
diff --git a/requirements/requirements-all.txt b/requirements/requirements-all.txt
index 439db88863..eba0a59f62 100644
--- a/requirements/requirements-all.txt
+++ b/requirements/requirements-all.txt
@@ -11,3 +11,5 @@
 -r ./requirements-convert_legacy_llama.txt
 -r ./requirements-convert_llama_ggml_to_gguf.txt
 -r ./requirements-tool_bench.txt
+
+-r ./requirements-gguf_editor_gui.txt
diff --git a/requirements/requirements-gguf_editor_gui.txt b/requirements/requirements-gguf_editor_gui.txt
new file mode 100644
index 0000000000..920dc7cf90
--- /dev/null
+++ b/requirements/requirements-gguf_editor_gui.txt
@@ -0,0 +1,3 @@
+numpy~=1.26.4
+PySide6~=6.9.0
+gguf>=0.16.0

From 6408210082cc0a61b992b487be7e2ff2efbb9e36 Mon Sep 17 00:00:00 2001
From: Daniel Tang 
Date: Fri, 18 Apr 2025 16:02:55 -0400
Subject: [PATCH 14/39] main : Fix Ctrl+D/newline handling (#12951)

This restores the behavior from #491. This does not affect Ctrl+D's ability to
terminate --multiline-input lines (#1040).

This also actually implements #587: "If the user wants the text to end in a
newline, this should be accomplished by explicitly adding a newline by using
\ followed by return, then returning control by pressing return again."

Fixes #12949
---
 examples/main/main.cpp | 21 ++++++++++++++++-----
 1 file changed, 16 insertions(+), 5 deletions(-)

diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index fd7410a646..c59b941bf5 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -865,9 +865,22 @@ int main(int argc, char ** argv) {
                 console::set_display(console::reset);
                 display = true;
 
-                // Add tokens to embd only if the input buffer is non-empty
-                // Entering a empty line lets the user pass control back
-                if (buffer.length() > 1) {
+                if (buffer.empty()) { // Ctrl+D on empty line exits
+                    LOG("EOF by user\n");
+                    break;
+                }
+
+                if (buffer.back() == '\n') {
+                    // Implement #587:
+                    // If the user wants the text to end in a newline,
+                    // this should be accomplished by explicitly adding a newline by using \ followed by return,
+                    // then returning control by pressing return again.
+                    buffer.pop_back();
+                }
+
+                if (buffer.empty()) { // Enter key on empty line lets the user pass control back
+                    LOG_DBG("empty line, passing control back\n");
+                } else { // Add tokens to embd only if the input buffer is non-empty
                     // append input suffix if any
                     if (!params.input_suffix.empty() && !params.conversation_mode) {
                         LOG_DBG("appending input suffix: '%s'\n", params.input_suffix.c_str());
@@ -915,8 +928,6 @@ int main(int argc, char ** argv) {
 
                     n_remain -= line_inp.size();
                     LOG_DBG("n_remain: %d\n", n_remain);
-                } else {
-                    LOG_DBG("empty line, passing control back\n");
                 }
 
                 input_echo = false; // do not echo this again

From 37b9f0d29d7301dd0ec5dfae8f357ccee96325d7 Mon Sep 17 00:00:00 2001
From: Xuan-Son Nguyen 
Date: Sat, 19 Apr 2025 09:15:45 +0200
Subject: [PATCH 15/39] clip : refactor, add `image_manipulation` and
 `llava_uhd` classes (#13011)

* clip : refactor, add `image_manipulation` and `llava_uhd`

* refactor llava-1.6 preprocessing

* simplify logic for llava-1.5

* missing include
---
 examples/llava/clip.cpp | 922 ++++++++++++++++++++--------------------
 1 file changed, 464 insertions(+), 458 deletions(-)

diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
index 49c90b7506..759706156d 100644
--- a/examples/llava/clip.cpp
+++ b/examples/llava/clip.cpp
@@ -27,6 +27,7 @@
 #include 
 #include 
 #include 
+#include 
 
 struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
 
@@ -1680,45 +1681,6 @@ bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length
     return true;
 }
 
-// Linear interpolation between two points
-inline float clip_lerp(float s, float e, float t) {
-    return s + (e - s) * t;
-}
-// Bilinear resize function
-static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) {
-    dst.nx = target_width;
-    dst.ny = target_height;
-    dst.buf.resize(3 * target_width * target_height);
-
-    float x_ratio = static_cast(src.nx - 1) / target_width;
-    float y_ratio = static_cast(src.ny - 1) / target_height;
-
-    for (int y = 0; y < target_height; y++) {
-        for (int x = 0; x < target_width; x++) {
-            float px = x_ratio * x;
-            float py = y_ratio * y;
-            int x_floor = static_cast(px);
-            int y_floor = static_cast(py);
-            float x_lerp = px - x_floor;
-            float y_lerp = py - y_floor;
-
-            for (int c = 0; c < 3; c++) {
-                float top = clip_lerp(
-                    static_cast(src.buf[3 * (y_floor * src.nx + x_floor) + c]),
-                    static_cast(src.buf[3 * (y_floor * src.nx + (x_floor + 1)) + c]),
-                    x_lerp
-                );
-                float bottom = clip_lerp(
-                    static_cast(src.buf[3 * ((y_floor + 1) * src.nx + x_floor) + c]),
-                    static_cast(src.buf[3 * ((y_floor + 1) * src.nx + (x_floor + 1)) + c]),
-                    x_lerp
-                );
-                dst.buf[3 * (y * target_width + x) + c] = static_cast(clip_lerp(top, bottom, y_lerp));
-            }
-        }
-    }
-}
-
 // Normalize image to float32 - careful with pytorch .to(model.device, dtype=torch.float16) - this sometimes reduces precision (32>16>32), sometimes not
 static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32 & dst, const float mean[3], const float std[3]) {
     dst.nx = src.nx;
@@ -1732,336 +1694,489 @@ static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32
     }
 }
 
-inline int clip(int x, int lower, int upper) {
-    return std::max(lower, std::min(x, upper));
-}
+// set of tools to manupulate images
+// in the future, we can have HW acceleration by allowing this struct to access 3rd party lib like imagick or opencv
+struct image_manipulation {
+    // Bilinear resize function
+    static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) {
+        dst.nx = target_width;
+        dst.ny = target_height;
+        dst.buf.resize(3 * target_width * target_height);
 
-static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) {
-    const int nx = img.nx;
-    const int ny = img.ny;
+        float x_ratio = static_cast(src.nx - 1) / target_width;
+        float y_ratio = static_cast(src.ny - 1) / target_height;
 
-    dst.nx = target_width;
-    dst.ny = target_height;
-    dst.buf.resize(3 * target_width * target_height);
+        for (int y = 0; y < target_height; y++) {
+            for (int x = 0; x < target_width; x++) {
+                float px = x_ratio * x;
+                float py = y_ratio * y;
+                int x_floor = static_cast(px);
+                int y_floor = static_cast(py);
+                float x_lerp = px - x_floor;
+                float y_lerp = py - y_floor;
 
-    float Cc;
-    float C[5];
-    float d0, d2, d3, a0, a1, a2, a3;
-    int i, j, k, jj;
-    int x, y;
-    float dx, dy;
-    float tx, ty;
-
-    tx = (float)nx / (float)target_width;
-    ty = (float)ny / (float)target_height;
-
-    // Bicubic interpolation; adapted from ViT.cpp, inspired from :
-    //    -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36
-    //    -> https://en.wikipedia.org/wiki/Bicubic_interpolation
-
-    for (i = 0; i < target_height; i++) {
-        for (j = 0; j < target_width; j++) {
-            x = (int)(tx * j);
-            y = (int)(ty * i);
-
-            dx = tx * j - x;
-            dy = ty * i - y;
-
-            for (k = 0; k < 3; k++) {
-                for (jj = 0; jj <= 3; jj++) {
-                    d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
-                    d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
-                    d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
-                    a0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
-
-                    a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3;
-                    a2 =  1.0 / 2 * d0 +      1.0 / 2 * d2;
-                    a3 = -1.0 / 6 * d0 -      1.0 / 2 * d2 + 1.0 / 6 * d3;
-
-                    C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx;
-
-                    d0 = C[0] - C[1];
-                    d2 = C[2] - C[1];
-                    d3 = C[3] - C[1];
-                    a0 = C[1];
-                    a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3;
-                    a2 =  1.0 / 2 * d0 +      1.0 / 2 * d2;
-                    a3 = -1.0 / 6 * d0 -      1.0 / 2 * d2 + 1.0 / 6 * d3;
-                    Cc = a0 + a1 * dy + a2 * dy * dy + a3 * dy * dy * dy;
-
-                    const uint8_t Cc2 = std::min(std::max(std::round(Cc), 0.0f), 255.0f);
-                    dst.buf[(i * target_width + j) * 3 + k] = float(Cc2);
+                for (int c = 0; c < 3; c++) {
+                    float top = lerp(
+                        static_cast(src.buf[3 * (y_floor * src.nx + x_floor) + c]),
+                        static_cast(src.buf[3 * (y_floor * src.nx + (x_floor + 1)) + c]),
+                        x_lerp
+                    );
+                    float bottom = lerp(
+                        static_cast(src.buf[3 * ((y_floor + 1) * src.nx + x_floor) + c]),
+                        static_cast(src.buf[3 * ((y_floor + 1) * src.nx + (x_floor + 1)) + c]),
+                        x_lerp
+                    );
+                    dst.buf[3 * (y * target_width + x) + c] = static_cast(lerp(top, bottom, y_lerp));
                 }
             }
         }
     }
 
-    return true;
-}
+    // Bicubic resize function
+    // part of image will be cropped if the aspect ratio is different
+    static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) {
+        const int nx = img.nx;
+        const int ny = img.ny;
 
-// llava-1.6 type of resize_and_pad (black)
-static void resize_and_pad_image(const clip_image_u8& image, clip_image_u8 &image_output, const std::pair& target_resolution) {
-    int target_width = target_resolution.first;
-    int target_height = target_resolution.second;
+        dst.nx = target_width;
+        dst.ny = target_height;
+        dst.buf.resize(3 * target_width * target_height);
 
-    float scale_w = static_cast(target_width) / image.nx;
-    float scale_h = static_cast(target_height) / image.ny;
+        float Cc;
+        float C[5];
+        float d0, d2, d3, a0, a1, a2, a3;
+        int i, j, k, jj;
+        int x, y;
+        float dx, dy;
+        float tx, ty;
 
-    int new_width, new_height;
+        tx = (float)nx / (float)target_width;
+        ty = (float)ny / (float)target_height;
 
-    if (scale_w < scale_h) {
-        new_width = target_width;
-        new_height = std::min(static_cast(std::ceil(image.ny * scale_w)), target_height);
-    } else {
-        new_height = target_height;
-        new_width = std::min(static_cast(std::ceil(image.nx * scale_h)), target_width);
+        // Bicubic interpolation; adapted from ViT.cpp, inspired from :
+        //    -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36
+        //    -> https://en.wikipedia.org/wiki/Bicubic_interpolation
+
+        for (i = 0; i < target_height; i++) {
+            for (j = 0; j < target_width; j++) {
+                x = (int)(tx * j);
+                y = (int)(ty * i);
+
+                dx = tx * j - x;
+                dy = ty * i - y;
+
+                for (k = 0; k < 3; k++) {
+                    for (jj = 0; jj <= 3; jj++) {
+                        d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
+                        d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
+                        d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
+                        a0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k];
+
+                        a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3;
+                        a2 =  1.0 / 2 * d0 +      1.0 / 2 * d2;
+                        a3 = -1.0 / 6 * d0 -      1.0 / 2 * d2 + 1.0 / 6 * d3;
+
+                        C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx;
+
+                        d0 = C[0] - C[1];
+                        d2 = C[2] - C[1];
+                        d3 = C[3] - C[1];
+                        a0 = C[1];
+                        a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3;
+                        a2 =  1.0 / 2 * d0 +      1.0 / 2 * d2;
+                        a3 = -1.0 / 6 * d0 -      1.0 / 2 * d2 + 1.0 / 6 * d3;
+                        Cc = a0 + a1 * dy + a2 * dy * dy + a3 * dy * dy * dy;
+
+                        const uint8_t Cc2 = std::min(std::max(std::round(Cc), 0.0f), 255.0f);
+                        dst.buf[(i * target_width + j) * 3 + k] = float(Cc2);
+                    }
+                }
+            }
+        }
+
+        return true;
     }
 
-    clip_image_u8 resized_image;
-    // bilinear_resize(image, resized_image, new_width, new_height);
-    bicubic_resize(image, resized_image, new_width, new_height);
+    // llava-1.6 type of resize_and_pad
+    // if the ratio is not 1:1, padding with pad_color will be applied
+    // pad_color is single channel, default is 0 (black)
+    static void resize_and_pad_image(const clip_image_u8 & image, clip_image_u8 & dst, const clip_image_size & target_resolution, std::array pad_color = {0, 0, 0}) {
+        int target_width  = target_resolution.width;
+        int target_height = target_resolution.height;
 
-    clip_image_u8 padded_image;
-    padded_image.nx = target_width;
-    padded_image.ny = target_height;
-    padded_image.buf.resize(3 * target_width * target_height, 0); // Initialize with black
+        float scale_w = static_cast(target_width) / image.nx;
+        float scale_h = static_cast(target_height) / image.ny;
 
-    // Calculate padding offsets
-    int pad_x = (target_width - new_width) / 2;
-    int pad_y = (target_height - new_height) / 2;
+        int new_width, new_height;
 
-    // Copy the resized image into the center of the padded buffer
-    for (int y = 0; y < new_height; ++y) {
-        for (int x = 0; x < new_width; ++x) {
-            for (int c = 0; c < 3; ++c) {
-                padded_image.buf[3 * ((y + pad_y) * target_width + (x + pad_x)) + c] = resized_image.buf[3 * (y * new_width + x) + c];
+        if (scale_w < scale_h) {
+            new_width  = target_width;
+            new_height = std::min(static_cast(std::ceil(image.ny * scale_w)), target_height);
+        } else {
+            new_height = target_height;
+            new_width  = std::min(static_cast(std::ceil(image.nx * scale_h)), target_width);
+        }
+
+        clip_image_u8 resized_image;
+        bicubic_resize(image, resized_image, new_width, new_height);
+
+        clip_image_u8 padded_image;
+        padded_image.nx = target_width;
+        padded_image.ny = target_height;
+        padded_image.buf.resize(3 * target_width * target_height);
+
+        // Fill the padded image with the fill color
+        for (size_t i = 0; i < padded_image.buf.size(); i += 3) {
+            padded_image.buf[i]     = pad_color[0];
+            padded_image.buf[i + 1] = pad_color[1];
+            padded_image.buf[i + 2] = pad_color[2];
+        }
+
+        // Calculate padding offsets
+        int pad_x = (target_width  - new_width)  / 2;
+        int pad_y = (target_height - new_height) / 2;
+
+        // Copy the resized image into the center of the padded buffer
+        for (int y = 0; y < new_height; ++y) {
+            for (int x = 0; x < new_width; ++x) {
+                for (int c = 0; c < 3; ++c) {
+                    padded_image.buf[3 * ((y + pad_y) * target_width + (x + pad_x)) + c] = resized_image.buf[3 * (y * new_width + x) + c];
+                }
+            }
+        }
+        dst = std::move(padded_image);
+    }
+
+    static void crop_image(const clip_image_u8 & image, clip_image_u8 & dst, int x, int y, int w, int h) {
+        dst.nx = w;
+        dst.ny = h;
+        dst.buf.resize(3 * w * h);
+
+        for (int i = 0; i < h; ++i) {
+            for (int j = 0; j < w; ++j) {
+                int src_idx = 3 * ((y + i)*image.nx + (x + j));
+                int dst_idx = 3 * (i*w + j);
+                dst.buf[dst_idx]     = image.buf[src_idx];
+                dst.buf[dst_idx + 1] = image.buf[src_idx + 1];
+                dst.buf[dst_idx + 2] = image.buf[src_idx + 2];
             }
         }
     }
-    image_output = std::move(padded_image);
-}
+
+private:
+    static inline int clip(int x, int lower, int upper) {
+        return std::max(lower, std::min(x, upper));
+    }
+
+    // Linear interpolation between two points
+    static inline float lerp(float s, float e, float t) {
+        return s + (e - s) * t;
+    }
+};
 
 /**
- * Selects the best resolution from a list of possible resolutions based on the original size.
+ * implementation of LLaVA-UHD:
+ *  - https://arxiv.org/pdf/2403.11703
+ *  - https://github.com/thunlp/LLaVA-UHD
+ *  - https://github.com/thunlp/LLaVA-UHD/blob/302301bc2175f7e717fb8548516188e89f649753/llava_uhd/train/llava-uhd/slice_logic.py#L118
  *
- * @param original_size The original size of the image in the format (width, height).
- * @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
- * @return The best fit resolution in the format (width, height).
+ * overview:
+ *   - an image always have a single overview (downscaled image)
+ *   - an image can have 0 or multiple slices, depending on the image size
+ *   - each slice can then be considered as a separate image
+ *
+ * for example:
+ *
+ * [overview] --> [slice 1] --> [slice 2]
+ *           |                |
+ *           +--> [slice 3] --> [slice 4]
  */
-static std::pair select_best_resolution(const std::pair & original_size, const std::vector> & possible_resolutions) {
-    int original_width = original_size.first;
-    int original_height = original_size.second;
-    std::pair best_fit;
-    int max_effective_resolution = 0;
-    int min_wasted_resolution = std::numeric_limits::max();
+struct llava_uhd {
+    struct slice_coordinates {
+        int x;
+        int y;
+        clip_image_size size;
+    };
 
-    for (const auto& resolution : possible_resolutions) {
-        int width = resolution.first;
-        int height = resolution.second;
-        float scale = std::min(static_cast(width) / original_width, static_cast(height) / original_height);
-        int downscaled_width = static_cast(original_width * scale);
-        int downscaled_height = static_cast(original_height * scale);
-        int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height);
-        int wasted_resolution = (width * height) - effective_resolution;
-        // LOG_INF("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution);
-        if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) {
-            max_effective_resolution = effective_resolution;
-            min_wasted_resolution = wasted_resolution;
-            best_fit = resolution;
+    struct slice_instructions {
+        clip_image_size overview_size; // size of downscaled image
+        clip_image_size refined_size;  // size of image right before slicing (must be multiple of slice size)
+        clip_image_size grid_size;     // grid_size.width * grid_size.height = number of slices
+        std::vector slices;
+        bool padding_refined = false;  // if true, refine image will be padded to the grid size (e.g. llava-1.6)
+    };
+
+    static int get_max_slices(struct clip_ctx * ctx) {
+        if (clip_is_minicpmv(ctx)) {
+            return 9;
         }
+        return 0;
     }
 
-    return best_fit;
-}
+    static slice_instructions get_slice_instructions(struct clip_ctx * ctx, const clip_image_size & original_size) {
+        slice_instructions res;
+        const int patch_size      = clip_get_patch_size(ctx);
+        const int slice_size      = clip_get_image_size(ctx);
+        const int max_slice_nums  = get_max_slices(ctx);
+        const int original_width  = original_size.width;
+        const int original_height = original_size.height;
+        const float log_ratio = log((float)original_width / original_height);
+        const float ratio = (float)original_width * original_height / (slice_size * slice_size);
+        const int multiple = fmin(ceil(ratio), max_slice_nums);
+        const bool has_slices = (multiple > 1);
+        const bool has_pinpoints = !ctx->vision_model.hparams.image_grid_pinpoints.empty();
 
-static std::vector divide_to_patches_u8(const clip_image_u8 & image, int patch_size) {
-    std::vector patches;
-    int width = image.nx;
-    int height = image.ny;
-    for (int i = 0; i < height; i += patch_size) {
-        for (int j = 0; j < width; j += patch_size) {
-            clip_image_u8_ptr patch(clip_image_u8_init());
-            patch->nx = std::min(patch_size, width - j);
-            patch->ny = std::min(patch_size, height - i);
-            patch->buf.resize(3 * patch->nx * patch->ny);
-            for (int y = 0; y < patch->ny; ++y) {
-                for (int x = 0; x < patch->nx; ++x) {
-                    for (int c = 0; c < 3; ++c) {
-                        patch->buf[3 * (y * patch->nx + x) + c] = image.buf[3 * ((i + y) * width + (j + x)) + c];
+        if (has_pinpoints) {
+            // has pinpoints, use them to calculate the grid size (e.g. llava-1.6)
+            auto refine_size = llava_uhd::select_best_resolution(
+                ctx->vision_model.hparams.image_grid_pinpoints,
+                original_size);
+            res.overview_size   = clip_image_size{slice_size, slice_size};
+            res.refined_size    = refine_size;
+            res.grid_size       = clip_image_size{0, 0};
+            res.padding_refined = true;
+
+            for (int y = 0; y < refine_size.height; y += slice_size) {
+                for (int x = 0; x < refine_size.width; x += slice_size) {
+                    slice_coordinates slice;
+                    slice.x = x;
+                    slice.y = y;
+                    slice.size.width  = std::min(slice_size, refine_size.width  - x);
+                    slice.size.height = std::min(slice_size, refine_size.height - y);
+                    res.slices.push_back(slice);
+                    if (x == 0) {
+                        res.grid_size.width++;
                     }
                 }
+                res.grid_size.height++;
             }
-            patches.push_back(std::move(patch));
+
+            return res;
         }
-    }
-    return patches;
-}
 
-static int ensure_divide(int length, int patch_size) {
-    return std::max(static_cast(std::round(static_cast(length) / patch_size) * patch_size), patch_size);
-}
+        // no pinpoints, dynamically calculate the grid size (e.g. minicpmv)
 
-static std::pair uhd_find_best_resize(std::pair original_size, int scale_resolution, int patch_size, bool allow_upscale = false) {
-    int width = original_size.first;
-    int height = original_size.second;
-    if ((width * height > scale_resolution * scale_resolution) || allow_upscale) {
-        float r = static_cast(width) / height;
-        height = static_cast(scale_resolution / std::sqrt(r));
-        width = static_cast(height * r);
-    }
-    int best_width = ensure_divide(width, patch_size);
-    int best_height = ensure_divide(height, patch_size);
-    return std::make_pair(best_width, best_height);
-}
+        auto best_size    = get_best_resize(original_size, slice_size, patch_size, has_slices);
+        res.overview_size = best_size;
 
-static std::pair uhd_get_refine_size(std::pair original_size, std::pair grid, int scale_resolution, int patch_size, bool allow_upscale = false) {
-    int width, height;
-    std::tie(width, height) = original_size;
-    int grid_x, grid_y;
-    std::tie(grid_x, grid_y) = grid;
+        if (!has_slices) {
+            // skip slicing logic
+            res.refined_size = clip_image_size{0, 0};
+            res.grid_size    = clip_image_size{0, 0};
 
-    int refine_width = ensure_divide(width, grid_x);
-    int refine_height = ensure_divide(height, grid_y);
+        } else {
+            auto best_grid   = get_best_grid(max_slice_nums, multiple, log_ratio);
+            auto refine_size = get_refine_size(original_size, best_grid, slice_size, patch_size, true);
+            res.grid_size    = best_grid;
+            res.refined_size = refine_size;
 
-    int grid_width = refine_width / grid_x;
-    int grid_height = refine_height / grid_y;
-
-   // auto best_grid_size = find_best_resize(std::make_tuple(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); (old line)
-    auto best_grid_size = uhd_find_best_resize(std::make_pair(grid_width, grid_height), scale_resolution, patch_size, allow_upscale); // (new line) => fixes conversion for make_tuple to make_pair
-    int best_grid_width, best_grid_height;
-    std::tie(best_grid_width, best_grid_height) = best_grid_size;
-
-  //  std::pair refine_size = std::make_tuple(best_grid_width * grid_x, best_grid_height * grid_y); (old line)
-    std::pair refine_size = std::make_pair(best_grid_width * grid_x, best_grid_height * grid_y); // (new line)
-    return refine_size;
-}
-
-static std::pair uhd_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) {
-    std::vector candidate_split_grids_nums;
-    for (int i : {multiple - 1, multiple, multiple + 1}) {
-        if (i == 1 || i > max_slice_nums) {
-            continue;
-        }
-        candidate_split_grids_nums.push_back(i);
-    }
-
-    std::vector> candidate_grids;
-    for (int split_grids_nums : candidate_split_grids_nums) {
-        int m = 1;
-        while (m <= split_grids_nums) {
-            if (split_grids_nums % m == 0) {
-                candidate_grids.emplace_back(m, split_grids_nums / m);
-            }
-            ++m;
-        }
-    }
-
-    std::pair best_grid{1, 1};
-    float min_error = std::numeric_limits::infinity();
-    for (const auto& grid : candidate_grids) {
-        float error = std::abs(log_ratio - std::log(1.0 * grid.first / grid.second));
-        if (error < min_error) {
-            best_grid = grid;
-            min_error = error;
-        }
-    }
-    return best_grid;
-}
-
-// inspired from LLaVA-UHD:
-//    -> https://arxiv.org/pdf/2403.11703
-//    -> https://github.com/thunlp/LLaVA-UHD
-//    -> https://github.com/thunlp/LLaVA-UHD/blob/302301bc2175f7e717fb8548516188e89f649753/llava_uhd/train/llava-uhd/slice_logic.py#L118
-static std::vector> uhd_slice_image(const clip_image_u8 * img, const int max_slice_nums=9, const int scale_resolution=448, const int patch_size=14) {
-    const std::pair original_size={img->nx,img->ny};
-    const int original_width = img->nx;
-    const int original_height = img->ny;
-    const float log_ratio = log(1.0*original_width/original_height);
-    const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution);
-    const int multiple = fmin(ceil(ratio), max_slice_nums);
-
-    std::vector> images;
-    LOG_DBG("%s: multiple %d\n", __func__, multiple);
-    images.push_back(std::vector());
-
-    if (multiple <= 1) {
-        auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size, true);
-        clip_image_u8_ptr source_image(clip_image_u8_init());
-        bicubic_resize(*img, *source_image, best_size.first, best_size.second);
-        // source_image = image.resize(best_size, Image.Resampling.BICUBIC)
-        images.back().push_back(std::move(source_image));
-    }
-    else if (multiple > 1) {
-        auto best_size = uhd_find_best_resize(original_size, scale_resolution, patch_size);
-        clip_image_u8_ptr source_image(clip_image_u8_init());
-        bicubic_resize(*img, *source_image, best_size.first, best_size.second);
-        // source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
-        LOG_DBG("%s: image_size: %d %d; source_image size: %d %d\n", __func__, img->nx, img->ny, best_size.first, best_size.second);
-        images.back().push_back(std::move(source_image));
-
-        std::pair best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio);
-        LOG_DBG("%s: image_size: %d %d; best_grid: %d %d\n", __func__, img->nx, img->ny, best_grid.first, best_grid.second);
-
-        auto refine_size = uhd_get_refine_size(original_size, best_grid, scale_resolution, patch_size, true);
-        clip_image_u8_ptr refine_image(clip_image_u8_init());
-        bicubic_resize(*img, *refine_image, refine_size.first, refine_size.second);
-
-        LOG_DBG("%s: refine_image_size: %d %d; refine_size: %d %d\n", __func__, refine_image->nx, refine_image->ny, refine_size.first, refine_size.second);
-
-        // split_to_patches
-        int width = refine_image->nx;
-        int height = refine_image->ny;
-        int grid_x = int(width / best_grid.first);
-        int grid_y = int(height / best_grid.second);
-        for (int patches_i = 0, ic = 0; patches_i < height && ic < best_grid.second; patches_i += grid_y, ic += 1){
-            images.push_back(std::vector());
-            for(int patches_j = 0, jc = 0; patches_j < width && jc < best_grid.first; patches_j += grid_x, jc += 1){
-                clip_image_u8_ptr patch(clip_image_u8_init());
-                patch->nx = grid_x;
-                patch->ny = grid_y;
-                patch->buf.resize(3 * patch->nx * patch->ny);
-                for (int y = patches_i; y < patches_i + grid_y; ++y) {
-                    for (int x = patches_j; x < patches_j + grid_x; ++x) {
-                        const int i = 3 * (y * refine_image->nx + x);
-                        const int j = 3 * ((y-patches_i) * patch->nx + (x-patches_j));
-                        patch->buf[j]   = refine_image->buf[i];
-                        patch->buf[j+1] = refine_image->buf[i+1];
-                        patch->buf[j+2] = refine_image->buf[i+2];
-                    }
+            int width  = refine_size.width;
+            int height = refine_size.height;
+            int grid_x = int(width  / best_grid.width);
+            int grid_y = int(height / best_grid.height);
+            for (int patches_y = 0,                    ic = 0;
+                    patches_y < refine_size.height && ic < best_grid.height;
+                    patches_y += grid_y,              ic += 1) {
+                for (int patches_x = 0,                   jc = 0;
+                        patches_x < refine_size.width && jc < best_grid.width;
+                        patches_x += grid_x,             jc += 1) {
+                    slice_coordinates slice;
+                    slice.x = patches_x;
+                    slice.y = patches_y;
+                    slice.size.width  = grid_x;
+                    slice.size.height = grid_y;
+                    res.slices.push_back(slice);
+                    // LOG_INF("slice %d: %d %d %d %d\n", ic, patches_i, patches_j, grid_x, grid_y);
                 }
-                images.back().push_back(std::move(patch));
             }
         }
-    }
-    return images;
-}
 
+        return res;
+    }
+
+    static std::vector slice_image(const clip_image_u8 * img, const slice_instructions & inst) {
+        std::vector output;
+
+        // resize to overview size
+        clip_image_u8_ptr resized_img(clip_image_u8_init());
+        image_manipulation::bicubic_resize(*img, *resized_img, inst.overview_size.width, inst.overview_size.height);
+        output.push_back(std::move(resized_img));
+        if (inst.slices.empty()) {
+            // no slices, just return the resized image
+            return output;
+        }
+
+        // resize to refined size
+        clip_image_u8_ptr refined_img(clip_image_u8_init());
+        if (inst.padding_refined) {
+            image_manipulation::resize_and_pad_image(*img, *refined_img, inst.refined_size);
+        } else {
+            image_manipulation::bilinear_resize(*img, *refined_img, inst.refined_size.width, inst.refined_size.height);
+        }
+
+        // create slices
+        for (const auto & slice : inst.slices) {
+            int x = slice.x;
+            int y = slice.y;
+            int w = slice.size.width;
+            int h = slice.size.height;
+
+            clip_image_u8_ptr img_slice(clip_image_u8_init());
+            image_manipulation::crop_image(*refined_img, *img_slice, x, y, w, h);
+            output.push_back(std::move(img_slice));
+        }
+
+        return output;
+    }
+
+private:
+    static clip_image_size get_best_resize(const clip_image_size & original_size, int scale_resolution, int patch_size, bool allow_upscale = false) {
+        int width  = original_size.width;
+        int height = original_size.height;
+        if ((width * height > scale_resolution * scale_resolution) || allow_upscale) {
+            float r = static_cast(width) / height;
+            height  = static_cast(scale_resolution / std::sqrt(r));
+            width   = static_cast(height * r);
+        }
+        clip_image_size res;
+        res.width  = ensure_divide(width,  patch_size);
+        res.height = ensure_divide(height, patch_size);
+        return res;
+    }
+
+    /**
+     * Selects the best resolution from a list of possible resolutions based on the original size.
+     *
+     * @param original_size The original size of the image
+     * @param possible_resolutions A list of possible resolutions
+     * @return The best fit resolution
+     */
+    static clip_image_size select_best_resolution(const clip_image_size & original_size, const std::vector & possible_resolutions) {
+        int original_width = original_size.width;
+        int original_height = original_size.height;
+        clip_image_size best_fit;
+        int max_effective_resolution = 0;
+        int min_wasted_resolution = std::numeric_limits::max();
+
+        for (const auto & resolution : possible_resolutions) {
+            int width  = resolution.width;
+            int height = resolution.height;
+            float scale = std::min(static_cast(width) / original_width, static_cast(height) / original_height);
+            int downscaled_width  = static_cast(original_width * scale);
+            int downscaled_height = static_cast(original_height * scale);
+            int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height);
+            int wasted_resolution = (width * height) - effective_resolution;
+            // LOG_INF("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution);
+            if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) {
+                max_effective_resolution = effective_resolution;
+                min_wasted_resolution = wasted_resolution;
+                best_fit = resolution;
+            }
+        }
+
+        return best_fit;
+    }
+
+    // used by llava 1.6 with custom list of pinpoints
+    static clip_image_size select_best_resolution(const std::vector & pinpoints, const clip_image_size & original_size) {
+        std::vector possible_resolutions;
+        for (size_t i = 0; i < pinpoints.size(); i += 2) {
+            possible_resolutions.push_back(clip_image_size{pinpoints[i], pinpoints[i+1]});
+        }
+        return select_best_resolution(original_size, possible_resolutions);
+    }
+
+    static int ensure_divide(int length, int patch_size) {
+        return std::max(static_cast(std::round(static_cast(length) / patch_size) * patch_size), patch_size);
+    }
+
+    static clip_image_size get_refine_size(const clip_image_size & original_size, const clip_image_size & grid, int scale_resolution, int patch_size, bool allow_upscale = false) {
+        int width  = original_size.width;
+        int height = original_size.height;
+        int grid_x = grid.width;
+        int grid_y = grid.height;
+
+        int refine_width  = ensure_divide(width, grid_x);
+        int refine_height = ensure_divide(height, grid_y);
+
+        clip_image_size grid_size;
+        grid_size.width  = refine_width  / grid_x;
+        grid_size.height = refine_height / grid_y;
+
+        auto best_grid_size  = get_best_resize(grid_size, scale_resolution, patch_size, allow_upscale);
+        int best_grid_width  = best_grid_size.width;
+        int best_grid_height = best_grid_size.height;
+
+        clip_image_size refine_size;
+        refine_size.width  = best_grid_width  * grid_x;
+        refine_size.height = best_grid_height * grid_y;
+        return refine_size;
+    }
+
+    static clip_image_size get_best_grid(const int max_slice_nums, const int multiple, const float log_ratio) {
+        std::vector candidate_split_grids_nums;
+        for (int i : {multiple - 1, multiple, multiple + 1}) {
+            if (i == 1 || i > max_slice_nums) {
+                continue;
+            }
+            candidate_split_grids_nums.push_back(i);
+        }
+
+        std::vector candidate_grids;
+        for (int split_grids_nums : candidate_split_grids_nums) {
+            int m = 1;
+            while (m <= split_grids_nums) {
+                if (split_grids_nums % m == 0) {
+                    candidate_grids.push_back(clip_image_size{m, split_grids_nums / m});
+                }
+                ++m;
+            }
+        }
+
+        clip_image_size best_grid{1, 1};
+        float min_error = std::numeric_limits::infinity();
+        for (const auto& grid : candidate_grids) {
+            float error = std::abs(log_ratio - std::log(1.0 * grid.width / grid.height));
+            if (error < min_error) {
+                best_grid = grid;
+                min_error = error;
+            }
+        }
+        return best_grid;
+    }
+};
+
+// TODO @ngxson : decprecate the load_image_size singleton pattern
 int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) {
-    const int max_slice_nums=9;
-    const int scale_resolution=448;
-    const int original_width = ctx_clip->load_image_size.width;
-    const int original_height = ctx_clip->load_image_size.height;
-    const float log_ratio = log(1.0*original_width/original_height);
-    const float ratio = 1.0 * original_width * original_height/ (scale_resolution * scale_resolution);
-    const int multiple = fmin(ceil(ratio), max_slice_nums);
-    std::pair best_grid = uhd_best_grid(max_slice_nums, multiple, log_ratio);
-    return best_grid.first;
+    const auto inst = llava_uhd::get_slice_instructions(ctx_clip, ctx_clip->load_image_size);
+    return inst.grid_size.width;
 }
 
 // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
 // res_imgs memory is being allocated here, previous allocations will be freed if found
 bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
+    if (!ctx->has_vision_encoder) {
+        LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__);
+        return false;
+    }
+
+    clip_image_size original_size{img->nx, img->ny};
+    bool pad_to_square = true;
+    auto & params = ctx->vision_model.hparams;
+    // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing
+    if (params.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD) {
+        pad_to_square = false;
+    }
 
     if (clip_is_minicpmv(ctx)) {
-        int max_slice_nums = 9;
-        std::vector> imgs = uhd_slice_image(img, max_slice_nums);
+        auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
+        std::vector imgs = llava_uhd::slice_image(img, inst);
+
         for (size_t i = 0; i < imgs.size(); ++i) {
-            for (size_t j = 0; j < imgs[i].size(); ++j) {
-                LOG_DBG("%s: %d %d\n", __func__,imgs[i][j]->nx,imgs[i][j]->ny);
-                clip_image_f32_ptr res(clip_image_f32_init());
-                normalize_image_u8_to_f32(*imgs[i][j], *res, ctx->image_mean, ctx->image_std);
-                res_imgs->entries.push_back(std::move(res));
-            }
+            // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
+            clip_image_f32_ptr res(clip_image_f32_init());
+            normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std);
+            res_imgs->entries.push_back(std::move(res));
         }
         return true;
     }
@@ -2070,7 +2185,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
         auto patch_size = clip_get_patch_size(ctx) * 2;
         int nx = ceil((float)img->nx / patch_size) * patch_size;
         int ny = ceil((float)img->ny / patch_size) * patch_size;
-        bicubic_resize(*img, resized, nx, ny);
+        image_manipulation::bicubic_resize(*img, resized, nx, ny);
 
         clip_image_f32_ptr img_f32(clip_image_f32_init());
         // clip_image_f32_ptr res(clip_image_f32_init());
@@ -2082,8 +2197,8 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
 
     if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
         clip_image_u8 resized_image;
-        int32_t sz=ctx->vision_model.hparams.image_size;
-        bicubic_resize(*img, resized_image,sz,sz);
+        int sz = params.image_size;
+        image_manipulation::bicubic_resize(*img, resized_image, sz, sz);
         clip_image_f32_ptr img_f32(clip_image_f32_init());
         //clip_image_save_to_bmp(resized_image, "resized.bmp");
         normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
@@ -2091,156 +2206,47 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
         return true;
     }
 
-    bool pad_to_square = true;
-    if (!ctx->has_vision_encoder) {
-        LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__);
-        return false;
-    }
-    auto & params = ctx->vision_model.hparams;
-    // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing
-    if (params.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD) {
-        pad_to_square = false;
-    }
-    // free the previous res_imgs if any set
-    res_imgs->entries.clear();
-
     // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
     // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
 
     clip_image_u8_ptr temp(clip_image_u8_init()); // we will keep the input image data here temporarily
-    if (pad_to_square && img->nx != img->ny) {
-        int longer_side = std::max(img->nx, img->ny);
+
+    if (pad_to_square) {
+        // for llava-1.5, we resize image to a square, and pad the shorter side with a background color
+        // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
+        const int longer_side = std::max(img->nx, img->ny);
         temp->nx = longer_side;
         temp->ny = longer_side;
         temp->buf.resize(3 * longer_side * longer_side);
-        const uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA (this is the mean rgb color * 255)
 
-        // fill with background color
-        for (size_t i = 0; i < temp->buf.size(); i++) {
-            temp->buf[i] = bc[i % 3];
+        // background color in RGB from LLaVA (this is the mean rgb color * 255)
+        const std::array pad_color = {122, 116, 104};
+
+        // resize the image to the target_size
+        image_manipulation::resize_and_pad_image(*img, *temp, clip_image_size{params.image_size, params.image_size}, pad_color);
+
+        clip_image_f32_ptr res(clip_image_f32_init());
+        normalize_image_u8_to_f32(*temp, *res, ctx->image_mean, ctx->image_std);
+        res_imgs->entries.push_back(std::move(res));
+        return true;
+
+    } else if (!params.image_grid_pinpoints.empty()) {
+        // "spatial_unpad" with "anyres" processing for llava-1.6
+        auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
+        std::vector imgs = llava_uhd::slice_image(img, inst);
+
+        for (size_t i = 0; i < imgs.size(); ++i) {
+            // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
+            clip_image_f32_ptr res(clip_image_f32_init());
+            normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std);
+            res_imgs->entries.push_back(std::move(res));
         }
 
-        // copy from the input image
-        for (int y = 0; y < img->ny; y++) {
-            for (int x = 0; x < img->nx; x++) {
-                const int i = 3 * (y * img->nx + x);
-                const int j = 3 * (y * temp->nx + x);
-                temp->buf[j]   = img->buf[i];
-                temp->buf[j+1] = img->buf[i+1];
-                temp->buf[j+2] = img->buf[i+2];
-            }
-        }
-    } else {
-        if (!params.image_grid_pinpoints.empty()) {
-            // "spatial_unpad" with "anyres" processing for llava-1.6
-            std::vector> possible_resolutions;
-            for (size_t i = 0; i < params.image_grid_pinpoints.size(); i+=2) {
-                possible_resolutions.push_back({params.image_grid_pinpoints[i], params.image_grid_pinpoints[i+1]});
-            }
-            std::pair best_resolution = select_best_resolution({img->nx, img->ny}, possible_resolutions);
-            // clip_image_save_to_bmp(*img, "input.bmp");
-            resize_and_pad_image(*img, *temp, best_resolution);  // we do not pad with mean-bg color anymore in llava-1.6
-            // clip_image_save_to_bmp(*temp, "resized.bmp");
-            // visually verify normalized image:
-            // normalize_image_u8_to_f32(*temp, *res, ctx->image_mean, ctx->image_std);
-            // {
-            //     clip_image_u8 * temp2 = clip_image_u8_init();
-            //     clip_image_convert_f32_to_u8(*res, *temp2);
-            //     clip_image_save_to_bmp(*temp2, "resized_normalized_f32.bmp");
-            //     clip_image_u8_free(temp2);
-            // }
+        return true;
 
-            std::vector patches = divide_to_patches_u8(*temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6)
-
-            clip_image_u8_ptr image_original_resize(clip_image_u8_init());
-            // bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
-            bicubic_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square
-            patches.insert(patches.begin(), std::move(image_original_resize));
-            for (auto & patch : patches) {
-                clip_image_f32_ptr res(clip_image_f32_init());
-                normalize_image_u8_to_f32(*patch, *res, ctx->image_mean, ctx->image_std);
-                res_imgs->entries.push_back(std::move(res));
-            }
-
-            return true;
-        } else {
-            temp->nx = img->nx;
-            temp->ny = img->ny;
-            temp->buf.resize(img->buf.size());
-            memcpy(temp->buf.data(), img->buf.data(), temp->buf.size());
-        }
     }
 
-    const int nx = temp->nx;
-    const int ny = temp->ny;
-    // clip_image_save_to_bmp(*temp, "resized_vanilla.bmp");
-
-    const int nx2 = ctx->vision_model.hparams.image_size;
-    const int ny2 = ctx->vision_model.hparams.image_size;
-    clip_image_f32_ptr res(clip_image_f32_init());
-    res->nx = nx2;
-    res->ny = ny2;
-    res->buf.resize(3 * nx2 * ny2);
-
-    const float scale = std::max(nx, ny) / (float)ctx->vision_model.hparams.image_size;
-
-    const int nx3 = int(nx / scale + 0.5f);
-    const int ny3 = int(ny / scale + 0.5f);
-
-    const auto & m3 = ctx->image_mean; // {0.48145466f, 0.4578275f, 0.40821073f};
-    const auto & s3 = ctx->image_std;  // {0.26862954f, 0.26130258f, 0.27577711f};
-
-    for (int y = 0; y < ny3; y++) {
-        for (int x = 0; x < nx3; x++) {
-            for (int c = 0; c < 3; c++) {
-                // linear interpolation
-                const float sx = (x + 0.5f) * scale - 0.5f;
-                const float sy = (y + 0.5f) * scale - 0.5f;
-
-                const int x0 = std::max(0, (int)std::floor(sx));
-                const int y0 = std::max(0, (int)std::floor(sy));
-
-                const int x1 = std::min(x0 + 1, nx - 1);
-                const int y1 = std::min(y0 + 1, ny - 1);
-
-                const float dx = sx - x0;
-                const float dy = sy - y0;
-
-                const int j00 = 3 * (y0 * nx + x0) + c;
-                const int j01 = 3 * (y0 * nx + x1) + c;
-                const int j10 = 3 * (y1 * nx + x0) + c;
-                const int j11 = 3 * (y1 * nx + x1) + c;
-
-                const float v00 = temp->buf[j00];
-                const float v01 = temp->buf[j01];
-                const float v10 = temp->buf[j10];
-                const float v11 = temp->buf[j11];
-
-                const float v0 = v00 * (1.0f - dx) + v01 * dx;
-                const float v1 = v10 * (1.0f - dx) + v11 * dx;
-
-                const float v = v0 * (1.0f - dy) + v1 * dy;
-
-                const uint8_t v2 = std::min(std::max(std::round(v), 0.0f), 255.0f);
-
-                const int i = 3 * (y * nx3 + x) + c;
-
-                res->buf[i] = ((float(v2) / 255.0f) - m3[c]) / s3[c];
-            }
-        }
-    }
-
-    // {
-    //     clip_image_u8 * temp2 = clip_image_u8_init();
-    //     clip_image_convert_f32_to_u8(*res, *temp2);
-    //     clip_image_save_to_bmp(*temp2, "resized_normalized_f32_vanilla.bmp");
-    //     clip_image_u8_free(temp2);
-    // }
-    // res_imgs.push_back(res);
-
-    res_imgs->entries.push_back(std::move(res));
-
-    return true;
+    GGML_ASSERT(false && "Unknown image preprocessing type");
 }
 
 ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) {

From fb28f4f80efb5a2990a6a240433b02e259b0dbb1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= 
Date: Sat, 19 Apr 2025 16:26:38 +0200
Subject: [PATCH 16/39] gguf-py : fix upload python package workflow (#13020)

---
 gguf-py/pyproject.toml | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml
index 9745d3ea84..0c82725677 100644
--- a/gguf-py/pyproject.toml
+++ b/gguf-py/pyproject.toml
@@ -1,6 +1,6 @@
 [tool.poetry]
 name = "gguf"
-version = "0.16.1"
+version = "0.16.2"
 description = "Read and write ML models in GGUF for GGML"
 authors = ["GGML "]
 packages = [
@@ -23,7 +23,7 @@ numpy = ">=1.17"
 tqdm = ">=4.27"
 pyyaml = ">=5.1"
 sentencepiece = ">=0.1.98,<=0.2.0"
-PySide6 = { version = "^6.9", optional = true }
+PySide6 = { version = "^6.9", python = ">=3.9,<3.14", optional = true }
 
 [tool.poetry.dev-dependencies]
 pytest = "^5.2"

From 00137157fca3d17b90380762b4d7cc158d385bd3 Mon Sep 17 00:00:00 2001
From: bandoti <141645996+bandoti@users.noreply.github.com>
Date: Sat, 19 Apr 2025 13:05:03 -0300
Subject: [PATCH 17/39] Disable CI cross-compile builds (#13022)

---
 .github/workflows/build.yml | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 34417985d2..32c8b7717f 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -601,8 +601,9 @@ jobs:
             -DGGML_SYCL_F16=ON
           cmake --build build --config Release -j $(nproc)
 
-  build-linux-cross:
-    uses: ./.github/workflows/build-linux-cross.yml
+# Disabled for now due to sporadic issue syncing.
+#  build-linux-cross:
+#    uses: ./.github/workflows/build-linux-cross.yml
 
   macOS-latest-cmake-ios:
     runs-on: macos-latest

From 4ba9d711ba0a5bf0be00936d3258d4a5e4164d4f Mon Sep 17 00:00:00 2001
From: Jeffrey Morgan 
Date: Sat, 19 Apr 2025 22:28:40 -0700
Subject: [PATCH 18/39] metal: add neg operator (#13029)

---
 ggml/src/ggml-metal/ggml-metal.m     | 15 +++++++++++++++
 ggml/src/ggml-metal/ggml-metal.metal |  7 +++++++
 2 files changed, 22 insertions(+)

diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index 85f3ae7bfd..266d8af469 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -481,6 +481,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_SQRT,
     GGML_METAL_KERNEL_TYPE_SIN,
     GGML_METAL_KERNEL_TYPE_COS,
+    GGML_METAL_KERNEL_TYPE_NEG,
     GGML_METAL_KERNEL_TYPE_SUM_ROWS,
     GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
     GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
@@ -1159,6 +1160,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT,                            sqrt,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN,                             sin,                             true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS,                             cos,                             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG,                             neg,                             true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                        sum_rows,                        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX,                          argmax,                          true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,                 pool_2d_avg_f32,                 true);
@@ -1320,6 +1322,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_ELU:
+                case GGML_UNARY_OP_NEG:
                     return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
                 default:
                     return false;
@@ -2010,6 +2013,18 @@ static void ggml_metal_encode_node(
 
                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                 } break;
+                case GGML_UNARY_OP_NEG:
+                {
+                    id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
                 default:
                 {
                     GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index dc7eab03ee..8d6e99e621 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -949,6 +949,13 @@ kernel void kernel_cos(
     dst[tpig] = cos(src0[tpig]);
 }
 
+kernel void kernel_neg(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = -src0[tpig];
+}
+
 kernel void kernel_sum_rows(
         device const float * src0,
         device       float * dst,

From 66168204be9559dc841f06f0025f3da01d9a8546 Mon Sep 17 00:00:00 2001
From: Jeff Bolz 
Date: Sun, 20 Apr 2025 03:50:02 -0500
Subject: [PATCH 19/39] vulkan: support noncontiguous rms_norm (#13031)

---
 ggml/src/ggml-vulkan/ggml-vulkan.cpp          | 22 ++++++++++---
 .../ggml-vulkan/vulkan-shaders/rms_norm.comp  | 32 ++++++++++++-------
 2 files changed, 39 insertions(+), 15 deletions(-)

diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 2f6d03c939..39f3cd343a 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -2397,7 +2397,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
 
@@ -6006,6 +6006,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
     case GGML_OP_REPEAT:
     case GGML_OP_REPEAT_BACK:
     case GGML_OP_ROPE:
+    case GGML_OP_RMS_NORM:
         return true;
     default:
         return false;
@@ -6216,7 +6217,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
 
     switch (op) {
     case GGML_OP_NORM:
-    case GGML_OP_RMS_NORM:
     case GGML_OP_RMS_NORM_BACK:
     case GGML_OP_L2_NORM:
     case GGML_OP_SOFT_MAX:
@@ -6233,6 +6233,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
                 elements = { nr, 1, 1 };
             }
         } break;
+    case GGML_OP_RMS_NORM:
+        elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
+        break;
+
     case GGML_OP_SUM:
         // We use GGML_OP_SUM_ROWS with 1 row.
         elements = { 1, 1, 1 };
@@ -6883,7 +6887,17 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
 
 static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
     float * op_params = (float *)dst->op_params;
-    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
+    const uint32_t src0_type_size = ggml_type_size(src0->type);
+    const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
+        (uint32_t)ggml_nelements(src0),
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
+        0,
+        op_params[0], 0.0f,
+        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+    }, dryrun);
 }
 
 static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -9388,10 +9402,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_VIEW:
         case GGML_OP_PERMUTE:
         case GGML_OP_TRANSPOSE:
+        case GGML_OP_RMS_NORM:
             return true;
         case GGML_OP_NORM:
         case GGML_OP_GROUP_NORM:
-        case GGML_OP_RMS_NORM:
         case GGML_OP_L2_NORM:
             return ggml_is_contiguous(op->src[0]);
         case GGML_OP_ADD:
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp
index b554400ba3..deb8ee9960 100644
--- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp
+++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp
@@ -1,6 +1,6 @@
 #version 450
 
-#include "generic_head.comp"
+#include "generic_unary_head.comp"
 #include "types.comp"
 
 #extension GL_EXT_control_flow_attributes : enable
@@ -8,19 +8,29 @@
 
 layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
 
-layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
-layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
-
 shared FLOAT_TYPE sum[BLOCK_SIZE];
 
 void main() {
-    const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
-    const uint tid = gl_LocalInvocationID.x;
+    const uint ncols     = p.ne00;
+    const uint nrows     = gl_NumWorkGroups.x;
+    const uint nchannels = gl_NumWorkGroups.y;
+
+    const uint row       = gl_WorkGroupID.x;
+    const uint channel   = gl_WorkGroupID.y;
+    const uint samp      = gl_WorkGroupID.z;
+    const uint tid       = gl_LocalInvocationID.x;
+
+    const uint stride_row       = p.nb01;
+    const uint stride_channel   = p.nb02;
+    const uint stride_sample    = p.nb03;
+
+    uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
+    uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
 
     sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
 
-    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
-        const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
+    [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
+        const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]);
         sum[tid] += xi * xi;
     }
 
@@ -33,10 +43,10 @@ void main() {
         barrier();
     }
 
-    const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX);
+    const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
     const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
 
-    [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
-        data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
+    [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
+        data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
     }
 }

From 6602304814e679cc8c162bb760a034aceb4f8965 Mon Sep 17 00:00:00 2001
From: Jeffrey Morgan 
Date: Sun, 20 Apr 2025 03:15:41 -0700
Subject: [PATCH 20/39] llava: fix errors in clip.h on certain compilers
 (#13030)

---
 examples/llava/clip.h | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/examples/llava/clip.h b/examples/llava/clip.h
index cc133a58de..5fc45d3e23 100644
--- a/examples/llava/clip.h
+++ b/examples/llava/clip.h
@@ -30,12 +30,13 @@ struct clip_image_size {
     int height;
 };
 
+struct clip_image_f32;
 struct clip_image_u8_batch;
 struct clip_image_f32_batch;
 
 struct clip_context_params {
     bool use_gpu;
-    ggml_log_level verbosity;
+    enum ggml_log_level verbosity;
 };
 
 // deprecated, use clip_init
@@ -84,7 +85,7 @@ CLIP_API void clip_image_f32_batch_free(struct clip_image_f32_batch * batch);
 CLIP_API size_t clip_image_f32_batch_n_images(const struct clip_image_f32_batch * batch); // equivalent to batch->size()
 CLIP_API size_t clip_image_f32_batch_nx(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->nx
 CLIP_API size_t clip_image_f32_batch_ny(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->ny
-CLIP_API clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
+CLIP_API struct clip_image_f32 * clip_image_f32_get_img(const struct clip_image_f32_batch * batch, int idx); // equivalent to batch[idx]->data
 
 /**
  * Build image from pixels decoded by other libraries instead of stb_image.h for better performance.

From 2016f07bd106c73699ecbaace80f55db5ed95dac Mon Sep 17 00:00:00 2001
From: Xuan-Son Nguyen 
Date: Sun, 20 Apr 2025 23:29:36 +0200
Subject: [PATCH 21/39] convert : experimental support for `--mmproj` flag
 (#13023)

* convert : experimental support for `--mmproj` flag

* fix bad ctrl+f replace

* fix style

* split into subclasses TextModel and VisionModel

* rename Mode --> ModelBase

* small fix

* correct CLIP_VISION arch name (because existing GGUF already use it)

* Apply suggestions from code review

Co-authored-by: compilade 

* fix Mistral3Model

* fix typo

Co-authored-by: compilade 

---------

Co-authored-by: compilade 
---
 convert_hf_to_gguf.py          | 666 +++++++++++++++++++--------------
 convert_lora_to_gguf.py        |   6 +-
 examples/llava/clip-impl.h     |   3 -
 gguf-py/gguf/constants.py      | 139 ++++++-
 gguf-py/gguf/tensor_mapping.py | 144 +++++++
 5 files changed, 663 insertions(+), 295 deletions(-)

diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 89522dee8b..6d34541a3c 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -42,11 +42,19 @@ class SentencePieceTokenTypes(IntEnum):
     BYTE = 6
 
 
-AnyModel = TypeVar("AnyModel", bound="type[Model]")
+class ModelType(IntEnum):
+    TEXT = 1
+    VISION = 2
 
 
-class Model:
-    _model_classes: dict[str, type[Model]] = {}
+AnyModel = TypeVar("AnyModel", bound="type[ModelBase]")
+
+
+class ModelBase:
+    _model_classes: dict[ModelType, dict[str, type[ModelBase]]] = {
+        ModelType.TEXT: {},
+        ModelType.VISION: {},
+    }
 
     dir_model: Path
     ftype: gguf.LlamaFileType
@@ -75,7 +83,9 @@ class Model:
                  metadata_override: Path | None = None, model_name: str | None = None,
                  split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
                  small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None):
-        if type(self) is Model:
+        if type(self) is ModelBase or \
+                type(self) is TextModel or \
+                type(self) is VisionModel:
             raise TypeError(f"{type(self).__name__!r} should not be directly instantiated")
 
         self.dir_model = dir_model
@@ -98,11 +108,11 @@ class Model:
 
             self.get_tensors = get_remote_tensors
         else:
-            self.part_names = Model.get_model_part_names(self.dir_model, "model", ".safetensors")
+            self.part_names = ModelBase.get_model_part_names(self.dir_model, "model", ".safetensors")
             self.is_safetensors = len(self.part_names) > 0
             if not self.is_safetensors:
-                self.part_names = Model.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
-        self.hparams = Model.load_hparams(self.dir_model) if hparams is None else hparams
+                self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin")
+        self.hparams = ModelBase.load_hparams(self.dir_model) if hparams is None else hparams
         self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"])
         self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
         self.tensor_names = None
@@ -126,11 +136,10 @@ class Model:
                                            split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
 
     @classmethod
-    def __init_subclass__(cls):
-        # can't use an abstract property, because overriding it without type errors
-        # would require using decorated functions instead of simply defining the property
-        if "model_arch" not in cls.__dict__:
-            raise TypeError(f"Missing property 'model_arch' for {cls.__name__!r}")
+    def add_prefix_to_filename(cls, path: Path, prefix: str) -> Path:
+        stem, suffix = path.stem, path.suffix
+        new_name = f"{prefix}{stem}{suffix}"
+        return path.with_name(new_name)
 
     def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any:
         key = next((k for k in keys if k in self.hparams), None)
@@ -140,9 +149,6 @@ class Model:
             return None
         raise KeyError(f"could not find any of: {keys}")
 
-    def set_vocab(self):
-        self._set_vocab_gpt2()
-
     def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
         tensor_names_from_parts: set[str] = set()
 
@@ -230,50 +236,7 @@ class Model:
         return new_name
 
     def set_gguf_parameters(self):
-        self.gguf_writer.add_block_count(self.block_count)
-
-        if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
-            self.gguf_writer.add_context_length(n_ctx)
-            logger.info(f"gguf: context length = {n_ctx}")
-
-        if (n_embd := self.find_hparam(["hidden_size", "n_embd"], optional=True)) is not None:
-            self.gguf_writer.add_embedding_length(n_embd)
-            logger.info(f"gguf: embedding length = {n_embd}")
-
-        if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None:
-            self.gguf_writer.add_feed_forward_length(n_ff)
-            logger.info(f"gguf: feed forward length = {n_ff}")
-
-        if (n_head := self.find_hparam(["num_attention_heads", "n_head"], optional=True)) is not None:
-            self.gguf_writer.add_head_count(n_head)
-            logger.info(f"gguf: head count = {n_head}")
-
-        if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
-            self.gguf_writer.add_head_count_kv(n_head_kv)
-            logger.info(f"gguf: key-value head count = {n_head_kv}")
-
-        if (rope_theta := self.hparams.get("rope_theta")) is not None:
-            self.gguf_writer.add_rope_freq_base(rope_theta)
-            logger.info(f"gguf: rope theta = {rope_theta}")
-        if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
-            self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
-            logger.info(f"gguf: rms norm epsilon = {f_rms_eps}")
-        if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
-            self.gguf_writer.add_layer_norm_eps(f_norm_eps)
-            logger.info(f"gguf: layer norm epsilon = {f_norm_eps}")
-        if (n_experts := self.hparams.get("num_local_experts")) is not None:
-            self.gguf_writer.add_expert_count(n_experts)
-            logger.info(f"gguf: expert count = {n_experts}")
-        if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
-            self.gguf_writer.add_expert_used_count(n_experts_used)
-            logger.info(f"gguf: experts used count = {n_experts_used}")
-
-        if (head_dim := self.hparams.get("head_dim")) is not None:
-            self.gguf_writer.add_key_length(head_dim)
-            self.gguf_writer.add_value_length(head_dim)
-
-        self.gguf_writer.add_file_type(self.ftype)
-        logger.info(f"gguf: file type = {self.ftype}")
+        raise NotImplementedError("set_gguf_parameters() must be implemented in subclasses")
 
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
         del bid  # unused
@@ -419,6 +382,88 @@ class Model:
         if self.metadata.size_label is None and total_params > 0:
             self.metadata.size_label = gguf.size_label(total_params, shared_params, expert_params, expert_count)
 
+        self.set_type()
+
+        logger.info("Set meta model")
+        self.metadata.set_gguf_meta_model(self.gguf_writer)
+
+        logger.info("Set model parameters")
+        self.set_gguf_parameters()
+
+        logger.info("Set model quantization version")
+        self.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
+
+    def write_vocab(self):
+        raise NotImplementedError("write_vocab() must be implemented in subclasses")
+
+    def write(self):
+        self.prepare_tensors()
+        self.prepare_metadata(vocab_only=False)
+        self.gguf_writer.write_header_to_file(path=self.fname_out)
+        self.gguf_writer.write_kv_data_to_file()
+        self.gguf_writer.write_tensors_to_file(progress=True)
+        self.gguf_writer.close()
+
+    @staticmethod
+    def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]:
+        part_names: list[str] = []
+        for filename in os.listdir(dir_model):
+            if filename.startswith(prefix) and filename.endswith(suffix):
+                part_names.append(filename)
+
+        part_names.sort()
+
+        return part_names
+
+    @staticmethod
+    def load_hparams(dir_model: Path):
+        with open(dir_model / "config.json", "r", encoding="utf-8") as f:
+            hparams = json.load(f)
+            if "text_config" in hparams:
+                hparams = {**hparams, **hparams["text_config"]}
+            return hparams
+
+    @classmethod
+    def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
+        assert names
+
+        def func(modelcls: AnyModel) -> AnyModel:
+            model_type = ModelType.VISION if modelcls.model_arch == gguf.MODEL_ARCH.CLIP_VISION else ModelType.TEXT
+            for name in names:
+                cls._model_classes[model_type][name] = modelcls
+            return modelcls
+        return func
+
+    @classmethod
+    def print_registered_models(cls):
+        for model_type, model_classes in cls._model_classes.items():
+            logger.error(f"{model_type.name} models:")
+            for name in sorted(model_classes.keys()):
+                logger.error(f"  - {name}")
+
+    @classmethod
+    def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type[ModelBase]:
+        try:
+            return cls._model_classes[model_type][arch]
+        except KeyError:
+            raise NotImplementedError(f'Architecture {arch!r} not supported!') from None
+
+
+class TextModel(ModelBase):
+    @classmethod
+    def __init_subclass__(cls):
+        # can't use an abstract property, because overriding it without type errors
+        # would require using decorated functions instead of simply defining the property
+        if "model_arch" not in cls.__dict__:
+            raise TypeError(f"Missing property 'model_arch' for {cls.__name__!r}")
+
+    def set_vocab(self):
+        self._set_vocab_gpt2()
+
+    def prepare_metadata(self, vocab_only: bool):
+        super().prepare_metadata(vocab_only=vocab_only)
+
+        total_params = self.gguf_writer.get_total_parameter_count()[0]
         # Extract the encoding scheme from the file type name. e.g. 'gguf.LlamaFileType.MOSTLY_Q8_0' --> 'Q8_0'
         output_type: str = self.ftype.name.partition("_")[2]
 
@@ -440,27 +485,54 @@ class Model:
             # Process templated file name with the output ftype, useful with the "auto" ftype
             self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type)
 
-        self.set_type()
-
-        logger.info("Set meta model")
-        self.metadata.set_gguf_meta_model(self.gguf_writer)
-
-        logger.info("Set model parameters")
-        self.set_gguf_parameters()
-
         logger.info("Set model tokenizer")
         self.set_vocab()
 
-        logger.info("Set model quantization version")
-        self.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
+    def set_gguf_parameters(self):
+        self.gguf_writer.add_block_count(self.block_count)
 
-    def write(self):
-        self.prepare_tensors()
-        self.prepare_metadata(vocab_only=False)
-        self.gguf_writer.write_header_to_file(path=self.fname_out)
-        self.gguf_writer.write_kv_data_to_file()
-        self.gguf_writer.write_tensors_to_file(progress=True)
-        self.gguf_writer.close()
+        if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
+            self.gguf_writer.add_context_length(n_ctx)
+            logger.info(f"gguf: context length = {n_ctx}")
+
+        if (n_embd := self.find_hparam(["hidden_size", "n_embd"], optional=True)) is not None:
+            self.gguf_writer.add_embedding_length(n_embd)
+            logger.info(f"gguf: embedding length = {n_embd}")
+
+        if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None:
+            self.gguf_writer.add_feed_forward_length(n_ff)
+            logger.info(f"gguf: feed forward length = {n_ff}")
+
+        if (n_head := self.find_hparam(["num_attention_heads", "n_head"], optional=True)) is not None:
+            self.gguf_writer.add_head_count(n_head)
+            logger.info(f"gguf: head count = {n_head}")
+
+        if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
+            self.gguf_writer.add_head_count_kv(n_head_kv)
+            logger.info(f"gguf: key-value head count = {n_head_kv}")
+
+        if (rope_theta := self.hparams.get("rope_theta")) is not None:
+            self.gguf_writer.add_rope_freq_base(rope_theta)
+            logger.info(f"gguf: rope theta = {rope_theta}")
+        if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None:
+            self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
+            logger.info(f"gguf: rms norm epsilon = {f_rms_eps}")
+        if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None:
+            self.gguf_writer.add_layer_norm_eps(f_norm_eps)
+            logger.info(f"gguf: layer norm epsilon = {f_norm_eps}")
+        if (n_experts := self.hparams.get("num_local_experts")) is not None:
+            self.gguf_writer.add_expert_count(n_experts)
+            logger.info(f"gguf: expert count = {n_experts}")
+        if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
+            self.gguf_writer.add_expert_used_count(n_experts_used)
+            logger.info(f"gguf: experts used count = {n_experts_used}")
+
+        if (head_dim := self.hparams.get("head_dim")) is not None:
+            self.gguf_writer.add_key_length(head_dim)
+            self.gguf_writer.add_value_length(head_dim)
+
+        self.gguf_writer.add_file_type(self.ftype)
+        logger.info(f"gguf: file type = {self.ftype}")
 
     def write_vocab(self):
         if len(self.gguf_writer.tensors) != 1:
@@ -471,44 +543,6 @@ class Model:
         self.gguf_writer.write_kv_data_to_file()
         self.gguf_writer.close()
 
-    @staticmethod
-    def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]:
-        part_names: list[str] = []
-        for filename in os.listdir(dir_model):
-            if filename.startswith(prefix) and filename.endswith(suffix):
-                part_names.append(filename)
-
-        part_names.sort()
-
-        return part_names
-
-    @staticmethod
-    def load_hparams(dir_model: Path):
-        with open(dir_model / "config.json", "r", encoding="utf-8") as f:
-            return json.load(f)
-
-    @classmethod
-    def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
-        assert names
-
-        def func(modelcls: AnyModel) -> AnyModel:
-            for name in names:
-                cls._model_classes[name] = modelcls
-            return modelcls
-        return func
-
-    @classmethod
-    def print_registered_models(cls):
-        for name in sorted(cls._model_classes.keys()):
-            logger.error(f"- {name}")
-
-    @classmethod
-    def from_model_architecture(cls, arch: str) -> type[Model]:
-        try:
-            return cls._model_classes[arch]
-        except KeyError:
-            raise NotImplementedError(f'Architecture {arch!r} not supported!') from None
-
     def does_token_look_special(self, token: str | bytes) -> bool:
         if isinstance(token, (bytes, bytearray)):
             token_text = token.decode(encoding="utf-8")
@@ -1024,8 +1058,48 @@ class Model:
             self.gguf_writer.add_add_eos_token(field.parts[-1].tolist()[0])
 
 
-@Model.register("GPTNeoXForCausalLM")
-class GPTNeoXModel(Model):
+class VisionModel(ModelBase):
+    model_arch = gguf.MODEL_ARCH.CLIP_VISION
+    n_text_embd = 0
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        if self.model_arch != gguf.MODEL_ARCH.CLIP_VISION:
+            raise TypeError("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION")
+
+        # small hack to correct the number of layers
+        self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.CLIP_VISION, 128)
+        self.n_embd_text = self.find_hparam(["hidden_size", "n_embd"])
+        assert self.n_embd_text > 0, "n_embd not found in hparams"
+
+        if "vision_config" not in self.hparams:
+            raise ValueError("vision_config not found in hparams")
+        # move vision config to the top level
+        self.hparams = self.hparams["vision_config"]
+
+    def set_type(self):
+        self.gguf_writer.add_type(gguf.GGUFType.CLIP_VISION)
+
+    def set_gguf_parameters(self):
+        self.gguf_writer.add_file_type(self.ftype)
+        self.gguf_writer.add_uint32(gguf.Keys.ClipVision.PROJECTION_DIM, self.n_embd_text)
+        self.gguf_writer.add_bool(gguf.Keys.ClipVision.HAS_VISION_ENCODER, True)
+
+        # vision config
+        self.gguf_writer.add_uint32(gguf.Keys.ClipVision.IMAGE_SIZE,           self.find_hparam(["image_size"]))
+        self.gguf_writer.add_uint32(gguf.Keys.ClipVision.PATCH_SIZE,           self.find_hparam(["patch_size"]))
+        self.gguf_writer.add_uint32(gguf.Keys.ClipVision.EMBEDDING_LENGTH,     self.find_hparam(["hidden_size"]))
+        self.gguf_writer.add_uint32(gguf.Keys.ClipVision.FEED_FORWARD_LENGTH,  self.find_hparam(["intermediate_size"]))
+        self.gguf_writer.add_uint32(gguf.Keys.ClipVision.BLOCK_COUNT,          self.find_hparam(["num_hidden_layers"]))
+        self.gguf_writer.add_uint32(gguf.Keys.ClipVision.Attention.HEAD_COUNT, self.find_hparam(["num_attention_heads"]))
+
+    def write_vocab(self):
+        raise ValueError("VisionModel does not support vocab writing")
+
+
+@ModelBase.register("GPTNeoXForCausalLM")
+class GPTNeoXModel(TextModel):
     model_arch = gguf.MODEL_ARCH.GPTNEOX
 
     def set_gguf_parameters(self):
@@ -1081,8 +1155,8 @@ class GPTNeoXModel(Model):
         return tensors
 
 
-@Model.register("BloomForCausalLM", "BloomModel")
-class BloomModel(Model):
+@ModelBase.register("BloomForCausalLM", "BloomModel")
+class BloomModel(TextModel):
     model_arch = gguf.MODEL_ARCH.BLOOM
 
     def set_gguf_parameters(self):
@@ -1138,8 +1212,8 @@ class BloomModel(Model):
         return tensors
 
 
-@Model.register("MPTForCausalLM")
-class MPTModel(Model):
+@ModelBase.register("MPTForCausalLM")
+class MPTModel(TextModel):
     model_arch = gguf.MODEL_ARCH.MPT
 
     def set_vocab(self):
@@ -1182,8 +1256,8 @@ class MPTModel(Model):
         return [(new_name, data_torch)]
 
 
-@Model.register("OrionForCausalLM")
-class OrionModel(Model):
+@ModelBase.register("OrionForCausalLM")
+class OrionModel(TextModel):
     model_arch = gguf.MODEL_ARCH.ORION
 
     def set_vocab(self):
@@ -1217,8 +1291,8 @@ class OrionModel(Model):
         self.gguf_writer.add_layer_norm_eps(self.hparams["rms_norm_eps"])
 
 
-@Model.register("BaichuanForCausalLM", "BaiChuanForCausalLM")
-class BaichuanModel(Model):
+@ModelBase.register("BaichuanForCausalLM", "BaiChuanForCausalLM")
+class BaichuanModel(TextModel):
     model_arch = gguf.MODEL_ARCH.BAICHUAN
 
     def set_vocab(self):
@@ -1297,8 +1371,8 @@ class BaichuanModel(Model):
         return weights[r * n_part:r * n_part + r, ...]
 
 
-@Model.register("XverseForCausalLM")
-class XverseModel(Model):
+@ModelBase.register("XverseForCausalLM")
+class XverseModel(TextModel):
     model_arch = gguf.MODEL_ARCH.XVERSE
 
     def set_vocab(self):
@@ -1404,8 +1478,8 @@ class XverseModel(Model):
         )
 
 
-@Model.register("FalconForCausalLM", "RWForCausalLM")
-class FalconModel(Model):
+@ModelBase.register("FalconForCausalLM", "RWForCausalLM")
+class FalconModel(TextModel):
     model_arch = gguf.MODEL_ARCH.FALCON
 
     def set_gguf_parameters(self):
@@ -1458,8 +1532,8 @@ class FalconModel(Model):
         return [(self.map_tensor_name(name), data_torch)]
 
 
-@Model.register("GPTBigCodeForCausalLM")
-class StarCoderModel(Model):
+@ModelBase.register("GPTBigCodeForCausalLM")
+class StarCoderModel(TextModel):
     model_arch = gguf.MODEL_ARCH.STARCODER
 
     def set_gguf_parameters(self):
@@ -1475,8 +1549,8 @@ class StarCoderModel(Model):
         self.gguf_writer.add_file_type(self.ftype)
 
 
-@Model.register("GPTRefactForCausalLM")
-class RefactModel(Model):
+@ModelBase.register("GPTRefactForCausalLM")
+class RefactModel(TextModel):
     model_arch = gguf.MODEL_ARCH.REFACT
 
     def set_vocab(self):
@@ -1539,8 +1613,8 @@ class RefactModel(Model):
         return tensors
 
 
-@Model.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM")
-class StableLMModel(Model):
+@ModelBase.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM")
+class StableLMModel(TextModel):
     model_arch = gguf.MODEL_ARCH.STABLELM
 
     def set_vocab(self):
@@ -1629,8 +1703,8 @@ class StableLMModel(Model):
                 raise ValueError(f"Unprocessed norms: {norms}")
 
 
-@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
-class LlamaModel(Model):
+@ModelBase.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
+class LlamaModel(TextModel):
     model_arch = gguf.MODEL_ARCH.LLAMA
     undo_permute = True
 
@@ -1778,23 +1852,13 @@ class LlamaModel(Model):
                 raise ValueError(f"Unprocessed experts: {experts}")
 
 
-@Model.register("Llama4ForConditionalGeneration")
+@ModelBase.register("Llama4ForConditionalGeneration")
 class Llama4Model(LlamaModel):
     model_arch = gguf.MODEL_ARCH.LLAMA4
-    has_vision: bool = False
     undo_permute = False
 
-    # TODO @ngxson : avoid duplicate this code everywhere by at least support "text_config"
-    # same with llama, but we need to merge the text_config into the root level of hparams
     def __init__(self, *args, **kwargs):
-        hparams = kwargs["hparams"] if "hparams" in kwargs else Model.load_hparams(args[0])
-        if "text_config" in hparams:
-            hparams = {**hparams, **hparams["text_config"]}
-            kwargs["hparams"] = hparams
         super().__init__(*args, **kwargs)
-        if "vision_config" in hparams:
-            logger.info("Has vision encoder, but it will be ignored")
-            self.has_vision = True
         # IMPORTANT: the normal "intermediate_size" is renamed to "intermediate_size_mlp", we need to undo this
         self.hparams["intermediate_size_moe"] = self.hparams["intermediate_size"]
         self.hparams["intermediate_size"] = self.hparams["intermediate_size_mlp"]
@@ -1829,18 +1893,10 @@ class Llama4Model(LlamaModel):
         return super().modify_tensors(data_torch, name, bid)
 
 
-@Model.register("Mistral3ForConditionalGeneration")
+@ModelBase.register("Mistral3ForConditionalGeneration")
 class Mistral3Model(LlamaModel):
     model_arch = gguf.MODEL_ARCH.LLAMA
 
-    # we need to merge the text_config into the root level of hparams
-    def __init__(self, *args, **kwargs):
-        hparams = kwargs["hparams"] if "hparams" in kwargs else Model.load_hparams(args[0])
-        if "text_config" in hparams:
-            hparams = {**hparams, **hparams["text_config"]}
-            kwargs["hparams"] = hparams
-        super().__init__(*args, **kwargs)
-
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
         name = name.replace("language_model.", "")
         if "multi_modal_projector" in name or "vision_tower" in name:
@@ -1848,8 +1904,8 @@ class Mistral3Model(LlamaModel):
         return super().modify_tensors(data_torch, name, bid)
 
 
-@Model.register("DeciLMForCausalLM")
-class DeciModel(Model):
+@ModelBase.register("DeciLMForCausalLM")
+class DeciModel(TextModel):
     model_arch = gguf.MODEL_ARCH.DECI
 
     @staticmethod
@@ -2020,8 +2076,8 @@ class DeciModel(Model):
         super().prepare_tensors()
 
 
-@Model.register("BitnetForCausalLM")
-class BitnetModel(Model):
+@ModelBase.register("BitnetForCausalLM")
+class BitnetModel(TextModel):
     model_arch = gguf.MODEL_ARCH.BITNET
 
     def set_vocab(self):
@@ -2061,8 +2117,8 @@ class BitnetModel(Model):
         yield (new_name, data_torch)
 
 
-@Model.register("GrokForCausalLM")
-class GrokModel(Model):
+@ModelBase.register("GrokForCausalLM")
+class GrokModel(TextModel):
     model_arch = gguf.MODEL_ARCH.GROK
 
     def set_vocab(self):
@@ -2114,8 +2170,8 @@ class GrokModel(Model):
         return [(self.map_tensor_name(name), data_torch)]
 
 
-@Model.register("DbrxForCausalLM")
-class DbrxModel(Model):
+@ModelBase.register("DbrxForCausalLM")
+class DbrxModel(TextModel):
     model_arch = gguf.MODEL_ARCH.DBRX
 
     def set_gguf_parameters(self):
@@ -2183,8 +2239,8 @@ class DbrxModel(Model):
         return n_dims > 1
 
 
-@Model.register("MiniCPMForCausalLM")
-class MiniCPMModel(Model):
+@ModelBase.register("MiniCPMForCausalLM")
+class MiniCPMModel(TextModel):
     model_arch = gguf.MODEL_ARCH.MINICPM
 
     def set_gguf_parameters(self):
@@ -2238,8 +2294,8 @@ class MiniCPMModel(Model):
         return [(self.map_tensor_name(name), data_torch)]
 
 
-@Model.register("MiniCPM3ForCausalLM")
-class MiniCPM3Model(Model):
+@ModelBase.register("MiniCPM3ForCausalLM")
+class MiniCPM3Model(TextModel):
     model_arch = gguf.MODEL_ARCH.MINICPM3
 
     def set_gguf_parameters(self):
@@ -2291,8 +2347,8 @@ class MiniCPM3Model(Model):
         )
 
 
-@Model.register("QWenLMHeadModel")
-class QwenModel(Model):
+@ModelBase.register("QWenLMHeadModel")
+class QwenModel(TextModel):
     model_arch = gguf.MODEL_ARCH.QWEN
 
     @staticmethod
@@ -2333,8 +2389,8 @@ class QwenModel(Model):
         self.gguf_writer.add_file_type(self.ftype)
 
 
-@Model.register("Qwen2ForCausalLM")
-class Qwen2Model(Model):
+@ModelBase.register("Qwen2ForCausalLM")
+class Qwen2Model(TextModel):
     model_arch = gguf.MODEL_ARCH.QWEN2
 
     def set_vocab(self):
@@ -2352,8 +2408,8 @@ class Qwen2Model(Model):
                 self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
 
 
-@Model.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
-class Qwen2VLModel(Model):
+@ModelBase.register("Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration")
+class Qwen2VLModel(TextModel):
     model_arch = gguf.MODEL_ARCH.QWEN2VL
 
     def set_gguf_parameters(self):
@@ -2375,8 +2431,8 @@ class Qwen2VLModel(Model):
             yield name, data
 
 
-@Model.register("WavTokenizerDec")
-class WavTokenizerDecModel(Model):
+@ModelBase.register("WavTokenizerDec")
+class WavTokenizerDecModel(TextModel):
     model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC
 
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
@@ -2413,8 +2469,8 @@ class WavTokenizerDecModel(Model):
         self.gguf_writer.add_causal_attention(False)
 
 
-@Model.register("Qwen2MoeForCausalLM")
-class Qwen2MoeModel(Model):
+@ModelBase.register("Qwen2MoeForCausalLM")
+class Qwen2MoeModel(TextModel):
     model_arch = gguf.MODEL_ARCH.QWEN2MOE
 
     def set_gguf_parameters(self):
@@ -2476,18 +2532,18 @@ class Qwen2MoeModel(Model):
                 raise ValueError(f"Unprocessed experts: {experts}")
 
 
-@Model.register("Qwen3ForCausalLM")
+@ModelBase.register("Qwen3ForCausalLM")
 class Qwen3Model(Qwen2Model):
     model_arch = gguf.MODEL_ARCH.QWEN3
 
 
-@Model.register("Qwen3MoeForCausalLM")
+@ModelBase.register("Qwen3MoeForCausalLM")
 class Qwen3MoeModel(Qwen2MoeModel):
     model_arch = gguf.MODEL_ARCH.QWEN3MOE
 
 
-@Model.register("GPT2LMHeadModel")
-class GPT2Model(Model):
+@ModelBase.register("GPT2LMHeadModel")
+class GPT2Model(TextModel):
     model_arch = gguf.MODEL_ARCH.GPT2
 
     def set_gguf_parameters(self):
@@ -2518,8 +2574,8 @@ class GPT2Model(Model):
         return tensors
 
 
-@Model.register("PhiForCausalLM")
-class Phi2Model(Model):
+@ModelBase.register("PhiForCausalLM")
+class Phi2Model(TextModel):
     model_arch = gguf.MODEL_ARCH.PHI2
 
     def set_gguf_parameters(self):
@@ -2542,8 +2598,8 @@ class Phi2Model(Model):
         self.gguf_writer.add_add_bos_token(False)
 
 
-@Model.register("Phi3ForCausalLM")
-class Phi3MiniModel(Model):
+@ModelBase.register("Phi3ForCausalLM")
+class Phi3MiniModel(TextModel):
     model_arch = gguf.MODEL_ARCH.PHI3
 
     def set_vocab(self):
@@ -2720,7 +2776,7 @@ class Phi3MiniModel(Model):
         yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))
 
 
-@Model.register("PhiMoEForCausalLM")
+@ModelBase.register("PhiMoEForCausalLM")
 class PhiMoeModel(Phi3MiniModel):
     model_arch = gguf.MODEL_ARCH.PHIMOE
 
@@ -2777,8 +2833,8 @@ class PhiMoeModel(Phi3MiniModel):
                 raise ValueError(f"Unprocessed experts: {experts}")
 
 
-@Model.register("PlamoForCausalLM")
-class PlamoModel(Model):
+@ModelBase.register("PlamoForCausalLM")
+class PlamoModel(TextModel):
     model_arch = gguf.MODEL_ARCH.PLAMO
 
     def set_vocab(self):
@@ -2825,8 +2881,8 @@ class PlamoModel(Model):
         return [(new_name, data_torch)]
 
 
-@Model.register("CodeShellForCausalLM")
-class CodeShellModel(Model):
+@ModelBase.register("CodeShellForCausalLM")
+class CodeShellModel(TextModel):
     model_arch = gguf.MODEL_ARCH.CODESHELL
 
     def set_gguf_parameters(self):
@@ -2866,8 +2922,8 @@ class CodeShellModel(Model):
         return [(new_name, data_torch)]
 
 
-@Model.register("InternLM2ForCausalLM")
-class InternLM2Model(Model):
+@ModelBase.register("InternLM2ForCausalLM")
+class InternLM2Model(TextModel):
     model_arch = gguf.MODEL_ARCH.INTERNLM2
 
     def set_vocab(self):
@@ -3039,8 +3095,8 @@ class InternLM2Model(Model):
             return [(self.map_tensor_name(name), data_torch)]
 
 
-@Model.register("InternLM3ForCausalLM")
-class InternLM3Model(Model):
+@ModelBase.register("InternLM3ForCausalLM")
+class InternLM3Model(TextModel):
     model_arch = gguf.MODEL_ARCH.LLAMA
 
     def set_vocab(self):
@@ -3099,8 +3155,8 @@ class InternLM3Model(Model):
         return [(self.map_tensor_name(name), data_torch)]
 
 
-@Model.register("BertModel", "BertForMaskedLM", "CamembertModel")
-class BertModel(Model):
+@ModelBase.register("BertModel", "BertForMaskedLM", "CamembertModel")
+class BertModel(TextModel):
     model_arch = gguf.MODEL_ARCH.BERT
 
     def __init__(self, *args, **kwargs):
@@ -3187,7 +3243,7 @@ class BertModel(Model):
         return [(self.map_tensor_name(name), data_torch)]
 
 
-@Model.register("RobertaModel")
+@ModelBase.register("RobertaModel")
 class RobertaModel(BertModel):
     model_arch = gguf.MODEL_ARCH.BERT
 
@@ -3232,7 +3288,7 @@ class RobertaModel(BertModel):
         return super().modify_tensors(data_torch, name, bid)
 
 
-@Model.register("NomicBertModel")
+@ModelBase.register("NomicBertModel")
 class NomicBertModel(BertModel):
     model_arch = gguf.MODEL_ARCH.NOMIC_BERT
 
@@ -3262,7 +3318,7 @@ class NomicBertModel(BertModel):
         self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
 
 
-@Model.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
+@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
 class XLMRobertaModel(BertModel):
     model_arch = gguf.MODEL_ARCH.BERT
 
@@ -3373,8 +3429,8 @@ class XLMRobertaModel(BertModel):
         return super().modify_tensors(data_torch, name, bid)
 
 
-@Model.register("GemmaForCausalLM")
-class GemmaModel(Model):
+@ModelBase.register("GemmaForCausalLM")
+class GemmaModel(TextModel):
     model_arch = gguf.MODEL_ARCH.GEMMA
 
     def set_vocab(self):
@@ -3424,8 +3480,8 @@ class GemmaModel(Model):
         return [(self.map_tensor_name(name), data_torch)]
 
 
-@Model.register("Gemma2ForCausalLM")
-class Gemma2Model(Model):
+@ModelBase.register("Gemma2ForCausalLM")
+class Gemma2Model(TextModel):
     model_arch = gguf.MODEL_ARCH.GEMMA2
 
     def set_vocab(self):
@@ -3471,27 +3527,9 @@ class Gemma2Model(Model):
         return [(self.map_tensor_name(name), data_torch)]
 
 
-@Model.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
-class Gemma3Model(Model):
+@ModelBase.register("Gemma3ForCausalLM", "Gemma3ForConditionalGeneration")
+class Gemma3Model(TextModel):
     model_arch = gguf.MODEL_ARCH.GEMMA3
-    has_vision: bool = False
-
-    # we need to merge the text_config into the root level of hparams
-    def __init__(self, *args, **kwargs):
-        hparams = kwargs["hparams"] if "hparams" in kwargs else Model.load_hparams(args[0])
-        if "text_config" in hparams:
-            hparams = {**hparams, **hparams["text_config"]}
-            kwargs["hparams"] = hparams
-        super().__init__(*args, **kwargs)
-        if "vision_config" in hparams:
-            logger.info("Has vision encoder, but it will be ignored")
-            self.has_vision = True
-
-    def write(self):
-        super().write()
-        if self.has_vision:
-            logger.info("NOTE: this script only convert the language model to GGUF")
-            logger.info("      for the vision model, please use gemma3_convert_encoder_to_gguf.py")
 
     def set_vocab(self):
         self._set_vocab_sentencepiece()
@@ -3529,10 +3567,10 @@ class Gemma3Model(Model):
 
         if name.startswith("language_model."):
             name = name.replace("language_model.", "")
+
         elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
-                or name.startswith("multimodal_projector.") or name.startswith("vision_model."): # this is for old HF model, should be removed later
-            # ignore vision tensors
-            return []
+                or name.startswith("multimodal_projector.") or name.startswith("vision_model."):
+            return [] # skip vision tensors
 
         # remove OOV (out-of-vocabulary) rows in token_embd
         if "embed_tokens.weight" in name:
@@ -3548,13 +3586,58 @@ class Gemma3Model(Model):
         return [(self.map_tensor_name(name), data_torch)]
 
 
-@Model.register("Starcoder2ForCausalLM")
-class StarCoder2Model(Model):
+@ModelBase.register("Gemma3ForConditionalGeneration")
+class Gemma3VisionModel(VisionModel):
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        hparams = self.hparams
+        self.gguf_writer.add_string(gguf.Keys.ClipVision.PROJECTOR_TYPE, "gemma3")
+        # default values below are taken from HF tranformers code
+        self.gguf_writer.add_float32(gguf.Keys.ClipVision.Attention.LAYERNORM_EPS, hparams.get("layer_norm_eps", 1e-6))
+        self.gguf_writer.add_array(gguf.Keys.ClipVision.IMAGE_MEAN, [0.5, 0.5, 0.5])
+        self.gguf_writer.add_array(gguf.Keys.ClipVision.IMAGE_STD,  [0.5, 0.5, 0.5])
+        self.gguf_writer.add_bool (gguf.Keys.ClipVision.USE_GELU,   True)
+
+    def tensor_force_quant(self, name, new_name, bid, n_dims):
+        del bid, new_name, n_dims  # unused
+        # related to https://github.com/ggml-org/llama.cpp/issues/13025
+        if "input_projection" in name:
+            return gguf.GGMLQuantizationType.F16
+        if ".embeddings." in name:
+            return gguf.GGMLQuantizationType.F32
+        return False
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        del bid  # unused
+
+        if name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
+                or name.startswith("multimodal_projector.") or name.startswith("vision_model."):
+            # process vision tensors
+            name = name.replace("_weight", ".weight")
+            if "fc1" in name:
+                name = name.replace("fc1", "fc2")
+            else:
+                name = name.replace("fc2", "fc1")
+
+            # correct norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector
+            # the other norm values are part of SigLIP model, and they are already correct
+            # ref code: Gemma3RMSNorm
+            if "soft_emb_norm.weight" in name:
+                logger.info(f"Correcting norm value for '{name}'")
+                data_torch = data_torch + 1
+
+            return [(self.map_tensor_name(name), data_torch)]
+
+        return [] # skip other tensors
+
+
+@ModelBase.register("Starcoder2ForCausalLM")
+class StarCoder2Model(TextModel):
     model_arch = gguf.MODEL_ARCH.STARCODER2
 
 
-@Model.register("Rwkv6ForCausalLM")
-class Rwkv6Model(Model):
+@ModelBase.register("Rwkv6ForCausalLM")
+class Rwkv6Model(TextModel):
     model_arch = gguf.MODEL_ARCH.RWKV6
 
     def set_vocab(self):
@@ -3626,7 +3709,7 @@ class Rwkv6Model(Model):
         yield (new_name, data_torch)
 
 
-@Model.register("RWKV6Qwen2ForCausalLM")
+@ModelBase.register("RWKV6Qwen2ForCausalLM")
 class RWKV6Qwen2Model(Rwkv6Model):
     model_arch = gguf.MODEL_ARCH.RWKV6QWEN2
 
@@ -3680,8 +3763,8 @@ class RWKV6Qwen2Model(Rwkv6Model):
             yield (new_name, data)
 
 
-@Model.register("Rwkv7ForCausalLM", "RWKV7ForCausalLM")
-class Rwkv7Model(Model):
+@ModelBase.register("Rwkv7ForCausalLM", "RWKV7ForCausalLM")
+class Rwkv7Model(TextModel):
     model_arch = gguf.MODEL_ARCH.RWKV7
 
     def set_vocab(self):
@@ -3799,7 +3882,7 @@ class Rwkv7Model(Model):
             yield (new_name, data_torch)
 
 
-@Model.register("RwkvHybridForCausalLM")
+@ModelBase.register("RwkvHybridForCausalLM")
 class ARwkv7Model(Rwkv7Model):
     model_arch = gguf.MODEL_ARCH.ARWKV7
 
@@ -3842,8 +3925,8 @@ class ARwkv7Model(Rwkv7Model):
         self.gguf_writer.add_head_count(0)
 
 
-@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
-class MambaModel(Model):
+@ModelBase.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
+class MambaModel(TextModel):
     model_arch = gguf.MODEL_ARCH.MAMBA
 
     def set_vocab(self):
@@ -3920,8 +4003,8 @@ class MambaModel(Model):
         return [(new_name, data_torch)]
 
 
-@Model.register("CohereForCausalLM")
-class CommandR2Model(Model):
+@ModelBase.register("CohereForCausalLM")
+class CommandR2Model(TextModel):
     model_arch = gguf.MODEL_ARCH.COMMAND_R
 
     def __init__(self, *args, **kwargs):
@@ -3938,8 +4021,8 @@ class CommandR2Model(Model):
         self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
 
 
-@Model.register("Cohere2ForCausalLM")
-class Cohere2Model(Model):
+@ModelBase.register("Cohere2ForCausalLM")
+class Cohere2Model(TextModel):
     model_arch = gguf.MODEL_ARCH.COHERE2
 
     def set_gguf_parameters(self):
@@ -3956,9 +4039,9 @@ class Cohere2Model(Model):
         self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
 
 
-@Model.register("OlmoForCausalLM")
-@Model.register("OLMoForCausalLM")
-class OlmoModel(Model):
+@ModelBase.register("OlmoForCausalLM")
+@ModelBase.register("OLMoForCausalLM")
+class OlmoModel(TextModel):
     model_arch = gguf.MODEL_ARCH.OLMO
 
     def set_gguf_parameters(self):
@@ -3984,13 +4067,13 @@ class OlmoModel(Model):
         return [(self.map_tensor_name(name), data_torch)]
 
 
-@Model.register("Olmo2ForCausalLM")
-class Olmo2Model(Model):
+@ModelBase.register("Olmo2ForCausalLM")
+class Olmo2Model(TextModel):
     model_arch = gguf.MODEL_ARCH.OLMO2
 
 
-@Model.register("OlmoeForCausalLM")
-class OlmoeModel(Model):
+@ModelBase.register("OlmoeForCausalLM")
+class OlmoeModel(TextModel):
     model_arch = gguf.MODEL_ARCH.OLMOE
 
     def set_gguf_parameters(self):
@@ -4049,7 +4132,7 @@ class OlmoeModel(Model):
                 raise ValueError(f"Unprocessed experts: {experts}")
 
 
-@Model.register("JinaBertModel", "JinaBertForMaskedLM")
+@ModelBase.register("JinaBertModel", "JinaBertForMaskedLM")
 class JinaBertV2Model(BertModel):
     model_arch = gguf.MODEL_ARCH.JINA_BERT_V2
 
@@ -4096,8 +4179,8 @@ class JinaBertV2Model(BertModel):
         return super().modify_tensors(data_torch, name, bid)
 
 
-@Model.register("OpenELMForCausalLM")
-class OpenELMModel(Model):
+@ModelBase.register("OpenELMForCausalLM")
+class OpenELMModel(TextModel):
     model_arch = gguf.MODEL_ARCH.OPENELM
 
     @staticmethod
@@ -4171,8 +4254,8 @@ class OpenELMModel(Model):
         yield (self.map_tensor_name(name), data_torch)
 
 
-@Model.register("ArcticForCausalLM")
-class ArcticModel(Model):
+@ModelBase.register("ArcticForCausalLM")
+class ArcticModel(TextModel):
     model_arch = gguf.MODEL_ARCH.ARCTIC
 
     def set_vocab(self):
@@ -4322,8 +4405,8 @@ class ArcticModel(Model):
                 raise ValueError(f"Unprocessed experts: {experts}")
 
 
-@Model.register("DeepseekForCausalLM")
-class DeepseekModel(Model):
+@ModelBase.register("DeepseekForCausalLM")
+class DeepseekModel(TextModel):
     model_arch = gguf.MODEL_ARCH.DEEPSEEK
 
     def set_vocab(self):
@@ -4413,9 +4496,9 @@ class DeepseekModel(Model):
                 raise ValueError(f"Unprocessed experts: {experts}")
 
 
-@Model.register("DeepseekV2ForCausalLM")
-@Model.register("DeepseekV3ForCausalLM")
-class DeepseekV2Model(Model):
+@ModelBase.register("DeepseekV2ForCausalLM")
+@ModelBase.register("DeepseekV3ForCausalLM")
+class DeepseekV2Model(TextModel):
     model_arch = gguf.MODEL_ARCH.DEEPSEEK2
 
     def set_vocab(self):
@@ -4541,8 +4624,8 @@ class DeepseekV2Model(Model):
                 raise ValueError(f"Unprocessed experts: {experts}")
 
 
-@Model.register("PLMForCausalLM")
-class PLMModel(Model):
+@ModelBase.register("PLMForCausalLM")
+class PLMModel(TextModel):
     model_arch = gguf.MODEL_ARCH.PLM
 
     def set_vocab(self):
@@ -4564,11 +4647,11 @@ class PLMModel(Model):
         super().prepare_tensors()
 
 
-@Model.register("T5WithLMHeadModel")
-@Model.register("T5ForConditionalGeneration")
-@Model.register("MT5ForConditionalGeneration")
-@Model.register("UMT5ForConditionalGeneration")
-class T5Model(Model):
+@ModelBase.register("T5WithLMHeadModel")
+@ModelBase.register("T5ForConditionalGeneration")
+@ModelBase.register("MT5ForConditionalGeneration")
+@ModelBase.register("UMT5ForConditionalGeneration")
+class T5Model(TextModel):
     model_arch = gguf.MODEL_ARCH.T5
 
     def __init__(self, *args, **kwargs):
@@ -4707,8 +4790,8 @@ class T5Model(Model):
         return [(self.map_tensor_name(name), data_torch)]
 
 
-@Model.register("T5EncoderModel")
-class T5EncoderModel(Model):
+@ModelBase.register("T5EncoderModel")
+class T5EncoderModel(TextModel):
     model_arch = gguf.MODEL_ARCH.T5ENCODER
 
     def __init__(self, *args, **kwargs):
@@ -4846,8 +4929,8 @@ class T5EncoderModel(Model):
         return [(self.map_tensor_name(name), data_torch)]
 
 
-@Model.register("JAISLMHeadModel")
-class JaisModel(Model):
+@ModelBase.register("JAISLMHeadModel")
+class JaisModel(TextModel):
     model_arch = gguf.MODEL_ARCH.JAIS
 
     def __init__(self, *args, **kwargs):
@@ -4929,8 +5012,8 @@ class JaisModel(Model):
         self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)
 
 
-@Model.register("Glm4ForCausalLM")
-class Glm4Model(Model):
+@ModelBase.register("Glm4ForCausalLM")
+class Glm4Model(TextModel):
     model_arch = gguf.MODEL_ARCH.GLM4
 
     def set_vocab(self):
@@ -4945,8 +5028,8 @@ class Glm4Model(Model):
                 self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
 
 
-@Model.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
-class ChatGLMModel(Model):
+@ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration")
+class ChatGLMModel(TextModel):
     model_arch = gguf.MODEL_ARCH.CHATGLM
 
     def set_vocab_chatglm3(self):
@@ -5100,8 +5183,8 @@ class ChatGLMModel(Model):
         return [(self.map_tensor_name(name), data_torch)]
 
 
-@Model.register("NemotronForCausalLM")
-class NemotronModel(Model):
+@ModelBase.register("NemotronForCausalLM")
+class NemotronModel(TextModel):
     model_arch = gguf.MODEL_ARCH.NEMOTRON
 
     def set_vocab(self):
@@ -5141,8 +5224,8 @@ class NemotronModel(Model):
         return [(self.map_tensor_name(name), data_torch)]
 
 
-@Model.register("ExaoneForCausalLM")
-class ExaoneModel(Model):
+@ModelBase.register("ExaoneForCausalLM")
+class ExaoneModel(TextModel):
     model_arch = gguf.MODEL_ARCH.EXAONE
 
     def set_gguf_parameters(self):
@@ -5210,7 +5293,7 @@ class ExaoneModel(Model):
                 yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))
 
 
-@Model.register("GraniteForCausalLM")
+@ModelBase.register("GraniteForCausalLM")
 class GraniteModel(LlamaModel):
     """Conversion for IBM's GraniteForCausalLM"""
     model_arch = gguf.MODEL_ARCH.GRANITE
@@ -5244,7 +5327,7 @@ class GraniteModel(LlamaModel):
             logger.info("gguf: (granite) logits_scale = %s", logits_scale)
 
 
-@Model.register("GraniteMoeForCausalLM")
+@ModelBase.register("GraniteMoeForCausalLM")
 class GraniteMoeModel(GraniteModel):
     """Conversion for IBM's GraniteMoeForCausalLM"""
     model_arch = gguf.MODEL_ARCH.GRANITE_MOE
@@ -5268,8 +5351,8 @@ class GraniteMoeModel(GraniteModel):
         return super().modify_tensors(data_torch, name, bid)
 
 
-@Model.register("BailingMoeForCausalLM")
-class BailingMoeModel(Model):
+@ModelBase.register("BailingMoeForCausalLM")
+class BailingMoeModel(TextModel):
     model_arch = gguf.MODEL_ARCH.BAILINGMOE
 
     def set_vocab(self):
@@ -5367,9 +5450,9 @@ class BailingMoeModel(Model):
                 raise ValueError(f"Unprocessed experts: {experts}")
 
 
-@Model.register("ChameleonForConditionalGeneration")
-@Model.register("ChameleonForCausalLM")  # obsolete
-class ChameleonModel(Model):
+@ModelBase.register("ChameleonForConditionalGeneration")
+@ModelBase.register("ChameleonForCausalLM")  # obsolete
+class ChameleonModel(TextModel):
     model_arch = gguf.MODEL_ARCH.CHAMELEON
 
     def set_gguf_parameters(self):
@@ -5554,6 +5637,10 @@ def parse_args() -> argparse.Namespace:
         "--remote", action="store_true",
         help="(Experimental) Read safetensors file remotely without downloading to disk. Config and tokenizer files will still be downloaded. To use this feature, you need to specify Hugging Face model repo name instead of a local directory. For example: 'HuggingFaceTB/SmolLM2-1.7B-Instruct'. Note: To access gated repo, set HF_TOKEN environment variable to your Hugging Face token.",
     )
+    parser.add_argument(
+        "--mmproj", action="store_true",
+        help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.",
+    )
 
     args = parser.parse_args()
     if not args.print_supported_models and args.model is None:
@@ -5584,7 +5671,7 @@ def main() -> None:
 
     if args.print_supported_models:
         logger.error("Supported models:")
-        Model.print_registered_models()
+        ModelBase.print_registered_models()
         sys.exit(0)
 
     if args.verbose:
@@ -5631,13 +5718,18 @@ def main() -> None:
 
     logger.info(f"Loading model: {dir_model.name}")
 
-    hparams = Model.load_hparams(dir_model)
+    hparams = ModelBase.load_hparams(dir_model)
+
+    if args.mmproj:
+        if "mmproj" not in fname_out.name:
+            fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")
 
     with torch.inference_mode():
         output_type = ftype_map[args.outtype]
         model_architecture = hparams["architectures"][0]
+        model_type = ModelType.VISION if args.mmproj else ModelType.TEXT
         try:
-            model_class = Model.from_model_architecture(model_architecture)
+            model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type)
         except NotImplementedError:
             logger.error(f"Model {model_architecture} is not supported")
             sys.exit(1)
diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py
index bdc991533b..00a6733cbd 100755
--- a/convert_lora_to_gguf.py
+++ b/convert_lora_to_gguf.py
@@ -24,7 +24,7 @@ if 'NO_LOCAL_GGUF' not in os.environ:
 import gguf
 
 # reuse model definitions from convert_hf_to_gguf.py
-from convert_hf_to_gguf import LazyTorchTensor, Model
+from convert_hf_to_gguf import LazyTorchTensor, ModelBase
 
 logger = logging.getLogger("lora-to-gguf")
 
@@ -340,11 +340,11 @@ if __name__ == '__main__':
             sys.exit(1)
     else:
         logger.info(f"Loading base model: {dir_base_model.name}")
-        hparams = Model.load_hparams(dir_base_model)
+        hparams = ModelBase.load_hparams(dir_base_model)
 
     with torch.inference_mode():
         try:
-            model_class = Model.from_model_architecture(hparams["architectures"][0])
+            model_class = ModelBase.from_model_architecture(hparams["architectures"][0])
         except NotImplementedError:
             logger.error(f"Model {hparams['architectures'][0]} is not supported")
             sys.exit(1)
diff --git a/examples/llava/clip-impl.h b/examples/llava/clip-impl.h
index 4d7340a56b..180ae9880b 100644
--- a/examples/llava/clip-impl.h
+++ b/examples/llava/clip-impl.h
@@ -50,7 +50,6 @@
 // tensor name constants
 //
 
-#define TN_TOKEN_EMBD      "%s.token_embd.weight"
 #define TN_POS_EMBD        "%s.position_embd.weight"
 #define TN_CLASS_EMBD      "v.class_embd"
 #define TN_PATCH_EMBD      "v.patch_embd.weight"  // not rename tensor with ".0" postfix for backwrad compat
@@ -66,8 +65,6 @@
 #define TN_LN_2            "%s.blk.%d.ln2.%s"
 #define TN_LN_PRE          "%s.pre_ln.%s"
 #define TN_LN_POST         "%s.post_ln.%s"
-#define TN_TEXT_PROJ       "text_projection.weight"
-#define TN_VIS_PROJ        "visual_projection.weight"
 #define TN_LLAVA_PROJ      "mm.%d.%s"
 #define TN_MVLM_PROJ_MLP   "mm.model.mlp.%d.%s"
 #define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index 8fcde2626a..3f24705201 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -218,17 +218,37 @@ class Keys:
         TYPE       = "adapter.type"
         LORA_ALPHA = "adapter.lora.alpha"
 
+    class ClipVision:
+        PROJECTOR_TYPE      = "clip.projector_type"
+        HAS_VISION_ENCODER  = "clip.has_vision_encoder"
+        HAS_LLAVA_PROJECTOR = "clip.has_llava_projector"
+        IMAGE_SIZE          = "clip.vision.image_size"
+        PATCH_SIZE          = "clip.vision.patch_size"
+        EMBEDDING_LENGTH    = "clip.vision.embedding_length"
+        FEED_FORWARD_LENGTH = "clip.vision.feed_forward_length"
+        PROJECTION_DIM      = "clip.vision.projection_dim"
+        BLOCK_COUNT         = "clip.vision.block_count"
+        IMAGE_MEAN          = "clip.vision.image_mean"
+        IMAGE_STD           = "clip.vision.image_std"
+        USE_GELU            = "clip.use_gelu"
+
+        class Attention:
+            HEAD_COUNT      = "clip.vision.attention.head_count"
+            LAYERNORM_EPS   = "clip.vision.attention.layer_norm_epsilon"
+
 #
 # recommended mapping of model tensor names for storage in gguf
 #
 
 
 class GGUFType:
-    MODEL   = "model"
-    ADAPTER = "adapter"
+    MODEL       = "model"
+    ADAPTER     = "adapter"
+    CLIP_VISION = "clip-vision"
 
 
 class MODEL_ARCH(IntEnum):
+    CLIP_VISION      = auto() # dummy arch for clip.cpp
     LLAMA            = auto()
     LLAMA4           = auto()
     DECI             = auto()
@@ -297,6 +317,16 @@ class MODEL_ARCH(IntEnum):
     BAILINGMOE       = auto()
 
 
+class VISION_PROJECTOR_TYPE(IntEnum):
+    MLP       = auto()
+    LDP       = auto()
+    LDPV2     = auto()
+    RESAMPLER = auto()
+    GLM_EDGE  = auto()
+    MERGER    = auto()
+    GEMMA3    = auto()
+
+
 class MODEL_TENSOR(IntEnum):
     TOKEN_EMBD           = auto()
     TOKEN_EMBD_NORM      = auto()
@@ -436,9 +466,41 @@ class MODEL_TENSOR(IntEnum):
     POSNET_ATTN_K        = auto()
     POSNET_ATTN_V        = auto()
     POSNET_ATTN_OUT      = auto()
+    # vision
+    V_MMPROJ             = auto()
+    V_MMPROJ_FC          = auto()
+    V_MMPROJ_MLP         = auto()
+    V_MMPROJ_PEG         = auto()
+    V_ENC_EMBD_CLS       = auto()
+    V_ENC_EMBD_PATCH     = auto()
+    V_ENC_EMBD_POS       = auto()
+    V_ENC_ATTN_Q         = auto()
+    V_ENC_ATTN_K         = auto()
+    V_ENC_ATTN_V         = auto()
+    V_ENC_INPUT_NORM     = auto()
+    V_ENC_OUTPUT         = auto()
+    V_ENC_OUTPUT_NORM    = auto()
+    V_ENC_FFN_UP         = auto()
+    V_ENC_FFN_DOWN       = auto()
+    V_PRE_NORM           = auto()
+    V_POST_NORM          = auto()
+    V_MM_INP_PROJ        = auto() # gemma3
+    V_MM_SOFT_EMB_NORM   = auto() # gemma3
+    V_RESMPL_POS_EMBD_K  = auto() # minicpmv
+    V_RESMPL_ATTN_Q      = auto() # minicpmv
+    V_RESMPL_ATTN_K      = auto() # minicpmv
+    V_RESMPL_ATTN_V      = auto() # minicpmv
+    V_RESMPL_ATTN_OUT    = auto() # minicpmv
+    V_RESMPL_KV          = auto() # minicpmv
+    V_RESMPL_KV_NORM     = auto() # minicpmv
+    V_RESMPL_POST_NORM   = auto() # minicpmv
+    V_RESMPL_Q_NORM      = auto() # minicpmv
+    V_RESMPL_PROJ        = auto() # minicpmv
+    V_RESMPL_QUERY       = auto() # minicpmv
 
 
 MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
+    MODEL_ARCH.CLIP_VISION:      "clip", # dummy arch for clip.cpp
     MODEL_ARCH.LLAMA:            "llama",
     MODEL_ARCH.LLAMA4:           "llama4",
     MODEL_ARCH.DECI:             "deci",
@@ -507,6 +569,16 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.BAILINGMOE:       "bailingmoe",
 }
 
+VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
+    VISION_PROJECTOR_TYPE.MLP:       "mlp",
+    VISION_PROJECTOR_TYPE.LDP:       "ldp",
+    VISION_PROJECTOR_TYPE.LDPV2:     "ldpv2",
+    VISION_PROJECTOR_TYPE.RESAMPLER: "resampler",
+    VISION_PROJECTOR_TYPE.GLM_EDGE:  "adapter",
+    VISION_PROJECTOR_TYPE.MERGER:    "qwen2vl_merger",
+    VISION_PROJECTOR_TYPE.GEMMA3:    "gemma3",
+}
+
 TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.TOKEN_EMBD:                "token_embd",
     MODEL_TENSOR.TOKEN_EMBD_NORM:           "token_embd_norm",
@@ -646,9 +718,72 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.POSNET_ATTN_K:             "posnet.{bid}.attn_k",
     MODEL_TENSOR.POSNET_ATTN_V:             "posnet.{bid}.attn_v",
     MODEL_TENSOR.POSNET_ATTN_OUT:           "posnet.{bid}.attn_output",
+    # vision
+    MODEL_TENSOR.V_MMPROJ:                  "mm.{bid}",
+    MODEL_TENSOR.V_MMPROJ_FC:               "mm.model.fc",
+    MODEL_TENSOR.V_MMPROJ_MLP:              "mm.model.mlp.{bid}",
+    MODEL_TENSOR.V_MMPROJ_PEG:              "mm.model.peg.{bid}",
+    MODEL_TENSOR.V_ENC_EMBD_CLS:            "v.class_embd",
+    MODEL_TENSOR.V_ENC_EMBD_PATCH:          "v.patch_embd",
+    MODEL_TENSOR.V_ENC_EMBD_POS:            "v.position_embd",
+    MODEL_TENSOR.V_ENC_ATTN_Q:              "v.blk.{bid}.attn_q",
+    MODEL_TENSOR.V_ENC_ATTN_K:              "v.blk.{bid}.attn_k",
+    MODEL_TENSOR.V_ENC_ATTN_V:              "v.blk.{bid}.attn_v",
+    MODEL_TENSOR.V_ENC_INPUT_NORM:          "v.blk.{bid}.ln1",
+    MODEL_TENSOR.V_ENC_OUTPUT:              "v.blk.{bid}.attn_out",
+    MODEL_TENSOR.V_ENC_OUTPUT_NORM:         "v.blk.{bid}.ln2",
+    MODEL_TENSOR.V_ENC_FFN_UP:              "v.blk.{bid}.ffn_up",
+    MODEL_TENSOR.V_ENC_FFN_DOWN:            "v.blk.{bid}.ffn_down",
+    MODEL_TENSOR.V_PRE_NORM:                "v.pre_ln",
+    MODEL_TENSOR.V_POST_NORM:               "v.post_ln",
+    MODEL_TENSOR.V_MM_INP_PROJ:             "mm.input_projection",
+    MODEL_TENSOR.V_MM_SOFT_EMB_NORM:        "mm.soft_emb_norm",
+    MODEL_TENSOR.V_RESMPL_POS_EMBD_K:       "resampler.pos_embd_k",
+    MODEL_TENSOR.V_RESMPL_ATTN_Q:           "resampler.attn.q",
+    MODEL_TENSOR.V_RESMPL_ATTN_K:           "resampler.attn.k",
+    MODEL_TENSOR.V_RESMPL_ATTN_V:           "resampler.attn.v",
+    MODEL_TENSOR.V_RESMPL_ATTN_OUT:         "resampler.attn.out",
+    MODEL_TENSOR.V_RESMPL_KV:               "resampler.kv",
+    MODEL_TENSOR.V_RESMPL_KV_NORM:          "resampler.ln_kv",
+    MODEL_TENSOR.V_RESMPL_POST_NORM:        "resampler.ln_post",
+    MODEL_TENSOR.V_RESMPL_Q_NORM:           "resampler.ln_q",
+    MODEL_TENSOR.V_RESMPL_PROJ:             "resampler.proj",
+    MODEL_TENSOR.V_RESMPL_QUERY:            "resampler.query",
 }
 
 MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
+    MODEL_ARCH.CLIP_VISION: [
+        MODEL_TENSOR.V_MMPROJ,
+        MODEL_TENSOR.V_MMPROJ_FC,
+        MODEL_TENSOR.V_MMPROJ_MLP,
+        MODEL_TENSOR.V_MMPROJ_PEG,
+        MODEL_TENSOR.V_ENC_EMBD_CLS,
+        MODEL_TENSOR.V_ENC_EMBD_PATCH,
+        MODEL_TENSOR.V_ENC_EMBD_POS,
+        MODEL_TENSOR.V_ENC_ATTN_Q,
+        MODEL_TENSOR.V_ENC_ATTN_K,
+        MODEL_TENSOR.V_ENC_ATTN_V,
+        MODEL_TENSOR.V_ENC_INPUT_NORM,
+        MODEL_TENSOR.V_ENC_OUTPUT,
+        MODEL_TENSOR.V_ENC_OUTPUT_NORM,
+        MODEL_TENSOR.V_ENC_FFN_UP,
+        MODEL_TENSOR.V_ENC_FFN_DOWN,
+        MODEL_TENSOR.V_PRE_NORM,
+        MODEL_TENSOR.V_POST_NORM,
+        MODEL_TENSOR.V_MM_INP_PROJ,
+        MODEL_TENSOR.V_MM_SOFT_EMB_NORM,
+        MODEL_TENSOR.V_RESMPL_POS_EMBD_K,
+        MODEL_TENSOR.V_RESMPL_ATTN_Q,
+        MODEL_TENSOR.V_RESMPL_ATTN_K,
+        MODEL_TENSOR.V_RESMPL_ATTN_V,
+        MODEL_TENSOR.V_RESMPL_ATTN_OUT,
+        MODEL_TENSOR.V_RESMPL_KV,
+        MODEL_TENSOR.V_RESMPL_KV_NORM,
+        MODEL_TENSOR.V_RESMPL_POST_NORM,
+        MODEL_TENSOR.V_RESMPL_Q_NORM,
+        MODEL_TENSOR.V_RESMPL_PROJ,
+        MODEL_TENSOR.V_RESMPL_QUERY,
+    ],
     MODEL_ARCH.LLAMA: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index 0bc75cf513..22066b2868 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -886,6 +886,150 @@ class TensorNameMap:
         MODEL_TENSOR.POSNET_ATTN_OUT: (
             "backbone.posnet.{bid}.proj_out", # wavtokenizer
         ),
+
+        #############################################################################
+        ## Vision encoder
+
+        MODEL_TENSOR.V_MMPROJ: (
+            "multi_modal_projector.linear_{bid}",
+        ),
+
+        MODEL_TENSOR.V_MMPROJ_FC: (
+            "model.connector.modality_projection.proj", # SmolVLM
+        ),
+
+        MODEL_TENSOR.V_MMPROJ_MLP: (
+            "model.mm_projector.mlp.mlp.{bid}",
+        ),
+
+        MODEL_TENSOR.V_MMPROJ_PEG: (
+            "model.mm_projector.peg.peg.{bid}",
+        ),
+
+        MODEL_TENSOR.V_ENC_EMBD_CLS: (
+            "vision_tower.vision_model.embeddings.class_embedding",
+        ),
+
+        MODEL_TENSOR.V_ENC_EMBD_PATCH: (
+            "vision_tower.vision_model.embeddings.patch_embedding",
+            "vpm.embeddings.patch_embedding",
+            "model.vision_model.embeddings.patch_embedding", # SmolVLM
+        ),
+
+        MODEL_TENSOR.V_ENC_EMBD_POS: (
+            "vision_tower.vision_model.embeddings.position_embedding",
+            "vpm.embeddings.position_embedding",
+            "model.vision_model.embeddings.position_embedding", # SmolVLM
+        ),
+
+        MODEL_TENSOR.V_ENC_ATTN_Q: (
+            "vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj",
+            "vpm.encoder.layers.{bid}.self_attn.q_proj",
+            "model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM
+        ),
+
+        MODEL_TENSOR.V_ENC_ATTN_K: (
+            "vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj",
+            "vpm.encoder.layers.{bid}.self_attn.k_proj",
+            "model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM
+        ),
+
+        MODEL_TENSOR.V_ENC_ATTN_V: (
+            "vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj",
+            "vpm.encoder.layers.{bid}.self_attn.v_proj",
+            "model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM
+        ),
+
+        MODEL_TENSOR.V_ENC_INPUT_NORM: (
+            "vision_tower.vision_model.encoder.layers.{bid}.layer_norm1",
+            "vpm.encoder.layers.{bid}.layer_norm1",
+            "model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
+        ),
+
+        MODEL_TENSOR.V_ENC_OUTPUT: (
+            "vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj",
+            "vpm.encoder.layers.{bid}.self_attn.out_proj",
+            "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
+        ),
+
+        MODEL_TENSOR.V_ENC_OUTPUT_NORM: (
+            "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2",
+            "vpm.encoder.layers.{bid}.layer_norm2",
+            "model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
+        ),
+
+        MODEL_TENSOR.V_ENC_FFN_UP: (
+            "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
+            "vpm.encoder.layers.{bid}.mlp.fc1",
+            "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM
+        ),
+
+        MODEL_TENSOR.V_ENC_FFN_DOWN: (
+            "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2",
+            "vpm.encoder.layers.{bid}.mlp.fc2",
+            "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM
+        ),
+
+        MODEL_TENSOR.V_PRE_NORM: (
+            "vision_tower.vision_model.pre_layrnorm",
+        ),
+
+        MODEL_TENSOR.V_POST_NORM: (
+            "vision_tower.vision_model.post_layernorm",
+            "model.vision_model.post_layernorm", # SmolVLM
+        ),
+
+        MODEL_TENSOR.V_MM_INP_PROJ: (
+            "multi_modal_projector.mm_input_projection",
+        ),
+
+        MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
+            "multi_modal_projector.mm_soft_emb_norm",
+        ),
+
+        MODEL_TENSOR.V_RESMPL_POS_EMBD_K: (
+            "resampler.pos_embed_k",
+        ),
+
+        MODEL_TENSOR.V_RESMPL_ATTN_Q: (
+            "resampler.attn.in_proj_q", # tensor generated from resampler.attn.in_proj
+        ),
+
+        MODEL_TENSOR.V_RESMPL_ATTN_K: (
+            "resampler.attn.in_proj_k", # tensor generated from resampler.attn.in_proj
+        ),
+
+        MODEL_TENSOR.V_RESMPL_ATTN_V: (
+            "resampler.attn.in_proj_v", # tensor generated from resampler.attn.in_proj
+        ),
+
+        MODEL_TENSOR.V_RESMPL_ATTN_OUT: (
+            "resampler.attn.out_proj",
+        ),
+
+        MODEL_TENSOR.V_RESMPL_KV: (
+            "resampler.kv_proj",
+        ),
+
+        MODEL_TENSOR.V_RESMPL_POST_NORM: (
+            "resampler.ln_post",
+        ),
+
+        MODEL_TENSOR.V_RESMPL_KV_NORM: (
+            "resampler.ln_kv",
+        ),
+
+        MODEL_TENSOR.V_RESMPL_Q_NORM: (
+            "resampler.ln_q",
+        ),
+
+        MODEL_TENSOR.V_RESMPL_PROJ: (
+            "resampler.proj",
+        ),
+
+        MODEL_TENSOR.V_RESMPL_QUERY: (
+            "resampler.query",
+        ),
     }
 
     # architecture-specific block mappings

From 84a9bf2fc2875205f0806fbbfbb66dc67204094c Mon Sep 17 00:00:00 2001
From: Xuan-Son Nguyen 
Date: Mon, 21 Apr 2025 15:32:58 +0200
Subject: [PATCH 22/39] mtmd : merge llava, gemma3 and minicpmv CLI into single
 `llama-mtmd-cli` (#13012)

* mtmd : merge `llava-cli` and `gemma3-cli` into single `mtmd-cli`

* support for minicpmv

* remove cpp files of llava and minicpmv

* update hot topics

* mtmd : add not supported msg for qwen2vl

* Update examples/llava/mtmd.cpp

Co-authored-by: Georgi Gerganov 

---------

Co-authored-by: Georgi Gerganov 
---
 README.md                                     |   1 +
 common/arg.cpp                                |   2 +-
 examples/llava/CMakeLists.txt                 |  22 +-
 examples/llava/deprecation-warning.cpp        |  22 ++
 examples/llava/llava-cli.cpp                  | 332 ----------------
 examples/llava/minicpmv-cli.cpp               | 354 ------------------
 .../llava/{gemma3-cli.cpp => mtmd-cli.cpp}    |  50 ++-
 examples/llava/mtmd.cpp                       | 313 +++++++++++++---
 examples/llava/tests.sh                       |  34 +-
 src/llama-chat.cpp                            |   2 +
 10 files changed, 360 insertions(+), 772 deletions(-)
 create mode 100644 examples/llava/deprecation-warning.cpp
 delete mode 100644 examples/llava/llava-cli.cpp
 delete mode 100644 examples/llava/minicpmv-cli.cpp
 rename examples/llava/{gemma3-cli.cpp => mtmd-cli.cpp} (82%)

diff --git a/README.md b/README.md
index cf45f23cf4..a0e7bd2d21 100644
--- a/README.md
+++ b/README.md
@@ -16,6 +16,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
 
 ## Hot topics
 
+- A new binary `llama-mtmd-cli` is introduced to replace `llava-cli`, `minicpmv-cli` and `gemma3-cli` https://github.com/ggml-org/llama.cpp/pull/13012, `libllava` will be deprecated
 - **How to use [MTLResidencySet](https://developer.apple.com/documentation/metal/mtlresidencyset?language=objc) to keep the GPU memory active?** https://github.com/ggml-org/llama.cpp/pull/11427
 - **VS Code extension for FIM completions:** https://github.com/ggml-org/llama.vscode
 - Universal [tool call support](./docs/function-calling.md) in `llama-server` https://github.com/ggml-org/llama.cpp/pull/9639
diff --git a/common/arg.cpp b/common/arg.cpp
index 0b57f9da1e..80c318a0e5 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -2726,7 +2726,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         [](common_params & params, const std::string & value) {
             params.chat_template = value;
         }
-    ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
+    ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_LLAVA}).set_env("LLAMA_ARG_CHAT_TEMPLATE"));
     add_opt(common_arg(
         {"--chat-template-file"}, "JINJA_TEMPLATE_FILE",
         string_format(
diff --git a/examples/llava/CMakeLists.txt b/examples/llava/CMakeLists.txt
index 2d5061de46..6409b4f5e6 100644
--- a/examples/llava/CMakeLists.txt
+++ b/examples/llava/CMakeLists.txt
@@ -61,19 +61,9 @@ if(TARGET BUILD_INFO)
     add_dependencies(mtmd BUILD_INFO)
 endif()
 
-set(TARGET llama-llava-cli)
-add_executable(${TARGET} llava-cli.cpp)
-set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-llava-cli)
-install(TARGETS ${TARGET} RUNTIME)
-target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
-target_compile_features(${TARGET} PRIVATE cxx_std_17)
-
-set(TARGET llama-minicpmv-cli)
-add_executable(${TARGET} minicpmv-cli.cpp)
-set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-minicpmv-cli)
-install(TARGETS ${TARGET} RUNTIME)
-target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
-target_compile_features(${TARGET} PRIVATE cxx_std_17)
+add_executable(llama-llava-cli    deprecation-warning.cpp)
+add_executable(llama-gemma3-cli   deprecation-warning.cpp)
+add_executable(llama-minicpmv-cli deprecation-warning.cpp)
 
 set(TARGET llama-qwen2vl-cli)
 add_executable(${TARGET} qwen2vl-cli.cpp)
@@ -82,9 +72,9 @@ install(TARGETS ${TARGET} RUNTIME)
 target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
 target_compile_features(${TARGET} PRIVATE cxx_std_17)
 
-set(TARGET llama-gemma3-cli)
-add_executable(${TARGET} gemma3-cli.cpp)
-set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-gemma3-cli)
+set(TARGET llama-mtmd-cli)
+add_executable(${TARGET} mtmd-cli.cpp)
+set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-mtmd-cli)
 install(TARGETS ${TARGET} RUNTIME)
 target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT})
 target_compile_features(${TARGET} PRIVATE cxx_std_17)
diff --git a/examples/llava/deprecation-warning.cpp b/examples/llava/deprecation-warning.cpp
new file mode 100644
index 0000000000..dded0a56af
--- /dev/null
+++ b/examples/llava/deprecation-warning.cpp
@@ -0,0 +1,22 @@
+#include 
+#include 
+
+int main(int argc, char** argv) {
+    std::string filename = "main";
+    if (argc >= 1) {
+        filename = argv[0];
+    }
+
+    // Get only the program name from the full path
+    size_t pos = filename.find_last_of("/\\");
+    if (pos != std::string::npos) {
+        filename = filename.substr(pos+1);
+    }
+
+    fprintf(stdout, "\n");
+    fprintf(stdout, "WARNING: The binary '%s' is deprecated.\n", filename.c_str());
+    fprintf(stdout, "Please use 'llama-mtmd-cli' instead.\n");
+    fprintf(stdout, "\n");
+
+    return EXIT_FAILURE;
+}
diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp
deleted file mode 100644
index 0fe0e333a5..0000000000
--- a/examples/llava/llava-cli.cpp
+++ /dev/null
@@ -1,332 +0,0 @@
-#include "arg.h"
-#include "base64.hpp"
-#include "log.h"
-#include "common.h"
-#include "sampling.h"
-#include "clip.h"
-#include "llava.h"
-#include "llama.h"
-#include "ggml.h"
-
-#include 
-#include 
-#include 
-#include 
-
-static bool eval_tokens(struct llama_context * ctx_llama, std::vector tokens, int n_batch, int * n_past) {
-    int N = (int) tokens.size();
-    for (int i = 0; i < N; i += n_batch) {
-        int n_eval = (int) tokens.size() - i;
-        if (n_eval > n_batch) {
-            n_eval = n_batch;
-        }
-        if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
-            LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
-            return false;
-        }
-        *n_past += n_eval;
-    }
-    return true;
-}
-
-static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) {
-    std::vector tokens;
-    tokens.push_back(id);
-    return eval_tokens(ctx_llama, tokens, 1, n_past);
-}
-
-static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, bool add_bos){
-    std::string              str2     = str;
-    std::vector embd_inp = common_tokenize(ctx_llama, str2, add_bos, true);
-    eval_tokens(ctx_llama, embd_inp, n_batch, n_past);
-    return true;
-}
-
-static const char * sample(struct common_sampler * smpl,
-                           struct llama_context * ctx_llama,
-                           int * n_past) {
-    const llama_token id = common_sampler_sample(smpl, ctx_llama, -1);
-    common_sampler_accept(smpl, id, true);
-
-    const llama_model * model = llama_get_model(ctx_llama);
-    const llama_vocab * vocab = llama_model_get_vocab(model);
-
-    static std::string ret;
-    if (llama_vocab_is_eog(vocab, id)) {
-        ret = "";
-    } else {
-        ret = common_token_to_piece(ctx_llama, id);
-    }
-    eval_id(ctx_llama, id, n_past);
-    return ret.c_str();
-}
-
-static const char* IMG_BASE64_TAG_BEGIN = "";
-
-static void find_image_tag_in_prompt(const std::string& prompt, size_t& begin_out, size_t& end_out) {
-    begin_out = prompt.find(IMG_BASE64_TAG_BEGIN);
-    end_out = prompt.find(IMG_BASE64_TAG_END, (begin_out == std::string::npos) ? 0UL : begin_out);
-}
-
-static bool prompt_contains_image(const std::string& prompt) {
-    size_t begin, end;
-    find_image_tag_in_prompt(prompt, begin, end);
-    return (begin != std::string::npos);
-}
-
-// replaces the base64 image tag in the prompt with `replacement`
-static llava_image_embed * llava_image_embed_make_with_prompt_base64(struct clip_ctx * ctx_clip, int n_threads, const std::string& prompt) {
-    size_t img_base64_str_start, img_base64_str_end;
-    find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end);
-    if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) {
-        LOG_ERR("%s: invalid base64 image tag. must be %s%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END);
-        return NULL;
-    }
-
-    auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN);
-    auto base64_bytes_count = img_base64_str_end - base64_bytes_start;
-    auto base64_str = prompt.substr(base64_bytes_start, base64_bytes_count );
-
-    auto required_bytes = base64::required_encode_size(base64_str.size());
-    auto img_bytes = std::vector(required_bytes);
-    base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin());
-
-    auto embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, img_bytes.data(), img_bytes.size());
-    if (!embed) {
-        LOG_ERR("%s: could not load image from base64 string.\n", __func__);
-        return NULL;
-    }
-
-    return embed;
-}
-
-static std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") {
-    size_t begin, end;
-    find_image_tag_in_prompt(prompt, begin, end);
-    if (begin == std::string::npos || end == std::string::npos) {
-        return prompt;
-    }
-    auto pre = prompt.substr(0, begin);
-    auto post = prompt.substr(end + strlen(IMG_BASE64_TAG_END));
-    return pre + replacement + post;
-}
-
-struct llava_context {
-    struct clip_ctx * ctx_clip = NULL;
-    struct llama_context * ctx_llama = NULL;
-    struct llama_model * model = NULL;
-};
-
-static void print_usage(int, char ** argv) {
-    LOG("\n example usage:\n");
-    LOG("\n     %s -m  --mmproj  --image  --image  [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
-    LOG("\n note: a lower temperature value like 0.1 is recommended for better quality.\n");
-}
-
-static struct llava_image_embed * load_image(llava_context * ctx_llava, common_params * params, const std::string & fname) {
-
-    // load and preprocess the image
-    llava_image_embed * embed = NULL;
-    auto prompt = params->prompt;
-    if (prompt_contains_image(prompt)) {
-        if (!params->image.empty()) {
-            LOG_INF("using base64 encoded image instead of command line image path\n");
-        }
-        embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->cpuparams.n_threads, prompt);
-        if (!embed) {
-            LOG_ERR("%s: can't load image from prompt\n", __func__);
-            return NULL;
-        }
-        params->prompt = remove_image_from_prompt(prompt);
-    } else {
-        embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->cpuparams.n_threads, fname.c_str());
-        if (!embed) {
-            fprintf(stderr, "%s: is %s really an image file?\n", __func__, fname.c_str());
-            return NULL;
-        }
-    }
-
-    return embed;
-}
-
-static void process_prompt(struct llava_context * ctx_llava, struct llava_image_embed * image_embed, common_params * params, const std::string & prompt) {
-    int n_past = 0;
-
-    const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
-
-    std::string system_prompt, user_prompt;
-    size_t image_pos = prompt.find("");
-    if (image_pos != std::string::npos) {
-        // new templating mode: Provide the full prompt including system message and use  as a placeholder for the image
-        system_prompt = prompt.substr(0, image_pos);
-        user_prompt = prompt.substr(image_pos + std::string("").length());
-        LOG_INF("system_prompt: %s\n", system_prompt.c_str());
-        if (params->verbose_prompt) {
-            auto tmp = common_tokenize(ctx_llava->ctx_llama, system_prompt, true, true);
-            for (int i = 0; i < (int) tmp.size(); i++) {
-                LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
-            }
-        }
-        LOG_INF("user_prompt: %s\n", user_prompt.c_str());
-        if (params->verbose_prompt) {
-            auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true);
-            for (int i = 0; i < (int) tmp.size(); i++) {
-                LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
-            }
-        }
-    } else {
-        // llava-1.5 native mode
-        system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:";
-        user_prompt = prompt + "\nASSISTANT:";
-        if (params->verbose_prompt) {
-            auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true);
-            for (int i = 0; i < (int) tmp.size(); i++) {
-                LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
-            }
-        }
-    }
-
-    eval_string(ctx_llava->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, true);
-    llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past);
-    eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);
-
-    // generate the response
-
-    LOG("\n");
-
-    struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling);
-    if (!smpl) {
-        LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
-        exit(1);
-    }
-
-    std::string response = "";
-    for (int i = 0; i < max_tgt_len; i++) {
-        const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past);
-        response += tmp;
-        if (strcmp(tmp, "") == 0) break;
-        if (strstr(tmp, "###")) break; // Yi-VL behavior
-        LOG("%s", tmp);
-        if (strstr(response.c_str(), "<|im_end|>")) break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works)
-        if (strstr(response.c_str(), "<|im_start|>")) break; // Yi-34B llava-1.6
-        if (strstr(response.c_str(), "USER:")) break; // mistral llava-1.6
-
-        fflush(stdout);
-    }
-
-    common_sampler_free(smpl);
-    LOG("\n");
-}
-
-static struct llama_model * llava_init(common_params * params) {
-    llama_backend_init();
-    llama_numa_init(params->numa);
-
-    llama_model_params model_params = common_model_params_to_llama(*params);
-
-    llama_model * model = llama_model_load_from_file(params->model.path.c_str(), model_params);
-    if (model == NULL) {
-        LOG_ERR("%s: unable to load model\n" , __func__);
-        return NULL;
-    }
-    return model;
-}
-
-static struct llava_context * llava_init_context(common_params * params, llama_model * model) {
-    const char * clip_path = params->mmproj.path.c_str();
-
-    auto prompt = params->prompt;
-    if (prompt.empty()) {
-        prompt = "describe the image in detail.";
-    }
-
-    auto ctx_clip = clip_model_load(clip_path, GGML_LOG_LEVEL_INFO);
-
-    llama_context_params ctx_params = common_context_params_to_llama(*params);
-    ctx_params.n_ctx           = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
-
-    llama_context * ctx_llama = llama_init_from_model(model, ctx_params);
-
-    if (ctx_llama == NULL) {
-        LOG_ERR("%s: failed to create the llama_context\n" , __func__);
-        return NULL;
-    }
-
-    auto * ctx_llava = (struct llava_context *)malloc(sizeof(llava_context));
-
-    ctx_llava->ctx_llama = ctx_llama;
-    ctx_llava->ctx_clip = ctx_clip;
-    ctx_llava->model = model;
-    return ctx_llava;
-}
-
-static void llava_free(struct llava_context * ctx_llava) {
-    if (ctx_llava->ctx_clip) {
-        clip_free(ctx_llava->ctx_clip);
-        ctx_llava->ctx_clip = NULL;
-    }
-
-    llama_free(ctx_llava->ctx_llama);
-    llama_model_free(ctx_llava->model);
-    llama_backend_free();
-}
-
-int main(int argc, char ** argv) {
-    ggml_time_init();
-
-    common_params params;
-
-    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, print_usage)) {
-        return 1;
-    }
-
-    common_init();
-
-    if (params.mmproj.path.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) {
-        print_usage(argc, argv);
-        return 1;
-    }
-
-    auto * model = llava_init(¶ms);
-    if (model == NULL) {
-        fprintf(stderr, "%s: error: failed to init llava model\n", __func__);
-        return 1;
-    }
-
-    if (prompt_contains_image(params.prompt)) {
-        auto * ctx_llava = llava_init_context(¶ms, model);
-
-        auto * image_embed = load_image(ctx_llava, ¶ms, "");
-
-        // process the prompt
-        process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
-
-        llama_perf_context_print(ctx_llava->ctx_llama);
-        llava_image_embed_free(image_embed);
-        ctx_llava->model = NULL;
-        llava_free(ctx_llava);
-    } else {
-        for (auto & image : params.image) {
-            auto * ctx_llava = llava_init_context(¶ms, model);
-
-            auto * image_embed = load_image(ctx_llava, ¶ms, image);
-            if (!image_embed) {
-                LOG_ERR("%s: failed to load image %s. Terminating\n\n", __func__, image.c_str());
-                return 1;
-            }
-
-            // process the prompt
-            process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
-
-            llama_perf_context_print(ctx_llava->ctx_llama);
-            llava_image_embed_free(image_embed);
-            ctx_llava->model = NULL;
-            llava_free(ctx_llava);
-        }
-    }
-
-    llama_model_free(model);
-
-    return 0;
-}
diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp
deleted file mode 100644
index 5ad970c220..0000000000
--- a/examples/llava/minicpmv-cli.cpp
+++ /dev/null
@@ -1,354 +0,0 @@
-#include "arg.h"
-#include "log.h"
-#include "common.h"
-#include "sampling.h"
-#include "clip.h"
-#include "llava.h"
-#include "llama.h"
-#include "ggml.h"
-
-#include 
-#include 
-#include 
-#include 
-#include 
-#include  // TODO: remove me
-
-struct llava_context {
-    struct clip_ctx * ctx_clip = NULL;
-    struct llama_context * ctx_llama = NULL;
-    struct llama_model * model = NULL;
-};
-
-static void show_additional_info(int /*argc*/, char ** argv) {
-    LOG("\nexample usage:\n\n%s -m  --mmproj  --image  --image  [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]);
-    LOG("\nnote: a lower temperature value like 0.1 is recommended for better quality.\n");
-}
-
-static struct llama_model * llava_init(common_params * params) {
-    llama_backend_init();
-    llama_numa_init(params->numa);
-
-    llama_model_params model_params = common_model_params_to_llama(*params);
-
-    llama_model * model = llama_model_load_from_file(params->model.path.c_str(), model_params);
-    if (model == NULL) {
-        LOG_ERR("%s: unable to load model\n" , __func__);
-        return NULL;
-    }
-    return model;
-}
-
-static struct llava_context * llava_init_context(common_params * params, llama_model * model) {
-    auto prompt = params->prompt;
-    if (prompt.empty()) {
-        prompt = "describe the image in detail.";
-    }
-
-    llama_context_params ctx_params = common_context_params_to_llama(*params);
-    if (params->n_ctx < 2048) {
-        // warn user here, "Image processing requires at least 2048 context, setting context to 2048"
-        LOG_WRN("%s: Image processing requires at least 2048 context, setting context to 2048\n" , __func__);
-        ctx_params.n_ctx = 2048;
-    } else {
-        ctx_params.n_ctx = params->n_ctx;
-    }
-
-    llama_context * ctx_llama = llama_init_from_model(model, ctx_params);
-
-    if (ctx_llama == NULL) {
-        LOG_ERR("%s: failed to create the llama_context\n" , __func__);
-        return NULL;
-    }
-
-    auto * ctx_llava = (struct llava_context *)malloc(sizeof(llava_context));
-
-    ctx_llava->ctx_llama = ctx_llama;
-    ctx_llava->model = model;
-    return ctx_llava;
-}
-
-static void llava_free(struct llava_context * ctx_llava) {
-    if (ctx_llava->ctx_clip) {
-        clip_free(ctx_llava->ctx_clip);
-        ctx_llava->ctx_clip = NULL;
-    }
-
-    llama_free(ctx_llava->ctx_llama);
-    llama_model_free(ctx_llava->model);
-    llama_backend_free();
-}
-
-static struct clip_ctx * clip_init_context(common_params * params) {
-    const char * clip_path = params->mmproj.path.c_str();
-
-    auto prompt = params->prompt;
-    if (prompt.empty()) {
-        prompt = "describe the image in detail.";
-    }
-    struct clip_context_params clip_params = {
-        /* use_gpu */   params->n_gpu_layers != 0,
-        /* verbosity */ GGML_LOG_LEVEL_INFO, // TODO: make this configurable
-    };
-    auto * ctx_clip = clip_init(clip_path, clip_params);
-    return ctx_clip;
-}
-
-static bool eval_tokens(struct llama_context * ctx_llama, std::vector tokens, int n_batch, int * n_past) {
-    int N = (int) tokens.size();
-    for (int i = 0; i < N; i += n_batch) {
-        int n_eval = (int) tokens.size() - i;
-        if (n_eval > n_batch) {
-            n_eval = n_batch;
-        }
-        if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
-            LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
-            return false;
-        }
-        *n_past += n_eval;
-    }
-    return true;
-}
-
-static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) {
-    std::vector tokens;
-    tokens.push_back(id);
-    return eval_tokens(ctx_llama, tokens, 1, n_past);
-}
-
-static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, bool add_bos){
-    std::string              str2     = str;
-    std::vector embd_inp = common_tokenize(ctx_llama, str2, add_bos, true);
-    return eval_tokens(ctx_llama, embd_inp, n_batch, n_past);
-}
-
-static void process_eval_image_embed(struct llava_context * ctx_llava, const struct llava_image_embed * embeds, int n_batch, int * n_past, int idx) {
-    float * image_embed = (float *)malloc(clip_embd_nbytes(ctx_llava->ctx_clip));
-    std::memcpy(image_embed, embeds->embed + idx * clip_n_patches(ctx_llava->ctx_clip) * clip_n_mmproj_embd(ctx_llava->ctx_clip), clip_embd_nbytes(ctx_llava->ctx_clip));
-
-    auto * slice_embed = (llava_image_embed*)malloc(sizeof(llava_image_embed));
-    slice_embed->embed = image_embed;
-    slice_embed->n_image_pos = clip_n_patches(ctx_llava->ctx_clip);
-    llava_eval_image_embed(ctx_llava->ctx_llama, slice_embed, n_batch, n_past);
-    llava_image_embed_free(slice_embed);
-}
-
-static void process_image(struct llava_context * ctx_llava, struct llava_image_embed * embeds, common_params * params, int &n_past) {
-    std::string system_prompt;
-    int idx = 0;
-    int num_image_embeds = embeds->n_image_pos / clip_n_patches(ctx_llava->ctx_clip);
-    int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
-    if (has_minicpmv_projector == 2) {
-        system_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n";
-    }
-    else if (has_minicpmv_projector == 3) {
-        system_prompt = "<|im_start|>user\n";
-    }
-    else if (has_minicpmv_projector == 4) {
-        system_prompt = "<|im_start|>user\n";
-    }
-    LOG_INF("%s: image token past: %d\n", __func__, n_past);
-    eval_string(ctx_llava->ctx_llama, (system_prompt+"").c_str(), params->n_batch, &n_past, false);
-    process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++);
-    eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false);
-    if (num_image_embeds > 1) {
-        if (has_minicpmv_projector == 2) {
-            size_t num_image_embeds_col = clip_uhd_num_image_embeds_col(ctx_llava->ctx_clip);
-            eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false);
-            for (size_t i = 0; i < (num_image_embeds-1)/num_image_embeds_col; ++i) {
-                for (size_t j = 0; j < num_image_embeds_col; ++j) {
-                    eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false);
-                    process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++);
-                    eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false);
-                    if (j == num_image_embeds_col - 1) {
-                        eval_string(ctx_llava->ctx_llama, std::string("\n").c_str(), params->n_batch, &n_past, false);
-                    }
-                }
-            }
-            eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false);
-        }
-        else if (has_minicpmv_projector == 3 || has_minicpmv_projector == 4) {
-            size_t num_image_embeds_col = clip_uhd_num_image_embeds_col(ctx_llava->ctx_clip);
-            for (size_t i = 0; i < (num_image_embeds-1)/num_image_embeds_col; ++i) {
-                for (size_t j = 0; j < num_image_embeds_col; ++j) {
-                    eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false);
-                    process_eval_image_embed(ctx_llava, embeds, params->n_batch, &n_past, idx++);
-                    eval_string(ctx_llava->ctx_llama, std::string("").c_str(), params->n_batch, &n_past, false);
-                    if (j == num_image_embeds_col - 1) {
-                        eval_string(ctx_llava->ctx_llama, std::string("\n").c_str(), params->n_batch, &n_past, false);
-                    }
-                }
-            }
-        }
-    }
-    LOG_INF("%s: image token past: %d\n", __func__, n_past);
-}
-
-static const char * sample(struct common_sampler * smpl,
-                           struct llama_context * ctx_llama,
-                           int * n_past) {
-    const llama_token id = common_sampler_sample(smpl, ctx_llama, -1);
-    common_sampler_accept(smpl, id, true);
-
-    const llama_model * model = llama_get_model(ctx_llama);
-    const llama_vocab * vocab = llama_model_get_vocab(model);
-
-    static std::string ret;
-    if (llama_vocab_is_eog(vocab, id)) {
-        ret = "";
-    } else {
-        ret = common_token_to_piece(ctx_llama, id);
-    }
-    eval_id(ctx_llama, id, n_past);
-    return ret.c_str();
-}
-
-static struct llava_context * minicpmv_init(common_params * params, const std::string & fname, int &n_past){
-    auto * ctx_clip = clip_init_context(params);
-    auto * embeds = llava_image_embed_make_with_filename(ctx_clip, params->cpuparams.n_threads, fname.c_str());
-    if (!embeds) {
-        LOG_ERR("failed to load image %s. Terminating\n\n", fname.c_str());
-        return NULL;
-    }
-
-    // process the prompt
-    if (params->prompt.empty() && params->interactive == false) {
-        LOG_ERR("prompt should be given or interactive mode should be on");
-        return NULL;
-    }
-
-    auto * model = llava_init(params);
-    if (model == NULL) {
-        fprintf(stderr, "%s: error: failed to init minicpmv model\n", __func__);
-        return NULL;
-    }
-    const int64_t t_llava_init_start_us = ggml_time_us();
-    auto * ctx_llava = llava_init_context(params, model);
-    ctx_llava->ctx_clip = ctx_clip;
-    const int64_t t_llava_init_end_us = ggml_time_us();
-    float t_llava_init_ms = (t_llava_init_end_us - t_llava_init_start_us) / 1000.0;
-    LOG_INF("%s: llava init in %8.2f ms.\n", __func__, t_llava_init_ms);
-
-    const int64_t t_process_image_start_us = ggml_time_us();
-    process_image(ctx_llava, embeds, params, n_past);
-    const int64_t t_process_image_end_us = ggml_time_us();
-    float t_process_image_ms = (t_process_image_end_us - t_process_image_start_us) / 1000.0;
-    LOG_INF("%s: llama process image in %8.2f ms.\n", __func__, t_process_image_ms);
-
-    llava_image_embed_free(embeds);
-    return ctx_llava;
-}
-
-static struct common_sampler * llama_init(struct llava_context * ctx_llava, common_params * params, const std::string & prompt, int & n_past, bool is_first = false){
-    std::string user_prompt = prompt;
-    int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
-    if (!is_first) {
-        if (has_minicpmv_projector == 2) {
-            user_prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + prompt;
-        }
-        else if (has_minicpmv_projector == 3) {
-            user_prompt = "<|im_start|>user\n" + prompt;
-        }
-        else if (has_minicpmv_projector == 4) {
-            user_prompt = "<|im_start|>user\n" + prompt;
-        }
-    }
-
-    eval_string(ctx_llava->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false);
-    if (has_minicpmv_projector == 2) {
-        eval_string(ctx_llava->ctx_llama, "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", params->n_batch, &n_past, false);
-    }
-    else if (has_minicpmv_projector == 3) {
-        eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false);
-    }
-    else if (has_minicpmv_projector == 4) {
-        eval_string(ctx_llava->ctx_llama, "<|im_end|><|im_start|>assistant\n", params->n_batch, &n_past, false);
-    }
-
-    // generate the response
-
-    LOG_INF("\n");
-
-    struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sampling);
-    return smpl;
-}
-
-static const char * llama_loop(struct llava_context * ctx_llava,struct common_sampler * smpl, int &n_past){
-
-    const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past);
-    return tmp;
-}
-
-int main(int argc, char ** argv) {
-    ggml_time_init();
-
-    common_params params;
-
-    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, show_additional_info)) {
-        return 1;
-    }
-
-    common_init();
-
-    if (params.mmproj.path.empty() || (params.image.empty())) {
-        show_additional_info(argc, argv);
-        return 1;
-    }
-
-    for (auto & image : params.image) {
-        int n_past = 0;
-        auto * ctx_llava = minicpmv_init(¶ms, image, n_past);
-
-        if (!params.prompt.empty()) {
-            LOG("%s\n", params.prompt.c_str());
-            LOG("");
-            auto * smpl = llama_init(ctx_llava, ¶ms, params.prompt, n_past, true);
-            const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
-            std::string response;
-            bool have_tmp = false;
-            for (int i = 0; i < max_tgt_len; i++) {
-                const auto * tmp = llama_loop(ctx_llava, smpl, n_past);
-                response += tmp;
-                if (strcmp(tmp, "") == 0){
-                    if (!have_tmp) {
-                        continue;
-                    }
-                    break;
-                }
-                if (strstr(tmp, "###")) break; // Yi-VL behavior
-                have_tmp = true;
-                printf("%s", tmp);
-                if (strstr(response.c_str(), "")) break; // minicpm-v
-
-                fflush(stdout);
-            }
-            common_sampler_free(smpl);
-        }else {
-            while (true) {
-                LOG("");
-                std::string prompt;
-                std::getline(std::cin, prompt);
-                LOG("");
-                auto * smpl = llama_init(ctx_llava, ¶ms, prompt, n_past, true);
-                const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
-                std::string response;
-                for (int i = 0; i < max_tgt_len; i++) {
-                    const auto * tmp = llama_loop(ctx_llava, smpl, n_past);
-                    response += tmp;
-                    if (strcmp(tmp, "") == 0) break;
-                    printf("%s", tmp);// mistral llava-1.6
-                    if (strstr(response.c_str(), "")) break; // minicpm-v
-                    fflush(stdout);
-                }
-                common_sampler_free(smpl);
-            }
-        }
-        printf("\n");
-        llama_perf_context_print(ctx_llava->ctx_llama);
-
-        ctx_llava->model = NULL;
-        llava_free(ctx_llava);
-    }
-
-    return 0;
-}
diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/mtmd-cli.cpp
similarity index 82%
rename from examples/llava/gemma3-cli.cpp
rename to examples/llava/mtmd-cli.cpp
index 3d56647506..e80845a2c5 100644
--- a/examples/llava/gemma3-cli.cpp
+++ b/examples/llava/mtmd-cli.cpp
@@ -28,15 +28,16 @@ static bool g_is_generating = false;
 
 /**
  * Please note that this is NOT a production-ready stuff.
- * It is a playground for trying Gemma 3 vision capabilities.
+ * It is a playground for trying multimodal support in llama.cpp.
  * For contributors: please keep this code simple and easy to understand.
  */
 
 static void show_additional_info(int /*argc*/, char ** argv) {
     LOG(
-        "Experimental CLI for using Gemma 3 vision model\n\n"
+        "Experimental CLI for multimodal\n\n"
         "Usage: %s [options] -m  --mmproj  --image  -p \n\n"
         "  -m and --mmproj are required\n"
+        "  -hf user/repo can replace both -m and --mmproj in most cases\n"
         "  --image and -p are optional, if NOT provided, the CLI will run in chat mode\n",
         argv[0]
     );
@@ -56,7 +57,7 @@ static void sigint_handler(int signo) {
 }
 #endif
 
-struct gemma3_context {
+struct mtmd_cli_context {
     mtmd_context_ptr ctx_vision;
     common_init_result llama_init;
 
@@ -70,18 +71,38 @@ struct gemma3_context {
     // so here we don't need to keep track of chat history
     common_chat_templates_ptr tmpls;
 
+    // support for legacy templates (models not having EOT token)
+    llama_tokens antiprompt_tokens;
+
     int n_threads    = 1;
     llama_pos n_past = 0;
 
-    gemma3_context(common_params & params) : llama_init(common_init_from_params(params)) {
+    mtmd_cli_context(common_params & params) : llama_init(common_init_from_params(params)) {
         model = llama_init.model.get();
         lctx = llama_init.context.get();
         vocab = llama_model_get_vocab(model);
         n_threads = params.cpuparams.n_threads;
         batch = llama_batch_init(params.n_batch, 0, 1);
         n_batch = params.n_batch;
+
+        if (!llama_model_chat_template(model, nullptr) && params.chat_template.empty()) {
+            LOG_ERR("Model does not have chat template.\n");
+            LOG_ERR("  For old llava models, you may need to use '--chat-template vicuna'\n");
+            LOG_ERR("  For MobileVLM models, use '--chat-template deepseek'\n");
+            exit(1);
+        }
+
         tmpls = common_chat_templates_init(model, params.chat_template);
+        LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(tmpls.get(), params.use_jinja).c_str());
+
         init_vision_context(params);
+
+        // load antiprompt tokens for legacy templates
+        if (params.chat_template == "vicuna") {
+            antiprompt_tokens = common_tokenize(lctx, "ASSISTANT:", false, true);
+        } else if (params.chat_template == "deepseek") {
+            antiprompt_tokens = common_tokenize(lctx, "###", false, true);
+        }
     }
 
     void init_vision_context(common_params & params) {
@@ -97,6 +118,17 @@ struct gemma3_context {
             exit(1);
         }
     }
+
+    bool check_antiprompt(const llama_tokens & generated_tokens) {
+        if (antiprompt_tokens.empty() || generated_tokens.size() < antiprompt_tokens.size()) {
+            return false;
+        }
+        return std::equal(
+            generated_tokens.end() - antiprompt_tokens.size(),
+            generated_tokens.end(),
+            antiprompt_tokens.begin()
+        );
+    }
 };
 
 struct decode_embd_batch {
@@ -132,7 +164,8 @@ struct decode_embd_batch {
     }
 };
 
-static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_predict) {
+static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
+    llama_tokens generated_tokens;
     for (int i = 0; i < n_predict; i++) {
         if (i > n_predict || !g_is_generating) {
             printf("\n");
@@ -140,9 +173,10 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_
         }
 
         llama_token token_id = common_sampler_sample(smpl, ctx.lctx, -1);
+        generated_tokens.push_back(token_id);
         common_sampler_accept(smpl, token_id, true);
 
-        if (llama_vocab_is_eog(ctx.vocab, token_id)) {
+        if (llama_vocab_is_eog(ctx.vocab, token_id) || ctx.check_antiprompt(generated_tokens)) {
             printf("\n");
             break; // end of generation
         }
@@ -161,7 +195,7 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_
     return 0;
 }
 
-static int eval_message(gemma3_context & ctx, common_chat_msg & msg, std::vector & images_fname, bool add_bos = false) {
+static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vector & images_fname, bool add_bos = false) {
     std::vector bitmaps;
 
     common_chat_templates_inputs tmpl_inputs;
@@ -218,7 +252,7 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    gemma3_context ctx(params);
+    mtmd_cli_context ctx(params);
     printf("%s: %s\n", __func__, params.model.path.c_str());
 
     bool is_single_turn = !params.prompt.empty() && !params.image.empty();
diff --git a/examples/llava/mtmd.cpp b/examples/llava/mtmd.cpp
index 3fd5bebc6a..8866c12e4d 100644
--- a/examples/llava/mtmd.cpp
+++ b/examples/llava/mtmd.cpp
@@ -12,6 +12,15 @@
 #include 
 #include 
 
+// slice template, used by some llava-uhd models to correctly place the special tokens around image embeddings
+// models not having it (llava-1.6) will process embeddings without any special tokens in-between
+enum mtmd_slice_tmpl {
+    MTMD_SLICE_TMPL_NONE,
+    MTMD_SLICE_TMPL_MINICPMV_2_5,
+    MTMD_SLICE_TMPL_MINICPMV_2_6,
+    // TODO @ngxson : add support for idefics (SmolVLM)
+};
+
 struct mtmd_context {
     struct clip_ctx * ctx_clip;
     const struct llama_model * text_model;
@@ -21,6 +30,16 @@ struct mtmd_context {
     int n_threads;
     std::string image_marker;
 
+    // for minicpmv, we need special tokens in-between slices
+    mtmd_slice_tmpl slice_tmpl    = MTMD_SLICE_TMPL_NONE;
+    llama_token tok_ov_img_start  = LLAMA_TOKEN_NULL; // overview image
+    llama_token tok_ov_img_end    = LLAMA_TOKEN_NULL; // overview image
+    llama_token tok_slices_start  = LLAMA_TOKEN_NULL; // start of all slices
+    llama_token tok_slices_end    = LLAMA_TOKEN_NULL; // end of all slices
+    llama_token tok_sli_img_start = LLAMA_TOKEN_NULL; // single slice
+    llama_token tok_sli_img_end   = LLAMA_TOKEN_NULL; // single slice
+    llama_token tok_row_end       = LLAMA_TOKEN_NULL; // end of row
+
     // TODO @ngxson : add timings
 
     mtmd_context(const char * mmproj_fname,
@@ -38,11 +57,66 @@ struct mtmd_context {
             throw std::runtime_error(string_format("Failed to load CLIP model from %s\n", mmproj_fname));
         }
         this->text_model = text_model;
+
+        GGML_ASSERT(!clip_is_qwen2vl(ctx_clip) && "Qwen2VL model is not supported yet, use llama-qwen2vl-cli instead");
+
+        int minicpmv_version = clip_is_minicpmv(ctx_clip);
+        if (minicpmv_version == 2) {
+            // minicpmv 2.5 format:
+            //  (overview)  (slice)  (slice) \n ... 
+            slice_tmpl        = MTMD_SLICE_TMPL_MINICPMV_2_5;
+            tok_ov_img_start  = lookup_token("");
+            tok_ov_img_end    = lookup_token("");
+            tok_slices_start  = lookup_token("");
+            tok_slices_end    = lookup_token("");
+            tok_sli_img_start = tok_ov_img_start;
+            tok_sli_img_end   = tok_ov_img_end;
+            tok_row_end       = lookup_token("\n");
+
+        } else if (minicpmv_version == 3 || minicpmv_version == 4) {
+            // minicpmv 2.6 format:
+            //  (overview)  (slice)  (slice) \n ...
+            slice_tmpl        = MTMD_SLICE_TMPL_MINICPMV_2_6;
+            tok_ov_img_start  = lookup_token("");
+            tok_ov_img_end    = lookup_token("");
+            tok_sli_img_start = lookup_token("");
+            tok_sli_img_end   = lookup_token("");
+            tok_row_end       = lookup_token("\n");
+
+        } else if (minicpmv_version != 0) {
+            GGML_ASSERT(false && "unsupported minicpmv version");
+        }
     }
 
     ~mtmd_context() {
         clip_free(ctx_clip);
     }
+
+private:
+    llama_token lookup_token(const std::string & token_text) {
+        const llama_vocab * vocab = llama_model_get_vocab(text_model);
+        const int n_vocab = llama_vocab_n_tokens(vocab);
+        for (int i = 0; i < n_vocab; i++) {
+            if (token_to_piece(vocab, i, true) == token_text) {
+                return i;
+            }
+        }
+        return LLAMA_TOKEN_NULL;
+    }
+
+    std::string token_to_piece(const llama_vocab * vocab, llama_token token, bool special) {
+        std::string piece;
+        piece.resize(piece.capacity());  // using string internal cache, 15 bytes + '\n'
+        const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
+        if (n_chars < 0) {
+            piece.resize(-n_chars);
+            int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
+            GGML_ASSERT(check == -n_chars);
+        } else {
+            piece.resize(n_chars);
+        }
+        return piece;
+    }
 };
 
 struct mtmd_image_tokens_data {
@@ -102,21 +176,58 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
 
     std::string prompt_modified(text.text);
     std::string marker_modified(ctx->image_marker);
-    projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
     // a bit hacky here, but works for now
     // for some models, we need to add prefix and suffix to the image embeddings
-    if (proj_type == PROJECTOR_TYPE_GEMMA3) {
+    if (clip_is_gemma3(ctx->ctx_clip)) {
+        // gemma 3
         //  ... (image embeddings) ... 
         marker_modified = "" + ctx->image_marker + "";
         string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
     }
 
+    // llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
+    // for glm-edge, we don't need to add because the tokens are already in the returned embeddings
+
+    // TODO @ngxson : glm-edge : remove BOI / EOI tokens embeddings, decode them as normal tokens
+
     std::vector parts = string_split_str(prompt_modified, ctx->image_marker);
     output.clear();
     output.reserve(parts.size());
 
     size_t i_img = 0;
 
+    // utility for adding raw tokens
+    auto add_text_chunk = [&output](std::vector && tokens) {
+        mtmd_input_chunk chunk{
+            MTMD_INPUT_CHUNK_TYPE_TEXT,
+            std::move(tokens),
+            {},
+        };
+        output.emplace_back(std::move(chunk));
+    };
+
+    // utility for splitting batch of multiple images into chunks of batch having single images
+    auto split_batch_to_chunk = [&ctx](clip_image_f32_batch && batch_f32, const std::string & id) {
+        std::vector chunks;
+
+        for (auto & entry : batch_f32.entries) {
+            mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
+            image_tokens->nx = clip_n_patches(ctx->ctx_clip);
+            image_tokens->ny = 1;
+            image_tokens->batch_f32.entries.push_back(std::move(entry));
+            image_tokens->id = id;
+
+            mtmd_input_chunk chunk{
+                MTMD_INPUT_CHUNK_TYPE_IMAGE,
+                {},
+                std::move(image_tokens),
+            };
+            chunks.emplace_back(std::move(chunk));
+        }
+
+        return chunks;
+    };
+
     for (const auto & part : parts) {
         //printf("tokenizing part: %s\n", part.c_str());
         bool add_bos = &parts.front() == ∂
@@ -139,12 +250,13 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
                 return 1;
             }
 
-            // shim layer
+            // convert mtmd_bitmap to clip_image_u8
             clip_image_u8_ptr img_u8(clip_image_u8_init());
             img_u8->nx = bitmaps[i_img].nx;
             img_u8->ny = bitmaps[i_img].ny;
             img_u8->buf.resize(bitmaps[i_img].data.size());
             std::memcpy(img_u8->buf.data(), bitmaps[i_img].data.data(), img_u8->nx * img_u8->ny * 3);
+            clip_image_size img_u8_size{img_u8->nx, img_u8->ny};
 
             // preprocess image
             clip_image_f32_batch batch_f32;
@@ -154,19 +266,70 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
                 return 2;
             }
 
-            mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
-            image_tokens->nx = clip_n_patches(ctx->ctx_clip); // TODO @ngxson : use clip_n_patches_by_image
-            image_tokens->ny = 1; // TODO
-            image_tokens->batch_f32 = std::move(batch_f32);
-            image_tokens->id = bitmaps[i_img].id; // optional
+            if (ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5 || ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6) {
+                // split batch into chunks of single images
+                auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_img].id);
+                GGML_ASSERT(chunks.size() > 0);
 
-            mtmd_input_chunk chunk{
-                MTMD_INPUT_CHUNK_TYPE_IMAGE,
-                {},
-                std::move(image_tokens),
-            };
-            output.emplace_back(std::move(chunk));
-            i_img++;
+                // add overview image
+                add_text_chunk({ctx->tok_ov_img_start});
+                output.emplace_back(std::move(chunks.front()));
+                chunks.erase(chunks.begin());
+                add_text_chunk({ctx->tok_ov_img_end});
+
+                // add slices
+                if (!chunks.empty()) {
+                    clip_add_load_image_size(ctx->ctx_clip, &img_u8_size);
+                    int n_col = clip_uhd_num_image_embeds_col(ctx->ctx_clip);
+                    int n_row = (int)chunks.size() / n_col;
+                    GGML_ASSERT(n_row * n_col == (int)chunks.size());
+                    if (ctx->tok_slices_start != LLAMA_TOKEN_NULL) {
+                        add_text_chunk({ctx->tok_slices_start});
+                    }
+                    for (int y = 0; y < n_row; y++) {
+                        for (int x = 0; x < n_col; x++) {
+                            if (ctx->tok_sli_img_start != LLAMA_TOKEN_NULL) {
+                                add_text_chunk({ctx->tok_sli_img_start});
+                            }
+                            output.emplace_back(std::move(chunks[y * n_col + x]));
+                            if (ctx->tok_sli_img_end != LLAMA_TOKEN_NULL) {
+                                add_text_chunk({ctx->tok_sli_img_end});
+                            }
+                        }
+                        if (ctx->tok_row_end != LLAMA_TOKEN_NULL && y != n_row - 1) {
+                            add_text_chunk({ctx->tok_row_end});
+                        }
+                    }
+                    if (ctx->tok_slices_end != LLAMA_TOKEN_NULL) {
+                        add_text_chunk({ctx->tok_slices_end});
+                    }
+                }
+
+            } else {
+                mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
+                image_tokens->nx = clip_n_patches(ctx->ctx_clip) * batch_f32.entries.size(); // TODO @ngxson : use clip_n_patches_by_image
+                image_tokens->ny = 1; // TODO
+                image_tokens->batch_f32 = std::move(batch_f32);
+                image_tokens->id = bitmaps[i_img].id; // optional
+
+                LOG_DBG("image_tokens->nx = %d\n", image_tokens->nx);
+                LOG_DBG("image_tokens->ny = %d\n", image_tokens->ny);
+                LOG_DBG("batch_f32 size = %d\n", (int)image_tokens->batch_f32.entries.size());
+
+                if (clip_is_glm(ctx->ctx_clip)) {
+                    // glm-edge
+                    image_tokens->nx += 2; // add 2 for the begin_of_image and end_of_image token embeddings
+                }
+
+                mtmd_input_chunk chunk{
+                    MTMD_INPUT_CHUNK_TYPE_IMAGE,
+                    {},
+                    std::move(image_tokens),
+                };
+                output.emplace_back(std::move(chunk));
+            }
+
+            i_img++; // move to next image
         }
     }
 
@@ -198,11 +361,35 @@ std::string mtmd_image_tokens_get_id(const mtmd_image_tokens * image_tokens) {
 int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) {
     int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
     ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
-    bool ok = clip_image_batch_encode(
-        ctx->ctx_clip,
-        ctx->n_threads,
-        &image_tokens->batch_f32,
-        ctx->image_embd_v.data());
+    bool ok = false;
+
+    // only effective for minicpmv and qwen2vl, other models will ignore load_image_size
+    {
+        clip_image_size slice_size{
+            image_tokens->batch_f32.entries[0]->nx,
+            image_tokens->batch_f32.entries[0]->ny};
+        clip_add_load_image_size(ctx->ctx_clip, &slice_size);
+    }
+
+    if (clip_is_llava(ctx->ctx_clip) || clip_is_minicpmv(ctx->ctx_clip) || clip_is_glm(ctx->ctx_clip)) {
+        // TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
+        const auto & entries = image_tokens->batch_f32.entries;
+        for (size_t i = 0; i < entries.size(); i++) {
+            int n_tokens_per_image = clip_n_patches(ctx->ctx_clip);
+            ok = clip_image_encode(
+                ctx->ctx_clip,
+                ctx->n_threads,
+                entries[i].get(),
+                ctx->image_embd_v.data() + i*n_mmproj_embd*n_tokens_per_image);
+        }
+    } else {
+        ok = clip_image_batch_encode(
+            ctx->ctx_clip,
+            ctx->n_threads,
+            &image_tokens->batch_f32,
+            ctx->image_embd_v.data());
+    }
+
     return ok ? 0 : 1;
 }
 
@@ -268,28 +455,31 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
     int32_t ret;
     llama_pos n_past = pos0;
     llama_batch text_batch = llama_batch_init(n_batch, 0, 1);
+    int n_mmproj_embd = clip_n_mmproj_embd(ctx->ctx_clip);
 
     for (auto & chunk : chunks) {
         bool is_last = &chunk == &chunks.back();
         if (chunk.type == MTMD_INPUT_CHUNK_TYPE_TEXT) {
-            // TODO @ngxson : may need to split into smaller batches
             text_batch.n_tokens = chunk.tokens_text.size();
-            for (size_t i = 0; i < chunk.tokens_text.size(); i++) {
-                text_batch.token   [i]    = chunk.tokens_text[i];
-                text_batch.pos     [i]    = n_past++;
-                text_batch.n_seq_id[i]    = 1;
-                text_batch.seq_id  [i][0] = seq_id;
-                text_batch.logits  [i]    = false;
-            }
-            if (is_last) {
-                // always get logits for last input chunk
-                text_batch.logits[text_batch.n_tokens - 1] = true;
-            }
-            ret = llama_decode(lctx, text_batch);
-            if (ret != 0) {
-                LOG_ERR("failed to decode text\n");
-                llama_batch_free(text_batch);
-                return ret;
+            size_t i = 0;
+            while (i < chunk.tokens_text.size()) { // split into batches
+                for (; i < chunk.tokens_text.size() && text_batch.n_tokens < n_batch; i++) {
+                    text_batch.token   [i]    = chunk.tokens_text[i];
+                    text_batch.pos     [i]    = n_past++;
+                    text_batch.n_seq_id[i]    = 1;
+                    text_batch.seq_id  [i][0] = seq_id;
+                    text_batch.logits  [i]    = false;
+                }
+                if (is_last) {
+                    // always get logits for last input chunk
+                    text_batch.logits[text_batch.n_tokens - 1] = true;
+                }
+                ret = llama_decode(lctx, text_batch);
+                if (ret != 0) {
+                    LOG_ERR("failed to decode text\n");
+                    llama_batch_free(text_batch);
+                    return ret;
+                }
             }
 
         } else if (chunk.type == MTMD_INPUT_CHUNK_TYPE_IMAGE) {
@@ -297,7 +487,7 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
             GGML_ASSERT(chunk.tokens_image != nullptr);
             int64_t t0 = ggml_time_ms();
             if (ctx->print_timings) {
-                LOG_INF("encoding image...\n");
+                LOG_INF("encoding image or slice...\n");
             }
             ret = mtmd_encode(ctx, chunk.tokens_image.get());
             if (ret != 0) {
@@ -306,24 +496,47 @@ int32_t mtmd_helper_eval(mtmd_context * ctx,
                 return ret;
             }
             if (ctx->print_timings) {
-                LOG_INF("image encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
+                LOG_INF("image/slice encoded in %" PRId64 " ms\n", ggml_time_ms() - t0);
             }
 
             int32_t n_tokens = mtmd_image_tokens_get_n_tokens(chunk.tokens_image.get());
+            int32_t i_batch = 0;
+            int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch;
             float * embd = mtmd_get_output_embd(ctx);
-            decode_embd_batch batch_img(embd, n_tokens, n_past, 0);
-            int64_t t1 = ggml_time_ms();
-            ret = llama_decode(lctx, batch_img.batch);
-            if (ret != 0) {
-                LOG_ERR("failed to decode image\n");
-                llama_batch_free(text_batch);
-                return ret;
-            }
-            if (ctx->print_timings) {
-                LOG_INF("image decoded in %" PRId64 " ms\n", ggml_time_ms() - t1);
+
+            if (mtmd_decode_use_non_causal(ctx)) {
+                llama_set_causal_attn(lctx, false);
+                // TODO @ngxson : need to make sure only one image is processed at a time, and n_ubatch must be enough to hold the image
             }
 
-            n_past += n_tokens;
+            while (i_batch < n_img_batches) { // split into batches
+                int32_t pos_offset = i_batch*n_batch;
+                int32_t n_tokens_batch = std::min(n_batch, n_tokens - pos_offset);
+                float * embd_batch = embd + pos_offset*n_mmproj_embd;
+                decode_embd_batch batch_img(embd_batch, n_tokens_batch, n_past, 0);
+
+                printf("decoding image batch %d/%d, n_tokens_batch = %d\n", i_batch+1, n_img_batches, n_tokens_batch);
+
+                int64_t t1 = ggml_time_ms();
+                ret = llama_decode(lctx, batch_img.batch);
+                if (ret != 0) {
+                    LOG_ERR("failed to decode image\n");
+                    llama_set_causal_attn(lctx, true); // restore causal attn
+                    llama_batch_free(text_batch);
+                    return ret;
+                }
+
+                if (ctx->print_timings) {
+                    LOG_INF("image decoded (batch %d/%d) in %" PRId64 " ms\n", i_batch+1, n_img_batches, ggml_time_ms() - t1);
+                }
+
+                i_batch++;
+                n_past += n_tokens_batch;
+            }
+
+            if (mtmd_decode_use_non_causal(ctx)) {
+                llama_set_causal_attn(lctx, true);
+            }
 
         } else {
             GGML_ASSERT(false && "chunk type not supported");
diff --git a/examples/llava/tests.sh b/examples/llava/tests.sh
index cc9bda8769..61ebb3ac18 100755
--- a/examples/llava/tests.sh
+++ b/examples/llava/tests.sh
@@ -17,26 +17,30 @@ cd $PROJ_ROOT
 
 arr_bin=()
 arr_hf=()
+arr_tmpl=() # chat template
 
 add_test() {
     local bin=$1
     local hf=$2
+    local tmpl=${3:-""} # default to empty string if not provided
     arr_bin+=("$bin")
     arr_hf+=("$hf")
+    arr_tmpl+=("$tmpl")
 }
 
-add_test "llama-gemma3-cli"   "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M"
-add_test "llama-llava-cli"    "cmp-nct/Yi-VL-6B-GGUF:Q5_K"
-add_test "llama-llava-cli"    "guinmoon/MobileVLM-3B-GGUF:Q4_K_M"
-add_test "llama-llava-cli"    "THUDM/glm-edge-v-5b-gguf:Q4_K_M"
-add_test "llama-llava-cli"    "second-state/Llava-v1.5-7B-GGUF:Q2_K"
-add_test "llama-llava-cli"    "cjpais/llava-1.6-mistral-7b-gguf:Q3_K"
-add_test "llama-llava-cli"    "ibm-research/granite-vision-3.2-2b-GGUF:Q4_K_M"
-add_test "llama-minicpmv-cli" "second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K" # model from openbmb is corrupted
-add_test "llama-minicpmv-cli" "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
-add_test "llama-minicpmv-cli" "openbmb/MiniCPM-o-2_6-gguf:Q4_0"
+add_test "llama-mtmd-cli"  "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M"
+add_test "llama-mtmd-cli"  "guinmoon/MobileVLM-3B-GGUF:Q4_K_M"               "deepseek"
+add_test "llama-mtmd-cli"  "THUDM/glm-edge-v-5b-gguf:Q4_K_M"
+add_test "llama-mtmd-cli"  "second-state/Llava-v1.5-7B-GGUF:Q2_K"            "vicuna"
+add_test "llama-mtmd-cli"  "cjpais/llava-1.6-mistral-7b-gguf:Q3_K"           "vicuna"
+add_test "llama-mtmd-cli"  "ibm-research/granite-vision-3.2-2b-GGUF:Q4_K_M"
+add_test "llama-mtmd-cli"  "second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K"  # model from openbmb is corrupted
+add_test "llama-mtmd-cli"  "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
+add_test "llama-mtmd-cli"  "openbmb/MiniCPM-o-2_6-gguf:Q4_0"
 add_test "llama-qwen2vl-cli"  "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
 
+# add_test "llama-mtmd-cli"  "cmp-nct/Yi-VL-6B-GGUF:Q5_K"  # this model has broken chat template, not usable
+
 ###############
 
 cmake --build build -j --target "${arr_bin[@]}"
@@ -46,12 +50,20 @@ arr_res=()
 for i in "${!arr_bin[@]}"; do
     bin="${arr_bin[$i]}"
     hf="${arr_hf[$i]}"
+    tmpl="${arr_tmpl[$i]}"
 
     echo "Running test with binary: $bin and HF model: $hf"
     echo ""
     echo ""
 
-    output=$("$PROJ_ROOT/build/bin/$bin" -hf "$hf" --image $SCRIPT_DIR/test-1.jpeg -p "what is the publisher name of the newspaper?" --temp 0 2>&1 | tee /dev/tty)
+    output=$(\
+        "$PROJ_ROOT/build/bin/$bin" \
+        -hf "$hf" \
+        --image $SCRIPT_DIR/test-1.jpeg \
+        -p "what is the publisher name of the newspaper?" \
+        --temp 0 -n 128 \
+        ${tmpl:+--chat-template "$tmpl"} \
+        2>&1 | tee /dev/tty)
 
     echo "$output" > $SCRIPT_DIR/output/$bin-$(echo "$hf" | tr '/' '-').log
 
diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp
index 721faa4e81..f62850ca57 100644
--- a/src/llama-chat.cpp
+++ b/src/llama-chat.cpp
@@ -121,6 +121,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
         return LLM_CHAT_TEMPLATE_PHI_3;
     } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) {
         return tmpl_contains("") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE;
+    } else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) {
+        return LLM_CHAT_TEMPLATE_GLMEDGE;
     } else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) {
         return LLM_CHAT_TEMPLATE_ZEPHYR;
     } else if (tmpl_contains("bos_token + message['role']")) {

From 5368ddda7a262d195b54687a31009dcc1f8b1602 Mon Sep 17 00:00:00 2001
From: Akarshan Biswas 
Date: Mon, 21 Apr 2025 19:13:30 +0530
Subject: [PATCH 23/39] SYCL: Add non-contiguous support in ROPE (#12993)

ggml-ci
---
 ggml/src/ggml-sycl/ggml-sycl.cpp |   7 +-
 ggml/src/ggml-sycl/rope.cpp      | 197 +++++++++++++++----------------
 ggml/src/ggml-sycl/rope.hpp      |   2 +-
 3 files changed, 96 insertions(+), 110 deletions(-)

diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp
index 1de34c9629..8081a77b74 100644
--- a/ggml/src/ggml-sycl/ggml-sycl.cpp
+++ b/ggml/src/ggml-sycl/ggml-sycl.cpp
@@ -3168,11 +3168,6 @@ static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_tensor
     ggml_sycl_op_diag_mask_inf(ctx, dst);
 }
 
-static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(dst->src[0])); // TODO: this restriction is temporary until non-cont support is implemented
-    ggml_sycl_op_rope(ctx, dst);
-}
-
 static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
     ggml_sycl_op_pool2d(ctx, dst);
 }
@@ -4002,7 +3997,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
                 if (mode == GGML_ROPE_TYPE_MROPE) {
                     return false;
                 }
-                return ggml_is_contiguous(op->src[0]);
+                return true;
             }
         case GGML_OP_IM2COL:
             return true;
diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp
index 80e050f241..4e276d3b62 100644
--- a/ggml/src/ggml-sycl/rope.cpp
+++ b/ggml/src/ggml-sycl/rope.cpp
@@ -34,23 +34,21 @@ static void rope_yarn(
     *sin_theta = sycl::sin(theta) * mscale;
 }
 
-template
-static void rope_norm(
-    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
-    const sycl::nd_item<3> &item_ct1) {
-    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                         item_ct1.get_local_id(1));
+template 
+static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
+                      const int32_t * pos, float freq_scale, float ext_factor, float attn_factor,
+                      const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
+                      const sycl::nd_item<3> & item_ct1) {
+    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
 
     if (i0 >= ne0) {
         return;
     }
 
-    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                    item_ct1.get_local_id(2);
+    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
 
     if (i0 >= n_dims) {
-        const int i = row*ne0 + i0;
+        const int i = row * ne0 + i0;
 
         dst[i + 0] = x[i + 0];
         dst[i + 1] = x[i + 1];
@@ -58,42 +56,43 @@ static void rope_norm(
         return;
     }
 
-    const int i = row*ne0 + i0;
-    const int i2 = row/p_delta_rows;
+    const int row0     = row % ne1;
+    const int channel0 = row / ne1;
 
-    const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f);
+    const int i  = row * ne0 + i0;
+    const int i2 = channel0 * s2 + row0 * s1 + i0;
 
-    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+    const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
+
+    const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
 
     float cos_theta;
     float sin_theta;
 
-    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
-    const float x0 = x[i + 0];
-    const float x1 = x[i + 1];
+    const float x0 = x[i2 + 0];
+    const float x1 = x[i2 + 1];
 
-    dst[i + 0] = x0*cos_theta - x1*sin_theta;
-    dst[i + 1] = x0*sin_theta + x1*cos_theta;
+    dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
+    dst[i + 1] = x0 * sin_theta + x1 * cos_theta;
 }
 
-template
-static void rope_neox(
-    const T * x, T * dst, int ne0, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
-    float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors,
-    const sycl::nd_item<3> &item_ct1) {
-    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                         item_ct1.get_local_id(1));
+template 
+static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
+                      const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
+                      const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
+                      const sycl::nd_item<3> & item_ct1) {
+    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
 
     if (i0 >= ne0) {
         return;
     }
 
-    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                    item_ct1.get_local_id(2);
+    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
 
     if (i0 >= n_dims) {
-        const int i = row*ne0 + i0;
+        const int i = row * ne0 + i0;
 
         dst[i + 0] = x[i + 0];
         dst[i + 1] = x[i + 1];
@@ -101,23 +100,26 @@ static void rope_neox(
         return;
     }
 
-    const int i  = row*ne0 + i0/2;
-    const int i2 = row/p_delta_rows;
+    const int row0     = row % ne1;
+    const int channel0 = row / ne1;
 
-    const float theta_base = pos[i2] * sycl::pow(theta_scale, i0 / 2.0f);
+    const int i  = row * ne0 + i0 / 2;
+    const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
 
-    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+    const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
+
+    const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
 
     float cos_theta;
     float sin_theta;
 
-    rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+    rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
 
-    const float x0 = x[i + 0];
-    const float x1 = x[i + n_dims/2];
+    const float x0 = x[i2 + 0];
+    const float x1 = x[i2 + n_dims / 2];
 
-    dst[i + 0]        = x0*cos_theta - x1*sin_theta;
-    dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
+    dst[i + 0]          = x0 * cos_theta - x1 * sin_theta;
+    dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
 }
 
 template 
@@ -163,18 +165,18 @@ static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, cons
 }
 
 template 
-static void rope_norm_sycl(
-    const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
+static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
+                           const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base,
+                           const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
+                           const float * freq_factors, queue_ptr stream) {
     GGML_ASSERT(ne0 % 2 == 0);
     const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
-    const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
+    const int            num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
     const sycl::range<3> block_nums(1, num_blocks_x, nr);
 
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float theta_scale = powf(freq_base, -2.0f / n_dims);
 
-    dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
+    dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
 
     if (freq_factors == nullptr) {
         /*
@@ -182,61 +184,47 @@ static void rope_norm_sycl(
         the limit. To get the device limit, query
         info::device::max_work_group_size. Adjust the work-group size if needed.
         */
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) {
-                rope_norm(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
-                               ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
-                               item_ct1);
-            });
+        stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
+            rope_norm(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                                theta_scale, freq_factors, item_ct1);
+        });
     } else {
         /*
         DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
         the limit. To get the device limit, query
         info::device::max_work_group_size. Adjust the work-group size if needed.
         */
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) {
-                rope_norm(x, dst, ne0, n_dims, pos, freq_scale, p_delta_rows,
-                              ext_factor, attn_factor, corr_dims, theta_scale, freq_factors,
-                              item_ct1);
-            });
+        stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
+            rope_norm(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                               theta_scale, freq_factors, item_ct1);
+        });
     }
 }
 
 template 
-static void rope_neox_sycl(
-    const T *x, T *dst, int ne0, int n_dims, int nr, const int32_t *pos, float freq_scale, int p_delta_rows,
-    float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
+static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
+                           const int n_dims, const int nr, const int32_t * pos, const float freq_scale,
+                           const float freq_base, const float ext_factor, const float attn_factor,
+                           const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
     GGML_ASSERT(ne0 % 2 == 0);
     const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
-    const int num_blocks_x = (ne0 + 2*SYCL_ROPE_BLOCK_SIZE - 1) / (2*SYCL_ROPE_BLOCK_SIZE);
+    const int            num_blocks_x = (ne0 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE);
     const sycl::range<3> block_nums(1, num_blocks_x, nr);
 
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
+    const float theta_scale = powf(freq_base, -2.0f / n_dims);
 
-    dpct::has_capability_or_fail(stream->get_device(),
-                                    {sycl::aspect::fp16});
+    dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
 
     if (freq_factors == nullptr) {
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) {
-                rope_neox(x, dst, ne0, n_dims, pos, freq_scale,
-                                    p_delta_rows, ext_factor, attn_factor,
-                                    corr_dims, theta_scale, freq_factors,
-                                    item_ct1);
-            });
+        stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
+            rope_neox(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                                theta_scale, freq_factors, item_ct1);
+        });
     } else {
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) {
-                rope_neox(x, dst, ne0, n_dims, pos, freq_scale,
-                                    p_delta_rows, ext_factor, attn_factor,
-                                    corr_dims, theta_scale, freq_factors,
-                                    item_ct1);
-            });
+        stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
+            rope_neox(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
+                               theta_scale, freq_factors, item_ct1);
+        });
     }
 }
 
@@ -272,7 +260,7 @@ static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1,
     }
 }
 
-void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
+inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
 
     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
     GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
@@ -329,43 +317,46 @@ void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
     if (is_neox) {
         GGML_SYCL_DEBUG("%s: neox path\n", __func__);
         if (dst->src[0]->type == GGML_TYPE_F32) {
-            rope_neox_sycl(
-                (const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, main_stream
-            );
+            rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
+                           pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
         } else if (dst->src[0]->type == GGML_TYPE_F16) {
-            rope_neox_sycl(
-                (const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, main_stream
-            );
+            rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
+                           n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
+                           main_stream);
         } else {
             GGML_ABORT("fatal error");
         }
     } else if (is_vision) {
         GGML_SYCL_DEBUG("%s: vision path\n", __func__);
         if (dst->src[0]->type == GGML_TYPE_F16) {
-            rope_vision_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
-                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, main_stream);
+            rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01,
+                             s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
+                             freq_factors, sections, main_stream);
         } else if (dst->src[0]->type == GGML_TYPE_F32) {
-            rope_vision_sycl((const float *) dst->src[0]->data, (float *)dst->data, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
-                freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, main_stream);
+            rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
+                             nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
+                             main_stream);
         } else {
             GGML_ABORT("Fatal error: Tensor type unsupported!");
         }
     } else {
         GGML_SYCL_DEBUG("%s: norm path\n", __func__);
         if (dst->src[0]->type == GGML_TYPE_F32) {
-            rope_norm_sycl(
-                (const float *)dst->src[0]->data, (float *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, main_stream
-            );
+            rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
+                           pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
         } else if (dst->src[0]->type == GGML_TYPE_F16) {
-            rope_norm_sycl(
-                (const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, n_dims, nr, pos, freq_scale, ne01, freq_base, ext_factor,
-                attn_factor, corr_dims, freq_factors, main_stream
-            );
+            rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
+                           n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
+                           main_stream);
         } else {
             GGML_ABORT("fatal error");
         }
     }
 }
+
+void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_rope(ctx, dst);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
+
diff --git a/ggml/src/ggml-sycl/rope.hpp b/ggml/src/ggml-sycl/rope.hpp
index a399bddb8a..8c7141aac5 100644
--- a/ggml/src/ggml-sycl/rope.hpp
+++ b/ggml/src/ggml-sycl/rope.hpp
@@ -15,6 +15,6 @@
 
 #include "common.hpp"
 
-void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
+void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst);
 
 #endif // GGML_SYCL_ROPE_HPP

From 1d735c0b4fa0551c51c2f4ac888dd9a01f447985 Mon Sep 17 00:00:00 2001
From: Diego Devesa 
Date: Mon, 21 Apr 2025 18:13:51 +0200
Subject: [PATCH 24/39] ggml : add SSE 4.2 and x64 base variant for CPUs
 without AVX (#12871)

* ggml : add SSE 4.2 variant for CPUs without AVX

* ggml : add x64 base ABI variant
---
 ggml/CMakeLists.txt                 |  1 +
 ggml/src/CMakeLists.txt             | 15 +++++++++------
 ggml/src/ggml-cpu/CMakeLists.txt    |  8 +++++---
 ggml/src/ggml-cpu/cpu-feats-x86.cpp |  2 +-
 4 files changed, 16 insertions(+), 10 deletions(-)

diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt
index 438c2a7309..61fe15a15f 100644
--- a/ggml/CMakeLists.txt
+++ b/ggml/CMakeLists.txt
@@ -107,6 +107,7 @@ message(DEBUG "INS_ENB             : ${INS_ENB}")
 option(GGML_CPU_HBM          "ggml: use memkind for CPU HBM" OFF)
 option(GGML_CPU_AARCH64      "ggml: use runtime weight conversion of Q4_0 to Q4_X_X" ON)
 option(GGML_CPU_KLEIDIAI     "ggml: use KleidiAI optimized kernels if applicable" OFF)
+option(GGML_SSE42            "ggml: enable SSE 4.2"          ${INS_ENB})
 option(GGML_AVX              "ggml: enable AVX"              ${INS_ENB})
 option(GGML_AVX_VNNI         "ggml: enable AVX-VNNI"         OFF)
 option(GGML_AVX2             "ggml: enable AVX2"             ${INS_ENB})
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index f00700da71..43d9fc4fe2 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -267,6 +267,7 @@ function(ggml_add_cpu_backend_variant tag_name)
     set(GGML_CPU_TAG_NAME ${tag_name})
     # other: OPENMP LLAMAFILE CPU_HBM
     foreach (feat NATIVE
+                  SSE42
                   AVX AVX2 BMI2 AVX_VNNI FMA F16C
                   AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16
                   AMX_TILE AMX_INT8 AMX_BF16)
@@ -286,14 +287,16 @@ if (GGML_CPU_ALL_VARIANTS)
     if (NOT GGML_BACKEND_DL)
         message(FATAL_ERROR "GGML_CPU_ALL_VARIANTS requires GGML_BACKEND_DL")
     endif()
-    ggml_add_cpu_backend_variant(sandybridge    AVX)
-    ggml_add_cpu_backend_variant(haswell        AVX F16C AVX2 BMI2 FMA)
-    ggml_add_cpu_backend_variant(skylakex       AVX F16C AVX2 BMI2 FMA AVX512)
-    ggml_add_cpu_backend_variant(icelake        AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
-    ggml_add_cpu_backend_variant(alderlake      AVX F16C AVX2 BMI2 FMA AVX_VNNI)
+    ggml_add_cpu_backend_variant(x64)
+    ggml_add_cpu_backend_variant(sse42        SSE42)
+    ggml_add_cpu_backend_variant(sandybridge  SSE42 AVX)
+    ggml_add_cpu_backend_variant(haswell      SSE42 AVX F16C AVX2 BMI2 FMA)
+    ggml_add_cpu_backend_variant(skylakex     SSE42 AVX F16C AVX2 BMI2 FMA AVX512)
+    ggml_add_cpu_backend_variant(icelake      SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI)
+    ggml_add_cpu_backend_variant(alderlake    SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI)
     if (NOT MSVC)
         # MSVC doesn't support AMX
-        ggml_add_cpu_backend_variant(sapphirerapids AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
+        ggml_add_cpu_backend_variant(sapphirerapids SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI AVX512_BF16 AMX_TILE AMX_INT8)
     endif()
 elseif (GGML_CPU)
     ggml_add_cpu_backend_variant_impl("")
diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt
index e73a3b69b5..6a652738c1 100644
--- a/ggml/src/ggml-cpu/CMakeLists.txt
+++ b/ggml/src/ggml-cpu/CMakeLists.txt
@@ -222,7 +222,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
             elseif (GGML_AVX)
                 list(APPEND ARCH_FLAGS /arch:AVX)
                 list(APPEND ARCH_DEFINITIONS GGML_AVX)
-            else ()
+            elseif (GGML_SSE42)
                 list(APPEND ARCH_FLAGS /arch:SSE4.2)
                 list(APPEND ARCH_DEFINITIONS GGML_SSE42)
             endif()
@@ -237,8 +237,10 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
             if (GGML_NATIVE)
                 list(APPEND ARCH_FLAGS -march=native)
             else ()
-                list(APPEND ARCH_FLAGS -msse4.2)
-                list(APPEND ARCH_DEFINITIONS GGML_SSE42)
+                if (GGML_SSE42)
+                    list(APPEND ARCH_FLAGS -msse4.2)
+                    list(APPEND ARCH_DEFINITIONS GGML_SSE42)
+                endif()
                 if (GGML_F16C)
                     list(APPEND ARCH_FLAGS -mf16c)
                     list(APPEND ARCH_DEFINITIONS GGML_F16C)
diff --git a/ggml/src/ggml-cpu/cpu-feats-x86.cpp b/ggml/src/ggml-cpu/cpu-feats-x86.cpp
index 902ee43466..d775a03638 100644
--- a/ggml/src/ggml-cpu/cpu-feats-x86.cpp
+++ b/ggml/src/ggml-cpu/cpu-feats-x86.cpp
@@ -263,7 +263,7 @@ void test_x86_is() {
 static int ggml_backend_cpu_x86_score() {
     // FIXME: this does not check for OS support
 
-    int score = 0;
+    int score = 1;
     cpuid_x86 is;
 
 #ifdef GGML_FMA

From 243453533e029334181dda50d911d5fc5a2b2486 Mon Sep 17 00:00:00 2001
From: Xuan-Son Nguyen 
Date: Tue, 22 Apr 2025 10:37:00 +0200
Subject: [PATCH 25/39] llava : update documentations (#13055)

* llava : update documentations

* fix typo
---
 common/arg.cpp                                |   3 +-
 .../multimodal/MobileVLM.md                   |  26 +-
 .../multimodal/gemma3.md                      |   7 +-
 .../multimodal/glmedge.md                     |   6 +-
 .../multimodal/granitevision.md               |   8 +-
 docs/multimodal/llava.md                      | 143 ++++++++
 .../multimodal/minicpmo2.6.md                 |   8 +-
 .../multimodal/minicpmv2.5.md                 |   8 +-
 .../multimodal/minicpmv2.6.md                 |   8 +-
 examples/llava/README.md                      | 177 ++--------
 examples/llava/android/adb_run.sh             |   2 +-
 .../llava/gemma3_convert_encoder_to_gguf.py   | 307 ------------------
 12 files changed, 212 insertions(+), 491 deletions(-)
 rename examples/llava/MobileVLM-README.md => docs/multimodal/MobileVLM.md (96%)
 rename examples/llava/README-gemma3.md => docs/multimodal/gemma3.md (82%)
 rename examples/llava/README-glmedge.md => docs/multimodal/glmedge.md (80%)
 rename examples/llava/README-granitevision.md => docs/multimodal/granitevision.md (92%)
 create mode 100644 docs/multimodal/llava.md
 rename examples/llava/README-minicpmo2.6.md => docs/multimodal/minicpmo2.6.md (73%)
 rename examples/llava/README-minicpmv2.5.md => docs/multimodal/minicpmv2.5.md (72%)
 rename examples/llava/README-minicpmv2.6.md => docs/multimodal/minicpmv2.6.md (71%)
 delete mode 100644 examples/llava/gemma3_convert_encoder_to_gguf.py

diff --git a/common/arg.cpp b/common/arg.cpp
index 80c318a0e5..1cfd0168d9 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -976,14 +976,13 @@ static void common_params_print_completion(common_params_context & ctx_arg) {
         "llama-gritlm",
         "llama-imatrix",
         "llama-infill",
-        "llama-llava-cli",
+        "llama-mtmd-cli",
         "llama-llava-clip-quantize-cli",
         "llama-lookahead",
         "llama-lookup",
         "llama-lookup-create",
         "llama-lookup-merge",
         "llama-lookup-stats",
-        "llama-minicpmv-cli",
         "llama-parallel",
         "llama-passkey",
         "llama-perplexity",
diff --git a/examples/llava/MobileVLM-README.md b/docs/multimodal/MobileVLM.md
similarity index 96%
rename from examples/llava/MobileVLM-README.md
rename to docs/multimodal/MobileVLM.md
index 4f783f3ce0..20ac02f7a8 100644
--- a/examples/llava/MobileVLM-README.md
+++ b/docs/multimodal/MobileVLM.md
@@ -9,15 +9,15 @@ The implementation is based on llava, and is compatible with llava and mobileVLM
 Notice: The overall process of model inference for both **MobileVLM** and **MobileVLM_V2** models is the same, but the process of model conversion is a little different. Therefore, using **MobileVLM-1.7B** as an example, the different conversion step will be shown.
 
 ## Usage
-Build with cmake or run `make llama-llava-cli` to build it.
 
-After building, run: `./llama-llava-cli` to see the usage. For example:
+Build the `llama-mtmd-cli` binary.
+
+After building, run: `./llama-mtmd-cli` to see the usage. For example:
 
 ```sh
-./llama-llava-cli -m MobileVLM-1.7B/ggml-model-q4_k.gguf \
+./llama-mtmd-cli -m MobileVLM-1.7B/ggml-model-q4_k.gguf \
     --mmproj MobileVLM-1.7B/mmproj-model-f16.gguf \
-    --image path/to/an/image.jpg \
-    -p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWho is the author of this book? Answer the question using a single word or phrase. ASSISTANT:"
+    --chat-template deepseek
 ```
 
 ## Model conversion
@@ -82,7 +82,7 @@ refer to `android/adb_run.sh`, modify resources' `name` and `path`
 ### case 1
 **input**
 ```sh
-/data/local/tmp/llama-llava-cli \
+/data/local/tmp/llama-mtmd-cli \
     -m /data/local/tmp/ggml-model-q4_k.gguf \
     --mmproj /data/local/tmp/mmproj-model-f16.gguf \
     -t 4 \
@@ -102,7 +102,7 @@ llama_print_timings:       total time =   34731.93 ms
 ### case 2
 **input**
 ```sh
-/data/local/tmp/llama-llava-cli \
+/data/local/tmp/llama-mtmd-cli \
     -m /data/local/tmp/ggml-model-q4_k.gguf \
     --mmproj /data/local/tmp/mmproj-model-f16.gguf \
     -t 4 \
@@ -123,10 +123,10 @@ llama_print_timings:       total time =   34570.79 ms
 
 ## Some result on Android with `Snapdragon 778G` chip
 ### MobileVLM-1.7B case
-#### llava-cli release-b2005
+#### mtmd-cli release-b2005
 **input**
 ```sh
-/data/local/tmp/llama-llava-cli \
+/data/local/tmp/llama-mtmd-cli \
     -m /data/local/tmp/ggml-model-q4_k.gguf \
     --mmproj /data/local/tmp/mmproj-model-f16.gguf \
     -t 4 \
@@ -147,7 +147,7 @@ llama_print_timings: prompt eval time =    8119.49 ms /   191 tokens (   42.51 m
 llama_print_timings:        eval time =    1005.75 ms /    14 runs   (   71.84 ms per token,    13.92 tokens per second)
 llama_print_timings:       total time =   28038.34 ms /   205 tokens
 ```
-#### llava-cli latest-version
+#### mtmd-cli latest-version
 **input**
 
 Just the same as above.
@@ -169,7 +169,7 @@ llama_print_timings:        eval time =   43894.02 ms /    13 runs   ( 3376.46 m
 llama_print_timings:       total time =  865441.76 ms /   204 tokens
 ```
 ### MobileVLM_V2-1.7B case
-#### llava-cli release-2005b
+#### mtmd-cli release-2005b
 **input**
 
 Just the same as above.
@@ -200,7 +200,7 @@ make GGML_CUDA=1 CUDA_DOCKER_ARCH=sm_87 GGML_CUDA_F16=1 -j 32
 ### case 1
 **input**
 ```sh
-./llama-llava-cli \
+./llama-mtmd-cli \
     -m /data/local/tmp/ggml-model-q4_k.gguf \
     --mmproj /data/local/tmp/mmproj-model-f16.gguf \
     --image /data/local/tmp/demo.jpeg \
@@ -224,7 +224,7 @@ llama_print_timings:       total time =    1352.63 ms /   252 tokens
 ### case 2
 **input**
 ```sh
-./llama-llava-cli \
+./llama-mtmd-cli \
     -m /data/local/tmp/ggml-model-q4_k.gguf \
     --mmproj /data/local/tmp/mmproj-model-f16.gguf \
     -p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWhat is in the image? ASSISTANT:" \
diff --git a/examples/llava/README-gemma3.md b/docs/multimodal/gemma3.md
similarity index 82%
rename from examples/llava/README-gemma3.md
rename to docs/multimodal/gemma3.md
index 3c25ee2583..8fa077de71 100644
--- a/examples/llava/README-gemma3.md
+++ b/docs/multimodal/gemma3.md
@@ -26,11 +26,12 @@ llama-gemma3-cli -hf ggml-org/gemma-3-27b-it-GGUF
 
 ## How to get mmproj.gguf?
 
+Simply to add `--mmproj` in when converting model via `convert_hf_to_gguf.py`:
+
 ```bash
 cd gemma-3-4b-it
-python ../llama.cpp/examples/llava/gemma3_convert_encoder_to_gguf.py .
-
-# output file is mmproj.gguf
+python ../llama.cpp/convert_hf_to_gguf.py --outfile model.gguf --outtype f16 --mmproj .
+# output file: mmproj-model.gguf
 ```
 
 ## How to run it?
diff --git a/examples/llava/README-glmedge.md b/docs/multimodal/glmedge.md
similarity index 80%
rename from examples/llava/README-glmedge.md
rename to docs/multimodal/glmedge.md
index 603d014745..af6b696a8a 100644
--- a/examples/llava/README-glmedge.md
+++ b/docs/multimodal/glmedge.md
@@ -3,12 +3,12 @@
 Currently this implementation supports [glm-edge-v-2b](https://huggingface.co/THUDM/glm-edge-v-2b) and [glm-edge-v-5b](https://huggingface.co/THUDM/glm-edge-v-5b).
 
 ## Usage
-Build with cmake or run `make llama-llava-cli` to build it.
+Build the `llama-mtmd-cli` binary.
 
-After building, run: `./llama-llava-cli` to see the usage. For example:
+After building, run: `./llama-mtmd-cli` to see the usage. For example:
 
 ```sh
-./llama-llava-cli -m model_path/ggml-model-f16.gguf --mmproj model_path/mmproj-model-f16.gguf --image img_path/image.jpg -p "<|system|>\n system prompt <|user|>\n prompt <|assistant|>\n"
+./llama-mtmd-cli -m model_path/ggml-model-f16.gguf --mmproj model_path/mmproj-model-f16.gguf
 ```
 
 **note**: A lower temperature like 0.1 is recommended for better quality. add `--temp 0.1` to the command to do so.
diff --git a/examples/llava/README-granitevision.md b/docs/multimodal/granitevision.md
similarity index 92%
rename from examples/llava/README-granitevision.md
rename to docs/multimodal/granitevision.md
index f08a21cc17..3118fe0cdc 100644
--- a/examples/llava/README-granitevision.md
+++ b/docs/multimodal/granitevision.md
@@ -176,15 +176,11 @@ Note that currently you cannot quantize the visual encoder because granite visio
 
 
 ### 5. Running the Model in Llama cpp
-Build llama cpp normally; you should have a target binary named `llama-llava-cli`, which you can pass two binaries to. As an example, we pass the the llama.cpp banner.
+Build llama cpp normally; you should have a target binary named `llama-mtmd-cli`, which you can pass two binaries to. As an example, we pass the the llama.cpp banner.
 
 ```bash
-$ ./build/bin/llama-llava-cli -m $LLM_GGUF_PATH \
+$ ./build/bin/llama-mtmd-cli -m $LLM_GGUF_PATH \
     --mmproj $VISUAL_GGUF_PATH \
-    --image ./media/llama0-banner.png \
     -c 16384 \
-    -p "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<|user|>\n\\nWhat does the text in this image say?\n<|assistant|>\n" \
     --temp 0
 ```
-
-Sample output: `The text in the image reads "LLAMA C++ Can it run DOOM Llama?"`
diff --git a/docs/multimodal/llava.md b/docs/multimodal/llava.md
new file mode 100644
index 0000000000..c5bdc82158
--- /dev/null
+++ b/docs/multimodal/llava.md
@@ -0,0 +1,143 @@
+# LLaVA
+
+Currently this implementation supports [llava-v1.5](https://huggingface.co/liuhaotian/llava-v1.5-7b) variants,
+as well as llava-1.6 [llava-v1.6](https://huggingface.co/collections/liuhaotian/llava-16-65b9e40155f60fd046a5ccf2) variants.
+
+The pre-converted [7b](https://huggingface.co/mys/ggml_llava-v1.5-7b)
+and [13b](https://huggingface.co/mys/ggml_llava-v1.5-13b)
+models are available.
+For llava-1.6 a variety of prepared gguf models are available as well [7b-34b](https://huggingface.co/cmp-nct/llava-1.6-gguf)
+
+After API is confirmed, more models will be supported / uploaded.
+
+## Usage
+Build the `llama-mtmd-cli` binary.
+
+After building, run: `./llama-mtmd-cli` to see the usage. For example:
+
+```sh
+./llama-mtmd-cli -m ../llava-v1.5-7b/ggml-model-f16.gguf \
+    --mmproj ../llava-v1.5-7b/mmproj-model-f16.gguf \
+    --chat-template vicuna
+```
+
+**note**: A lower temperature like 0.1 is recommended for better quality. add `--temp 0.1` to the command to do so.
+**note**: For GPU offloading ensure to use the `-ngl` flag just like usual
+
+## LLaVA 1.5
+
+1. Clone a LLaVA and a CLIP model ([available options](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)). For example:
+
+```sh
+git clone https://huggingface.co/liuhaotian/llava-v1.5-7b
+
+git clone https://huggingface.co/openai/clip-vit-large-patch14-336
+```
+
+2. Install the required Python packages:
+
+```sh
+pip install -r examples/llava/requirements.txt
+```
+
+3. Use `llava_surgery.py` to split the LLaVA model to LLaMA and multimodel projector constituents:
+
+```sh
+python ./examples/llava/llava_surgery.py -m ../llava-v1.5-7b
+```
+
+4. Use `convert_image_encoder_to_gguf.py` to convert the LLaVA image encoder to GGUF:
+
+```sh
+python ./examples/llava/convert_image_encoder_to_gguf.py -m ../clip-vit-large-patch14-336 --llava-projector ../llava-v1.5-7b/llava.projector --output-dir ../llava-v1.5-7b
+```
+
+5. Use `examples/convert_legacy_llama.py` to convert the LLaMA part of LLaVA to GGUF:
+
+```sh
+python ./examples/convert_legacy_llama.py ../llava-v1.5-7b --skip-unknown
+```
+
+Now both the LLaMA part and the image encoder are in the `llava-v1.5-7b` directory.
+
+## LLaVA 1.6 gguf conversion
+1) First clone a LLaVA 1.6 model:
+```console
+git clone https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b
+```
+
+2) Install the required Python packages:
+
+```sh
+pip install -r examples/llava/requirements.txt
+```
+
+3) Use `llava_surgery_v2.py` which also supports llava-1.5 variants pytorch as well as safetensor models:
+```console
+python examples/llava/llava_surgery_v2.py -C -m ../llava-v1.6-vicuna-7b/
+```
+- you will find a llava.projector and a llava.clip file in your model directory
+
+4) Copy the llava.clip file into a subdirectory (like vit), rename it to pytorch_model.bin and add a fitting vit configuration to the directory:
+```console
+mkdir vit
+cp ../llava-v1.6-vicuna-7b/llava.clip vit/pytorch_model.bin
+cp ../llava-v1.6-vicuna-7b/llava.projector vit/
+curl -s -q https://huggingface.co/cmp-nct/llava-1.6-gguf/raw/main/config_vit.json -o vit/config.json
+```
+
+5) Create the visual gguf model:
+```console
+python ./examples/llava/convert_image_encoder_to_gguf.py -m vit --llava-projector vit/llava.projector --output-dir vit --clip-model-is-vision
+```
+- This is similar to llava-1.5, the difference is that we tell the encoder that we are working with the pure vision model part of CLIP
+
+6) Then convert the model to gguf format:
+```console
+python ./examples/convert_legacy_llama.py ../llava-v1.6-vicuna-7b/ --skip-unknown
+```
+
+7) And finally we can run the llava cli using the 1.6 model version:
+```console
+./llama-mtmd-cli -m ../llava-v1.6-vicuna-7b/ggml-model-f16.gguf --mmproj vit/mmproj-model-f16.gguf
+```
+
+**note** llava-1.6 needs more context than llava-1.5, at least 3000 is needed (just run it at -c 4096)
+
+**note** llava-1.6 greatly benefits from batched prompt processing (defaults work)
+
+**note** if the language model in step `6)` is incompatible with the legacy conversion script, the easiest way handle the LLM model conversion is to load the model in transformers, and export only the LLM from the llava next model.
+
+```python
+import os
+import transformers
+
+model_path = ...
+llm_export_path = ...
+
+tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
+model = transformers.AutoModelForImageTextToText.from_pretrained(model_path)
+
+tokenizer.save_pretrained(llm_export_path)
+model.language_model.save_pretrained(llm_export_path)
+```
+
+Then, you can convert the LLM using the `convert_hf_to_gguf.py` script, which handles more LLM architectures.
+
+## Chat template
+
+For llava-1.5 and llava-1.6, you need to use `vicuna` chat template. Simply add `--chat-template vicuna` to activate this template.
+
+
+## How to know if you are running in llava-1.5 or llava-1.6 mode
+
+When running llava-cli you will see a visual information right before the prompt is being processed:
+
+**Llava-1.5:**
+`encode_image_with_clip: image embedding created: 576 tokens`
+
+**Llava-1.6 (anything above 576):**
+`encode_image_with_clip: image embedding created: 2880 tokens`
+
+
+Alternatively just pay notice to how many "tokens" have been used for your prompt, it will also show 1000+ tokens for llava-1.6
diff --git a/examples/llava/README-minicpmo2.6.md b/docs/multimodal/minicpmo2.6.md
similarity index 73%
rename from examples/llava/README-minicpmo2.6.md
rename to docs/multimodal/minicpmo2.6.md
index 48c4232383..de470d8a82 100644
--- a/examples/llava/README-minicpmo2.6.md
+++ b/docs/multimodal/minicpmo2.6.md
@@ -40,9 +40,9 @@ python ./convert_hf_to_gguf.py ../MiniCPM-o-2_6/model
 
 Inference on Linux or Mac
 ```bash
-# run f16 version
-./build/bin/llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
+# run in single-turn mode
+./build/bin/llama-mtmd-cli -m ../MiniCPM-o-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
 
-# run quantized int4 version
-./build/bin/llama-minicpmv-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg  -p "What is in the image?"
+# run in conversation mode
+./build/bin/llama-mtmd-cli -m ../MiniCPM-o-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-o-2_6/mmproj-model-f16.gguf
 ```
diff --git a/examples/llava/README-minicpmv2.5.md b/docs/multimodal/minicpmv2.5.md
similarity index 72%
rename from examples/llava/README-minicpmv2.5.md
rename to docs/multimodal/minicpmv2.5.md
index 6bfe7abd16..7a6879d395 100644
--- a/examples/llava/README-minicpmv2.5.md
+++ b/docs/multimodal/minicpmv2.5.md
@@ -39,9 +39,9 @@ python ./convert_hf_to_gguf.py ../MiniCPM-Llama3-V-2_5/model
 
 Inference on Linux or Mac
 ```bash
-# run f16 version
-./build/bin/llama-minicpmv-cli -m ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
+# run in single-turn mode
+./build/bin/llama-mtmd-cli -m ../MiniCPM-Llama3-V-2_5/model/model-8B-F16.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
 
-# run quantized int4 version
-./build/bin/llama-minicpmv-cli -m ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg  -p "What is in the image?"
+# run in conversation mode
+./build/bin/llama-mtmd-cli -m ../MiniCPM-Llama3-V-2_5/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-Llama3-V-2_5/mmproj-model-f16.gguf
 ```
diff --git a/examples/llava/README-minicpmv2.6.md b/docs/multimodal/minicpmv2.6.md
similarity index 71%
rename from examples/llava/README-minicpmv2.6.md
rename to docs/multimodal/minicpmv2.6.md
index 2df39cdbac..410a5dd177 100644
--- a/examples/llava/README-minicpmv2.6.md
+++ b/docs/multimodal/minicpmv2.6.md
@@ -39,9 +39,9 @@ python ./convert_hf_to_gguf.py ../MiniCPM-V-2_6/model
 
 Inference on Linux or Mac
 ```bash
-# run f16 version
-./build/bin/llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
+# run in single-turn mode
+./build/bin/llama-mtmd-cli -m ../MiniCPM-V-2_6/model/ggml-model-f16.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg -p "What is in the image?"
 
-# run quantized int4 version
-./build/bin/llama-minicpmv-cli -m ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf -c 4096 --temp 0.7 --top-p 0.8 --top-k 100 --repeat-penalty 1.05 --image xx.jpg  -p "What is in the image?"
+# run in conversation mode
+./build/bin/llama-mtmd-cli -m ../MiniCPM-V-2_6/model/ggml-model-Q4_K_M.gguf --mmproj ../MiniCPM-V-2_6/mmproj-model-f16.gguf
 ```
diff --git a/examples/llava/README.md b/examples/llava/README.md
index 0e3c320320..cadbc53fab 100644
--- a/examples/llava/README.md
+++ b/examples/llava/README.md
@@ -1,158 +1,47 @@
-# LLaVA
+# Multimodal Support in llama.cpp
 
-Currently this implementation supports [llava-v1.5](https://huggingface.co/liuhaotian/llava-v1.5-7b) variants,
-as well as llava-1.6 [llava-v1.6](https://huggingface.co/collections/liuhaotian/llava-16-65b9e40155f60fd046a5ccf2) variants.
+This directory provides multimodal capabilities for `llama.cpp`. Initially intended as a showcase for running LLaVA models, its scope has expanded significantly over time to include various other vision-capable models. As a result, LLaVA is no longer the only multimodal architecture supported.
 
-The pre-converted [7b](https://huggingface.co/mys/ggml_llava-v1.5-7b)
-and [13b](https://huggingface.co/mys/ggml_llava-v1.5-13b)
-models are available.
-For llava-1.6 a variety of prepared gguf models are available as well [7b-34b](https://huggingface.co/cmp-nct/llava-1.6-gguf)
+> [!IMPORTANT]
+>
+> Multimodal support can be viewed as a sub-project within `llama.cpp`. It is under **very heavy development**, and **breaking changes are expected**.
 
-After API is confirmed, more models will be supported / uploaded.
+The naming and structure related to multimodal support have evolved, which might cause some confusion. Here's a brief timeline to clarify:
 
-## Usage
-Build with cmake or run `make llama-llava-cli` to build it.
+- [#3436](https://github.com/ggml-org/llama.cpp/pull/3436): Initial support for LLaVA 1.5 was added, introducing `llava.cpp` and `clip.cpp`. The `llava-cli` binary was created for model interaction.
+- [#4954](https://github.com/ggml-org/llama.cpp/pull/4954): Support for MobileVLM was added, becoming the second vision model supported. This built upon the existing `llava.cpp`, `clip.cpp`, and `llava-cli` infrastructure.
+- **Expansion & Fragmentation:** Many new models were subsequently added (e.g., [#7599](https://github.com/ggml-org/llama.cpp/pull/7599), [#10361](https://github.com/ggml-org/llama.cpp/pull/10361), [#12344](https://github.com/ggml-org/llama.cpp/pull/12344), and others). However, `llava-cli` lacked support for the increasingly complex chat templates required by these models. This led to the creation of model-specific binaries like `qwen2vl-cli`, `minicpmv-cli`, and `gemma3-cli`. While functional, this proliferation of command-line tools became confusing for users.
+- [#12849](https://github.com/ggml-org/llama.cpp/pull/12849): `libmtmd` was introduced as a replacement for `llava.cpp`. Its goals include providing a single, unified command-line interface, improving the user/developer experience (UX/DX), and supporting both audio and image inputs.
+- [#13012](https://github.com/ggml-org/llama.cpp/pull/13012): `mtmd-cli` was added, consolidating the various model-specific CLIs into a single tool powered by `libmtmd`.
 
-After building, run: `./llama-llava-cli` to see the usage. For example:
+## How it works and what is `mmproj`?
 
-```sh
-./llama-llava-cli -m ../llava-v1.5-7b/ggml-model-f16.gguf --mmproj ../llava-v1.5-7b/mmproj-model-f16.gguf --image path/to/an/image.jpg
-```
+Multimodal support in `llama.cpp` works by encoding images into embeddings using a separate model component, and then feeding these embeddings into the language model.
 
-**note**: A lower temperature like 0.1 is recommended for better quality. add `--temp 0.1` to the command to do so.
-**note**: For GPU offloading ensure to use the `-ngl` flag just like usual
+This approach keeps the multimodal components distinct from the core `libllama` library. Separating these allows for faster, independent development cycles. While many modern vision models are based on Vision Transformers (ViTs), their specific pre-processing and projection steps can vary significantly. Integrating this diverse complexity directly into `libllama` is currently challenging.
 
-## LLaVA 1.5
+Consequently, running a multimodal model typically requires two GGUF files:
+1.  The standard language model file.
+2.  A corresponding **multimodal projector (`mmproj`)** file, which handles the image encoding and projection.
 
-1. Clone a LLaVA and a CLIP model ([available options](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)). For example:
+## What is `libmtmd`?
 
-```sh
-git clone https://huggingface.co/liuhaotian/llava-v1.5-7b
+As outlined in the history, `libmtmd` is the modern library designed to replace the original `llava.cpp` implementation for handling multimodal inputs.
 
-git clone https://huggingface.co/openai/clip-vit-large-patch14-336
-```
+Built upon `clip.cpp` (similar to `llava.cpp`), `libmtmd` offers several advantages:
+- **Unified Interface:** Aims to consolidate interaction for various multimodal models.
+- **Improved UX/DX:** Features a more intuitive API, inspired by the `Processor` class in the Hugging Face `transformers` library.
+- **Flexibility:** Designed to support multiple input types (text, audio, images) while respecting the wide variety of chat templates used by different models.
 
-2. Install the required Python packages:
+## How to obtain `mmproj`
 
-```sh
-pip install -r examples/llava/requirements.txt
-```
+Multimodal projector (`mmproj`) files are specific to each model architecture. Please refer to the relevant guide for instructions on how to obtain or create them:
 
-3. Use `llava_surgery.py` to split the LLaVA model to LLaMA and multimodel projector constituents:
-
-```sh
-python ./examples/llava/llava_surgery.py -m ../llava-v1.5-7b
-```
-
-4. Use `convert_image_encoder_to_gguf.py` to convert the LLaVA image encoder to GGUF:
-
-```sh
-python ./examples/llava/convert_image_encoder_to_gguf.py -m ../clip-vit-large-patch14-336 --llava-projector ../llava-v1.5-7b/llava.projector --output-dir ../llava-v1.5-7b
-```
-
-5. Use `examples/convert_legacy_llama.py` to convert the LLaMA part of LLaVA to GGUF:
-
-```sh
-python ./examples/convert_legacy_llama.py ../llava-v1.5-7b --skip-unknown
-```
-
-Now both the LLaMA part and the image encoder are in the `llava-v1.5-7b` directory.
-
-## LLaVA 1.6 gguf conversion
-1) First clone a LLaVA 1.6 model:
-```console
-git clone https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b
-```
-
-2) Install the required Python packages:
-
-```sh
-pip install -r examples/llava/requirements.txt
-```
-
-3) Use `llava_surgery_v2.py` which also supports llava-1.5 variants pytorch as well as safetensor models:
-```console
-python examples/llava/llava_surgery_v2.py -C -m ../llava-v1.6-vicuna-7b/
-```
-- you will find a llava.projector and a llava.clip file in your model directory
-
-4) Copy the llava.clip file into a subdirectory (like vit), rename it to pytorch_model.bin and add a fitting vit configuration to the directory:
-```console
-mkdir vit
-cp ../llava-v1.6-vicuna-7b/llava.clip vit/pytorch_model.bin
-cp ../llava-v1.6-vicuna-7b/llava.projector vit/
-curl -s -q https://huggingface.co/cmp-nct/llava-1.6-gguf/raw/main/config_vit.json -o vit/config.json
-```
-
-5) Create the visual gguf model:
-```console
-python ./examples/llava/convert_image_encoder_to_gguf.py -m vit --llava-projector vit/llava.projector --output-dir vit --clip-model-is-vision
-```
-- This is similar to llava-1.5, the difference is that we tell the encoder that we are working with the pure vision model part of CLIP
-
-6) Then convert the model to gguf format:
-```console
-python ./examples/convert_legacy_llama.py ../llava-v1.6-vicuna-7b/ --skip-unknown
-```
-
-7) And finally we can run the llava cli using the 1.6 model version:
-```console
-./llama-llava-cli -m ../llava-v1.6-vicuna-7b/ggml-model-f16.gguf --mmproj vit/mmproj-model-f16.gguf --image some-image.jpg -c 4096
-```
-
-**note** llava-1.6 needs more context than llava-1.5, at least 3000 is needed (just run it at -c 4096)
-
-**note** llava-1.6 greatly benefits from batched prompt processing (defaults work)
-
-**note** if the language model in step `6)` is incompatible with the legacy conversion script, the easiest way handle the LLM model conversion is to load the model in transformers, and export only the LLM from the llava next model.
-
-```python
-import os
-import transformers
-
-model_path = ...
-llm_export_path = ...
-
-tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
-model = transformers.AutoModelForImageTextToText.from_pretrained(model_path)
-
-tokenizer.save_pretrained(llm_export_path)
-model.language_model.save_pretrained(llm_export_path)
-```
-
-Then, you can convert the LLM using the `convert_hf_to_gguf.py` script, which handles more LLM architectures.
-
-## llava-cli templating and llava-1.6 prompting
-
-llava-1.5 models all use the same vicuna prompt, here you can just add your image question like `-p "Provide a full description."`
-For llava-1.5 models which are not vicuna (mistral and Yi) you need to adapt system prompt as well as user prompt, for this purpose llava-cli has a basic templating system:
-
-**For Mistral and using llava-cli binary:**
-Add this: `-p "\nUSER:\nProvide a full description.\nASSISTANT:\n"`
-The mistral template for llava-1.6 seems to be no system print and a USER/ASSISTANT role
-
-**For the 34B this should work:**
-Add this: `-e -p <|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n\nProvide a full description.<|im_end|><|im_start|>assistant\n`
-
-
-## How to know if you are running in llava-1.5 or llava-1.6 mode
-
-When running llava-cli you will see a visual information right before the prompt is being processed:
-
-**Llava-1.5:**
-`encode_image_with_clip: image embedding created: 576 tokens`
-
-**Llava-1.6 (anything above 576):**
-`encode_image_with_clip: image embedding created: 2880 tokens`
-
-
-Alternatively just pay notice to how many "tokens" have been used for your prompt, it will also show 1000+ tokens for llava-1.6
-
-
-
-
-## TODO
-
-- [x] Support non-CPU backend for the image encoding part.
-- [ ] Support different sampling methods.
-- [ ] Support more model variants.
+- [LLaVA](../../docs/multimodal/llava.md)
+- [MobileVLM](../../docs/multimodal/MobileVLM.md)
+- [GLM-Edge](../../docs/multimodal/glmedge.md)
+- [MiniCPM-V 2.5](../../docs/multimodal/minicpmv2.5.md)
+- [MiniCPM-V 2.6](../../docs/multimodal/minicpmv2.6.md)
+- [MiniCPM-o 2.6](../../docs/multimodal/minicpmo2.6.md)
+- [IBM Granite Vision](../../docs/multimodal/granitevision.md)
+- [Google Gemma 3](../../docs/multimodal/gemma3.md)
diff --git a/examples/llava/android/adb_run.sh b/examples/llava/android/adb_run.sh
index 45ccf8d70d..a24d6787d9 100755
--- a/examples/llava/android/adb_run.sh
+++ b/examples/llava/android/adb_run.sh
@@ -10,7 +10,7 @@ prompt="A chat between a curious user and an artificial intelligence assistant.
 # prompt="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: \nWhat is in the image? ASSISTANT:"
 
 program_dir="build_64/bin"
-binName="llama-llava-cli"
+binName="llama-mtmd-cli"
 n_threads=4
 
 
diff --git a/examples/llava/gemma3_convert_encoder_to_gguf.py b/examples/llava/gemma3_convert_encoder_to_gguf.py
deleted file mode 100644
index 241b526b9e..0000000000
--- a/examples/llava/gemma3_convert_encoder_to_gguf.py
+++ /dev/null
@@ -1,307 +0,0 @@
-import gguf
-import argparse
-import logging
-import sys
-import torch
-import json
-import os
-import numpy as np
-from typing import cast, ContextManager, Any, Iterator
-from pathlib import Path
-from torch import Tensor
-
-logger = logging.getLogger("gemma3-mmproj")
-
-
-# (copied from convert_hf_to_gguf.py)
-# tree of lazy tensors
-class LazyTorchTensor(gguf.LazyBase):
-    _tensor_type = torch.Tensor
-    # to keep the type-checker happy
-    dtype: torch.dtype
-    shape: torch.Size
-
-    # only used when converting a torch.Tensor to a np.ndarray
-    _dtype_map: dict[torch.dtype, type] = {
-        torch.float16: np.float16,
-        torch.float32: np.float32,
-    }
-
-    # used for safetensors slices
-    # ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
-    # TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
-    _dtype_str_map: dict[str, torch.dtype] = {
-        "F64": torch.float64,
-        "F32": torch.float32,
-        "BF16": torch.bfloat16,
-        "F16": torch.float16,
-        # "U64": torch.uint64,
-        "I64": torch.int64,
-        # "U32": torch.uint32,
-        "I32": torch.int32,
-        # "U16": torch.uint16,
-        "I16": torch.int16,
-        "U8": torch.uint8,
-        "I8": torch.int8,
-        "BOOL": torch.bool,
-        "F8_E4M3": torch.float8_e4m3fn,
-        "F8_E5M2": torch.float8_e5m2,
-    }
-
-    def numpy(self) -> gguf.LazyNumpyTensor:
-        dtype = self._dtype_map[self.dtype]
-        return gguf.LazyNumpyTensor(
-            meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape),
-            args=(self,),
-            func=(lambda s: s.numpy())
-        )
-
-    @classmethod
-    def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor:
-        return torch.empty(size=shape, dtype=dtype, device="meta")
-
-    @classmethod
-    def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
-        dtype = cls._dtype_str_map[st_slice.get_dtype()]
-        shape: tuple[int, ...] = tuple(st_slice.get_shape())
-        lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:])
-        return cast(torch.Tensor, lazy)
-
-    @classmethod
-    def __torch_function__(cls, func, types, args=(), kwargs=None):
-        del types  # unused
-
-        if kwargs is None:
-            kwargs = {}
-
-        if func is torch.Tensor.numpy:
-            return args[0].numpy()
-
-        return cls._wrap_fn(func)(*args, **kwargs)
-
-
-class Gemma3VisionTower:
-    hparams: dict
-    gguf_writer: gguf.GGUFWriter
-    fname_out: Path
-    ftype: gguf.LlamaFileType
-
-    @staticmethod
-    def load_hparams(dir_model: Path):
-        with open(dir_model / "config.json", "r", encoding="utf-8") as f:
-            return json.load(f)
-
-    @staticmethod
-    def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]:
-        part_names: list[str] = []
-        for filename in os.listdir(dir_model):
-            if filename.startswith(prefix) and filename.endswith(suffix):
-                part_names.append(filename)
-        part_names.sort()
-        return part_names
-
-    def __init__(self,
-                 dir_model: Path,
-                 fname_out: Path,
-                 ftype: gguf.LlamaFileType,
-                 is_big_endian: bool,):
-        hparams = Gemma3VisionTower.load_hparams(dir_model)
-        self.hparams = hparams
-        self.fname_out = fname_out
-        self.ftype = ftype
-        endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
-        self.gguf_writer = gguf.GGUFWriter(path=None, arch="clip", endianess=endianess)
-
-        text_config = hparams["text_config"]
-        vision_config = hparams["vision_config"]
-
-        assert hparams["architectures"][0] == "Gemma3ForConditionalGeneration"
-        assert text_config is not None
-        assert vision_config is not None
-
-        self.gguf_writer.add_string ("clip.projector_type",              "gemma3")
-        self.gguf_writer.add_bool   ("clip.has_text_encoder",            False)
-        self.gguf_writer.add_bool   ("clip.has_vision_encoder",          True)
-        self.gguf_writer.add_bool   ("clip.has_llava_projector",         False) # legacy
-        self.gguf_writer.add_uint32 ("clip.vision.image_size",           vision_config["image_size"])
-        self.gguf_writer.add_uint32 ("clip.vision.patch_size",           vision_config["patch_size"])
-        self.gguf_writer.add_uint32 ("clip.vision.embedding_length",     vision_config["hidden_size"])
-        self.gguf_writer.add_uint32 ("clip.vision.feed_forward_length",  vision_config["intermediate_size"])
-        self.gguf_writer.add_uint32 ("clip.vision.projection_dim",       text_config["hidden_size"])
-        self.gguf_writer.add_uint32 ("clip.vision.block_count",          vision_config["num_hidden_layers"])
-        self.gguf_writer.add_uint32 ("clip.vision.attention.head_count", vision_config["num_attention_heads"])
-        self.gguf_writer.add_float32("clip.vision.attention.layer_norm_epsilon", vision_config.get("layer_norm_eps", 1e-6))
-        # default values taken from HF tranformers code
-        self.gguf_writer.add_array  ("clip.vision.image_mean", [0.5, 0.5, 0.5])
-        self.gguf_writer.add_array  ("clip.vision.image_std",  [0.5, 0.5, 0.5])
-        self.gguf_writer.add_bool   ("clip.use_gelu", True)
-
-        # load tensors
-        for name, data_torch in self.get_tensors(dir_model):
-            # convert any unsupported data types to float32
-            if data_torch.dtype not in (torch.float16, torch.float32):
-                data_torch = data_torch.to(torch.float32)
-            self.add_tensor(name, data_torch)
-
-    def get_tensors(self, dir_model: Path) -> Iterator[tuple[str, Tensor]]:
-        part_names = Gemma3VisionTower.get_model_part_names(dir_model, "model", ".safetensors")
-        tensor_names_from_parts: set[str] = set()
-        for part_name in part_names:
-            logger.info(f"gguf: loading model part '{part_name}'")
-            from safetensors import safe_open
-            ctx = cast(ContextManager[Any], safe_open(dir_model / part_name, framework="pt", device="cpu"))
-            with ctx as model_part:
-                tensor_names_from_parts.update(model_part.keys())
-
-                for name in model_part.keys():
-                    data = model_part.get_slice(name)
-                    data = LazyTorchTensor.from_safetensors_slice(data)
-                    yield name, data
-
-    def add_tensor(self, name: str, data_torch: Tensor):
-        is_1d = len(data_torch.shape) == 1
-        is_embd = ".embeddings." in name
-        old_dtype = data_torch.dtype
-        can_quantize = not is_1d and not is_embd
-        data_qtype = gguf.GGMLQuantizationType.F32
-
-        # this is to support old checkpoint
-        # TODO: remove this when we have the final model
-        name = name.replace("vision_model.vision_model.", "vision_tower.vision_model.")
-        name = name.replace("multimodal_projector.", "multi_modal_projector.")
-
-        # filter only vision tensors
-        if not name.startswith("vision_tower.vision_model.") and not name.startswith("multi_modal_projector."):
-            return
-        # prefix
-        name = name.replace("vision_tower.vision_model.encoder.layers.", "v.blk.")
-        name = name.replace("vision_tower.vision_model.", "v.")
-        # projector and input embd
-        name = name.replace(".embeddings.patch_embedding.", ".patch_embd.")
-        name = name.replace(".embeddings.position_embedding.", ".position_embd.")
-        name = name.replace(
-            "multi_modal_projector.mm_input_projection_weight",
-            "mm.input_projection.weight"
-        )
-        name = name.replace(
-            "multi_modal_projector.mm_soft_emb_norm.weight",
-            "mm.soft_emb_norm.weight"
-        )
-        name = name.replace("post_layernorm.", "post_ln.")
-        # each block
-        name = name.replace(".self_attn.k_proj.", ".attn_k.")
-        name = name.replace(".self_attn.v_proj.", ".attn_v.")
-        name = name.replace(".self_attn.q_proj.", ".attn_q.")
-        name = name.replace(".self_attn.out_proj.", ".attn_out.")
-        name = name.replace(".layer_norm1.", ".ln1.")
-        name = name.replace(".layer_norm2.", ".ln2.")
-        name = name.replace(".mlp.fc1.", ".ffn_down.")
-        name = name.replace(".mlp.fc2.", ".ffn_up.")
-
-        if can_quantize:
-            if self.ftype == gguf.LlamaFileType.ALL_F32:
-                data_qtype = gguf.GGMLQuantizationType.F32
-            elif self.ftype == gguf.LlamaFileType.MOSTLY_F16:
-                data_qtype = gguf.GGMLQuantizationType.F16
-            elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
-                data_qtype = gguf.GGMLQuantizationType.BF16
-            elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0:
-                data_qtype = gguf.GGMLQuantizationType.Q8_0
-            else:
-                raise ValueError(f"Unsupported file type: {self.ftype}")
-
-        # corrent norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector
-        # the other norm values are part of SigLIP model, and they are already correct
-        # ref code: Gemma3RMSNorm
-        if "soft_emb_norm.weight" in name:
-            logger.info(f"Correcting norm value for '{name}'")
-            data_torch = data_torch + 1
-
-        data = data_torch.numpy()
-
-        try:
-            data = gguf.quants.quantize(data, data_qtype)
-        except Exception as e:
-            logger.error(f"Error quantizing tensor '{name}': {e}, fallback to F16")
-            data_qtype = gguf.GGMLQuantizationType.F16
-            data = gguf.quants.quantize(data, data_qtype)
-
-        # reverse shape to make it similar to the internal ggml dimension order
-        shape_str = f"{{{', '.join(str(n) for n in reversed(data_torch.shape))}}}"
-        logger.info(f"{f'%-32s' % f'{name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}")
-
-        self.gguf_writer.add_tensor(name, data, raw_dtype=data_qtype)
-
-    def write(self):
-        self.gguf_writer.write_header_to_file(path=self.fname_out)
-        self.gguf_writer.write_kv_data_to_file()
-        self.gguf_writer.write_tensors_to_file(progress=True)
-        self.gguf_writer.close()
-
-def parse_args() -> argparse.Namespace:
-    parser = argparse.ArgumentParser(
-        description="Convert Gemma 3 vision tower safetensors to GGUF format",)
-    parser.add_argument(
-        "--outfile", type=Path, default="mmproj.gguf",
-        help="path to write to",
-    )
-    parser.add_argument(
-        "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16",
-        help="output format",
-    )
-    parser.add_argument(
-        "--bigendian", action="store_true",
-        help="model is executed on big endian machine",
-    )
-    parser.add_argument(
-        "model", type=Path,
-        help="directory containing model file",
-        nargs="?",
-    )
-    parser.add_argument(
-        "--verbose", action="store_true",
-        help="increase output verbosity",
-    )
-
-    args = parser.parse_args()
-    if args.model is None:
-        parser.error("the following arguments are required: model")
-    return args
-
-
-def main() -> None:
-    args = parse_args()
-
-    if args.verbose:
-        logging.basicConfig(level=logging.DEBUG)
-    else:
-        logging.basicConfig(level=logging.INFO)
-
-    dir_model = args.model
-
-    if not dir_model.is_dir():
-        logger.error(f'Error: {args.model} is not a directory')
-        sys.exit(1)
-
-    ftype_map: dict[str, gguf.LlamaFileType] = {
-        "f32": gguf.LlamaFileType.ALL_F32,
-        "f16": gguf.LlamaFileType.MOSTLY_F16,
-        "bf16": gguf.LlamaFileType.MOSTLY_BF16,
-        "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
-    }
-
-    logger.info(f"Loading model: {dir_model.name}")
-
-    with torch.inference_mode():
-        gemma3_vision_tower = Gemma3VisionTower(
-            dir_model=dir_model,
-            fname_out=args.outfile,
-            ftype=ftype_map[args.outtype],
-            is_big_endian=args.bigendian,
-        )
-        gemma3_vision_tower.write()
-
-
-if __name__ == '__main__':
-    main()
-

From 7b53389c24a507564be39e1ea82746a39749059b Mon Sep 17 00:00:00 2001
From: Georgi Gerganov 
Date: Tue, 22 Apr 2025 16:15:51 +0300
Subject: [PATCH 26/39] metal : add memory pool for temp allocs (#12850)

* metal : add memory pool for temp allocs (wip) [no ci]

* cont : free buffers from the heap

* cont : resize heap [no ci]

* cont : refactor heap [no ci]

* cont : heap for each cmd buffer [no ci]

* cont : fix free

* wip

* cont : fix alignment [no ci]

* cont : not working .. [no ci]

* cont : heap allocation now works [no ci]

* cont : use MTLHeapTypePlacement

ggml-ci

* metal : use dynamic MTLHeap allocations

ggml-ci

* metal : add comments

* metal : disable softmax use of mem_pool

ggml-ci

* metal : final touches
---
 ggml/src/ggml-metal/ggml-metal.m | 406 ++++++++++++++++++++++++++++---
 1 file changed, 366 insertions(+), 40 deletions(-)

diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
index 266d8af469..d92392edb7 100644
--- a/ggml/src/ggml-metal/ggml-metal.m
+++ b/ggml/src/ggml-metal/ggml-metal.m
@@ -44,8 +44,8 @@ static struct ggml_backend_device g_ggml_backend_metal_device;
 // note: assumes single GPU device - the default one
 // TODO: support multiple GPU devices
 static struct ggml_backend_metal_device_context {
-    id mtl_device;
-    int           mtl_device_ref_count;
+    id  mtl_device;
+    int            mtl_device_ref_count;
     id mtl_library;
 
     bool has_simdgroup_reduction;
@@ -490,7 +490,259 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_COUNT
 };
 
+//
+// ggml_metal_heap
+//
+
+struct ggml_metal_heap {
+    // number of times the heap was unused
+    int n_unused;
+
+    // total number of buffer allocations in this heap across all computes
+    int64_t n_alloc;
+
+    // current offset in the heap - we reset this after each node in order to reuse the memory
+    size_t offs;
+
+    // the currently allocated MTLBuffer objects in this heap
+    id obj;
+
+    NSMutableArray * bufs;
+};
+
+static struct ggml_metal_heap * ggml_metal_heap_init(id device, size_t size) {
+    struct ggml_metal_heap * heap = calloc(1, sizeof(struct ggml_metal_heap));
+
+    MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
+    desc.storageMode  = MTLStorageModePrivate;
+    desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
+    desc.type         = MTLHeapTypePlacement;
+    desc.size         = size;
+
+    heap->n_unused = 0;
+    heap->n_alloc = 0;
+
+    heap->obj = [device newHeapWithDescriptor:desc];
+    if (!heap->obj) {
+        GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size);
+
+        free(heap);
+
+        return false;
+    }
+
+    [desc release];
+
+    heap->bufs = [[NSMutableArray alloc] init];
+
+    return heap;
+}
+
+static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) {
+    heap->offs = 0;
+
+    // count how many graph computes the heap ended up being unused
+    if ([heap->bufs count] > 0) {
+        heap->n_unused = 0;
+    } else {
+        heap->n_unused++;
+    }
+
+    for (id buf in heap->bufs) {
+        [buf release];
+    }
+    [heap->bufs removeAllObjects];
+
+    // tell the OS that it can reuse this memory if needed
+    // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
+    [heap->obj setPurgeableState:MTLPurgeableStateVolatile];
+}
+
+static void ggml_metal_heap_free(struct ggml_metal_heap * heap) {
+    if (heap == nil) {
+        return;
+    }
+
+    ggml_metal_heap_reset(heap);
+
+    [heap->obj  release];
+    [heap->bufs release];
+
+    free(heap);
+}
+
+@interface ggml_metal_heap_ptr : NSObject
+
+@property (nonatomic, assign) struct ggml_metal_heap * data;
+
+@end
+
+@implementation ggml_metal_heap_ptr
+@end
+
+//
+// ggml_metal_mem_pool
+//
+
+struct ggml_metal_mem_pool {
+    id device;
+
+    int n_heaps; // total number of heaps ever created (including those that were removed)
+
+    NSMutableArray * heaps;
+    NSMutableArray * heaps_to_remove;
+};
+
+static struct ggml_metal_mem_pool * ggml_metal_mem_pool_init(void) {
+    struct ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct ggml_metal_mem_pool));
+
+    mem_pool->n_heaps = 0;
+
+    mem_pool->heaps           = [[NSMutableArray alloc] init];
+    mem_pool->heaps_to_remove = [[NSMutableArray alloc] init];
+
+    return mem_pool;
+}
+
+static void ggml_metal_mem_pool_free(struct ggml_metal_mem_pool * mem_pool) {
+    GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps);
+
+    size_t size_all = 0;
+    size_t size_cur = 0;
+
+    for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
+        GGML_LOG_DEBUG("%s:   heap: %p\n",                __func__, (void *) ptr.data);
+        GGML_LOG_DEBUG("%s:     n_alloc:  %" PRId64 "\n", __func__, ptr.data->n_alloc);
+        GGML_LOG_DEBUG("%s:     n_unused: %d\n",          __func__, ptr.data->n_unused);
+        GGML_LOG_DEBUG("%s:     size:     %.2f MiB\n",    __func__, [ptr.data->obj size] / 1024.0 / 1024.0);
+        GGML_LOG_DEBUG("%s:     bufs:     %zu\n",         __func__, [ptr.data->bufs count]);
+
+        if ([ptr.data->bufs count] > 0) {
+            size_cur += [ptr.data->obj size];
+        }
+        size_all += [ptr.data->obj size];
+
+        ggml_metal_heap_free(ptr.data);
+        [ptr release];
+    }
+    [mem_pool->heaps           release];
+    [mem_pool->heaps_to_remove release];
+
+    if (size_all > 0) {
+        GGML_LOG_DEBUG("%s:   size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0);
+        GGML_LOG_DEBUG("%s:   size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0);
+    }
+
+    free(mem_pool);
+}
+
+static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) {
+    for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) {
+        ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i];
+
+        struct ggml_metal_heap * heap = ptr.data;
+        ggml_metal_heap_reset(heap);
+
+        // if the heap hasn't been used for a while, remove it
+        if (heap->n_unused >= 128) {
+            [mem_pool->heaps_to_remove addObject:@(i)];
+        }
+    }
+
+    if (mem_pool->heaps_to_remove.count > 0) {
+        for (NSUInteger i = 0; i < [mem_pool->heaps_to_remove count]; i++) {
+            NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
+            ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
+
+            struct ggml_metal_heap * heap = ptr.data;
+            ggml_metal_heap_free(heap);
+
+            [mem_pool->heaps removeObjectAtIndex:index];
+            [ptr release];
+        }
+
+        [mem_pool->heaps_to_remove removeAllObjects];
+    }
+}
+
+static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) {
+    for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
+        ptr.data->offs = 0;
+    }
+}
+
+static id ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) {
+    const size_t alignment = 32;
+
+    const size_t size_aligned = GGML_PAD(size, alignment);
+
+    // try one of the existing heaps
+    for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
+        struct ggml_metal_heap * heap = ptr.data;
+        if (heap->offs + size_aligned <= [heap->obj size]) {
+            // if this is the first buffer in the heap for the current command buffer, tell the OS that
+            //   it cannot free the memory used by the heap
+            // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
+            if ([heap->bufs count] == 0) {
+                [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
+            }
+
+            id buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
+            if (buf == nil) {
+                GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
+                return nil;
+            }
+
+            heap->n_alloc++;
+            heap->offs += size_aligned;
+
+            [heap->bufs addObject:buf];
+
+            return buf;
+        }
+    }
+
+    // create a new heap that can fit this buffer
+    ggml_metal_heap_ptr * heap_ptr = [ggml_metal_heap_ptr new];
+
+    struct ggml_metal_heap * heap = ggml_metal_heap_init(mem_pool->device, size_aligned);
+    if (heap == NULL) {
+        GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned);
+        return NULL;
+    }
+
+    //GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]);
+
+    heap_ptr.data = heap;
+    ggml_metal_heap_reset(heap);
+
+    [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
+    id buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
+    if (buf == nil) {
+        GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
+        return NULL;
+    }
+
+    heap->n_alloc++;
+    heap->offs += size_aligned;
+
+    [heap->bufs addObject:buf];
+
+    [mem_pool->heaps addObject:heap_ptr];
+    mem_pool->n_heaps++;
+
+    return buf;
+}
+
+struct ggml_metal_command_buffer {
+    id obj;
+
+    // each command buffer has a memory pool from which it can allocate temporary buffers during the compute
+    struct ggml_metal_mem_pool * mem_pool;
+};
+
 struct ggml_backend_metal_context {
+    id       device;
     id queue;
 
     dispatch_queue_t d_queue;
@@ -515,7 +767,7 @@ struct ggml_backend_metal_context {
     void (^encode_async)(size_t ith);
 
     // n_cb command buffers + 1 used by the main thread
-    id command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
+    struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
 
     // abort ggml_metal_graph_compute if callback returns true
     ggml_abort_callback abort_callback;
@@ -705,9 +957,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
     struct ggml_backend_metal_device_context * ctx_dev = dev->context;
 
     id device = ggml_backend_metal_device_acq(ctx_dev);
+
     GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
 
-    ctx->queue  = [device newCommandQueue];
+    ctx->device = device;
+    ctx->queue = [device newCommandQueue];
     if (ctx->queue == nil) {
         GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
         return NULL;
@@ -768,7 +1022,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
     ctx->gf = nil;
     ctx->encode_async = nil;
     for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
-        ctx->command_buffers[i] = nil;
+        ctx->cmd_bufs[i].obj = nil;
+
+        ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init();
+        ctx->cmd_bufs[i].mem_pool->device = device;
     }
 
 #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
@@ -1181,6 +1438,12 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
 
     [ctx->queue release];
 
+    for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
+        // ctx->cmd_bufs[i].obj is auto released
+
+        ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
+    }
+
     dispatch_release(ctx->d_queue);
 
     free(ctx);
@@ -1486,10 +1749,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
     }
 }
 
-static void ggml_metal_encode_node(
+static bool ggml_metal_encode_node(
                         ggml_backend_t   backend,
                                    int   idx,
-          id   encoder) {
+          id   encoder,
+            struct ggml_metal_mem_pool * mem_pool) {
     struct ggml_backend_metal_context        * ctx     = backend->context;
     struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
 
@@ -1505,7 +1769,7 @@ static void ggml_metal_encode_node(
     struct ggml_tensor * dst  = node;
 
     if (ggml_is_empty(dst)) {
-        return;
+        return true;
     }
 
     switch (dst->op) {
@@ -1516,7 +1780,7 @@ static void ggml_metal_encode_node(
         case GGML_OP_PERMUTE:
             {
                 // noop -> next node
-            } return;
+            } return true;
         default:
             {
             } break;
@@ -1527,6 +1791,8 @@ static void ggml_metal_encode_node(
         GGML_ABORT("unsupported op");
     }
 
+    ggml_metal_mem_pool_clear(mem_pool);
+
     const int64_t  ne00 = src0 ? src0->ne[0] : 0;
     const int64_t  ne01 = src0 ? src0->ne[1] : 0;
     const int64_t  ne02 = src0 ? src0->ne[2] : 0;
@@ -2173,26 +2439,76 @@ static void ggml_metal_encode_node(
                 const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
                 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
-                ggml_metal_kargs_soft_max args = {
+// use this branch to test the ggml_metal_mem_pool functionality
+#if 0
+                // cpy to tmp buffer in MTLHeap
+
+                id h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
+                if (!h_src0) {
+                    GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
+                    return false;
+                }
+
+                offs_src0 = 0;
+
+                ggml_metal_kargs_cpy args_cpy = {
                     /*.ne00 =*/ ne00,
                     /*.ne01 =*/ ne01,
                     /*.ne02 =*/ ne02,
-                    /*.scale =*/ scale,
-                    /*.max_bias =*/ max_bias,
-                    /*.m0 =*/ m0,
-                    /*.m1 =*/ m1,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne0  =*/ ne00,
+                    /*.ne1  =*/ ne01,
+                    /*.ne2  =*/ ne02,
+                    /*.ne3  =*/ ne03,
+                    /*.nb0  =*/ nb00,
+                    /*.nb1  =*/ nb01,
+                    /*.nb2  =*/ nb02,
+                    /*.nb3  =*/ nb03,
+                };
+
+                if (src0->type == GGML_TYPE_F16) {
+                    [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
+                } else {
+                    [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
+                }
+                [encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
+                [encoder setBuffer:id_src0  offset:offs_src0        atIndex:1];
+                [encoder setBuffer:h_src0   offset:0                atIndex:2];
+
+                GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
+                int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type));
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
+
+#else
+                id h_src0 = id_src0;
+#endif
+                // softmax
+
+                ggml_metal_kargs_soft_max args = {
+                    /*.ne00        =*/ ne00,
+                    /*.ne01        =*/ ne01,
+                    /*.ne02        =*/ ne02,
+                    /*.scale       =*/ scale,
+                    /*.max_bias    =*/ max_bias,
+                    /*.m0          =*/ m0,
+                    /*.m1          =*/ m1,
                     /*.n_head_log2 =*/ n_head_log2,
                 };
 
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
+                [encoder setBuffer:h_src0 offset:offs_src0      atIndex:0];
                 if (id_src1) {
-                    [encoder setBuffer:id_src1 offset:offs_src1   atIndex:1];
+                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                 } else {
-                    [encoder setBuffer:id_src0 offset:offs_src0   atIndex:1];
+                    [encoder setBuffer:h_src0 offset:offs_src0  atIndex:1];
                 }
-                [encoder setBuffer:id_dst      offset:offs_dst            atIndex:2];
-                [encoder setBytes:&args        length:sizeof(args)        atIndex:3];
+                [encoder setBuffer:id_dst offset:offs_dst       atIndex:2];
+                [encoder setBytes:&args   length:sizeof(args)   atIndex:3];
 
                 [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
@@ -4601,6 +4917,8 @@ static void ggml_metal_encode_node(
                 GGML_ABORT("fatal error");
             }
     }
+
+    return true;
 }
 
 static enum ggml_status ggml_metal_graph_compute(
@@ -4654,25 +4972,25 @@ static enum ggml_status ggml_metal_graph_compute(
         }
 
         // the main thread commits the first few commands immediately
-        // command_buffer[n_cb]
+        // cmd_buf[n_cb]
         {
-            id command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
-            ctx->command_buffers[n_cb] = command_buffer;
+            id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
+            ctx->cmd_bufs[n_cb].obj = cmd_buf;
 
-            [command_buffer enqueue];
+            [cmd_buf enqueue];
             ctx->encode_async(n_cb);
         }
 
         // prepare the rest of the command buffers asynchronously
-        // command_buffer[0.. n_cb)
+        // cmd_buf[0.. n_cb)
         for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
-            id command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
-            ctx->command_buffers[cb_idx] = command_buffer;
+            id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
+            ctx->cmd_bufs[cb_idx].obj = cmd_buf;
 
             // always enqueue the first two command buffers
             // enqueue all of the command buffers if we don't need to abort
             if (cb_idx < 2 || ctx->abort_callback == NULL) {
-                [command_buffer enqueue];
+                [cmd_buf enqueue];
             }
         }
 
@@ -4681,14 +4999,14 @@ static enum ggml_status ggml_metal_graph_compute(
         // wait for completion and check status of each command buffer
         // needed to detect if the device ran out-of-memory for example (#1881)
         {
-            id command_buffer = ctx->command_buffers[n_cb];
-            [command_buffer waitUntilCompleted];
+            id cmd_buf = ctx->cmd_bufs[n_cb].obj;
+            [cmd_buf waitUntilCompleted];
 
-            MTLCommandBufferStatus status = [command_buffer status];
+            MTLCommandBufferStatus status = [cmd_buf status];
             if (status != MTLCommandBufferStatusCompleted) {
                 GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
                 if (status == MTLCommandBufferStatusError) {
-                    GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
+                    GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
                 }
 
                 return GGML_STATUS_FAILED;
@@ -4696,20 +5014,20 @@ static enum ggml_status ggml_metal_graph_compute(
         }
 
         for (int i = 0; i < n_cb; ++i) {
-            id command_buffer = ctx->command_buffers[i];
-            [command_buffer waitUntilCompleted];
+            id cmd_buf = ctx->cmd_bufs[i].obj;
+            [cmd_buf waitUntilCompleted];
 
-            MTLCommandBufferStatus status = [command_buffer status];
+            MTLCommandBufferStatus status = [cmd_buf status];
             if (status != MTLCommandBufferStatusCompleted) {
                 GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
                 if (status == MTLCommandBufferStatusError) {
-                    GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
+                    GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
                 }
 
                 return GGML_STATUS_FAILED;
             }
 
-            id next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
+            id next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
             if (!next_buffer) {
                 continue;
             }
@@ -5092,8 +5410,9 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
 
         const int n_nodes_per_cb = ctx->n_nodes_per_cb;
 
-        id command_buffer  = ctx->command_buffers[cb_idx];
-        id encoder = [command_buffer computeCommandEncoder];
+        id cmd_buf = ctx->cmd_bufs[cb_idx].obj;
+
+        id encoder = [cmd_buf computeCommandEncoder];
 
         int node_start = 0;
         int node_end   = n_nodes_0;
@@ -5105,22 +5424,29 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
 
         const bool should_capture = ctx->capture_next_compute;
 
+        struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
+        ggml_metal_mem_pool_reset(mem_pool);
+
         for (int idx = node_start; idx < node_end; ++idx) {
             if (should_capture) {
                 [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
             }
 
-            ggml_metal_encode_node(backend, idx, encoder);
+            const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
 
             if (should_capture) {
                 [encoder popDebugGroup];
             }
+
+            if (!res) {
+                break;
+            }
         }
 
         [encoder endEncoding];
 
         if (cb_idx < 2 || ctx->abort_callback == NULL) {
-            [command_buffer commit];
+            [cmd_buf commit];
         }
     });
 }

From ab47dec3d37aa1927c2ec590e166b76141374ed3 Mon Sep 17 00:00:00 2001
From: Georgi Gerganov 
Date: Tue, 22 Apr 2025 16:16:10 +0300
Subject: [PATCH 27/39] security : add note about RPC and server functionality
 (#13061)

* security : add note about RPC functionality

* security : add note about llama-server
---
 SECURITY.md | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/SECURITY.md b/SECURITY.md
index 6a1bb6c32c..9370fb1a88 100644
--- a/SECURITY.md
+++ b/SECURITY.md
@@ -40,7 +40,8 @@ To protect sensitive data from potential leaks or unauthorized access, it is cru
 ### Untrusted environments or networks
 
 If you can't run your models in a secure and isolated environment or if it must be exposed to an untrusted network, make sure to take the following security precautions:
-* Confirm the hash of any downloaded artifact (e.g. pre-trained model weights) matches a known-good value
+* Do not use the RPC backend, [rpc-server](https://github.com/ggml-org/llama.cpp/tree/master/examples/rpc) and [llama-server](https://github.com/ggml-org/llama.cpp/tree/master/examples/server) functionality (see https://github.com/ggml-org/llama.cpp/pull/13061).
+* Confirm the hash of any downloaded artifact (e.g. pre-trained model weights) matches a known-good value.
 * Encrypt your data if sending it over the network.
 
 ### Multi-Tenant environments

From dc39a5e7a84815a90fa0c515ed8927870cf858c9 Mon Sep 17 00:00:00 2001
From: Xuan-Son Nguyen 
Date: Tue, 22 Apr 2025 16:24:54 +0200
Subject: [PATCH 28/39] mtmd : support SmolVLM (version 1 and 2) (#13050)

* mtmd : support SmolVLM (version 1 and 2)

* correct chat template

* fix n_patches

* scale_factor is an int

* add more models to test
---
 convert_hf_to_gguf.py          | 100 +++++++++++++++++++-----
 examples/llava/clip-impl.h     |   7 +-
 examples/llava/clip.cpp        | 135 +++++++++++++++++++++++----------
 examples/llava/mtmd.cpp        |   7 ++
 examples/llava/tests.sh        |  11 ++-
 gguf-py/gguf/constants.py      |   9 +++
 gguf-py/gguf/gguf_writer.py    |  47 ++++++++++++
 gguf-py/gguf/tensor_mapping.py |   4 +-
 src/llama-chat.cpp             |  23 +++++-
 src/llama-chat.h               |   1 +
 10 files changed, 279 insertions(+), 65 deletions(-)

diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 6d34541a3c..645bdad9b5 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -419,8 +419,12 @@ class ModelBase:
     def load_hparams(dir_model: Path):
         with open(dir_model / "config.json", "r", encoding="utf-8") as f:
             hparams = json.load(f)
+            architectures = hparams.get("architectures")
             if "text_config" in hparams:
                 hparams = {**hparams, **hparams["text_config"]}
+            if architectures is not None:
+                # preserve "architectures" from root level config
+                hparams["architectures"] = architectures
             return hparams
 
     @classmethod
@@ -1061,6 +1065,8 @@ class TextModel(ModelBase):
 class VisionModel(ModelBase):
     model_arch = gguf.MODEL_ARCH.CLIP_VISION
     n_text_embd = 0
+    preprocessor_config: dict[str, Any]
+    global_config: dict[str, Any]
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
@@ -1075,24 +1081,33 @@ class VisionModel(ModelBase):
 
         if "vision_config" not in self.hparams:
             raise ValueError("vision_config not found in hparams")
-        # move vision config to the top level
+        # move vision config to the top level, while preserving the original hparams in global_config
+        self.global_config = self.hparams
         self.hparams = self.hparams["vision_config"]
 
+        # load preprocessor config
+        with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
+            self.preprocessor_config = json.load(f)
+
     def set_type(self):
         self.gguf_writer.add_type(gguf.GGUFType.CLIP_VISION)
 
     def set_gguf_parameters(self):
         self.gguf_writer.add_file_type(self.ftype)
-        self.gguf_writer.add_uint32(gguf.Keys.ClipVision.PROJECTION_DIM, self.n_embd_text)
-        self.gguf_writer.add_bool(gguf.Keys.ClipVision.HAS_VISION_ENCODER, True)
+        self.gguf_writer.add_vision_projection_dim(self.n_embd_text)
+        self.gguf_writer.add_vision_has_vision_encoder(True)
 
         # vision config
-        self.gguf_writer.add_uint32(gguf.Keys.ClipVision.IMAGE_SIZE,           self.find_hparam(["image_size"]))
-        self.gguf_writer.add_uint32(gguf.Keys.ClipVision.PATCH_SIZE,           self.find_hparam(["patch_size"]))
-        self.gguf_writer.add_uint32(gguf.Keys.ClipVision.EMBEDDING_LENGTH,     self.find_hparam(["hidden_size"]))
-        self.gguf_writer.add_uint32(gguf.Keys.ClipVision.FEED_FORWARD_LENGTH,  self.find_hparam(["intermediate_size"]))
-        self.gguf_writer.add_uint32(gguf.Keys.ClipVision.BLOCK_COUNT,          self.find_hparam(["num_hidden_layers"]))
-        self.gguf_writer.add_uint32(gguf.Keys.ClipVision.Attention.HEAD_COUNT, self.find_hparam(["num_attention_heads"]))
+        self.gguf_writer.add_vision_image_size(self.find_hparam(["image_size"]))
+        self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
+        self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
+        self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
+        self.gguf_writer.add_vision_block_count(self.find_hparam(["num_hidden_layers"]))
+        self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))
+
+        # preprocessor config
+        self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"])
+        self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_mean"])
 
     def write_vocab(self):
         raise ValueError("VisionModel does not support vocab writing")
@@ -1703,11 +1718,23 @@ class StableLMModel(TextModel):
                 raise ValueError(f"Unprocessed norms: {norms}")
 
 
-@ModelBase.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
+@ModelBase.register(
+    "LLaMAForCausalLM",
+    "LlamaForCausalLM",
+    "MistralForCausalLM",
+    "MixtralForCausalLM",
+    "Idefics3ForConditionalGeneration",
+    "SmolVLMForConditionalGeneration")
 class LlamaModel(TextModel):
     model_arch = gguf.MODEL_ARCH.LLAMA
     undo_permute = True
 
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        # fix for SmolVLM2, missing `num_attention_heads` in config.json
+        if self.hparams["architectures"][0] == "SmolVLMForConditionalGeneration":
+            self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
+
     def set_vocab(self):
         try:
             self._set_vocab_sentencepiece()
@@ -1770,6 +1797,12 @@ class LlamaModel(TextModel):
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
         n_head = self.hparams["num_attention_heads"]
         n_kv_head = self.hparams.get("num_key_value_heads")
+        is_vision_tensor = "vision_tower" in name or "vision_model" in name or "model.connector" in name
+
+        if is_vision_tensor:
+            return [] # skip vision tensors
+        elif name.startswith("model.text_model"):
+            name = name.replace("text_model.", "") # for SmolVLM
 
         if self.undo_permute:
             if name.endswith(("q_proj.weight", "q_proj.bias")):
@@ -1852,6 +1885,41 @@ class LlamaModel(TextModel):
                 raise ValueError(f"Unprocessed experts: {experts}")
 
 
+@ModelBase.register("Idefics3ForConditionalGeneration", "SmolVLMForConditionalGeneration")
+class SmolVLMModel(VisionModel):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        # fix for SmolVLM2, missing some keys in config.json
+        # default values are taken from transformers code
+        if self.hparams["model_type"] == "smolvlm_vision":
+            self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1152)
+            self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
+            self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 3072)
+            self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 12)
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.IDEFICS3)
+        self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5))
+        self.gguf_writer.add_vision_projector_scale_factor(self.global_config.get("scale_factor", 2))
+        self.gguf_writer.add_vision_use_gelu(True)
+
+    def tensor_force_quant(self, name, new_name, bid, n_dims):
+        del bid, new_name, n_dims  # unused
+        if ".embeddings." in name:
+            return gguf.GGMLQuantizationType.F32
+        return False
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        del bid  # unused
+        is_vision_tensor = "vision_tower" in name or "vision_model" in name or "model.connector" in name
+
+        if is_vision_tensor:
+            return [(self.map_tensor_name(name), data_torch)]
+
+        return [] # skip other tensors
+
+
 @ModelBase.register("Llama4ForConditionalGeneration")
 class Llama4Model(LlamaModel):
     model_arch = gguf.MODEL_ARCH.LLAMA4
@@ -3591,12 +3659,10 @@ class Gemma3VisionModel(VisionModel):
     def set_gguf_parameters(self):
         super().set_gguf_parameters()
         hparams = self.hparams
-        self.gguf_writer.add_string(gguf.Keys.ClipVision.PROJECTOR_TYPE, "gemma3")
+        self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.GEMMA3)
         # default values below are taken from HF tranformers code
-        self.gguf_writer.add_float32(gguf.Keys.ClipVision.Attention.LAYERNORM_EPS, hparams.get("layer_norm_eps", 1e-6))
-        self.gguf_writer.add_array(gguf.Keys.ClipVision.IMAGE_MEAN, [0.5, 0.5, 0.5])
-        self.gguf_writer.add_array(gguf.Keys.ClipVision.IMAGE_STD,  [0.5, 0.5, 0.5])
-        self.gguf_writer.add_bool (gguf.Keys.ClipVision.USE_GELU,   True)
+        self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6))
+        self.gguf_writer.add_vision_use_gelu(True)
 
     def tensor_force_quant(self, name, new_name, bid, n_dims):
         del bid, new_name, n_dims  # unused
@@ -3614,10 +3680,6 @@ class Gemma3VisionModel(VisionModel):
                 or name.startswith("multimodal_projector.") or name.startswith("vision_model."):
             # process vision tensors
             name = name.replace("_weight", ".weight")
-            if "fc1" in name:
-                name = name.replace("fc1", "fc2")
-            else:
-                name = name.replace("fc2", "fc1")
 
             # correct norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector
             # the other norm values are part of SigLIP model, and they are already correct
diff --git a/examples/llava/clip-impl.h b/examples/llava/clip-impl.h
index 180ae9880b..faa8a9d5e9 100644
--- a/examples/llava/clip-impl.h
+++ b/examples/llava/clip-impl.h
@@ -33,13 +33,13 @@
 #define KEY_LAYER_NORM_EPS      "clip.%s.attention.layer_norm_epsilon"
 #define KEY_PROJ_DIM            "clip.%s.projection_dim"
 #define KEY_TOKENS              "tokenizer.ggml.tokens"
-#define KEY_N_POSITIONS         "clip.text.context_length"
 #define KEY_IMAGE_SIZE          "clip.vision.image_size"
 #define KEY_PATCH_SIZE          "clip.vision.patch_size"
 #define KEY_IMAGE_MEAN          "clip.vision.image_mean"
 #define KEY_IMAGE_STD           "clip.vision.image_std"
-#define KEY_PROJ_TYPE           "clip.projector_type"
 #define KEY_FEATURE_LAYER       "clip.vision.feature_layer"
+#define KEY_PROJ_SCALE_FACTOR   "clip.vision.projector.scale_factor"
+#define KEY_PROJ_TYPE           "clip.projector_type"
 
 #define KEY_MM_PATCH_MERGE_TYPE   "clip.vision.mm_patch_merge_type"
 #define KEY_IMAGE_GRID_PINPOINTS  "clip.vision.image_grid_pinpoints"
@@ -72,6 +72,7 @@
 #define TN_IMAGE_NEWLINE   "model.image_newline"
 #define TN_MM_INP_PROJ     "mm.input_projection.weight" // gemma3
 #define TN_MM_SOFT_EMB_N   "mm.soft_emb_norm.weight"    // gemma3
+#define TN_MM_PROJECTOR    "mm.model.fc.weight"         // idefics3
 
 // mimicpmv
 #define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
@@ -99,6 +100,7 @@ enum projector_type {
     PROJECTOR_TYPE_GLM_EDGE,
     PROJECTOR_TYPE_MERGER,
     PROJECTOR_TYPE_GEMMA3,
+    PROJECTOR_TYPE_IDEFICS3,
     PROJECTOR_TYPE_UNKNOWN,
 };
 
@@ -110,6 +112,7 @@ static std::map PROJECTOR_TYPE_NAMES = {
     { PROJECTOR_TYPE_GLM_EDGE,  "adapter"},
     { PROJECTOR_TYPE_MERGER,    "qwen2vl_merger"},
     { PROJECTOR_TYPE_GEMMA3,    "gemma3"},
+    { PROJECTOR_TYPE_IDEFICS3,  "idefics3"},
 };
 
 static projector_type clip_projector_type_from_string(const std::string & str) {
diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
index 759706156d..ad39e1bd61 100644
--- a/examples/llava/clip.cpp
+++ b/examples/llava/clip.cpp
@@ -159,6 +159,7 @@ struct clip_hparams {
     int32_t projection_dim;
     int32_t n_head;
     int32_t n_layer;
+    int32_t proj_scale_factor = 0; // idefics3
 
     patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT;
 
@@ -506,6 +507,35 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
         embeddings = ggml_mul_mat(ctx0,
             ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)),
             embeddings);
+
+    } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
+        // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
+
+        ggml_tensor * cur = embeddings;
+        const int scale_factor = model.hparams.proj_scale_factor;
+        const int n_embd = cur->ne[0];
+        const int seq    = cur->ne[1];
+        const int bsz    = 1; // batch size, always 1 for now since we don't support batching
+        const int height = std::sqrt(seq);
+        const int width  = std::sqrt(seq);
+        GGML_ASSERT(scale_factor != 0);
+        cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz);
+        cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+        cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
+            n_embd * scale_factor * scale_factor,
+            height / scale_factor,
+            width / scale_factor,
+            bsz);
+        cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+        cur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, cur),
+            n_embd * scale_factor * scale_factor,
+            seq / (scale_factor * scale_factor),
+            bsz);
+
+        cur = ggml_mul_mat(ctx0, model.projection, cur);
+        embeddings = cur;
+    } else {
+        GGML_ABORT("SigLIP: Unsupported projector type");
     }
 
     // build the graph
@@ -1081,12 +1111,20 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
 }
 
 static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
-    if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
-        return clip_image_build_graph_siglip(ctx, imgs);
-    } else {
-        // TODO: we should have one build_* function per model
-        return clip_image_build_graph_legacy(ctx, imgs, load_image_size, is_inf);
+    ggml_cgraph * res;
+    switch (ctx->proj_type) {
+        case PROJECTOR_TYPE_GEMMA3:
+        case PROJECTOR_TYPE_IDEFICS3:
+            {
+                res = clip_image_build_graph_siglip(ctx, imgs);
+            } break;
+        default:
+            {
+                // TODO: we should have one build_* function per model
+                res = clip_image_build_graph_legacy(ctx, imgs, load_image_size, is_inf);
+            } break;
     }
+    return res;
 }
 
 struct clip_model_loader {
@@ -1147,6 +1185,8 @@ struct clip_model_loader {
     }
 
     void load_hparams() {
+        auto & hparams = ctx_clip.vision_model.hparams;
+
         // projector type
         {
             std::string proj_type;
@@ -1177,7 +1217,6 @@ struct clip_model_loader {
             get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false);
             get_bool(KEY_USE_SILU, ctx_clip.use_silu, false);
 
-            auto & hparams = ctx_clip.vision_model.hparams;
             get_u32(string_format(KEY_N_EMBD,         "vision"), hparams.hidden_size);
             get_u32(string_format(KEY_N_HEAD,         "vision"), hparams.n_head);
             get_u32(string_format(KEY_N_FF,           "vision"), hparams.n_intermediate);
@@ -1233,6 +1272,16 @@ struct clip_model_loader {
             LOG_INF("%s: model size:         %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
             LOG_INF("%s: metadata size:      %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
         }
+
+        // model-specific params
+        switch (ctx_clip.proj_type) {
+            case PROJECTOR_TYPE_IDEFICS3:
+                {
+                    get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
+                } break;
+            default:
+                break;
+        }
     }
 
     void load_tensors() {
@@ -1422,6 +1471,10 @@ struct clip_model_loader {
                     vision_model.mm_input_proj_w = get_tensor(TN_MM_INP_PROJ);
                     vision_model.mm_soft_emb_norm_w = get_tensor(TN_MM_SOFT_EMB_N);
                 } break;
+            case PROJECTOR_TYPE_IDEFICS3:
+                {
+                    vision_model.projection = get_tensor(TN_MM_PROJECTOR);
+                } break;
             default:
                 GGML_ASSERT(false && "unknown projector type");
         }
@@ -2195,10 +2248,12 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
         return true;
     }
 
-    if (ctx->has_glm_projector || ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
+    if (ctx->has_glm_projector
+            || ctx->proj_type == PROJECTOR_TYPE_GEMMA3
+            || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
         clip_image_u8 resized_image;
         int sz = params.image_size;
-        image_manipulation::bicubic_resize(*img, resized_image, sz, sz);
+        image_manipulation::resize_and_pad_image(*img, resized_image, {sz, sz});
         clip_image_f32_ptr img_f32(clip_image_f32_init());
         //clip_image_save_to_bmp(resized_image, "resized.bmp");
         normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
@@ -2330,6 +2385,8 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
         n_patches = x_patch * y_patch;
     } else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
         n_patches = 256;
+    } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
+        n_patches /= ctx->vision_model.hparams.proj_scale_factor;
     }
 
     return n_patches;
@@ -2597,6 +2654,9 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
         else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
             // do nothing
         }
+        else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
+            // do nothing
+        }
         else {
             struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
 
@@ -2783,37 +2843,34 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
 }
 
 int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
-    if (ctx->proj_type == PROJECTOR_TYPE_LDP) {
-        return ctx->vision_model.mm_model_block_1_block_2_1_b->ne[0];
-    }
-    if (ctx->proj_type == PROJECTOR_TYPE_LDPV2) {
-        return ctx->vision_model.mm_model_peg_0_b->ne[0];
-    }
-    if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
-        return ctx->vision_model.mm_2_b->ne[0];
-    }
-    if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
-        return ctx->vision_model.mm_3_b->ne[0];
-    }
-    if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
-        if (ctx->minicpmv_version == 2) {
-            return 4096;
-        }
-        else if (ctx->minicpmv_version == 3) {
-            return 3584;
-        }
-        else if (ctx->minicpmv_version == 4) {
-            return 3584;
-        }
-    }
-    if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE){
-        return ctx->vision_model.mm_model_mlp_3_w->ne[1];
-    }
-    if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
-        return ctx->vision_model.mm_1_b->ne[0];
-    }
-    if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
-        return ctx->vision_model.mm_input_proj_w->ne[0];
+    switch (ctx->proj_type) {
+        case PROJECTOR_TYPE_LDP:
+            return ctx->vision_model.mm_model_block_1_block_2_1_b->ne[0];
+        case PROJECTOR_TYPE_LDPV2:
+            return ctx->vision_model.mm_model_peg_0_b->ne[0];
+        case PROJECTOR_TYPE_MLP:
+            return ctx->vision_model.mm_2_b->ne[0];
+        case PROJECTOR_TYPE_MLP_NORM:
+            return ctx->vision_model.mm_3_b->ne[0];
+        case PROJECTOR_TYPE_RESAMPLER:
+            if (ctx->minicpmv_version == 2) {
+                return 4096;
+            } else if (ctx->minicpmv_version == 3) {
+                return 3584;
+            } else if (ctx->minicpmv_version == 4) {
+                return 3584;
+            }
+            break; // Should not happen if version is valid
+        case PROJECTOR_TYPE_GLM_EDGE:
+            return ctx->vision_model.mm_model_mlp_3_w->ne[1];
+        case PROJECTOR_TYPE_MERGER:
+            return ctx->vision_model.mm_1_b->ne[0];
+        case PROJECTOR_TYPE_GEMMA3:
+            return ctx->vision_model.mm_input_proj_w->ne[0];
+        case PROJECTOR_TYPE_IDEFICS3:
+            return ctx->vision_model.projection->ne[1];
+        default:
+            break; // Fall through to throw
     }
 
     std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
diff --git a/examples/llava/mtmd.cpp b/examples/llava/mtmd.cpp
index 8866c12e4d..c3fb2f18a2 100644
--- a/examples/llava/mtmd.cpp
+++ b/examples/llava/mtmd.cpp
@@ -176,6 +176,8 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
 
     std::string prompt_modified(text.text);
     std::string marker_modified(ctx->image_marker);
+    projector_type proj_type = clip_get_projector_type(ctx->ctx_clip);
+
     // a bit hacky here, but works for now
     // for some models, we need to add prefix and suffix to the image embeddings
     if (clip_is_gemma3(ctx->ctx_clip)) {
@@ -183,6 +185,11 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
         //  ... (image embeddings) ... 
         marker_modified = "" + ctx->image_marker + "";
         string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
+
+    } else if (proj_type == PROJECTOR_TYPE_IDEFICS3) {
+        // https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
+        marker_modified = "" + ctx->image_marker + "";
+        string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
     }
 
     // llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
diff --git a/examples/llava/tests.sh b/examples/llava/tests.sh
index 61ebb3ac18..8752fc267a 100755
--- a/examples/llava/tests.sh
+++ b/examples/llava/tests.sh
@@ -28,6 +28,9 @@ add_test() {
     arr_tmpl+=("$tmpl")
 }
 
+add_test "llama-mtmd-cli"  "ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0"
+add_test "llama-mtmd-cli"  "ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M"
+add_test "llama-mtmd-cli"  "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0"
 add_test "llama-mtmd-cli"  "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M"
 add_test "llama-mtmd-cli"  "guinmoon/MobileVLM-3B-GGUF:Q4_K_M"               "deepseek"
 add_test "llama-mtmd-cli"  "THUDM/glm-edge-v-5b-gguf:Q4_K_M"
@@ -39,7 +42,13 @@ add_test "llama-mtmd-cli"  "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
 add_test "llama-mtmd-cli"  "openbmb/MiniCPM-o-2_6-gguf:Q4_0"
 add_test "llama-qwen2vl-cli"  "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
 
-# add_test "llama-mtmd-cli"  "cmp-nct/Yi-VL-6B-GGUF:Q5_K"  # this model has broken chat template, not usable
+# these models always give the wrong answer, not sure why
+# add_test "llama-mtmd-cli"  "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"
+# add_test "llama-mtmd-cli"  "ggml-org/SmolVLM-256M-Instruct-GGUF:Q8_0"
+# add_test "llama-mtmd-cli"  "ggml-org/SmolVLM2-256M-Video-Instruct-GGUF:Q8_0"
+
+# this model has broken chat template, not usable
+# add_test "llama-mtmd-cli"  "cmp-nct/Yi-VL-6B-GGUF:Q5_K"
 
 ###############
 
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index 3f24705201..59510bd0c2 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -231,11 +231,15 @@ class Keys:
         IMAGE_MEAN          = "clip.vision.image_mean"
         IMAGE_STD           = "clip.vision.image_std"
         USE_GELU            = "clip.use_gelu"
+        USE_SILU            = "clip.use_silu"
 
         class Attention:
             HEAD_COUNT      = "clip.vision.attention.head_count"
             LAYERNORM_EPS   = "clip.vision.attention.layer_norm_epsilon"
 
+        class Projector:
+            SCALE_FACTOR    = "clip.vision.projector.scale_factor"
+
 #
 # recommended mapping of model tensor names for storage in gguf
 #
@@ -2122,6 +2126,11 @@ class GGUFValueType(IntEnum):
             raise ValueError(f"Unknown type: {type(val)}")
 
 
+class VisionProjectorType:
+    GEMMA3 = "gemma3"
+    IDEFICS3 = "idefics3"
+
+
 # Items here are (block size, type size)
 QK_K = 256
 GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py
index aef03db157..48e9a470b7 100644
--- a/gguf-py/gguf/gguf_writer.py
+++ b/gguf-py/gguf/gguf_writer.py
@@ -931,6 +931,53 @@ class GGUFWriter:
     def add_eom_token_id(self, id: int) -> None:
         self.add_uint32(Keys.Tokenizer.EOM_ID, id)
 
+    # for vision models
+
+    def add_vision_projection_dim(self, value: int) -> None:
+        self.add_uint32(Keys.ClipVision.PROJECTION_DIM, value)
+
+    def add_vision_has_vision_encoder(self, value: bool) -> None:
+        self.add_bool(Keys.ClipVision.HAS_VISION_ENCODER, value)
+
+    def add_vision_patch_size(self, value: int) -> None:
+        self.add_uint32(Keys.ClipVision.PATCH_SIZE, value)
+
+    def add_vision_embedding_length(self, value: int) -> None:
+        self.add_uint32(Keys.ClipVision.EMBEDDING_LENGTH, value)
+
+    def add_vision_feed_forward_length(self, value: int) -> None:
+        self.add_uint32(Keys.ClipVision.FEED_FORWARD_LENGTH, value)
+
+    def add_vision_block_count(self, value: int) -> None:
+        self.add_uint32(Keys.ClipVision.BLOCK_COUNT, value)
+
+    def add_vision_head_count(self, value: int) -> None:
+        self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT, value)
+
+    def add_vision_projector_type(self, value: str) -> None:
+        self.add_string(Keys.ClipVision.PROJECTOR_TYPE, value)
+
+    def add_vision_attention_layernorm_eps(self, value: float) -> None:
+        self.add_float32(Keys.ClipVision.Attention.LAYERNORM_EPS, value)
+
+    def add_vision_image_size(self, value: int) -> None:
+        self.add_uint32(Keys.ClipVision.IMAGE_SIZE, value)
+
+    def add_vision_image_mean(self, values: Sequence[float]) -> None:
+        self.add_array(Keys.ClipVision.IMAGE_MEAN, values)
+
+    def add_vision_image_std(self, values: Sequence[float]) -> None:
+        self.add_array(Keys.ClipVision.IMAGE_STD, values)
+
+    def add_vision_use_gelu(self, value: bool) -> None:
+        self.add_bool(Keys.ClipVision.USE_GELU, value)
+
+    def add_vision_use_silu(self, value: bool) -> None:
+        self.add_bool(Keys.ClipVision.USE_SILU, value)
+
+    def add_vision_projector_scale_factor(self, value: int) -> None:
+        self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value)
+
     def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
         pack_prefix = ''
         if not skip_pack_prefix:
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index 22066b2868..3ff378c136 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -961,13 +961,13 @@ class TensorNameMap:
         MODEL_TENSOR.V_ENC_FFN_UP: (
             "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
             "vpm.encoder.layers.{bid}.mlp.fc1",
-            "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM
+            "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 (note: name is swapped)
         ),
 
         MODEL_TENSOR.V_ENC_FFN_DOWN: (
             "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2",
             "vpm.encoder.layers.{bid}.mlp.fc2",
-            "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM
+            "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 (note: name is swapped)
         ),
 
         MODEL_TENSOR.V_PRE_NORM: (
diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp
index f62850ca57..41f89e3a9d 100644
--- a/src/llama-chat.cpp
+++ b/src/llama-chat.cpp
@@ -62,6 +62,7 @@ static const std::map LLM_CHAT_TEMPLATES = {
     { "yandex",            LLM_CHAT_TEMPLATE_YANDEX            },
     { "bailing",           LLM_CHAT_TEMPLATE_BAILING           },
     { "llama4",            LLM_CHAT_TEMPLATE_LLAMA4            },
+    { "smolvlm",           LLM_CHAT_TEMPLATE_SMOLVLM           },
 };
 
 llm_chat_template llm_chat_template_from_str(const std::string & name) {
@@ -81,7 +82,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
     if (tmpl_contains("<|im_start|>")) {
         return tmpl_contains("<|im_sep|>")
             ? LLM_CHAT_TEMPLATE_PHI_4
-            : LLM_CHAT_TEMPLATE_CHATML;
+            : tmpl_contains("")
+                ? LLM_CHAT_TEMPLATE_SMOLVLM // SmolVLM uses <|im_start|> as BOS, but it is NOT chatml
+                : LLM_CHAT_TEMPLATE_CHATML;
     } else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) {
         if (tmpl_contains("[SYSTEM_PROMPT]")) {
             return LLM_CHAT_TEMPLATE_MISTRAL_V7;
@@ -622,7 +625,23 @@ int32_t llm_chat_apply_template(
         if (add_ass) {
             ss << "<|header_start|>assistant<|header_end|>\n\n";
         }
-    }  else {
+    } else if (tmpl == LLM_CHAT_TEMPLATE_SMOLVLM) {
+        // SmolVLM
+        ss << "<|im_start|>"; // uses <|im_start|> as BOS, but the actual content is NOT chatml
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "system") {
+                ss << message->content << "\n\n";
+            } else if (role == "user") {
+                ss << "User: " << message->content << "\n";
+            } else {
+                ss << "Assistant: " << message->content << "\n";
+            }
+        }
+        if (add_ass) {
+            ss << "Assistant:";
+        }
+    } else {
         // template not supported
         return -1;
     }
diff --git a/src/llama-chat.h b/src/llama-chat.h
index 34537ca21e..dc30df711a 100644
--- a/src/llama-chat.h
+++ b/src/llama-chat.h
@@ -41,6 +41,7 @@ enum llm_chat_template {
     LLM_CHAT_TEMPLATE_YANDEX,
     LLM_CHAT_TEMPLATE_BAILING,
     LLM_CHAT_TEMPLATE_LLAMA4,
+    LLM_CHAT_TEMPLATE_SMOLVLM,
     LLM_CHAT_TEMPLATE_UNKNOWN,
 };
 

From 658987cfc9d752dca7758987390d5fb1a7a0a54a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= 
Date: Tue, 22 Apr 2025 21:27:40 +0200
Subject: [PATCH 29/39] CUDA: noncont MMVQ + batched bs1 MUL_MAT_ID (#13014)

* CUDA: noncont MMVQ + batched bs1 MUL_MAT_ID

* fix logic for RoPE support, CUDA graphs
---
 ggml/src/ggml-cuda/ggml-cuda.cu | 165 +++++-----
 ggml/src/ggml-cuda/mmv.cu       | 179 +++++++----
 ggml/src/ggml-cuda/mmv.cuh      |   2 +-
 ggml/src/ggml-cuda/mmvq.cu      | 543 ++++++++++++++++++--------------
 ggml/src/ggml-cuda/mmvq.cuh     |   3 +
 ggml/src/ggml-cuda/quantize.cu  |  66 ++--
 ggml/src/ggml-cuda/quantize.cuh |  12 +-
 ggml/src/ggml-cuda/vecdotq.cuh  |   2 +
 tests/test-backend-ops.cpp      |   2 +-
 9 files changed, 548 insertions(+), 426 deletions(-)

diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index a7febef723..e0e0d2137f 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -1410,6 +1410,11 @@ static void ggml_cuda_op_mul_mat(
     const int64_t ne0 = dst->ne[0];
     const int64_t ne1 = dst->ne[1];
 
+    // const int64_t nb10 = src1->nb[0];
+    const int64_t nb11 = src1->nb[1];
+    const int64_t nb12 = src1->nb[2];
+    const int64_t nb13 = src1->nb[3];
+
     const int64_t nb2 = dst->nb[2];
     const int64_t nb3 = dst->nb[3];
 
@@ -1545,7 +1550,10 @@ static void ggml_cuda_op_mul_mat(
             dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size);
 
             if (src1_on_device && src1_is_contiguous) {
-                quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream);
+                quantize_src1(
+                    dev[id].src1_ddf, dev[id].src1_ddq, src0->type, ne10,
+                    nb11/sizeof(float), nb12/sizeof(float), nb13/sizeof(float),
+                    src1_padded_col_size, ne11, ne12, ne13, stream);
                 CUDA_CHECK(cudaGetLastError());
             }
         }
@@ -1640,7 +1648,9 @@ static void ggml_cuda_op_mul_mat(
                 }
 
                 if (quantize_src1 && !src1_is_contiguous) {
-                    quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream);
+                    quantize_src1(
+                        src1_ddf_i, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10,
+                        src1_padded_col_size, src1_ncols, 1, 1, stream);
                     CUDA_CHECK(cudaGetLastError());
                 }
 
@@ -1878,7 +1888,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
 static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
 
-    bool use_mul_mat_vec   = (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
+    bool use_mul_mat_vec   = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
         && src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
     bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
@@ -1919,10 +1929,12 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
     //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
 
-    if (!split && use_mul_mat_vec && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
+    if (!split && use_mul_mat_vec && (src0->ne[1] <= MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
         // the custom F16 vector kernel can be used over batched cuBLAS GEMM
         // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
-        ggml_cuda_mul_mat_vec(ctx, src0, src1, dst);
+        ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
+    } else if (!split && use_mul_mat_vec_q) {
+        ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
     } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
                && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
         // general KQ + KQV multi-batch without FlashAttention
@@ -1999,6 +2011,15 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
+    if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && ne2 == 1) {
+        if (ggml_is_quantized(src0->type)) {
+            ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
+        } else {
+            ggml_cuda_mul_mat_vec(ctx, src0, src1, ids, dst);
+        }
+        return;
+    }
+
     GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers");
 
     cudaStream_t stream = ctx.stream();
@@ -2035,97 +2056,75 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
     dst_row.nb[2] = nb1;
     dst_row.nb[3] = nb1;
 
-    if (ne12 == 1) {
+    ggml_cuda_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
+    ggml_cuda_pool_alloc  dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
+
+    src1_row.data = src1_contiguous.get();
+    dst_row.data  =  dst_contiguous.get();
+
+    for (int64_t i02 = 0; i02 < n_as; i02++) {
+        int64_t num_src1_rows = 0;
+
         for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
             for (int64_t id = 0; id < n_ids; id++) {
-                const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
+                const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
 
-                GGML_ASSERT(i02 >= 0 && i02 < n_as);
+                GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
 
-                const int64_t i11 = id % ne11;
-                const int64_t i12 = iid1;
+                if (row_id_i != i02) {
+                    continue;
+                }
 
-                const int64_t i1 = id;
-                const int64_t i2 = i12;
-
-                src0_row.data = src0_original + i02*nb02;
-                src1_row.data = src1_original + i11*nb11 + i12*nb12;
-                dst_row.data  =  dst_original + i1*nb1   + i2*nb2;
-
-                ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
+                num_src1_rows++;
             }
         }
-    } else {
-        ggml_cuda_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
-        ggml_cuda_pool_alloc  dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
 
-        src1_row.data = src1_contiguous.get();
-        dst_row.data  =  dst_contiguous.get();
+        if (num_src1_rows == 0) {
+            continue;
+        }
 
-        for (int64_t i02 = 0; i02 < n_as; i02++) {
-            int64_t num_src1_rows = 0;
+        ggml_cuda_pool_alloc dev_cur_src1_row(ctx.pool(), 1);
+        ggml_cuda_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows);
+        CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
 
-            for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
-                for (int64_t id = 0; id < n_ids; id++) {
-                    const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
+        {
+            dim3 block_dims(std::min((unsigned int)ne10, 768u));
+            dim3 grid_dims(ids->ne[1], n_ids);
+            k_copy_src1_to_contiguous<<>>(
+                    src1_original, src1_contiguous.get(),
+                    dev_cur_src1_row.get(), dev_row_mapping.get(),
+                    ids_dev, i02, ids->nb[1], ids->nb[0],
+                    ne11, ne10,
+                    nb11, nb12);
+            CUDA_CHECK(cudaGetLastError());
+        }
 
-                    GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
+        src0_row.data = src0_original + i02*nb02;
 
-                    if (row_id_i != i02) {
-                        continue;
-                    }
+        GGML_ASSERT(nb11 == sizeof(float)*ne10);
+        GGML_ASSERT(nb1 == sizeof(float)*ne0);
 
-                    num_src1_rows++;
-                }
-            }
+        src1_row.ne[1] = num_src1_rows;
+        src1_row.nb[1] = nb11;
+        src1_row.nb[2] = num_src1_rows*nb11;
+        src1_row.nb[3] = num_src1_rows*nb11;
 
-            if (num_src1_rows == 0) {
-                continue;
-            }
+        dst_row.ne[1] = num_src1_rows;
+        dst_row.nb[1] = nb1;
+        dst_row.nb[2] = num_src1_rows*nb1;
+        dst_row.nb[3] = num_src1_rows*nb1;
 
-            ggml_cuda_pool_alloc dev_cur_src1_row(ctx.pool(), 1);
-            ggml_cuda_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows);
-            CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
+        ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
 
-            {
-                dim3 block_dims(std::min((unsigned int)ne10, 768u));
-                dim3 grid_dims(ids->ne[1], n_ids);
-                k_copy_src1_to_contiguous<<>>(
-                        src1_original, src1_contiguous.get(),
-                        dev_cur_src1_row.get(), dev_row_mapping.get(),
-                        ids_dev, i02, ids->nb[1], ids->nb[0],
-                        ne11, ne10,
-                        nb11, nb12);
-                CUDA_CHECK(cudaGetLastError());
-            }
-
-            src0_row.data = src0_original + i02*nb02;
-
-            GGML_ASSERT(nb11 == sizeof(float)*ne10);
-            GGML_ASSERT(nb1 == sizeof(float)*ne0);
-
-            src1_row.ne[1] = num_src1_rows;
-            src1_row.nb[1] = nb11;
-            src1_row.nb[2] = num_src1_rows*nb11;
-            src1_row.nb[3] = num_src1_rows*nb11;
-
-            dst_row.ne[1] = num_src1_rows;
-            dst_row.nb[1] = nb1;
-            dst_row.nb[2] = num_src1_rows*nb1;
-            dst_row.nb[3] = num_src1_rows*nb1;
-
-            ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
-
-            {
-                dim3 block_dims(std::min((unsigned int)ne0, 768u));
-                dim3 grid_dims(num_src1_rows);
-                k_copy_dst_from_contiguous<<>>(
-                        dst_original, dst_contiguous.get(),
-                        dev_row_mapping.get(),
-                        ne0,
-                        nb1, nb2);
-                CUDA_CHECK(cudaGetLastError());
-            }
+        {
+            dim3 block_dims(std::min((unsigned int)ne0, 768u));
+            dim3 grid_dims(num_src1_rows);
+            k_copy_dst_from_contiguous<<>>(
+                    dst_original, dst_contiguous.get(),
+                    dev_row_mapping.get(),
+                    ne0,
+                    nb1, nb2);
+            CUDA_CHECK(cudaGetLastError());
         }
     }
 }
@@ -2489,7 +2488,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
 #endif
         }
 
-        if (node->op == GGML_OP_MUL_MAT_ID) {
+        if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
             use_cuda_graph = false; // This node type is not supported by CUDA graph capture
 #ifndef NDEBUG
             GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
@@ -3203,9 +3202,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         }
         case GGML_OP_ROPE:
         case GGML_OP_ROPE_BACK: {
-            const size_t ts = ggml_type_size(op->src[0]->type);
-            const int64_t ne0_012 = op->src[0]->ne[0] * op->src[0]->ne[1] * op->src[0]->ne[2];
-            return op->src[0]->nb[0] == ts && op->src[0]->nb[3] == ne0_012*ts;
+            return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
         }
         case GGML_OP_IM2COL:
         case GGML_OP_POOL_2D:
diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu
index b39961cd11..d8c385e239 100644
--- a/ggml/src/ggml-cuda/mmv.cu
+++ b/ggml/src/ggml-cuda/mmv.cu
@@ -4,18 +4,23 @@
 
 template 
 static __global__ void mul_mat_vec(
-        const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
+        const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
+        const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row,
         const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
         const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
-    const int64_t row       = blockIdx.x;
-    const int64_t channel   = blockIdx.y;
-    const int64_t sample    = blockIdx.z;
-    const int     tid       = threadIdx.x;
-    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+    const int64_t row         = blockIdx.x;
+    const int64_t channel_dst = blockIdx.y;
+    const int64_t channel_x   = ids ? ids[channel_dst]          : channel_dst / channel_ratio;
+    const int64_t channel_y   = ids ? channel_dst % nchannels_y : channel_dst;
+    const int64_t sample_dst  = blockIdx.z;
+    const int64_t sample_x    = sample_dst / sample_ratio;
+    const int64_t sample_y    = sample_dst;
+    const int     tid         = threadIdx.x;
+    constexpr int warp_size   = ggml_cuda_get_physical_warp_size();
 
-    x   +=  (sample/sample_ratio)*stride_sample_x   + (channel/channel_ratio)*stride_channel_x + row*stride_row;
-    y   +=   sample              *stride_sample_y   +  channel               *stride_channel_y;
-    dst +=   sample              *stride_sample_dst +  channel               *stride_channel_dst;
+    x   += sample_x  *stride_sample_x   + channel_x  *stride_channel_x   + row*stride_row;
+    y   += sample_y  *stride_sample_y   + channel_y  *stride_channel_y;
+    dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst;
 
     const float2 * y2 = (const float2 *) y;
 
@@ -31,12 +36,19 @@ static __global__ void mul_mat_vec(
 
     float sumf = 0.0f;
 
-    if constexpr (std::is_same::value) {
+    if constexpr (std::is_same::value) {
+        const float2 * x2 = (const float2 *) x;
+
+        for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+            const float2 tmpx = x2[col2];
+            const float2 tmpy = y2[col2];
+            sumf += tmpx.x*tmpy.x;
+            sumf += tmpx.y*tmpy.y;
+        }
+    } else if constexpr (std::is_same::value) {
         const half2 * x2 = (const half2 *) x;
 
         if (std::is_same::value) {
-            sumf = 0.0f;
-
             for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
                 const float2 tmpx = __half22float2(x2[col2]);
                 const float2 tmpy = y2[col2];
@@ -59,8 +71,6 @@ static __global__ void mul_mat_vec(
         }
     } else if constexpr (std::is_same::value) {
         const int * x2 = (const int *) x;
-        sumf = 0.0f;
-
         for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
             const int    tmpx = x2[col2];
             const float2 tmpy = y2[col2];
@@ -92,17 +102,17 @@ static __global__ void mul_mat_vec(
 
 template 
 static void launch_mul_mat_vec_cuda(
-        const T * x, const float * y, float * dst,
-        const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
+        const T * x, const float * y, const int32_t * ids, float * dst,
+        const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
-        const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
         cudaStream_t stream) {
     GGML_ASSERT(ncols      % 2 == 0);
     GGML_ASSERT(stride_row % 2 == 0);
-    GGML_ASSERT(nchannels_y % nchannels_x == 0);
-    GGML_ASSERT(nsamples_y  % nsamples_x  == 0);
-    const int64_t channel_ratio = nchannels_y / nchannels_x;
-    const int64_t sample_ratio  = nsamples_y  / nsamples_x;
+    GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
+    GGML_ASSERT(       nsamples_dst  % nsamples_x  == 0);
+    const int64_t channel_ratio = nchannels_dst / nchannels_x;
+    const int64_t sample_ratio  = nsamples_dst  / nsamples_x;
     int device;
     int warp_size;
 
@@ -124,48 +134,48 @@ static void launch_mul_mat_vec_cuda(
     }
 
     const int smem = warp_size*sizeof(float);
-    const dim3 block_nums(nrows, nchannels_y, nsamples_y);
+    const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
     const dim3 block_dims(block_size_best, 1, 1);
     switch (block_size_best) {
         case   32: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case   64: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case   96: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  128: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  160: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  192: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  224: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  256: {
             mul_mat_vec<<>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         default: {
             GGML_ABORT("fatal error");
@@ -175,28 +185,28 @@ static void launch_mul_mat_vec_cuda(
 
 template
 static void mul_mat_vec_cuda(
-        const T * x, const float * y, float * dst,
-        const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
+        const T * x, const float * y, const int32_t * ids, float * dst,
+        const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
-        const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
         enum ggml_prec prec, cudaStream_t stream) {
-    switch (prec) {
-        case GGML_PREC_DEFAULT: {
+    if constexpr(std::is_same::value) {
+        if (prec == GGML_PREC_DEFAULT) {
             launch_mul_mat_vec_cuda
-                (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
-        } break;
-        case GGML_PREC_F32: {
-            launch_mul_mat_vec_cuda
-                (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
-        } break;
+                (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            return;
+        }
     }
+    launch_mul_mat_vec_cuda
+        (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+         stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
 }
 
-void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+    GGML_ASSERT(        src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(!ids ||  ids->type == GGML_TYPE_I32);
+    GGML_ASSERT(         dst->type == GGML_TYPE_F32);
 
     GGML_TENSOR_BINARY_OP_LOCALS;
 
@@ -204,21 +214,24 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
     const size_t ts_src1 = ggml_type_size(src1->type);
     const size_t ts_dst  = ggml_type_size(dst->type);
 
-    GGML_ASSERT(ne11 == 1);
-    GGML_ASSERT(ne12 == ne2);
+    GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for  batch size 1.
     GGML_ASSERT(ne13 == ne3);
 
-    GGML_ASSERT(nb00 == ts_src0);
-    GGML_ASSERT(nb10 == ts_src1);
-    GGML_ASSERT(nb0  == ts_dst);
+    GGML_ASSERT(        nb00       == ts_src0);
+    GGML_ASSERT(        nb10       == ts_src1);
+    GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
+    GGML_ASSERT(        nb0        == ts_dst);
 
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
 
-    const float * src1_d = (const float *) src1->data;
-    float       *  dst_d = (float       *)  dst->data;
+    const float   * src1_d =       (const float   *) src1->data;
+    const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;
+    float         *  dst_d =       (float         *)  dst->data;
 
     const int64_t s01 = src0->nb[1] / ts_src0;
+    const int64_t s11 = src1->nb[1] / ts_src1;
+    const int64_t s1  =  dst->nb[1] / ts_dst;
     const int64_t s02 = src0->nb[2] / ts_src0;
     const int64_t s12 = src1->nb[2] / ts_src1;
     const int64_t s2  =  dst->nb[2] / ts_dst;
@@ -226,14 +239,33 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
     const int64_t s13 = src1->nb[3] / ts_src1;
     const int64_t s3  =  dst->nb[3] / ts_dst;
 
+    // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
+    const int64_t ncols_dst          = ids ? ne2  : ne1;
+    const int64_t nchannels_y        = ids ? ne11 : ne12;
+    const int64_t nchannels_dst      = ids ? ne1  : ne2;
+    const int64_t stride_channel_dst = ids ? s1   : s2;
+    const int64_t stride_channel_y   = ids ? s11  : s12;
+
+    GGML_ASSERT(ncols_dst == 1);
+
     switch (src0->type) {
+        case GGML_TYPE_F32: {
+            const float * src0_d = (const float *) src0->data;
+            mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
+                ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
+                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
+        } break;
         case GGML_TYPE_F16: {
             const half * src0_d = (const half *) src0->data;
-            mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
+            mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
+                ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
+                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
-            mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
+            mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
+                ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
+                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
         } break;
         default:
             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
@@ -262,27 +294,34 @@ void ggml_cuda_op_mul_mat_vec(
     const int64_t stride_row         = ne00;
     const int64_t nchannels_x        = 1;
     const int64_t nchannels_y        = 1;
+    const int64_t nchannels_dst      = 1;
     const int64_t stride_channel_x   = 0;
     const int64_t stride_channel_y   = 0;
     const int64_t stride_channel_dst = 0;
     const int64_t nsamples_x         = 1;
-    const int64_t nsamples_y         = 1;
+    const int64_t nsamples_dst       = 1;
     const int64_t stride_sample_x    = 0;
     const int64_t stride_sample_y    = 0;
     const int64_t stride_sample_dst  = 0;
 
     switch (src0->type) {
+        case GGML_TYPE_F32: {
+            const float * src0_d = (const float *) src0_dd_i;
+            mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
+                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+        } break;
         case GGML_TYPE_F16: {
             const half * src0_d = (const half *) src0_dd_i;
-            mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
-                nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+            mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
+                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
-            mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
-                nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+            mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
+                nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
         } break;
         default:
             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
diff --git a/ggml/src/ggml-cuda/mmv.cuh b/ggml/src/ggml-cuda/mmv.cuh
index 78a1cd4a69..756e7e1cc7 100644
--- a/ggml/src/ggml-cuda/mmv.cuh
+++ b/ggml/src/ggml-cuda/mmv.cuh
@@ -3,7 +3,7 @@
 // maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available
 #define MMV_MAX_ROWS 512
 
-void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
 
 void ggml_cuda_op_mul_mat_vec(
     ggml_backend_cuda_context & ctx,
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index eef8585a73..cac04916cd 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -1,6 +1,9 @@
 #include "mmvq.cuh"
+#include "quantize.cuh"
 #include "vecdotq.cuh"
 
+#include 
+
 typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
 
 static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
@@ -73,9 +76,9 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
     return MMVQ_PARAMETERS_GENERIC;
 }
 
-static constexpr __host__ __device__ int calc_nwarps(int ncols_y,  mmvq_parameter_table_id table_id) {
+static constexpr __host__ __device__ int calc_nwarps(int ncols_dst,  mmvq_parameter_table_id table_id) {
     if (table_id == MMVQ_PARAMETERS_GENERIC) {
-        switch (ncols_y) {
+        switch (ncols_dst) {
             case 1:
             case 2:
             case 3:
@@ -90,7 +93,7 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_y,  mmvq_paramete
                 return 1;
         }
     } else if (table_id == MMVQ_PARAMETERS_GCN) {
-        switch (ncols_y) {
+        switch (ncols_dst) {
             case 1:
             case 2:
             case 3:
@@ -107,9 +110,9 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_y,  mmvq_paramete
     return 1;
 }
 
-static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int table_id) {
+static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) {
     if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
-        switch (ncols_y) {
+        switch (ncols_dst) {
             case 1:
                 return 1;
             case 2:
@@ -127,19 +130,21 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int ta
     return 1;
 }
 
-template 
+template 
 // tell the compiler to use as many registers as it wants, see nwarps definition below
-__launch_bounds__(calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
+__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
 static __global__ void mul_mat_vec_q(
-    const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+        const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
+        const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst,
+        const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+        const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
 
     constexpr int qk  = ggml_cuda_type_traits::qk;
     constexpr int qi  = ggml_cuda_type_traits::qi;
     constexpr int vdr = get_vdr_mmvq(type);
     constexpr mmvq_parameter_table_id table_id = get_device_table_id();
-    constexpr int nwarps = calc_nwarps(ncols_y, table_id);
-    constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_y, table_id);
+    constexpr int nwarps = calc_nwarps(ncols_dst, table_id);
+    constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
     constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
@@ -147,13 +152,21 @@ static __global__ void mul_mat_vec_q(
     const     int tid = warp_size*threadIdx.y + threadIdx.x;
     const     int row0 = rows_per_cuda_block*blockIdx.x;
     const     int blocks_per_row_x = ncols_x / qk;
-    const     int blocks_per_col_y = nrows_y / QK8_1;
     constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
 
-    // partial sum for each thread
-    float tmp[ncols_y][rows_per_cuda_block] = {{0.0f}};
+    // The MUL_MAT_ID code path with ids != nullptr is only implemetned for ncols_dst == 1.
+    const int channel_dst = blockIdx.y;
+    const int channel_x   = ncols_dst == 1 && ids ? ids[channel_dst]          : channel_dst / channel_ratio;
+    const int channel_y   = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
+    const int sample_dst  = blockIdx.z;
+    const int sample_x    = sample_dst / sample_ratio;
+    const int sample_y    = sample_dst;
 
-    const block_q8_1 * y = (const block_q8_1 *) vy;
+    // partial sum for each thread
+    float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
+
+    const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
+    const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
 
     for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
         const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
@@ -162,18 +175,19 @@ static __global__ void mul_mat_vec_q(
         const int kqs = vdr * (tid % (qi/vdr));
 
 #pragma unroll
-        for (int j = 0; j < ncols_y; ++j) {
+        for (int j = 0; j < ncols_dst; ++j) {
 #pragma unroll
             for (int i = 0; i < rows_per_cuda_block; ++i) {
-                tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs);
+                tmp[j][i] += vec_dot_q_cuda(
+                    vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
             }
         }
     }
 
-    __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][warp_size];
+    __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
     if (threadIdx.y > 0) {
 #pragma unroll
-        for (int j = 0; j < ncols_y; ++j) {
+        for (int j = 0; j < ncols_dst; ++j) {
 #pragma unroll
             for (int i = 0; i < rows_per_cuda_block; ++i) {
                 tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
@@ -185,9 +199,11 @@ static __global__ void mul_mat_vec_q(
         return;
     }
 
+    dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
+
     // sum up partial sums and write back result
 #pragma unroll
-    for (int j = 0; j < ncols_y; ++j) {
+    for (int j = 0; j < ncols_dst; ++j) {
 #pragma unroll
         for (int i = 0; i < rows_per_cuda_block; ++i) {
 #pragma unroll
@@ -197,88 +213,121 @@ static __global__ void mul_mat_vec_q(
             tmp[j][i] = warp_reduce_sum(tmp[j][i]);
         }
 
-        if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < (unsigned)nrows_dst)) {
-            dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
+        if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + int(threadIdx.x) < stride_col_dst)) {
+            dst[j*stride_col_dst + threadIdx.x] = tmp[j][threadIdx.x];
         }
     }
-
-    GGML_UNUSED(nrows_x);
 }
 
-static std::pair calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) {
-    const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_y, table_id) - 1) / calc_rows_per_block(ncols_y, table_id);
-    const dim3 block_nums(nblocks, 1, 1);
-    const dim3 block_dims(warp_size, calc_nwarps(ncols_y, table_id), 1);
+static std::pair calc_launch_params(
+        const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y,
+        const int warp_size, const mmvq_parameter_table_id table_id) {
+    const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
+    const dim3 block_nums(nblocks, nchannels_y, nsamples_y);
+    const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1);
     return {block_nums, block_dims};
 }
 
 template 
-static void mul_mat_vec_q_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+static void mul_mat_vec_q_switch_ncols_dst(
+        const void * vx, const void * vy, const int32_t * ids, float * dst,
+        const int ncols_x, const int nrows_x, const int ncols_dst,
+        const int stride_row_x, const int stride_col_y, const int stride_col_dst,
+        const int nchannels_x, const int nchannels_y, const int nchannels_dst,
+        const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+        const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+        cudaStream_t stream) {
 
     GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
-    GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
+    GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
+
+    const int channel_ratio = nchannels_dst / nchannels_x;
+    const int sample_ratio  = nsamples_dst  / nsamples_x;
 
     const int device = ggml_cuda_get_device();
     const int warp_size = ggml_cuda_info().devices[device].warp_size;
     const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
 
-    switch (ncols_y) {
+    GGML_ASSERT(!ids || ncols_dst == 1);
+    switch (ncols_dst) {
         case 1:
         {
-            constexpr int c_ncols_y = 1;
-            std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
-            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            constexpr int c_ncols_dst = 1;
+            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            mul_mat_vec_q<<>>
+                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
             break;
         }
         case 2:
         {
-            constexpr int c_ncols_y = 2;
-            std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
-            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            constexpr int c_ncols_dst = 2;
+            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            mul_mat_vec_q<<>>
+                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
             break;
         }
         case 3:
         {
-            constexpr int c_ncols_y = 3;
-            std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
-            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            constexpr int c_ncols_dst = 3;
+            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            mul_mat_vec_q<<>>
+                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
             break;
         }
         case 4:
         {
-            constexpr int c_ncols_y = 4;
-            std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
-            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            constexpr int c_ncols_dst = 4;
+            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            mul_mat_vec_q<<>>
+                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
             break;
         }
         case 5:
         {
-            constexpr int c_ncols_y = 5;
-            std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
-            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            constexpr int c_ncols_dst = 5;
+            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            mul_mat_vec_q<<>>
+                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
             break;
         }
         case 6:
         {
-            constexpr int c_ncols_y = 6;
-            std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
-            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            constexpr int c_ncols_dst = 6;
+            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            mul_mat_vec_q<<>>
+                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
             break;
         }
         case 7:
         {
-            constexpr int c_ncols_y = 7;
-            std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
-            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            constexpr int c_ncols_dst = 7;
+            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            mul_mat_vec_q<<>>
+                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
             break;
         }
         case 8:
         {
-            constexpr int c_ncols_y = 8;
-            std::pair dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
-            mul_mat_vec_q<<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+            constexpr int c_ncols_dst = 8;
+            std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+            mul_mat_vec_q<<>>
+                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
             break;
         }
         default:
@@ -287,137 +336,213 @@ static void mul_mat_vec_q_cuda(
     }
 }
 
-static void mul_mat_vec_q4_0_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+static void mul_mat_vec_q_switch_type(
+        const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, float * dst,
+        const int ncols_x, const int nrows_x, const int ncols_dst,
+        const int stride_row_x, const int stride_col_y, const int stride_col_dst,
+        const int nchannels_x, const int nchannels_y, const int nchannels_dst,
+        const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+        const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+        cudaStream_t stream) {
+    switch (type_x) {
+        case GGML_TYPE_Q4_0:
+            mul_mat_vec_q_switch_ncols_dst
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
+            break;
+        case GGML_TYPE_Q4_1:
+            mul_mat_vec_q_switch_ncols_dst
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
+            break;
+        case GGML_TYPE_Q5_0:
+            mul_mat_vec_q_switch_ncols_dst
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
+            break;
+        case GGML_TYPE_Q5_1:
+            mul_mat_vec_q_switch_ncols_dst
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
+            break;
+        case GGML_TYPE_Q8_0:
+            mul_mat_vec_q_switch_ncols_dst
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
+            break;
+        case GGML_TYPE_Q2_K:
+            mul_mat_vec_q_switch_ncols_dst
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
+            break;
+        case GGML_TYPE_Q3_K:
+            mul_mat_vec_q_switch_ncols_dst
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
+            break;
+        case GGML_TYPE_Q4_K:
+            mul_mat_vec_q_switch_ncols_dst
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
+            break;
+        case GGML_TYPE_Q5_K:
+            mul_mat_vec_q_switch_ncols_dst
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
+            break;
+        case GGML_TYPE_Q6_K:
+            mul_mat_vec_q_switch_ncols_dst
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
+            break;
+        case GGML_TYPE_IQ2_XXS:
+            mul_mat_vec_q_switch_ncols_dst
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
+            break;
+        case GGML_TYPE_IQ2_XS:
+            mul_mat_vec_q_switch_ncols_dst
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
+            break;
+        case GGML_TYPE_IQ2_S:
+            mul_mat_vec_q_switch_ncols_dst
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
+            break;
+        case GGML_TYPE_IQ3_XXS:
+            mul_mat_vec_q_switch_ncols_dst
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
+            break;
+        case GGML_TYPE_IQ1_S:
+            mul_mat_vec_q_switch_ncols_dst
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
+            break;
+        case GGML_TYPE_IQ1_M:
+            mul_mat_vec_q_switch_ncols_dst
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
+            break;
+        case GGML_TYPE_IQ4_NL:
+            mul_mat_vec_q_switch_ncols_dst
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
+            break;
+        case GGML_TYPE_IQ4_XS:
+            mul_mat_vec_q_switch_ncols_dst
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
+            break;
+        case GGML_TYPE_IQ3_S:
+            mul_mat_vec_q_switch_ncols_dst
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
+            break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
 }
 
-static void mul_mat_vec_q4_1_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+void ggml_cuda_mul_mat_vec_q(
+        ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+    GGML_ASSERT(        src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(        dst->type  == GGML_TYPE_F32);
+    GGML_ASSERT(!ids || ids->type  == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
 
-    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
+    GGML_TENSOR_BINARY_OP_LOCALS;
 
-static void mul_mat_vec_q5_0_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+    cudaStream_t stream = ctx.stream();
 
-    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
+    const size_t ts_src0 = ggml_type_size(src0->type);
+    const size_t ts_src1 = ggml_type_size(src1->type);
+    const size_t ts_dst  = ggml_type_size(dst->type);
 
-static void mul_mat_vec_q5_1_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+    GGML_ASSERT(        nb00       == ts_src0);
+    GGML_ASSERT(        nb10       == ts_src1);
+    GGML_ASSERT(        nb0        == ts_dst);
+    GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
 
-    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
+    GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for  batch size 1.
 
-static void mul_mat_vec_q8_0_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+    const float   * src1_d =       (const float   *) src1->data;
+    const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;
+    float         *  dst_d =       (float         *)  dst->data;
 
-    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
+    const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
+    ggml_cuda_pool_alloc src1_q8_1(ctx.pool(), ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1);
+    {
+        const int64_t s11 = src1->nb[1] / ts_src1;
+        const int64_t s12 = src1->nb[2] / ts_src1;
+        const int64_t s13 = src1->nb[3] / ts_src1;
+        quantize_row_q8_1_cuda(src1_d, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
+    }
 
-static void mul_mat_vec_q2_K_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+    const int64_t s01 = src0->nb[1] / ts_src0;
+    const int64_t s11 = ne10_padded / QK8_1;
+    const int64_t s1  =  dst->nb[1] / ts_dst;
+    const int64_t s02 = src0->nb[2] / ts_src0;
+    const int64_t s2  =  dst->nb[2] / ts_dst;
+    const int64_t s03 = src0->nb[3] / ts_src0;
+    const int64_t s3  =  dst->nb[3] / ts_dst;
 
-    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
+    const int64_t s12 = ne11*s11;
+    const int64_t s13 = ne12*s12;
 
-static void mul_mat_vec_q3_K_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+    // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
+    const int64_t ncols_dst          = ids ? ne2  : ne1;
+    const int64_t nchannels_y        = ids ? ne11 : ne12;
+    const int64_t nchannels_dst      = ids ? ne1  : ne2;
+    const int64_t stride_col_dst     = ids ? s2   : s1;
+    const int64_t stride_col_y       = ids ? s12  : s11;
+    const int64_t stride_channel_dst = ids ? s1   : s2;
+    const int64_t stride_channel_y   = ids ? s11  : s12;
 
-    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_q4_K_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_q5_K_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_q6_K_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq2_xxs_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq2_xs_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq2_s_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq3_xxs_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq1_s_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq1_m_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq4_nl_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq4_xs_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq3_s_q8_1_cuda(
-    const void * vx, const void * vy, float * dst,
-    const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
-    mul_mat_vec_q_cuda(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
+    mul_mat_vec_q_switch_type(
+        src0->data, src0->type, src1_q8_1.get(), ids_d, dst_d, ne00,
+        ne01,              ncols_dst,     s01, stride_col_y,     stride_col_dst,
+        ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
+        ne03,              ne3,           s03, s13,              s3,                 stream);
 }
 
 void ggml_cuda_op_mul_mat_vec_q(
@@ -440,68 +565,12 @@ void ggml_cuda_op_mul_mat_vec_q(
     // nrows_dst == nrows of the matrix that the kernel writes into
     const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
 
-    switch (src0->type) {
-        case GGML_TYPE_Q4_0:
-            mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
-            break;
-        case GGML_TYPE_Q4_1:
-            mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
-            break;
-        case GGML_TYPE_Q5_0:
-            mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
-            break;
-        case GGML_TYPE_Q5_1:
-            mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
-            break;
-        case GGML_TYPE_Q8_0:
-            mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
-            break;
-        case GGML_TYPE_Q2_K:
-            mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
-            break;
-        case GGML_TYPE_Q3_K:
-            mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
-            break;
-        case GGML_TYPE_Q4_K:
-            mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
-            break;
-        case GGML_TYPE_Q5_K:
-            mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
-            break;
-        case GGML_TYPE_Q6_K:
-            mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
-            break;
-        case GGML_TYPE_IQ2_XXS:
-            mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
-            break;
-        case GGML_TYPE_IQ2_XS:
-            mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
-            break;
-        case GGML_TYPE_IQ2_S:
-            mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
-            break;
-        case GGML_TYPE_IQ3_XXS:
-            mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
-            break;
-        case GGML_TYPE_IQ1_S:
-            mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
-            break;
-        case GGML_TYPE_IQ1_M:
-            mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
-            break;
-        case GGML_TYPE_IQ4_NL:
-            mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
-            break;
-        case GGML_TYPE_IQ4_XS:
-            mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
-            break;
-        case GGML_TYPE_IQ3_S:
-            mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
-            break;
-        default:
-            GGML_ABORT("fatal error");
-            break;
-    }
+    const int stride_row_x = ne00 / ggml_blck_size(src0->type);
+    const int stride_col_y = src1_padded_row_size / QK8_1;
+
+    mul_mat_vec_q_switch_type(
+        src0_dd_i, src0->type, src1_ddq_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
+        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
 
     GGML_UNUSED(src1);
     GGML_UNUSED(dst);
diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh
index d9e42fdd6d..39dc7d33eb 100644
--- a/ggml/src/ggml-cuda/mmvq.cuh
+++ b/ggml/src/ggml-cuda/mmvq.cuh
@@ -2,6 +2,9 @@
 
 #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
 
+void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
+    const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
+
 void ggml_cuda_op_mul_mat_vec_q(
     ggml_backend_cuda_context & ctx,
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu
index 1702e4ce2f..3bab47d56a 100644
--- a/ggml/src/ggml-cuda/quantize.cu
+++ b/ggml/src/ggml-cuda/quantize.cu
@@ -1,30 +1,40 @@
 #include "quantize.cuh"
 #include 
 
-static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) {
-    const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
+static __global__ void quantize_q8_1(
+        const float * __restrict__ x, void * __restrict__ vy,
+        const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
+        const int64_t ne0, const int ne1, const int ne2) {
+    const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
 
-    if (ix0 >= kx0_padded) {
+    if (i0 >= ne0) {
         return;
     }
 
-    const int64_t ix1 = blockIdx.y;
+    const int64_t i1 = blockIdx.y;
+    const int64_t i2 = blockIdx.z % ne2;
+    const int64_t i3 = blockIdx.z / ne2;
 
-    const int64_t i_padded = ix1*kx0_padded + ix0;
+    const int64_t & i00 = i0;
+    const int64_t & i01 = i1;
+    const int64_t & i02 = i2;
+    const int64_t & i03 = i3;
+
+    const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0;
 
     block_q8_1 * y = (block_q8_1 *) vy;
 
-    const int64_t ib = i_padded / QK8_1; // block index
-    const int64_t iqs = i_padded % QK8_1; // quant index
+    const int64_t ib  = i_cont / QK8_1; // block index
+    const int64_t iqs = i_cont % QK8_1; // quant index
 
-    const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f;
+    const float xi = i0 < ne00 ? x[i03*s03 + i02*s02 + i01*s01 + i00] : 0.0f;
     float amax = fabsf(xi);
     float sum = xi;
 
     amax = warp_reduce_max(amax);
-    sum = warp_reduce_sum(sum);
+    sum  = warp_reduce_sum(sum);
 
-    const float d = amax / 127;
+    const float  d = amax / 127;
     const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
 
     y[ib].qs[iqs] = q;
@@ -127,43 +137,45 @@ static __global__ void quantize_mmq_q8_1(
 }
 
 void quantize_row_q8_1_cuda(
-    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
-    const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
+    const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
+    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
 
-    GGML_ASSERT(kx0_padded % QK8_1 == 0);
+    GGML_ASSERT(ne0 % QK8_1 == 0);
 
-    const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
-    const dim3 num_blocks(block_num_x, kx1*channels, 1);
+    const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
+    const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
     const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
-    quantize_q8_1<<>>(x, vy, kx0, kx0_padded);
-
-    GGML_UNUSED(type_x);
+    quantize_q8_1<<>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
+    GGML_UNUSED(type_src0);
 }
 
 void quantize_mmq_q8_1_cuda(
-    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
-    const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
+    const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
+    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
 
-    GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
+    GGML_ASSERT(ne0 % (4*QK8_1) == 0);
 
-    const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
-    const dim3 num_blocks(block_num_x, kx1, channels);
+    const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
+    const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
     const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
-    switch (mmq_get_q8_1_ds_layout(type_x)) {
+    switch (mmq_get_q8_1_ds_layout(type_src0)) {
         case MMQ_Q8_1_DS_LAYOUT_D4:
             quantize_mmq_q8_1
-                <<>>(x, vy, kx0, kx1, kx0_padded);
+                <<>>(x, vy, ne00, ne1, ne0);
             break;
         case MMQ_Q8_1_DS_LAYOUT_DS4:
             quantize_mmq_q8_1
-                <<>>(x, vy, kx0, kx1, kx0_padded);
+                <<>>(x, vy, ne00, ne1, ne0);
             break;
         case MMQ_Q8_1_DS_LAYOUT_D2S6:
             quantize_mmq_q8_1
-                <<>>(x, vy, kx0, kx1, kx0_padded);
+                <<>>(x, vy, ne00, ne1, ne0);
             break;
         default:
             GGML_ABORT("fatal error");
             break;
     }
+    GGML_UNUSED(s01);
+    GGML_UNUSED(s02);
+    GGML_UNUSED(s03);
 }
diff --git a/ggml/src/ggml-cuda/quantize.cuh b/ggml/src/ggml-cuda/quantize.cuh
index 03bf322b95..b627c4e400 100644
--- a/ggml/src/ggml-cuda/quantize.cuh
+++ b/ggml/src/ggml-cuda/quantize.cuh
@@ -12,13 +12,13 @@ static_assert(MATRIX_ROW_PADDING %    CUDA_QUANTIZE_BLOCK_SIZE      == 0, "Risk
 static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
 
 typedef void (*quantize_cuda_t)(
-    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
-    const ggml_type type_x, cudaStream_t stream);
+    const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
+    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
 
 void quantize_row_q8_1_cuda(
-    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
-    const ggml_type type_x, cudaStream_t stream);
+    const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
+    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
 
 void quantize_mmq_q8_1_cuda(
-    const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
-    const ggml_type type_x, cudaStream_t stream);
+    const float * x, void * vy, const ggml_type type_src0, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
+    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream);
diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh
index 40091a0ef0..ba195e1d10 100644
--- a/ggml/src/ggml-cuda/vecdotq.cuh
+++ b/ggml/src/ggml-cuda/vecdotq.cuh
@@ -1,3 +1,5 @@
+#pragma once
+
 #include "common.cuh"
 #include 
 
diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp
index 5f6f87d1a3..61751755b3 100644
--- a/tests/test-backend-ops.cpp
+++ b/tests/test-backend-ops.cpp
@@ -2071,7 +2071,7 @@ struct test_mul_mat_id : public test_case {
     const ggml_type type_b;
     const int n_mats;
     const int n_used;
-    const bool b; // brodcast b matrix
+    const bool b; // broadcast b matrix
     const int64_t m;
     const int64_t n;
     const int64_t k;

From 2cca6c01e46d2fc1124d15730273ed2acdad1016 Mon Sep 17 00:00:00 2001
From: Radoslav Gerganov 
Date: Wed, 23 Apr 2025 10:32:49 +0300
Subject: [PATCH 30/39] rpc : add command line option for number of threads for
 the CPU backend (#13060)

closes #13051
---
 examples/rpc/rpc-server.cpp | 19 +++++++++++++++++--
 1 file changed, 17 insertions(+), 2 deletions(-)

diff --git a/examples/rpc/rpc-server.cpp b/examples/rpc/rpc-server.cpp
index 9db5542570..0277e25cb5 100644
--- a/examples/rpc/rpc-server.cpp
+++ b/examples/rpc/rpc-server.cpp
@@ -22,6 +22,7 @@
 
 #include "ggml-rpc.h"
 #ifdef _WIN32
+#  define NOMINMAX
 #  define DIRECTORY_SEPARATOR '\\'
 #  include 
 #  include 
@@ -37,6 +38,8 @@
 #include 
 #include 
 #include 
+#include 
+#include 
 
 namespace fs = std::filesystem;
 
@@ -150,12 +153,14 @@ struct rpc_server_params {
     int         port        = 50052;
     size_t      backend_mem = 0;
     bool        use_cache   = false;
+    int         n_threads   = std::max(1U, std::thread::hardware_concurrency()/2);
 };
 
 static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
     fprintf(stderr, "Usage: %s [options]\n\n", argv[0]);
     fprintf(stderr, "options:\n");
     fprintf(stderr, "  -h, --help                show this help message and exit\n");
+    fprintf(stderr, "  -t,      --threads        number of threads for the CPU backend (default: %d)\n", params.n_threads);
     fprintf(stderr, "  -H HOST, --host HOST      host to bind to (default: %s)\n", params.host.c_str());
     fprintf(stderr, "  -p PORT, --port PORT      port to bind to (default: %d)\n", params.port);
     fprintf(stderr, "  -m MEM,  --mem MEM        backend memory size (in MB)\n");
@@ -172,6 +177,15 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
                 return false;
             }
             params.host = argv[i];
+        } else if (arg == "-t" || arg == "--threads") {
+            if (++i >= argc) {
+                return false;
+            }
+            params.n_threads = std::stoi(argv[i]);
+            if (params.n_threads <= 0) {
+                fprintf(stderr, "error: invalid number of threads: %d\n", params.n_threads);
+                return false;
+            }
         } else if (arg == "-p" || arg == "--port") {
             if (++i >= argc) {
                 return false;
@@ -199,7 +213,7 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params &
     return true;
 }
 
-static ggml_backend_t create_backend() {
+static ggml_backend_t create_backend(const rpc_server_params & params) {
     ggml_backend_t backend = NULL;
 #ifdef GGML_USE_CUDA
     fprintf(stderr, "%s: using CUDA backend\n", __func__);
@@ -231,6 +245,7 @@ static ggml_backend_t create_backend() {
     if (!backend) {
         fprintf(stderr, "%s: using CPU backend\n", __func__);
         backend = ggml_backend_cpu_init();
+        ggml_backend_cpu_set_n_threads(backend, params.n_threads);
     }
     return backend;
 }
@@ -275,7 +290,7 @@ int main(int argc, char * argv[]) {
         fprintf(stderr, "\n");
     }
 
-    ggml_backend_t backend = create_backend();
+    ggml_backend_t backend = create_backend(params);
     if (!backend) {
         fprintf(stderr, "Failed to create backend\n");
         return 1;

From eb1776b15a32d832f1266deeeab75b9d255c5849 Mon Sep 17 00:00:00 2001
From: piDack <104877312+piDack@users.noreply.github.com>
Date: Wed, 23 Apr 2025 22:59:14 +0800
Subject: [PATCH 31/39] convert : Append mult-eos,half-rope,bos to GLM4-0414
 and Z (#13021)

* append mult-eos,half-rope,bos to GLM4-0414

* remove unset var
---
 convert_hf_to_gguf.py | 17 ++++++++++++++++-
 1 file changed, 16 insertions(+), 1 deletion(-)

diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 645bdad9b5..04131319b1 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -5079,10 +5079,25 @@ class Glm4Model(TextModel):
     model_arch = gguf.MODEL_ARCH.GLM4
 
     def set_vocab(self):
-        self._set_vocab_gpt2()
+        from transformers import AutoTokenizer
+        tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
+        special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
+        tokens, toktypes, tokpre = self.get_vocab_base()
+        self.gguf_writer.add_tokenizer_model("gpt2")
+        self.gguf_writer.add_tokenizer_pre(tokpre)
+        self.gguf_writer.add_token_list(tokens)
+        self.gguf_writer.add_token_types(toktypes)
+        special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
+        special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
+        special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
+        special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
+        special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"])
+        special_vocab.add_to_gguf(self.gguf_writer)
 
     def set_gguf_parameters(self):
         super().set_gguf_parameters()
+        rope_dim = self.hparams["head_dim"]
+        self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
         if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
             if self.hparams["rope_scaling"].get("type") == "yarn":
                 self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)

From ecda2ec4b347031a9b8a89ee2efc664ce63f599c Mon Sep 17 00:00:00 2001
From: Xuan-Son Nguyen 
Date: Wed, 23 Apr 2025 20:21:59 +0200
Subject: [PATCH 32/39] mtmd : Support Pixtral 12B (#13065)

* add pixtral text model (vision is wip)

* cgraph ok, just missing 2D RoPE

* fix bad rebase

* first working version

* fix problem with img_break token

* support dynamic image size

* update docs

* update test script
---
 convert_hf_to_gguf.py              |  66 +++++-
 convert_hf_to_gguf_update.py       |   1 +
 docs/multimodal/gemma3.md          |  12 +-
 examples/llava/README.md           |  28 +++
 examples/llava/clip-impl.h         |   4 +
 examples/llava/clip.cpp            | 343 +++++++++++++++++++++++++++--
 examples/llava/mtmd.cpp            |  16 +-
 examples/llava/tests.sh            |  17 ++
 gguf-py/gguf/constants.py          |   7 +
 gguf-py/gguf/tensor_mapping.py     |  18 ++
 include/llama.h                    |   1 +
 models/ggml-vocab-pixtral.gguf.inp | 112 ++++++++++
 models/ggml-vocab-pixtral.gguf.out |  46 ++++
 src/llama-vocab.cpp                |   3 +-
 14 files changed, 643 insertions(+), 31 deletions(-)
 create mode 100644 models/ggml-vocab-pixtral.gguf.inp
 create mode 100644 models/ggml-vocab-pixtral.gguf.out

diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 04131319b1..cf35fb86ec 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -776,6 +776,9 @@ class TextModel(ModelBase):
         if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
             # ref: https://huggingface.co/THUDM/glm-4-9b-hf
             res = "glm4"
+        if chkhsh == "0e9433cbbb161f89e264eb32e8e64bfe69e834973ffca5d41d3948a604a3e2a3":
+            # ref: https://huggingface.co/mistral-community/pixtral-12b
+            res = "pixtral"
 
         if res is None:
             logger.warning("\n")
@@ -1724,7 +1727,8 @@ class StableLMModel(TextModel):
     "MistralForCausalLM",
     "MixtralForCausalLM",
     "Idefics3ForConditionalGeneration",
-    "SmolVLMForConditionalGeneration")
+    "SmolVLMForConditionalGeneration",
+    "LlavaForConditionalGeneration")
 class LlamaModel(TextModel):
     model_arch = gguf.MODEL_ARCH.LLAMA
     undo_permute = True
@@ -1734,6 +1738,10 @@ class LlamaModel(TextModel):
         # fix for SmolVLM2, missing `num_attention_heads` in config.json
         if self.hparams["architectures"][0] == "SmolVLMForConditionalGeneration":
             self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
+        # fix for Pixtral, missing `num_attention_heads` in config.json
+        if self.hparams["architectures"][0] == "LlavaForConditionalGeneration" \
+                and self.hparams.get("model_type") == "mistral":
+            self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
 
     def set_vocab(self):
         try:
@@ -1797,12 +1805,17 @@ class LlamaModel(TextModel):
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
         n_head = self.hparams["num_attention_heads"]
         n_kv_head = self.hparams.get("num_key_value_heads")
-        is_vision_tensor = "vision_tower" in name or "vision_model" in name or "model.connector" in name
+        is_vision_tensor = "vision_tower" in name \
+            or "vision_model" in name \
+            or "model.connector" in name \
+            or "multi_modal_projector" in name
 
         if is_vision_tensor:
             return [] # skip vision tensors
         elif name.startswith("model.text_model"):
             name = name.replace("text_model.", "") # for SmolVLM
+        elif name.startswith("language_model."):
+            name = name.replace("language_model.", "") # for the rest
 
         if self.undo_permute:
             if name.endswith(("q_proj.weight", "q_proj.bias")):
@@ -1885,6 +1898,55 @@ class LlamaModel(TextModel):
                 raise ValueError(f"Unprocessed experts: {experts}")
 
 
+@ModelBase.register("LlavaForConditionalGeneration")
+class LlavaVisionModel(VisionModel):
+    img_break_tok_id = -1
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        if self.hparams["model_type"] == "pixtral":
+            # fix missing config.json values
+            self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
+            self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 24)
+            self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 4096)
+            self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1024)
+            self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
+            self.img_break_tok_id = 12 # see tokenizer_config.json
+        else:
+            raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        hparams = self.hparams
+        if hparams["model_type"] == "pixtral":
+            self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
+            # default values below are taken from HF tranformers code
+            self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
+            self.gguf_writer.add_vision_use_silu(True)
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        del bid  # unused
+        n_head = self.hparams["num_attention_heads"]
+        n_kv_head = n_head
+
+        if name.startswith("multi_modal_projector.") or name.startswith("vision_tower."):
+            # process vision tensors
+            if name.endswith(("q_proj.weight", "q_proj.bias")):
+                data_torch = LlamaModel.permute(data_torch, n_head, n_head)
+            if name.endswith(("k_proj.weight", "k_proj.bias")):
+                data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
+            return [(self.map_tensor_name(name), data_torch)]
+
+        if self.img_break_tok_id > 0 and "embed_tokens.weight" in name:
+            logger.info(f"Extracting [IMG_BREAK] token embedding from {name}")
+            # for pixtral model, we need to extract the [IMG_BREAK] token embedding
+            img_break_embd = data_torch[self.img_break_tok_id]
+            name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK]
+            return [(self.map_tensor_name(name), img_break_embd)]
+
+        return [] # skip other tensors
+
+
 @ModelBase.register("Idefics3ForConditionalGeneration", "SmolVLMForConditionalGeneration")
 class SmolVLMModel(VisionModel):
     def __init__(self, *args, **kwargs):
diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py
index 160c9fe0e6..03a1d8d8c9 100755
--- a/convert_hf_to_gguf_update.py
+++ b/convert_hf_to_gguf_update.py
@@ -115,6 +115,7 @@ models = [
     {"name": "bailingmoe",       "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", },
     {"name": "llama4",           "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", },
     {"name": "glm4",             "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", },
+    {"name": "pixtral",          "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", },
 ]
 
 
diff --git a/docs/multimodal/gemma3.md b/docs/multimodal/gemma3.md
index 8fa077de71..110a36f408 100644
--- a/docs/multimodal/gemma3.md
+++ b/docs/multimodal/gemma3.md
@@ -11,15 +11,15 @@ You can use pre-quantized model from [ggml-org](https://huggingface.co/ggml-org)
 ```bash
 # build
 cmake -B build
-cmake --build build --target llama-gemma3-cli
+cmake --build build --target llama-mtmd-cli
 
 # alternatively, install from brew (MacOS)
 brew install llama.cpp
 
 # run it
-llama-gemma3-cli -hf ggml-org/gemma-3-4b-it-GGUF
-llama-gemma3-cli -hf ggml-org/gemma-3-12b-it-GGUF
-llama-gemma3-cli -hf ggml-org/gemma-3-27b-it-GGUF
+llama-mtmd-cli -hf ggml-org/gemma-3-4b-it-GGUF
+llama-mtmd-cli -hf ggml-org/gemma-3-12b-it-GGUF
+llama-mtmd-cli -hf ggml-org/gemma-3-27b-it-GGUF
 
 # note: 1B model does not support vision
 ```
@@ -44,8 +44,8 @@ What you need:
 ```bash
 # build
 cmake -B build
-cmake --build build --target llama-gemma3-cli
+cmake --build build --target llama-mtmd-cli
 
 # run it
-./build/bin/llama-gemma3-cli -m {text_model}.gguf --mmproj mmproj.gguf --image your_image.jpg
+./build/bin/llama-mtmd-cli -m {text_model}.gguf --mmproj mmproj.gguf --image your_image.jpg
 ```
diff --git a/examples/llava/README.md b/examples/llava/README.md
index cadbc53fab..f58d9de710 100644
--- a/examples/llava/README.md
+++ b/examples/llava/README.md
@@ -14,6 +14,28 @@ The naming and structure related to multimodal support have evolved, which might
 - [#12849](https://github.com/ggml-org/llama.cpp/pull/12849): `libmtmd` was introduced as a replacement for `llava.cpp`. Its goals include providing a single, unified command-line interface, improving the user/developer experience (UX/DX), and supporting both audio and image inputs.
 - [#13012](https://github.com/ggml-org/llama.cpp/pull/13012): `mtmd-cli` was added, consolidating the various model-specific CLIs into a single tool powered by `libmtmd`.
 
+## Pre-quantized models
+
+These are ready-to-use models, most of them come with `Q4_K_M` quantization by default:
+
+```sh
+# Gemma 3
+llama-mtmd-cli -hf ggml-org/gemma-3-4b-it-GGUF
+llama-mtmd-cli -hf ggml-org/gemma-3-12b-it-GGUF
+llama-mtmd-cli -hf ggml-org/gemma-3-27b-it-GGUF
+
+# SmolVLM
+llama-mtmd-cli -hf ggml-org/SmolVLM-Instruct-GGUF
+llama-mtmd-cli -hf ggml-org/SmolVLM-256M-Instruct-GGUF
+llama-mtmd-cli -hf ggml-org/SmolVLM-500M-Instruct-GGUF
+llama-mtmd-cli -hf ggml-org/SmolVLM2-2.2B-Instruct-GGUF
+llama-mtmd-cli -hf ggml-org/SmolVLM2-256M-Video-Instruct-GGUF
+llama-mtmd-cli -hf ggml-org/SmolVLM2-500M-Video-Instruct-GGUF
+
+# Pixtral 12B
+llama-mtmd-cli -hf ggml-org/pixtral-12b-GGUF
+```
+
 ## How it works and what is `mmproj`?
 
 Multimodal support in `llama.cpp` works by encoding images into embeddings using a separate model component, and then feeding these embeddings into the language model.
@@ -45,3 +67,9 @@ Multimodal projector (`mmproj`) files are specific to each model architecture. P
 - [MiniCPM-o 2.6](../../docs/multimodal/minicpmo2.6.md)
 - [IBM Granite Vision](../../docs/multimodal/granitevision.md)
 - [Google Gemma 3](../../docs/multimodal/gemma3.md)
+
+For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` flag to get the `mmproj` file:
+- [Gemma 3](https://huggingface.co/collections/google/gemma-3-release-67c6c6f89c4f76621268bb6d) - Note: 1B variant does not have vision support
+- SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
+- SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
+- [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint
diff --git a/examples/llava/clip-impl.h b/examples/llava/clip-impl.h
index faa8a9d5e9..8d310fb027 100644
--- a/examples/llava/clip-impl.h
+++ b/examples/llava/clip-impl.h
@@ -60,6 +60,7 @@
 #define TN_ATTN_V          "%s.blk.%d.attn_v.%s"
 #define TN_ATTN_OUTPUT     "%s.blk.%d.attn_out.%s"
 #define TN_FFN_DOWN        "%s.blk.%d.ffn_down.%s"
+#define TN_FFN_GATE        "%s.blk.%d.ffn_gate.%s"
 #define TN_FFN_UP          "%s.blk.%d.ffn_up.%s"
 #define TN_LN_1            "%s.blk.%d.ln1.%s"
 #define TN_LN_2            "%s.blk.%d.ln2.%s"
@@ -73,6 +74,7 @@
 #define TN_MM_INP_PROJ     "mm.input_projection.weight" // gemma3
 #define TN_MM_SOFT_EMB_N   "mm.soft_emb_norm.weight"    // gemma3
 #define TN_MM_PROJECTOR    "mm.model.fc.weight"         // idefics3
+#define TN_TOK_IMG_BREAK   "v.token_embd.img_break"     // pixtral
 
 // mimicpmv
 #define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
@@ -101,6 +103,7 @@ enum projector_type {
     PROJECTOR_TYPE_MERGER,
     PROJECTOR_TYPE_GEMMA3,
     PROJECTOR_TYPE_IDEFICS3,
+    PROJECTOR_TYPE_PIXTRAL,
     PROJECTOR_TYPE_UNKNOWN,
 };
 
@@ -113,6 +116,7 @@ static std::map PROJECTOR_TYPE_NAMES = {
     { PROJECTOR_TYPE_MERGER,    "qwen2vl_merger"},
     { PROJECTOR_TYPE_GEMMA3,    "gemma3"},
     { PROJECTOR_TYPE_IDEFICS3,  "idefics3"},
+    { PROJECTOR_TYPE_PIXTRAL,   "pixtral"},
 };
 
 static projector_type clip_projector_type_from_string(const std::string & str) {
diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp
index ad39e1bd61..4eec4a2646 100644
--- a/examples/llava/clip.cpp
+++ b/examples/llava/clip.cpp
@@ -163,7 +163,8 @@ struct clip_hparams {
 
     patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT;
 
-    float eps;
+    float eps = 1e-6;
+    float rope_theta = 0.0;
 
     std::vector image_grid_pinpoints;
     int32_t image_crop_resolution;
@@ -187,11 +188,17 @@ struct clip_layer {
     struct ggml_tensor * ln_1_b = nullptr;
 
     // ff
-    struct ggml_tensor * ff_i_w = nullptr;
-    struct ggml_tensor * ff_i_b = nullptr;
+    struct ggml_tensor * ff_i_w = nullptr; // legacy naming
+    struct ggml_tensor * ff_i_b = nullptr; // legacy naming
+    struct ggml_tensor * ff_o_w = nullptr; // legacy naming
+    struct ggml_tensor * ff_o_b = nullptr; // legacy naming
 
-    struct ggml_tensor * ff_o_w = nullptr;
-    struct ggml_tensor * ff_o_b = nullptr;
+    struct ggml_tensor * ff_up_w = nullptr;
+    struct ggml_tensor * ff_up_b = nullptr;
+    struct ggml_tensor * ff_gate_w = nullptr;
+    struct ggml_tensor * ff_gate_b = nullptr;
+    struct ggml_tensor * ff_down_w = nullptr;
+    struct ggml_tensor * ff_down_b = nullptr;
 
     // layernorm 2
     struct ggml_tensor * ln_2_w = nullptr;
@@ -297,6 +304,9 @@ struct clip_vision_model {
     // gemma3
     struct ggml_tensor * mm_input_proj_w = nullptr;
     struct ggml_tensor * mm_soft_emb_norm_w = nullptr;
+
+    // pixtral
+    struct ggml_tensor * token_embd_img_break = nullptr;
 };
 
 struct clip_ctx {
@@ -329,6 +339,7 @@ struct clip_ctx {
     ggml_backend_t backend_cpu;
     ggml_backend_buffer_ptr buf;
 
+    int max_nodes = 8192;
     ggml_backend_sched_ptr sched;
 
     clip_image_size load_image_size;
@@ -544,6 +555,218 @@ static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_im
     return gf;
 }
 
+// implementation of the 2D RoPE without adding a new op in ggml
+static ggml_tensor * build_rope_2d(
+    ggml_cgraph * gf,
+    ggml_context * ctx0,
+    ggml_tensor * cur,
+    ggml_tensor * pos_h,
+    ggml_tensor * pos_w,
+    const float freq_base
+) {
+    ggml_tensor * tmp;
+    const int64_t n_dim  = cur->ne[0];
+    const int64_t n_head = cur->ne[1];
+    const int64_t n_pos  = cur->ne[2];
+
+    // for example, if we have cur tensor of shape (n_dim=8, n_head, n_pos)
+    // we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3
+    // first half of cur will use 1e-0, 1e-2 (even)
+    // second half of cur will use 1e-1, 1e-3 (odd)
+    //
+    // for the first half, the trick here is to rotate n_dim/2, so inv_freq will be even
+    //  ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
+    // then for the second half, we use freq_scale to shift the inv_freq
+    //  ^ why? replace (2i) with (2i+1) in the above equation
+    const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim);
+
+    // first half
+    {
+        cur = ggml_rope_ext_inplace(
+            ctx0,
+            cur,
+            pos_h,      // positions
+            nullptr,    // freq factors
+            n_dim/2,    // n_dims
+            0, 0, freq_base,
+            1.0f, 0.0f, 1.0f, 0.0f, 0.0f
+        );
+    }
+
+    // second half
+    {
+        tmp = ggml_view_3d(ctx0, cur,
+            n_dim/2, n_head, n_pos,
+            ggml_row_size(cur->type, n_dim),
+            ggml_row_size(cur->type, n_dim*n_head),
+            n_dim/2 * ggml_element_size(cur));
+        tmp = ggml_rope_ext_inplace(
+            ctx0,
+            tmp,
+            pos_w,      // positions
+            nullptr,    // freq factors
+            n_dim/2,    // n_dims
+            0, 0, freq_base,
+            freq_scale_odd,
+            0.0f, 1.0f, 0.0f, 0.0f
+        );
+        // calculate inplace (modify cur directly)
+        ggml_build_forward_expand(gf, tmp);
+    }
+
+    return cur;
+}
+
+static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
+    const auto & model = ctx->vision_model;
+    const auto & hparams = model.hparams;
+
+    GGML_ASSERT(ctx->proj_type == PROJECTOR_TYPE_PIXTRAL);
+    GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1
+
+    int image_size_width  = imgs.entries[0]->nx;
+    int image_size_height = imgs.entries[0]->ny;
+
+    const int patch_size  = hparams.patch_size;
+    const int n_patches_x = image_size_width  / patch_size;
+    const int n_patches_y = image_size_height / patch_size;
+    const int num_patches = n_patches_x * n_patches_y;
+    const int hidden_size = hparams.hidden_size;
+    const int n_head      = hparams.n_head;
+    const int d_head      = hidden_size / n_head;
+    const int n_layer     = hparams.n_layer;
+    const float eps       = hparams.eps;
+
+    struct ggml_init_params params = {
+        /*.mem_size   =*/ ctx->buf_compute_meta.size(),
+        /*.mem_buffer =*/ ctx->buf_compute_meta.data(),
+        /*.no_alloc   =*/ true,
+    };
+
+    ggml_context_ptr ctx0_ptr(ggml_init(params));
+    auto ctx0 = ctx0_ptr.get();
+
+    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+    // input raw
+    struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3);
+    ggml_set_name(inp_raw, "inp_raw");
+    ggml_set_input(inp_raw);
+
+    // 2D input positions
+    struct ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
+    ggml_set_name(pos_h, "pos_h");
+    ggml_set_input(pos_h);
+    struct ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
+    ggml_set_name(pos_w, "pos_w");
+    ggml_set_input(pos_w);
+
+    struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+    inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
+    inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
+
+    struct ggml_tensor * embeddings = inp;
+
+    // pre-layer norm
+    embeddings = ggml_mul(ctx0, ggml_rms_norm(ctx0, embeddings, eps), model.pre_ln_w);
+
+    // loop over layers
+    for (int il = 0; il < n_layer; il++) {
+        struct ggml_tensor * cur = embeddings;
+
+        // pre-attention norm
+        cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.layers[il].ln_1_w);
+
+        // self-attention
+        {
+            struct ggml_tensor * Q = ggml_mul_mat(ctx0, model.layers[il].q_w, cur);
+
+            Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
+            Q = build_rope_2d(gf, ctx0, Q, pos_h, pos_w, hparams.rope_theta);
+            Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
+
+            struct ggml_tensor * K = ggml_mul_mat(ctx0, model.layers[il].k_w, cur);
+
+            K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
+            K = build_rope_2d(gf, ctx0, K, pos_h, pos_w, hparams.rope_theta);
+            K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
+
+            struct ggml_tensor * V = ggml_mul_mat(ctx0, model.layers[il].v_w, cur);
+
+            V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches);
+            V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
+
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+            KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
+
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
+            KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
+            KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+            cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
+
+            cur = ggml_mul_mat(ctx0, model.layers[il].o_w, cur);
+        }
+
+        // re-add the layer input, e.g., residual
+        cur = ggml_add(ctx0, cur, embeddings);
+
+        embeddings = cur; // embeddings = residual, cur = hidden_states
+
+        // pre-ffn norm
+        cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.layers[il].ln_2_w);
+
+        // feed-forward
+        {
+            ggml_tensor * gate_proj = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur);
+            ggml_tensor * up_proj   = ggml_mul_mat(ctx0, model.layers[il].ff_up_w,   cur);
+            gate_proj = ggml_silu(ctx0, gate_proj); // pixtral uses silu
+            cur = ggml_mul(ctx0, up_proj, gate_proj);
+            cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
+        }
+
+        // residual 2
+        cur = ggml_add(ctx0, embeddings, cur);
+
+        embeddings = cur;
+    }
+
+    // LlavaMultiModalProjector (with GELU activation)
+    {
+        embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
+        embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
+
+        embeddings = ggml_gelu(ctx0, embeddings);
+        embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
+        embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
+    }
+
+    // arrangement of the [IMG_BREAK] token
+    {
+        // not efficient, but works
+        // the trick is to view the embeddings as a 3D tensor with shape [hidden_size, n_patches_per_row, n_rows]
+        // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
+        // after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows]
+
+        const int n_embd_text     = embeddings->ne[0];
+        const int n_tokens_output = num_patches + n_patches_y - 1; // one [IMG_BREAK] per row, except the last row
+
+        ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, n_patches_x, n_patches_y);
+        ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, n_patches_y);
+        tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor
+        tok = ggml_add(ctx0, tok, model.token_embd_img_break);
+        cur = ggml_concat(ctx0, cur, tok, 1);
+        embeddings = ggml_view_2d(ctx0, cur,
+            n_embd_text, n_tokens_output,
+            ggml_row_size(cur->type, n_embd_text), 0);
+    }
+
+    // build the graph
+    ggml_build_forward_expand(gf, embeddings);
+
+    return gf;
+}
+
 static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
     if (!ctx->has_vision_encoder) {
         LOG_ERR("This gguf file seems to have no vision encoder\n");
@@ -1118,6 +1341,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
             {
                 res = clip_image_build_graph_siglip(ctx, imgs);
             } break;
+        case PROJECTOR_TYPE_PIXTRAL:
+            {
+                res = clip_image_build_graph_pixtral(ctx, imgs);
+            } break;
         default:
             {
                 // TODO: we should have one build_* function per model
@@ -1279,6 +1506,10 @@ struct clip_model_loader {
                 {
                     get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
                 } break;
+            case PROJECTOR_TYPE_PIXTRAL:
+                {
+                    hparams.rope_theta = 10000.0f;
+                } break;
             default:
                 break;
         }
@@ -1350,16 +1581,26 @@ struct clip_model_loader {
             layer.o_w    = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight"));
             layer.ln_1_w = get_tensor(string_format(TN_LN_1,        "v", il, "weight"), false);
             layer.ln_2_w = get_tensor(string_format(TN_LN_2,        "v", il, "weight"), false);
-            layer.ff_i_w = get_tensor(string_format(TN_FFN_DOWN,    "v", il, "weight"));
-            layer.ff_o_w = get_tensor(string_format(TN_FFN_UP,      "v", il, "weight"));
             layer.k_b    = get_tensor(string_format(TN_ATTN_K,      "v", il, "bias"), false);
             layer.q_b    = get_tensor(string_format(TN_ATTN_Q,      "v", il, "bias"), false);
             layer.v_b    = get_tensor(string_format(TN_ATTN_V,      "v", il, "bias"), false);
             layer.o_b    = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false);
             layer.ln_1_b = get_tensor(string_format(TN_LN_1,        "v", il, "bias"), false);
             layer.ln_2_b = get_tensor(string_format(TN_LN_2,        "v", il, "bias"), false);
-            layer.ff_i_b = get_tensor(string_format(TN_FFN_DOWN,    "v", il, "bias"), false);
-            layer.ff_o_b = get_tensor(string_format(TN_FFN_UP,      "v", il, "bias"), false);
+
+            // new naming
+            layer.ff_up_w   = get_tensor(string_format(TN_FFN_UP,   "v", il, "weight"));
+            layer.ff_up_b   = get_tensor(string_format(TN_FFN_UP,   "v", il, "bias"),   false);
+            layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), false);
+            layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, "v", il, "bias"),   false);
+            layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
+            layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"),   false);
+
+            // legacy naming (the in and out is reversed! don't ask me why)
+            layer.ff_i_w = layer.ff_down_w;
+            layer.ff_o_w = layer.ff_up_w;
+            layer.ff_i_b = layer.ff_down_b;
+            layer.ff_o_b = layer.ff_up_b;
         }
 
         switch (ctx_clip.proj_type) {
@@ -1475,6 +1716,15 @@ struct clip_model_loader {
                 {
                     vision_model.projection = get_tensor(TN_MM_PROJECTOR);
                 } break;
+            case PROJECTOR_TYPE_PIXTRAL:
+                {
+                    vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
+                    vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
+                    vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
+                    vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
+                    // [IMG_BREAK] token embedding
+                    vision_model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
+                } break;
             default:
                 GGML_ASSERT(false && "unknown projector type");
         }
@@ -1517,18 +1767,17 @@ struct clip_model_loader {
     }
 
     void alloc_compute_meta() {
-        ctx_clip.buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead());
+        ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
 
         // create a fake batch
         clip_image_f32_batch batch;
         clip_image_f32_ptr img(clip_image_f32_init());
         clip_image_size image_size;
-        image_size.width  = clip_get_image_size(&ctx_clip);
-        image_size.height = clip_get_image_size(&ctx_clip);
-        int n_patches = clip_get_image_size(&ctx_clip) / image_size.width;
-        img->nx = n_patches;
-        img->ny = n_patches;
-        img->buf.resize(n_patches * image_size.width * image_size.height * 3);
+        image_size.width  = ctx_clip.vision_model.hparams.image_size;
+        image_size.height = ctx_clip.vision_model.hparams.image_size;
+        img->nx = image_size.width;
+        img->ny = image_size.height;
+        img->buf.resize(image_size.width * image_size.height * 3);
         batch.entries.push_back(std::move(img));
 
         ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch, image_size, false);
@@ -1916,6 +2165,26 @@ struct image_manipulation {
         }
     }
 
+    // calculate the size of the **resized** image, while preserving the aspect ratio
+    // the calculated size will be aligned to the nearest multiple of align_size
+    // if H or W size is larger than max_dimension, it will be resized to max_dimension
+    static clip_image_size calc_size_preserved_ratio(const clip_image_size & inp_size, const int align_size, const int max_dimension) {
+        if (inp_size.width <= 0 || inp_size.height <= 0 || align_size <= 0 || max_dimension <= 0) {
+            return {0, 0};
+        }
+
+        float scale = std::min(1.0f, std::min(static_cast(max_dimension) / inp_size.width,
+                                              static_cast(max_dimension) / inp_size.height));
+
+        float target_width_f  = static_cast(inp_size.width)  * scale;
+        float target_height_f = static_cast(inp_size.height) * scale;
+
+        int aligned_width  = GGML_PAD((int)target_width_f,  align_size);
+        int aligned_height = GGML_PAD((int)target_height_f, align_size);
+
+        return {aligned_width, aligned_height};
+    }
+
 private:
     static inline int clip(int x, int lower, int upper) {
         return std::max(lower, std::min(x, upper));
@@ -2247,8 +2516,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
         res_imgs->entries.push_back(std::move(img_f32));
         return true;
     }
-
-    if (ctx->has_glm_projector
+    else if (ctx->has_glm_projector
             || ctx->proj_type == PROJECTOR_TYPE_GEMMA3
             || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
         clip_image_u8 resized_image;
@@ -2260,6 +2528,15 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
         res_imgs->entries.push_back(std::move(img_f32));
         return true;
     }
+    else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
+        clip_image_u8 resized_image;
+        auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
+        image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
+        clip_image_f32_ptr img_f32(clip_image_f32_init());
+        normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
+        res_imgs->entries.push_back(std::move(img_f32));
+        return true;
+    }
 
     // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
     // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
@@ -2387,6 +2664,10 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
         n_patches = 256;
     } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
         n_patches /= ctx->vision_model.hparams.proj_scale_factor;
+    } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
+        int n_patches_x = img->nx / params.patch_size;
+        int n_patches_y = img->ny / params.patch_size;
+        n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
     }
 
     return n_patches;
@@ -2540,10 +2821,15 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
         struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
         float * data = (float *)malloc(ggml_nbytes(inp_raw));
 
+        // TODO @ngxson : this whole code block is ugly, will need to be refactored
         for (size_t i = 0; i < imgs.entries.size(); i++) {
             const int nx = imgs.entries[i]->nx;
             const int ny = imgs.entries[i]->ny;
-            if (!(ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger)) {
+
+            if (ctx->has_glm_projector
+                    || ctx->has_llava_projector
+                    || ctx->proj_type == PROJECTOR_TYPE_GEMMA3
+                    || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
                 GGML_ASSERT(nx == image_size && ny == image_size);
             }
 
@@ -2657,6 +2943,24 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
         else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
             // do nothing
         }
+        else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
+            // set the 2D positions
+            int n_patches_per_col = image_size_width / patch_size;
+            std::vector pos_data(num_positions);
+            struct ggml_tensor * pos;
+            // dimension H
+            pos = ggml_graph_get_tensor(gf, "pos_h");
+            for (int i = 0; i < num_positions; i++) {
+                pos_data[i] = i / n_patches_per_col;
+            }
+            ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos));
+            // dimension W
+            pos = ggml_graph_get_tensor(gf, "pos_w");
+            for (int i = 0; i < num_positions; i++) {
+                pos_data[i] = i % n_patches_per_col;
+            }
+            ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos));
+        }
         else {
             struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
 
@@ -2849,6 +3153,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
         case PROJECTOR_TYPE_LDPV2:
             return ctx->vision_model.mm_model_peg_0_b->ne[0];
         case PROJECTOR_TYPE_MLP:
+        case PROJECTOR_TYPE_PIXTRAL:
             return ctx->vision_model.mm_2_b->ne[0];
         case PROJECTOR_TYPE_MLP_NORM:
             return ctx->vision_model.mm_3_b->ne[0];
diff --git a/examples/llava/mtmd.cpp b/examples/llava/mtmd.cpp
index c3fb2f18a2..11ca7b30f1 100644
--- a/examples/llava/mtmd.cpp
+++ b/examples/llava/mtmd.cpp
@@ -190,6 +190,11 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
         // https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215
         marker_modified = "" + ctx->image_marker + "";
         string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
+
+    } else if (proj_type == PROJECTOR_TYPE_PIXTRAL) {
+        // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
+        marker_modified = ctx->image_marker + "[IMG_END]";
+        string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
     }
 
     // llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
@@ -219,7 +224,7 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
 
         for (auto & entry : batch_f32.entries) {
             mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
-            image_tokens->nx = clip_n_patches(ctx->ctx_clip);
+            image_tokens->nx = clip_n_patches_by_img(ctx->ctx_clip, entry.get());
             image_tokens->ny = 1;
             image_tokens->batch_f32.entries.push_back(std::move(entry));
             image_tokens->id = id;
@@ -313,8 +318,13 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
                 }
 
             } else {
+                size_t n_tokens = 0;
+                for (const auto & entry : batch_f32.entries) {
+                    n_tokens += clip_n_patches_by_img(ctx->ctx_clip, entry.get());
+                }
+
                 mtmd_image_tokens_ptr image_tokens(new mtmd_image_tokens);
-                image_tokens->nx = clip_n_patches(ctx->ctx_clip) * batch_f32.entries.size(); // TODO @ngxson : use clip_n_patches_by_image
+                image_tokens->nx = n_tokens;
                 image_tokens->ny = 1; // TODO
                 image_tokens->batch_f32 = std::move(batch_f32);
                 image_tokens->id = bitmaps[i_img].id; // optional
@@ -382,7 +392,7 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens)
         // TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
         const auto & entries = image_tokens->batch_f32.entries;
         for (size_t i = 0; i < entries.size(); i++) {
-            int n_tokens_per_image = clip_n_patches(ctx->ctx_clip);
+            int n_tokens_per_image = clip_n_patches_by_img(ctx->ctx_clip, entries[i].get());
             ok = clip_image_encode(
                 ctx->ctx_clip,
                 ctx->n_threads,
diff --git a/examples/llava/tests.sh b/examples/llava/tests.sh
index 8752fc267a..e612857edc 100755
--- a/examples/llava/tests.sh
+++ b/examples/llava/tests.sh
@@ -13,6 +13,14 @@ mkdir -p $SCRIPT_DIR/output
 PROJ_ROOT="$SCRIPT_DIR/../.."
 cd $PROJ_ROOT
 
+# Check if the first argument is "big", then run test with big models
+# This is useful if we're running the script on a larger machine, so we can test the big models
+RUN_BIG_TESTS=false
+if [ "${1:-}" = "big" ]; then
+    RUN_BIG_TESTS=true
+    echo "Include BIG models..."
+fi
+
 ###############
 
 arr_bin=()
@@ -28,6 +36,12 @@ add_test() {
     arr_tmpl+=("$tmpl")
 }
 
+add_test_big() {
+    if [ "$RUN_BIG_TESTS" = true ]; then
+        add_test "$@"
+    fi
+}
+
 add_test "llama-mtmd-cli"  "ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0"
 add_test "llama-mtmd-cli"  "ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M"
 add_test "llama-mtmd-cli"  "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0"
@@ -42,6 +56,9 @@ add_test "llama-mtmd-cli"  "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
 add_test "llama-mtmd-cli"  "openbmb/MiniCPM-o-2_6-gguf:Q4_0"
 add_test "llama-qwen2vl-cli"  "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
 
+# to test the big models, run: ./tests.sh big
+add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
+
 # these models always give the wrong answer, not sure why
 # add_test "llama-mtmd-cli"  "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"
 # add_test "llama-mtmd-cli"  "ggml-org/SmolVLM-256M-Instruct-GGUF:Q8_0"
diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py
index 59510bd0c2..b81017b142 100644
--- a/gguf-py/gguf/constants.py
+++ b/gguf-py/gguf/constants.py
@@ -485,6 +485,7 @@ class MODEL_TENSOR(IntEnum):
     V_ENC_OUTPUT         = auto()
     V_ENC_OUTPUT_NORM    = auto()
     V_ENC_FFN_UP         = auto()
+    V_ENC_FFN_GATE       = auto()
     V_ENC_FFN_DOWN       = auto()
     V_PRE_NORM           = auto()
     V_POST_NORM          = auto()
@@ -501,6 +502,7 @@ class MODEL_TENSOR(IntEnum):
     V_RESMPL_Q_NORM      = auto() # minicpmv
     V_RESMPL_PROJ        = auto() # minicpmv
     V_RESMPL_QUERY       = auto() # minicpmv
+    V_TOK_EMBD_IMG_BREAK = auto() # pixtral
 
 
 MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@@ -737,6 +739,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.V_ENC_OUTPUT:              "v.blk.{bid}.attn_out",
     MODEL_TENSOR.V_ENC_OUTPUT_NORM:         "v.blk.{bid}.ln2",
     MODEL_TENSOR.V_ENC_FFN_UP:              "v.blk.{bid}.ffn_up",
+    MODEL_TENSOR.V_ENC_FFN_GATE:            "v.blk.{bid}.ffn_gate",
     MODEL_TENSOR.V_ENC_FFN_DOWN:            "v.blk.{bid}.ffn_down",
     MODEL_TENSOR.V_PRE_NORM:                "v.pre_ln",
     MODEL_TENSOR.V_POST_NORM:               "v.post_ln",
@@ -753,6 +756,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.V_RESMPL_Q_NORM:           "resampler.ln_q",
     MODEL_TENSOR.V_RESMPL_PROJ:             "resampler.proj",
     MODEL_TENSOR.V_RESMPL_QUERY:            "resampler.query",
+    MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK:      "v.token_embd.img_break", # pixtral
 }
 
 MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@@ -771,6 +775,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.V_ENC_OUTPUT,
         MODEL_TENSOR.V_ENC_OUTPUT_NORM,
         MODEL_TENSOR.V_ENC_FFN_UP,
+        MODEL_TENSOR.V_ENC_FFN_GATE,
         MODEL_TENSOR.V_ENC_FFN_DOWN,
         MODEL_TENSOR.V_PRE_NORM,
         MODEL_TENSOR.V_POST_NORM,
@@ -787,6 +792,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.V_RESMPL_Q_NORM,
         MODEL_TENSOR.V_RESMPL_PROJ,
         MODEL_TENSOR.V_RESMPL_QUERY,
+        MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
     ],
     MODEL_ARCH.LLAMA: [
         MODEL_TENSOR.TOKEN_EMBD,
@@ -2129,6 +2135,7 @@ class GGUFValueType(IntEnum):
 class VisionProjectorType:
     GEMMA3 = "gemma3"
     IDEFICS3 = "idefics3"
+    PIXTRAL = "pixtral"
 
 
 # Items here are (block size, type size)
diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py
index 3ff378c136..1d70551973 100644
--- a/gguf-py/gguf/tensor_mapping.py
+++ b/gguf-py/gguf/tensor_mapping.py
@@ -914,6 +914,7 @@ class TensorNameMap:
             "vision_tower.vision_model.embeddings.patch_embedding",
             "vpm.embeddings.patch_embedding",
             "model.vision_model.embeddings.patch_embedding", # SmolVLM
+            "vision_tower.patch_conv", # pixtral
         ),
 
         MODEL_TENSOR.V_ENC_EMBD_POS: (
@@ -926,52 +927,65 @@ class TensorNameMap:
             "vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj",
             "vpm.encoder.layers.{bid}.self_attn.q_proj",
             "model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM
+            "vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral
         ),
 
         MODEL_TENSOR.V_ENC_ATTN_K: (
             "vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj",
             "vpm.encoder.layers.{bid}.self_attn.k_proj",
             "model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM
+            "vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral
         ),
 
         MODEL_TENSOR.V_ENC_ATTN_V: (
             "vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj",
             "vpm.encoder.layers.{bid}.self_attn.v_proj",
             "model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM
+            "vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral
         ),
 
         MODEL_TENSOR.V_ENC_INPUT_NORM: (
             "vision_tower.vision_model.encoder.layers.{bid}.layer_norm1",
             "vpm.encoder.layers.{bid}.layer_norm1",
             "model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
+            "vision_tower.transformer.layers.{bid}.attention_norm", # pixtral
         ),
 
         MODEL_TENSOR.V_ENC_OUTPUT: (
             "vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj",
             "vpm.encoder.layers.{bid}.self_attn.out_proj",
             "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
+            "vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral
         ),
 
         MODEL_TENSOR.V_ENC_OUTPUT_NORM: (
             "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2",
             "vpm.encoder.layers.{bid}.layer_norm2",
             "model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
+            "vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral
         ),
 
         MODEL_TENSOR.V_ENC_FFN_UP: (
             "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
             "vpm.encoder.layers.{bid}.mlp.fc1",
             "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 (note: name is swapped)
+            "vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral
+        ),
+
+        MODEL_TENSOR.V_ENC_FFN_GATE: (
+            "vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral
         ),
 
         MODEL_TENSOR.V_ENC_FFN_DOWN: (
             "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2",
             "vpm.encoder.layers.{bid}.mlp.fc2",
             "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 (note: name is swapped)
+            "vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral
         ),
 
         MODEL_TENSOR.V_PRE_NORM: (
             "vision_tower.vision_model.pre_layrnorm",
+            "vision_tower.ln_pre", # pixtral
         ),
 
         MODEL_TENSOR.V_POST_NORM: (
@@ -1030,6 +1044,10 @@ class TensorNameMap:
         MODEL_TENSOR.V_RESMPL_QUERY: (
             "resampler.query",
         ),
+
+        MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: (
+            "v.token_embd.img_break", # for pixtral, this is a generated vector
+        ),
     }
 
     # architecture-specific block mappings
diff --git a/include/llama.h b/include/llama.h
index 5657fbf0a7..a13350e15b 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -111,6 +111,7 @@ extern "C" {
         LLAMA_VOCAB_PRE_TYPE_TRILLION       = 31,
         LLAMA_VOCAB_PRE_TYPE_BAILINGMOE     = 32,
         LLAMA_VOCAB_PRE_TYPE_LLAMA4         = 33,
+        LLAMA_VOCAB_PRE_TYPE_PIXTRAL        = 34,
     };
 
     enum llama_rope_type {
diff --git a/models/ggml-vocab-pixtral.gguf.inp b/models/ggml-vocab-pixtral.gguf.inp
new file mode 100644
index 0000000000..9baf7d77ae
--- /dev/null
+++ b/models/ggml-vocab-pixtral.gguf.inp
@@ -0,0 +1,112 @@
+ied 4 ½ months
+__ggml_vocab_test__
+Führer
+__ggml_vocab_test__
+
+__ggml_vocab_test__
+ 
+__ggml_vocab_test__
+  
+__ggml_vocab_test__
+   
+__ggml_vocab_test__
+	
+__ggml_vocab_test__
+
+
+__ggml_vocab_test__
+
+
+
+__ggml_vocab_test__
+
+
+
+
+__ggml_vocab_test__
+	
+
+__ggml_vocab_test__
+Hello world
+__ggml_vocab_test__
+ Hello world
+__ggml_vocab_test__
+Hello World
+__ggml_vocab_test__
+ Hello World
+__ggml_vocab_test__
+ Hello World!
+__ggml_vocab_test__
+Hello, world!
+__ggml_vocab_test__
+ Hello, world!
+__ggml_vocab_test__
+ this is 🦙.cpp
+__ggml_vocab_test__
+w048 7tuijk dsdfhu
+__ggml_vocab_test__
+нещо на Български
+__ggml_vocab_test__
+កាន់តែពិសេសអាចខលចេញ
+__ggml_vocab_test__
+🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)
+__ggml_vocab_test__
+Hello
+__ggml_vocab_test__
+ Hello
+__ggml_vocab_test__
+  Hello
+__ggml_vocab_test__
+   Hello
+__ggml_vocab_test__
+    Hello
+__ggml_vocab_test__
+    Hello
+    Hello
+__ggml_vocab_test__
+ (
+__ggml_vocab_test__
+
+ =
+__ggml_vocab_test__
+' era
+__ggml_vocab_test__
+Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
+__ggml_vocab_test__
+!!!!!!
+__ggml_vocab_test__
+3
+__ggml_vocab_test__
+33
+__ggml_vocab_test__
+333
+__ggml_vocab_test__
+3333
+__ggml_vocab_test__
+33333
+__ggml_vocab_test__
+333333
+__ggml_vocab_test__
+3333333
+__ggml_vocab_test__
+33333333
+__ggml_vocab_test__
+333333333
+__ggml_vocab_test__
+Cửa Việt
+__ggml_vocab_test__
+ discards
+__ggml_vocab_test__
+
+ 
+
+ 
+
+
+ 	 		 	
+  
+   
+    
+     
+🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````""""......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL
+__ggml_vocab_test__
diff --git a/models/ggml-vocab-pixtral.gguf.out b/models/ggml-vocab-pixtral.gguf.out
new file mode 100644
index 0000000000..53309d1bc9
--- /dev/null
+++ b/models/ggml-vocab-pixtral.gguf.out
@@ -0,0 +1,46 @@
+ 2014 1032 1052 1032 28504 6972
+ 1070 7088 1258
+
+ 1032
+ 1256
+ 1293
+ 1009
+ 1010
+ 1267
+ 4688
+ 1009 1010
+ 22177 4304
+ 45383 4304
+ 22177 5325
+ 45383 5325
+ 45383 5325 1033
+ 22177 1044 4304 1033
+ 45383 1044 4304 1033
+ 1593 1395 119685 1166 1153 1046 51228
+ 1119 1048 1052 1056 1032 1055 17391 23216 30203 7785 17279
+ 3337 30757 1902 4200 63073 3671
+ 1225 1158 1128 1225 1158 1182 1225 1158 1147 1225 1159 1139 1225 1158 1143 1225 1159 1130 1225 1158 1150 1225 1158 1183 1225 1158 1159 1225 21359 1225 1158 1159 1225 1158 1162 1225 1158 1182 1225 1158 1133 1225 1158 1129 1225 1158 1155 1225 1158 1133 1225 21359 1225 1158 1137
+ 1240 1159 1154 1128 1319 13052 1041 119685 1152 1182 29568 1240 1159 1140 1171 1239 1184 1143 1319 88181 1873 3659 1275 56421 1621 1041 126241 1133 1319 11234 1873 26303 1455 1934 2246 3754 10835 1041
+ 22177
+ 45383
+ 1032 45383
+ 1256 45383
+ 1293 45383
+ 1293 45383 1010 1293 45383
+ 1319
+ 1010 1376
+ 1039 4033
+ 22177 1044 1404 48054 1033 3075 1584 1636 119685 1152 1129 3082 26060 2998 63614 82278 1049 1051 1049 1052 1049 1053 1049 6434 6749
+ 7290 7290 7290
+ 1051
+ 1051 1051
+ 1051 1051 1051
+ 1051 1051 1051 1051
+ 1051 1051 1051 1051 1051
+ 1051 1051 1051 1051 1051 1051
+ 1051 1051 1051 1051 1051 1051 1051
+ 1051 1051 1051 1051 1051 1051 1051 1051
+ 1051 1051 1051 1051 1051 1051 1051 1051 1051
+ 1067 59503 28783
+ 3724 4058
+ 1010 1032 1267 1032 4688 1032 17152 1458 29356 1010 1256 1010 1293 1010 1260 1010 1652 1010 1240 1159 1154 1128 1319 13052 1041 119685 1152 1182 29568 1240 1159 1140 1171 1239 1184 1143 1319 88181 1873 3659 1275 56421 1621 1041 126241 1133 119685 1166 1153 1240 1159 1166 1153 1032 1051 1032 1051 1051 1032 1051 1051 1051 1032 1051 1051 1051 1051 1032 1051 1051 1051 1051 1051 1032 1051 1051 1051 1051 1051 1051 1032 1051 1051 1051 1051 1051 1051 1051 1032 1051 1051 1051 1051 1051 1051 1051 1051 1032 1051 1046 1051 1032 1051 1791 1051 1032 1051 2880 1051 71881 1158 1128 1225 1158 1182 1225 1158 1147 1225 1159 1139 1225 1158 1143 1225 1159 1130 1225 1158 1150 1225 1158 1183 1225 1158 1159 1225 21359 1225 1158 1159 1225 1158 1162 1225 1158 1182 1225 1158 1133 1240 1159 1152 1129 3082 26060 2998 63614 82278 1049 1051 1049 1052 1049 1053 1049 6434 6749 45577 1045 6626 43555 2843 30757 1902 4200 63073 3671 14931 20040 20040 1657 1657 1975 14135 14135 83923 7290 7290 7290 45509 45509 45509 1362 6483 2151 1576 1116 2189 1514 1681 2156 1044 1576 3609 1636 5257 1063 1576 1077 1605 5257 1362 7534 3180 1494 1044 1576 1068 1636 2479 2269 26883 1063 2837 1039 45654 1261 54297 1076
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
index 480605173d..50ded286f3 100644
--- a/src/llama-vocab.cpp
+++ b/src/llama-vocab.cpp
@@ -1506,7 +1506,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                     tokenizer_pre == "llama3"   ||
                     tokenizer_pre == "llama-v3" ||
                     tokenizer_pre == "llama-bpe"||
-                    tokenizer_pre == "falcon3") {
+                    tokenizer_pre == "falcon3"  ||
+                    tokenizer_pre == "pixtral") {
                 pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
                 ignore_merges = true;
                 add_bos = true;

From 56304069599f4dd9749d94b9bca2c2c65bb27c02 Mon Sep 17 00:00:00 2001
From: pl752 
Date: Thu, 24 Apr 2025 02:32:35 +0500
Subject: [PATCH 33/39] llama-mtmd-cli: Sigint rework in mtmd vision example
 (#13080)

* Sigint rework in mtmd vision example

* Applied suggestions on mtmd-cli PR

* Forgot to invert one of the conditions

* Update examples/llava/mtmd-cli.cpp

* Removed redundant exit check

---------

Co-authored-by: pl752 
Co-authored-by: Xuan-Son Nguyen 
---
 examples/llava/mtmd-cli.cpp | 31 ++++++++++++++++++++++++-------
 1 file changed, 24 insertions(+), 7 deletions(-)

diff --git a/examples/llava/mtmd-cli.cpp b/examples/llava/mtmd-cli.cpp
index e80845a2c5..89af7331a1 100644
--- a/examples/llava/mtmd-cli.cpp
+++ b/examples/llava/mtmd-cli.cpp
@@ -24,7 +24,9 @@
 #include 
 #endif
 
-static bool g_is_generating = false;
+// volatile, because of signal being an interrupt
+static volatile bool g_is_generating = false;
+static volatile bool g_is_interrupted = false;
 
 /**
  * Please note that this is NOT a production-ready stuff.
@@ -50,8 +52,10 @@ static void sigint_handler(int signo) {
             g_is_generating = false;
         } else {
             console::cleanup();
-            LOG("\nInterrupted by user\n");
-            _exit(130);
+            if (g_is_interrupted) {
+                _exit(1);
+            }
+            g_is_interrupted = true;
         }
     }
 }
@@ -167,7 +171,7 @@ struct decode_embd_batch {
 static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int n_predict) {
     llama_tokens generated_tokens;
     for (int i = 0; i < n_predict; i++) {
-        if (i > n_predict || !g_is_generating) {
+        if (i > n_predict || !g_is_generating || g_is_interrupted) {
             printf("\n");
             break;
         }
@@ -184,6 +188,11 @@ static int generate_response(mtmd_cli_context & ctx, common_sampler * smpl, int
         printf("%s", common_token_to_piece(ctx.lctx, token_id).c_str());
         fflush(stdout);
 
+        if (g_is_interrupted) {
+            printf("\n");
+            break;
+        }
+
         // eval the token
         common_batch_clear(ctx.batch);
         common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true);
@@ -219,6 +228,9 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, std::vect
     text.add_special   = add_bos;
     text.parse_special = true;
     mtmd_input_chunks chunks;
+
+    if (g_is_interrupted) return 0;
+
     int32_t res = mtmd_tokenize(ctx.ctx_vision.get(), chunks, text, bitmaps);
     if (res != 0) {
         LOG_ERR("Unable to tokenize prompt, res = %d\n", res);
@@ -276,6 +288,8 @@ int main(int argc, char ** argv) {
 #endif
     }
 
+    if (g_is_interrupted) return 130;
+
     if (is_single_turn) {
         g_is_generating = true;
         if (params.prompt.find("<__image__>") == std::string::npos) {
@@ -287,7 +301,7 @@ int main(int argc, char ** argv) {
         if (eval_message(ctx, msg, params.image, true)) {
             return 1;
         }
-        if (generate_response(ctx, smpl, n_predict)) {
+        if (!g_is_interrupted && generate_response(ctx, smpl, n_predict)) {
             return 1;
         }
 
@@ -302,12 +316,13 @@ int main(int argc, char ** argv) {
         std::vector images_fname;
         std::string content;
 
-        while (true) {
+        while (!g_is_interrupted) {
             g_is_generating = false;
             LOG("\n> ");
             console::set_display(console::user_input);
             std::string line;
             console::readline(line, false);
+            if (g_is_interrupted) break;
             console::set_display(console::reset);
             line = string_strip(line);
             if (line.empty()) {
@@ -335,6 +350,7 @@ int main(int argc, char ** argv) {
             msg.role = "user";
             msg.content = content;
             int ret = eval_message(ctx, msg, images_fname, is_first_msg);
+            if (g_is_interrupted) break;
             if (ret == 2) {
                 // non-fatal error
                 images_fname.clear();
@@ -352,6 +368,7 @@ int main(int argc, char ** argv) {
             is_first_msg = false;
         }
     }
+    if (g_is_interrupted) LOG("\nInterrupted by user\n");
     llama_perf_context_print(ctx.lctx);
-    return 0;
+    return g_is_interrupted ? 130 : 0;
 }

From b3b6d862cfdf190e1b9ad961639a25f5ebc0c7e3 Mon Sep 17 00:00:00 2001
From: Eve <139727413+netrunnereve@users.noreply.github.com>
Date: Thu, 24 Apr 2025 07:18:33 +0000
Subject: [PATCH 34/39] vulkan: matmul gcn tuning (#13016)

* tune matmul for gcn

* this one is more power efficient

* Update ggml/src/ggml-vulkan/ggml-vulkan.cpp

Co-authored-by: 0cc4m 

* disable this tune for the proprietary driver

---------

Co-authored-by: 0cc4m 
---
 ggml/src/ggml-vulkan/ggml-vulkan.cpp | 7 +++++++
 1 file changed, 7 insertions(+)

diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
index 39f3cd343a..c0bdb9e17a 100644
--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp
@@ -246,6 +246,7 @@ struct vk_device_struct {
     bool pipeline_robustness;
     vk::Device device;
     uint32_t vendor_id;
+    vk::DriverId driver_id;
     vk_device_architecture architecture;
     vk_queue compute_queue;
     vk_queue transfer_queue;
@@ -1740,6 +1741,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
         m_warptile_mmq_int = { 128,  64,  64, 32, subgroup_size_8,     32, 2, 2, 2, 1, subgroup_size_8 };
         s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32,       32, 2, 2, 1, 1, subgroup_size_8 };
 
+        // chip specific tuning
+        if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
+            m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
+        }
+
         l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
         m_mmq_wg_denoms = m_wg_denoms = { 64,  64, 1 };
         s_mmq_wg_denoms = s_wg_denoms = { 32,  32, 1 };
@@ -2658,6 +2664,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
         device->physical_device.getProperties2(&props2);
         device->properties = props2.properties;
         device->vendor_id = device->properties.vendorID;
+        device->driver_id = driver_props.driverID;
 
         const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
 

From 7604a7d6b80e78eef8f275fc700d0f64820d672f Mon Sep 17 00:00:00 2001
From: Georgi Gerganov 
Date: Thu, 24 Apr 2025 10:38:30 +0300
Subject: [PATCH 35/39] metal : fix floating-point range of attention scores in
 FA kernels (#13090)

ggml-ci
---
 ggml/src/ggml-metal/ggml-metal.metal | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
index 8d6e99e621..9f4147e939 100644
--- a/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ggml/src/ggml-metal/ggml-metal.metal
@@ -3192,7 +3192,7 @@ kernel void kernel_flash_attn_ext(
 
     {
         float S[Q] = { [0 ... Q-1] = 0.0f };
-        float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
+        float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 };
 
         // thread indices inside the simdgroup
         // TODO: see if we can utilize quad-group functions for better performance
@@ -3452,7 +3452,7 @@ kernel void kernel_flash_attn_ext(
     // reduce the warps sequentially
     for (ushort sg = 1; sg < nsg; ++sg) {
         float S = { 0.0f };
-        float M = { -__FLT16_MAX__/2 };
+        float M = { -__FLT_MAX__/2 };
 
         threadgroup_barrier(mem_flags::mem_threadgroup);
 
@@ -3699,7 +3699,7 @@ kernel void kernel_flash_attn_ext_vec(
 
     {
         float S = 0.0f;
-        float M = -__FLT16_MAX__/2;
+        float M = -__FLT_MAX__/2;
 
         // thread indices inside the simdgroup
         const short tx = tiisg%NL;

From 80982e815e67bae2442237f4e11466f44c9a2988 Mon Sep 17 00:00:00 2001
From: Xuan-Son Nguyen 
Date: Thu, 24 Apr 2025 12:14:13 +0200
Subject: [PATCH 36/39] arg : clean up handling --mmproj with -hf (#13082)

* arg : clean up handling --mmproj with -hf

* rm change about no_mmproj

* Revert "rm change about no_mmproj"

This reverts commit 2cac8e0efb629d66c612f137e75d562f94bb9e6c.

* handle no_mmproj explicitly

* skip download mmproj on examples not using it
---
 common/arg.cpp              | 68 +++++++++++++++++++++++++++----------
 common/common.h             |  1 +
 examples/llava/mtmd-cli.cpp |  1 +
 3 files changed, 53 insertions(+), 17 deletions(-)

diff --git a/common/arg.cpp b/common/arg.cpp
index 1cfd0168d9..85ba411146 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -38,6 +38,11 @@
 
 using json = nlohmann::ordered_json;
 
+std::initializer_list mmproj_examples = {
+    LLAMA_EXAMPLE_LLAVA,
+    // TODO: add LLAMA_EXAMPLE_SERVER when it's ready
+};
+
 common_arg & common_arg::set_examples(std::initializer_list examples) {
     this->examples = std::move(examples);
     return *this;
@@ -641,11 +646,16 @@ static struct common_hf_file_res common_get_hf_file(const std::string &, const s
 // utils
 //
 
-static void common_params_handle_model(
+struct handle_model_result {
+    bool found_mmproj = false;
+    common_params_model mmproj;
+};
+
+static handle_model_result common_params_handle_model(
         struct common_params_model & model,
         const std::string & bearer_token,
-        const std::string & model_path_default,
-        bool is_mmproj = false) { // TODO: move is_mmproj to an enum when we have more files?
+        const std::string & model_path_default) {
+    handle_model_result result;
     // handle pre-fill default model path and url based on hf_repo and hf_file
     {
         if (!model.hf_repo.empty()) {
@@ -657,7 +667,12 @@ static void common_params_handle_model(
                         exit(1); // built without CURL, error message already printed
                     }
                     model.hf_repo = auto_detected.repo;
-                    model.hf_file = is_mmproj ? auto_detected.mmprojFile : auto_detected.ggufFile;
+                    model.hf_file = auto_detected.ggufFile;
+                    if (!auto_detected.mmprojFile.empty()) {
+                        result.found_mmproj   = true;
+                        result.mmproj.hf_repo = model.hf_repo;
+                        result.mmproj.hf_file = auto_detected.mmprojFile;
+                    }
                 } else {
                     model.hf_file = model.path;
                 }
@@ -694,6 +709,8 @@ static void common_params_handle_model(
             exit(1);
         }
     }
+
+    return result;
 }
 
 const std::vector kv_cache_types = {
@@ -827,16 +844,25 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
         throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
     }
 
-    common_params_handle_model(params.model,             params.hf_token, DEFAULT_MODEL_PATH);
-    common_params_handle_model(params.speculative.model, params.hf_token, "");
-    common_params_handle_model(params.vocoder.model,     params.hf_token, "");
-
-    // allow --mmproj to be set from -hf
-    // assuming that mmproj is always in the same repo as text model
-    if (!params.model.hf_repo.empty() && ctx_arg.ex == LLAMA_EXAMPLE_LLAVA) {
-        params.mmproj.hf_repo = params.model.hf_repo;
+    // handle model and download
+    {
+        auto res = common_params_handle_model(params.model, params.hf_token, DEFAULT_MODEL_PATH);
+        if (params.no_mmproj) {
+            params.mmproj = {};
+        } else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
+            // optionally, handle mmproj model when -hf is specified
+            params.mmproj = res.mmproj;
+        }
+        // only download mmproj if the current example is using it
+        for (auto & ex : mmproj_examples) {
+            if (ctx_arg.ex == ex) {
+                common_params_handle_model(params.mmproj,    params.hf_token, "");
+                break;
+            }
+        }
+        common_params_handle_model(params.speculative.model, params.hf_token, "");
+        common_params_handle_model(params.vocoder.model,     params.hf_token, "");
     }
-    common_params_handle_model(params.mmproj,            params.hf_token, "", true);
 
     if (params.escape) {
         string_process_escapes(params.prompt);
@@ -2095,18 +2121,25 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
     ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_NO_CONT_BATCHING"));
     add_opt(common_arg(
         {"--mmproj"}, "FILE",
-        "path to a multimodal projector file for LLaVA. see examples/llava/README.md",
+        "path to a multimodal projector file. see examples/llava/README.md",
         [](common_params & params, const std::string & value) {
             params.mmproj.path = value;
         }
-    ).set_examples({LLAMA_EXAMPLE_LLAVA}));
+    ).set_examples(mmproj_examples));
     add_opt(common_arg(
         {"--mmproj-url"}, "URL",
-        "URL to a multimodal projector file for LLaVA. see examples/llava/README.md",
+        "URL to a multimodal projector file. see examples/llava/README.md",
         [](common_params & params, const std::string & value) {
             params.mmproj.url = value;
         }
-    ).set_examples({LLAMA_EXAMPLE_LLAVA}));
+    ).set_examples(mmproj_examples));
+    add_opt(common_arg(
+        {"--no-mmproj"},
+        "explicitly disable multimodal projector, useful when using -hf",
+        [](common_params & params) {
+            params.no_mmproj = true;
+        }
+    ).set_examples(mmproj_examples));
     add_opt(common_arg(
         {"--image"}, "FILE",
         "path to an image file. use with multimodal models. Specify multiple times for batching",
@@ -2381,6 +2414,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
     add_opt(common_arg(
         {"-hf", "-hfr", "--hf-repo"}, "/[:quant]",
         "Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n"
+        "mmproj is also downloaded automatically if available. to disable, add --no-mmproj\n"
         "example: unsloth/phi-4-GGUF:q4_k_m\n"
         "(default: unused)",
         [](common_params & params, const std::string & value) {
diff --git a/common/common.h b/common/common.h
index e6eaa8e80c..70d3ef8f27 100644
--- a/common/common.h
+++ b/common/common.h
@@ -342,6 +342,7 @@ struct common_params {
 
     // multimodal models (see examples/llava)
     struct common_params_model mmproj;
+    bool no_mmproj = false;         // explicitly disable multimodal model
     std::vector image; // path to image file(s)
 
     // embedding
diff --git a/examples/llava/mtmd-cli.cpp b/examples/llava/mtmd-cli.cpp
index 89af7331a1..1937376057 100644
--- a/examples/llava/mtmd-cli.cpp
+++ b/examples/llava/mtmd-cli.cpp
@@ -261,6 +261,7 @@ int main(int argc, char ** argv) {
 
     if (params.mmproj.path.empty()) {
         show_additional_info(argc, argv);
+        LOG_ERR("ERR: Missing --mmproj argument\n");
         return 1;
     }
 

From 7c727fbe39150fbe8381f4fa43fed08719ebebe6 Mon Sep 17 00:00:00 2001
From: Xuan-Son Nguyen 
Date: Thu, 24 Apr 2025 14:04:14 +0200
Subject: [PATCH 37/39] arg : add --no-mmproj-offload (#13093)

* arg : add --no-mmproj-offload

* Update common/arg.cpp
---
 common/arg.cpp              | 7 +++++++
 common/common.h             | 1 +
 examples/llava/mtmd-cli.cpp | 7 ++++---
 3 files changed, 12 insertions(+), 3 deletions(-)

diff --git a/common/arg.cpp b/common/arg.cpp
index 85ba411146..9cbf985710 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -2140,6 +2140,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.no_mmproj = true;
         }
     ).set_examples(mmproj_examples));
+    add_opt(common_arg(
+        {"--no-mmproj-offload"},
+        "do not offload multimodal projector to GPU",
+        [](common_params & params) {
+            params.mmproj_use_gpu = false;
+        }
+    ).set_examples(mmproj_examples));
     add_opt(common_arg(
         {"--image"}, "FILE",
         "path to an image file. use with multimodal models. Specify multiple times for batching",
diff --git a/common/common.h b/common/common.h
index 70d3ef8f27..0a9dc0599f 100644
--- a/common/common.h
+++ b/common/common.h
@@ -342,6 +342,7 @@ struct common_params {
 
     // multimodal models (see examples/llava)
     struct common_params_model mmproj;
+    bool mmproj_use_gpu = true;     // use GPU for multimodal model
     bool no_mmproj = false;         // explicitly disable multimodal model
     std::vector image; // path to image file(s)
 
diff --git a/examples/llava/mtmd-cli.cpp b/examples/llava/mtmd-cli.cpp
index 1937376057..250e8c9a9e 100644
--- a/examples/llava/mtmd-cli.cpp
+++ b/examples/llava/mtmd-cli.cpp
@@ -40,7 +40,8 @@ static void show_additional_info(int /*argc*/, char ** argv) {
         "Usage: %s [options] -m  --mmproj  --image  -p \n\n"
         "  -m and --mmproj are required\n"
         "  -hf user/repo can replace both -m and --mmproj in most cases\n"
-        "  --image and -p are optional, if NOT provided, the CLI will run in chat mode\n",
+        "  --image and -p are optional, if NOT provided, the CLI will run in chat mode\n"
+        "  to disable using GPU for mmproj model, add --no-mmproj-offload\n",
         argv[0]
     );
 }
@@ -112,10 +113,10 @@ struct mtmd_cli_context {
     void init_vision_context(common_params & params) {
         const char * clip_path = params.mmproj.path.c_str();
         ctx_vision.reset(mtmd_init_from_file(clip_path, model, mtmd_context_params{
-            /* use_gpu */   true,
+            /* use_gpu */   params.mmproj_use_gpu,
             /* timings */   true,
             /* n_threads */ params.cpuparams.n_threads,
-            /* verbosity */ GGML_LOG_LEVEL_INFO,
+            /* verbosity */ params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO,
         }));
         if (!ctx_vision.get()) {
             LOG_ERR("Failed to load vision model from %s\n", clip_path);

From 572b3141d343d7f947bf53b57513016e90db5680 Mon Sep 17 00:00:00 2001
From: Georgi Gerganov 
Date: Thu, 24 Apr 2025 15:44:05 +0300
Subject: [PATCH 38/39] clang-tidy : disable warning about missing math
 parenthesis (#13091)

---
 .clang-tidy | 1 +
 1 file changed, 1 insertion(+)

diff --git a/.clang-tidy b/.clang-tidy
index 310c3d182c..5bc63bc6e2 100644
--- a/.clang-tidy
+++ b/.clang-tidy
@@ -13,6 +13,7 @@ Checks: >
     -readability-magic-numbers,
     -readability-uppercase-literal-suffix,
     -readability-simplify-boolean-expr,
+    -readability-math-missing-parentheses,
     clang-analyzer-*,
     -clang-analyzer-security.insecureAPI.DeprecatedOrUnsafeBufferHandling,
     performance-*,

From 13b4548877326fdabee3e831b8cfd65d9844383c Mon Sep 17 00:00:00 2001
From: Georgi Gerganov 
Date: Thu, 24 Apr 2025 16:00:10 +0300
Subject: [PATCH 39/39] cmake : do not include ./src as public for libllama
 (#13062)

* cmake : do not include ./src as public for libllama

ggml-ci

* cmake : rework tests

ggml-ci

* llguidance : remove unicode include

ggml-ci

* cmake : make c++17 private

ggml-ci
---
 common/arg.cpp                                |  2 -
 examples/CMakeLists.txt                       |  9 ---
 examples/gbnf-validator/CMakeLists.txt        |  5 --
 examples/quantize-stats/CMakeLists.txt        |  6 --
 grammars/README.md                            |  2 +-
 src/CMakeLists.txt                            |  5 +-
 tests/CMakeLists.txt                          | 69 +++++++++++--------
 tests/test-chat.cpp                           |  5 +-
 .../test-gbnf-validator.cpp                   |  4 +-
 tests/test-grammar-integration.cpp            |  5 +-
 tests/test-grammar-llguidance.cpp             |  3 +-
 tests/test-grammar-parser.cpp                 |  4 +-
 tests/test-json-schema-to-grammar.cpp         |  2 +-
 tests/test-llama-grammar.cpp                  |  3 +-
 .../test-quantize-stats.cpp                   |  3 +-
 tests/test-tokenizer-1-bpe.cpp                |  3 +-
 tests/test-tokenizer-1-spm.cpp                |  3 +-
 17 files changed, 64 insertions(+), 69 deletions(-)
 delete mode 100644 examples/gbnf-validator/CMakeLists.txt
 delete mode 100644 examples/quantize-stats/CMakeLists.txt
 rename examples/gbnf-validator/gbnf-validator.cpp => tests/test-gbnf-validator.cpp (98%)
 rename examples/quantize-stats/quantize-stats.cpp => tests/test-quantize-stats.cpp (99%)

diff --git a/common/arg.cpp b/common/arg.cpp
index 9cbf985710..0657553e4e 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -994,7 +994,6 @@ static void common_params_print_completion(common_params_context & ctx_arg) {
         "llama-embedding",
         "llama-eval-callback",
         "llama-export-lora",
-        "llama-gbnf-validator",
         "llama-gen-docs",
         "llama-gguf",
         "llama-gguf-hash",
@@ -1014,7 +1013,6 @@ static void common_params_print_completion(common_params_context & ctx_arg) {
         "llama-perplexity",
         "llama-q8dot",
         "llama-quantize",
-        "llama-quantize-stats",
         "llama-qwen2vl-cli",
         "llama-retrieval",
         "llama-run",
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
index 66cfab2c3b..37476f9043 100644
--- a/examples/CMakeLists.txt
+++ b/examples/CMakeLists.txt
@@ -21,11 +21,6 @@ else()
     add_subdirectory(embedding)
     add_subdirectory(eval-callback)
 
-    if (NOT WIN32)
-        # disabled on Windows because it uses internal functions not exported with LLAMA_API
-        add_subdirectory(gbnf-validator)
-    endif()
-
     add_subdirectory(gguf-hash)
     add_subdirectory(gguf-split)
     add_subdirectory(gguf)
@@ -58,10 +53,6 @@ else()
         add_subdirectory(convert-llama2c-to-ggml)
         add_subdirectory(cvector-generator)
         add_subdirectory(export-lora)
-        if (NOT WIN32)
-            # disabled on Windows because it uses internal functions not exported with LLAMA_API
-            add_subdirectory(quantize-stats)
-        endif()
         add_subdirectory(llava)
         if (GGML_RPC)
             add_subdirectory(rpc)
diff --git a/examples/gbnf-validator/CMakeLists.txt b/examples/gbnf-validator/CMakeLists.txt
deleted file mode 100644
index d2cb524c0a..0000000000
--- a/examples/gbnf-validator/CMakeLists.txt
+++ /dev/null
@@ -1,5 +0,0 @@
-set(TARGET llama-gbnf-validator)
-add_executable(${TARGET} gbnf-validator.cpp)
-install(TARGETS ${TARGET} RUNTIME)
-target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
-target_compile_features(${TARGET} PRIVATE cxx_std_17)
diff --git a/examples/quantize-stats/CMakeLists.txt b/examples/quantize-stats/CMakeLists.txt
deleted file mode 100644
index 9a3a0d3cd2..0000000000
--- a/examples/quantize-stats/CMakeLists.txt
+++ /dev/null
@@ -1,6 +0,0 @@
-set(TARGET llama-quantize-stats)
-add_executable(${TARGET} quantize-stats.cpp)
-install(TARGETS ${TARGET} RUNTIME)
-target_link_libraries(${TARGET} PRIVATE llama build_info ${CMAKE_THREAD_LIBS_INIT})
-target_include_directories(${TARGET} PRIVATE ../../common)
-target_compile_features(${TARGET} PRIVATE cxx_std_17)
diff --git a/grammars/README.md b/grammars/README.md
index 935213f5c1..5aa12acc1b 100644
--- a/grammars/README.md
+++ b/grammars/README.md
@@ -112,7 +112,7 @@ You can use GBNF grammars:
 
 - In [llama-server](../examples/server)'s completion endpoints, passed as the `grammar` body field
 - In [llama-cli](../examples/main), passed as the `--grammar` & `--grammar-file` flags
-- With [llama-gbnf-validator](../examples/gbnf-validator) tool, to test them against strings.
+- With [test-gbnf-validator](../tests/test-gbnf-validator.cpp), to test them against strings.
 
 ## JSON Schemas → GBNF
 
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 9f7ab13f1e..1cd316b03e 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -32,8 +32,9 @@ add_library(llama
             unicode.h
             )
 
-target_include_directories(llama PUBLIC . ../include)
-target_compile_features   (llama PUBLIC cxx_std_17) # don't bump
+target_include_directories(llama PRIVATE .)
+target_include_directories(llama PUBLIC ../include)
+target_compile_features   (llama PRIVATE cxx_std_17) # don't bump
 
 target_link_libraries(llama PUBLIC ggml)
 
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index 2bb210702a..ae68275251 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -1,5 +1,17 @@
 llama_add_compile_flags()
 
+function(llama_build source)
+    if (DEFINED LLAMA_TEST_NAME)
+        set(TEST_TARGET ${LLAMA_TEST_NAME})
+    else()
+        get_filename_component(TEST_TARGET ${source} NAME_WE)
+    endif()
+
+    add_executable(${TEST_TARGET} ${source})
+    target_link_libraries(${TEST_TARGET} PRIVATE common)
+    install(TARGETS ${TEST_TARGET} RUNTIME)
+endfunction()
+
 function(llama_test target)
     include(CMakeParseArguments)
     set(options)
@@ -36,7 +48,7 @@ endfunction()
 # - LABEL: label for the test (defaults to main)
 # - ARGS: arguments to pass to the test executable
 # - WORKING_DIRECTORY
-function(llama_target_and_test source)
+function(llama_build_and_test source)
     include(CMakeParseArguments)
     set(options)
     set(oneValueArgs NAME LABEL WORKING_DIRECTORY)
@@ -58,6 +70,7 @@ function(llama_target_and_test source)
     add_executable(${TEST_TARGET} ${source} get-model.cpp)
     install(TARGETS ${TEST_TARGET} RUNTIME)
     target_link_libraries(${TEST_TARGET} PRIVATE common)
+
     add_test(
         NAME ${TEST_TARGET}
         WORKING_DIRECTORY ${LLAMA_TEST_WORKING_DIRECTORY}
@@ -68,9 +81,7 @@ function(llama_target_and_test source)
 endfunction()
 
 # build test-tokenizer-0 target once and add many tests
-add_executable(test-tokenizer-0 test-tokenizer-0.cpp)
-target_link_libraries(test-tokenizer-0 PRIVATE common)
-install(TARGETS test-tokenizer-0 RUNTIME)
+llama_build(test-tokenizer-0.cpp)
 
 llama_test(test-tokenizer-0 NAME test-tokenizer-0-bert-bge          ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-bert-bge.gguf)
 llama_test(test-tokenizer-0 NAME test-tokenizer-0-command-r         ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-command-r.gguf)
@@ -87,27 +98,27 @@ llama_test(test-tokenizer-0 NAME test-tokenizer-0-refact            ARGS ${CMAKE
 llama_test(test-tokenizer-0 NAME test-tokenizer-0-starcoder         ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
 
 if (LLAMA_LLGUIDANCE)
-    llama_target_and_test(test-grammar-llguidance.cpp ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf)
+    llama_build_and_test(test-grammar-llguidance.cpp ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf)
 endif ()
 
 if (NOT WIN32)
     # these tests are disabled on Windows because they use internal functions not exported with LLAMA_API
-    llama_target_and_test(test-sampling.cpp)
-    llama_target_and_test(test-grammar-parser.cpp)
-    llama_target_and_test(test-grammar-integration.cpp)
-    llama_target_and_test(test-llama-grammar.cpp)
-    llama_target_and_test(test-chat.cpp)
+    llama_build_and_test(test-sampling.cpp)
+    llama_build_and_test(test-grammar-parser.cpp)
+    llama_build_and_test(test-grammar-integration.cpp)
+    llama_build_and_test(test-llama-grammar.cpp)
+    llama_build_and_test(test-chat.cpp)
     # TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8
     if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
-        llama_target_and_test(test-json-schema-to-grammar.cpp   WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..)
+        llama_build_and_test(test-json-schema-to-grammar.cpp   WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..)
         target_include_directories(test-json-schema-to-grammar PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../examples/server)
     endif()
 
+    llama_build(test-quantize-stats.cpp)
+    llama_build(test-gbnf-validator.cpp)
 
     # build test-tokenizer-1-bpe target once and add many tests
-    add_executable(test-tokenizer-1-bpe test-tokenizer-1-bpe.cpp)
-    target_link_libraries(test-tokenizer-1-bpe PRIVATE common)
-    install(TARGETS test-tokenizer-1-bpe RUNTIME)
+    llama_build(test-tokenizer-1-bpe.cpp)
 
     # TODO: disabled due to slowness
     #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-aquila    ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf)
@@ -120,37 +131,35 @@ if (NOT WIN32)
     #llama_test(test-tokenizer-1-bpe NAME test-tokenizer-1-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
 
     # build test-tokenizer-1-spm target once and add many tests
-    add_executable(test-tokenizer-1-spm test-tokenizer-1-spm.cpp)
-    target_link_libraries(test-tokenizer-1-spm PRIVATE common)
-    install(TARGETS test-tokenizer-1-spm RUNTIME)
+    llama_build(test-tokenizer-1-spm.cpp)
 
     llama_test(test-tokenizer-1-spm  NAME test-tokenizer-1-llama-spm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-spm.gguf)
     #llama_test(test-tokenizer-1-spm  NAME test-tokenizer-1-baichuan  ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-baichuan.gguf)
 
-    # llama_target_and_test(test-double-float.cpp) # SLOW
+    # llama_build_and_test(test-double-float.cpp) # SLOW
 endif()
 
-llama_target_and_test(test-log.cpp)
-llama_target_and_test(test-chat-template.cpp)
+llama_build_and_test(test-log.cpp)
+llama_build_and_test(test-chat-template.cpp)
 
 # this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)
 if (NOT WIN32)
-    llama_target_and_test(test-arg-parser.cpp)
+    llama_build_and_test(test-arg-parser.cpp)
 endif()
 
-# llama_target_and_test(test-opt.cpp) # SLOW
-llama_target_and_test(test-gguf.cpp)
-llama_target_and_test(test-backend-ops.cpp)
+# llama_build_and_test(test-opt.cpp) # SLOW
+llama_build_and_test(test-gguf.cpp)
+llama_build_and_test(test-backend-ops.cpp)
 
-llama_target_and_test(test-model-load-cancel.cpp  LABEL "model")
-llama_target_and_test(test-autorelease.cpp        LABEL "model")
+llama_build_and_test(test-model-load-cancel.cpp  LABEL "model")
+llama_build_and_test(test-autorelease.cpp        LABEL "model")
 
 if (NOT GGML_BACKEND_DL)
     # these tests use the backends directly and cannot be built with dynamic loading
-    llama_target_and_test(test-barrier.cpp)
-    llama_target_and_test(test-quantize-fns.cpp)
-    llama_target_and_test(test-quantize-perf.cpp)
-    llama_target_and_test(test-rope.cpp)
+    llama_build_and_test(test-barrier.cpp)
+    llama_build_and_test(test-quantize-fns.cpp)
+    llama_build_and_test(test-quantize-perf.cpp)
+    llama_build_and_test(test-rope.cpp)
 endif()
 
 
diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp
index a0bf6affe5..fa7aed82df 100644
--- a/tests/test-chat.cpp
+++ b/tests/test-chat.cpp
@@ -11,8 +11,9 @@
 #include 
 
 #include "chat.h"
-#include "llama-grammar.h"
-#include "unicode.h"
+
+#include "../src/unicode.h"
+#include "../src/llama-grammar.h"
 
 using json = nlohmann::ordered_json;
 
diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/tests/test-gbnf-validator.cpp
similarity index 98%
rename from examples/gbnf-validator/gbnf-validator.cpp
rename to tests/test-gbnf-validator.cpp
index a610e6a0b1..6547eec32f 100644
--- a/examples/gbnf-validator/gbnf-validator.cpp
+++ b/tests/test-gbnf-validator.cpp
@@ -1,5 +1,5 @@
-#include "unicode.h"
-#include "llama-grammar.h"
+#include "../src/unicode.h"
+#include "../src/llama-grammar.h"
 
 #include 
 #include 
diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp
index 8906086489..8988c347e3 100644
--- a/tests/test-grammar-integration.cpp
+++ b/tests/test-grammar-integration.cpp
@@ -2,10 +2,11 @@
 #undef NDEBUG
 #endif
 
-#include "unicode.h"
-#include "llama-grammar.h"
 #include "json-schema-to-grammar.h"
 
+#include "../src/unicode.h"
+#include "../src/llama-grammar.h"
+
 #include 
 #include 
 #include 
diff --git a/tests/test-grammar-llguidance.cpp b/tests/test-grammar-llguidance.cpp
index 3c19220e11..566b039a07 100644
--- a/tests/test-grammar-llguidance.cpp
+++ b/tests/test-grammar-llguidance.cpp
@@ -2,7 +2,6 @@
 #    undef NDEBUG
 #endif
 
-#include "unicode.h"
 #include "sampling.h"
 
 #include 
@@ -84,7 +83,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str,
 
             fprintf(stderr,
                     "\n NOTE: Debug grammar file generated. To analyze this failure in detail, run the following "
-                    "command:     ./llama-gbnf-validator test-grammar-integration.grammar.gbnf "
+                    "command:     ./test-gbnf-validator test-grammar-integration.grammar.gbnf "
                     "test-grammar-integration.string.txt\n\n");
         } else {
             fprintf(stdout, "✅︎\n");
diff --git a/tests/test-grammar-parser.cpp b/tests/test-grammar-parser.cpp
index 259172d999..67821a2d5c 100644
--- a/tests/test-grammar-parser.cpp
+++ b/tests/test-grammar-parser.cpp
@@ -3,7 +3,9 @@
 #endif
 
 #include "llama.h"
-#include "llama-grammar.h"
+
+// TODO: shold not include libllama sources
+#include "../src/llama-grammar.h"
 
 #include 
 
diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp
index 4d78e91426..e35134f3cb 100755
--- a/tests/test-json-schema-to-grammar.cpp
+++ b/tests/test-json-schema-to-grammar.cpp
@@ -4,7 +4,7 @@
 
 #include "json-schema-to-grammar.h"
 
-#include "llama-grammar.h"
+#include "../src/llama-grammar.h"
 
 #include 
 #include 
diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp
index e2129206be..cc198f3e3c 100644
--- a/tests/test-llama-grammar.cpp
+++ b/tests/test-llama-grammar.cpp
@@ -3,7 +3,8 @@
 #endif
 
 #include "llama.h"
-#include "llama-grammar.h"
+
+#include "../src/llama-grammar.h"
 
 #include 
 #include 
diff --git a/examples/quantize-stats/quantize-stats.cpp b/tests/test-quantize-stats.cpp
similarity index 99%
rename from examples/quantize-stats/quantize-stats.cpp
rename to tests/test-quantize-stats.cpp
index dd07ab9b37..db01059119 100644
--- a/examples/quantize-stats/quantize-stats.cpp
+++ b/tests/test-quantize-stats.cpp
@@ -1,8 +1,9 @@
 #include "ggml.h"
 #include "llama.h"
-#include "llama-model.h"
 #include "common.h"
 
+#include "../src/llama-model.h"
+
 #include 
 #include 
 #include 
diff --git a/tests/test-tokenizer-1-bpe.cpp b/tests/test-tokenizer-1-bpe.cpp
index 55425d88a7..b183da47f3 100644
--- a/tests/test-tokenizer-1-bpe.cpp
+++ b/tests/test-tokenizer-1-bpe.cpp
@@ -1,8 +1,9 @@
 #include "llama.h"
 #include "common.h"
-#include "unicode.h"
 #include "console.h"
 
+#include "../src/unicode.h"
+
 #include 
 #include 
 #include 
diff --git a/tests/test-tokenizer-1-spm.cpp b/tests/test-tokenizer-1-spm.cpp
index 9e7b77f31e..ba6e94ba8e 100644
--- a/tests/test-tokenizer-1-spm.cpp
+++ b/tests/test-tokenizer-1-spm.cpp
@@ -1,8 +1,9 @@
 #include "llama.h"
 #include "common.h"
-#include "unicode.h"
 #include "console.h"
 
+#include "../src/unicode.h"
+
 #include 
 #include 
 #include