From 6e0691e5bd5e38fe95bde9953b17c3cd005f9ce9 Mon Sep 17 00:00:00 2001 From: Abhijit Ramesh Date: Sat, 21 Mar 2026 23:00:09 -0700 Subject: [PATCH] ggml-webgpu: port all AOT opeartors to JIT * port cpy pipeline to shader lib with JIT compilation * port glu pipeline to shader lib with JIT compilation * port rope pipeline to shader lib with JIT compilation * port soft_max pipeline to shader lib with JIT compilation * removed unused functions from embed_wgsl.py which were used for old AOT template expansion --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 325 +++++++++++++++++ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 224 +++--------- .../wgsl-shaders/{cpy.tmpl.wgsl => cpy.wgsl} | 74 ++-- .../ggml-webgpu/wgsl-shaders/embed_wgsl.py | 107 +----- .../ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl | 323 ---------------- ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl | 155 ++++++++ .../{rope.tmpl.wgsl => rope.wgsl} | 217 ++++------- .../{soft_max.tmpl.wgsl => soft_max.wgsl} | 344 +++++++----------- 8 files changed, 766 insertions(+), 1003 deletions(-) rename ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl => cpy.wgsl} (60%) delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl rename ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl => rope.wgsl} (72%) rename ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl => soft_max.wgsl} (59%) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 59861ac16c..97863f4041 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -535,6 +535,95 @@ struct ggml_webgpu_mul_mat_shader_decisions { uint32_t mul_mat_wg_size; }; +/** Cpy **/ + +struct ggml_webgpu_cpy_pipeline_key { + ggml_type src_type; + ggml_type dst_type; + + bool operator==(const ggml_webgpu_cpy_pipeline_key & other) const { + return src_type == other.src_type && dst_type == other.dst_type; + } +}; + +struct ggml_webgpu_cpy_pipeline_key_hash { + size_t operator()(const ggml_webgpu_cpy_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src_type); + ggml_webgpu_hash_combine(seed, key.dst_type); + return seed; + } +}; + +/** Glu **/ + +struct ggml_webgpu_glu_pipeline_key { + ggml_glu_op glu_op; + ggml_type type; + bool split; + + bool operator==(const ggml_webgpu_glu_pipeline_key & other) const { + return glu_op == other.glu_op && type == other.type && split == other.split; + } +}; + +struct ggml_webgpu_glu_pipeline_key_hash { + size_t operator()(const ggml_webgpu_glu_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.glu_op); + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.split); + return seed; + } +}; + +/** Rope **/ + +struct ggml_webgpu_rope_pipeline_key { + ggml_type type; + bool inplace; + bool has_ff; + + bool operator==(const ggml_webgpu_rope_pipeline_key & other) const { + return type == other.type && inplace == other.inplace && has_ff == other.has_ff; + } +}; + +struct ggml_webgpu_rope_pipeline_key_hash { + size_t operator()(const ggml_webgpu_rope_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.has_ff); + return seed; + } +}; + +/** SoftMax **/ + +struct ggml_webgpu_soft_max_pipeline_key { + ggml_type mask_type; + bool has_mask; + bool has_sink; + bool inplace; + + bool operator==(const ggml_webgpu_soft_max_pipeline_key & other) const { + return mask_type == other.mask_type && has_mask == other.has_mask && has_sink == other.has_sink && + inplace == other.inplace; + } +}; + +struct ggml_webgpu_soft_max_pipeline_key_hash { + size_t operator()(const ggml_webgpu_soft_max_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.mask_type); + ggml_webgpu_hash_combine(seed, key.has_mask); + ggml_webgpu_hash_combine(seed, key.has_sink); + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + class ggml_webgpu_shader_lib { wgpu::Device device; pre_wgsl::Preprocessor preprocessor; @@ -582,6 +671,12 @@ class ggml_webgpu_shader_lib { std::unordered_map set_rows_pipelines; std::unordered_map set_pipelines; + std::unordered_map cpy_pipelines; + std::unordered_map glu_pipelines; + std::unordered_map + rope_pipelines; + std::unordered_map + soft_max_pipelines; public: ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } @@ -1679,6 +1774,236 @@ class ggml_webgpu_shader_lib { return flash_attn_pipelines[key]; } + webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_cpy_pipeline_key key = { + .src_type = context.src0->type, + .dst_type = context.dst->type, + }; + + auto it = cpy_pipelines.find(key); + if (it != cpy_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "cpy"; + + switch (key.src_type) { + case GGML_TYPE_F32: + defines.push_back("SRC_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported src type for cpy shader"); + } + + switch (key.dst_type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); + variant += "_f16"; + break; + case GGML_TYPE_I32: + defines.push_back("DST_I32"); + variant += "_i32"; + break; + default: + GGML_ABORT("Unsupported dst type for cpy shader"); + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_cpy, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + cpy_pipelines[key] = pipeline; + return cpy_pipelines[key]; + } + + webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_glu_pipeline_key key = { + .glu_op = ggml_get_glu_op(context.dst), + .type = context.dst->type, + .split = (context.src1 != nullptr), + }; + + auto it = glu_pipelines.find(key); + if (it != glu_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "glu"; + + switch (key.glu_op) { + case GGML_GLU_OP_REGLU: + defines.push_back("OP_REGLU"); + variant += "_reglu"; + break; + case GGML_GLU_OP_GEGLU: + defines.push_back("OP_GEGLU"); + variant += "_geglu"; + break; + case GGML_GLU_OP_SWIGLU: + defines.push_back("OP_SWIGLU"); + variant += "_swiglu"; + break; + case GGML_GLU_OP_SWIGLU_OAI: + defines.push_back("OP_SWIGLU_OAI"); + variant += "_swiglu_oai"; + break; + case GGML_GLU_OP_GEGLU_ERF: + defines.push_back("OP_GEGLU_ERF"); + variant += "_geglu_erf"; + break; + case GGML_GLU_OP_GEGLU_QUICK: + defines.push_back("OP_GEGLU_QUICK"); + variant += "_geglu_quick"; + break; + default: + GGML_ABORT("Unsupported GLU op"); + } + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for GLU shader"); + } + + if (key.split) { + variant += "_split"; + } else { + defines.push_back("NO_SPLIT"); + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_glu, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + glu_pipelines[key] = pipeline; + return glu_pipelines[key]; + } + + webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_rope_pipeline_key key = { + .type = context.dst->type, + .inplace = context.inplace, + .has_ff = (context.src2 != nullptr), + }; + + auto it = rope_pipelines.find(key); + if (it != rope_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "rope"; + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for ROPE shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + if (key.has_ff) { + defines.push_back("FF_FUNC"); + variant += "_ff"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_rope, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + rope_pipelines[key] = pipeline; + return rope_pipelines[key]; + } + + webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_soft_max_pipeline_key key = { + .mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32, + .has_mask = (context.src1 != nullptr), + .has_sink = (context.src2 != nullptr), + .inplace = context.inplace, + }; + + auto it = soft_max_pipelines.find(key); + if (it != soft_max_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "soft_max"; + + if (key.has_mask) { + defines.push_back("HAS_MASK"); + switch (key.mask_type) { + case GGML_TYPE_F32: + defines.push_back("MASK_F32"); + variant += "_mask_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("MASK_F16"); + variant += "_mask_f16"; + break; + default: + GGML_ABORT("Unsupported type for SOFT_MAX shader"); + } + } + + if (key.has_sink) { + defines.push_back("HAS_SINK"); + variant += "_sink"; + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_soft_max, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + soft_max_pipelines[key] = pipeline; + return soft_max_pipelines[key]; + } + private: static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, std::string shader_code, diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 5e16f84ddd..fa3c492a7a 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -364,13 +364,6 @@ struct webgpu_context_struct { wgpu::Buffer set_rows_dev_error_buf; wgpu::Buffer set_rows_host_error_buf; - std::map> cpy_pipelines; // src_type, dst_type - - std::map>> rope_pipelines; // type, ff, inplace - std::map>> glu_pipelines; // glu_op, type, split - - std::map>> soft_max_pipelines; // mask_type, has_sink, inplace - size_t memset_bytes_per_thread; }; @@ -849,6 +842,16 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0 } static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_cpy_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + uint32_t ne = (uint32_t) ggml_nelements(dst); std::vector params = { @@ -875,9 +878,8 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->cpy_pipelines[src->type][dst->type], - params, entries, wg_x); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { @@ -1914,6 +1916,19 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .src2 = src2, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .inplace = ggml_webgpu_tensor_equal(src0, dst), + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + const int inplace = ggml_webgpu_tensor_equal(src0, dst); const int has_freq_factor = (src2 != nullptr); @@ -1996,12 +2011,22 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace]; - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_glu_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + const int split = (src1 != nullptr); std::vector params = { @@ -2048,8 +2073,7 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split]; - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -2109,9 +2133,20 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { - const int inplace = ggml_webgpu_tensor_equal(src0, dst); - const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here - const int has_sink = (src2 != nullptr); + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .src2 = src2, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .inplace = ggml_webgpu_tensor_equal(src0, dst), + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx); + + const int inplace = ggml_webgpu_tensor_equal(src0, dst); + const int has_mask = (src1 != nullptr); + const int has_sink = (src2 != nullptr); float max_bias; memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); @@ -2120,15 +2155,15 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), - mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0, has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), - mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0, - mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0, - mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0, (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), @@ -2136,8 +2171,8 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], (uint32_t) src0->ne[2], - mask_type < 2 ? (uint32_t) src1->ne[2] : 0, - mask_type < 2 ? (uint32_t) src1->ne[3] : 0, + has_mask ? (uint32_t) src1->ne[2] : 0, + has_mask ? (uint32_t) src1->ne[3] : 0, *(uint32_t *) dst->op_params, // scale *(uint32_t *) &max_bias, *(uint32_t *) &n_head_log2, @@ -2152,7 +2187,7 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, src0) } }; uint32_t binding_num = 1; - if (mask_type < 2) { + if (has_mask) { entries.push_back({ .binding = binding_num, .buffer = ggml_webgpu_tensor_buf(src1), .offset = ggml_webgpu_tensor_align_offset(ctx, src1), @@ -2173,9 +2208,7 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, - ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries, - ggml_nrows(dst)); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(dst)); } static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { @@ -2885,139 +2918,6 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); } -static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants); -} - -static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, "rope_f32", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants); - - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, "rope_f16", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants); -} - -static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - // REGLU - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, "reglu_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, "reglu_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants); - - // GEGLU - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, "geglu_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, "geglu_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants); - - // SWIGLU - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants); - - // SWIGLU_OAI - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants); - - // GEGLU_ERF - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants); - - // GEGLU_QUICK - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); -} - -static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); - - // f32 (no mask) - webgpu_ctx->soft_max_pipelines[2][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants); - webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants); - webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants); - webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants); - - // f32 mask (mask_type = 0) - webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants); - webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants); - webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants); - webgpu_ctx->soft_max_pipelines[0][1][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace, - "soft_max_f32_mask_f32_sink_inplace", constants); - - // f16 mask (mask_type = 1) - webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants); - webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants); - webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants); - webgpu_ctx->soft_max_pipelines[1][1][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, - "soft_max_f32_mask_f16_sink_inplace", constants); -} - static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { wgpu::RequestAdapterOptions options = {}; @@ -3183,10 +3083,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf"); - ggml_webgpu_init_cpy_pipeline(webgpu_ctx); - ggml_webgpu_init_rope_pipeline(webgpu_ctx); - ggml_webgpu_init_glu_pipeline(webgpu_ctx); - ggml_webgpu_init_soft_max_pipeline(webgpu_ctx); #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl similarity index 60% rename from ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl index b5e93b812f..fa3bdf4e39 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl @@ -1,66 +1,41 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "SRC_TYPE": "f32", - "DST_TYPE": "f32" - } - }, - { - "REPLS": { - "SRC_TYPE": "f32", - "DST_TYPE": "i32" - } - }, - { - "REPLS": { - "SRC_TYPE": "f32", - "DST_TYPE": "f16" - } - }, - { - "REPLS": { - "SRC_TYPE": "f16", - "DST_TYPE": "f16" - } - }, - { - "REPLS": { - "SRC_TYPE": "f16", - "DST_TYPE": "f32" - } - } -] - -#end(VARIANTS) - -#define(SHADER) enable f16; +#ifdef SRC_F32 +#define SRC_TYPE f32 +#elif defined(SRC_F16) +#define SRC_TYPE f16 +#endif + +#ifdef DST_F32 +#define DST_TYPE f32 +#elif defined(DST_F16) +#define DST_TYPE f16 +#elif defined(DST_I32) +#define DST_TYPE i32 +#endif + @group(0) @binding(0) -var src: array<{{SRC_TYPE}}>; +var src: array; @group(0) @binding(1) -var dst: array<{{DST_TYPE}}>; +var dst: array; -struct Params { - ne: u32, // total number of elements - offset_src: u32, // in elements - offset_dst: u32, // in elements +struct Params{ + ne: u32, + offset_src: u32, + offset_dst: u32, - // Strides (in elements) — may be permuted stride_src0: u32, stride_src1: u32, stride_src2: u32, stride_src3: u32, + stride_dst0: u32, stride_dst1: u32, stride_dst2: u32, stride_dst3: u32, - // Logical shapes src_ne0: u32, src_ne1: u32, src_ne2: u32, @@ -73,8 +48,7 @@ struct Params { @group(0) @binding(2) var params: Params; -override wg_size: u32; -@compute @workgroup_size(wg_size) +@compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { if (gid.x >= params.ne) { return; @@ -102,6 +76,6 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + j2 * params.stride_dst2 + j3 * params.stride_dst3; - dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx])); + dst[params.offset_dst + dst_idx] = DST_TYPE((src[params.offset_src + src_idx])); } -#end(SHADER) + diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index 8b5cfe715e..79a3a9597a 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -1,41 +1,8 @@ import os import re -import ast import argparse -def extract_block(text, name): - pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)' - match = re.search(pattern, text, re.DOTALL) - if not match: - raise ValueError(f"Missing block: {name}") - return match.group(1).strip() - - -def parse_decls(decls_text): - decls = {} - for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL): - decls[name.strip()] = code.strip() - return decls - - -def replace_repl_placeholders(variant, template_map): - for repl, code in variant["REPLS"].items(): - for key, val in template_map.items(): - # Match "key" and avoid matching subsequences using by using \b - code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code) - variant["REPLS"][repl] = code - return variant - - -def replace_placeholders(shader_text, replacements): - for key, val in replacements.items(): - # Match {{KEY}} literally, where KEY is escaped - pattern = r'{{\s*' + re.escape(key) + r'\s*}}' - shader_text = re.sub(pattern, str(val), shader_text) - return shader_text - - def expand_includes(shader, input_dir): """ Replace #include "file" lines in the text with the contents of that file. @@ -98,84 +65,24 @@ def write_shader(shader_name, shader_code, output_dir, outfile, input_dir): outfile.write(f'const char* wgsl_{shader_name} = wgsl_{shader_name}_str().c_str();\n\n') -def generate_variants(fname, input_dir, output_dir, outfile): - shader_path = os.path.join(input_dir, fname) - shader_base_name = fname.split(".")[0] - - with open(shader_path, "r", encoding="utf-8") as f: - text = f.read() - - try: - variants = ast.literal_eval(extract_block(text, "VARIANTS")) - except ValueError: - write_shader(shader_base_name, text, output_dir, outfile, input_dir) - else: - try: - decls_map = parse_decls(extract_block(text, "DECLS")) - except ValueError: - decls_map = {} - try: - templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES")) - except ValueError: - templates_map = {} - - for fname in sorted(os.listdir(input_dir)): - if fname.endswith(".tmpl"): - tmpl_path = os.path.join(input_dir, fname) - with open(tmpl_path, "r", encoding="utf-8") as f_tmpl: - decls = f_tmpl.read() - decls_map.update(parse_decls(decls)) - - shader_template = extract_block(text, "SHADER") - for variant in variants: - if "DECLS" in variant: - decls = variant["DECLS"] - else: - decls = [] - decls_code = "" - for key in decls: - if key not in decls_map: - raise ValueError(f"DECLS key '{key}' not found.") - decls_code += decls_map[key] + "\n\n" - final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template) - if "REPLS" in variant: - variant = replace_repl_placeholders(variant, templates_map) - final_shader = replace_placeholders(final_shader, variant["REPLS"]) - # second run to expand placeholders in repl_template - final_shader = replace_placeholders(final_shader, variant["REPLS"]) - final_shader = expand_includes(final_shader, input_dir) - - if "SHADER_NAME" in variant: - output_name = variant["SHADER_NAME"] - elif "SHADER_SUFFIX" in variant: - output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"] - elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]: - output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]]) - elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]: - output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]]) - elif "REPLS" in variant and "TYPE" in variant["REPLS"]: - output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"] - else: - output_name = shader_base_name - write_shader(output_name, final_shader, output_dir, outfile, input_dir) - - def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", required=True) parser.add_argument("--output_file", required=True) - parser.add_argument("--output_dir") args = parser.parse_args() - if args.output_dir: - os.makedirs(args.output_dir, exist_ok=True) - with open(args.output_file, "w", encoding="utf-8") as out: out.write("// Auto-generated shader embedding\n") out.write("#include \n\n") for fname in sorted(os.listdir(args.input_dir)): if fname.endswith(".wgsl"): - generate_variants(fname, args.input_dir, args.output_dir, out) + shader_path = os.path.join(args.input_dir, fname) + shader_name = fname.replace(".wgsl", "") + + with open(shader_path, "r", encoding="utf-8") as f: + shader_code = f.read() + + write_shader(shader_name, shader_code, None, out, args.input_dir) if __name__ == "__main__": diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl deleted file mode 100644 index 03fcd54868..0000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +++ /dev/null @@ -1,323 +0,0 @@ -#define(VARIANTS) - -[ - { - "SHADER_NAME": "reglu_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "reglu_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "reglu_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "reglu_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "geglu_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "geglu_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "geglu_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "geglu_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "swiglu_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_oai_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "SWIGLU_OAI"] - }, - { - "SHADER_NAME": "swiglu_oai_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "SWIGLU_OAI"] - }, - { - "SHADER_NAME": "geglu_erf_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_erf_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_erf_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_erf_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_quick_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "GEGLU_QUICK"] - }, - { - "SHADER_NAME": "geglu_quick_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "GEGLU_QUICK"] - }, - { - "SHADER_NAME": "geglu_quick_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "GEGLU_QUICK"] - }, - { - "SHADER_NAME": "geglu_quick_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "GEGLU_QUICK"] - }, -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(REGLU) -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - return max(a, 0) * b; -} -#enddecl(REGLU) - -#decl(GEGLU) -const SQRT_2_OVER_PI: {{TYPE}} = 0.79788456080286535587989211986876; -const GELU_COEF_A: {{TYPE}} = 0.044715; - -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a); - return 0.5 * a * (2.0 - 2.0 / (exp(2 * val) + 1)) * b; -} -#enddecl(GEGLU) - -#decl(SWIGLU) -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - return a / (1.0 + exp(-a)) * b; -} -#enddecl(SWIGLU) - -#decl(SWIGLU_OAI) -fn op(a: f32, b: f32) -> f32 { - let xi = min(a, params.limit); - let gi = max(min(b, params.limit), -params.limit); - var out_glu = xi / (1.0 + exp(-xi * params.alpha)); - out_glu = out_glu * (1.0 + gi); - return out_glu; -} -#enddecl(SWIGLU_OAI) - -#decl(GEGLU_ERF) -const p_erf: {{TYPE}} = 0.3275911; -const a1_erf: {{TYPE}} = 0.254829592; -const a2_erf: {{TYPE}} = -0.284496736; -const a3_erf: {{TYPE}} = 1.421413741; -const a4_erf: {{TYPE}} = -1.453152027; -const a5_erf: {{TYPE}} = 1.061405429; -const SQRT_2_INV: {{TYPE}} = 0.7071067811865476; - -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - let a_div_sqr2 = a * SQRT_2_INV; - let sign_x = sign(a_div_sqr2); - let x = abs(a_div_sqr2); - let t = 1.0 / (1.0 + p_erf * x); - let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x)); - let erf_approx = sign_x * y; - return 0.5 * a * (1.0 + erf_approx) * b; -} -#enddecl(GEGLU_ERF) - -#decl(GEGLU_QUICK) -const GELU_QUICK_COEF: {{TYPE}} = -1.702; - -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b; -} -#enddecl(GEGLU_QUICK) - -#decl(NO_SPLIT) -@group(0) @binding(1) -var dst: array<{{TYPE}}>; - -@group(0) @binding(2) -var params: Params; - -fn a_value(base: u32) -> {{TYPE}} { - let offset: u32 = select(0, params.ne0, params.swapped != 0); - return src0[base + offset]; -} - -fn b_value(base: u32) -> {{TYPE}} { - let offset: u32 = select(params.ne0, 0, params.swapped != 0); - return src0[base + offset]; -} -#enddecl(NO_SPLIT) - -#decl(SPLIT) -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -@group(0) @binding(2) -var dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var params: Params; - -fn a_value(base: u32) -> {{TYPE}} { - return src0[base]; -} - -fn b_value(base: u32) -> {{TYPE}} { - return src1[base]; -} -#enddecl(SPLIT) - -#end(DECLS) - -#define(SHADER) - -enable f16; - -struct Params { - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - - // Strides (in elements) - stride_src01: u32, - stride_src02: u32, - stride_src03: u32, - - stride_src11: u32, - stride_src12: u32, - stride_src13: u32, - - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // shape of dst - ne: u32, - ne0: u32, - ne1: u32, - ne2: u32, - - swapped: u32, - alpha: f32, - limit: f32, -} - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -DECLS - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne) { - return; - } - - var i = gid.x; - let i3 = i / (params.ne2 * params.ne1 * params.ne0); - i = i % (params.ne2 * params.ne1 * params.ne0); - let i2 = i / (params.ne1 * params.ne0); - i = i % (params.ne1 * params.ne0); - let i1 = i / params.ne0; - let i0 = i % params.ne0; - - let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0; - let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0; - let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; - - dst[i_dst] = op(a_value(i_a), b_value(i_b)); -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl new file mode 100644 index 0000000000..e6d7608cec --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl @@ -0,0 +1,155 @@ +enable f16; + +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_F16 +#define DataType f16 +#endif + +#ifdef OP_REGLU +fn op(a: DataType, b: DataType) -> DataType { + return max(a, 0) * b; +} +#endif + +#ifdef OP_GEGLU +const SQRT_2_OVER_PI: DataType = 0.79788456080286535587989211986876; +const GELU_COEF_A: DataType = 0.044715; + +fn op(a: DataType, b: DataType) -> DataType { + let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a); + return 0.5 * a * (2.0 - 2.0/ (exp(2* val) + 1)) * b; +} +#endif + +#ifdef OP_SWIGLU +fn op(a: DataType, b: DataType) -> DataType { + return a / (1.0 + exp(-a)) * b; +} +#endif +#ifdef OP_SWIGLU_OAI +fn op(a: f32, b: f32) -> f32 { + let xi = min(a, params.limit); + let gi = max(min(b, params.limit), -params.limit); + var out_glu = xi / (1.0 + exp(-xi * params.alpha)); + out_glu = out_glu * (1.0 + gi); + return out_glu; +} +#endif +#ifdef OP_GEGLU_ERF +const p_erf: DataType = 0.3275911; +const a1_erf: DataType = 0.254829592; +const a2_erf: DataType = -0.284496736; +const a3_erf: DataType = 1.421413741; +const a4_erf: DataType = -1.453152027; +const a5_erf: DataType = 1.061405429; +const SQRT_2_INV: DataType = 0.7071067811865476; + +fn op(a: DataType, b: DataType) -> DataType { + let a_div_sqr2 = a * SQRT_2_INV; + let sign_x = sign(a_div_sqr2); + let x = abs(a_div_sqr2); + let t = 1.0 / (1.0 + p_erf * x); + let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x)); + let erf_approx = sign_x * y; + return 0.5 * a * (1.0 + erf_approx) * b; +} +#endif +#ifdef OP_GEGLU_QUICK +const GELU_QUICK_COEF: DataType = -1.702; + +fn op(a: DataType, b: DataType) -> DataType { + return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b; +} +#endif + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // shape of dst + ne: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + swapped: u32, + alpha: f32, + limit: f32, +} + +@group(0) @binding(0) +var src0: array; + +#ifdef NO_SPLIT +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; + +fn a_value(base: u32) -> DataType { + let offset: u32 = select(0, params.ne0, params.swapped != 0); + return src0[base + offset]; +} + +fn b_value(base: u32) -> DataType { + let offset: u32 = select(params.ne0, 0, params.swapped != 0); + return src0[base + offset]; +} + +#else +@group(0) @binding(1) +var src1: array; + +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; + +fn a_value(base: u32) -> DataType { + return src0[base]; +} + +fn b_value(base: u32) -> DataType { + return src1[base]; +} + +#endif + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0; + let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0; + let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; + + dst[i_dst] = op(a_value(i_a), b_value(i_b)); +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rope.wgsl similarity index 72% rename from ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/rope.wgsl index 84dc8dbff6..3b28c4ea98 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rope.wgsl @@ -1,138 +1,12 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f32_inplace", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"] - }, - { - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f16_inplace", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"] - }, - { - "SHADER_SUFFIX": "f32_ff", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f32_ff_inplace", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"] - }, - { - "SHADER_SUFFIX": "f16_ff", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f16_ff_inplace", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(ROTATE) -fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { - dst[i_dst0] = {{TYPE}}(out0); - dst[i_dst1] = {{TYPE}}(out1); -} -#enddecl(ROTATE) - -#decl(ROTATE_INPLACE) -fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { - src0[i_dst0] = {{TYPE}}(out0); - src0[i_dst1] = {{TYPE}}(out1); -} -#enddecl(ROTATE_INPLACE) - -#decl(NO_FF_FUNC) -fn freq_factor(i: u32) -> f32 { - return 1.0f; -} -#enddecl(NO_FF_FUNC) - -#decl(FF_FUNC) -fn freq_factor(i: u32) -> f32 { - return src2[params.offset_src2 + i/2]; -} -#enddecl(FF_FUNC) - -#decl(NO_FF_BINDINGS) - -@group(0) @binding(2) -var dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var params: Params; - -#enddecl(NO_FF_BINDINGS) - -#decl(NO_FF_BINDINGS_INPLACE) - -@group(0) @binding(2) -var params: Params; - -#enddecl(NO_FF_BINDINGS_INPLACE) - -#decl(FF_BINDINGS) - -@group(0) @binding(2) -var src2: array; - -@group(0) @binding(3) -var dst: array<{{TYPE}}>; - -@group(0) @binding(4) -var params: Params; - -#enddecl(FF_BINDINGS) - -#decl(FF_BINDINGS_INPLACE) - -@group(0) @binding(2) -var src2: array; - -@group(0) @binding(3) -var params: Params; - -#enddecl(FF_BINDINGS_INPLACE) - -#end(DECLS) - -#define(SHADER) - enable f16; +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_F16 +#define DataType f16 +#endif + struct Params { offset_src0: u32, offset_src1: u32, @@ -168,12 +42,69 @@ struct Params { }; @group(0) @binding(0) -var src0: array<{{TYPE}}>; - +var src0: array; @group(0) @binding(1) var src1: array; -DECLS +#ifdef INPLACE + +#ifdef FF_FUNC + +@group(0) @binding(2) +var src2: array; + +@group(0) @binding(3) +var params: Params; + +#else + +@group(0) @binding(2) +var params: Params; + +#endif + +#else + +#ifdef FF_FUNC +@group(0) @binding(2) +var src2: array; + +@group(0) @binding(3) +var dst: array; + +@group(0) @binding(4) +var params: Params; + +#else +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; +#endif +#endif + +#ifdef FF_FUNC +fn freq_factor(i: u32) -> f32 { + return src2[params.offset_src2 + i/2]; +} + +#else +fn freq_factor(i: u32) -> f32 { + return 1.0f; +} +#endif +#ifdef INPLACE +fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { + src0[i_dst0] = DataType(out0); + src0[i_dst1] = DataType(out1); +} +#else +fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { + dst[i_dst0] = DataType(out0); + dst[i_dst1] = DataType(out1); +} +#endif fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 { let y = (f32(i / 2) - low) / max(0.001f, high - low); @@ -184,7 +115,7 @@ fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 { // TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row fn rope_yarn(theta_extrap: f32, i: u32) -> vec2 { var mscale = params.attn_factor; - var theta = params.freq_scale * theta_extrap; + var theta = params.freq_scale * theta_extrap; if (params.ext_factor != 0.0f) { let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor; theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix; @@ -207,14 +138,13 @@ fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 { } else if (is_neox || is_mrope) { return params.n_dims / 2; } else { - return 1; + return 1; } } -override wg_size: u32; -@compute @workgroup_size(wg_size) +@compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { - // two elements per thread + // two elements per n_threads if (gid.x >= params.n_threads) { return; } @@ -235,7 +165,7 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01; let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; - if (i0 >= params.n_dims && !is_vision) { + if (i0 >= params.n_dims && !is_vision) { let i_src = i_src_row + i0; let i_dst = i_dst_row + i0; rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1])); @@ -290,6 +220,5 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let x0 = f32(src0[i_src]); let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]); rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x); -} -#end(SHADER) +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.wgsl similarity index 59% rename from ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/soft_max.wgsl index c74dc4cc92..10edf13604 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.wgsl @@ -1,215 +1,12 @@ -#define(VARIANTS) -[ - { - "SHADER_NAME": "soft_max_f32", - "DECLS": ["BASE_BINDINGS", "NOT_INPLACE", "NO_MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_inplace", - "DECLS": ["BASE_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_sink", - "DECLS": ["SINK_BINDINGS", "NOT_INPLACE", "NO_MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_sink_inplace", - "DECLS": ["SINK_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32_inplace", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16_inplace", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32_sink", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32_sink_inplace", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16_sink", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16_sink_inplace", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"] - } -] -#end(VARIANTS) - -#define(DECLS) - -#decl(BASE_BINDINGS) -@group(0) @binding(1) -var dst: array; - -@group(0) @binding(2) -var params: Params; -#enddecl(BASE_BINDINGS) - -#decl(BASE_BINDINGS_INPLACE) -@group(0) @binding(1) -var params: Params; -#enddecl(BASE_BINDINGS_INPLACE) - -#decl(SINK_BINDINGS) -@group(0) @binding(1) -var sinks: array; - -@group(0) @binding(2) -var dst: array; - -@group(0) @binding(3) -var params: Params; -#enddecl(SINK_BINDINGS) - -#decl(SINK_BINDINGS_INPLACE) -@group(0) @binding(1) -var sinks: array; - -@group(0) @binding(2) -var params: Params; -#enddecl(SINK_BINDINGS_INPLACE) - -#decl(MASK_BINDINGS) -@group(0) @binding(1) -var mask: array<{{MASK_TYPE}}>; - -@group(0) @binding(2) -var dst: array; - -@group(0) @binding(3) -var params: Params; -#enddecl(MASK_BINDINGS) - -#decl(MASK_BINDINGS_INPLACE) -@group(0) @binding(1) -var mask: array<{{MASK_TYPE}}>; - -@group(0) @binding(2) -var params: Params; -#enddecl(MASK_BINDINGS_INPLACE) - -#decl(MASK_SINK_BINDINGS) -@group(0) @binding(1) -var mask: array<{{MASK_TYPE}}>; - -@group(0) @binding(2) -var sinks: array; - -@group(0) @binding(3) -var dst: array; - -@group(0) @binding(4) -var params: Params; -#enddecl(MASK_SINK_BINDINGS) - -#decl(MASK_SINK_BINDINGS_INPLACE) -@group(0) @binding(1) -var mask: array<{{MASK_TYPE}}>; - -@group(0) @binding(2) -var sinks: array; - -@group(0) @binding(3) -var params: Params; -#enddecl(MASK_SINK_BINDINGS_INPLACE) - -#decl(NOT_INPLACE) -fn inter_value(i: u32) -> f32 { - return dst[i]; -} - -fn update(i: u32, val: f32) { - dst[i] = val; -} -#enddecl(NOT_INPLACE) - -#decl(INPLACE) -fn inter_value(i: u32) -> f32 { - return src[i]; -} - -fn update(i: u32, val: f32) { - src[i] = val; -} -#enddecl(INPLACE) - -#decl(NO_MASK) -fn mask_val(i: u32) -> f32 { - return 0.0; -} -#enddecl(NO_MASK) - -#decl(MASK) -fn mask_val(i: u32) -> f32 { - return f32(mask[i]); -} -#enddecl(MASK) - -#decl(NO_SINK) -fn lower_max_bound(i2: u32) -> f32 { - return -1e30; -} - -fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { - return val; -} -#enddecl(NO_SINK) - -#decl(SINK) -fn lower_max_bound(i2: u32) -> f32 { - return sinks[params.offset_sinks + i2]; -} - -fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { - return val + exp(sinks[params.offset_sinks + i2] - max_val); -} -#enddecl(SINK) - -#end(DECLS) - -#define(SHADER) enable f16; +#ifdef MASK_F32 +#define MaskType f32 +#endif +#ifdef MASK_F16 +#define MaskType f16 +#endif + struct Params { offset_src0: u32, offset_src1: u32, @@ -249,14 +46,117 @@ struct Params { @group(0) @binding(0) var src: array; -DECLS +#ifdef HAS_MASK +#ifdef HAS_SINK +@group(0) @binding(1) +var mask: array; +@group(0) @binding(2) +var sinks: array; + +#ifdef INPLACE +@group(0) @binding(3) +var params: Params; + +#else +@group(0) @binding(3) +var dst: array; +@group(0) @binding(4) +var params: Params; +#endif + +#else +@group(0) @binding(1) +var mask: array; + +#ifdef INPLACE +@group(0) @binding(2) +var params: Params; + +#else +@group(0) @binding(2) +var dst: array; +@group(0) @binding(3) +var params: Params; +#endif +#endif + +#else +#ifdef HAS_SINK +@group(0) @binding(1) +var sinks: array; + +#ifdef INPLACE +@group(0) @binding(2) +var params: Params; + +#else +@group(0) @binding(2) +var dst: array; +@group(0) @binding(3) +var params: Params; +#endif + +#else +#ifdef INPLACE +@group(0) @binding(1) +var params: Params; +#else +@group(0) @binding(1) +var dst: array; +@group(0) @binding(2) +var params: Params; +#endif +#endif +#endif + +#ifdef INPLACE +fn inter_value(i: u32) -> f32 { + return src[i]; +} +fn update(i: u32, val: f32) { + src[i] = val; +} + +#else +fn inter_value(i: u32) -> f32 { + return dst[i]; +} +fn update(i: u32, val: f32) { + dst[i] = val; +} +#endif + +#ifdef HAS_MASK +fn mask_val(i: u32) -> f32 { + return f32(mask[i]); +} + +#else +fn mask_val(i: u32) -> f32 { + return 0.0; +} +#endif + +#ifdef HAS_SINK +fn lower_max_bound(i2: u32) -> f32 { + return sinks[params.offset_sinks + i2]; +} +fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { + return val + exp(sinks[params.offset_sinks + i2] - max_val); +} +#else +fn lower_max_bound(i2: u32) -> f32 { + return -1e30; +} +fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { + return val; +} +#endif const CACHE_SIZE: u32 = 16; +var scratch: array; -override wg_size: u32; -var scratch: array; - -@compute @workgroup_size(wg_size) +@compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3) { @@ -268,7 +168,7 @@ fn main(@builtin(workgroup_id) wid: vec3, let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01; let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11; let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; - let elems = (params.ne0 + wg_size - 1) / wg_size; + let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; let head = f32(i2); let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0); @@ -286,12 +186,12 @@ fn main(@builtin(workgroup_id) wid: vec3, if (col < CACHE_SIZE) { cache[col] = val; } - col += wg_size; + col += WG_SIZE; } scratch[lid.x] = max_val; workgroupBarrier(); - var offset = wg_size / 2; + var offset: u32 = WG_SIZE / 2; while (offset > 0) { if (lid.x < offset) { scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]); @@ -317,12 +217,12 @@ fn main(@builtin(workgroup_id) wid: vec3, } else { update(i_dst_row + col, ex); } - col += wg_size; + col += WG_SIZE; } scratch[lid.x] = sum; workgroupBarrier(); - offset = wg_size / 2; + offset = WG_SIZE / 2; while (offset > 0) { if (lid.x < offset) { scratch[lid.x] += scratch[lid.x + offset]; @@ -339,7 +239,7 @@ fn main(@builtin(workgroup_id) wid: vec3, break; } update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip); - col += wg_size; + col += WG_SIZE; } } -#end(SHADER) +