#version 450 #include "rte.glsl" #include "types.glsl" #include "generic_unary_head.glsl" #define GGML_TRI_TYPE_UPPER_DIAG 0 #define GGML_TRI_TYPE_UPPER 1 #define GGML_TRI_TYPE_LOWER_DIAG 2 #define GGML_TRI_TYPE_LOWER 3 layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; void main() { const uint idx = get_idx(); if (idx >= p.ne) { return; } const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L); const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L); const uint i02_offset = i02*p.ne01*p.ne00; const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L); const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; int param = floatBitsToInt(p.param1); bool pass = false; switch (param) { case GGML_TRI_TYPE_UPPER_DIAG: pass = i00 >= i01; break; case GGML_TRI_TYPE_UPPER: pass = i00 > i01; break; case GGML_TRI_TYPE_LOWER_DIAG: pass = i00 <= i01; break; case GGML_TRI_TYPE_LOWER: pass = i00 < i01; break; } if (pass) { const float val = float(data_a[get_aoffset() + src0_idx(idx)]); data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val); } else { data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0); } }