From ccb87fa3ee1961ec915f77cb447706f471dca6a5 Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Sun, 22 Mar 2026 14:19:35 +0530 Subject: [PATCH 1/9] [CUDA] Increase number of output elements per-thread block if the K-dimension is small (#20635) * Increase per-thread work if the K-dimension is small With tensor parallelism, the K-dimension of the FFN-down matrices is split, which makes it quite small, especially for MOEs. For example, Qwen3-30b-A3B has a K-dimension of 768, and Qwen3235B-A22B has k-dimension of 1536. The current heuristic uses a group of 4 warps irrespective of K-dimension size, resulting in some of the threads being idle. This results in poor performance for these matrices. This change increases the number of output elements per block for such cases. * Limit this change to ncols_dst = 1 * tab to space --- ggml/src/ggml-cuda/mmvq.cu | 56 +++++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 632246e43f..024b3d8cf2 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -33,7 +33,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) } } -static constexpr __device__ int get_vdr_mmvq(ggml_type type) { +static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; @@ -173,11 +173,11 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d return 1; } -static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) { +static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id, bool small_k = false, int nwarps = 1) { if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { switch (ncols_dst) { case 1: - return 1; + return small_k ? nwarps : 1; case 2: case 3: case 4: @@ -193,7 +193,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int return 1; } -template +template __launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, @@ -208,7 +208,7 @@ static __global__ void mul_mat_vec_q( constexpr int vdr = get_vdr_mmvq(type); constexpr mmvq_parameter_table_id table_id = get_device_table_id(); constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id); - constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id); + constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); @@ -414,14 +414,16 @@ static __global__ void mul_mat_vec_q( template static std::pair calc_launch_params( const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens, - const int warp_size, const mmvq_parameter_table_id table_id) { - const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); + const int warp_size, const mmvq_parameter_table_id table_id, const bool small_k = false) { + const int nwarps = calc_nwarps(type, ncols_dst, table_id); + const int rpb = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps); + const int64_t nblocks = (nrows_x + rpb - 1) / rpb; const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens); - const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1); + const dim3 block_dims(warp_size, nwarps, 1); return {block_nums, block_dims}; } -template +template static void mul_mat_vec_q_switch_fusion( const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, @@ -434,7 +436,7 @@ static void mul_mat_vec_q_switch_fusion( const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (c_ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); @@ -444,7 +446,7 @@ static void mul_mat_vec_q_switch_fusion( GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); @@ -488,11 +490,33 @@ static void mul_mat_vec_q_switch_ncols_dst( switch (ncols_dst) { case 1: { constexpr int c_ncols_dst = 1; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); + + // When K is small, increase rows_per_block to match nwarps so each warp has more work to do + // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle. + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int vdr = get_vdr_mmvq(type); + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_iter_1warp = vdr * warp_size / qi; + const int nwarps = calc_nwarps(type, c_ncols_dst, table_id); + const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp; + if (use_small_k) { + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, + warp_size, table_id, true); + mul_mat_vec_q_switch_fusion( + vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } else { + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, + warp_size, table_id); + mul_mat_vec_q_switch_fusion( + vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } } break; case 2: { constexpr int c_ncols_dst = 2; From db9d8aa428012cc5593e18635d4c3c54095f5138 Mon Sep 17 00:00:00 2001 From: Patrick Buckley Date: Sun, 22 Mar 2026 03:05:51 -0700 Subject: [PATCH 2/9] ggml-cuda: native bf16 flash attention for vec kernel (#20525) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml-cuda: native bf16 flash attention for vec and tile kernels mma kernel still converts bf16 to fp16 before launch, native mma bf16 todo * ggml-cuda: address code owner review feedback reverted tile kernel changes to avoid larger refactor * fix ci failures on turing and hip * fix bf16 vec kernel compile on hip v_dot2 platforms * add comments --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/CMakeLists.txt | 11 ++--- ggml/src/ggml-cuda/convert.cuh | 6 +++ ggml/src/ggml-cuda/fattn-common.cuh | 48 +++++++++++++++++++ ggml/src/ggml-cuda/fattn-vec.cuh | 26 +++++++--- ggml/src/ggml-cuda/fattn.cu | 16 +++++++ .../fattn-vec-instance-bf16-bf16.cu | 7 +++ .../fattn-vec-instance-bf16-f16.cu | 7 +++ .../fattn-vec-instance-bf16-q4_0.cu | 7 +++ .../fattn-vec-instance-bf16-q4_1.cu | 7 +++ .../fattn-vec-instance-bf16-q5_0.cu | 7 +++ .../fattn-vec-instance-bf16-q5_1.cu | 7 +++ .../fattn-vec-instance-bf16-q8_0.cu | 7 +++ .../fattn-vec-instance-f16-bf16.cu | 7 +++ .../fattn-vec-instance-q4_0-bf16.cu | 7 +++ .../fattn-vec-instance-q4_1-bf16.cu | 7 +++ .../fattn-vec-instance-q5_0-bf16.cu | 7 +++ .../fattn-vec-instance-q5_1-bf16.cu | 7 +++ .../fattn-vec-instance-q8_0-bf16.cu | 7 +++ .../template-instances/generate_cu_files.py | 2 +- ggml/src/ggml-hip/CMakeLists.txt | 11 ++--- ggml/src/ggml-musa/CMakeLists.txt | 11 ++--- 21 files changed, 197 insertions(+), 25 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 262f88204e..419862101d 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -116,12 +116,11 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_SOURCES_CUDA ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) + list(APPEND GGML_SOURCES_CUDA + template-instances/fattn-vec-instance-f16-f16.cu + template-instances/fattn-vec-instance-q4_0-q4_0.cu + template-instances/fattn-vec-instance-q8_0-q8_0.cu + template-instances/fattn-vec-instance-bf16-bf16.cu) endif() ggml_add_backend_library(ggml-cuda diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index 09f9a33f90..b8caeacf09 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -41,6 +41,12 @@ template return __bfloat162float(x); } else if constexpr(std::is_same_v && std::is_same_v) { return __float22half2_rn(x); + } else if constexpr(std::is_same_v && std::is_same_v) { +#if !defined(GGML_USE_HIP) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __bfloat1622float2(x); +#else + return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x))); +#endif // !defined(GGML_USE_HIP) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 } else if constexpr(std::is_same_v && std::is_same_v) { // bypass compile error on cuda 12.0.1 #ifdef GGML_USE_HIP diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index e9abdf288c..c59a4db399 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -74,6 +74,37 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( return sum; } +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_bf16( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { + + const nv_bfloat162 * K_bf16 = (const nv_bfloat162 *) K_c; + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) { + __align__(16) nv_bfloat162 tmp[cpy_ne]; + ggml_cuda_memcpy_1(tmp, K_bf16 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { +#ifdef V_DOT2_F32_F16_AVAILABLE + // FIXME replace macros in vector FA kernel with templating and use FP32 for BF16 + ggml_cuda_mad(sum, ggml_cuda_cast(tmp[k_KQ_1]), __half22float2(((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1])); +#else + ggml_cuda_mad(sum, ggml_cuda_cast(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); +#endif // V_DOT2_F32_F16_AVAILABLE + } + } + + return sum; +} + template static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -321,6 +352,19 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_ } } +template +static __device__ __forceinline__ void dequantize_V_bf16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + static_assert(std::is_same_v, "BF16 V dequantization only supports float output"); + static_assert(ne % 2 == 0, "bad ne"); + __align__(16) nv_bfloat162 tmp[ne/2]; + ggml_cuda_memcpy_1(tmp, (const nv_bfloat16 *) vx + i0); + float2 * dst_f2 = (float2 *) dst; +#pragma unroll + for (int l = 0; l < ne/2; ++l) { + dst_f2[l] = ggml_cuda_cast(tmp[l]); + } +} + template static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { const block_q4_0 * x = (const block_q4_0 *) vx; @@ -547,6 +591,8 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() { return vec_dot_fattn_vec_KQ_q5_1; } else if constexpr (type_K == GGML_TYPE_Q8_0) { return vec_dot_fattn_vec_KQ_q8_0; + } else if constexpr (type_K == GGML_TYPE_BF16) { + return vec_dot_fattn_vec_KQ_bf16; } else { static_assert(type_K == -1, "bad type"); return nullptr; @@ -567,6 +613,8 @@ constexpr __device__ dequantize_V_t get_dequantize_V() { return dequantize_V_q5_1; } else if constexpr (type_V == GGML_TYPE_Q8_0) { return dequantize_V_q8_0; + } else if constexpr (type_V == GGML_TYPE_BF16) { + return dequantize_V_bf16; } else { static_assert(type_V == -1, "bad type"); return nullptr; diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 7cbe32633e..f0bd42a576 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -75,17 +75,17 @@ static __global__ void flash_attn_ext_vec( #endif // GGML_USE_HIP constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device(); - constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q; - constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q; + constexpr int nthreads_KQ = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_KQ_q; + constexpr int nthreads_V = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_V_q; static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K"); static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V"); - constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4; + constexpr int V_rows_per_thread = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 2*cpy_ne : 4; constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V; constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ(); - constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; + constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16 && type_K != GGML_TYPE_BF16; #ifdef V_DOT2_F32_F16_AVAILABLE constexpr dequantize_V_t dequantize_V = get_dequantize_V(); #else @@ -323,8 +323,18 @@ static __global__ void flash_attn_ext_vec( #pragma unroll for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { half2 tmp[V_rows_per_thread/2]; - dequantize_V(V + k*nb21, tmp, - 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); + if constexpr (type_V == GGML_TYPE_BF16) { + float2 tmp_f[V_rows_per_thread/2]; + dequantize_V(V + k*nb21, tmp_f, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); +#pragma unroll + for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { + tmp[i_VKQ_1] = __float22half2_rn(tmp_f[i_VKQ_1]); + } + } else { + dequantize_V(V + k*nb21, tmp, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); + } #pragma unroll for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { #pragma unroll @@ -563,6 +573,7 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_BF16); \ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0) @@ -570,6 +581,7 @@ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_BF16) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0) @@ -577,6 +589,7 @@ EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_BF16) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0) @@ -584,3 +597,4 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_BF16) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 85c177f496..a25a890db6 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -224,6 +224,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) @@ -231,6 +232,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) @@ -238,6 +240,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) @@ -245,6 +248,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) @@ -252,6 +256,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) @@ -259,10 +264,20 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q8_0) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) #else FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) #endif // GGML_CUDA_FA_ALL_QUANTS GGML_ABORT("fatal error"); @@ -355,6 +370,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const #endif // GGML_CUDA_FA_ALL_QUANTS case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_BF16: break; default: return BEST_FATTN_KERNEL_NONE; diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu new file mode 100644 index 0000000000..3a2fa99b05 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu new file mode 100644 index 0000000000..60f0f6f795 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu new file mode 100644 index 0000000000..489e05f08c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu new file mode 100644 index 0000000000..6fa3c26d30 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu new file mode 100644 index 0000000000..421027fb29 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu new file mode 100644 index 0000000000..abbc943480 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu new file mode 100644 index 0000000000..d641f859d8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu new file mode 100644 index 0000000000..d1071dc243 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu new file mode 100644 index 0000000000..8afda31423 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu new file mode 100644 index 0000000000..506864ac18 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu new file mode 100644 index 0000000000..0bbda8371e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu new file mode 100644 index 0000000000..79be24daf9 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu new file mode 100644 index 0000000000..45636e5e70 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index e382df1ae2..3b5ab12fc4 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -5,7 +5,7 @@ import os HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576] -TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"] +TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"] SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index f96c6e09a9..291b483745 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -71,12 +71,11 @@ if (GGML_CUDA_FA_ALL_QUANTS) list(APPEND GGML_SOURCES_ROCM ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) + list(APPEND GGML_SOURCES_ROCM + ../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu) endif() ggml_add_backend_library(ggml-hip diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index d76cb51977..cc53c812ce 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -48,12 +48,11 @@ if (MUSAToolkit_FOUND) list(APPEND GGML_SOURCES_MUSA ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) + list(APPEND GGML_SOURCES_MUSA + ../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu) endif() set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX) From f40a80b4f3cd00c4c405c45b7f316f7e77352323 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Sun, 22 Mar 2026 22:06:27 +0800 Subject: [PATCH 3/9] support bf16 and quantized type (#20803) --- ggml/src/ggml-sycl/ggml-sycl.cpp | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 2ec1421841..456b1699fa 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4667,22 +4667,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g if (a->ne[3] != b->ne[3]) { return false; } - ggml_type a_type = a->type; - if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS || - a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S || - a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ2_S || - a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ1_M - ) { - if (b->ne[1] == 1 && ggml_nrows(b) > 1) { - return false; - } - } + ggml_type src0_type = op->src[0]->type; - if (src0_type == GGML_TYPE_BF16 ) { - // TODO: support GGML_TYPE_BF16 - // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added - return false; - } // TODO: The configuration below needs more work to be supported with oneDNN if (ggml_is_permuted(a) && !ggml_is_contiguous(a) && From 81bc4d3ddcaa1767b60837d73a0c31babb1190cf Mon Sep 17 00:00:00 2001 From: Evgeny Kurnevsky Date: Sun, 22 Mar 2026 15:29:22 +0100 Subject: [PATCH 4/9] server: fix Host header (#20843) It should include port when it's not default. --- tools/server/server-models.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 4ac55cd158..62dae2db7a 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -1181,7 +1181,8 @@ server_http_proxy::server_http_proxy( continue; } if (key == "Host" || key == "host") { - req.set_header(key, host); + bool is_default_port = (scheme == "https" && port == 443) || (scheme == "http" && port == 80); + req.set_header(key, is_default_port ? host : host + ":" + std::to_string(port)); } else { req.set_header(key, value); } From 23c9182ce872666aff6bd0fc571c3e3b6ae5fd89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sun, 22 Mar 2026 17:45:10 +0100 Subject: [PATCH 5/9] jinja : refactor token advancement (#20864) * refactor token advancement * exercise sub-expressions --- common/jinja/parser.cpp | 62 +++++++++++++++++++++++------------------ tests/test-jinja.cpp | 6 ++++ 2 files changed, 41 insertions(+), 27 deletions(-) diff --git a/common/jinja/parser.cpp b/common/jinja/parser.cpp index 7970336ac0..4ae4477445 100644 --- a/common/jinja/parser.cpp +++ b/common/jinja/parser.cpp @@ -53,6 +53,13 @@ private: return tokens[current + offset]; } + const token & next() { + if (current >= tokens.size()) { + throw parser_exception("Parser Error: Unexpected EOF", source, tokens.empty() ? 0 : tokens.back().pos); + } + return tokens[current++]; + } + token expect(token::type type, const std::string& error) { const auto & t = peek(); if (t.t != type) { @@ -90,9 +97,9 @@ private: size_t start_pos = current; switch (peek().t) { case token::comment: - return mk_stmt(start_pos, tokens[current++].value); + return mk_stmt(start_pos, next().value); case token::text: - return mk_stmt(start_pos, tokens[current++].value); + return mk_stmt(start_pos, next().value); case token::open_statement: return parse_jinja_statement(); case token::open_expression: @@ -119,8 +126,7 @@ private: } size_t start_pos = current; - std::string name = peek().value; - current++; // consume identifier + std::string name = next().value; statement_ptr result; if (name == "set") { @@ -202,7 +208,7 @@ private: // Ignore generation blocks (transformers-specific) // See https://github.com/huggingface/transformers/pull/30650 for more information. result = mk_stmt(start_pos); - current++; + ++current; } else { throw std::runtime_error("Unknown statement: " + name); @@ -217,7 +223,7 @@ private: statements body; if (is(token::equals)) { - current++; + ++current; value = parse_expression_sequence(); } else { // parsing multiline set here @@ -280,7 +286,7 @@ private: exprs.push_back(primary ? parse_primary_expression() : parse_expression()); bool is_tuple = is(token::comma); while (is(token::comma)) { - current++; // consume comma + ++current; // consume comma exprs.push_back(primary ? parse_primary_expression() : parse_expression()); } return is_tuple ? mk_stmt(start_pos, std::move(exprs)) : std::move(exprs[0]); @@ -290,7 +296,7 @@ private: // e.g., `message` in `for message in messages` auto loop_var = parse_expression_sequence(true); // should be an identifier/tuple if (!is_identifier("in")) throw std::runtime_error("Expected 'in'"); - current++; + ++current; // consume 'in' // `messages` in `for message in messages` auto iterable = parse_expression(); @@ -305,7 +311,8 @@ private: } if (is_statement({"else"})) { - current += 2; + ++current; // consume {% + ++current; // consume 'else' expect(token::close_statement, "Expected %}"); while (!is_statement({"endfor"})) { alternate.push_back(parse_any()); @@ -347,7 +354,7 @@ private: auto left = parse_logical_and_expression(); while (is_identifier("or")) { size_t start_pos = current; - token op = tokens[current++]; + token op = next(); left = mk_stmt(start_pos, op, std::move(left), parse_logical_and_expression()); } return left; @@ -357,7 +364,7 @@ private: auto left = parse_logical_negation_expression(); while (is_identifier("and")) { size_t start_pos = current; - auto op = tokens[current++]; + auto op = next(); left = mk_stmt(start_pos, op, std::move(left), parse_logical_negation_expression()); } return left; @@ -367,7 +374,7 @@ private: // Try parse unary operators if (is_identifier("not")) { size_t start_pos = current; - auto op = tokens[current++]; + auto op = next(); return mk_stmt(start_pos, op, parse_logical_negation_expression()); } return parse_comparison_expression(); @@ -382,11 +389,12 @@ private: size_t start_pos = current; if (is_identifier("not") && peek(1).t == token::identifier && peek(1).value == "in") { op = {token::identifier, "not in", tokens[current].pos}; - current += 2; + ++current; // consume 'not' + ++current; // consume 'in' } else if (is_identifier("in")) { - op = tokens[current++]; + op = next(); } else if (is(token::comparison_binary_operator)) { - op = tokens[current++]; + op = next(); } else break; left = mk_stmt(start_pos, op, std::move(left), parse_additive_expression()); } @@ -397,7 +405,7 @@ private: auto left = parse_multiplicative_expression(); while (is(token::additive_binary_operator)) { size_t start_pos = current; - auto op = tokens[current++]; + auto op = next(); left = mk_stmt(start_pos, op, std::move(left), parse_multiplicative_expression()); } return left; @@ -407,7 +415,7 @@ private: auto left = parse_test_expression(); while (is(token::multiplicative_binary_operator)) { size_t start_pos = current; - auto op = tokens[current++]; + auto op = next(); left = mk_stmt(start_pos, op, std::move(left), parse_test_expression()); } return left; @@ -417,9 +425,9 @@ private: auto operand = parse_filter_expression(); while (is_identifier("is")) { size_t start_pos = current; - current++; + ++current; // consume 'is' bool negate = false; - if (is_identifier("not")) { current++; negate = true; } + if (is_identifier("not")) { ++current; negate = true; } auto test_id = parse_primary_expression(); // FIXME: tests can also be expressed like this: if x is eq 3 if (is(token::open_paren)) test_id = parse_call_expression(std::move(test_id)); @@ -432,7 +440,7 @@ private: auto operand = parse_call_member_expression(); while (is(token::pipe)) { size_t start_pos = current; - current++; + ++current; // consume pipe auto filter = parse_primary_expression(); if (is(token::open_paren)) filter = parse_call_expression(std::move(filter)); operand = mk_stmt(start_pos, std::move(operand), std::move(filter)); @@ -490,7 +498,7 @@ private: statement_ptr parse_member_expression(statement_ptr object) { size_t start_pos = current; while (is(token::dot) || is(token::open_square_bracket)) { - auto op = tokens[current++]; + auto op = next(); bool computed = op.t == token::open_square_bracket; statement_ptr prop; if (computed) { @@ -536,7 +544,7 @@ private: statement_ptr parse_primary_expression() { size_t start_pos = current; - auto t = tokens[current++]; + auto t = next(); switch (t.t) { case token::numeric_literal: if (t.value.find('.') != std::string::npos) { @@ -547,7 +555,7 @@ private: case token::string_literal: { std::string val = t.value; while (is(token::string_literal)) { - val += tokens[current++].value; + val += next().value; } return mk_stmt(start_pos, val); } @@ -562,9 +570,9 @@ private: statements vals; while (!is(token::close_square_bracket)) { vals.push_back(parse_expression()); - if (is(token::comma)) current++; + if (is(token::comma)) ++current; } - current++; + ++current; return mk_stmt(start_pos, std::move(vals)); } case token::open_curly_bracket: { @@ -573,9 +581,9 @@ private: auto key = parse_expression(); expect(token::colon, "Expected :"); pairs.push_back({std::move(key), parse_expression()}); - if (is(token::comma)) current++; + if (is(token::comma)) ++current; } - current++; + ++current; return mk_stmt(start_pos, std::move(pairs)); } default: diff --git a/tests/test-jinja.cpp b/tests/test-jinja.cpp index ef9c8f73c8..1550627bf0 100644 --- a/tests/test-jinja.cpp +++ b/tests/test-jinja.cpp @@ -2264,6 +2264,7 @@ static void test_fuzzing(testing & t) { t.test("malformed templates (should error, not crash)", [&](testing & t) { const std::vector malformed = { + "", "{{ x", "{% if %}", "{% for %}", @@ -2284,6 +2285,11 @@ static void test_fuzzing(testing & t) { for (const auto & tmpl : malformed) { t.assert_true("malformed: " + tmpl, fuzz_test_template(tmpl, json::object())); } + std::string tmpl = "{% for message in messages %}{{ message.role | string }} : {{ message.content if ('content' in message and message.content is not none) }}{% endfor %"; + while (tmpl.length() > 0) { + t.assert_true("malformed: " + tmpl, fuzz_test_template(tmpl, json::object())); + tmpl.pop_back(); + } }); t.test("type coercion edge cases", [&](testing & t) { From bd3f1d9d65113737e795cd6d62eb66f90628663c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 22 Mar 2026 17:53:33 +0100 Subject: [PATCH 6/9] CUDA: fix BF16 FA compilation (#20865) --- ggml/src/ggml-cuda/convert.cuh | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index b8caeacf09..f5d37c7b99 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -42,11 +42,15 @@ template } else if constexpr(std::is_same_v && std::is_same_v) { return __float22half2_rn(x); } else if constexpr(std::is_same_v && std::is_same_v) { -#if !defined(GGML_USE_HIP) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#ifdef GGML_USE_HIP + return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x))); +#else +#if __CUDA_ARCH__ >= 800 return __bfloat1622float2(x); #else - return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x))); -#endif // !defined(GGML_USE_HIP) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return make_float2(__bfloat162float(x.x), __bfloat162float(x.y)); +#endif // __CUDA_ARCH__ >= 800 +#endif // GGML_USE_HIP } else if constexpr(std::is_same_v && std::is_same_v) { // bypass compile error on cuda 12.0.1 #ifdef GGML_USE_HIP From 49bfddeca18e62fa3d39114a23e9fcbdf8a22388 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Sun, 22 Mar 2026 18:33:52 +0100 Subject: [PATCH 7/9] server: allow router to report child instances sleep status (#20849) * server: allow router to report child instances sleep status * refactor * move sleeping to state * nits --- tools/server/README.md | 7 ++++ tools/server/server-context.cpp | 3 ++ tools/server/server-context.h | 4 ++ tools/server/server-models.cpp | 70 ++++++++++++++++++++------------- tools/server/server-models.h | 40 +++++++++++++------ tools/server/server-queue.h | 12 +++++- tools/server/server.cpp | 9 ++++- 7 files changed, 102 insertions(+), 43 deletions(-) diff --git a/tools/server/README.md b/tools/server/README.md index 554444d74b..f584b60026 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -1634,6 +1634,13 @@ The `status` object can be: } ``` +```json +"status": { + "value": "sleeping", + "args": ["llama-server", "-ctx", "4096"] +} +``` + ### POST `/models/load`: Load a model Load a model diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 9de554e900..b79a5270b5 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -3033,6 +3033,9 @@ struct server_res_generator : server_http_res { } }; +void server_context::on_sleeping_changed(std::function callback) { + impl->queue_tasks.on_sleeping_state(std::move(callback)); +} // diff --git a/tools/server/server-context.h b/tools/server/server-context.h index 75f3d2de56..a4d2201cbe 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -74,6 +74,10 @@ struct server_context { // get server metadata (read-only), can only be called after load_model() // not thread-safe, should only be used from the main thread server_context_meta get_meta() const; + + // register a callback to be called when sleeping state changes + // must be set before load_model() is called + void on_sleeping_changed(std::function callback); }; diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 62dae2db7a..7e61844f08 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -39,7 +39,8 @@ extern char **environ; #define DEFAULT_STOP_TIMEOUT 10 // seconds #define CMD_ROUTER_TO_CHILD_EXIT "cmd_router_to_child:exit" -#define CMD_CHILD_TO_ROUTER_READY "cmd_child_to_router:ready" +#define CMD_CHILD_TO_ROUTER_READY "cmd_child_to_router:ready" // also sent when waking up from sleep +#define CMD_CHILD_TO_ROUTER_SLEEP "cmd_child_to_router:sleep" // address for child process, this is needed because router may run on 0.0.0.0 // ref: https://github.com/ggml-org/llama.cpp/issues/17862 @@ -380,7 +381,7 @@ void server_models::update_meta(const std::string & name, const server_model_met if (it != mapping.end()) { it->second.meta = meta; } - cv.notify_all(); // notify wait_until_loaded + cv.notify_all(); // notify wait_until_loading_finished } bool server_models::has_model(const std::string & name) { @@ -503,7 +504,7 @@ void server_models::unload_lru() { { std::unique_lock lk(mutex); for (const auto & m : mapping) { - if (m.second.meta.is_active()) { + if (m.second.meta.is_running()) { count_active++; if (m.second.meta.last_used < lru_last_used) { lru_model_name = m.first; @@ -546,7 +547,7 @@ void server_models::load(const std::string & name) { if (base_params.models_max > 0) { size_t count_active = 0; for (const auto & m : mapping) { - if (m.second.meta.is_active()) { + if (m.second.meta.is_running()) { count_active++; } } @@ -605,15 +606,15 @@ void server_models::load(const std::string & name) { std::thread log_thread([&]() { // read stdout/stderr and forward to main server log // also handle status report from child process - bool state_received = false; // true if child state received if (stdout_file) { char buffer[4096]; while (fgets(buffer, sizeof(buffer), stdout_file) != nullptr) { LOG("[%5d] %s", port, buffer); - if (!state_received && std::strstr(buffer, CMD_CHILD_TO_ROUTER_READY) != nullptr) { - // child process is ready + std::string str(buffer); + if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_READY)) { this->update_status(name, SERVER_MODEL_STATUS_LOADED, 0); - state_received = true; + } else if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_SLEEP)) { + this->update_status(name, SERVER_MODEL_STATUS_SLEEPING, 0); } } } else { @@ -706,13 +707,13 @@ void server_models::unload(const std::string & name) { std::lock_guard lk(mutex); auto it = mapping.find(name); if (it != mapping.end()) { - if (it->second.meta.is_active()) { - SRV_INF("unloading model instance name=%s\n", name.c_str()); + if (it->second.meta.is_running()) { + SRV_INF("stopping model instance name=%s\n", name.c_str()); stopping_models.insert(name); cv_stop.notify_all(); // status change will be handled by the managing thread } else { - SRV_WRN("model instance name=%s is not loaded\n", name.c_str()); + SRV_WRN("model instance name=%s is not running\n", name.c_str()); } } } @@ -722,8 +723,8 @@ void server_models::unload_all() { { std::lock_guard lk(mutex); for (auto & [name, inst] : mapping) { - if (inst.meta.is_active()) { - SRV_INF("unloading model instance name=%s\n", name.c_str()); + if (inst.meta.is_running()) { + SRV_INF("stopping model instance name=%s\n", name.c_str()); stopping_models.insert(name); cv_stop.notify_all(); // status change will be handled by the managing thread @@ -750,7 +751,7 @@ void server_models::update_status(const std::string & name, server_model_status cv.notify_all(); } -void server_models::wait_until_loaded(const std::string & name) { +void server_models::wait_until_loading_finished(const std::string & name) { std::unique_lock lk(mutex); cv.wait(lk, [this, &name]() { auto it = mapping.find(name); @@ -761,22 +762,25 @@ void server_models::wait_until_loaded(const std::string & name) { }); } -bool server_models::ensure_model_loaded(const std::string & name) { +bool server_models::ensure_model_ready(const std::string & name) { auto meta = get_meta(name); if (!meta.has_value()) { throw std::runtime_error("model name=" + name + " is not found"); } - if (meta->status == SERVER_MODEL_STATUS_LOADED) { - return false; // already loaded + if (meta->is_ready()) { + return false; // ready for taking requests + } + if (meta->status == SERVER_MODEL_STATUS_SLEEPING) { + return false; // child is sleeping but still running; new request will wake it up } if (meta->status == SERVER_MODEL_STATUS_UNLOADED) { SRV_INF("model name=%s is not loaded, loading...\n", name.c_str()); load(name); } - // for loading state + // wait for loading to complete SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str()); - wait_until_loaded(name); + wait_until_loading_finished(name); // check final status meta = get_meta(name); @@ -792,8 +796,8 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co if (!meta.has_value()) { throw std::runtime_error("model name=" + name + " is not found"); } - if (meta->status != SERVER_MODEL_STATUS_LOADED) { - throw std::invalid_argument("model name=" + name + " is not loaded"); + if (!meta->is_running()) { + throw std::invalid_argument("model name=" + name + " is not running"); } if (update_last_used) { std::unique_lock lk(mutex); @@ -819,6 +823,11 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co return proxy; } +bool server_models::is_child_server() { + const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT"); + return router_port != nullptr; +} + std::thread server_models::setup_child_server(const std::function & shutdown_handler) { // send a notification to the router server that a model instance is ready common_log_pause(common_log_main()); @@ -852,6 +861,13 @@ std::thread server_models::setup_child_server(const std::function & s }); } +void server_models::notify_router_sleeping_state(bool is_sleeping) { + common_log_pause(common_log_main()); + fflush(stdout); + fprintf(stdout, "%s\n", is_sleeping ? CMD_CHILD_TO_ROUTER_SLEEP : CMD_CHILD_TO_ROUTER_READY); + fflush(stdout); + common_log_resume(common_log_main()); +} // @@ -881,9 +897,9 @@ static bool router_validate_model(std::string & name, server_models & models, bo // resolve alias to canonical model name name = meta->name; if (models_autoload) { - models.ensure_model_loaded(name); + models.ensure_model_ready(name); } else { - if (meta->status != SERVER_MODEL_STATUS_LOADED) { + if (!meta->is_running()) { res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); return false; } @@ -956,8 +972,8 @@ void server_models_routes::init_routes() { res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND)); return res; } - if (meta->status == SERVER_MODEL_STATUS_LOADED) { - res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST)); + if (meta->is_running()) { + res_err(res, format_error_response("model is already running", ERROR_TYPE_INVALID_REQUEST)); return res; } models.load(meta->name); @@ -1015,8 +1031,8 @@ void server_models_routes::init_routes() { res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST)); return res; } - if (!model->is_active()) { - res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); + if (!model->is_running()) { + res_err(res, format_error_response("model is not running", ERROR_TYPE_INVALID_REQUEST)); return res; } models.unload(model->name); diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 2b392f299a..1db34b6c4d 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -14,17 +14,18 @@ /** * state diagram: * - * UNLOADED ──► LOADING ──► LOADED - * ▲ │ │ - * └───failed───┘ │ - * ▲ │ + * UNLOADED ──► LOADING ──► LOADED ◄──── SLEEPING + * ▲ │ │ ▲ + * └───failed───┘ │ │ + * ▲ └──sleeping─────┘ * └────────unloaded─────────┘ */ enum server_model_status { // TODO: also add downloading state when the logic is added SERVER_MODEL_STATUS_UNLOADED, SERVER_MODEL_STATUS_LOADING, - SERVER_MODEL_STATUS_LOADED + SERVER_MODEL_STATUS_LOADED, + SERVER_MODEL_STATUS_SLEEPING }; static server_model_status server_model_status_from_string(const std::string & status_str) { @@ -37,6 +38,9 @@ static server_model_status server_model_status_from_string(const std::string & s if (status_str == "loaded") { return SERVER_MODEL_STATUS_LOADED; } + if (status_str == "sleeping") { + return SERVER_MODEL_STATUS_SLEEPING; + } throw std::runtime_error("invalid server model status"); } @@ -45,6 +49,7 @@ static std::string server_model_status_to_string(server_model_status status) { case SERVER_MODEL_STATUS_UNLOADED: return "unloaded"; case SERVER_MODEL_STATUS_LOADING: return "loading"; case SERVER_MODEL_STATUS_LOADED: return "loaded"; + case SERVER_MODEL_STATUS_SLEEPING: return "sleeping"; default: return "unknown"; } } @@ -61,8 +66,12 @@ struct server_model_meta { int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED) int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown - bool is_active() const { - return status == SERVER_MODEL_STATUS_LOADED || status == SERVER_MODEL_STATUS_LOADING; + bool is_ready() const { + return status == SERVER_MODEL_STATUS_LOADED; + } + + bool is_running() const { + return status == SERVER_MODEL_STATUS_LOADED || status == SERVER_MODEL_STATUS_LOADING || status == SERVER_MODEL_STATUS_SLEEPING; } bool is_failed() const { @@ -130,19 +139,26 @@ public: void update_status(const std::string & name, server_model_status status, int exit_code); // wait until the model instance is fully loaded (thread-safe) - // return when the model is loaded or failed to load - void wait_until_loaded(const std::string & name); + // return when the model no longer in "loading" state + void wait_until_loading_finished(const std::string & name); - // load the model if not loaded, otherwise do nothing (thread-safe) - // return false if model is already loaded; return true otherwise (meta may need to be refreshed) - bool ensure_model_loaded(const std::string & name); + // ensure the model is in ready state (thread-safe) + // return false if model is ready + // otherwise, load the model and blocking wait until it's ready, then return true (meta may need to be refreshed) + bool ensure_model_ready(const std::string & name); // proxy an HTTP request to the model instance server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used); + // return true if the current process is a child server instance + static bool is_child_server(); + // notify the router server that a model instance is ready // return the monitoring thread (to be joined by the caller) static std::thread setup_child_server(const std::function & shutdown_handler); + + // notify the router server that the sleeping state has changed + static void notify_router_sleeping_state(bool sleeping); }; struct server_models_routes { diff --git a/tools/server/server-queue.h b/tools/server/server-queue.h index 164f09b195..35f010401f 100644 --- a/tools/server/server-queue.h +++ b/tools/server/server-queue.h @@ -95,11 +95,19 @@ public: callback_update_slots = std::move(callback); } - // Register callback for sleeping state change + // Register callback for sleeping state change; multiple callbacks are allowed // note: when entering sleeping state, the callback is called AFTER sleeping is set to true // when leaving sleeping state, the callback is called BEFORE sleeping is set to false void on_sleeping_state(std::function callback) { - callback_sleeping_state = std::move(callback); + if (callback_sleeping_state) { + auto prev_callback = std::move(callback_sleeping_state); + callback_sleeping_state = [prev_callback, callback](bool sleeping) { + prev_callback(sleeping); + callback(sleeping); + }; + } else { + callback_sleeping_state = std::move(callback); + } } private: diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 0bd6fda17d..73e3aa44e0 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -259,6 +259,12 @@ int main(int argc, char ** argv) { // load the model LOG_INF("%s: loading model\n", __func__); + if (server_models::is_child_server()) { + ctx_server.on_sleeping_changed([&](bool sleeping) { + server_models::notify_router_sleeping_state(sleeping); + }); + } + if (!ctx_server.load_model(params)) { clean_up(); if (ctx_http.thread.joinable()) { @@ -309,9 +315,8 @@ int main(int argc, char ** argv) { LOG_INF("%s: starting the main loop...\n", __func__); // optionally, notify router server that this instance is ready - const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT"); std::thread monitor_thread; - if (router_port != nullptr) { + if (server_models::is_child_server()) { monitor_thread = server_models::setup_child_server(shutdown_handler); } From d3ac030a5d1edfbb7f7150126f6b9e9bf7c00c26 Mon Sep 17 00:00:00 2001 From: DorianRudolph Date: Mon, 23 Mar 2026 01:04:14 +0100 Subject: [PATCH 8/9] mtmd : fix LightOnOCR image preprocessing (#20877) --- tools/mtmd/clip.cpp | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 44a19189ea..5fcc7c5b59 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1161,7 +1161,6 @@ struct clip_model_loader { hparams.set_warmup_n_tokens(16*16); } break; case PROJECTOR_TYPE_PIXTRAL: - case PROJECTOR_TYPE_LIGHTONOCR: { // ref: https://huggingface.co/mistral-community/pixtral-12b/blob/main/preprocessor_config.json // TODO: verify the image_min_tokens @@ -1171,6 +1170,15 @@ struct clip_model_loader { hparams.set_limit_image_tokens(8, 1024); hparams.set_warmup_n_tokens(256); // avoid OOM on warmup } break; + case PROJECTOR_TYPE_LIGHTONOCR: + { + hparams.n_merge = 1; + hparams.rope_theta = 10000.0f; + get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false); + hparams.image_longest_edge = hparams.image_size; + get_u32(KEY_PREPROC_IMAGE_SIZE, hparams.image_longest_edge, false); + hparams.set_warmup_n_tokens(256); // avoid OOM on warmup + } break; case PROJECTOR_TYPE_KIMIVL: { hparams.rope_theta = 10000.0f; @@ -3180,7 +3188,6 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str case PROJECTOR_TYPE_PHI4: case PROJECTOR_TYPE_PIXTRAL: - case PROJECTOR_TYPE_LIGHTONOCR: { GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0); clip_image_u8 resized_image; @@ -3196,6 +3203,19 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(img_f32)); } break; + case PROJECTOR_TYPE_LIGHTONOCR: + { + GGML_ASSERT(params.image_longest_edge > 0); + clip_image_u8 resized_image; + const clip_image_size target_size = img_tool::calc_size_preserved_ratio( + original_size, + params.patch_size * params.n_merge, + params.image_longest_edge); + img_tool::resize(*img, resized_image, target_size, img_tool::RESIZE_ALGO_BICUBIC); + clip_image_f32_ptr img_f32(clip_image_f32_init()); + normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(img_f32)); + } break; case PROJECTOR_TYPE_LLAMA4: { From ec2b787ebe975cab1fa592c2505a6d3a9e4ff2e7 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sun, 22 Mar 2026 20:06:30 -0400 Subject: [PATCH 9/9] mtmd: Add dynamic high-resolution image preprocessing for InternVL model (#20847) * added support for internvl's dynamic high-resolution (Qianfan-OCR needed) * add min/max dynamic patch to gguf meta * clean up * simplified handling min/max dynamic patch * reuse llava_uhd logic for slice images * provide default values for older models * flake8 * prevent writing 0 value to gguf * remove duplicated resolution candidates with a better algorithm * fix indentation * format * add protection from divide by zero * change to 0 to be safe --------- Co-authored-by: Xuan Son Nguyen --- convert_hf_to_gguf.py | 15 ++++++++++ gguf-py/gguf/constants.py | 2 ++ gguf-py/gguf/gguf_writer.py | 6 ++++ tools/mtmd/clip-impl.h | 2 ++ tools/mtmd/clip-model.h | 3 ++ tools/mtmd/clip.cpp | 56 +++++++++++++++++++++++++++++++++++-- tools/mtmd/mtmd.cpp | 4 ++- 7 files changed, 84 insertions(+), 4 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index dba190b480..0cd47645d3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4273,6 +4273,16 @@ class Qwen25OmniModel(Qwen2VLVisionModel): @ModelBase.register("InternVisionModel") class InternVisionModel(MmprojModel): + + min_dynamic_tiles: int = 0 + max_dynamic_tiles: int = 0 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_vision is not None + self.min_dynamic_tiles = self.global_config.get("min_dynamic_patch", 0) + self.max_dynamic_tiles = self.global_config.get("max_dynamic_patch", 0) + def set_gguf_parameters(self): assert self.hparams_vision is not None if isinstance(self.hparams_vision['image_size'], list): @@ -4295,6 +4305,11 @@ class InternVisionModel(MmprojModel): downsample_ratio = self.global_config.get("downsample_ratio") assert downsample_ratio is not None self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio)) + # older models may not have min/max_dynamic_patch in config + if self.min_dynamic_tiles > 0: + self.gguf_writer.add_vision_preproc_min_tiles(self.min_dynamic_tiles) + if self.max_dynamic_tiles > 0: + self.gguf_writer.add_vision_preproc_max_tiles(self.max_dynamic_tiles) def tensor_force_quant(self, name, new_name, bid, n_dims): if ".position_embd." in new_name: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c5f92c7700..9383644abf 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -301,6 +301,8 @@ class Keys: IMAGE_SIZE = "clip.vision.image_size" IMAGE_MIN_PIXELS = "clip.vision.image_min_pixels" IMAGE_MAX_PIXELS = "clip.vision.image_max_pixels" + PREPROC_MIN_TILES = "clip.vision.preproc_min_tiles" + PREPROC_MAX_TILES = "clip.vision.preproc_max_tiles" PREPROC_IMAGE_SIZE = "clip.vision.preproc_image_size" PATCH_SIZE = "clip.vision.patch_size" EMBEDDING_LENGTH = "clip.vision.embedding_length" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 5f653d386d..010dfeea1c 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1156,6 +1156,12 @@ class GGUFWriter: def add_vision_min_pixels(self, value: int) -> None: self.add_uint32(Keys.ClipVision.IMAGE_MIN_PIXELS, value) + def add_vision_preproc_max_tiles(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.PREPROC_MAX_TILES, value) + + def add_vision_preproc_min_tiles(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.PREPROC_MIN_TILES, value) + def add_vision_preproc_image_size(self, value: int) -> None: self.add_uint32(Keys.ClipVision.PREPROC_IMAGE_SIZE, value) diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 3eb66f9145..bf55cec7ef 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -38,6 +38,8 @@ #define KEY_IMAGE_SIZE "clip.vision.image_size" #define KEY_IMAGE_MIN_PIXELS "clip.vision.image_min_pixels" #define KEY_IMAGE_MAX_PIXELS "clip.vision.image_max_pixels" +#define KEY_PREPROC_MIN_TILES "clip.vision.preproc_min_tiles" +#define KEY_PREPROC_MAX_TILES "clip.vision.preproc_max_tiles" #define KEY_PREPROC_IMAGE_SIZE "clip.vision.preproc_image_size" #define KEY_PATCH_SIZE "clip.vision.patch_size" #define KEY_IMAGE_MEAN "clip.vision.image_mean" diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index eeb8da58e0..265a17130f 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -42,6 +42,9 @@ struct clip_hparams { int32_t image_max_pixels = -1; int32_t n_merge = 0; // number of patch merges **per-side** + int32_t preproc_min_tiles = 0; + int32_t preproc_max_tiles = 0; + float image_mean[3]; float image_std[3]; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 5fcc7c5b59..a47f1f495d 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1138,6 +1138,16 @@ struct clip_model_loader { } } break; case PROJECTOR_TYPE_INTERNVL: + { + // older version of internvl doesn't have min/max tiles, we need to provide default values for them to avoid issues + hparams.preproc_min_tiles = 1; + hparams.preproc_max_tiles = 12; + get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); + get_u32(KEY_PREPROC_MIN_TILES, hparams.preproc_min_tiles, false); + get_u32(KEY_PREPROC_MAX_TILES, hparams.preproc_max_tiles, false); + GGML_ASSERT(hparams.preproc_min_tiles <= hparams.preproc_max_tiles && hparams.preproc_max_tiles < INT32_MAX); + set_internvl_dhr_res_candidates(model); + } break; case PROJECTOR_TYPE_NEMOTRON_V2_VL: { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); @@ -2188,6 +2198,27 @@ struct clip_model_loader { } } } + + static void set_internvl_dhr_res_candidates(clip_model & model) { + auto & hparams = model.hparams; + int min_num = hparams.preproc_min_tiles; + int max_num = hparams.preproc_max_tiles; + if (min_num < 1) { + return; // avoid divide by 0 + } + for (int a = min_num; a <= max_num; ++a) { + int b_lo = (min_num + a - 1) / a; + int b_hi = max_num / a; + b_lo = std::max(b_lo, min_num); + b_hi = std::min(b_hi, max_num); + for (int b = b_lo; b <= b_hi; ++b) { + hparams.image_res_candidates.push_back(clip_image_size { + a*hparams.image_size, + b*hparams.image_size, + }); + } + } + } }; struct clip_init_result clip_init(const char * fname, struct clip_context_params ctx_params) { @@ -2734,17 +2765,22 @@ struct llava_uhd { return res; } - static std::vector slice_image(const clip_image_u8 * img, const slice_instructions & inst) { + static std::vector slice_image(const clip_image_u8 * img, const slice_instructions & inst, bool overview_first = true) { std::vector output; // resize to overview size clip_image_u8_ptr resized_img(clip_image_u8_init()); img_tool::resize(*img, *resized_img, inst.overview_size, inst.interpolation_overview, inst.padding_overview, inst.pad_color_overview); - output.push_back(std::move(resized_img)); + if (overview_first) { + output.push_back(std::move(resized_img)); + } if (inst.slices.empty()) { // no slices, just return the resized image + if (!overview_first) { + output.push_back(std::move(resized_img)); + } return output; } @@ -2765,6 +2801,10 @@ struct llava_uhd { output.push_back(std::move(img_slice)); } + if (!overview_first) { + output.push_back(std::move(resized_img)); + } + return output; } @@ -3149,10 +3189,20 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->grid_x = instructions.grid_size.width; res_imgs->grid_y = instructions.grid_size.height; } break; + case PROJECTOR_TYPE_INTERNVL: // support dynamic high-resolution + { + GGML_ASSERT(!params.image_res_candidates.empty()); + auto const inst = llava_uhd::get_slice_instructions(ctx, original_size); + std::vector imgs = llava_uhd::slice_image(img, inst, false); + for (size_t i = 0; i < imgs.size(); ++i) { + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(res)); + } + } break; case PROJECTOR_TYPE_GLM_EDGE: case PROJECTOR_TYPE_GEMMA3: - case PROJECTOR_TYPE_INTERNVL: // TODO @ngxson : support dynamic resolution case PROJECTOR_TYPE_NEMOTRON_V2_VL: { clip_image_u8 resized_image; diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index f66c07345e..456ce7b73c 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -851,13 +851,15 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) LOG_ERR("%s: this API does not support non-vision input, please use mtmd_encode_chunk instead\n", __func__); return 1; } + auto proj_type = clip_get_projector_type(ctx_clip); int n_mmproj_embd = clip_n_mmproj_embd(ctx_clip); ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd); bool ok = false; if (clip_is_llava(ctx_clip) || clip_is_minicpmv(ctx_clip) - || clip_is_glm(ctx_clip)) { + || clip_is_glm(ctx_clip) + || proj_type == PROJECTOR_TYPE_INTERNVL) { // TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode() const auto & entries = image_tokens->batch_f32.entries; for (size_t i = 0; i < entries.size(); i++) {