f32 add all tests passing
This commit is contained in:
parent
96d107e505
commit
39aa11d9a4
|
|
@ -451,6 +451,16 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
|
||||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline, params, entries, wg_x);
|
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline, params, entries, wg_x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sample test
|
||||||
|
// ADD(type=f32, ne=[10,5,4,3], nr=[2,1,1,1], nf=1)
|
||||||
|
// ne: number of elements in each dimension of tensor b
|
||||||
|
// nr: number of repetitions in each dimension
|
||||||
|
// tensor b is the smaller tensor, and is broadcasted with repetitions to match the size of a
|
||||||
|
// broadcasted with ne * nr
|
||||||
|
// 10*2, 5*1, 4*1, 3*1 = [20, 5, 4, 3] is the shape of dst and a
|
||||||
|
// essentially, if nr[x] is > 1, that dimension of b is repeated
|
||||||
|
// nf: number of fused operations (1 means singular addition)
|
||||||
|
|
||||||
// adds src0 and src1 and puts in dst
|
// adds src0 and src1 and puts in dst
|
||||||
static void ggml_webgpu_add(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_webgpu_add(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
// each tensor in GGML is stored inside a buffer on the GPU
|
// each tensor in GGML is stored inside a buffer on the GPU
|
||||||
|
|
@ -464,15 +474,13 @@ static void ggml_webgpu_add(webgpu_context & ctx, ggml_tensor * src0, ggml_tenso
|
||||||
src0_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
src0_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||||
|
|
||||||
size_t src1_offset = ggml_backend_webgpu_tensor_offset(src1);
|
size_t src1_offset = ggml_backend_webgpu_tensor_offset(src1);
|
||||||
// assumes power of 2 offset alignment
|
|
||||||
size_t src1_misalignment = src1_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
size_t src1_misalignment = src1_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||||
// align to minimum offset alignment
|
|
||||||
src1_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
src1_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||||
|
|
||||||
size_t dst_offset = ggml_backend_webgpu_tensor_offset(dst);
|
size_t dst_offset = ggml_backend_webgpu_tensor_offset(dst);
|
||||||
size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||||
dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
|
||||||
|
|
||||||
// set up parameters
|
// set up parameters
|
||||||
std::vector<uint32_t> params = {
|
std::vector<uint32_t> params = {
|
||||||
// number of elements-- determines how many threads to dispatch (one for each addition operation)
|
// number of elements-- determines how many threads to dispatch (one for each addition operation)
|
||||||
|
|
@ -480,11 +488,13 @@ static void ggml_webgpu_add(webgpu_context & ctx, ggml_tensor * src0, ggml_tenso
|
||||||
|
|
||||||
// even though tensors are 4d, the actual data is stored linearly
|
// even though tensors are 4d, the actual data is stored linearly
|
||||||
// stride = how many elements (or bytes) we must skip in memory to move from one value to another along a certain dimension
|
// stride = how many elements (or bytes) we must skip in memory to move from one value to another along a certain dimension
|
||||||
// i.e.
|
// i.e. tensor: [5, 6, 3, 2], ggml_type_size: 4 (each number is 4 bytes)
|
||||||
// nb[0] = 1 // each element is next to the previous
|
// (nb = number of bytes to skip for each element (stride))
|
||||||
// nb[1] = nb[0] * ne[0] = 5 // to move to next row, skip 5 elements
|
// (ne = number of elements in that dimension)
|
||||||
// nb[2] = nb[1] * ne[1] = 20 // to next matrix, skip 20 elements
|
// nb[0] = 4 // each element is next to the previous, so only 4 bytes in between
|
||||||
// nb[3] = nb[2] * ne[2] = 60 // to next batch, skip 60 elements
|
// nb[1] = nb[0] * ne[0] = 4 * 5 = 20 // to move to next row, skip 20 bytes
|
||||||
|
// nb[2] = nb[1] * ne[1] = 20 * 6 = 120 // to next matrix, skip 120 elements
|
||||||
|
// nb[3] = nb[2] * ne[2] = 120 * 3 = 360 // to next batch, skip 60 elements
|
||||||
|
|
||||||
// calculate element strides for each tensor
|
// calculate element strides for each tensor
|
||||||
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
|
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
|
||||||
|
|
@ -502,16 +512,24 @@ static void ggml_webgpu_add(webgpu_context & ctx, ggml_tensor * src0, ggml_tenso
|
||||||
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
|
||||||
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
|
||||||
|
|
||||||
// number of elements in each dimension
|
// number of elements in each dimension of larger tensors (src0 and dst)
|
||||||
(uint32_t) dst->ne[0],
|
(uint32_t) dst->ne[0],
|
||||||
(uint32_t) dst->ne[1],
|
(uint32_t) dst->ne[1],
|
||||||
(uint32_t) dst->ne[2],
|
(uint32_t) dst->ne[2],
|
||||||
(uint32_t) dst->ne[3],
|
(uint32_t) dst->ne[3],
|
||||||
|
|
||||||
|
// number of elements in each dimension of smaller tensor to be broadcasted (src1)
|
||||||
|
(uint32_t) src1->ne[0],
|
||||||
|
(uint32_t) src1->ne[1],
|
||||||
|
(uint32_t) src1->ne[2],
|
||||||
|
(uint32_t) src1->ne[3],
|
||||||
|
|
||||||
// offsets in terms of elements instead of bytes
|
// offsets in terms of elements instead of bytes
|
||||||
(uint32_t) (src0_misalignment / ggml_type_size(src0->type)),
|
(uint32_t) (src0_misalignment / ggml_type_size(src0->type)),
|
||||||
(uint32_t) (src1_misalignment / ggml_type_size(src1->type)),
|
(uint32_t) (src1_misalignment / ggml_type_size(src1->type)),
|
||||||
(uint32_t) (dst_misalignment / ggml_type_size(dst->type)),
|
(uint32_t) (dst_misalignment / ggml_type_size(dst->type)),
|
||||||
|
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// bind group = groups together several GPU resources that shaders will use (e.g., buffers holding tensor data)
|
// bind group = groups together several GPU resources that shaders will use (e.g., buffers holding tensor data)
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ var<storage, read_write> src1: array<f32>;
|
||||||
var<storage, read_write> dst: array<f32>;
|
var<storage, read_write> dst: array<f32>;
|
||||||
|
|
||||||
struct Params {
|
struct Params {
|
||||||
ne: u32, // total number of elements
|
ne: u32,
|
||||||
|
|
||||||
stride_src0_0: u32,
|
stride_src0_0: u32,
|
||||||
stride_src0_1: u32,
|
stride_src0_1: u32,
|
||||||
|
|
@ -27,10 +27,15 @@ struct Params {
|
||||||
stride_dst_2: u32,
|
stride_dst_2: u32,
|
||||||
stride_dst_3: u32,
|
stride_dst_3: u32,
|
||||||
|
|
||||||
ne0: u32,
|
a_ne0: u32,
|
||||||
ne1: u32,
|
a_ne1: u32,
|
||||||
ne2: u32,
|
a_ne2: u32,
|
||||||
ne3: u32,
|
a_ne3: u32,
|
||||||
|
|
||||||
|
b_ne0: u32,
|
||||||
|
b_ne1: u32,
|
||||||
|
b_ne2: u32,
|
||||||
|
b_ne3: u32,
|
||||||
|
|
||||||
// offsets in elements
|
// offsets in elements
|
||||||
offset_src0: u32,
|
offset_src0: u32,
|
||||||
|
|
@ -48,31 +53,41 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
var i = gid.x; // i = thread id
|
// i = thread id, ranges from 0 --> total ne - 1
|
||||||
|
// represents the position in the flat array a we are adding with array b
|
||||||
|
var i = gid.x;
|
||||||
|
|
||||||
// compute indexes for each dimension of the tensor
|
// given the index of linear a, we want to compute the 4d index [a_i0, a_i1, a_i2, a_i3]
|
||||||
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
|
// we need this because tensor a and b are different shapes
|
||||||
i = i % (params.ne2 * params.ne1 * params.ne0);
|
// so the same linear index won't work for b, and we can only compute b's linear index from the 4d index of a
|
||||||
|
|
||||||
|
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
|
||||||
|
i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
|
||||||
|
|
||||||
let i2 = i / (params.ne1 * params.ne0);
|
let a_i2 = i / (params.a_ne1 * params.a_ne0);
|
||||||
i = i % (params.ne1 * params.ne0);
|
i = i % (params.a_ne1 * params.a_ne0);
|
||||||
|
|
||||||
let i1 = i / params.ne0;
|
let a_i1 = i / params.a_ne0;
|
||||||
|
|
||||||
let i0 = i % params.ne0;
|
let a_i0 = i % params.a_ne0;
|
||||||
|
|
||||||
// compute indexes for position in each flat array
|
|
||||||
let src0_idx = i0 * params.stride_src0_0 + i1 * params.stride_src0_1 +
|
|
||||||
i2 * params.stride_src0_2 + i3 * params.stride_src0_3;
|
|
||||||
|
|
||||||
let src1_idx = i0 * params.stride_src1_0 + i1 * params.stride_src1_1 +
|
|
||||||
i2 * params.stride_src1_2 + i3 * params.stride_src1_3;
|
|
||||||
|
|
||||||
let dst_idx = i0 * params.stride_dst_0 + i1 * params.stride_dst_1 +
|
|
||||||
i2 * params.stride_dst_2 + i3 * params.stride_dst_3;
|
|
||||||
|
|
||||||
|
|
||||||
// dst[dst_idx] = src0[src0_idx] + src1[src1_idx];
|
// handle repetition of b
|
||||||
|
// index loops back to the beginning and repeats after elements are exhausted = modulo
|
||||||
|
let b_i0 = a_i0 % params.b_ne0;
|
||||||
|
let b_i1 = a_i1 % params.b_ne1;
|
||||||
|
let b_i2 = a_i2 % params.b_ne2;
|
||||||
|
let b_i3 = a_i3 % params.b_ne3;
|
||||||
|
|
||||||
dst[params.offset_dst + dst_idx] = src0[params.offset_src0 + src0_idx] + src1[params.offset_src1 + src1_idx];
|
|
||||||
|
// compute index for position in b's flat array
|
||||||
|
let src1_idx = b_i0 * params.stride_src1_0 +
|
||||||
|
b_i1 * params.stride_src1_1 +
|
||||||
|
b_i2 * params.stride_src1_2 +
|
||||||
|
b_i3 * params.stride_src1_3;
|
||||||
|
|
||||||
|
// actual addition operation, now that the indexes are all figured out
|
||||||
|
// ensuring that the offsets are included
|
||||||
|
// gid.x used for flat indexing into dst and a, since variable i was modified during calcs
|
||||||
|
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_idx];
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue