diff --git a/docs/ops.md b/docs/ops.md index 1357771442..47534b1401 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -37,7 +37,7 @@ Legend: | CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | -| DIAG | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | +| DIAG | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | | DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ | | DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | | DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ | @@ -115,7 +115,7 @@ Legend: | TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | | TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | | TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ | -| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | +| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | | TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ | | UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ | | XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | diff --git a/docs/ops/WebGPU.csv b/docs/ops/WebGPU.csv index b7761b9dd3..56bae2f3c8 100644 --- a/docs/ops/WebGPU.csv +++ b/docs/ops/WebGPU.csv @@ -10036,17 +10036,17 @@ "WebGPU: WebGPU","CUMSUM","type=f32,ne=[375960,1,1,1]","support","1","yes","WebGPU" "WebGPU: WebGPU","CUMSUM","type=f32,ne=[20481,4,1,1]","support","1","yes","WebGPU" "WebGPU: WebGPU","XIELU","type=f32,ne=[10,5,4,3]","support","1","yes","WebGPU" -"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","1","yes","WebGPU" +"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","1","yes","WebGPU" +"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","1","yes","WebGPU" +"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","1","yes","WebGPU" "WebGPU: WebGPU","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","1","yes","WebGPU" "WebGPU: WebGPU","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","1","yes","WebGPU" "WebGPU: WebGPU","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","1","yes","WebGPU" "WebGPU: WebGPU","FILL","type=f32,ne=[2048,512,2,2],c=3.500000","support","1","yes","WebGPU" -"WebGPU: WebGPU","DIAG","type=f32,ne=[10,1,4,3]","support","0","no","WebGPU" -"WebGPU: WebGPU","DIAG","type=f32,ne=[79,1,19,13]","support","0","no","WebGPU" -"WebGPU: WebGPU","DIAG","type=f32,ne=[256,1,8,16]","support","0","no","WebGPU" +"WebGPU: WebGPU","DIAG","type=f32,ne=[10,1,4,3]","support","1","yes","WebGPU" +"WebGPU: WebGPU","DIAG","type=f32,ne=[79,1,19,13]","support","1","yes","WebGPU" +"WebGPU: WebGPU","DIAG","type=f32,ne=[256,1,8,16]","support","1","yes","WebGPU" "WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","0","no","WebGPU" "WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","0","no","WebGPU" "WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","0","no","WebGPU" diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 3d7e59fddf..ad665e4de9 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -244,13 +244,15 @@ struct ggml_webgpu_binary_pipeline_key_hash { /** Unary **/ struct ggml_webgpu_unary_pipeline_key { - int type; - int op; - bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella - bool inplace; + int type; + int op; + bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella + bool inplace; + ggml_tri_type ttype; // only used for GGML_OP_TRI bool operator==(const ggml_webgpu_unary_pipeline_key & other) const { - return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace; + return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace && + ttype == other.ttype; } }; @@ -261,6 +263,7 @@ struct ggml_webgpu_unary_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.op); ggml_webgpu_hash_combine(seed, key.is_unary); ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.ttype); return seed; } }; @@ -1058,6 +1061,7 @@ class ggml_webgpu_shader_lib { .op = op, .is_unary = is_unary, .inplace = context.inplace, + .ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0), }; auto it = unary_pipelines.find(key); @@ -1088,6 +1092,29 @@ class ggml_webgpu_shader_lib { variant += "_inplace"; } + if (op == GGML_OP_TRI) { + switch (key.ttype) { + case GGML_TRI_TYPE_LOWER: + defines.push_back("TRI_TYPE_LOWER"); + variant += "_tri_type_lower"; + break; + case GGML_TRI_TYPE_LOWER_DIAG: + defines.push_back("TRI_TYPE_LOWER_DIAG"); + variant += "_tri_type_lower_diag"; + break; + case GGML_TRI_TYPE_UPPER: + defines.push_back("TRI_TYPE_UPPER"); + variant += "_tri_type_upper"; + break; + case GGML_TRI_TYPE_UPPER_DIAG: + defines.push_back("TRI_TYPE_UPPER_DIAG"); + variant += "_tri_upper_diag"; + break; + default: + GGML_ABORT("Unsupported ggml_tri_type for unary shader"); + } + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); auto processed = preprocessor.preprocess(wgsl_unary, defines); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 3976a171d1..4b0eeac0f4 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2209,6 +2209,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: + case GGML_OP_DIAG: + case GGML_OP_TRI: return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_PAD: return ggml_webgpu_pad(ctx, src0, node); @@ -3201,6 +3203,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_COS: supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); break; + case GGML_OP_DIAG: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; + case GGML_OP_TRI: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; case GGML_OP_PAD: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index feaf6d0ac2..21beb9bb94 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -5,7 +5,6 @@ enable f16; #define TYPE f32 #endif - @group(0) @binding(0) var src: array; @@ -57,12 +56,20 @@ fn main(@builtin(global_invocation_id) gid: vec3) { return; } var i = gid.x; - let i3 = i / (params.ne2 * params.ne1 * params.ne0); - i = i % (params.ne2 * params.ne1 * params.ne0); - let i2 = i / (params.ne1 * params.ne0); - i = i % (params.ne1 * params.ne0); - let i1 = i / params.ne0; - let i0 = i % params.ne0; + let ne2 = params.ne2; +#ifdef DIAG + let ne1 = params.ne0; +#else + let ne1 = params.ne1; +#endif + let ne0 = params.ne0; + + let i3 = i / (ne2 * ne1 * ne0); + i = i % (ne2 * ne1 * ne0); + let i2 = i / (ne1 * ne0); + i = i % (ne1 * ne0); + let i1 = i / ne0; + let i0 = i % ne0; let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + i2 * params.stride_src2 + i3 * params.stride_src3; @@ -184,6 +191,20 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res_f32 = cos(f32(src[params.offset_src + src_idx])); let res = TYPE(res_f32); #endif +#ifdef DIAG + let res = select(0.0, src[params.offset_src + i0 + i2 * params.stride_src2 + i3 * params.stride_src3], i0 == i1); +#endif +#ifdef TRI +#ifdef TRI_TYPE_LOWER + let res = select(0.0, src[params.offset_src + src_idx], i0 < i1); +#elif TRI_TYPE_LOWER_DIAG + let res = select(0.0, src[params.offset_src + src_idx], i0 <= i1); +#elif TRI_TYPE_UPPER + let res = select(0.0, src[params.offset_src + src_idx], i0 > i1); +#elif TRI_TYPE_UPPER_DIAG + let res = select(0.0, src[params.offset_src + src_idx], i0 >= i1); +#endif +#endif #ifdef INPLACE src[params.offset_src + src_idx] = res;