44 lines
1.3 KiB
Plaintext
44 lines
1.3 KiB
Plaintext
#version 450
|
|
|
|
#include "rte.glsl"
|
|
#include "types.glsl"
|
|
#include "generic_unary_head.glsl"
|
|
|
|
#define GGML_TRI_TYPE_UPPER_DIAG 0
|
|
#define GGML_TRI_TYPE_UPPER 1
|
|
#define GGML_TRI_TYPE_LOWER_DIAG 2
|
|
#define GGML_TRI_TYPE_LOWER 3
|
|
|
|
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
void main() {
|
|
const uint idx = get_idx();
|
|
|
|
if (idx >= p.ne) {
|
|
return;
|
|
}
|
|
|
|
const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);
|
|
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
|
|
const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);
|
|
const uint i02_offset = i02*p.ne01*p.ne00;
|
|
const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);
|
|
const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
|
|
|
|
int param = floatBitsToInt(p.param1);
|
|
bool pass = false;
|
|
switch (param) {
|
|
case GGML_TRI_TYPE_UPPER_DIAG: pass = i00 >= i01; break;
|
|
case GGML_TRI_TYPE_UPPER: pass = i00 > i01; break;
|
|
case GGML_TRI_TYPE_LOWER_DIAG: pass = i00 <= i01; break;
|
|
case GGML_TRI_TYPE_LOWER: pass = i00 < i01; break;
|
|
}
|
|
|
|
if (pass) {
|
|
const float val = float(data_a[get_aoffset() + src0_idx(idx)]);
|
|
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val);
|
|
} else {
|
|
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0);
|
|
}
|
|
}
|