diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 0d5a818dac..369475eaf5 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -68,6 +68,7 @@ struct ggml_webgpu_shader_lib_context { size_t wg_mem_limit_bytes = 0; bool inplace = false; bool overlap = false; + bool src_overlap = false; bool supports_subgroup_matrix = false; uint32_t sg_mat_m = 0; uint32_t sg_mat_n = 0; @@ -179,9 +180,10 @@ struct ggml_webgpu_binary_pipeline_key { int op; bool inplace; bool overlap; + bool src_overlap; bool operator==(const ggml_webgpu_binary_pipeline_key & other) const { - return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap; + return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap; } }; @@ -192,6 +194,7 @@ struct ggml_webgpu_binary_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.op); ggml_webgpu_hash_combine(seed, key.inplace); ggml_webgpu_hash_combine(seed, key.overlap); + ggml_webgpu_hash_combine(seed, key.src_overlap); return seed; } }; @@ -1044,6 +1047,7 @@ class ggml_webgpu_shader_lib { .op = context.dst->op, .inplace = context.inplace, .overlap = context.overlap, + .src_overlap = context.src_overlap, }; auto it = binary_pipelines.find(key); @@ -1076,6 +1080,9 @@ class ggml_webgpu_shader_lib { } else if (key.overlap) { defines.push_back("OVERLAP"); variant += "_overlap"; + } else if (key.src_overlap) { + defines.push_back("SRC_OVERLAP"); + variant += "_src_overlap"; } defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 1c00d3cb2b..4dc56e1dc5 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -788,6 +788,7 @@ static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) { struct binary_overlap_flags { bool inplace; // src0 == dst bool overlap; // src1 == dst + bool src_overlap; }; static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0, @@ -796,6 +797,7 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0 binary_overlap_flags flags = {}; flags.inplace = ggml_webgpu_tensor_equal(src0, dst); flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); + flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); return flags; } @@ -1353,6 +1355,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, .inplace = flags.inplace, .overlap = flags.overlap, + .src_overlap = flags.src_overlap, }; webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); @@ -1361,11 +1364,28 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, uint32_t ne = (uint32_t) ggml_nelements(dst); + size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0); + size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1); + + uint32_t offset_merged_src0 = 0; + uint32_t offset_merged_src1 = 0; + if (flags.src_overlap) { + size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); + offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type)); + offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type)); + } + std::vector params = { ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + offset_merged_src0, + offset_merged_src1, + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), @@ -1381,25 +1401,43 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, std::vector entries; - entries.push_back({ - .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0), - }); - - entries.push_back({ - .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1), - }); - - if (!flags.inplace && !flags.overlap) { - entries.push_back({ .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + if (flags.src_overlap) { + size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); + size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0), + src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1)); + entries.push_back({ + .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = merged_offset, + .size = merged_end - merged_offset, + }); + entries.push_back({ + .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst), + }); + } else { + entries.push_back({ + .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = src0_webgpu_tensor_align_offset, + .size = ggml_webgpu_tensor_binding_size(ctx, src0), + }); + entries.push_back({ + .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = src1_webgpu_tensor_align_offset, + .size = ggml_webgpu_tensor_binding_size(ctx, src1), + }); + if (!flags.inplace && !flags.overlap) { + entries.push_back({ + .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst), + }); + } } uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); @@ -2816,10 +2854,8 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: - // TODO: support non-contiguous tensors, e.g. for MOE_EXPERT_REDUCE - // see https://github.com/ggml-org/llama.cpp/pull/16857 supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) && - (src1->type == op->type) && ggml_is_contiguous(src0) && ggml_is_contiguous(src1); + (src1->type == op->type); break; case GGML_OP_CPY: case GGML_OP_CONT: diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl index 55dd66408a..a748dc1b86 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl @@ -7,6 +7,13 @@ struct Params { offset_src0: u32, offset_src1: u32, offset_dst: u32, + offset_merged_src0: u32, + offset_merged_src1: u32, + + stride_src0_0: u32, + stride_src0_1: u32, + stride_src0_2: u32, + stride_src0_3: u32, stride_src1_0: u32, stride_src1_1: u32, @@ -23,6 +30,21 @@ struct Params { b_ne3: u32, }; +fn src0_index(_i: u32) -> u32 { + var i = _i; + let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); + i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0); + let a_i2 = i / (params.a_ne1 * params.a_ne0); + i = i % (params.a_ne1 * params.a_ne0); + let a_i1 = i / params.a_ne0; + let a_i0 = i % params.a_ne0; + + return a_i0 * params.stride_src0_0 + + a_i1 * params.stride_src0_1 + + a_i2 * params.stride_src0_2 + + a_i3 * params.stride_src0_3; +} + fn src1_index(_i: u32) -> u32 { var i = _i; let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); @@ -53,17 +75,22 @@ fn src1_index(_i: u32) -> u32 { #define DataType f16 #endif +#ifdef SRC_OVERLAP +@group(0) @binding(0) +var merged_src: array; + +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; +#else @group(0) @binding(0) var src0: array; @group(0) @binding(1) var src1 : array; - -#ifdef INPLACE -@group(0) @binding(2) -var params: Params; - -#elif defined(OVERLAP) +#if defined(INPLACE) || defined(OVERLAP) @group(0) @binding(2) var params: Params; @@ -74,6 +101,7 @@ var dst: array; @group(0) @binding(3) var params: Params; #endif +#endif fn op(a: DataType, b: DataType) -> DataType { #ifdef OP_ADD @@ -87,13 +115,17 @@ fn op(a: DataType, b: DataType) -> DataType { #endif } -fn update(dst_i: u32, src0_i: u32, src1_i: u32){ +fn update(dst_i: u32, src0_i: u32, src1_i: u32) { +#ifdef SRC_OVERLAP + let result = op(merged_src[src0_i], merged_src[src1_i]); +#else let result = op(src0[src0_i], src1[src1_i]); +#endif #ifdef INPLACE - src0[dst_i] = result; + src0[src0_i] = result; #elif defined(OVERLAP) - src1[dst_i] = result; + src1[src1_i] = result; #else dst[dst_i] = result; #endif @@ -102,6 +134,8 @@ fn update(dst_i: u32, src0_i: u32, src1_i: u32){ @compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { if (gid.x < params.ne) { - update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x)); + let src0_i = params.offset_src0 + params.offset_merged_src0 + src0_index(gid.x); + let src1_i = params.offset_src1 + params.offset_merged_src1 + src1_index(gid.x); + update(params.offset_dst + gid.x, src0_i, src1_i); } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index e8e237c6ec..0ac21cdcf6 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2977,6 +2977,7 @@ struct test_bin_bcast : public test_case { const std::array nr; int nf; // number of fused ops, nf == 1 -> single op (no fusion) bool perm1; // permute src1? + bool src_overlap; // src0 and src1 are overlapping views of the same buffer bool run_whole_graph() override { return nf > 1; } @@ -2992,8 +2993,8 @@ struct test_bin_bcast : public test_case { std::array ne = {10, 10, 1, 1}, std::array nr = {1, 2, 1, 1}, int nf = 1, - bool perm1 = false) - : op(op), type(type), ne(ne), nr(nr), nf(nf), perm1(perm1) {} + bool perm1 = false, bool src_overlap = false) + : op(op), type(type), ne(ne), nr(nr), nf(nf), perm1(perm1), src_overlap(src_overlap) {} ggml_tensor * build_graph(ggml_context * ctx) override { GGML_ASSERT(nf <= 16); @@ -3008,6 +3009,8 @@ struct test_bin_bcast : public test_case { b[i] = ggml_new_tensor_4d(ctx, type, ne[p[0]], ne[p[1]], ne[p[2]], ne[p[3]]); b[i] = ggml_permute(ctx, b[i], p[0], p[1], p[2], p[3]); + } else if (src_overlap) { + b[i] = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], 2 * (ne[3] / 3), a->nb[1], a->nb[2], a->nb[3], (ne[3] / 3) * a->nb[3]); } else { b[i] = ggml_new_tensor(ctx, type, 4, ne.data()); } @@ -3021,7 +3024,13 @@ struct test_bin_bcast : public test_case { ggml_set_param(b[0]); } - ggml_tensor * out = a; + ggml_tensor *out; + + if (src_overlap) { + out = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], 2 * (ne[3] / 3), a->nb[1], a->nb[2], a->nb[3], 0); + } else { + out = a; + } for (int i = 0; i < nf; ++i) { out = op(ctx, out, b[i]); @@ -7527,9 +7536,9 @@ static std::vector> make_test_cases_eval() { } } - auto add_test_bin_bcast = [&](ggml_type type, std::array ne, std::array nr, bool perm1 = false) { + auto add_test_bin_bcast = [&](ggml_type type, std::array ne, std::array nr, bool perm1 = false, bool src_overlap = false) { for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) { - test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr, 1, perm1)); + test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr, 1, perm1, src_overlap)); } }; for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) { @@ -7549,6 +7558,12 @@ static std::vector> make_test_cases_eval() { add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2}, perm1); } + // src_overlap + add_test_bin_bcast(type, {10, 5, 4, 6}, {1, 1, 1, 1}, false, true); + add_test_bin_bcast(type, {10, 5, 4, 5}, {1, 1, 1, 1}, false, true); + add_test_bin_bcast(type, {1, 1, 120, 120}, {1, 1, 1, 1}, false, true); + add_test_bin_bcast(type, {1, 1, 4, 320}, {1, 1, 1, 1}, false, true); + // test case for k_bin_bcast_unravel in CUDA backend add_test_bin_bcast(type, {1, 1, 65536, 1}, {256, 1, 1, 1});