ggml webgpu: unary op suppport, code refactoring, ops support (#17764)
* Squashed commit of the following:
commit b3c6bf4b0450d8d452b934df27a0fb7cb53cd755
Author: Abhijit Ramesh <abhijitramesh2k@gmail.com>
Date: Mon Dec 1 18:29:00 2025 -0800
ggml webgpu: fix xielu parameter passing (#11)
The XIELU operation was incorrectly using static_cast to convert
float parameters to uint32_t, which converted numeric values instead
of preserving IEEE 754 bit patterns. This caused incorrect values
to be interpreted by the GPU shader.
* Use reinterpret_cast to preserve float bit patterns when passing
through uint32_t params buffer
* Update WGSL shader parameter types from u32 to f32
* Re-enable XIELU support (was disabled due to numerical issues)
Fixes NMSE test failures for XIELU operation on WebGPU backend.
commit 5ca9b5e49ea7cddc9ab7c8b43a11a9c76a4dff4a
Author: neha-ha <137219201+neha-ha@users.noreply.github.com>
Date: Tue Nov 18 12:17:00 2025 -0800
Refactored pipelines and workgroup calculations (#10)
* refactored pipelines
* refactored workgroup calculation
* removed commented out block of prior maps
* Clean up ceiling division pattern
---------
Co-authored-by: Neha Abbas <nehaabbas@eduroam-169-233-141-223.ucsc.edu>
Co-authored-by: Reese Levine <reeselevine1@gmail.com>
Author: James Contini <jamescontini@gmail.com>
Date: Wed Oct 29 23:13:06 2025 -0700
formatted embed wgsl and ggml-webgpu.cpp
commit e1f6baea31645e5d96ad53664acae856f74b96f4
Author: James Contini <jamescontini@gmail.com>
Date: Wed Oct 29 23:08:37 2025 -0700
implemented REPL_Template support and removed bug in unary operators kernel
commit 8c70b8fece445cdc9a8c660dbddbf201e52da2bb
Author: James Contini <jamescontini@gmail.com>
Date: Wed Oct 15 16:14:20 2025 -0700
responded and dealt with PR comments
commit f9282c660c10dec4487d434549bdb707a9cd9f37
Author: James Contini <jamescontini@gmail.com>
Date: Sun Oct 12 13:41:41 2025 -0700
removed unnecesarry checking if node->src[1] exists for unary operators
commit 4cf28d7dec41c29186d66152735b244c5699f9dc
Author: James Contini <jamescontini@gmail.com>
Date: Sun Oct 12 13:32:45 2025 -0700
All operators (inlcluding xielu) working
commit 74c6add1761a59d2c2ff60b60e8ad3c8300f6d3e
Author: James Contini <jamescontini@gmail.com>
Date: Fri Oct 10 13:16:48 2025 -0700
fixed autoconfig
commit 362749910be4f0120c8ffb21ceddeb7d2c088e51
Author: James Contini <jamescontini@gmail.com>
Date: Fri Oct 10 13:10:46 2025 -0700
removed vestigial files
commit cb0858333785757804c5104e59c4981843207c16
Author: James Contini <jamescontini@gmail.com>
Date: Fri Oct 10 12:59:32 2025 -0700
abides by editor-config
commit 5360e2852a4b51197d7d67d0a5d42e908b02d7ed
Author: James Contini <jamescontini@gmail.com>
Date: Fri Oct 10 12:45:57 2025 -0700
rms_norm double declaration bug atoned
commit 7b09baa4aa53711be5a126043670cc182c78bfcd
Merge: 8a6ec843 74b8fc17
Author: James Contini <jamescontini@gmail.com>
Date: Fri Oct 10 11:50:03 2025 -0700
resolving merge conflicts
commit 8a6ec843a50ab82f8cef59b4558eb63f318ba02d
Author: James Contini <jamescontini@gmail.com>
Date: Wed Oct 8 18:06:47 2025 -0700
unary operators pass ggml tests
commit c3ae38278a2db236adc5912c9140e4f0d63f2c19
Author: James Contini <jamescontini@gmail.com>
Date: Wed Oct 1 16:22:40 2025 -0700
neg passes backend test
commit aa1c9b2f8877a405470ca56709c42a1fd43713de
Author: James Contini <jamescontini@gmail.com>
Date: Tue Sep 30 23:55:27 2025 -0700
neg f16xf32xip builds and runs, havent actually ran a model that uses neg kernel yet though
Co-authored-by: James Contini <jamescontini@gmail.com>
Co-authored-by: Neha Abbas <neabbas@ucsc.edu>
Co-authored-by: Abhijit Ramesh <abhijitramesh2k@gmail.com>
* Remove extra code and format
* Add ops documentation (finally)
* Update ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
---------
Co-authored-by: James Contini <jamescontini@gmail.com>
Co-authored-by: Neha Abbas <neabbas@ucsc.edu>
Co-authored-by: Abhijit Ramesh <abhijitramesh2k@gmail.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
parent
6ab0d64960
commit
fd57b24c0f
216
docs/ops.md
216
docs/ops.md
|
|
@ -12,111 +12,111 @@ Legend:
|
||||||
- 🟡 Partially supported by this backend
|
- 🟡 Partially supported by this backend
|
||||||
- ❌ Not supported by this backend
|
- ❌ Not supported by this backend
|
||||||
|
|
||||||
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | zDNN |
|
| Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | WebGPU | zDNN |
|
||||||
|-----------|------|------|------|------|------|------|------|------|------|
|
|-----------|------|------|------|------|------|------|------|------|------|------|
|
||||||
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
| ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
| ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||||
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
| ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
| ADD_ID | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||||
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
|
||||||
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ |
|
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||||
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
||||||
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ |
|
| CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
| CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||||
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||||
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| CUMSUM | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
| CUMSUM | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ |
|
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||||
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||||
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
|
||||||
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
| ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ |
|
||||||
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
| EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ |
|
| EXPM1 | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| FILL | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
| FILL | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
|
| FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||||
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
| FLOOR | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||||
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
| GATED_LINEAR_ATTN | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||||
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| GEGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| GEGLU_ERF | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| GEGLU_QUICK | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
| GELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
| GELU_ERF | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
| GELU_QUICK | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
|
| GET_ROWS | ❌ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||||
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
| HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
| HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| IM2COL | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ |
|
||||||
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ❌ |
|
| LOG | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | ✅ | ❌ | ❌ |
|
||||||
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||||
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 |
|
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
|
||||||
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
|
| MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||||
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
| NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ |
|
| NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ | ❌ |
|
||||||
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
| OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
| OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
| OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||||
| PAD | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
| PAD | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||||
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
|
| PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||||
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| POOL_2D | ❌ | 🟡 | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| REGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
| RELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ |
|
| REPEAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | 🟡 | ❌ | ❌ |
|
||||||
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
| REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| RMS_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
| RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| RMS_NORM_MUL_ADD | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
| RMS_NORM_MUL_ADD | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
|
| ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
| ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
| ROUND | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||||
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| RWKV_WKV6 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| RWKV_WKV7 | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| SCALE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
| SET | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
|
||||||
| SET_ROWS | ❌ | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
| SET_ROWS | ❌ | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
|
||||||
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ❌ |
|
| SGN | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | ❌ | ✅ | ❌ |
|
||||||
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
| SIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
|
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||||
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| SOFTCAP | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
|
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||||
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ |
|
||||||
| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | 🟡 | ❌ |
|
| SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||||
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
| SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||||
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
| SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||||
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
|
| SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ |
|
| SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||||
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ❌ |
|
| STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
|
| SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ✅ | ❌ |
|
||||||
| SUM | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
|
| SUM | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||||
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
| SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||||
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ |
|
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ |
|
||||||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ❌ |
|
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
|
||||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
|
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| TOP_K | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | 🟡 | ❌ |
|
| TOP_K | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | 🟡 | ❌ | ❌ |
|
||||||
| TRI | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ |
|
| TRI | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
|
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ |
|
||||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
|
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ |
|
||||||
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ |
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -19,6 +19,15 @@ def parse_decls(decls_text):
|
||||||
return decls
|
return decls
|
||||||
|
|
||||||
|
|
||||||
|
def replace_repl_placeholders(variant, template_map):
|
||||||
|
for repl, code in variant["REPLS"].items():
|
||||||
|
for key, val in template_map.items():
|
||||||
|
# Match "key" and avoid matching subsequences using by using \b
|
||||||
|
code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code)
|
||||||
|
variant["REPLS"][repl] = code
|
||||||
|
return variant
|
||||||
|
|
||||||
|
|
||||||
def replace_placeholders(shader_text, replacements):
|
def replace_placeholders(shader_text, replacements):
|
||||||
for key, val in replacements.items():
|
for key, val in replacements.items():
|
||||||
# Match {{KEY}} literally, where KEY is escaped
|
# Match {{KEY}} literally, where KEY is escaped
|
||||||
|
|
@ -71,6 +80,10 @@ def generate_variants(fname, input_dir, output_dir, outfile):
|
||||||
decls_map = parse_decls(extract_block(text, "DECLS"))
|
decls_map = parse_decls(extract_block(text, "DECLS"))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
decls_map = {}
|
decls_map = {}
|
||||||
|
try:
|
||||||
|
templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES"))
|
||||||
|
except ValueError:
|
||||||
|
templates_map = {}
|
||||||
|
|
||||||
for fname in sorted(os.listdir(input_dir)):
|
for fname in sorted(os.listdir(input_dir)):
|
||||||
if fname.endswith(".tmpl"):
|
if fname.endswith(".tmpl"):
|
||||||
|
|
@ -90,9 +103,11 @@ def generate_variants(fname, input_dir, output_dir, outfile):
|
||||||
if key not in decls_map:
|
if key not in decls_map:
|
||||||
raise ValueError(f"DECLS key '{key}' not found.")
|
raise ValueError(f"DECLS key '{key}' not found.")
|
||||||
decls_code += decls_map[key] + "\n\n"
|
decls_code += decls_map[key] + "\n\n"
|
||||||
|
|
||||||
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template)
|
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template)
|
||||||
if "REPLS" in variant:
|
if "REPLS" in variant:
|
||||||
|
variant = replace_repl_placeholders(variant, templates_map)
|
||||||
|
final_shader = replace_placeholders(final_shader, variant["REPLS"])
|
||||||
|
# second run to expand placeholders in repl_template
|
||||||
final_shader = replace_placeholders(final_shader, variant["REPLS"])
|
final_shader = replace_placeholders(final_shader, variant["REPLS"])
|
||||||
final_shader = expand_includes(final_shader, input_dir)
|
final_shader = expand_includes(final_shader, input_dir)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,461 @@
|
||||||
|
#define(REPL_TEMPLATES)
|
||||||
|
|
||||||
|
{
|
||||||
|
"XIELU_FUNC": "{{MUTATE}}[dst_i] = select(((exp(min(src[src_i], {{TYPE}}(params.eps))) - 1.0) - src[src_i]) * {{TYPE}}(params.alpha_n) + {{TYPE}}(params.beta) * src[src_i], {{TYPE}}(params.alpha_p) * src[src_i] * src[src_i] + {{TYPE}}(params.beta) * src[src_i], src[src_i] > 0.0);",
|
||||||
|
"ABS_FUNC": "{{MUTATE}}[dst_i] = abs(src[src_i]);",
|
||||||
|
"SGN_FUNC": "{{MUTATE}}[dst_i] = select({{TYPE}}(select(0.0, -1.0, src[src_i] < 0.0)), {{TYPE}}(1.0), src[src_i] > 0.0);",
|
||||||
|
"NEG_FUNC": "{{MUTATE}}[dst_i] = -src[src_i];",
|
||||||
|
"STEP_FUNC": "{{MUTATE}}[dst_i] = {{TYPE}}(select(0.0, 1.0, src[src_i] > 0.0));",
|
||||||
|
"TANH_FUNC": "{{MUTATE}}[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
|
||||||
|
"RELU_FUNC": "{{MUTATE}}[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);",
|
||||||
|
"ELU_FUNC": "{{MUTATE}}[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);",
|
||||||
|
"HARDSIGMOID_FUNC": "{{MUTATE}}[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));",
|
||||||
|
"SIGMOID_FUNC": "{{MUTATE}}[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));",
|
||||||
|
"SILU_FUNC": "{{MUTATE}}[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));",
|
||||||
|
"EXP_FUNC": "{{MUTATE}}[dst_i] = exp(src[src_i]);",
|
||||||
|
"HARDSWISH_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));",
|
||||||
|
"GELU_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
|
||||||
|
"GELU_QUICK_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458",
|
||||||
|
"GELU_ERF_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458"
|
||||||
|
}
|
||||||
|
|
||||||
|
#end(REPL_TEMPLATES)
|
||||||
|
|
||||||
|
#define(VARIANTS)
|
||||||
|
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "abs_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "abs_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "abs_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "abs_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sgn_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sgn_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sgn_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sgn_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "neg_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "neg_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "neg_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "neg_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "step_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "step_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "step_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "step_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "tanh_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "tanh_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "tanh_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "tanh_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "elu_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "elu_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "elu_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "elu_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "relu_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "relu_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "relu_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "relu_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sigmoid_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sigmoid_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sigmoid_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "sigmoid_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "silu_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "silu_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "silu_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "silu_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "exp_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "exp_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "exp_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "exp_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "hardsigmoid_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "hardsigmoid_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "hardsigmoid_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "hardsigmoid_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "hardswish_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "hardswish_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "hardswish_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "hardswish_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_quick_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_quick_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_quick_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_quick_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "xielu_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "xielu_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "xielu_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "xielu_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_erf_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_erf_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" },
|
||||||
|
"DECLS": ["NOT_INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_erf_inplace_f32",
|
||||||
|
"REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"SHADER_NAME": "gelu_erf_inplace_f16",
|
||||||
|
"REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" },
|
||||||
|
"DECLS": ["INPLACE"]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
#end(VARIANTS)
|
||||||
|
|
||||||
|
#define(DECLS)
|
||||||
|
|
||||||
|
#decl(INPLACE)
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
#enddecl(INPLACE)
|
||||||
|
|
||||||
|
#decl(NOT_INPLACE)
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<uniform> params: Params;
|
||||||
|
|
||||||
|
#enddecl(NOT_INPLACE)
|
||||||
|
|
||||||
|
#end(DECLS)
|
||||||
|
|
||||||
|
#define(SHADER)
|
||||||
|
|
||||||
|
enable f16;
|
||||||
|
|
||||||
|
fn update(dst_i: u32, src_i: u32) {
|
||||||
|
{{FUNC}}
|
||||||
|
}
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read_write> src: array<{{TYPE}}>;
|
||||||
|
|
||||||
|
DECLS
|
||||||
|
|
||||||
|
struct Params {
|
||||||
|
ne: u32, // total number of elements
|
||||||
|
offset_src: u32, // in elements
|
||||||
|
offset_dst: u32, // in elements
|
||||||
|
|
||||||
|
// Strides (in elements) — may be permuted
|
||||||
|
stride_src0: u32,
|
||||||
|
stride_src1: u32,
|
||||||
|
stride_src2: u32,
|
||||||
|
stride_src3: u32,
|
||||||
|
|
||||||
|
stride_dst0: u32,
|
||||||
|
stride_dst1: u32,
|
||||||
|
stride_dst2: u32,
|
||||||
|
stride_dst3: u32,
|
||||||
|
|
||||||
|
// Logical shapes
|
||||||
|
src_ne0: u32,
|
||||||
|
src_ne1: u32,
|
||||||
|
src_ne2: u32,
|
||||||
|
|
||||||
|
dst_ne0: u32,
|
||||||
|
dst_ne1: u32,
|
||||||
|
dst_ne2: u32,
|
||||||
|
|
||||||
|
{{EXT_PARAMS}}
|
||||||
|
};
|
||||||
|
|
||||||
|
override wg_size: u32;
|
||||||
|
@compute @workgroup_size(wg_size)
|
||||||
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||||
|
if (gid.x >= params.ne) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
var i = gid.x;
|
||||||
|
let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0);
|
||||||
|
i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0);
|
||||||
|
let i2 = i / (params.src_ne1 * params.src_ne0);
|
||||||
|
i = i % (params.src_ne1 * params.src_ne0);
|
||||||
|
let i1 = i / params.src_ne0;
|
||||||
|
let i0 = i % params.src_ne0;
|
||||||
|
|
||||||
|
var j = gid.x;
|
||||||
|
let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
|
||||||
|
j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0);
|
||||||
|
let j2 = j / (params.dst_ne1 * params.dst_ne0);
|
||||||
|
j = j % (params.dst_ne1 * params.dst_ne0);
|
||||||
|
let j1 = j / params.dst_ne0;
|
||||||
|
let j0 = j % params.dst_ne0;
|
||||||
|
|
||||||
|
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
|
||||||
|
i2 * params.stride_src2 + i3 * params.stride_src3;
|
||||||
|
|
||||||
|
let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 +
|
||||||
|
j2 * params.stride_dst2 + j3 * params.stride_dst3;
|
||||||
|
|
||||||
|
|
||||||
|
update(params.offset_dst + dst_idx, params.offset_src + src_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
#end(SHADER)
|
||||||
|
|
||||||
Loading…
Reference in New Issue