Get addition and multiplication working
This commit is contained in:
parent
7f9ee10e75
commit
c10219705d
|
|
@ -126,6 +126,9 @@ struct webgpu_context_struct {
|
||||||
wgpu::ComputePipeline set_rows_pipeline;
|
wgpu::ComputePipeline set_rows_pipeline;
|
||||||
wgpu::ComputePipeline cpy_pipeline;
|
wgpu::ComputePipeline cpy_pipeline;
|
||||||
wgpu::ComputePipeline add_pipeline[2];
|
wgpu::ComputePipeline add_pipeline[2];
|
||||||
|
wgpu::ComputePipeline add_ip_pipeline[2];
|
||||||
|
wgpu::ComputePipeline mul_pipeline[2];
|
||||||
|
wgpu::ComputePipeline mul_ip_pipeline[2];
|
||||||
|
|
||||||
size_t memset_bytes_per_thread;
|
size_t memset_bytes_per_thread;
|
||||||
|
|
||||||
|
|
@ -347,6 +350,7 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
|
||||||
std::vector<uint32_t> params,
|
std::vector<uint32_t> params,
|
||||||
std::vector<wgpu::BindGroupEntry> bind_group_entries,
|
std::vector<wgpu::BindGroupEntry> bind_group_entries,
|
||||||
uint32_t wg_x,
|
uint32_t wg_x,
|
||||||
|
const char * bind_group_label = nullptr,
|
||||||
bool submit_and_wait = false) {
|
bool submit_and_wait = false) {
|
||||||
webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
|
webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
|
||||||
|
|
||||||
|
|
@ -368,6 +372,9 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
|
||||||
bind_group_desc.layout = pipeline.GetBindGroupLayout(0);
|
bind_group_desc.layout = pipeline.GetBindGroupLayout(0);
|
||||||
bind_group_desc.entryCount = bind_group_entries.size();
|
bind_group_desc.entryCount = bind_group_entries.size();
|
||||||
bind_group_desc.entries = bind_group_entries.data();
|
bind_group_desc.entries = bind_group_entries.data();
|
||||||
|
if (bind_group_label) {
|
||||||
|
bind_group_desc.label = bind_group_label;
|
||||||
|
}
|
||||||
wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
|
wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
|
||||||
|
|
||||||
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
|
wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
|
||||||
|
|
@ -413,7 +420,7 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
|
||||||
};
|
};
|
||||||
size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread;
|
size_t bytes_per_wg = ctx->limits.maxComputeWorkgroupSizeX * ctx->memset_bytes_per_thread;
|
||||||
uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg;
|
uint32_t wg_x = ((size + 3) + bytes_per_wg - 1) / bytes_per_wg;
|
||||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, true);
|
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->memset_pipeline, params, entries, wg_x, "MEMSET", true);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** End WebGPU Actions */
|
/** End WebGPU Actions */
|
||||||
|
|
@ -457,6 +464,12 @@ static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor
|
||||||
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
|
~(WEBGPU_STORAGE_BUF_BINDING_MULT - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Used to determine if two tensors are the same for in-place operations
|
||||||
|
static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
|
||||||
|
return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
|
||||||
|
(ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||||
|
|
||||||
|
|
@ -485,7 +498,7 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor
|
||||||
|
|
||||||
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
|
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
|
||||||
uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size;
|
uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size;
|
||||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x);
|
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
|
static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) {
|
||||||
|
|
@ -537,7 +550,7 @@ static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_t
|
||||||
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
|
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
|
||||||
ctx->staged_set_row_error_bufs.push_back(error_bufs);
|
ctx->staged_set_row_error_bufs.push_back(error_bufs);
|
||||||
|
|
||||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x);
|
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->set_rows_pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
|
|
@ -577,10 +590,16 @@ static void ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_t
|
||||||
|
|
||||||
uint32_t wg_x =
|
uint32_t wg_x =
|
||||||
(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
|
(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3] + WEBGPU_MUL_MAT_WG_SIZE - 1) / WEBGPU_MUL_MAT_WG_SIZE;
|
||||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x);
|
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->mul_mat_pipeline[src0->type][src1->type], params, entries, wg_x,
|
||||||
|
ggml_op_name(dst->op));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_webgpu_add(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||||
|
ggml_tensor * src0,
|
||||||
|
ggml_tensor * src1,
|
||||||
|
ggml_tensor * dst,
|
||||||
|
wgpu::ComputePipeline & pipeline,
|
||||||
|
bool in_place) {
|
||||||
std::vector<uint32_t> params = {
|
std::vector<uint32_t> params = {
|
||||||
(uint32_t) ggml_nelements(dst),
|
(uint32_t) ggml_nelements(dst),
|
||||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||||
|
|
@ -607,16 +626,18 @@ static void ggml_webgpu_add(webgpu_context & ctx, ggml_tensor * src0, ggml_tenso
|
||||||
{ .binding = 1,
|
{ .binding = 1,
|
||||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) },
|
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }
|
||||||
{ .binding = 2,
|
};
|
||||||
|
if (!in_place) {
|
||||||
|
entries.push_back({ .binding = 2,
|
||||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
|
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||||
};
|
}
|
||||||
|
|
||||||
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
|
size_t max_wg_size = ctx->limits.maxComputeWorkgroupSizeX;
|
||||||
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
|
uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size;
|
||||||
ggml_backend_webgpu_build_and_enqueue(ctx, ctx->add_pipeline[dst->type], params, entries, wg_x);
|
ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if node has enqueued work into the queue, false otherwise
|
// Returns true if node has enqueued work into the queue, false otherwise
|
||||||
|
|
@ -654,7 +675,20 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
|
||||||
}
|
}
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
{
|
{
|
||||||
ggml_webgpu_add(ctx, src0, src1, node);
|
if (ggml_webgpu_tensor_equal(src0, node)) {
|
||||||
|
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_ip_pipeline[node->type], true);
|
||||||
|
} else {
|
||||||
|
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type], false);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case GGML_OP_MUL:
|
||||||
|
{
|
||||||
|
if (ggml_webgpu_tensor_equal(src0, node)) {
|
||||||
|
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_ip_pipeline[node->type], true);
|
||||||
|
} else {
|
||||||
|
ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type], false);
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
|
@ -994,8 +1028,28 @@ static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
std::vector<wgpu::ConstantEntry> constants(1);
|
std::vector<wgpu::ConstantEntry> constants(1);
|
||||||
constants[0].key = "wg_size";
|
constants[0].key = "wg_size";
|
||||||
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
|
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32", constants);
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32",
|
||||||
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16", constants);
|
constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16",
|
||||||
|
constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F32], wgsl_add_in_place_f32,
|
||||||
|
"add_in_place_f32", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F16], wgsl_add_in_place_f16,
|
||||||
|
"add_in_place_f16", constants);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
|
||||||
|
std::vector<wgpu::ConstantEntry> constants(1);
|
||||||
|
constants[0].key = "wg_size";
|
||||||
|
constants[0].value = webgpu_ctx->limits.maxComputeWorkgroupSizeX;
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32], wgsl_mul_f32, "mul_f32",
|
||||||
|
constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16], wgsl_mul_f16, "mul_f16",
|
||||||
|
constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F32], wgsl_mul_in_place_f32,
|
||||||
|
"mul_in_place_f32", constants);
|
||||||
|
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F16], wgsl_mul_in_place_f16,
|
||||||
|
"mul_in_place_f16", constants);
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
|
static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
|
||||||
|
|
@ -1048,22 +1102,30 @@ static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggm
|
||||||
static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
||||||
GGML_UNUSED(dev);
|
GGML_UNUSED(dev);
|
||||||
|
|
||||||
|
bool supports_op = false;
|
||||||
switch (op->op) {
|
switch (op->op) {
|
||||||
case GGML_OP_NONE:
|
case GGML_OP_NONE:
|
||||||
case GGML_OP_VIEW:
|
case GGML_OP_VIEW:
|
||||||
case GGML_OP_PERMUTE:
|
case GGML_OP_PERMUTE:
|
||||||
case GGML_OP_TRANSPOSE:
|
case GGML_OP_TRANSPOSE:
|
||||||
case GGML_OP_RESHAPE:
|
case GGML_OP_RESHAPE:
|
||||||
|
supports_op = true;
|
||||||
|
break;
|
||||||
case GGML_OP_ADD:
|
case GGML_OP_ADD:
|
||||||
return true;
|
case GGML_OP_MUL:
|
||||||
|
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type) &&
|
||||||
|
(op->src[1]->type == op->type);
|
||||||
|
break;
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
case GGML_OP_SET_ROWS:
|
case GGML_OP_SET_ROWS:
|
||||||
return op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32;
|
supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32);
|
||||||
|
break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
{
|
{
|
||||||
switch (op->src[1]->type) {
|
switch (op->src[1]->type) {
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
return op->src[0]->type == GGML_TYPE_F16;
|
supports_op = (op->src[0]->type == GGML_TYPE_F16);
|
||||||
|
break;
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
switch (op->src[0]->type) {
|
switch (op->src[0]->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
|
|
@ -1087,17 +1149,26 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
|
||||||
case GGML_TYPE_IQ1_M:
|
case GGML_TYPE_IQ1_M:
|
||||||
case GGML_TYPE_IQ4_NL:
|
case GGML_TYPE_IQ4_NL:
|
||||||
case GGML_TYPE_IQ4_XS:
|
case GGML_TYPE_IQ4_XS:
|
||||||
return true;
|
supports_op = true;
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
break;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return false;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return false;
|
break;
|
||||||
}
|
}
|
||||||
|
#ifdef GGML_WEBGPU_DEBUG
|
||||||
|
if (!supports_op) {
|
||||||
|
WEBGPU_LOG_DEBUG("not supported: " << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
|
||||||
|
<< ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
|
||||||
|
<< ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return supports_op;
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
|
static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
|
||||||
|
|
@ -1210,6 +1281,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
|
||||||
ggml_webgpu_init_set_rows_pipeline(ctx);
|
ggml_webgpu_init_set_rows_pipeline(ctx);
|
||||||
ggml_webgpu_init_cpy_pipeline(ctx);
|
ggml_webgpu_init_cpy_pipeline(ctx);
|
||||||
ggml_webgpu_init_add_pipeline(ctx);
|
ggml_webgpu_init_add_pipeline(ctx);
|
||||||
|
ggml_webgpu_init_mul_pipeline(ctx);
|
||||||
|
|
||||||
#ifdef GGML_WEBGPU_DEBUG
|
#ifdef GGML_WEBGPU_DEBUG
|
||||||
// Initialize debug buffers
|
// Initialize debug buffers
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,8 @@
|
||||||
|
|
||||||
enable f16;
|
enable f16;
|
||||||
|
|
||||||
|
#include "binary_head.tmpl"
|
||||||
|
|
||||||
@group(0) @binding(0)
|
@group(0) @binding(0)
|
||||||
var<storage, read_write> src0: array<{{TYPE}}>;
|
var<storage, read_write> src0: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
|
@ -28,71 +30,15 @@ var<storage, read_write> src1: array<{{TYPE}}>;
|
||||||
@group(0) @binding(2)
|
@group(0) @binding(2)
|
||||||
var<storage, read_write> dst: array<{{TYPE}}>;
|
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||||
|
|
||||||
struct Params {
|
|
||||||
ne: u32,
|
|
||||||
|
|
||||||
// offsets in elements
|
|
||||||
offset_src0: u32,
|
|
||||||
offset_src1: u32,
|
|
||||||
offset_dst: u32,
|
|
||||||
|
|
||||||
stride_src1_0: u32,
|
|
||||||
stride_src1_1: u32,
|
|
||||||
stride_src1_2: u32,
|
|
||||||
stride_src1_3: u32,
|
|
||||||
|
|
||||||
a_ne0: u32,
|
|
||||||
a_ne1: u32,
|
|
||||||
a_ne2: u32,
|
|
||||||
|
|
||||||
b_ne0: u32,
|
|
||||||
b_ne1: u32,
|
|
||||||
b_ne2: u32,
|
|
||||||
b_ne3: u32,
|
|
||||||
};
|
|
||||||
|
|
||||||
@group(0) @binding(3)
|
@group(0) @binding(3)
|
||||||
var<uniform> params: Params;
|
var<uniform> params: Params;
|
||||||
|
|
||||||
override wg_size: u32;
|
override wg_size: u32;
|
||||||
@compute @workgroup_size(wg_size)
|
@compute @workgroup_size(wg_size)
|
||||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
if (gid.x >= params.ne) {
|
if (gid.x < params.ne) {
|
||||||
return;
|
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)];
|
||||||
}
|
}
|
||||||
|
|
||||||
// i = thread id, ranges from 0 --> total ne - 1
|
|
||||||
// represents the position in the flat array a we are adding with array b
|
|
||||||
var i = gid.x;
|
|
||||||
|
|
||||||
// given the index of linear a, we want to compute the 4d index [a_i0, a_i1, a_i2, a_i3]
|
|
||||||
// we need this because tensor a and b are different shapes
|
|
||||||
// so the same linear index won't work for b, and we can only compute b's linear index from the 4d index of a
|
|
||||||
|
|
||||||
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
|
|
||||||
i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
|
|
||||||
let a_i2 = i / (params.a_ne1 * params.a_ne0);
|
|
||||||
i = i % (params.a_ne1 * params.a_ne0);
|
|
||||||
let a_i1 = i / params.a_ne0;
|
|
||||||
let a_i0 = i % params.a_ne0;
|
|
||||||
|
|
||||||
// handle repetition of b
|
|
||||||
// index loops back to the beginning and repeats after elements are exhausted = modulo
|
|
||||||
let b_i0 = a_i0 % params.b_ne0;
|
|
||||||
let b_i1 = a_i1 % params.b_ne1;
|
|
||||||
let b_i2 = a_i2 % params.b_ne2;
|
|
||||||
let b_i3 = a_i3 % params.b_ne3;
|
|
||||||
|
|
||||||
// compute index for position in b's flat array
|
|
||||||
let src1_idx = b_i0 * params.stride_src1_0 +
|
|
||||||
b_i1 * params.stride_src1_1 +
|
|
||||||
b_i2 * params.stride_src1_2 +
|
|
||||||
b_i3 * params.stride_src1_3;
|
|
||||||
|
|
||||||
// actual addition operation, now that the indexes are all figured out
|
|
||||||
// ensuring that the offsets are included
|
|
||||||
// gid.x used for flat indexing into dst and a, since variable i was modified during calcs
|
|
||||||
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_idx];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#end(SHADER)
|
#end(SHADER)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
#define(VARIANTS)
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
#end(VARIANTS)
|
||||||
|
|
||||||
|
#define(SHADER)
|
||||||
|
|
||||||
|
enable f16;
|
||||||
|
|
||||||
|
#include "binary_head.tmpl"
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> src0: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> src1: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
override wg_size: u32;
|
||||||
|
@compute @workgroup_size(wg_size)
|
||||||
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
|
if (gid.x < params.ne) {
|
||||||
|
src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#end(SHADER)
|
||||||
|
|
@ -0,0 +1,45 @@
|
||||||
|
struct Params {
|
||||||
|
ne: u32,
|
||||||
|
|
||||||
|
// offsets in elements
|
||||||
|
offset_src0: u32,
|
||||||
|
offset_src1: u32,
|
||||||
|
offset_dst: u32,
|
||||||
|
|
||||||
|
stride_src1_0: u32,
|
||||||
|
stride_src1_1: u32,
|
||||||
|
stride_src1_2: u32,
|
||||||
|
stride_src1_3: u32,
|
||||||
|
|
||||||
|
a_ne0: u32,
|
||||||
|
a_ne1: u32,
|
||||||
|
a_ne2: u32,
|
||||||
|
|
||||||
|
b_ne0: u32,
|
||||||
|
b_ne1: u32,
|
||||||
|
b_ne2: u32,
|
||||||
|
b_ne3: u32,
|
||||||
|
};
|
||||||
|
|
||||||
|
fn src1_index(_i: u32) -> u32 {
|
||||||
|
var i = _i;
|
||||||
|
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
|
||||||
|
i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
|
||||||
|
let a_i2 = i / (params.a_ne1 * params.a_ne0);
|
||||||
|
i = i % (params.a_ne1 * params.a_ne0);
|
||||||
|
let a_i1 = i / params.a_ne0;
|
||||||
|
let a_i0 = i % params.a_ne0;
|
||||||
|
|
||||||
|
// handle repetition of b
|
||||||
|
// index loops back to the beginning and repeats after elements are exhausted = modulo
|
||||||
|
let b_i0 = a_i0 % params.b_ne0;
|
||||||
|
let b_i1 = a_i1 % params.b_ne1;
|
||||||
|
let b_i2 = a_i2 % params.b_ne2;
|
||||||
|
let b_i3 = a_i3 % params.b_ne3;
|
||||||
|
|
||||||
|
// compute index for position in b's flat array
|
||||||
|
return b_i0 * params.stride_src1_0 +
|
||||||
|
b_i1 * params.stride_src1_1 +
|
||||||
|
b_i2 * params.stride_src1_2 +
|
||||||
|
b_i3 * params.stride_src1_3;
|
||||||
|
}
|
||||||
|
|
@ -26,6 +26,24 @@ def replace_placeholders(shader_text, replacements):
|
||||||
shader_text = re.sub(pattern, str(val), shader_text)
|
shader_text = re.sub(pattern, str(val), shader_text)
|
||||||
return shader_text
|
return shader_text
|
||||||
|
|
||||||
|
def expand_includes(shader, input_dir):
|
||||||
|
"""
|
||||||
|
Replace #include "file" lines in the text with the contents of that file.
|
||||||
|
Searches for files relative to input_dir.
|
||||||
|
"""
|
||||||
|
include_pattern = re.compile(r'^\s*#include\s+"([^"]+)"\s*$', re.MULTILINE)
|
||||||
|
|
||||||
|
def replacer(match):
|
||||||
|
fname = match.group(1)
|
||||||
|
file_path = os.path.join(input_dir, fname)
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"Included file not found: {file_path}")
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
included_code = f.read()
|
||||||
|
# Recursively expand includes inside the included file
|
||||||
|
return expand_includes(included_code, input_dir)
|
||||||
|
|
||||||
|
return include_pattern.sub(replacer, shader)
|
||||||
|
|
||||||
def write_shader(shader_name, shader_code, output_dir, outfile):
|
def write_shader(shader_name, shader_code, output_dir, outfile):
|
||||||
if output_dir:
|
if output_dir:
|
||||||
|
|
@ -35,8 +53,9 @@ def write_shader(shader_name, shader_code, output_dir, outfile):
|
||||||
outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n')
|
outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n')
|
||||||
|
|
||||||
|
|
||||||
def generate_variants(shader_path, output_dir, outfile):
|
def generate_variants(fname, input_dir, output_dir, outfile):
|
||||||
shader_base_name = shader_path.split("/")[-1].split(".")[0]
|
shader_path = os.path.join(input_dir, fname)
|
||||||
|
shader_base_name = fname.split(".")[0]
|
||||||
|
|
||||||
with open(shader_path, "r", encoding="utf-8") as f:
|
with open(shader_path, "r", encoding="utf-8") as f:
|
||||||
text = f.read()
|
text = f.read()
|
||||||
|
|
@ -52,6 +71,7 @@ def generate_variants(shader_path, output_dir, outfile):
|
||||||
decls_map = {}
|
decls_map = {}
|
||||||
|
|
||||||
shader_template = extract_block(text, "SHADER")
|
shader_template = extract_block(text, "SHADER")
|
||||||
|
shader_template = expand_includes(shader_template, input_dir)
|
||||||
for variant in variants:
|
for variant in variants:
|
||||||
if "DECLS" in variant:
|
if "DECLS" in variant:
|
||||||
decls = variant["DECLS"]
|
decls = variant["DECLS"]
|
||||||
|
|
@ -89,7 +109,7 @@ def main():
|
||||||
out.write("// Auto-generated shader embedding\n\n")
|
out.write("// Auto-generated shader embedding\n\n")
|
||||||
for fname in sorted(os.listdir(args.input_dir)):
|
for fname in sorted(os.listdir(args.input_dir)):
|
||||||
if fname.endswith(".wgsl"):
|
if fname.endswith(".wgsl"):
|
||||||
generate_variants(os.path.join(args.input_dir, fname), args.output_dir, out)
|
generate_variants(fname, args.input_dir, args.output_dir, out)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
#define(VARIANTS)
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
#end(VARIANTS)
|
||||||
|
|
||||||
|
#define(SHADER)
|
||||||
|
|
||||||
|
enable f16;
|
||||||
|
|
||||||
|
#include "binary_head.tmpl"
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> src0: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> src1: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(3)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
override wg_size: u32;
|
||||||
|
@compute @workgroup_size(wg_size)
|
||||||
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
|
if (gid.x < params.ne) {
|
||||||
|
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#end(SHADER)
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
#define(VARIANTS)
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f32",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"REPLS": {
|
||||||
|
"TYPE" : "f16",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
#end(VARIANTS)
|
||||||
|
|
||||||
|
#define(SHADER)
|
||||||
|
|
||||||
|
enable f16;
|
||||||
|
|
||||||
|
#include "binary_head.tmpl"
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> src0: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> src1: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
override wg_size: u32;
|
||||||
|
@compute @workgroup_size(wg_size)
|
||||||
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
|
if (gid.x < params.ne) {
|
||||||
|
src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#end(SHADER)
|
||||||
Loading…
Reference in New Issue