This commit is contained in:
Abhijit Ramesh 2026-03-24 02:47:00 -07:00 committed by GitHub
commit 0a5a393ff8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 766 additions and 1003 deletions

View File

@ -535,6 +535,95 @@ struct ggml_webgpu_mul_mat_shader_decisions {
uint32_t mul_mat_wg_size;
};
/** Cpy **/
struct ggml_webgpu_cpy_pipeline_key {
ggml_type src_type;
ggml_type dst_type;
bool operator==(const ggml_webgpu_cpy_pipeline_key & other) const {
return src_type == other.src_type && dst_type == other.dst_type;
}
};
struct ggml_webgpu_cpy_pipeline_key_hash {
size_t operator()(const ggml_webgpu_cpy_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.src_type);
ggml_webgpu_hash_combine(seed, key.dst_type);
return seed;
}
};
/** Glu **/
struct ggml_webgpu_glu_pipeline_key {
ggml_glu_op glu_op;
ggml_type type;
bool split;
bool operator==(const ggml_webgpu_glu_pipeline_key & other) const {
return glu_op == other.glu_op && type == other.type && split == other.split;
}
};
struct ggml_webgpu_glu_pipeline_key_hash {
size_t operator()(const ggml_webgpu_glu_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.glu_op);
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.split);
return seed;
}
};
/** Rope **/
struct ggml_webgpu_rope_pipeline_key {
ggml_type type;
bool inplace;
bool has_ff;
bool operator==(const ggml_webgpu_rope_pipeline_key & other) const {
return type == other.type && inplace == other.inplace && has_ff == other.has_ff;
}
};
struct ggml_webgpu_rope_pipeline_key_hash {
size_t operator()(const ggml_webgpu_rope_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.inplace);
ggml_webgpu_hash_combine(seed, key.has_ff);
return seed;
}
};
/** SoftMax **/
struct ggml_webgpu_soft_max_pipeline_key {
ggml_type mask_type;
bool has_mask;
bool has_sink;
bool inplace;
bool operator==(const ggml_webgpu_soft_max_pipeline_key & other) const {
return mask_type == other.mask_type && has_mask == other.has_mask && has_sink == other.has_sink &&
inplace == other.inplace;
}
};
struct ggml_webgpu_soft_max_pipeline_key_hash {
size_t operator()(const ggml_webgpu_soft_max_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.mask_type);
ggml_webgpu_hash_combine(seed, key.has_mask);
ggml_webgpu_hash_combine(seed, key.has_sink);
ggml_webgpu_hash_combine(seed, key.inplace);
return seed;
}
};
class ggml_webgpu_shader_lib {
wgpu::Device device;
pre_wgsl::Preprocessor preprocessor;
@ -582,6 +671,12 @@ class ggml_webgpu_shader_lib {
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
set_rows_pipelines;
std::unordered_map<ggml_webgpu_set_pipeline_key, webgpu_pipeline, ggml_webgpu_set_pipeline_key_hash> set_pipelines;
std::unordered_map<ggml_webgpu_cpy_pipeline_key, webgpu_pipeline, ggml_webgpu_cpy_pipeline_key_hash> cpy_pipelines;
std::unordered_map<ggml_webgpu_glu_pipeline_key, webgpu_pipeline, ggml_webgpu_glu_pipeline_key_hash> glu_pipelines;
std::unordered_map<ggml_webgpu_rope_pipeline_key, webgpu_pipeline, ggml_webgpu_rope_pipeline_key_hash>
rope_pipelines;
std::unordered_map<ggml_webgpu_soft_max_pipeline_key, webgpu_pipeline, ggml_webgpu_soft_max_pipeline_key_hash>
soft_max_pipelines;
public:
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
@ -1679,6 +1774,236 @@ class ggml_webgpu_shader_lib {
return flash_attn_pipelines[key];
}
webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_cpy_pipeline_key key = {
.src_type = context.src0->type,
.dst_type = context.dst->type,
};
auto it = cpy_pipelines.find(key);
if (it != cpy_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "cpy";
switch (key.src_type) {
case GGML_TYPE_F32:
defines.push_back("SRC_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("SRC_F16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported src type for cpy shader");
}
switch (key.dst_type) {
case GGML_TYPE_F32:
defines.push_back("DST_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("DST_F16");
variant += "_f16";
break;
case GGML_TYPE_I32:
defines.push_back("DST_I32");
variant += "_i32";
break;
default:
GGML_ABORT("Unsupported dst type for cpy shader");
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_cpy, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
cpy_pipelines[key] = pipeline;
return cpy_pipelines[key];
}
webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_glu_pipeline_key key = {
.glu_op = ggml_get_glu_op(context.dst),
.type = context.dst->type,
.split = (context.src1 != nullptr),
};
auto it = glu_pipelines.find(key);
if (it != glu_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "glu";
switch (key.glu_op) {
case GGML_GLU_OP_REGLU:
defines.push_back("OP_REGLU");
variant += "_reglu";
break;
case GGML_GLU_OP_GEGLU:
defines.push_back("OP_GEGLU");
variant += "_geglu";
break;
case GGML_GLU_OP_SWIGLU:
defines.push_back("OP_SWIGLU");
variant += "_swiglu";
break;
case GGML_GLU_OP_SWIGLU_OAI:
defines.push_back("OP_SWIGLU_OAI");
variant += "_swiglu_oai";
break;
case GGML_GLU_OP_GEGLU_ERF:
defines.push_back("OP_GEGLU_ERF");
variant += "_geglu_erf";
break;
case GGML_GLU_OP_GEGLU_QUICK:
defines.push_back("OP_GEGLU_QUICK");
variant += "_geglu_quick";
break;
default:
GGML_ABORT("Unsupported GLU op");
}
switch (key.type) {
case GGML_TYPE_F32:
defines.push_back("TYPE_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("TYPE_F16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported type for GLU shader");
}
if (key.split) {
variant += "_split";
} else {
defines.push_back("NO_SPLIT");
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_glu, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
glu_pipelines[key] = pipeline;
return glu_pipelines[key];
}
webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_rope_pipeline_key key = {
.type = context.dst->type,
.inplace = context.inplace,
.has_ff = (context.src2 != nullptr),
};
auto it = rope_pipelines.find(key);
if (it != rope_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "rope";
switch (key.type) {
case GGML_TYPE_F32:
defines.push_back("TYPE_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("TYPE_F16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported type for ROPE shader");
}
if (key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
}
if (key.has_ff) {
defines.push_back("FF_FUNC");
variant += "_ff";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_rope, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
rope_pipelines[key] = pipeline;
return rope_pipelines[key];
}
webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_soft_max_pipeline_key key = {
.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32,
.has_mask = (context.src1 != nullptr),
.has_sink = (context.src2 != nullptr),
.inplace = context.inplace,
};
auto it = soft_max_pipelines.find(key);
if (it != soft_max_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "soft_max";
if (key.has_mask) {
defines.push_back("HAS_MASK");
switch (key.mask_type) {
case GGML_TYPE_F32:
defines.push_back("MASK_F32");
variant += "_mask_f32";
break;
case GGML_TYPE_F16:
defines.push_back("MASK_F16");
variant += "_mask_f16";
break;
default:
GGML_ABORT("Unsupported type for SOFT_MAX shader");
}
}
if (key.has_sink) {
defines.push_back("HAS_SINK");
variant += "_sink";
}
if (key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_soft_max, defines);
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
pipeline.context = decisions;
soft_max_pipelines[key] = pipeline;
return soft_max_pipelines[key];
}
private:
static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
std::string shader_code,

View File

@ -364,13 +364,6 @@ struct webgpu_context_struct {
wgpu::Buffer set_rows_dev_error_buf;
wgpu::Buffer set_rows_host_error_buf;
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines; // mask_type, has_sink, inplace
size_t memset_bytes_per_thread;
};
@ -849,6 +842,16 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0
}
static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_cpy_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t ne = (uint32_t) ggml_nelements(dst);
std::vector<uint32_t> params = {
@ -875,9 +878,8 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->cpy_pipelines[src->type][dst->type],
params, entries, wg_x);
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
@ -1914,6 +1916,19 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
ggml_tensor * src1,
ggml_tensor * src2,
ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.src2 = src2,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.inplace = ggml_webgpu_tensor_equal(src0, dst),
};
webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
const int has_freq_factor = (src2 != nullptr);
@ -1996,12 +2011,22 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace];
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
};
webgpu_pipeline pipeline = ctx->shader_lib->get_glu_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
const int split = (src1 != nullptr);
std::vector<uint32_t> params = {
@ -2048,8 +2073,7 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0,
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split];
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
@ -2109,9 +2133,20 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
ggml_tensor * src1,
ggml_tensor * src2,
ggml_tensor * dst) {
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here
const int has_sink = (src2 != nullptr);
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src0,
.src1 = src1,
.src2 = src2,
.dst = dst,
.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.inplace = ggml_webgpu_tensor_equal(src0, dst),
};
webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx);
const int inplace = ggml_webgpu_tensor_equal(src0, dst);
const int has_mask = (src1 != nullptr);
const int has_sink = (src2 != nullptr);
float max_bias;
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
@ -2120,15 +2155,15 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
std::vector<uint32_t> params = {
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
(uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
(uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
has_mask ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
has_mask ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
has_mask ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
(uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
(uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
@ -2136,8 +2171,8 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
(uint32_t) src0->ne[0],
(uint32_t) src0->ne[1],
(uint32_t) src0->ne[2],
mask_type < 2 ? (uint32_t) src1->ne[2] : 0,
mask_type < 2 ? (uint32_t) src1->ne[3] : 0,
has_mask ? (uint32_t) src1->ne[2] : 0,
has_mask ? (uint32_t) src1->ne[3] : 0,
*(uint32_t *) dst->op_params, // scale
*(uint32_t *) &max_bias,
*(uint32_t *) &n_head_log2,
@ -2152,7 +2187,7 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
.size = ggml_webgpu_tensor_binding_size(ctx, src0) }
};
uint32_t binding_num = 1;
if (mask_type < 2) {
if (has_mask) {
entries.push_back({ .binding = binding_num,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
@ -2173,9 +2208,7 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool,
ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries,
ggml_nrows(dst));
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(dst));
}
static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
@ -2885,139 +2918,6 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) {
ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants);
}
static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants);
webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
}
static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, "rope_f32", constants);
webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants);
webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, "rope_f16", constants);
webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants);
webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
}
static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
// REGLU
webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, "reglu_f32", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, "reglu_f16", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants);
// GEGLU
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, "geglu_f32", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, "geglu_f16", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants);
// SWIGLU
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
// SWIGLU_OAI
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
// GEGLU_ERF
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
// GEGLU_QUICK
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
}
static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
// f32 (no mask)
webgpu_ctx->soft_max_pipelines[2][0][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants);
webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants);
webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants);
webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
// f32 mask (mask_type = 0)
webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants);
webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
webgpu_ctx->soft_max_pipelines[0][1][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace,
"soft_max_f32_mask_f32_sink_inplace", constants);
// f16 mask (mask_type = 1)
webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants);
webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline(
webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
webgpu_ctx->soft_max_pipelines[1][1][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace,
"soft_max_f32_mask_f16_sink_inplace", constants);
}
static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
wgpu::RequestAdapterOptions options = {};
@ -3183,10 +3083,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf");
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
ggml_webgpu_init_rope_pipeline(webgpu_ctx);
ggml_webgpu_init_glu_pipeline(webgpu_ctx);
ggml_webgpu_init_soft_max_pipeline(webgpu_ctx);
#ifdef GGML_WEBGPU_DEBUG
// Initialize debug buffers
ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf,

View File

@ -1,66 +1,41 @@
#define(VARIANTS)
[
{
"REPLS": {
"SRC_TYPE": "f32",
"DST_TYPE": "f32"
}
},
{
"REPLS": {
"SRC_TYPE": "f32",
"DST_TYPE": "i32"
}
},
{
"REPLS": {
"SRC_TYPE": "f32",
"DST_TYPE": "f16"
}
},
{
"REPLS": {
"SRC_TYPE": "f16",
"DST_TYPE": "f16"
}
},
{
"REPLS": {
"SRC_TYPE": "f16",
"DST_TYPE": "f32"
}
}
]
#end(VARIANTS)
#define(SHADER)
enable f16;
#ifdef SRC_F32
#define SRC_TYPE f32
#elif defined(SRC_F16)
#define SRC_TYPE f16
#endif
#ifdef DST_F32
#define DST_TYPE f32
#elif defined(DST_F16)
#define DST_TYPE f16
#elif defined(DST_I32)
#define DST_TYPE i32
#endif
@group(0) @binding(0)
var<storage, read_write> src: array<{{SRC_TYPE}}>;
var<storage, read_write> src: array<SRC_TYPE>;
@group(0) @binding(1)
var<storage, read_write> dst: array<{{DST_TYPE}}>;
var<storage, read_write> dst: array<DST_TYPE>;
struct Params {
ne: u32, // total number of elements
offset_src: u32, // in elements
offset_dst: u32, // in elements
struct Params{
ne: u32,
offset_src: u32,
offset_dst: u32,
// Strides (in elements) may be permuted
stride_src0: u32,
stride_src1: u32,
stride_src2: u32,
stride_src3: u32,
stride_dst0: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// Logical shapes
src_ne0: u32,
src_ne1: u32,
src_ne2: u32,
@ -73,8 +48,7 @@ struct Params {
@group(0) @binding(2)
var<uniform> params: Params;
override wg_size: u32;
@compute @workgroup_size(wg_size)
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
@ -102,6 +76,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
j2 * params.stride_dst2 + j3 * params.stride_dst3;
dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx]));
dst[params.offset_dst + dst_idx] = DST_TYPE((src[params.offset_src + src_idx]));
}
#end(SHADER)

View File

@ -1,41 +1,8 @@
import os
import re
import ast
import argparse
def extract_block(text, name):
pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)'
match = re.search(pattern, text, re.DOTALL)
if not match:
raise ValueError(f"Missing block: {name}")
return match.group(1).strip()
def parse_decls(decls_text):
decls = {}
for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL):
decls[name.strip()] = code.strip()
return decls
def replace_repl_placeholders(variant, template_map):
for repl, code in variant["REPLS"].items():
for key, val in template_map.items():
# Match "key" and avoid matching subsequences using by using \b
code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code)
variant["REPLS"][repl] = code
return variant
def replace_placeholders(shader_text, replacements):
for key, val in replacements.items():
# Match {{KEY}} literally, where KEY is escaped
pattern = r'{{\s*' + re.escape(key) + r'\s*}}'
shader_text = re.sub(pattern, str(val), shader_text)
return shader_text
def expand_includes(shader, input_dir):
"""
Replace #include "file" lines in the text with the contents of that file.
@ -98,84 +65,24 @@ def write_shader(shader_name, shader_code, output_dir, outfile, input_dir):
outfile.write(f'const char* wgsl_{shader_name} = wgsl_{shader_name}_str().c_str();\n\n')
def generate_variants(fname, input_dir, output_dir, outfile):
shader_path = os.path.join(input_dir, fname)
shader_base_name = fname.split(".")[0]
with open(shader_path, "r", encoding="utf-8") as f:
text = f.read()
try:
variants = ast.literal_eval(extract_block(text, "VARIANTS"))
except ValueError:
write_shader(shader_base_name, text, output_dir, outfile, input_dir)
else:
try:
decls_map = parse_decls(extract_block(text, "DECLS"))
except ValueError:
decls_map = {}
try:
templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES"))
except ValueError:
templates_map = {}
for fname in sorted(os.listdir(input_dir)):
if fname.endswith(".tmpl"):
tmpl_path = os.path.join(input_dir, fname)
with open(tmpl_path, "r", encoding="utf-8") as f_tmpl:
decls = f_tmpl.read()
decls_map.update(parse_decls(decls))
shader_template = extract_block(text, "SHADER")
for variant in variants:
if "DECLS" in variant:
decls = variant["DECLS"]
else:
decls = []
decls_code = ""
for key in decls:
if key not in decls_map:
raise ValueError(f"DECLS key '{key}' not found.")
decls_code += decls_map[key] + "\n\n"
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template)
if "REPLS" in variant:
variant = replace_repl_placeholders(variant, templates_map)
final_shader = replace_placeholders(final_shader, variant["REPLS"])
# second run to expand placeholders in repl_template
final_shader = replace_placeholders(final_shader, variant["REPLS"])
final_shader = expand_includes(final_shader, input_dir)
if "SHADER_NAME" in variant:
output_name = variant["SHADER_NAME"]
elif "SHADER_SUFFIX" in variant:
output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"]
elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]:
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]])
elif "REPLS" in variant and "TYPE" in variant["REPLS"]:
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
else:
output_name = shader_base_name
write_shader(output_name, final_shader, output_dir, outfile, input_dir)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", required=True)
parser.add_argument("--output_file", required=True)
parser.add_argument("--output_dir")
args = parser.parse_args()
if args.output_dir:
os.makedirs(args.output_dir, exist_ok=True)
with open(args.output_file, "w", encoding="utf-8") as out:
out.write("// Auto-generated shader embedding\n")
out.write("#include <string>\n\n")
for fname in sorted(os.listdir(args.input_dir)):
if fname.endswith(".wgsl"):
generate_variants(fname, args.input_dir, args.output_dir, out)
shader_path = os.path.join(args.input_dir, fname)
shader_name = fname.replace(".wgsl", "")
with open(shader_path, "r", encoding="utf-8") as f:
shader_code = f.read()
write_shader(shader_name, shader_code, None, out, args.input_dir)
if __name__ == "__main__":

View File

@ -1,323 +0,0 @@
#define(VARIANTS)
[
{
"SHADER_NAME": "reglu_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "REGLU"]
},
{
"SHADER_NAME": "reglu_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "REGLU"]
},
{
"SHADER_NAME": "reglu_f16",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_SPLIT", "REGLU"]
},
{
"SHADER_NAME": "reglu_f16_split",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["SPLIT", "REGLU"]
},
{
"SHADER_NAME": "geglu_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "GEGLU"]
},
{
"SHADER_NAME": "geglu_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "GEGLU"]
},
{
"SHADER_NAME": "geglu_f16",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_SPLIT", "GEGLU"]
},
{
"SHADER_NAME": "geglu_f16_split",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["SPLIT", "GEGLU"]
},
{
"SHADER_NAME": "swiglu_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "SWIGLU"]
},
{
"SHADER_NAME": "swiglu_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "SWIGLU"]
},
{
"SHADER_NAME": "swiglu_f16",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_SPLIT", "SWIGLU"]
},
{
"SHADER_NAME": "swiglu_f16_split",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["SPLIT", "SWIGLU"]
},
{
"SHADER_NAME": "swiglu_oai_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "SWIGLU_OAI"]
},
{
"SHADER_NAME": "swiglu_oai_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "SWIGLU_OAI"]
},
{
"SHADER_NAME": "geglu_erf_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "GEGLU_ERF"]
},
{
"SHADER_NAME": "geglu_erf_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "GEGLU_ERF"]
},
{
"SHADER_NAME": "geglu_erf_f16",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_SPLIT", "GEGLU_ERF"]
},
{
"SHADER_NAME": "geglu_erf_f16_split",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["SPLIT", "GEGLU_ERF"]
},
{
"SHADER_NAME": "geglu_quick_f32",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_SPLIT", "GEGLU_QUICK"]
},
{
"SHADER_NAME": "geglu_quick_f32_split",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["SPLIT", "GEGLU_QUICK"]
},
{
"SHADER_NAME": "geglu_quick_f16",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_SPLIT", "GEGLU_QUICK"]
},
{
"SHADER_NAME": "geglu_quick_f16_split",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["SPLIT", "GEGLU_QUICK"]
},
]
#end(VARIANTS)
#define(DECLS)
#decl(REGLU)
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
return max(a, 0) * b;
}
#enddecl(REGLU)
#decl(GEGLU)
const SQRT_2_OVER_PI: {{TYPE}} = 0.79788456080286535587989211986876;
const GELU_COEF_A: {{TYPE}} = 0.044715;
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a);
return 0.5 * a * (2.0 - 2.0 / (exp(2 * val) + 1)) * b;
}
#enddecl(GEGLU)
#decl(SWIGLU)
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
return a / (1.0 + exp(-a)) * b;
}
#enddecl(SWIGLU)
#decl(SWIGLU_OAI)
fn op(a: f32, b: f32) -> f32 {
let xi = min(a, params.limit);
let gi = max(min(b, params.limit), -params.limit);
var out_glu = xi / (1.0 + exp(-xi * params.alpha));
out_glu = out_glu * (1.0 + gi);
return out_glu;
}
#enddecl(SWIGLU_OAI)
#decl(GEGLU_ERF)
const p_erf: {{TYPE}} = 0.3275911;
const a1_erf: {{TYPE}} = 0.254829592;
const a2_erf: {{TYPE}} = -0.284496736;
const a3_erf: {{TYPE}} = 1.421413741;
const a4_erf: {{TYPE}} = -1.453152027;
const a5_erf: {{TYPE}} = 1.061405429;
const SQRT_2_INV: {{TYPE}} = 0.7071067811865476;
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
let a_div_sqr2 = a * SQRT_2_INV;
let sign_x = sign(a_div_sqr2);
let x = abs(a_div_sqr2);
let t = 1.0 / (1.0 + p_erf * x);
let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x));
let erf_approx = sign_x * y;
return 0.5 * a * (1.0 + erf_approx) * b;
}
#enddecl(GEGLU_ERF)
#decl(GEGLU_QUICK)
const GELU_QUICK_COEF: {{TYPE}} = -1.702;
fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b;
}
#enddecl(GEGLU_QUICK)
#decl(NO_SPLIT)
@group(0) @binding(1)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(2)
var<uniform> params: Params;
fn a_value(base: u32) -> {{TYPE}} {
let offset: u32 = select(0, params.ne0, params.swapped != 0);
return src0[base + offset];
}
fn b_value(base: u32) -> {{TYPE}} {
let offset: u32 = select(params.ne0, 0, params.swapped != 0);
return src0[base + offset];
}
#enddecl(NO_SPLIT)
#decl(SPLIT)
@group(0) @binding(1)
var<storage, read_write> src1: array<{{TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(3)
var<uniform> params: Params;
fn a_value(base: u32) -> {{TYPE}} {
return src0[base];
}
fn b_value(base: u32) -> {{TYPE}} {
return src1[base];
}
#enddecl(SPLIT)
#end(DECLS)
#define(SHADER)
enable f16;
struct Params {
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
// Strides (in elements)
stride_src01: u32,
stride_src02: u32,
stride_src03: u32,
stride_src11: u32,
stride_src12: u32,
stride_src13: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// shape of dst
ne: u32,
ne0: u32,
ne1: u32,
ne2: u32,
swapped: u32,
alpha: f32,
limit: f32,
}
@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;
DECLS
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
i = i % (params.ne1 * params.ne0);
let i1 = i / params.ne0;
let i0 = i % params.ne0;
let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0;
let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
dst[i_dst] = op(a_value(i_a), b_value(i_b));
}
#end(SHADER)

View File

@ -0,0 +1,155 @@
enable f16;
#ifdef TYPE_F32
#define DataType f32
#endif
#ifdef TYPE_F16
#define DataType f16
#endif
#ifdef OP_REGLU
fn op(a: DataType, b: DataType) -> DataType {
return max(a, 0) * b;
}
#endif
#ifdef OP_GEGLU
const SQRT_2_OVER_PI: DataType = 0.79788456080286535587989211986876;
const GELU_COEF_A: DataType = 0.044715;
fn op(a: DataType, b: DataType) -> DataType {
let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a);
return 0.5 * a * (2.0 - 2.0/ (exp(2* val) + 1)) * b;
}
#endif
#ifdef OP_SWIGLU
fn op(a: DataType, b: DataType) -> DataType {
return a / (1.0 + exp(-a)) * b;
}
#endif
#ifdef OP_SWIGLU_OAI
fn op(a: f32, b: f32) -> f32 {
let xi = min(a, params.limit);
let gi = max(min(b, params.limit), -params.limit);
var out_glu = xi / (1.0 + exp(-xi * params.alpha));
out_glu = out_glu * (1.0 + gi);
return out_glu;
}
#endif
#ifdef OP_GEGLU_ERF
const p_erf: DataType = 0.3275911;
const a1_erf: DataType = 0.254829592;
const a2_erf: DataType = -0.284496736;
const a3_erf: DataType = 1.421413741;
const a4_erf: DataType = -1.453152027;
const a5_erf: DataType = 1.061405429;
const SQRT_2_INV: DataType = 0.7071067811865476;
fn op(a: DataType, b: DataType) -> DataType {
let a_div_sqr2 = a * SQRT_2_INV;
let sign_x = sign(a_div_sqr2);
let x = abs(a_div_sqr2);
let t = 1.0 / (1.0 + p_erf * x);
let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x));
let erf_approx = sign_x * y;
return 0.5 * a * (1.0 + erf_approx) * b;
}
#endif
#ifdef OP_GEGLU_QUICK
const GELU_QUICK_COEF: DataType = -1.702;
fn op(a: DataType, b: DataType) -> DataType {
return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b;
}
#endif
struct Params {
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
// Strides (in elements)
stride_src01: u32,
stride_src02: u32,
stride_src03: u32,
stride_src11: u32,
stride_src12: u32,
stride_src13: u32,
stride_dst1: u32,
stride_dst2: u32,
stride_dst3: u32,
// shape of dst
ne: u32,
ne0: u32,
ne1: u32,
ne2: u32,
swapped: u32,
alpha: f32,
limit: f32,
}
@group(0) @binding(0)
var<storage, read_write> src0: array<DataType>;
#ifdef NO_SPLIT
@group(0) @binding(1)
var<storage, read_write> dst: array<DataType>;
@group(0) @binding(2)
var<uniform> params: Params;
fn a_value(base: u32) -> DataType {
let offset: u32 = select(0, params.ne0, params.swapped != 0);
return src0[base + offset];
}
fn b_value(base: u32) -> DataType {
let offset: u32 = select(params.ne0, 0, params.swapped != 0);
return src0[base + offset];
}
#else
@group(0) @binding(1)
var<storage, read_write> src1: array<DataType>;
@group(0) @binding(2)
var<storage, read_write> dst: array<DataType>;
@group(0) @binding(3)
var<uniform> params: Params;
fn a_value(base: u32) -> DataType {
return src0[base];
}
fn b_value(base: u32) -> DataType {
return src1[base];
}
#endif
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x >= params.ne) {
return;
}
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
i = i % (params.ne1 * params.ne0);
let i1 = i / params.ne0;
let i0 = i % params.ne0;
let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0;
let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
dst[i_dst] = op(a_value(i_a), b_value(i_b));
}

View File

@ -1,138 +1,12 @@
#define(VARIANTS)
[
{
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
},
{
"SHADER_SUFFIX": "f32_inplace",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
},
{
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
},
{
"SHADER_SUFFIX": "f16_inplace",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
},
{
"SHADER_SUFFIX": "f32_ff",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
},
{
"SHADER_SUFFIX": "f32_ff_inplace",
"REPLS": {
"TYPE" : "f32",
},
"DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
},
{
"SHADER_SUFFIX": "f16_ff",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
},
{
"SHADER_SUFFIX": "f16_ff_inplace",
"REPLS": {
"TYPE" : "f16",
},
"DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
}
]
#end(VARIANTS)
#define(DECLS)
#decl(ROTATE)
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
dst[i_dst0] = {{TYPE}}(out0);
dst[i_dst1] = {{TYPE}}(out1);
}
#enddecl(ROTATE)
#decl(ROTATE_INPLACE)
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
src0[i_dst0] = {{TYPE}}(out0);
src0[i_dst1] = {{TYPE}}(out1);
}
#enddecl(ROTATE_INPLACE)
#decl(NO_FF_FUNC)
fn freq_factor(i: u32) -> f32 {
return 1.0f;
}
#enddecl(NO_FF_FUNC)
#decl(FF_FUNC)
fn freq_factor(i: u32) -> f32 {
return src2[params.offset_src2 + i/2];
}
#enddecl(FF_FUNC)
#decl(NO_FF_BINDINGS)
@group(0) @binding(2)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(NO_FF_BINDINGS)
#decl(NO_FF_BINDINGS_INPLACE)
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(NO_FF_BINDINGS_INPLACE)
#decl(FF_BINDINGS)
@group(0) @binding(2)
var<storage, read_write> src2: array<f32>;
@group(0) @binding(3)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(4)
var<uniform> params: Params;
#enddecl(FF_BINDINGS)
#decl(FF_BINDINGS_INPLACE)
@group(0) @binding(2)
var<storage, read_write> src2: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(FF_BINDINGS_INPLACE)
#end(DECLS)
#define(SHADER)
enable f16;
#ifdef TYPE_F32
#define DataType f32
#endif
#ifdef TYPE_F16
#define DataType f16
#endif
struct Params {
offset_src0: u32,
offset_src1: u32,
@ -168,12 +42,69 @@ struct Params {
};
@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;
var<storage, read_write> src0: array<DataType>;
@group(0) @binding(1)
var<storage, read_write> src1: array<i32>;
DECLS
#ifdef INPLACE
#ifdef FF_FUNC
@group(0) @binding(2)
var<storage, read_write> src2: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#else
@group(0) @binding(2)
var<uniform> params: Params;
#endif
#else
#ifdef FF_FUNC
@group(0) @binding(2)
var<storage, read_write> src2: array<f32>;
@group(0) @binding(3)
var<storage, read_write> dst: array<DataType>;
@group(0) @binding(4)
var<uniform> params: Params;
#else
@group(0) @binding(2)
var<storage, read_write> dst: array<DataType>;
@group(0) @binding(3)
var<uniform> params: Params;
#endif
#endif
#ifdef FF_FUNC
fn freq_factor(i: u32) -> f32 {
return src2[params.offset_src2 + i/2];
}
#else
fn freq_factor(i: u32) -> f32 {
return 1.0f;
}
#endif
#ifdef INPLACE
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
src0[i_dst0] = DataType(out0);
src0[i_dst1] = DataType(out1);
}
#else
fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
dst[i_dst0] = DataType(out0);
dst[i_dst1] = DataType(out1);
}
#endif
fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 {
let y = (f32(i / 2) - low) / max(0.001f, high - low);
@ -184,7 +115,7 @@ fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 {
// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row
fn rope_yarn(theta_extrap: f32, i: u32) -> vec2<f32> {
var mscale = params.attn_factor;
var theta = params.freq_scale * theta_extrap;
var theta = params.freq_scale * theta_extrap;
if (params.ext_factor != 0.0f) {
let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor;
theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix;
@ -207,14 +138,13 @@ fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 {
} else if (is_neox || is_mrope) {
return params.n_dims / 2;
} else {
return 1;
return 1;
}
}
override wg_size: u32;
@compute @workgroup_size(wg_size)
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
// two elements per thread
// two elements per n_threads
if (gid.x >= params.n_threads) {
return;
}
@ -235,7 +165,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
if (i0 >= params.n_dims && !is_vision) {
if (i0 >= params.n_dims && !is_vision) {
let i_src = i_src_row + i0;
let i_dst = i_dst_row + i0;
rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1]));
@ -290,6 +220,5 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let x0 = f32(src0[i_src]);
let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]);
rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x);
}
#end(SHADER)
}

View File

@ -1,215 +1,12 @@
#define(VARIANTS)
[
{
"SHADER_NAME": "soft_max_f32",
"DECLS": ["BASE_BINDINGS", "NOT_INPLACE", "NO_MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_inplace",
"DECLS": ["BASE_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_sink",
"DECLS": ["SINK_BINDINGS", "NOT_INPLACE", "NO_MASK", "SINK"]
},
{
"SHADER_NAME": "soft_max_f32_sink_inplace",
"DECLS": ["SINK_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f32",
"REPLS": {
"MASK_TYPE" : "f32",
},
"DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f32_inplace",
"REPLS": {
"MASK_TYPE" : "f32",
},
"DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f16",
"REPLS": {
"MASK_TYPE" : "f16",
},
"DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f16_inplace",
"REPLS": {
"MASK_TYPE" : "f16",
},
"DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f32_sink",
"REPLS": {
"MASK_TYPE" : "f32",
},
"DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f32_sink_inplace",
"REPLS": {
"MASK_TYPE" : "f32",
},
"DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f16_sink",
"REPLS": {
"MASK_TYPE" : "f16",
},
"DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
},
{
"SHADER_NAME": "soft_max_f32_mask_f16_sink_inplace",
"REPLS": {
"MASK_TYPE" : "f16",
},
"DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
}
]
#end(VARIANTS)
#define(DECLS)
#decl(BASE_BINDINGS)
@group(0) @binding(1)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(BASE_BINDINGS)
#decl(BASE_BINDINGS_INPLACE)
@group(0) @binding(1)
var<uniform> params: Params;
#enddecl(BASE_BINDINGS_INPLACE)
#decl(SINK_BINDINGS)
@group(0) @binding(1)
var<storage, read_write> sinks: array<f32>;
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(SINK_BINDINGS)
#decl(SINK_BINDINGS_INPLACE)
@group(0) @binding(1)
var<storage, read_write> sinks: array<f32>;
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(SINK_BINDINGS_INPLACE)
#decl(MASK_BINDINGS)
@group(0) @binding(1)
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(MASK_BINDINGS)
#decl(MASK_BINDINGS_INPLACE)
@group(0) @binding(1)
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(MASK_BINDINGS_INPLACE)
#decl(MASK_SINK_BINDINGS)
@group(0) @binding(1)
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> sinks: array<f32>;
@group(0) @binding(3)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(4)
var<uniform> params: Params;
#enddecl(MASK_SINK_BINDINGS)
#decl(MASK_SINK_BINDINGS_INPLACE)
@group(0) @binding(1)
var<storage, read_write> mask: array<{{MASK_TYPE}}>;
@group(0) @binding(2)
var<storage, read_write> sinks: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(MASK_SINK_BINDINGS_INPLACE)
#decl(NOT_INPLACE)
fn inter_value(i: u32) -> f32 {
return dst[i];
}
fn update(i: u32, val: f32) {
dst[i] = val;
}
#enddecl(NOT_INPLACE)
#decl(INPLACE)
fn inter_value(i: u32) -> f32 {
return src[i];
}
fn update(i: u32, val: f32) {
src[i] = val;
}
#enddecl(INPLACE)
#decl(NO_MASK)
fn mask_val(i: u32) -> f32 {
return 0.0;
}
#enddecl(NO_MASK)
#decl(MASK)
fn mask_val(i: u32) -> f32 {
return f32(mask[i]);
}
#enddecl(MASK)
#decl(NO_SINK)
fn lower_max_bound(i2: u32) -> f32 {
return -1e30;
}
fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
return val;
}
#enddecl(NO_SINK)
#decl(SINK)
fn lower_max_bound(i2: u32) -> f32 {
return sinks[params.offset_sinks + i2];
}
fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
return val + exp(sinks[params.offset_sinks + i2] - max_val);
}
#enddecl(SINK)
#end(DECLS)
#define(SHADER)
enable f16;
#ifdef MASK_F32
#define MaskType f32
#endif
#ifdef MASK_F16
#define MaskType f16
#endif
struct Params {
offset_src0: u32,
offset_src1: u32,
@ -249,14 +46,117 @@ struct Params {
@group(0) @binding(0)
var<storage, read_write> src: array<f32>;
DECLS
#ifdef HAS_MASK
#ifdef HAS_SINK
@group(0) @binding(1)
var<storage, read_write> mask: array<MaskType>;
@group(0) @binding(2)
var<storage, read_write> sinks: array<f32>;
#ifdef INPLACE
@group(0) @binding(3)
var<uniform> params: Params;
#else
@group(0) @binding(3)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(4)
var<uniform> params: Params;
#endif
#else
@group(0) @binding(1)
var<storage, read_write> mask: array<MaskType>;
#ifdef INPLACE
@group(0) @binding(2)
var<uniform> params: Params;
#else
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#endif
#endif
#else
#ifdef HAS_SINK
@group(0) @binding(1)
var<storage, read_write> sinks: array<f32>;
#ifdef INPLACE
@group(0) @binding(2)
var<uniform> params: Params;
#else
@group(0) @binding(2)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(3)
var<uniform> params: Params;
#endif
#else
#ifdef INPLACE
@group(0) @binding(1)
var<uniform> params: Params;
#else
@group(0) @binding(1)
var<storage, read_write> dst: array<f32>;
@group(0) @binding(2)
var<uniform> params: Params;
#endif
#endif
#endif
#ifdef INPLACE
fn inter_value(i: u32) -> f32 {
return src[i];
}
fn update(i: u32, val: f32) {
src[i] = val;
}
#else
fn inter_value(i: u32) -> f32 {
return dst[i];
}
fn update(i: u32, val: f32) {
dst[i] = val;
}
#endif
#ifdef HAS_MASK
fn mask_val(i: u32) -> f32 {
return f32(mask[i]);
}
#else
fn mask_val(i: u32) -> f32 {
return 0.0;
}
#endif
#ifdef HAS_SINK
fn lower_max_bound(i2: u32) -> f32 {
return sinks[params.offset_sinks + i2];
}
fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
return val + exp(sinks[params.offset_sinks + i2] - max_val);
}
#else
fn lower_max_bound(i2: u32) -> f32 {
return -1e30;
}
fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
return val;
}
#endif
const CACHE_SIZE: u32 = 16;
var<workgroup> scratch: array<f32, WG_SIZE>;
override wg_size: u32;
var<workgroup> scratch: array<f32, wg_size>;
@compute @workgroup_size(wg_size)
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
@ -268,7 +168,7 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11;
let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
let elems = (params.ne0 + wg_size - 1) / wg_size;
let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
let head = f32(i2);
let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0);
@ -286,12 +186,12 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
if (col < CACHE_SIZE) {
cache[col] = val;
}
col += wg_size;
col += WG_SIZE;
}
scratch[lid.x] = max_val;
workgroupBarrier();
var offset = wg_size / 2;
var offset: u32 = WG_SIZE / 2;
while (offset > 0) {
if (lid.x < offset) {
scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]);
@ -317,12 +217,12 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
} else {
update(i_dst_row + col, ex);
}
col += wg_size;
col += WG_SIZE;
}
scratch[lid.x] = sum;
workgroupBarrier();
offset = wg_size / 2;
offset = WG_SIZE / 2;
while (offset > 0) {
if (lid.x < offset) {
scratch[lid.x] += scratch[lid.x + offset];
@ -339,7 +239,7 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
break;
}
update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip);
col += wg_size;
col += WG_SIZE;
}
}
#end(SHADER)