ggml-webgpu: Add supports for `DIAG` and `TRI` (#20664)

* Add supports for DIAG and TRI.

* Remove extra ttype and add a comment for TRI op.
This commit is contained in:
Masashi Yoshimura 2026-03-19 13:08:35 +09:00 committed by GitHub
parent 07ba6d275b
commit ea01d196d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 77 additions and 21 deletions

View File

@ -37,7 +37,7 @@ Legend:
| CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CUMSUM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |
| DIAG | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | | ❌ | ❌ |
| DIAG | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | | ❌ | ❌ |
| DIAG_MASK_INF | ❌ | ✅ | ✅ | ✅ | ❌ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
@ -115,7 +115,7 @@ Legend:
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | ❌ | ❌ |
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ |

View File

@ -10036,17 +10036,17 @@
"WebGPU: WebGPU","CUMSUM","type=f32,ne=[375960,1,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","CUMSUM","type=f32,ne=[20481,4,1,1]","support","1","yes","WebGPU"
"WebGPU: WebGPU","XIELU","type=f32,ne=[10,5,4,3]","support","1","yes","WebGPU"
"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","0","no","WebGPU"
"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","0","no","WebGPU"
"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","0","no","WebGPU"
"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","0","no","WebGPU"
"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","1","yes","WebGPU"
"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","1","yes","WebGPU"
"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","1","yes","WebGPU"
"WebGPU: WebGPU","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","1","yes","WebGPU"
"WebGPU: WebGPU","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","1","yes","WebGPU"
"WebGPU: WebGPU","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","1","yes","WebGPU"
"WebGPU: WebGPU","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","1","yes","WebGPU"
"WebGPU: WebGPU","FILL","type=f32,ne=[2048,512,2,2],c=3.500000","support","1","yes","WebGPU"
"WebGPU: WebGPU","DIAG","type=f32,ne=[10,1,4,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","DIAG","type=f32,ne=[79,1,19,13]","support","0","no","WebGPU"
"WebGPU: WebGPU","DIAG","type=f32,ne=[256,1,8,16]","support","0","no","WebGPU"
"WebGPU: WebGPU","DIAG","type=f32,ne=[10,1,4,3]","support","1","yes","WebGPU"
"WebGPU: WebGPU","DIAG","type=f32,ne=[79,1,19,13]","support","1","yes","WebGPU"
"WebGPU: WebGPU","DIAG","type=f32,ne=[256,1,8,16]","support","1","yes","WebGPU"
"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","0","no","WebGPU"
"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","0","no","WebGPU"
"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","0","no","WebGPU"

Can't render this file because it is too large.

View File

@ -244,13 +244,15 @@ struct ggml_webgpu_binary_pipeline_key_hash {
/** Unary **/
struct ggml_webgpu_unary_pipeline_key {
int type;
int op;
bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
bool inplace;
int type;
int op;
bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella
bool inplace;
ggml_tri_type ttype; // only used for GGML_OP_TRI
bool operator==(const ggml_webgpu_unary_pipeline_key & other) const {
return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace;
return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace &&
ttype == other.ttype;
}
};
@ -261,6 +263,7 @@ struct ggml_webgpu_unary_pipeline_key_hash {
ggml_webgpu_hash_combine(seed, key.op);
ggml_webgpu_hash_combine(seed, key.is_unary);
ggml_webgpu_hash_combine(seed, key.inplace);
ggml_webgpu_hash_combine(seed, key.ttype);
return seed;
}
};
@ -1058,6 +1061,7 @@ class ggml_webgpu_shader_lib {
.op = op,
.is_unary = is_unary,
.inplace = context.inplace,
.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0),
};
auto it = unary_pipelines.find(key);
@ -1088,6 +1092,29 @@ class ggml_webgpu_shader_lib {
variant += "_inplace";
}
if (op == GGML_OP_TRI) {
switch (key.ttype) {
case GGML_TRI_TYPE_LOWER:
defines.push_back("TRI_TYPE_LOWER");
variant += "_tri_type_lower";
break;
case GGML_TRI_TYPE_LOWER_DIAG:
defines.push_back("TRI_TYPE_LOWER_DIAG");
variant += "_tri_type_lower_diag";
break;
case GGML_TRI_TYPE_UPPER:
defines.push_back("TRI_TYPE_UPPER");
variant += "_tri_type_upper";
break;
case GGML_TRI_TYPE_UPPER_DIAG:
defines.push_back("TRI_TYPE_UPPER_DIAG");
variant += "_tri_upper_diag";
break;
default:
GGML_ABORT("Unsupported ggml_tri_type for unary shader");
}
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
auto processed = preprocessor.preprocess(wgsl_unary, defines);

View File

@ -2209,6 +2209,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_DIAG:
case GGML_OP_TRI:
return ggml_webgpu_unary_op(ctx, src0, node);
case GGML_OP_PAD:
return ggml_webgpu_pad(ctx, src0, node);
@ -3201,6 +3203,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_OP_COS:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
break;
case GGML_OP_DIAG:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
break;
case GGML_OP_TRI:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
break;
case GGML_OP_PAD:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break;

View File

@ -5,7 +5,6 @@ enable f16;
#define TYPE f32
#endif
@group(0) @binding(0)
var<storage, read_write> src: array<TYPE>;
@ -57,12 +56,20 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
return;
}
var i = gid.x;
let i3 = i / (params.ne2 * params.ne1 * params.ne0);
i = i % (params.ne2 * params.ne1 * params.ne0);
let i2 = i / (params.ne1 * params.ne0);
i = i % (params.ne1 * params.ne0);
let i1 = i / params.ne0;
let i0 = i % params.ne0;
let ne2 = params.ne2;
#ifdef DIAG
let ne1 = params.ne0;
#else
let ne1 = params.ne1;
#endif
let ne0 = params.ne0;
let i3 = i / (ne2 * ne1 * ne0);
i = i % (ne2 * ne1 * ne0);
let i2 = i / (ne1 * ne0);
i = i % (ne1 * ne0);
let i1 = i / ne0;
let i0 = i % ne0;
let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
i2 * params.stride_src2 + i3 * params.stride_src3;
@ -184,6 +191,20 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let res_f32 = cos(f32(src[params.offset_src + src_idx]));
let res = TYPE(res_f32);
#endif
#ifdef DIAG
let res = select(0.0, src[params.offset_src + i0 + i2 * params.stride_src2 + i3 * params.stride_src3], i0 == i1);
#endif
#ifdef TRI
#ifdef TRI_TYPE_LOWER
let res = select(0.0, src[params.offset_src + src_idx], i0 < i1);
#elif TRI_TYPE_LOWER_DIAG
let res = select(0.0, src[params.offset_src + src_idx], i0 <= i1);
#elif TRI_TYPE_UPPER
let res = select(0.0, src[params.offset_src + src_idx], i0 > i1);
#elif TRI_TYPE_UPPER_DIAG
let res = select(0.0, src[params.offset_src + src_idx], i0 >= i1);
#endif
#endif
#ifdef INPLACE
src[params.offset_src + src_idx] = res;