ggml-webgpu: Support non-contiguous `src0` and overlapping `src0/src1` in binary ops (#19850)

* ggml-webgpu: Add binary op support for overlapping and non-contiguous.

* Add newline to binary.wgsl

* Append the test of binary op for src overlapping  to test_bin_bcast.

* Remove unnecessary newline.
This commit is contained in:
Masashi Yoshimura 2026-03-03 00:59:53 +09:00 committed by GitHub
parent feefb92836
commit 36a7a6589c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 131 additions and 39 deletions

View File

@ -68,6 +68,7 @@ struct ggml_webgpu_shader_lib_context {
size_t wg_mem_limit_bytes = 0;
bool inplace = false;
bool overlap = false;
bool src_overlap = false;
bool supports_subgroup_matrix = false;
uint32_t sg_mat_m = 0;
uint32_t sg_mat_n = 0;
@ -179,9 +180,10 @@ struct ggml_webgpu_binary_pipeline_key {
int op;
bool inplace;
bool overlap;
bool src_overlap;
bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap;
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap;
}
};
@ -192,6 +194,7 @@ struct ggml_webgpu_binary_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.op);
ggml_webgpu_hash_combine(seed, key.inplace);
ggml_webgpu_hash_combine(seed, key.overlap);
ggml_webgpu_hash_combine(seed, key.src_overlap);
return seed;
}
};
@ -1044,6 +1047,7 @@ class ggml_webgpu_shader_lib {
.op = context.dst->op,
.inplace = context.inplace,
.overlap = context.overlap,
.src_overlap = context.src_overlap,
};
auto it = binary_pipelines.find(key);
@ -1076,6 +1080,9 @@ class ggml_webgpu_shader_lib {
} else if (key.overlap) {
defines.push_back("OVERLAP");
variant += "_overlap";
} else if (key.src_overlap) {
defines.push_back("SRC_OVERLAP");
variant += "_src_overlap";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));

View File

@ -788,6 +788,7 @@ static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
struct binary_overlap_flags {
bool inplace; // src0 == dst
bool overlap; // src1 == dst
bool src_overlap;
};
static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0,
@ -796,6 +797,7 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0
binary_overlap_flags flags = {};
flags.inplace = ggml_webgpu_tensor_equal(src0, dst);
flags.overlap = ggml_webgpu_tensor_overlap(src1, dst);
flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1);
return flags;
}
@ -1353,6 +1355,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.inplace = flags.inplace,
.overlap = flags.overlap,
.src_overlap = flags.src_overlap,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
@ -1361,11 +1364,28 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
uint32_t ne = (uint32_t) ggml_nelements(dst);
size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0);
size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1);
uint32_t offset_merged_src0 = 0;
uint32_t offset_merged_src1 = 0;
if (flags.src_overlap) {
size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
}
std::vector<uint32_t> params = {
ne,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
offset_merged_src0,
offset_merged_src1,
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
@ -1381,25 +1401,43 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
std::vector<wgpu::BindGroupEntry> entries;
entries.push_back({
.binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0),
});
entries.push_back({
.binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1),
});
if (!flags.inplace && !flags.overlap) {
entries.push_back({ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
if (flags.src_overlap) {
size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0),
src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1));
entries.push_back({
.binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = merged_offset,
.size = merged_end - merged_offset,
});
entries.push_back({
.binding = 1,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst),
});
} else {
entries.push_back({
.binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = src0_webgpu_tensor_align_offset,
.size = ggml_webgpu_tensor_binding_size(ctx, src0),
});
entries.push_back({
.binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = src1_webgpu_tensor_align_offset,
.size = ggml_webgpu_tensor_binding_size(ctx, src1),
});
if (!flags.inplace && !flags.overlap) {
entries.push_back({
.binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst),
});
}
}
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
@ -2816,10 +2854,8 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
// TODO: support non-contiguous tensors, e.g. for MOE_EXPERT_REDUCE
// see https://github.com/ggml-org/llama.cpp/pull/16857
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
(src1->type == op->type) && ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
(src1->type == op->type);
break;
case GGML_OP_CPY:
case GGML_OP_CONT:

View File

