llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp

115 lines
3.9 KiB
Plaintext

#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_KHR_memory_scope_semantics : enable
#pragma use_vulkan_memory_model
#include "types.glsl"
layout(constant_id = 0) const int BLOCK_SIZE = 1024;
layout(constant_id = 1) const int WG_UNROLL_FACTOR = 2;
#define ASC 0
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) workgroupcoherent buffer B {ivec2 tmp_idx[];};
layout (binding = 2) workgroupcoherent buffer D {int data_d[];};
layout (push_constant) uniform parameter {
uint ncols;
uint ncols_padded;
uint ncols_padded_log2;
uint nrows;
uint order;
uint outer_start;
uint outer_end;
uint inner_start;
uint inner_end;
} p;
void argsort(bool needs_bounds_check, const uint row) {
// bitonic sort
int col = int(gl_GlobalInvocationID.x);
col = (col % BLOCK_SIZE) + (col / BLOCK_SIZE) * BLOCK_SIZE * WG_UNROLL_FACTOR;
const uint row_offset = row * p.ncols;
uint idx_offset = row * p.ncols_padded;
bool need_barrier = false;
// initialize indices
if (p.outer_start == 0 && p.inner_start == 0) {
[[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {
uint c = u*BLOCK_SIZE + col;
if (c < p.ncols_padded) {
ivec2 v = ivec2(c, floatBitsToInt(data_a[row_offset + c]));
tmp_idx[idx_offset + c] = v;
}
}
need_barrier = true;
}
[[unroll]] for (uint outer_idx = p.outer_start, k = (2 << outer_idx); outer_idx < p.outer_end; k *= 2, outer_idx++) {
uint inner_end = min(p.inner_end, outer_idx + 1);
for (uint j = k >> (p.inner_start + 1), inner_idx = p.inner_start; inner_idx < inner_end; j /= 2, inner_idx++) {
if (need_barrier) {
controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease);
}
need_barrier = true;
[[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {
int c = u*BLOCK_SIZE + col;
const int ixj = int(c ^ j);
if (ixj < c) {
continue;
}
int idx_0 = (c & k) == 0 ? c : ixj;
int idx_1 = (c & k) == 0 ? ixj : c;
ivec2 sh_idx_0 = tmp_idx[idx_offset + idx_0];
ivec2 sh_idx_1 = tmp_idx[idx_offset + idx_1];
bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.ncols : false;
bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.ncols : false;
if ((idx_0_oob ||
(!idx_1_oob && intBitsToFloat(sh_idx_0.y) > intBitsToFloat(sh_idx_1.y)))) {
tmp_idx[idx_offset + idx_0] = sh_idx_1;
tmp_idx[idx_offset + idx_1] = sh_idx_0;
}
}
}
}
if (p.outer_end == p.ncols_padded_log2 &&
p.inner_end >= p.ncols_padded_log2 + 1) {
controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup, gl_StorageSemanticsBuffer, gl_SemanticsAcquireRelease);
[[unroll]] for (int u = 0; u < WG_UNROLL_FACTOR; ++u) {
uint c = u*BLOCK_SIZE + col;
if (c < p.ncols) {
if (p.order == ASC) {
data_d[row_offset + c] = tmp_idx[idx_offset + c].x;
} else {
data_d[row_offset + p.ncols - c - 1] = tmp_idx[idx_offset + c].x;
}
}
}
}
}
void main() {
if (p.ncols == p.ncols_padded) {
uint row = gl_WorkGroupID.y;
while (row < p.nrows) {
argsort(false, row);
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
} else {
uint row = gl_WorkGroupID.y;
while (row < p.nrows) {
argsort(true, row);
row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
}
}