llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp

18 lines
466 B
Plaintext

#version 450
#include "rope_head.glsl"
#include "rope_funcs.glsl"
void main() {
const uint i0 = 2*gl_GlobalInvocationID.y;
const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z;
if (row >= pc.nrows) {
return;
}
const uint i3 = row / (pc.ne01*pc.ne02);
const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01;
const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01);
rope_neox(i0, i1, i2, i3, pc);
}