metal : update sum_rows kernel to support float4 (#19524)

This commit is contained in:
Georgi Gerganov 2026-02-12 11:35:28 +02:00 committed by GitHub
parent 6845f7f87f
commit 3b3a948134
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 106 additions and 51 deletions

View File

@ -328,31 +328,46 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_l
} }
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) { ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
char base[256]; char base[256];
char name[256]; char name[256];
const char * op_str = "undefined"; int op_num = -1;
switch (op->op) { switch (op->op) {
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break;
op_str = "sum_rows"; break; case GGML_OP_MEAN: op_num = OP_SUM_ROWS_NUM_MEAN; break;
case GGML_OP_MEAN:
op_str = "mean"; break;
default: GGML_ABORT("fatal error"); default: GGML_ABORT("fatal error");
}; };
snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); const char * t0_str = ggml_type_name(op->src[0]->type);
const char * t_str = ggml_type_name(op->type);
snprintf(name, 256, "%s", base); const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s_op=%d", base, op_num);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) { if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
} }
res.smem = 32*sizeof(float); res.smem = 32*sizeof(float);
if (is_c4) {
res.smem *= 4;
}
res.c4 = is_c4;
return res; return res;
} }

View File

@ -82,6 +82,7 @@
#define FC_COUNT_EQUAL 1100 #define FC_COUNT_EQUAL 1100
#define FC_UNARY 1200 #define FC_UNARY 1200
#define FC_BIN 1300 #define FC_BIN 1300
#define FC_SUM_ROWS 1400
// op-specific constants // op-specific constants
#define OP_FLASH_ATTN_EXT_NQPSG 8 #define OP_FLASH_ATTN_EXT_NQPSG 8
@ -118,6 +119,8 @@
#define OP_UNARY_NUM_SOFTPLUS 115 #define OP_UNARY_NUM_SOFTPLUS 115
#define OP_UNARY_NUM_EXPM1 116 #define OP_UNARY_NUM_EXPM1 116
#define OP_SUM_ROWS_NUM_SUM_ROWS 10
#define OP_SUM_ROWS_NUM_MEAN 11
// kernel argument structs // kernel argument structs
// //

View File

@ -904,6 +904,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_kargs_sum_rows args = { ggml_metal_kargs_sum_rows args = {
/*.ne00 =*/ ne00, /*.ne00 =*/ ne00,
/*.ne01 =*/ ne01, /*.ne01 =*/ ne01,
@ -925,21 +930,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op); auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
if (pipeline.c4) {
args.ne00 = ne00/4;
args.ne0 = ne0/4;
}
int nth = 32; // SIMD width int nth = 32; // SIMD width
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2; nth *= 2;
} }
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, ne00); nth = std::min(nth, (int) args.ne00);
const size_t smem = pipeline.smem; const size_t smem = pipeline.smem;
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);

View File

@ -77,6 +77,14 @@ static inline float dot(float x, float y) {
return x*y; return x*y;
} }
static inline float sum(float x) {
return x;
}
static inline float sum(float4 x) {
return x[0] + x[1] + x[2] + x[3];
}
// NOTE: this is not dequantizing - we are simply fitting the template // NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4> template <typename type4x4>
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@ -1501,33 +1509,35 @@ kernel void kernel_op_sum_f32(
} }
} }
template <bool norm> constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]];
kernel void kernel_sum_rows(
template <typename T0, typename T>
kernel void kernel_sum_rows_impl(
constant ggml_metal_kargs_sum_rows & args, constant ggml_metal_kargs_sum_rows & args,
device const float * src0, device const char * src0,
device float * dst, device char * dst,
threadgroup float * shmem_f32 [[threadgroup(0)]], threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]], ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) { ushort3 ntg[[threads_per_threadgroup]]) {
int64_t i3 = tgpig.z; #define FC_OP FC_sum_rows_op
int64_t i2 = tgpig.y;
int64_t i1 = tgpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { const int i3 = tgpig.z;
return; const int i2 = tgpig.y;
} const int i1 = tgpig.x;
threadgroup T0 * shmem_t = (threadgroup T0 *) shmem;
if (sgitg == 0) { if (sgitg == 0) {
shmem_f32[tiisg] = 0.0f; shmem_t[tiisg] = 0.0f;
} }
device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
float sumf = 0; T0 sumf = T0(0.0f);
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
sumf += src_row[i0]; sumf += src_row[i0];
@ -1538,23 +1548,33 @@ kernel void kernel_sum_rows(
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) { if (tiisg == 0) {
shmem_f32[sgitg] = sumf; shmem_t[sgitg] = sumf;
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
sumf = shmem_f32[tiisg]; sumf = shmem_t[tiisg];
sumf = simd_sum(sumf); sumf = simd_sum(sumf);
if (tpitg.x == 0) { if (tpitg.x == 0) {
dst_row[0] = norm ? sumf / args.ne00 : sumf; if (FC_OP == OP_SUM_ROWS_NUM_MEAN) {
if (is_same<float4, T0>::value) {
dst_row[0] = sum(sumf) / (4*args.ne00);
} else {
dst_row[0] = sum(sumf) / args.ne00;
}
} else {
dst_row[0] = sum(sumf);
}
} }
#undef FC_OP
} }
typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t; typedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t;
template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>; template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float, float>;
template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>; template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float4, float>;
template<typename T> template<typename T>
kernel void kernel_cumsum_blk( kernel void kernel_cumsum_blk(
@ -2435,9 +2455,6 @@ kernel void kernel_solve_tri_f32(
const short K = FC_solve_tri_k; const short K = FC_solve_tri_k;
const short NP = PAD2(N, NW); const short NP = PAD2(N, NW);
const int32_t ne02 = args.ne02;
const int32_t ne03 = args.ne03;
const int32_t i03 = tgpig.z; const int32_t i03 = tgpig.z;
const int32_t i02 = tgpig.y; const int32_t i02 = tgpig.y;
const int32_t i01 = tgpig.x*NSG + sgitg; const int32_t i01 = tgpig.x*NSG + sgitg;
@ -5949,7 +5966,7 @@ kernel void kernel_flash_attn_ext_vec(
static_assert(DK4 % NL == 0, "DK4 must be divisible by NL"); static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
static_assert(DV4 % NL == 0, "DV4 must be divisible by NL"); static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
const short T = PK + NSG*SH; // shared memory size per query in (half) //const short T = PK + NSG*SH; // shared memory size per query in (half)
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
@ -8537,7 +8554,9 @@ kernel void kernel_mul_mm(
threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
#ifdef GGML_METAL_HAS_TENSOR
threadgroup float * sc = (threadgroup float *)(shmem); threadgroup float * sc = (threadgroup float *)(shmem);
#endif
constexpr int NR0 = 64; constexpr int NR0 = 64;
constexpr int NR1 = 32; constexpr int NR1 = 32;
@ -8660,8 +8679,8 @@ kernel void kernel_mul_mm(
const short sx = (tiitg%NL1); const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8; const short sy = (tiitg/NL1)/8;
const short dx = sx; //const short dx = sx;
const short dy = sy; //const short dy = sy;
const short ly = (tiitg/NL1)%8; const short ly = (tiitg/NL1)%8;
@ -8910,7 +8929,9 @@ kernel void kernel_mul_mm_id(
threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S0 * sa = (threadgroup S0 *)(shmem);
threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
#ifdef GGML_METAL_HAS_TENSOR
threadgroup float * sc = (threadgroup float *)(shmem); threadgroup float * sc = (threadgroup float *)(shmem);
#endif
constexpr int NR0 = 64; constexpr int NR0 = 64;
constexpr int NR1 = 32; constexpr int NR1 = 32;
@ -9045,8 +9066,8 @@ kernel void kernel_mul_mm_id(
const short sx = (tiitg%NL1); const short sx = (tiitg%NL1);
const short sy = (tiitg/NL1)/8; const short sy = (tiitg/NL1)/8;
const short dx = sx; //const short dx = sx;
const short dy = sy; //const short dy = sy;
const short ly = (tiitg/NL1)%8; const short ly = (tiitg/NL1)%8;

View File

@ -8132,24 +8132,30 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
} }
test_cases.emplace_back(new test_sum()); test_cases.emplace_back(new test_sum());
test_cases.emplace_back(new test_sum_rows());
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3})); // row-contiguous but non-contiguous test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3})); // row-contiguous but non-contiguous
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1})); test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1}));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2})); test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mean());
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32, 1, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32, 256, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32768, 1, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous
test_cases.emplace_back(new test_sum_rows());
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false)); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true)); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true)); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true));
test_cases.emplace_back(new test_mean()); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, true, false));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 })); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, false, true));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, true, true));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 })); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 })); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 })); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1})); test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1}));
test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1})); test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1}));
test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {64, 64, 320, 1})); test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {64, 64, 320, 1}));