sycl: implement GGML_OP_TRI (#19089)
* sycl: implement GGML_OP_TRI * docs: update ops.md for SYCL TRI * docs: regenerate ops.md * docs: update SYCL support for GGML_OP_TRI
This commit is contained in:
parent
d284baf1b5
commit
c7358ddf64
|
|
@ -114,7 +114,7 @@ Legend:
|
|||
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
|
||||
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
|
||||
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | ❌ |
|
||||
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
|
|
|
|||
|
|
@ -10052,10 +10052,10 @@
|
|||
"SYCL0","CUMSUM","type=f32,ne=[375960,1,1,1]","support","0","no","SYCL"
|
||||
"SYCL0","CUMSUM","type=f32,ne=[20481,4,1,1]","support","0","no","SYCL"
|
||||
"SYCL0","XIELU","type=f32,ne=[10,5,4,3]","support","0","no","SYCL"
|
||||
"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","0","no","SYCL"
|
||||
"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","0","no","SYCL"
|
||||
"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","0","no","SYCL"
|
||||
"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","0","no","SYCL"
|
||||
"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=3","support","1","yes","SYCL"
|
||||
"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=2","support","1","yes","SYCL"
|
||||
"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=1","support","1","yes","SYCL"
|
||||
"SYCL0","TRI","type=f32,ne=[10,10,4,3],tri_type=0","support","1","yes","SYCL"
|
||||
"SYCL0","FILL","type=f32,ne=[10,10,4,3],c=0.000000","support","0","no","SYCL"
|
||||
"SYCL0","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","0","no","SYCL"
|
||||
"SYCL0","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","0","no","SYCL"
|
||||
|
|
|
|||
|
Can't render this file because it is too large.
|
|
|
@ -2263,6 +2263,65 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_ten
|
|||
diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
|
||||
}
|
||||
|
||||
static void tri_f32_sycl(
|
||||
const float * src,
|
||||
float * dst,
|
||||
const int64_t ne0,
|
||||
const int64_t ne1,
|
||||
const int64_t ne2,
|
||||
const int64_t ne3,
|
||||
const ggml_tri_type ttype,
|
||||
dpct::queue_ptr main_stream
|
||||
) {
|
||||
const size_t total = (size_t) ne0 * (size_t) ne1 * (size_t) ne2 * (size_t) ne3;
|
||||
|
||||
main_stream->parallel_for(sycl::range<1>(total), [=](sycl::id<1> tid) {
|
||||
const int64_t idx = (int64_t) tid[0];
|
||||
|
||||
const int64_t i0 = idx % ne0;
|
||||
const int64_t t1 = idx / ne0;
|
||||
const int64_t i1 = t1 % ne1;
|
||||
|
||||
bool keep = false;
|
||||
switch (ttype) {
|
||||
case GGML_TRI_TYPE_LOWER: keep = (i0 < i1); break;
|
||||
case GGML_TRI_TYPE_LOWER_DIAG: keep = (i0 <= i1); break;
|
||||
case GGML_TRI_TYPE_UPPER: keep = (i0 > i1); break;
|
||||
case GGML_TRI_TYPE_UPPER_DIAG: keep = (i0 >= i1); break;
|
||||
default: keep = false; break;
|
||||
}
|
||||
|
||||
dst[idx] = keep ? src[idx] : 0.0f;
|
||||
});
|
||||
}
|
||||
|
||||
static void ggml_sycl_op_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
GGML_ASSERT(src0);
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, dst));
|
||||
|
||||
dpct::queue_ptr main_stream = ctx.stream();
|
||||
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
|
||||
|
||||
const float * src0_dd = static_cast<const float *>(src0->data);
|
||||
float * dst_dd = static_cast<float *>(dst->data);
|
||||
|
||||
const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
|
||||
|
||||
const int64_t ne0 = src0->ne[0];
|
||||
const int64_t ne1 = src0->ne[1];
|
||||
const int64_t ne2 = src0->ne[2];
|
||||
const int64_t ne3 = src0->ne[3];
|
||||
|
||||
tri_f32_sycl(src0_dd, dst_dd, ne0, ne1, ne2, ne3, ttype, main_stream);
|
||||
}
|
||||
|
||||
|
||||
inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
|
@ -3912,6 +3971,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
|||
case GGML_OP_TRANSPOSE:
|
||||
GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__);
|
||||
break;
|
||||
case GGML_OP_TRI:
|
||||
ggml_sycl_op_tri(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
ggml_sycl_diag_mask_inf(ctx, dst);
|
||||
break;
|
||||
|
|
@ -4616,6 +4678,13 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
return true;
|
||||
case GGML_OP_CONT:
|
||||
return op->src[0]->type != GGML_TYPE_BF16;
|
||||
case GGML_OP_TRI:
|
||||
{
|
||||
const ggml_tensor * src0 = op->src[0];
|
||||
return src0 &&
|
||||
op->type == GGML_TYPE_F32 &&
|
||||
ggml_is_contiguous(src0);
|
||||
}
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
return true;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
|
|
|
|||
Loading…
Reference in New Issue