some f32 tests passing

This commit is contained in:
Neha Abbas 2025-08-01 14:35:20 -05:00
parent 1d5726a26a
commit 96d107e505
2 changed files with 187 additions and 0 deletions

View File

@ -128,6 +128,7 @@ struct webgpu_context_struct {
wgpu::ComputePipeline memset_pipeline;
wgpu::ComputePipeline mul_mat_pipeline;
wgpu::ComputePipeline cpy_pipeline;
wgpu::ComputePipeline add_pipeline;
size_t memset_bytes_per_thread;
@ -450,6 +451,95 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline, params, entries, wg_x);
}
// adds src0 and src1 and puts in dst
static void ggml_webgpu_add(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
// each tensor in GGML is stored inside a buffer on the GPU
// but that buffer may contain more than one tensors data (or it might be a subregion of a larger buffer)
// offset = starting byte position inside that buffer where this tensors data actually begins
size_t src0_offset = ggml_backend_webgpu_tensor_offset(src0);
// assumes power of 2 offset alignment
size_t src0_misalignment = src0_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
// align to minimum offset alignment
src0_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
size_t src1_offset = ggml_backend_webgpu_tensor_offset(src1);
// assumes power of 2 offset alignment
size_t src1_misalignment = src1_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
// align to minimum offset alignment
src1_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
size_t dst_offset = ggml_backend_webgpu_tensor_offset(dst);
size_t dst_misalignment = dst_offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
dst_offset &= ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
// set up parameters
std::vector<uint32_t> params = {
// number of elements-- determines how many threads to dispatch (one for each addition operation)
(uint32_t) ggml_nelements(dst),
// even though tensors are 4d, the actual data is stored linearly
// stride = how many elements (or bytes) we must skip in memory to move from one value to another along a certain dimension
// i.e.
// nb[0] = 1 // each element is next to the previous
// nb[1] = nb[0] * ne[0] = 5 // to move to next row, skip 5 elements
// nb[2] = nb[1] * ne[1] = 20 // to next matrix, skip 20 elements
// nb[3] = nb[2] * ne[2] = 60 // to next batch, skip 60 elements
// calculate element strides for each tensor
(uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
(uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
(uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
(uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
// number of elements in each dimension
(uint32_t) dst->ne[0],
(uint32_t) dst->ne[1],
(uint32_t) dst->ne[2],
(uint32_t) dst->ne[3],
// offsets in terms of elements instead of bytes
(uint32_t) (src0_misalignment / ggml_type_size(src0->type)),
(uint32_t) (src1_misalignment / ggml_type_size(src1->type)),
(uint32_t) (dst_misalignment / ggml_type_size(dst->type)),
};
// bind group = groups together several GPU resources that shaders will use (e.g., buffers holding tensor data)
// bind group entry describes one resource within the bind group (in this case, one tensor)
// offset + size: specify exactly where in the gpu buffer the shader should read/write
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_backend_webgpu_tensor_buf(src0),
.offset = src0_offset,
.size = (ggml_nbytes(src0) + src0_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) },
{ .binding = 1,
.buffer = ggml_backend_webgpu_tensor_buf(src1),
.offset = src1_offset,
.size = (ggml_nbytes(src1) + src1_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) },
{ .binding = 2,
.buffer = ggml_backend_webgpu_tensor_buf(dst),
.offset = dst_offset,
.size = (ggml_nbytes(dst) + dst_misalignment + WEBGPU_STORAGE_BUF_BINDING_MULT - 1) & ~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1) }
};
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX; // max threads in a single workgroup
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; // number of workgroups to dispatch to cover all elements
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->add_pipeline, params, entries, wg_x); // dispatch shader
}
// Returns true if node has enqueued work into the queue, false otherwise
static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
if (ggml_is_empty(node)) {
@ -476,6 +566,11 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
ggml_webgpu_mul_mat(ctx, src0, src1, node);
break;
}
case GGML_OP_ADD:
{
ggml_webgpu_add(ctx, src0, src1, node);
break;
}
default:
return false;
}
@ -759,6 +854,14 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", constants);
}
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants(1);
constants[0].key = "wg_size";
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline, wgsl_add, "add", constants);
}
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
GGML_UNUSED(params);
@ -812,6 +915,7 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
ggml_webgpu_init_memset_pipeline(webgpu_ctx);
ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
ggml_webgpu_init_add_pipeline(webgpu_ctx);
webgpu_ctx->device_init = true;
}
@ -866,6 +970,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_MUL_MAT:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_ADD:
// return (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) &&
// (op->src[0]->type == op->type) &&
// (op->src[1]->type == op->type);
return op->type == GGML_TYPE_F32;
default:
return false;
}

View File

@ -0,0 +1,78 @@
enable f16;
@group(0) @binding(0)
var<storage, read_write> src0: array<f32>;
@group(0) @binding(1)
var<storage, read_write> src1: array<f32>;
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
struct Params {
ne: u32, // total number of elements
stride_src0_0: u32,
stride_src0_1: u32,
stride_src0_2: u32,
stride_src0_3: u32,
stride_src1_0: u32,
stride_src1_1: u32,
stride_src1_2: u32,
stride_src1_3: u32,
stride_dst_0: u32,
stride_dst_1: u32,
stride_dst_2: u32,
stride_dst_3: u32,
ne0: u32,
ne1: u32,
ne2: u32,
ne3: u32,
// offsets in elements
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
};
@group(0) @binding(3)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
var i = gid.x; // i = thread id
// compute indexes for each dimension of the tensor
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
i = i % (params.ne1 * params.ne0);
let i1 = i / params.ne0;
let i0 = i % params.ne0;
// compute indexes for position in each flat array
let src0_idx = i0 * params.stride_src0_0 + i1 * params.stride_src0_1 +
i2 * params.stride_src0_2 + i3 * params.stride_src0_3;
let src1_idx = i0 * params.stride_src1_0 + i1 * params.stride_src1_1 +
i2 * params.stride_src1_2 + i3 * params.stride_src1_3;
let dst_idx = i0 * params.stride_dst_0 + i1 * params.stride_dst_1 +
i2 * params.stride_dst_2 + i3 * params.stride_dst_3;
// dst[dst_idx] = src0[src0_idx] + src1[src1_idx];
dst[params.offset_dst + dst_idx] = src0[params.offset_src0 + src0_idx] + src1[params.offset_src1 + src1_idx];
}