metal: optimise `GGML_OP_SUM` (#16559)
* optimise GGML_OP_SUM * add non-contiguous tests by permuting the input * change tests to require full contiguity of OP_SUM * cuda : add check GGML_OP_SUM --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
17304cbcc1
commit
f4ce81c45e
|
|
@ -3625,9 +3625,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
case GGML_OP_CONV_2D_DW:
|
case GGML_OP_CONV_2D_DW:
|
||||||
case GGML_OP_CONV_TRANSPOSE_2D:
|
case GGML_OP_CONV_TRANSPOSE_2D:
|
||||||
case GGML_OP_POOL_2D:
|
case GGML_OP_POOL_2D:
|
||||||
case GGML_OP_SUM:
|
|
||||||
case GGML_OP_ACC:
|
case GGML_OP_ACC:
|
||||||
return true;
|
return true;
|
||||||
|
case GGML_OP_SUM:
|
||||||
|
return ggml_is_contiguous_rows(op->src[0]);
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
// TODO: Support arbitrary column width
|
// TODO: Support arbitrary column width
|
||||||
return op->src[0]->ne[0] <= 1024;
|
return op->src[0]->ne[0] <= 1024;
|
||||||
|
|
|
||||||
|
|
@ -662,6 +662,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
|
||||||
case GGML_OP_LOG:
|
case GGML_OP_LOG:
|
||||||
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_SUM:
|
case GGML_OP_SUM:
|
||||||
|
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||||
case GGML_OP_SUM_ROWS:
|
case GGML_OP_SUM_ROWS:
|
||||||
case GGML_OP_MEAN:
|
case GGML_OP_MEAN:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
|
|
|
||||||
|
|
@ -866,12 +866,25 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
|
||||||
|
|
||||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
|
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
|
||||||
|
|
||||||
|
int nth = 32; // SIMD width
|
||||||
|
|
||||||
|
while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
||||||
|
nth *= 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||||
|
nth = std::min(nth, (int) n);
|
||||||
|
|
||||||
|
const int nsg = (nth + 31) / 32;
|
||||||
|
|
||||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||||
|
|
||||||
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
|
||||||
|
|
||||||
|
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
|
||||||
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1727,18 +1727,48 @@ kernel void kernel_op_sum_f32(
|
||||||
constant ggml_metal_kargs_sum & args,
|
constant ggml_metal_kargs_sum & args,
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
ushort tiitg[[thread_index_in_threadgroup]]) {
|
threadgroup float * shmem_f32 [[threadgroup(0)]],
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||||
|
|
||||||
if (tiitg != 0) {
|
if (args.np == 0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
float acc = 0.0f;
|
const uint nsg = (ntg.x + 31) / 32;
|
||||||
for (ulong i = 0; i < args.np; ++i) {
|
|
||||||
acc += src0[i];
|
float sumf = 0;
|
||||||
|
|
||||||
|
for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
|
||||||
|
sumf += src0[i0];
|
||||||
}
|
}
|
||||||
|
|
||||||
dst[0] = acc;
|
sumf = simd_sum(sumf);
|
||||||
|
|
||||||
|
if (tiisg == 0) {
|
||||||
|
shmem_f32[sgitg] = sumf;
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
float total = 0;
|
||||||
|
|
||||||
|
if (sgitg == 0) {
|
||||||
|
float v = 0;
|
||||||
|
|
||||||
|
if (tpitg.x < nsg) {
|
||||||
|
v = shmem_f32[tpitg.x];
|
||||||
|
}
|
||||||
|
|
||||||
|
total = simd_sum(v);
|
||||||
|
|
||||||
|
if (tpitg.x == 0) {
|
||||||
|
dst[0] = total;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool norm>
|
template <bool norm>
|
||||||
|
|
|
||||||
|
|
@ -4588,20 +4588,31 @@ struct test_topk_moe: public test_case {
|
||||||
struct test_sum : public test_case {
|
struct test_sum : public test_case {
|
||||||
const ggml_type type;
|
const ggml_type type;
|
||||||
const std::array<int64_t, 4> ne;
|
const std::array<int64_t, 4> ne;
|
||||||
|
const std::array<int64_t, 4> permute;
|
||||||
|
bool _use_permute;
|
||||||
|
|
||||||
std::string vars() override {
|
std::string vars() override {
|
||||||
return VARS_TO_STR2(type, ne);
|
std::string v = VARS_TO_STR2(type, ne);
|
||||||
|
if (_use_permute) v += "," + VAR_TO_STR(permute);
|
||||||
|
return v;
|
||||||
}
|
}
|
||||||
|
|
||||||
test_sum(ggml_type type = GGML_TYPE_F32,
|
test_sum(ggml_type type = GGML_TYPE_F32,
|
||||||
std::array<int64_t, 4> ne = {10, 5, 4, 3})
|
std::array<int64_t, 4> ne = {10, 5, 4, 3},
|
||||||
: type(type), ne(ne) {}
|
std::array<int64_t, 4> permute = {0, 0, 0, 0})
|
||||||
|
: type(type), ne(ne), permute(permute),
|
||||||
|
_use_permute(permute[0] + permute[1] + permute[2] + permute[3] > 0) {}
|
||||||
|
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
ggml_set_param(a);
|
ggml_set_param(a);
|
||||||
ggml_set_name(a, "a");
|
ggml_set_name(a, "a");
|
||||||
|
|
||||||
|
if (_use_permute) {
|
||||||
|
a = ggml_permute(ctx, a, permute[0], permute[1], permute[2], permute[3]);
|
||||||
|
ggml_set_name(a, "a_permuted");
|
||||||
|
}
|
||||||
|
|
||||||
ggml_tensor * out = ggml_sum(ctx, a);
|
ggml_tensor * out = ggml_sum(ctx, a);
|
||||||
ggml_set_name(out, "out");
|
ggml_set_name(out, "out");
|
||||||
|
|
||||||
|
|
@ -6724,6 +6735,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
|
|
||||||
test_cases.emplace_back(new test_sum());
|
test_cases.emplace_back(new test_sum());
|
||||||
test_cases.emplace_back(new test_sum_rows());
|
test_cases.emplace_back(new test_sum_rows());
|
||||||
|
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3})); // row-contiguous but non-contiguous
|
||||||
|
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1}));
|
||||||
|
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2}));
|
||||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false));
|
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false));
|
||||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true));
|
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true));
|
||||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true));
|
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true));
|
||||||
|
|
@ -6734,6 +6748,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
|
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
|
||||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
|
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 }));
|
||||||
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
|
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }));
|
||||||
|
test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous
|
||||||
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 }));
|
test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 }));
|
||||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
|
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 }));
|
||||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));
|
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 }));
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue