diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/neg.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/neg.wgsl index 7aa2a75ddd..23feb9aa7d 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/neg.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/neg.wgsl @@ -3,20 +3,19 @@ [ { "REPLS": { - "TYPE" : "f32", + "TYPE": "f32", } }, { "REPLS": { - "TYPE" : "f16", + "TYPE": "f16", } - } + }, ] #end(VARIANTS) #define(SHADER) - enable f16; @group(0) @binding(0) @@ -25,17 +24,64 @@ var src: array<{{TYPE}}>; @group(0) @binding(1) var dst: array<{{TYPE}}>; +struct Params { + ne: u32, // total number of elements + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) — may be permuted + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Logical shapes + src_ne0: u32, + src_ne1: u32, + src_ne2: u32, + + dst_ne0: u32, + dst_ne1: u32, + dst_ne2: u32 +}; + @group(0) @binding(2) var params: Params; - override wg_size: u32; @compute @workgroup_size(wg_size) fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - dst[gid.x] = -src[gid.x]; + if (gid.x >= params.ne) { + return; } -} + var i = gid.x; + let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); + i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); + let i2 = i / (params.src_ne1 * params.src_ne0); + i = i % (params.src_ne1 * params.src_ne0); + let i1 = i / params.src_ne0; + let i0 = i % params.src_ne0; -#end(SHADER) \ No newline at end of file + var j = gid.x; + let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + let j2 = j / (params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne1 * params.dst_ne0); + let j1 = j / params.dst_ne0; + let j0 = j % params.dst_ne0; + + let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + + i2 * params.stride_src2 + i3 * params.stride_src3; + + let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + + j2 * params.stride_dst2 + j3 * params.stride_dst3; + + dst[params.offset_dst + dst_idx] = -((src[params.offset_src + src_idx])); +} +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/neg_in_place.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/neg_in_place.wgsl index 1ca0b3a76b..732b56cea2 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/neg_in_place.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/neg_in_place.wgsl @@ -3,36 +3,82 @@ [ { "REPLS": { - "TYPE" : "f32", + "TYPE": "f32", } }, { "REPLS": { - "TYPE" : "f16", + "TYPE": "f16", } - } + }, ] #end(VARIANTS) #define(SHADER) - enable f16; @group(0) @binding(0) var src: array<{{TYPE}}>; +struct Params { + ne: u32, // total number of elements + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) — may be permuted + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Logical shapes + src_ne0: u32, + src_ne1: u32, + src_ne2: u32, + + dst_ne0: u32, + dst_ne1: u32, + dst_ne2: u32 +}; + @group(0) @binding(1) var params: Params; - override wg_size: u32; @compute @workgroup_size(wg_size) fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - src[gid.x] = -src[gid.x]; + if (gid.x >= params.ne) { + return; } -} + var i = gid.x; + let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); + i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); + let i2 = i / (params.src_ne1 * params.src_ne0); + i = i % (params.src_ne1 * params.src_ne0); + let i1 = i / params.src_ne0; + let i0 = i % params.src_ne0; -#end(SHADER) \ No newline at end of file + var j = gid.x; + let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + let j2 = j / (params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne1 * params.dst_ne0); + let j1 = j / params.dst_ne0; + let j0 = j % params.dst_ne0; + + let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + + i2 * params.stride_src2 + i3 * params.stride_src3; + + let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + + j2 * params.stride_dst2 + j3 * params.stride_dst3; + + dst[params.offset_dst + dst_idx] = -((src[params.offset_src + src_idx])); +} +#end(SHADER)