@ -7,6 +7,13 @@ struct Params {
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
offset_merged_src0: u32,
offset_merged_src1: u32,
stride_src0_0: u32,
stride_src0_1: u32,
stride_src0_2: u32,
stride_src0_3: u32,
stride_src1_0: u32,
stride_src1_1: u32,
@ -23,6 +30,21 @@ struct Params {
b_ne3: u32,
};
fn src0_index(_i: u32) -> u32 {
var i = _i;
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
let a_i2 = i / (params.a_ne1 * params.a_ne0);
i = i % (params.a_ne1 * params.a_ne0);
let a_i1 = i / params.a_ne0;
let a_i0 = i % params.a_ne0;
return a_i0 * params.stride_src0_0 +
a_i1 * params.stride_src0_1 +
a_i2 * params.stride_src0_2 +
a_i3 * params.stride_src0_3;
}
fn src1_index(_i: u32) -> u32 {
var i = _i;
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
@ -53,17 +75,22 @@ fn src1_index(_i: u32) -> u32 {
#define DataType f16
#endif
#ifdef SRC_OVERLAP
@group(0) @binding(0)
var<storage, read_write> merged_src: array<DataType>;
@group(0) @binding(1)
var<storage, read_write> dst: array<DataType>;
@group(0) @binding(2)
var<uniform> params: Params;
#else
@group(0) @binding(0)
var<storage, read_write> src0: array<DataType>;
@group(0) @binding(1)
var<storage, read_write> src1 : array<DataType>;
#ifdef INPLACE
@group(0) @binding(2)
var<uniform> params: Params;
#elif defined(OVERLAP)
#if defined(INPLACE) || defined(OVERLAP)
@group(0) @binding(2)
var<uniform> params: Params;
@ -74,6 +101,7 @@ var<storage, read_write> dst: array<DataType>;
@group(0) @binding(3)
var<uniform> params: Params;
#endif
#endif
fn op(a: DataType, b: DataType) -> DataType {
#ifdef OP_ADD
@ -87,13 +115,17 @@ fn op(a: DataType, b: DataType) -> DataType {
#endif
}
fn update(dst_i: u32, src0_i: u32, src1_i: u32){
fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
#ifdef SRC_OVERLAP
let result = op(merged_src[src0_i], merged_src[src1_i]);
#else
let result = op(src0[src0_i], src1[src1_i]);
#endif
#ifdef INPLACE
src0[dst_i] = result;
src0[src0_i] = result;
#elif defined(OVERLAP)
src1[dst_i] = result;
src1[src1_i] = result;
#else
dst[dst_i] = result;
#endif
@ -102,6 +134,8 @@ fn update(dst_i: u32, src0_i: u32, src1_i: u32){
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
let src0_i = params.offset_src0 + params.offset_merged_src0 + src0_index(gid.x);
let src1_i = params.offset_src1 + params.offset_merged_src1 + src1_index(gid.x);
update(params.offset_dst + gid.x, src0_i, src1_i);
}
}

View File

@ -2977,6 +2977,7 @@ struct test_bin_bcast : public test_case {
const std::array<int, 4> nr;
int nf; // number of fused ops, nf == 1 -> single op (no fusion)
bool perm1; // permute src1?
bool src_overlap; // src0 and src1 are overlapping views of the same buffer
bool run_whole_graph() override { return nf > 1; }
@ -2992,8 +2993,8 @@ struct test_bin_bcast : public test_case {
std::array<int64_t, 4> ne = {10, 10, 1, 1},
std::array<int, 4> nr = {1, 2, 1, 1},
int nf = 1,
bool perm1 = false)
: op(op), type(type), ne(ne), nr(nr), nf(nf), perm1(perm1) {}
bool perm1 = false, bool src_overlap = false)
: op(op), type(type), ne(ne), nr(nr), nf(nf), perm1(perm1), src_overlap(src_overlap) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
GGML_ASSERT(nf <= 16);
@ -3008,6 +3009,8 @@ struct test_bin_bcast : public test_case {
b[i] = ggml_new_tensor_4d(ctx, type, ne[p[0]], ne[p[1]], ne[p[2]], ne[p[3]]);
b[i] = ggml_permute(ctx, b[i], p[0], p[1], p[2], p[3]);
} else if (src_overlap) {
b[i] = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], 2 * (ne[3] / 3), a->nb[1], a->nb[2], a->nb[3], (ne[3] / 3) * a->nb[3]);
} else {
b[i] = ggml_new_tensor(ctx, type, 4, ne.data());
}
@ -3021,7 +3024,13 @@ struct test_bin_bcast : public test_case {
ggml_set_param(b[0]);
}
ggml_tensor * out = a;
ggml_tensor *out;
if (src_overlap) {
out = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], 2 * (ne[3] / 3), a->nb[1], a->nb[2], a->nb[3], 0);
} else {
out = a;
}
for (int i = 0; i < nf; ++i) {
out = op(ctx, out, b[i]);
@ -7527,9 +7536,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr, bool perm1 = false) {
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr, bool perm1 = false, bool src_overlap = false) {
for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr, 1, perm1));
test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr, 1, perm1, src_overlap));
}
};
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
@ -7549,6 +7558,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2}, perm1);
}
// src_overlap
add_test_bin_bcast(type, {10, 5, 4, 6}, {1, 1, 1, 1}, false, true);
add_test_bin_bcast(type, {10, 5, 4, 5}, {1, 1, 1, 1}, false, true);
add_test_bin_bcast(type, {1, 1, 120, 120}, {1, 1, 1, 1}, false, true);
add_test_bin_bcast(type, {1, 1, 4, 320}, {1, 1, 1, 1}, false, true);
// test case for k_bin_bcast_unravel in CUDA backend
add_test_bin_bcast(type, {1, 1, 65536, 1}, {256, 1, 1, 1});