From 3b3a9481341ce475089efaa25151508a961fe217 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 12 Feb 2026 11:35:28 +0200 Subject: [PATCH] metal : update sum_rows kernel to support float4 (#19524) --- ggml/src/ggml-metal/ggml-metal-device.cpp | 33 +++++++--- ggml/src/ggml-metal/ggml-metal-impl.h | 3 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 18 ++++-- ggml/src/ggml-metal/ggml-metal.metal | 79 ++++++++++++++--------- tests/test-backend-ops.cpp | 24 ++++--- 5 files changed, 106 insertions(+), 51 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 517559d12a..06f3d80459 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -328,31 +328,46 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_l } ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) { - GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); char base[256]; char name[256]; - const char * op_str = "undefined"; + int op_num = -1; + switch (op->op) { - case GGML_OP_SUM_ROWS: - op_str = "sum_rows"; break; - case GGML_OP_MEAN: - op_str = "mean"; break; + case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break; + case GGML_OP_MEAN: op_num = OP_SUM_ROWS_NUM_MEAN; break; default: GGML_ABORT("fatal error"); }; - snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t_str = ggml_type_name(op->type); - snprintf(name, 256, "%s", base); + const bool is_c4 = op->src[0]->ne[0] % 4 == 0; + + snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : ""); + snprintf(name, 256, "%s_op=%d", base, op_num); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } res.smem = 32*sizeof(float); + if (is_c4) { + res.smem *= 4; + } + + res.c4 = is_c4; + return res; } diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 952e1be076..383e0d6e93 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -82,6 +82,7 @@ #define FC_COUNT_EQUAL 1100 #define FC_UNARY 1200 #define FC_BIN 1300 +#define FC_SUM_ROWS 1400 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPSG 8 @@ -118,6 +119,8 @@ #define OP_UNARY_NUM_SOFTPLUS 115 #define OP_UNARY_NUM_EXPM1 116 +#define OP_SUM_ROWS_NUM_SUM_ROWS 10 +#define OP_SUM_ROWS_NUM_MEAN 11 // kernel argument structs // diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 7db95d1c84..20880d9551 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -904,6 +904,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + ggml_metal_kargs_sum_rows args = { /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, @@ -925,21 +930,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op); + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } + int nth = 32; // SIMD width - while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nth *= 2; } nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - nth = std::min(nth, ne00); + nth = std::min(nth, (int) args.ne00); const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 0036ba90ec..6c349aa0c9 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -77,6 +77,14 @@ static inline float dot(float x, float y) { return x*y; } +static inline float sum(float x) { + return x; +} + +static inline float sum(float4 x) { + return x[0] + x[1] + x[2] + x[3]; +} + // NOTE: this is not dequantizing - we are simply fitting the template template void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { @@ -1501,33 +1509,35 @@ kernel void kernel_op_sum_f32( } } -template -kernel void kernel_sum_rows( +constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]]; + +template +kernel void kernel_sum_rows_impl( constant ggml_metal_kargs_sum_rows & args, - device const float * src0, - device float * dst, - threadgroup float * shmem_f32 [[threadgroup(0)]], + device const char * src0, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - int64_t i3 = tgpig.z; - int64_t i2 = tgpig.y; - int64_t i1 = tgpig.x; +#define FC_OP FC_sum_rows_op - if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { - return; - } + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; + + threadgroup T0 * shmem_t = (threadgroup T0 *) shmem; if (sgitg == 0) { - shmem_f32[tiisg] = 0.0f; + shmem_t[tiisg] = 0.0f; } - device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); - device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); + device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); - float sumf = 0; + T0 sumf = T0(0.0f); for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { sumf += src_row[i0]; @@ -1538,23 +1548,33 @@ kernel void kernel_sum_rows( threadgroup_barrier(mem_flags::mem_threadgroup); if (tiisg == 0) { - shmem_f32[sgitg] = sumf; + shmem_t[sgitg] = sumf; } threadgroup_barrier(mem_flags::mem_threadgroup); - sumf = shmem_f32[tiisg]; + sumf = shmem_t[tiisg]; sumf = simd_sum(sumf); if (tpitg.x == 0) { - dst_row[0] = norm ? sumf / args.ne00 : sumf; + if (FC_OP == OP_SUM_ROWS_NUM_MEAN) { + if (is_same::value) { + dst_row[0] = sum(sumf) / (4*args.ne00); + } else { + dst_row[0] = sum(sumf) / args.ne00; + } + } else { + dst_row[0] = sum(sumf); + } } + +#undef FC_OP } -typedef decltype(kernel_sum_rows) kernel_sum_rows_t; +typedef decltype(kernel_sum_rows_impl) kernel_sum_rows_t; -template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; -template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; +template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl; +template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl; template kernel void kernel_cumsum_blk( @@ -2435,9 +2455,6 @@ kernel void kernel_solve_tri_f32( const short K = FC_solve_tri_k; const short NP = PAD2(N, NW); - const int32_t ne02 = args.ne02; - const int32_t ne03 = args.ne03; - const int32_t i03 = tgpig.z; const int32_t i02 = tgpig.y; const int32_t i01 = tgpig.x*NSG + sgitg; @@ -5949,7 +5966,7 @@ kernel void kernel_flash_attn_ext_vec( static_assert(DK4 % NL == 0, "DK4 must be divisible by NL"); static_assert(DV4 % NL == 0, "DV4 must be divisible by NL"); - const short T = PK + NSG*SH; // shared memory size per query in (half) + //const short T = PK + NSG*SH; // shared memory size per query in (half) //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t @@ -8537,7 +8554,9 @@ kernel void kernel_mul_mm( threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); +#ifdef GGML_METAL_HAS_TENSOR threadgroup float * sc = (threadgroup float *)(shmem); +#endif constexpr int NR0 = 64; constexpr int NR1 = 32; @@ -8660,8 +8679,8 @@ kernel void kernel_mul_mm( const short sx = (tiitg%NL1); const short sy = (tiitg/NL1)/8; - const short dx = sx; - const short dy = sy; + //const short dx = sx; + //const short dy = sy; const short ly = (tiitg/NL1)%8; @@ -8910,7 +8929,9 @@ kernel void kernel_mul_mm_id( threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); +#ifdef GGML_METAL_HAS_TENSOR threadgroup float * sc = (threadgroup float *)(shmem); +#endif constexpr int NR0 = 64; constexpr int NR1 = 32; @@ -9045,8 +9066,8 @@ kernel void kernel_mul_mm_id( const short sx = (tiitg%NL1); const short sy = (tiitg/NL1)/8; - const short dx = sx; - const short dy = sy; + //const short dx = sx; + //const short dy = sy; const short ly = (tiitg/NL1)%8; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ed99c24516..222b935841 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8132,24 +8132,30 @@ static std::vector> make_test_cases_eval() { } test_cases.emplace_back(new test_sum()); - test_cases.emplace_back(new test_sum_rows()); test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3})); // row-contiguous but non-contiguous test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1})); test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2})); + test_cases.emplace_back(new test_mean()); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 })); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 })); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 })); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32, 1, 1, 1 })); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32, 256, 1, 1 })); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32768, 1, 1, 1 })); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 })); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 })); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 })); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous + test_cases.emplace_back(new test_sum_rows()); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false)); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true)); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true)); - test_cases.emplace_back(new test_mean()); - test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 })); + test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, true, false)); + test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, false, true)); + test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, true, true)); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 })); - test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 })); - test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 })); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 })); - test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 })); - test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 })); - test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 })); - test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 })); test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1})); test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1})); test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {64, 64, 320, 1}));