From c7358ddf6458ffaf45dc1d9ca447b354c919fcff Mon Sep 17 00:00:00 2001 From: RachelMantel Date: Fri, 30 Jan 2026 06:00:49 +0200 Subject: [PATCH] sycl: implement GGML_OP_TRI (#19089) * sycl: implement GGML_OP_TRI * docs: update ops.md for SYCL TRI * docs: regenerate ops.md * docs: update SYCL support for GGML_OP_TRI --- docs/ops.md | 2 +- docs/ops/SYCL.csv | 8 ++-- ggml/src/ggml-sycl/ggml-sycl.cpp | 69 ++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 5 deletions(-) diff --git a/docs/ops.md b/docs/ops.md index c066ab5a85..b8e0347803 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -114,7 +114,7 @@ Legend: | TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | | TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | | TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | -| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | +| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ | | UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | ❌ | | XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/ops/SYCL.csv b/docs/ops/SYCL.csv index 091a5caed7..255b4ef473 100644 --- a/docs/ops/SYCL.csv +++ b/docs/ops/SYCL.csv @@ -10052,10 +10052,10 @@ "SYCL0","CUMSUM","type=f32,ne=[375960,1,1,1]","support","0","no","SYCL" "SYCL0","CUMSUM","type=f32,ne=[20481,4,1,1]","support","0","no","SYCL" "SYCL0","XIELU","type=f32,ne=[10,5,4,3]","support","0","no","SYCL" -"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","0","no","SYCL" -"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","0","no","SYCL" -"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","0","no","SYCL" -"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","0","no","SYCL" +"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","1","yes","SYCL" +"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","1","yes","SYCL" +"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","1","yes","SYCL" +"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","1","yes","SYCL" "SYCL0","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","0","no","SYCL" "SYCL0","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","0","no","SYCL" "SYCL0","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","0","no","SYCL" diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 3a4c092af5..d20b7ec57d 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2263,6 +2263,65 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_ten diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream); } +static void tri_f32_sycl( + const float * src, + float * dst, + const int64_t ne0, + const int64_t ne1, + const int64_t ne2, + const int64_t ne3, + const ggml_tri_type ttype, + dpct::queue_ptr main_stream +) { + const size_t total = (size_t) ne0 * (size_t) ne1 * (size_t) ne2 * (size_t) ne3; + + main_stream->parallel_for(sycl::range<1>(total), [=](sycl::id<1> tid) { + const int64_t idx = (int64_t) tid[0]; + + const int64_t i0 = idx % ne0; + const int64_t t1 = idx / ne0; + const int64_t i1 = t1 % ne1; + + bool keep = false; + switch (ttype) { + case GGML_TRI_TYPE_LOWER: keep = (i0 < i1); break; + case GGML_TRI_TYPE_LOWER_DIAG: keep = (i0 <= i1); break; + case GGML_TRI_TYPE_UPPER: keep = (i0 > i1); break; + case GGML_TRI_TYPE_UPPER_DIAG: keep = (i0 >= i1); break; + default: keep = false; break; + } + + dst[idx] = keep ? src[idx] : 0.0f; + }); +} + +static void ggml_sycl_op_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + GGML_ASSERT(src0); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const float * src0_dd = static_cast(src0->data); + float * dst_dd = static_cast(dst->data); + + const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0); + + const int64_t ne0 = src0->ne[0]; + const int64_t ne1 = src0->ne[1]; + const int64_t ne2 = src0->ne[2]; + const int64_t ne3 = src0->ne[3]; + + tri_f32_sycl(src0_dd, dst_dd, ne0, ne1, ne2, ne3, ttype, main_stream); +} + + inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -3912,6 +3971,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_TRANSPOSE: GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__); break; + case GGML_OP_TRI: + ggml_sycl_op_tri(ctx, dst); + break; case GGML_OP_DIAG_MASK_INF: ggml_sycl_diag_mask_inf(ctx, dst); break; @@ -4616,6 +4678,13 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return true; case GGML_OP_CONT: return op->src[0]->type != GGML_TYPE_BF16; + case GGML_OP_TRI: + { + const ggml_tensor * src0 = op->src[0]; + return src0 && + op->type == GGML_TYPE_F32 && + ggml_is_contiguous(src0); + } case GGML_OP_DIAG_MASK_INF: return true; case GGML_OP_SOFT_MAX: