ggml-webgpu: JIT compile binary operators and handle binding overlaps (#19310)
* ggml webgpu: port binary operators to use pre-wgsl * Add binary.wgsl: unified shader with conditionals for all 4 ops * Add gen_binary_shaders.cpp: build tool for using pre_wgsl preprocessor * Remove bin_op.tmpl.wgsl and binary.wgsl (Python template) * Update CMake to generate binary operator shaders at build time * ggml-webgpu: migrate binary ops to JIT compilation with overlap handling * port binary operators from AOT to pre-wgsl JIT compilation * add src1=dst overlap handling for binary ops * use compile-time workgroup size defines instead of runtime overrides * ggml-webgpu: complete overlap handling for binary ops * add support for inplace & overlap case in binding setup * restructure conditional logic to handle all overlap cases * ensure all buffer bindings are correctly assigned for edge cases * ggml-webgpu: remove unused binary overlap cases Remove src0==src1 binary overlap case that never occurs in practice. * keep INPLACE (src0==dst), OVERLAP (src1==dst), DEFAULT * remove unused src0==src1 and all-same variant * refactor wgsl to eliminate duplication
This commit is contained in:
parent
537eadb1b9
commit
7fbd36c50c
|
|
@ -465,4 +465,73 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader(
|
|||
return result;
|
||||
}
|
||||
|
||||
/** Binary **/
|
||||
|
||||
struct ggml_webgpu_binary_pipeline_key {
|
||||
int type;
|
||||
int op;
|
||||
bool inplace;
|
||||
bool overlap;
|
||||
|
||||
bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
|
||||
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_binary_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.op);
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
ggml_webgpu_hash_combine(seed, key.overlap);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_binary_shader_lib_context {
|
||||
ggml_webgpu_binary_pipeline_key key;
|
||||
uint32_t max_wg_size;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_binary_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_binary_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string op_name = ggml_op_name((ggml_op) context.key.op);
|
||||
std::string variant = op_name;
|
||||
|
||||
defines.push_back(std::string("OP_") + op_name);
|
||||
|
||||
switch (context.key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("TYPE_F32");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("TYPE_F16");
|
||||
variant += "_f16";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for binary shader");
|
||||
}
|
||||
|
||||
if (context.key.inplace) {
|
||||
defines.push_back("INPLACE");
|
||||
variant += "_inplace";
|
||||
} else if (context.key.overlap) {
|
||||
defines.push_back("OVERLAP");
|
||||
variant += "_overlap";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
result.decisions = decisions;
|
||||
return result;
|
||||
}
|
||||
#endif // GGML_WEBGPU_SHADER_LIB_HPP
|
||||
|
|
|
|||
|
|
@ -348,13 +348,12 @@ struct webgpu_context_struct {
|
|||
|
||||
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
|
||||
set_rows_pipelines;
|
||||
std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
|
||||
std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
|
||||
|
||||
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
|
||||
std::map<int, std::map<int, webgpu_pipeline>> add_pipelines; // type, inplace
|
||||
std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines; // type, inplace
|
||||
std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines; // type, inplace
|
||||
std::map<int, std::map<int, webgpu_pipeline>> div_pipelines; // type, inplace
|
||||
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
|
||||
|
||||
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
|
||||
binary_pipelines;
|
||||
|
||||
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
|
||||
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
|
||||
|
|
@ -823,6 +822,28 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
|
|||
(ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
|
||||
}
|
||||
|
||||
// Used to determine if two tensors share the same buffer and their byte ranges overlap,
|
||||
static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
|
||||
return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
|
||||
ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) &&
|
||||
ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a));
|
||||
}
|
||||
|
||||
struct binary_overlap_flags {
|
||||
bool inplace; // src0 == dst
|
||||
bool overlap; // src1 == dst
|
||||
};
|
||||
|
||||
static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
binary_overlap_flags flags = {};
|
||||
flags.inplace = ggml_webgpu_tensor_equal(src0, dst);
|
||||
flags.overlap = ggml_webgpu_tensor_overlap(src1, dst);
|
||||
|
||||
return flags;
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
|
||||
|
|
@ -1375,14 +1396,42 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s
|
|||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst,
|
||||
webgpu_pipeline & pipeline,
|
||||
bool inplace) {
|
||||
static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
|
||||
ggml_tensor * src0,
|
||||
ggml_tensor * src1,
|
||||
ggml_tensor * dst) {
|
||||
binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst);
|
||||
|
||||
ggml_webgpu_binary_pipeline_key pipeline_key = {
|
||||
.type = dst->type,
|
||||
.op = dst->op,
|
||||
.inplace = flags.inplace,
|
||||
.overlap = flags.overlap,
|
||||
};
|
||||
ggml_webgpu_binary_shader_lib_context shader_lib_ctx = {
|
||||
.key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
|
||||
};
|
||||
|
||||
webgpu_pipeline pipeline;
|
||||
auto it = ctx->binary_pipelines.find(pipeline_key);
|
||||
if (it != ctx->binary_pipelines.end()) {
|
||||
pipeline = it->second;
|
||||
} else {
|
||||
ggml_webgpu_processed_shader processed =
|
||||
ggml_webgpu_preprocess_binary_shader(ctx->p, wgsl_binary, shader_lib_ctx);
|
||||
pipeline =
|
||||
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
|
||||
pipeline.context = processed.decisions;
|
||||
ctx->binary_pipelines.emplace(pipeline_key, pipeline);
|
||||
}
|
||||
|
||||
ggml_webgpu_generic_shader_decisions decisions =
|
||||
*static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context);
|
||||
|
||||
uint32_t ne = (uint32_t) ggml_nelements(dst);
|
||||
|
||||
std::vector<uint32_t> params = {
|
||||
(uint32_t) ggml_nelements(dst),
|
||||
ne,
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
|
||||
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
|
||||
|
|
@ -1399,24 +1448,30 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
|
|||
(uint32_t) src1->ne[3],
|
||||
};
|
||||
|
||||
std::vector<wgpu::BindGroupEntry> entries = {
|
||||
{ .binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
|
||||
{ .binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }
|
||||
};
|
||||
if (!inplace) {
|
||||
std::vector<wgpu::BindGroupEntry> entries;
|
||||
|
||||
entries.push_back({
|
||||
.binding = 0,
|
||||
.buffer = ggml_webgpu_tensor_buf(src0),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src0),
|
||||
});
|
||||
|
||||
entries.push_back({
|
||||
.binding = 1,
|
||||
.buffer = ggml_webgpu_tensor_buf(src1),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, src1),
|
||||
});
|
||||
|
||||
if (!flags.inplace && !flags.overlap) {
|
||||
entries.push_back({ .binding = 2,
|
||||
.buffer = ggml_webgpu_tensor_buf(dst),
|
||||
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
|
||||
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
|
||||
}
|
||||
|
||||
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
|
||||
uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size);
|
||||
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
|
||||
}
|
||||
|
||||
|
|
@ -2038,25 +2093,10 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
|
|||
return std::nullopt;
|
||||
#endif
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
int inplace = ggml_webgpu_tensor_equal(src0, node);
|
||||
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipelines[node->type][inplace], inplace);
|
||||
}
|
||||
case GGML_OP_SUB:
|
||||
{
|
||||
int inplace = ggml_webgpu_tensor_equal(src0, node);
|
||||
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipelines[node->type][inplace], inplace);
|
||||
}
|
||||
case GGML_OP_MUL:
|
||||
{
|
||||
int inplace = ggml_webgpu_tensor_equal(src0, node);
|
||||
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipelines[node->type][inplace], inplace);
|
||||
}
|
||||
case GGML_OP_DIV:
|
||||
{
|
||||
int inplace = ggml_webgpu_tensor_equal(src0, node);
|
||||
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipelines[node->type][inplace], inplace);
|
||||
}
|
||||
return ggml_webgpu_binary_op(ctx, src0, src1, node);
|
||||
case GGML_OP_RMS_NORM:
|
||||
return ggml_webgpu_rms_norm(ctx, src0, node);
|
||||
case GGML_OP_ROPE:
|
||||
|
|
@ -2665,58 +2705,6 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
|
|||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
|
||||
|
||||
webgpu_ctx->add_pipelines[GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f32, "add_f32", constants);
|
||||
webgpu_ctx->add_pipelines[GGML_TYPE_F16][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f16, "add_f16", constants);
|
||||
webgpu_ctx->add_pipelines[GGML_TYPE_F32][1] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants);
|
||||
webgpu_ctx->add_pipelines[GGML_TYPE_F16][1] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
|
||||
|
||||
webgpu_ctx->sub_pipelines[GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f32, "sub_f32", constants);
|
||||
webgpu_ctx->sub_pipelines[GGML_TYPE_F16][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f16, "sub_f16", constants);
|
||||
webgpu_ctx->sub_pipelines[GGML_TYPE_F32][1] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants);
|
||||
webgpu_ctx->sub_pipelines[GGML_TYPE_F16][1] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
|
||||
|
||||
webgpu_ctx->mul_pipelines[GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f32, "mul_f32", constants);
|
||||
webgpu_ctx->mul_pipelines[GGML_TYPE_F16][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f16, "mul_f16", constants);
|
||||
webgpu_ctx->mul_pipelines[GGML_TYPE_F32][1] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants);
|
||||
webgpu_ctx->mul_pipelines[GGML_TYPE_F16][1] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
|
||||
|
||||
webgpu_ctx->div_pipelines[GGML_TYPE_F32][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f32, "div_f32", constants);
|
||||
webgpu_ctx->div_pipelines[GGML_TYPE_F16][0] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f16, "div_f16", constants);
|
||||
webgpu_ctx->div_pipelines[GGML_TYPE_F32][1] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants);
|
||||
webgpu_ctx->div_pipelines[GGML_TYPE_F16][1] =
|
||||
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants);
|
||||
}
|
||||
|
||||
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
|
||||
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
|
||||
|
||||
|
|
@ -3018,10 +3006,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
|
|||
ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_get_rows_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_add_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_sub_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_mul_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_div_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_rope_pipeline(webgpu_ctx);
|
||||
ggml_webgpu_init_glu_pipeline(webgpu_ctx);
|
||||
|
|
|
|||
|
|
@ -1,188 +0,0 @@
|
|||
#define(VARIANTS)
|
||||
|
||||
[
|
||||
{
|
||||
"SHADER_NAME": "add_f32",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
"OP": "+"
|
||||
},
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "add_f16",
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
"OP": "+"
|
||||
},
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "add_f32_inplace",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
"OP": "+"
|
||||
},
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "add_f16_inplace",
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
"OP": "+"
|
||||
},
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "mul_f32",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
"OP": "*"
|
||||
},
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "mul_f16",
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
"OP": "*"
|
||||
},
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "mul_f32_inplace",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
"OP": "*"
|
||||
},
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "mul_f16_inplace",
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
"OP": "*"
|
||||
},
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "sub_f32",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
"OP": "-"
|
||||
},
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "sub_f16",
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
"OP": "-"
|
||||
},
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "sub_f32_inplace",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
"OP": "-"
|
||||
},
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "sub_f16_inplace",
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
"OP": "-"
|
||||
},
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "div_f32",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
"OP": "/"
|
||||
},
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "div_f16",
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
"OP": "/"
|
||||
},
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "div_f32_inplace",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
"OP": "/"
|
||||
},
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "div_f16_inplace",
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
"OP": "/"
|
||||
},
|
||||
"DECLS": ["INPLACE"]
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(DECLS)
|
||||
|
||||
#decl(NOT_INPLACE)
|
||||
|
||||
fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
|
||||
dst[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
|
||||
}
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#enddecl(NOT_INPLACE)
|
||||
|
||||
#decl(INPLACE)
|
||||
|
||||
fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
|
||||
src0[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
|
||||
}
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#enddecl(INPLACE)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
|
||||
#define(SHADER)
|
||||
|
||||
enable f16;
|
||||
|
||||
#include "binary_head.tmpl"
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src1: array<{{TYPE}}>;
|
||||
|
||||
DECLS
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x < params.ne) {
|
||||
update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
|
||||
}
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
|
|
@ -0,0 +1,107 @@
|
|||
enable f16;
|
||||
|
||||
struct Params {
|
||||
ne: u32,
|
||||
|
||||
// offsets in elements
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
stride_src1_0: u32,
|
||||
stride_src1_1: u32,
|
||||
stride_src1_2: u32,
|
||||
stride_src1_3: u32,
|
||||
|
||||
a_ne0: u32,
|
||||
a_ne1: u32,
|
||||
a_ne2: u32,
|
||||
|
||||
b_ne0: u32,
|
||||
b_ne1: u32,
|
||||
b_ne2: u32,
|
||||
b_ne3: u32,
|
||||
};
|
||||
|
||||
fn src1_index(_i: u32) -> u32 {
|
||||
var i = _i;
|
||||
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
|
||||
i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
|
||||
let a_i2 = i / (params.a_ne1 * params.a_ne0);
|
||||
i = i % (params.a_ne1 * params.a_ne0);
|
||||
let a_i1 = i / params.a_ne0;
|
||||
let a_i0 = i % params.a_ne0;
|
||||
|
||||
// handle repetition of b
|
||||
// index loops back to the beginning and repeats after elements are exhausted = modulo
|
||||
let b_i0 = a_i0 % params.b_ne0;
|
||||
let b_i1 = a_i1 % params.b_ne1;
|
||||
let b_i2 = a_i2 % params.b_ne2;
|
||||
let b_i3 = a_i3 % params.b_ne3;
|
||||
|
||||
// compute index for position in b's flat array
|
||||
return b_i0 * params.stride_src1_0 +
|
||||
b_i1 * params.stride_src1_1 +
|
||||
b_i2 * params.stride_src1_2 +
|
||||
b_i3 * params.stride_src1_3;
|
||||
}
|
||||
|
||||
#ifdef TYPE_F32
|
||||
#define DataType f32
|
||||
#endif
|
||||
#ifdef TYPE_F16
|
||||
#define DataType f16
|
||||
#endif
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<DataType>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src1 : array<DataType>;
|
||||
|
||||
#ifdef INPLACE
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#elif defined(OVERLAP)
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#else
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<DataType>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
#endif
|
||||
|
||||
fn op(a: DataType, b: DataType) -> DataType {
|
||||
#ifdef OP_ADD
|
||||
return a + b;
|
||||
#elif defined(OP_SUB)
|
||||
return a - b;
|
||||
#elif defined(OP_MUL)
|
||||
return a * b;
|
||||
#elif defined(OP_DIV)
|
||||
return a / b;
|
||||
#endif
|
||||
}
|
||||
|
||||
fn update(dst_i: u32, src0_i: u32, src1_i: u32){
|
||||
let result = op(src0[src0_i], src1[src1_i]);
|
||||
|
||||
#ifdef INPLACE
|
||||
src0[dst_i] = result;
|
||||
#elif defined(OVERLAP)
|
||||
src1[dst_i] = result;
|
||||
#else
|
||||
dst[dst_i] = result;
|
||||
#endif
|
||||
}
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x < params.ne) {
|
||||
update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
|
||||
}
|
||||
}
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
struct Params {
|
||||
ne: u32,
|
||||
|
||||
// offsets in elements
|
||||
offset_src0: u32,
|
||||
offset_src1: u32,
|
||||
offset_dst: u32,
|
||||
|
||||
stride_src1_0: u32,
|
||||
stride_src1_1: u32,
|
||||
stride_src1_2: u32,
|
||||
stride_src1_3: u32,
|
||||
|
||||
a_ne0: u32,
|
||||
a_ne1: u32,
|
||||
a_ne2: u32,
|
||||
|
||||
b_ne0: u32,
|
||||
b_ne1: u32,
|
||||
b_ne2: u32,
|
||||
b_ne3: u32,
|
||||
};
|
||||
|
||||
fn src1_index(_i: u32) -> u32 {
|
||||
var i = _i;
|
||||
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
|
||||
i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
|
||||
let a_i2 = i / (params.a_ne1 * params.a_ne0);
|
||||
i = i % (params.a_ne1 * params.a_ne0);
|
||||
let a_i1 = i / params.a_ne0;
|
||||
let a_i0 = i % params.a_ne0;
|
||||
|
||||
// handle repetition of b
|
||||
// index loops back to the beginning and repeats after elements are exhausted = modulo
|
||||
let b_i0 = a_i0 % params.b_ne0;
|
||||
let b_i1 = a_i1 % params.b_ne1;
|
||||
let b_i2 = a_i2 % params.b_ne2;
|
||||
let b_i3 = a_i3 % params.b_ne3;
|
||||
|
||||
// compute index for position in b's flat array
|
||||
return b_i0 * params.stride_src1_0 +
|
||||
b_i1 * params.stride_src1_1 +
|
||||
b_i2 * params.stride_src1_2 +
|
||||
b_i3 * params.stride_src1_3;
|
||||
}
|
||||
Loading…
Reference in New Issue