ggml webgpu: support for backend sampling (#18880)
* ggml webgpu: add SOFTPLUS unary operator Implements SOFTPLUS (log(1 + exp(x))) with f16/f32 support. Uses f32 precision for intermediate calculations to prevent f16 overflow. * Add shader implementation and 4 variants (f32/f16, inplace/non-inplace) * Register pipelines and device support * Follow Vulkan backend numerical stability pattern * ggml webgpu: add EXPM1 unary operator Implements EXPM1 (exp(x) - 1) with f16/f32 support. * Add shader implementation and 4 variants (f32/f16, inplace/non-inplace) * Register pipelines and device support * ggml webgpu: add FLOOR unary operator Implements FLOOR (rounds down to nearest integer) with f16/f32 support. * Add shader implementation and 4 variants (f32/f16, inplace/non-inplace) * Register pipelines and device support * ggml webgpu: add CEIL unary operator Implements CEIL (rounds up to nearest integer) with f16/f32 support. * Add shader implementation and 4 variants (f32/f16, inplace/non-inplace) * Register pipelines and device support * ggml webgpu: add ROUND unary operator Implements ROUND (rounds to nearest integer) with f16/f32 support. * Add shader implementation and 4 variants (f32/f16, inplace/non-inplace) * Register pipelines and device support * ggml webgpu: add TRUNC unary operator Implements TRUNC (truncates towards zero) with f16/f32 support. * Add shader implementation and 4 variants (f32/f16, inplace/non-inplace) * Register pipelines and device support * docs : update WebGPU support for unary operators (FLOOR, CEIL, ROUND, TRUNC, EXPM1, SOFTPLUS) * Updates to webgpu get_memory * Add argmax * Add argmax,cumsum,sum,sum_rows * Add necessary CPY/GET_ROWS operators * Support for argsort using multi-pass strategy * Update set_rows for i32 indices, move to pre-wgsl * Port unary operators to pre-wgsl and support FILL * Implement PAD * Add support for top-k * clean up, scope pipeline init mutex * fix newline * Add support for log * Update LOG for better precision, and ops doc --------- Co-authored-by: Abhijit Ramesh <abhijitramesh2k@gmail.com>
This commit is contained in:
parent
388ce82241
commit
a89002f07b
33
docs/ops.md
33
docs/ops.md
|
|
@ -20,10 +20,10 @@ Legend:
|
|||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
|
|
@ -36,17 +36,17 @@ Legend:
|
|||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DIAG | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
| FILL | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
|
|
@ -63,7 +63,7 @@ Legend:
|
|||
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| L2_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||
|
|
@ -73,8 +73,9 @@ Legend:
|
|||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| OUT_PROD | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | 🟡 |
|
||||
| PAD | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| PAD | ❌ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_1D | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
|
|
@ -85,7 +86,7 @@ Legend:
|
|||
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ROPE | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
|
|
@ -96,7 +97,7 @@ Legend:
|
|||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
|
|
@ -106,14 +107,14 @@ Legend:
|
|||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| SUM | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
|
||||
| SUM | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
|
|
|
|||
15267
docs/ops/WebGPU.csv
15267
docs/ops/WebGPU.csv
File diff suppressed because it is too large
Load Diff
|
|
@ -9,12 +9,28 @@
|
|||
|
||||
#define GGML_WEBGPU_F16_SIZE_BYTES 2
|
||||
#define GGML_WEBGPU_F32_SIZE_BYTES 4
|
||||
#define GGML_WEBGPU_I32_SIZE_BYTES 4
|
||||
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
|
||||
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u
|
||||
// Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.
|
||||
#define GGML_WEBGPU_KV_SEQ_PAD 256u
|
||||
|
||||
struct ggml_webgpu_flash_attn_shader_lib_context {
|
||||
#define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u
|
||||
|
||||
struct ggml_webgpu_processed_shader {
|
||||
std::string wgsl;
|
||||
std::string variant;
|
||||
void * decisions;
|
||||
};
|
||||
|
||||
// Same hash combine function as in boost
|
||||
template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
|
||||
seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||
}
|
||||
|
||||
/** FlashAttention */
|
||||
|
||||
struct ggml_webgpu_flash_attn_pipeline_key {
|
||||
ggml_type kv_type;
|
||||
uint32_t head_dim_qk;
|
||||
uint32_t head_dim_v;
|
||||
|
|
@ -22,11 +38,35 @@ struct ggml_webgpu_flash_attn_shader_lib_context {
|
|||
bool has_mask;
|
||||
bool has_sinks;
|
||||
bool uses_logit_softcap;
|
||||
uint32_t sg_mat_m;
|
||||
uint32_t sg_mat_n;
|
||||
uint32_t sg_mat_k;
|
||||
size_t wg_mem_limit_bytes;
|
||||
uint32_t max_subgroup_size;
|
||||
|
||||
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
|
||||
return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
|
||||
kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
|
||||
uses_logit_softcap == other.uses_logit_softcap;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.kv_type);
|
||||
ggml_webgpu_hash_combine(seed, key.head_dim_qk);
|
||||
ggml_webgpu_hash_combine(seed, key.head_dim_v);
|
||||
ggml_webgpu_hash_combine(seed, key.kv_direct);
|
||||
ggml_webgpu_hash_combine(seed, key.has_mask);
|
||||
ggml_webgpu_hash_combine(seed, key.has_sinks);
|
||||
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_shader_lib_context {
|
||||
ggml_webgpu_flash_attn_pipeline_key key;
|
||||
uint32_t sg_mat_m;
|
||||
uint32_t sg_mat_n;
|
||||
uint32_t sg_mat_k;
|
||||
size_t wg_mem_limit_bytes;
|
||||
uint32_t max_subgroup_size;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_flash_attn_shader_decisions {
|
||||
|
|
@ -35,12 +75,6 @@ struct ggml_webgpu_flash_attn_shader_decisions {
|
|||
uint32_t wg_size = 0;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_processed_shader {
|
||||
std::string wgsl;
|
||||
std::string variant;
|
||||
ggml_webgpu_flash_attn_shader_decisions decisions;
|
||||
};
|
||||
|
||||
// This is exposed because it's necessary in supports_op
|
||||
inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
|
||||
uint32_t kv_tile,
|
||||
|
|
@ -66,15 +100,16 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
|
|||
}
|
||||
|
||||
static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
|
||||
const size_t limit_bytes = context.wg_mem_limit_bytes;
|
||||
const size_t q_tile = context.sg_mat_m;
|
||||
const size_t base_q_bytes = (context.head_dim_qk + context.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
|
||||
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||
const size_t limit_bytes = context.wg_mem_limit_bytes;
|
||||
const size_t q_tile = context.sg_mat_m;
|
||||
const size_t base_q_bytes =
|
||||
(context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
|
||||
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
|
||||
size_t bytes_per_kv = 0;
|
||||
if (!context.kv_direct) {
|
||||
bytes_per_kv += std::max(context.head_dim_qk, context.head_dim_v);
|
||||
if (!context.key.kv_direct) {
|
||||
bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v);
|
||||
}
|
||||
if (context.has_mask) {
|
||||
if (context.key.has_mask) {
|
||||
bytes_per_kv += q_tile;
|
||||
}
|
||||
bytes_per_kv += q_tile;
|
||||
|
|
@ -90,7 +125,7 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
|
|||
std::vector<std::string> defines;
|
||||
std::string variant = "flash_attn";
|
||||
|
||||
switch (context.kv_type) {
|
||||
switch (context.key.kv_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("KV_F32");
|
||||
break;
|
||||
|
|
@ -106,32 +141,31 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
|
|||
default:
|
||||
GGML_ABORT("Unsupported KV type for flash attention shader");
|
||||
}
|
||||
variant += std::string("_") + ggml_type_name(context.kv_type);
|
||||
variant += std::string("_") + ggml_type_name(context.key.kv_type);
|
||||
|
||||
if (context.has_mask) {
|
||||
if (context.key.has_mask) {
|
||||
defines.push_back("MASK");
|
||||
variant += "_mask";
|
||||
}
|
||||
if (context.has_sinks) {
|
||||
if (context.key.has_sinks) {
|
||||
defines.push_back("SINKS");
|
||||
variant += "_sinks";
|
||||
}
|
||||
if (context.uses_logit_softcap) {
|
||||
if (context.key.uses_logit_softcap) {
|
||||
defines.push_back("LOGIT_SOFTCAP");
|
||||
variant += "_lgsc";
|
||||
}
|
||||
|
||||
if (context.kv_direct) {
|
||||
if (context.key.kv_direct) {
|
||||
defines.push_back("KV_DIRECT");
|
||||
variant += "_kvdirect";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.head_dim_qk));
|
||||
variant += std::string("_hsqk") + std::to_string(context.head_dim_qk);
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.head_dim_v));
|
||||
variant += std::string("_hsv") + std::to_string(context.head_dim_v);
|
||||
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk));
|
||||
variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk);
|
||||
|
||||
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
|
||||
variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
|
||||
// For now these are not part of the variant name
|
||||
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
|
||||
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
|
||||
|
|
@ -141,7 +175,7 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
|
|||
uint32_t q_tile = context.sg_mat_m;
|
||||
uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
|
||||
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
|
||||
if (context.kv_direct) {
|
||||
if (context.key.kv_direct) {
|
||||
GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
|
||||
// Avoids having to use bounds-checks and decreasing performance for direct KV loads
|
||||
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
|
||||
|
|
@ -158,11 +192,276 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
|
|||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
result.decisions.q_tile = q_tile;
|
||||
result.decisions.kv_tile = kv_tile;
|
||||
result.decisions.wg_size = wg_size;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
ggml_webgpu_flash_attn_shader_decisions * decisions = new ggml_webgpu_flash_attn_shader_decisions();
|
||||
decisions->q_tile = q_tile;
|
||||
decisions->kv_tile = kv_tile;
|
||||
decisions->wg_size = wg_size;
|
||||
result.decisions = decisions;
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Generic **/
|
||||
|
||||
struct ggml_webgpu_generic_shader_lib_context {
|
||||
int vec4;
|
||||
uint32_t max_wg_size;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_generic_shader_decisions {
|
||||
uint32_t wg_size;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_generic_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_generic_shader_lib_context & context,
|
||||
const std::string & base_variant) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = base_variant;
|
||||
|
||||
if (context.vec4) {
|
||||
defines.push_back("VEC4");
|
||||
variant += "_vec";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Pad **/
|
||||
|
||||
struct ggml_webgpu_pad_pipeline_key {
|
||||
bool circular;
|
||||
|
||||
bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; }
|
||||
};
|
||||
|
||||
struct ggml_webgpu_pad_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.circular);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_pad_shader_lib_context {
|
||||
ggml_webgpu_pad_pipeline_key key;
|
||||
uint32_t max_wg_size;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_pad_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_pad_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "pad";
|
||||
|
||||
if (context.key.circular) {
|
||||
defines.push_back("CIRCULAR");
|
||||
variant += "_circular";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
result.decisions = decisions;
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Argsort **/
|
||||
|
||||
struct ggml_webgpu_argsort_shader_lib_context {
|
||||
uint32_t max_wg_size;
|
||||
size_t wg_mem_limit_bytes;
|
||||
int32_t order;
|
||||
};
|
||||
|
||||
struct ggml_webgpu_argsort_shader_decisions {
|
||||
uint32_t wg_size = 0;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_argsort_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "argsort";
|
||||
defines.push_back(std::string("ORDER=") + std::to_string(context.order));
|
||||
variant += std::string("_order") + std::to_string(context.order);
|
||||
uint32_t wg_size = 1;
|
||||
while (wg_size * 2 <= context.max_wg_size &&
|
||||
wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) {
|
||||
wg_size *= 2;
|
||||
}
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
ggml_webgpu_argsort_shader_decisions * decisions = new ggml_webgpu_argsort_shader_decisions();
|
||||
decisions->wg_size = wg_size;
|
||||
result.decisions = decisions;
|
||||
return result;
|
||||
}
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_merge_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_argsort_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "argsort_merge";
|
||||
defines.push_back(std::string("ORDER=") + std::to_string(context.order));
|
||||
variant += std::string("_order") + std::to_string(context.order);
|
||||
uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size);
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
ggml_webgpu_argsort_shader_decisions * decisions = new ggml_webgpu_argsort_shader_decisions();
|
||||
decisions->wg_size = wg_size;
|
||||
result.decisions = decisions;
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Set Rows **/
|
||||
|
||||
struct ggml_webgpu_set_rows_pipeline_key {
|
||||
int dst_type;
|
||||
int vec4;
|
||||
int i64_idx;
|
||||
|
||||
bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const {
|
||||
return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_set_rows_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.dst_type);
|
||||
ggml_webgpu_hash_combine(seed, key.vec4);
|
||||
ggml_webgpu_hash_combine(seed, key.i64_idx);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_set_rows_shader_lib_context {
|
||||
ggml_webgpu_set_rows_pipeline_key key;
|
||||
uint32_t max_wg_size;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_set_rows_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_set_rows_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = "set_rows";
|
||||
|
||||
switch (context.key.dst_type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("DST_F32");
|
||||
variant += "_dstf32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("DST_F16");
|
||||
variant += "_dstf16";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported dst type for set_rows shader");
|
||||
}
|
||||
|
||||
if (context.key.vec4) {
|
||||
defines.push_back("VEC4");
|
||||
variant += "_vec";
|
||||
}
|
||||
if (context.key.i64_idx) {
|
||||
defines.push_back("I64_IDX");
|
||||
variant += "_i64idx";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
result.decisions = decisions;
|
||||
return result;
|
||||
}
|
||||
|
||||
struct ggml_webgpu_unary_pipeline_key {
|
||||
int type;
|
||||
int op;
|
||||
bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
|
||||
bool inplace;
|
||||
|
||||
bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
|
||||
return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_unary_pipeline_key_hash {
|
||||
size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const {
|
||||
size_t seed = 0;
|
||||
ggml_webgpu_hash_combine(seed, key.type);
|
||||
ggml_webgpu_hash_combine(seed, key.op);
|
||||
ggml_webgpu_hash_combine(seed, key.is_unary);
|
||||
ggml_webgpu_hash_combine(seed, key.inplace);
|
||||
return seed;
|
||||
}
|
||||
};
|
||||
|
||||
struct ggml_webgpu_unary_shader_lib_context {
|
||||
ggml_webgpu_unary_pipeline_key key;
|
||||
uint32_t max_wg_size;
|
||||
};
|
||||
|
||||
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader(
|
||||
pre_wgsl::Preprocessor & preprocessor,
|
||||
const char * shader_src,
|
||||
const ggml_webgpu_unary_shader_lib_context & context) {
|
||||
std::vector<std::string> defines;
|
||||
std::string variant = context.key.is_unary ? ggml_unary_op_name((ggml_unary_op) context.key.op) :
|
||||
ggml_op_name((ggml_op) context.key.op);
|
||||
// Operation-specific behavior
|
||||
defines.push_back(variant);
|
||||
|
||||
switch (context.key.type) {
|
||||
case GGML_TYPE_F32:
|
||||
defines.push_back("TYPE_F32");
|
||||
variant += "_f32";
|
||||
break;
|
||||
case GGML_TYPE_F16:
|
||||
defines.push_back("TYPE_F16");
|
||||
variant += "_f16";
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("Unsupported type for unary shader");
|
||||
}
|
||||
|
||||
if (context.key.inplace) {
|
||||
defines.push_back("INPLACE");
|
||||
variant += "_inplace";
|
||||
}
|
||||
|
||||
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
|
||||
|
||||
ggml_webgpu_processed_shader result;
|
||||
result.wgsl = preprocessor.preprocess(shader_src, defines);
|
||||
result.variant = variant;
|
||||
ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
|
||||
decisions->wg_size = context.max_wg_size;
|
||||
result.decisions = decisions;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,72 @@
|
|||
@group(0) @binding(0)
|
||||
#ifdef VEC4
|
||||
var<storage, read_write> src: array<vec4<f32>>;
|
||||
#define VEC_SIZE 4
|
||||
#else
|
||||
var<storage, read_write> src: array<f32>;
|
||||
#define VEC_SIZE 1
|
||||
#endif
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<i32>;
|
||||
|
||||
struct Params {
|
||||
offset_src: u32, // in elements
|
||||
offset_dst: u32, // in elements
|
||||
ne0: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
const FLOAT_MIN: f32 = -1.0e9;
|
||||
|
||||
struct Pair {
|
||||
value: f32,
|
||||
index: i32
|
||||
};
|
||||
|
||||
var<workgroup> shared_max: array<Pair, WG_SIZE>;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||
@builtin(local_invocation_id) lid: vec3<u32>) {
|
||||
let row_idx = params.offset_src + wid.x * params.ne0;
|
||||
var local_pair = Pair(FLOAT_MIN, -1);
|
||||
#ifdef VEC4
|
||||
for (var col = lid.x; col < params.ne0/VEC_SIZE; col += WG_SIZE) {
|
||||
let vec_val = src[row_idx / VEC_SIZE + col];
|
||||
for (var v = 0u; v < VEC_SIZE; v++) {
|
||||
let val = vec_val[v];
|
||||
if (val >= local_pair.value) {
|
||||
local_pair = Pair(val, i32(col * VEC_SIZE + v));
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (var col = lid.x; col < params.ne0; col += WG_SIZE) {
|
||||
if (src[row_idx + col] >= local_pair.value) {
|
||||
local_pair = Pair(src[row_idx + col], i32(col));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
shared_max[lid.x] = local_pair;
|
||||
workgroupBarrier();
|
||||
var offset: u32 = WG_SIZE >> 1;
|
||||
while (offset > 0) {
|
||||
if (lid.x < offset) {
|
||||
let a = shared_max[lid.x];
|
||||
let b = shared_max[lid.x + offset];
|
||||
if (b.value > a.value) {
|
||||
shared_max[lid.x] = b;
|
||||
} else if (b.value == a.value && b.index > a.index) {
|
||||
shared_max[lid.x] = b;
|
||||
}
|
||||
}
|
||||
workgroupBarrier();
|
||||
offset >>= 1;
|
||||
}
|
||||
if (lid.x == 0u) {
|
||||
dst[params.offset_dst + wid.x] = shared_max[0].index;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<f32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<i32>;
|
||||
|
||||
struct Params {
|
||||
offset_src: u32, // in elements
|
||||
offset_dst: u32, // in elements
|
||||
|
||||
stride_src1: u32,
|
||||
stride_src2: u32,
|
||||
stride_src3: u32,
|
||||
|
||||
stride_dst1: u32,
|
||||
stride_dst2: u32,
|
||||
stride_dst3: u32,
|
||||
|
||||
// src/dst dimensions
|
||||
src_ne0: u32,
|
||||
ne1: u32,
|
||||
ne2: u32,
|
||||
|
||||
ne0: u32,
|
||||
top_k: u32,
|
||||
|
||||
npr: u32, // tiles per row
|
||||
nrows: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
var<workgroup> shmem_idx: array<u32, WG_SIZE>;
|
||||
|
||||
#if ORDER == 0
|
||||
#define EXTREME_VALUE 1e30
|
||||
#define SWAP_COMPARE_UP >
|
||||
#define SWAP_COMPARE_DOWN <
|
||||
#else
|
||||
#define EXTREME_VALUE -1e30
|
||||
#define SWAP_COMPARE_UP <
|
||||
#define SWAP_COMPARE_DOWN >
|
||||
#endif
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>,
|
||||
@builtin(local_invocation_id) lid: vec3<u32>) {
|
||||
let linear = wid.x + wid.y * num_wg.x;
|
||||
// guard against overprovisioned workgroups
|
||||
if (linear >= params.npr * params.nrows) {
|
||||
return;
|
||||
}
|
||||
let tile = linear % params.npr;
|
||||
var row = linear / params.npr;
|
||||
let i3 = row / (params.ne2 * params.ne1);
|
||||
row = row % (params.ne2 * params.ne1);
|
||||
let i2 = row / params.ne1;
|
||||
let i1 = row % params.ne1;
|
||||
|
||||
let row_base = params.offset_src +
|
||||
i1 * params.stride_src1 +
|
||||
i2 * params.stride_src2 +
|
||||
i3 * params.stride_src3;
|
||||
|
||||
let tile_base = tile * WG_SIZE;
|
||||
let idx = tile_base + lid.x;
|
||||
shmem_idx[lid.x] = select(params.src_ne0, idx, idx < params.src_ne0);
|
||||
workgroupBarrier();
|
||||
|
||||
var k = 2u;
|
||||
while (k <= WG_SIZE) {
|
||||
var j = k >> 1;
|
||||
while (j > 0) {
|
||||
let ixj = lid.x ^ j;
|
||||
if (ixj > lid.x) {
|
||||
let dir_up = (lid.x & k) == 0;
|
||||
let a_idx = shmem_idx[lid.x];
|
||||
let b_idx = shmem_idx[ixj];
|
||||
let a_val = select(EXTREME_VALUE, src[row_base + a_idx], a_idx < params.src_ne0);
|
||||
let b_val = select(EXTREME_VALUE, src[row_base + b_idx], b_idx < params.src_ne0);
|
||||
let should_swap = select(
|
||||
(a_val SWAP_COMPARE_DOWN b_val),
|
||||
(a_val SWAP_COMPARE_UP b_val),
|
||||
dir_up);
|
||||
if (should_swap) {
|
||||
shmem_idx[lid.x] = b_idx;
|
||||
shmem_idx[ixj] = a_idx;
|
||||
}
|
||||
}
|
||||
workgroupBarrier();
|
||||
j >>= 1;
|
||||
}
|
||||
k <<= 1;
|
||||
}
|
||||
|
||||
let out_idx = tile * params.top_k + lid.x;
|
||||
if (out_idx < params.ne0 && lid.x < params.top_k) {
|
||||
let row_dst = params.offset_dst +
|
||||
i1 * params.stride_dst1 +
|
||||
i2 * params.stride_dst2 +
|
||||
i3 * params.stride_dst3;
|
||||
dst[row_dst + out_idx] = i32(shmem_idx[lid.x]);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<f32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> idx_in: array<i32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> idx_out: array<i32>;
|
||||
|
||||
struct Params {
|
||||
offset_src: u32, // in elements
|
||||
offset_in: u32, // in elements
|
||||
offset_out: u32, // in elements
|
||||
|
||||
stride_src1: u32,
|
||||
stride_src2: u32,
|
||||
stride_src3: u32,
|
||||
|
||||
stride_idx1: u32,
|
||||
stride_idx2: u32,
|
||||
stride_idx3: u32,
|
||||
|
||||
stride_out1: u32,
|
||||
stride_out2: u32,
|
||||
stride_out3: u32,
|
||||
|
||||
ne0: u32,
|
||||
ne1: u32,
|
||||
ne2: u32,
|
||||
|
||||
top_k: u32,
|
||||
|
||||
len: u32,
|
||||
nm: u32,
|
||||
nrows: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
fn take_left(a_idx: i32, b_idx: i32, row_base: u32) -> bool {
|
||||
let a_val = src[row_base + u32(a_idx)];
|
||||
let b_val = src[row_base + u32(b_idx)];
|
||||
#if ORDER == 0
|
||||
return a_val <= b_val;
|
||||
#else
|
||||
return a_val >= b_val;
|
||||
#endif
|
||||
}
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||
@builtin(num_workgroups) num_wg: vec3<u32>,
|
||||
@builtin(local_invocation_id) lid: vec3<u32>) {
|
||||
let linear = wid.x + wid.y * num_wg.x;
|
||||
// guard against overprovisioned workgroups
|
||||
if (linear >= params.nm * params.nrows) {
|
||||
return;
|
||||
}
|
||||
|
||||
let start = (linear % params.nm) * params.len * 2;
|
||||
let len0 = min(params.len, params.ne0 - start);
|
||||
let rem1 = select(0, params.ne0 - (start + params.len), params.ne0 > (start + params.len));
|
||||
let len1 = min(params.len, rem1);
|
||||
let total = len0 + len1;
|
||||
let chunk = (total + WG_SIZE - 1u) / WG_SIZE;
|
||||
let k0 = lid.x * chunk;
|
||||
let k1 = min(min(k0 + chunk, total), params.top_k);
|
||||
// guard against overprovisioned threads
|
||||
if (k0 >= params.top_k || k0 >= total) {
|
||||
return;
|
||||
}
|
||||
|
||||
var row = linear / params.nm;
|
||||
let i3 = row / (params.ne2 * params.ne1);
|
||||
row = row % (params.ne2 * params.ne1);
|
||||
let i2 = row / params.ne1;
|
||||
let i1 = row % params.ne1;
|
||||
|
||||
let row_src = params.offset_src +
|
||||
i1 * params.stride_src1 +
|
||||
i2 * params.stride_src2 +
|
||||
i3 * params.stride_src3;
|
||||
|
||||
let row_in = params.offset_in +
|
||||
i1 * params.stride_idx1 +
|
||||
i2 * params.stride_idx2 +
|
||||
i3 * params.stride_idx3;
|
||||
|
||||
let row_out = params.offset_out +
|
||||
i1 * params.stride_out1 +
|
||||
i2 * params.stride_out2 +
|
||||
i3 * params.stride_out3;
|
||||
|
||||
|
||||
var low: u32 = select(0, k0 - len1, k0 > len1);
|
||||
var high: u32 = min(k0, len0);
|
||||
|
||||
while (low < high) {
|
||||
let mid = (low + high) >> 1;
|
||||
let idx0 = idx_in[row_in + start + mid];
|
||||
let idx1 = idx_in[row_in + start + params.len + (k0 - mid - 1)];
|
||||
if (take_left(idx0, idx1, row_src)) {
|
||||
low = mid + 1;
|
||||
} else {
|
||||
high = mid;
|
||||
}
|
||||
}
|
||||
|
||||
var i = low;
|
||||
var j = k0 - i;
|
||||
var k = k0;
|
||||
while (k < k1) {
|
||||
var take_l = false;
|
||||
if (i >= len0) {
|
||||
take_l = false;
|
||||
} else if (j >= len1) {
|
||||
take_l = true;
|
||||
} else {
|
||||
let idx0 = idx_in[row_in + start + i];
|
||||
let idx1 = idx_in[row_in + start + params.len + j];
|
||||
take_l = take_left(idx0, idx1, row_src);
|
||||
}
|
||||
|
||||
let out_idx = select(
|
||||
idx_in[row_in + start + params.len + j],
|
||||
idx_in[row_in + start + i],
|
||||
take_l);
|
||||
idx_out[row_out + start + k] = out_idx;
|
||||
i = select(i, i + 1, take_l);
|
||||
j = select(j + 1, j, take_l);
|
||||
k += 1;
|
||||
}
|
||||
}
|
||||
|
|
@ -7,6 +7,12 @@
|
|||
"DST_TYPE": "f32"
|
||||
}
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC_TYPE": "f32",
|
||||
"DST_TYPE": "i32"
|
||||
}
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC_TYPE": "f32",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,66 @@
|
|||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<f32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
struct Params {
|
||||
offset_src: u32, // in elements
|
||||
offset_dst: u32, // in elements
|
||||
ne0: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
var<workgroup> shared_sum: array<f32, WG_SIZE>;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||
@builtin(local_invocation_id) lid: vec3<u32>) {
|
||||
let row_idx = params.offset_src + wid.x * params.ne0;
|
||||
let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE;
|
||||
var local_sum: f32 = 0.0;
|
||||
for (var col = lid.x * elems; col < (lid.x + 1) * elems && col < params.ne0; col ++) {
|
||||
local_sum += src[row_idx + col];
|
||||
}
|
||||
shared_sum[lid.x] = local_sum;
|
||||
workgroupBarrier();
|
||||
|
||||
// upsweep
|
||||
var offset = 1u;
|
||||
while (offset < WG_SIZE) {
|
||||
let idx = (lid.x + 1) * offset * 2 - 1;
|
||||
if (idx < WG_SIZE) {
|
||||
shared_sum[idx] = shared_sum[idx] + shared_sum[idx - offset];
|
||||
}
|
||||
workgroupBarrier();
|
||||
offset <<= 1;
|
||||
}
|
||||
|
||||
// set last to 0 for exclusive sum
|
||||
if (lid.x == 0) {
|
||||
shared_sum[WG_SIZE - 1] = 0.0;
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
// downsweep
|
||||
offset = WG_SIZE >> 1;
|
||||
while (offset > 0) {
|
||||
let idx = (lid.x + 1) * offset * 2 - 1;
|
||||
if (idx < WG_SIZE) {
|
||||
let t = shared_sum[idx - offset];
|
||||
shared_sum[idx - offset] = shared_sum[idx];
|
||||
shared_sum[idx] = shared_sum[idx] + t;
|
||||
}
|
||||
workgroupBarrier();
|
||||
offset = offset >> 1;
|
||||
}
|
||||
|
||||
// shared_sum[lid] is exclusive prefix sum up to this thread.
|
||||
var running_sum = shared_sum[lid.x];
|
||||
for (var col = lid.x * elems; col < (lid.x + 1) * elems && col < params.ne0; col ++) {
|
||||
running_sum += src[row_idx + col];
|
||||
dst[params.offset_dst + wid.x * params.ne0 + col] = running_sum;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,86 @@
|
|||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<f32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
struct Params {
|
||||
ne: u32, // total number of elements
|
||||
offset_src: u32, // in elements
|
||||
offset_dst: u32, // in elements
|
||||
|
||||
// Strides (in elements)
|
||||
stride_src0: u32,
|
||||
stride_src1: u32,
|
||||
stride_src2: u32,
|
||||
stride_src3: u32,
|
||||
|
||||
// Logical shapes
|
||||
src_ne0: u32,
|
||||
src_ne1: u32,
|
||||
src_ne2: u32,
|
||||
src_ne3: u32,
|
||||
|
||||
dst_ne0: u32,
|
||||
dst_ne1: u32,
|
||||
dst_ne2: u32,
|
||||
dst_ne3: u32,
|
||||
|
||||
// Pad sizes (in elements)
|
||||
lp0: u32,
|
||||
rp0: u32,
|
||||
lp1: u32,
|
||||
rp1: u32,
|
||||
lp2: u32,
|
||||
rp2: u32,
|
||||
lp3: u32,
|
||||
rp3: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
fn wrap_around(idx: i32, n: u32) -> u32 {
|
||||
return u32(idx + i32(n)) % n;
|
||||
}
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
var i = gid.x;
|
||||
let dst_plane = params.dst_ne2 * params.dst_ne1 * params.dst_ne0;
|
||||
let i3 = i / dst_plane;
|
||||
i = i % dst_plane;
|
||||
let i2 = i / (params.dst_ne1 * params.dst_ne0);
|
||||
i = i % (params.dst_ne1 * params.dst_ne0);
|
||||
let i1 = i / params.dst_ne0;
|
||||
let i0 = i % params.dst_ne0;
|
||||
|
||||
var value: f32 = 0.0;
|
||||
|
||||
#ifdef CIRCULAR
|
||||
let ci0 = wrap_around(i32(i0) - i32(params.lp0), params.src_ne0);
|
||||
let ci1 = wrap_around(i32(i1) - i32(params.lp1), params.src_ne1);
|
||||
let ci2 = wrap_around(i32(i2) - i32(params.lp2), params.src_ne2);
|
||||
let ci3 = wrap_around(i32(i3) - i32(params.lp3), params.src_ne3);
|
||||
let circular_src_idx = ci0 * params.stride_src0 + ci1 * params.stride_src1 +
|
||||
ci2 * params.stride_src2 + ci3 * params.stride_src3;
|
||||
value = src[params.offset_src + circular_src_idx];
|
||||
#else
|
||||
let is_src =
|
||||
(i0 >= params.lp0 && i0 < params.dst_ne0 - params.rp0) &&
|
||||
(i1 >= params.lp1 && i1 < params.dst_ne1 - params.rp1) &&
|
||||
(i2 >= params.lp2 && i2 < params.dst_ne2 - params.rp2) &&
|
||||
(i3 >= params.lp3 && i3 < params.dst_ne3 - params.rp3);
|
||||
if (is_src) {
|
||||
let src_idx = (i0 - params.lp0) * params.stride_src0 + (i1 - params.lp1) * params.stride_src1 +
|
||||
(i2 - params.lp2) * params.stride_src2 + (i3 - params.lp3) * params.stride_src3;
|
||||
value = src[params.offset_src + src_idx];
|
||||
}
|
||||
#endif
|
||||
|
||||
dst[params.offset_dst + gid.x] = value;
|
||||
}
|
||||
|
|
@ -1,41 +1,37 @@
|
|||
#define(VARIANTS)
|
||||
|
||||
[
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_vec",
|
||||
"REPLS": {
|
||||
"TYPE" : "vec4<f32>",
|
||||
"DST_TYPE": "vec4<f16>",
|
||||
"VEC_SIZE": 4
|
||||
}
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16",
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
"DST_TYPE": "f16",
|
||||
"VEC_SIZE": 1
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(SHADER)
|
||||
|
||||
enable f16;
|
||||
|
||||
#ifdef DST_F32
|
||||
#define DST_INNER_TYPE f32
|
||||
#else
|
||||
#define DST_INNER_TYPE f16
|
||||
#endif
|
||||
|
||||
#ifdef VEC4
|
||||
#define SRC_TYPE vec4<f32>
|
||||
#define DST_TYPE vec4<DST_INNER_TYPE>
|
||||
#define VEC_SIZE 4
|
||||
#else
|
||||
#define SRC_TYPE f32
|
||||
#define DST_TYPE DST_INNER_TYPE
|
||||
#define VEC_SIZE 1
|
||||
#endif
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<{{TYPE}}>;
|
||||
var<storage, read_write> src: array<SRC_TYPE>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> idx: array<u32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<{{DST_TYPE}}>;
|
||||
var<storage, read_write> dst: array<DST_TYPE>;
|
||||
|
||||
#ifdef I64_IDX
|
||||
@group(0) @binding(3)
|
||||
var<storage, read_write> error: atomic<u32>;
|
||||
#define PARAMS_BINDING 4
|
||||
#else
|
||||
#define PARAMS_BINDING 3
|
||||
#endif
|
||||
|
||||
struct Params {
|
||||
offset_src: u32, // in elements
|
||||
|
|
@ -66,18 +62,17 @@ struct Params {
|
|||
idx2: u32,
|
||||
};
|
||||
|
||||
@group(0) @binding(4)
|
||||
@group(0) @binding(PARAMS_BINDING)
|
||||
var<uniform> params: Params;
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / {{VEC_SIZE}}) {
|
||||
if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / VEC_SIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
// getting the row from gid
|
||||
let elems_per_row = params.ne0 / {{VEC_SIZE}};
|
||||
let elems_per_row = params.ne0 / VEC_SIZE;
|
||||
var i = gid.x / elems_per_row;
|
||||
|
||||
let i_src3 = i / (params.ne2 * params.n_rows);
|
||||
|
|
@ -90,9 +85,10 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||
let i_idx1 = i_src2 % params.idx1;
|
||||
let i_idx0 = i_src1;
|
||||
|
||||
#ifdef I64_IDX
|
||||
let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2;
|
||||
|
||||
let idx_high_val = idx[idx_high];
|
||||
let idx_val = idx[idx_high];
|
||||
let idx_low_val = idx[idx_high + 1];
|
||||
|
||||
if (idx_low_val != 0) {
|
||||
|
|
@ -100,13 +96,14 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||
atomicStore(&error, 1);
|
||||
return;
|
||||
}
|
||||
#else
|
||||
let idx_i = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
|
||||
let idx_val = idx[idx_i];
|
||||
#endif
|
||||
|
||||
let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
|
||||
let i_dst_row = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3;
|
||||
let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3;
|
||||
|
||||
let col_idx = (gid.x % elems_per_row);
|
||||
dst[i_dst_row/{{VEC_SIZE}} + col_idx] = {{DST_TYPE}}(src[i_src_row/{{VEC_SIZE}} + col_idx]);
|
||||
dst[i_dst_row/VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row/VEC_SIZE + col_idx]);
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<f32>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
struct Params {
|
||||
offset_src: u32, // in elements
|
||||
offset_dst: u32, // in elements
|
||||
|
||||
// Strides (in elements)
|
||||
stride_src1: u32,
|
||||
stride_src2: u32,
|
||||
stride_src3: u32,
|
||||
|
||||
ne0: u32,
|
||||
ne1: u32,
|
||||
ne2: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
var<workgroup> shared_sum: array<f32, WG_SIZE>;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(workgroup_id) wid: vec3<u32>,
|
||||
@builtin(local_invocation_id) lid: vec3<u32>) {
|
||||
|
||||
var i = wid.x;
|
||||
let i3 = i / (params.ne2 * params.ne1);
|
||||
i = i % (params.ne2 * params.ne1);
|
||||
let i2 = i / params.ne1;
|
||||
let i1 = i % params.ne1;
|
||||
let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
|
||||
var local_sum: f32 = 0.0;
|
||||
for (var col = lid.x; col < params.ne0; col += WG_SIZE) {
|
||||
local_sum += src[i_src_row + col];
|
||||
}
|
||||
shared_sum[lid.x] = local_sum;
|
||||
workgroupBarrier();
|
||||
// reduce within workgroup
|
||||
var offset: u32 = WG_SIZE >> 1;
|
||||
while (offset > 0) {
|
||||
if (lid.x < offset) {
|
||||
shared_sum[lid.x] = shared_sum[lid.x] + shared_sum[lid.x + offset];
|
||||
}
|
||||
workgroupBarrier();
|
||||
offset >>= 1;
|
||||
}
|
||||
|
||||
if (lid.x == 0) {
|
||||
dst[params.offset_dst + wid.x] = shared_sum[0];
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,179 @@
|
|||
#ifdef TYPE_F16
|
||||
enable f16;
|
||||
#define TYPE f16
|
||||
#else
|
||||
#define TYPE f32
|
||||
#endif
|
||||
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<TYPE>;
|
||||
|
||||
#ifndef INPLACE
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<TYPE>;
|
||||
#define PARAMS_BINDING 2
|
||||
#else
|
||||
#define PARAMS_BINDING 1
|
||||
#endif
|
||||
|
||||
struct Params {
|
||||
ne: u32, // total number of elements
|
||||
offset_src: u32, // in elements
|
||||
offset_dst: u32, // in elements
|
||||
|
||||
// Strides (in elements)
|
||||
stride_src0: u32,
|
||||
stride_src1: u32,
|
||||
stride_src2: u32,
|
||||
stride_src3: u32,
|
||||
|
||||
// Logical shapes
|
||||
ne0: u32,
|
||||
ne1: u32,
|
||||
ne2: u32,
|
||||
#ifdef CLAMP
|
||||
clamp_min: f32,
|
||||
clamp_max: f32,
|
||||
#endif
|
||||
#ifdef FILL
|
||||
fill_val: f32,
|
||||
#endif
|
||||
#ifdef XIELU
|
||||
alpha_n: f32,
|
||||
alpha_p: f32,
|
||||
beta: f32,
|
||||
eps: f32,
|
||||
#endif
|
||||
|
||||
};
|
||||
|
||||
@group(0) @binding(PARAMS_BINDING)
|
||||
var<uniform> params: Params;
|
||||
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne) {
|
||||
return;
|
||||
}
|
||||
var i = gid.x;
|
||||
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
|
||||
i = i % (params.ne2 * params.ne1 * params.ne0);
|
||||
let i2 = i / (params.ne1 * params.ne0);
|
||||
i = i % (params.ne1 * params.ne0);
|
||||
let i1 = i / params.ne0;
|
||||
let i0 = i % params.ne0;
|
||||
|
||||
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
|
||||
i2 * params.stride_src2 + i3 * params.stride_src3;
|
||||
|
||||
#ifdef ABS
|
||||
let res = abs(src[params.offset_src + src_idx]);
|
||||
#endif
|
||||
#ifdef SGN
|
||||
let res = select(TYPE(select(0.0, -1.0, src[params.offset_src + src_idx] < 0.0)), TYPE(1.0),
|
||||
src[params.offset_src + src_idx] > 0.0);
|
||||
#endif
|
||||
#ifdef NEG
|
||||
let res = -src[params.offset_src + src_idx];
|
||||
#endif
|
||||
#ifdef STEP
|
||||
let res = TYPE(select(0.0, 1.0, src[params.offset_src + src_idx] > 0.0));
|
||||
#endif
|
||||
#ifdef TANH
|
||||
let res = tanh(clamp(src[params.offset_src + src_idx], -9.010913, 9.010913));
|
||||
#endif
|
||||
#ifdef RELU
|
||||
let res = select(0.0, src[params.offset_src + src_idx], src[params.offset_src + src_idx] > 0.0);
|
||||
#endif
|
||||
#ifdef ELU
|
||||
let res = select(exp(src[params.offset_src + src_idx]) - 1.0, src[params.offset_src + src_idx],
|
||||
src[params.offset_src + src_idx] > 0.0);
|
||||
#endif
|
||||
#ifdef HARDSIGMOID
|
||||
let res = min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));
|
||||
#endif
|
||||
#ifdef SIGMOID
|
||||
let res = 1.0 / (1.0 + exp(-src[params.offset_src + src_idx]));
|
||||
#endif
|
||||
#ifdef SILU
|
||||
let res = src[params.offset_src + src_idx] / (1.0 + exp(-src[params.offset_src + src_idx]));
|
||||
#endif
|
||||
#ifdef EXP
|
||||
let res = exp(src[params.offset_src + src_idx]);
|
||||
#endif
|
||||
#ifdef LOG
|
||||
let res = TYPE(log(f32(src[params.offset_src + src_idx])));
|
||||
#endif
|
||||
#ifdef CLAMP
|
||||
let res = clamp(src[params.offset_src + src_idx], TYPE(params.clamp_min), TYPE(params.clamp_max));
|
||||
#endif
|
||||
#ifdef FILL
|
||||
let res = TYPE(params.fill_val);
|
||||
#endif
|
||||
#ifdef HARDSWISH
|
||||
let res = src[params.offset_src + src_idx] *
|
||||
min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));
|
||||
#endif
|
||||
#ifdef GELU
|
||||
let res = 0.5 * src[params.offset_src + src_idx] *
|
||||
(1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) *
|
||||
(src[params.offset_src + src_idx] +
|
||||
0.044715 * pow(src[params.offset_src + src_idx], 3.0)),
|
||||
-9.010913, 9.010913)));
|
||||
#endif
|
||||
#ifdef GELU_QUICK
|
||||
let res = src[params.offset_src + src_idx] * 0.5 *
|
||||
(1.0 + tanh(clamp(0.79788456 *
|
||||
(src[params.offset_src + src_idx] +
|
||||
0.044715 * src[params.offset_src + src_idx] *
|
||||
src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),
|
||||
-9.010913, 9.010913)));
|
||||
#endif
|
||||
#ifdef GELU_ERF
|
||||
let res = 0.5 * src[params.offset_src + src_idx] *
|
||||
(1.0 + tanh(clamp(0.79788456 *
|
||||
(src[params.offset_src + src_idx] +
|
||||
0.044715 * src[params.offset_src + src_idx] *
|
||||
src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),
|
||||
-9.010913, 9.010913)));
|
||||
#endif
|
||||
#ifdef XIELU
|
||||
let res =
|
||||
select(((exp(min(src[params.offset_src + src_idx], TYPE(params.eps))) - 1.0) -
|
||||
src[params.offset_src + src_idx]) *
|
||||
TYPE(params.alpha_n) +
|
||||
TYPE(params.beta) * src[params.offset_src + src_idx],
|
||||
TYPE(params.alpha_p) * src[params.offset_src + src_idx] *
|
||||
src[params.offset_src + src_idx] +
|
||||
TYPE(params.beta) * src[params.offset_src + src_idx],
|
||||
src[params.offset_src + src_idx] > 0.0);
|
||||
#endif
|
||||
#ifdef SOFTPLUS
|
||||
let src_f32 = f32(src[params.offset_src + src_idx]);
|
||||
let res = TYPE(select(log(1.0 + exp(src_f32)), src_f32, src_f32 > 20.0));
|
||||
#endif
|
||||
#ifdef EXPM1
|
||||
let res = exp(src[params.offset_src + src_idx]) - 1.0;
|
||||
#endif
|
||||
#ifdef FLOOR
|
||||
let res = floor(src[params.offset_src + src_idx]);
|
||||
#endif
|
||||
#ifdef CEIL
|
||||
let res = ceil(src[params.offset_src + src_idx]);
|
||||
#endif
|
||||
#ifdef ROUND
|
||||
let src_f32 = f32(src[params.offset_src + src_idx]);
|
||||
let result = select(ceil(src_f32 - 0.5), floor(src_f32 + 0.5), src_f32 >= 0.0);
|
||||
let res = TYPE(result);
|
||||
#endif
|
||||
#ifdef TRUNC
|
||||
let res = trunc(src[params.offset_src + src_idx]);
|
||||
#endif
|
||||
|
||||
#ifdef INPLACE
|
||||
src[params.offset_src + src_idx] = res;
|
||||
#else
|
||||
dst[params.offset_dst + gid.x] = res;
|
||||
#endif
|
||||
}
|
||||
|
|
@ -1,483 +0,0 @@
|
|||
#define(REPL_TEMPLATES)
|
||||
|
||||
{
|
||||
"XIELU_FUNC": "{{MUTATE}}[dst_i] = select(((exp(min(src[src_i], {{TYPE}}(params.eps))) - 1.0) - src[src_i]) * {{TYPE}}(params.alpha_n) + {{TYPE}}(params.beta) * src[src_i], {{TYPE}}(params.alpha_p) * src[src_i] * src[src_i] + {{TYPE}}(params.beta) * src[src_i], src[src_i] > 0.0);",
|
||||
"ABS_FUNC": "{{MUTATE}}[dst_i] = abs(src[src_i]);",
|
||||
"SGN_FUNC": "{{MUTATE}}[dst_i] = select({{TYPE}}(select(0.0, -1.0, src[src_i] < 0.0)), {{TYPE}}(1.0), src[src_i] > 0.0);",
|
||||
"NEG_FUNC": "{{MUTATE}}[dst_i] = -src[src_i];",
|
||||
"STEP_FUNC": "{{MUTATE}}[dst_i] = {{TYPE}}(select(0.0, 1.0, src[src_i] > 0.0));",
|
||||
"TANH_FUNC": "{{MUTATE}}[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
|
||||
"RELU_FUNC": "{{MUTATE}}[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);",
|
||||
"ELU_FUNC": "{{MUTATE}}[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);",
|
||||
"HARDSIGMOID_FUNC": "{{MUTATE}}[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));",
|
||||
"SIGMOID_FUNC": "{{MUTATE}}[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));",
|
||||
"SILU_FUNC": "{{MUTATE}}[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));",
|
||||
"EXP_FUNC": "{{MUTATE}}[dst_i] = exp(src[src_i]);",
|
||||
"HARDSWISH_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));",
|
||||
"GELU_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
|
||||
"GELU_QUICK_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
|
||||
"GELU_ERF_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
|
||||
"CEIL_FUNC": "{{MUTATE}}[dst_i] = ceil(src[src_i]);"
|
||||
}
|
||||
|
||||
#end(REPL_TEMPLATES)
|
||||
|
||||
#define(VARIANTS)
|
||||
|
||||
[
|
||||
{
|
||||
"SHADER_NAME": "abs_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "abs_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "abs_inplace_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "abs_inplace_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
|
||||
{
|
||||
"SHADER_NAME": "sgn_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "sgn_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "sgn_inplace_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "sgn_inplace_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
|
||||
{
|
||||
"SHADER_NAME": "neg_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "neg_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "neg_inplace_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "neg_inplace_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
|
||||
{
|
||||
"SHADER_NAME": "step_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "step_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "step_inplace_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "step_inplace_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
|
||||
{
|
||||
"SHADER_NAME": "tanh_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "tanh_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "tanh_inplace_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "tanh_inplace_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
|
||||
{
|
||||
"SHADER_NAME": "elu_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "elu_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "elu_inplace_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "elu_inplace_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
|
||||
{
|
||||
"SHADER_NAME": "relu_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "relu_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "relu_inplace_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "relu_inplace_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
|
||||
{
|
||||
"SHADER_NAME": "sigmoid_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "sigmoid_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "sigmoid_inplace_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "sigmoid_inplace_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
|
||||
{
|
||||
"SHADER_NAME": "silu_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "silu_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "silu_inplace_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "silu_inplace_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
|
||||
{
|
||||
"SHADER_NAME": "exp_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "exp_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "exp_inplace_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "exp_inplace_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
|
||||
{
|
||||
"SHADER_NAME": "hardsigmoid_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "hardsigmoid_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "hardsigmoid_inplace_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "hardsigmoid_inplace_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
|
||||
{
|
||||
"SHADER_NAME": "hardswish_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "hardswish_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "hardswish_inplace_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "hardswish_inplace_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
|
||||
{
|
||||
"SHADER_NAME": "gelu_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "gelu_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "gelu_inplace_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "gelu_inplace_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
|
||||
{
|
||||
"SHADER_NAME": "gelu_quick_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "gelu_quick_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "gelu_quick_inplace_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "gelu_quick_inplace_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
|
||||
{
|
||||
"SHADER_NAME": "xielu_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "xielu_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "xielu_inplace_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "xielu_inplace_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "gelu_erf_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "gelu_erf_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "gelu_erf_inplace_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "gelu_erf_inplace_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
|
||||
{
|
||||
"SHADER_NAME": "ceil_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "ceil_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "ceil_inplace_f32",
|
||||
"REPLS": { "TYPE": "f32", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "ceil_inplace_f16",
|
||||
"REPLS": { "TYPE": "f16", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||
"DECLS": ["INPLACE"]
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(DECLS)
|
||||
|
||||
#decl(INPLACE)
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#enddecl(INPLACE)
|
||||
|
||||
#decl(NOT_INPLACE)
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<uniform> params: Params;
|
||||
|
||||
#enddecl(NOT_INPLACE)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
|
||||
enable f16;
|
||||
|
||||
fn update(dst_i: u32, src_i: u32) {
|
||||
{{FUNC}}
|
||||
}
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<{{TYPE}}>;
|
||||
|
||||
DECLS
|
||||
|
||||
struct Params {
|
||||
ne: u32, // total number of elements
|
||||
offset_src: u32, // in elements
|
||||
offset_dst: u32, // in elements
|
||||
|
||||
// Strides (in elements) — may be permuted
|
||||
stride_src0: u32,
|
||||
stride_src1: u32,
|
||||
stride_src2: u32,
|
||||
stride_src3: u32,
|
||||
|
||||
stride_dst0: u32,
|
||||
stride_dst1: u32,
|
||||
stride_dst2: u32,
|
||||
stride_dst3: u32,
|
||||
|
||||
// Logical shapes
|
||||
src_ne0: u32,
|
||||
src_ne1: u32,
|
||||
src_ne2: u32,
|
||||
|
||||
dst_ne0: u32,
|
||||
dst_ne1: u32,
|
||||
dst_ne2: u32,
|
||||
|
||||
{{EXT_PARAMS}}
|
||||
};
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne) {
|
||||
return;
|
||||
}
|
||||
|
||||
var i = gid.x;
|
||||
let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
|
||||
i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
|
||||
let i2 = i / (params.src_ne1 * params.src_ne0);
|
||||
i = i % (params.src_ne1 * params.src_ne0);
|
||||
let i1 = i / params.src_ne0;
|
||||
let i0 = i % params.src_ne0;
|
||||
|
||||
var j = gid.x;
|
||||
let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
|
||||
j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
|
||||
let j2 = j / (params.dst_ne1 * params.dst_ne0);
|
||||
j = j % (params.dst_ne1 * params.dst_ne0);
|
||||
let j1 = j / params.dst_ne0;
|
||||
let j0 = j % params.dst_ne0;
|
||||
|
||||
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
|
||||
i2 * params.stride_src2 + i3 * params.stride_src3;
|
||||
|
||||
let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
|
||||
j2 * params.stride_dst2 + j3 * params.stride_dst3;
|
||||
|
||||
|
||||
update(params.offset_dst + dst_idx, params.offset_src + src_idx);
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
|
||||
Loading…
Reference in New Issue