91 lines
1.7 KiB
WebGPU Shading Language
91 lines
1.7 KiB
WebGPU Shading Language
#define(VARIANTS)
|
|
|
|
[
|
|
{
|
|
"SHADER_NAME": "scale_f32",
|
|
"DECLS": ["NOT_INPLACE"]
|
|
},
|
|
{
|
|
"SHADER_NAME": "scale_f32_inplace",
|
|
"DECLS": ["INPLACE"]
|
|
}
|
|
]
|
|
|
|
#end(VARIANTS)
|
|
|
|
#define(DECLS)
|
|
|
|
#decl(NOT_INPLACE)
|
|
@group(0) @binding(1)
|
|
var<storage, read_write> dst: array<f32>;
|
|
|
|
@group(0) @binding(2)
|
|
var<uniform> params: Params;
|
|
|
|
fn store_scale(val: f32, offset: u32) {
|
|
dst[offset] = val;
|
|
}
|
|
#enddecl(NOT_INPLACE)
|
|
|
|
#decl(INPLACE)
|
|
@group(0) @binding(1)
|
|
var<uniform> params: Params;
|
|
|
|
fn store_scale(val: f32, offset: u32) {
|
|
src[offset] = val;
|
|
}
|
|
#enddecl(INPLACE)
|
|
|
|
#end(DECLS)
|
|
|
|
#define(SHADER)
|
|
|
|
struct Params {
|
|
offset_src: u32,
|
|
offset_dst: u32,
|
|
|
|
// Strides (in elements)
|
|
stride_src1: u32,
|
|
stride_src2: u32,
|
|
stride_src3: u32,
|
|
|
|
stride_dst1: u32,
|
|
stride_dst2: u32,
|
|
stride_dst3: u32,
|
|
|
|
ne: u32,
|
|
ne0: u32,
|
|
ne1: u32,
|
|
ne2: u32,
|
|
|
|
scale: f32,
|
|
bias: f32
|
|
};
|
|
|
|
@group(0) @binding(0)
|
|
var<storage, read_write> src: array<f32>;
|
|
|
|
DECLS
|
|
|
|
override wg_size: u32;
|
|
@compute @workgroup_size(wg_size)
|
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
if (gid.x >= params.ne) {
|
|
return;
|
|
}
|
|
|
|
var i = gid.x;
|
|
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
|
|
i = i % (params.ne2 * params.ne1 * params.ne0);
|
|
let i2 = i / (params.ne1 * params.ne0);
|
|
i = i % (params.ne1 * params.ne0);
|
|
let i1 = i / params.ne0;
|
|
let i0 = i % params.ne0;
|
|
|
|
let i_src = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1 + i0;
|
|
let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
|
|
|
|
store_scale(src[i_src] * params.scale + params.bias, i_dst);
|
|
}
|
|
#end(SHADER)
|