llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter...

113 lines
3.7 KiB
Plaintext

#version 450
#extension GL_EXT_control_flow_attributes : require
// Inter-chunk state propagation for chunked gated delta net
layout(constant_id = 0) const uint S_V = 128;
layout(constant_id = 1) const uint CHUNK_SIZE = 64;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout(push_constant) uniform Parameters {
uint H;
uint n_tokens;
uint n_seqs;
uint sq1, sq2, sq3;
uint sv1, sv2, sv3;
uint sb1, sb2, sb3;
uint neq1, rq3;
uint n_chunks;
uint s_off;
};
layout(binding = 0) readonly buffer KBuf { float k_in[]; };
layout(binding = 1) readonly buffer WBuf { float w_in[]; };
layout(binding = 2) readonly buffer UBuf { float u_in[]; };
layout(binding = 3) readonly buffer DecayBuf { float decay_in[]; };
layout(binding = 4) readonly buffer GCumBuf { float gcum_in[]; };
layout(binding = 5) readonly buffer StateBuf { float state_in[]; };
layout(binding = 6) writeonly buffer HBuf { float h_out[]; };
layout(binding = 7) writeonly buffer VNewBuf { float vnew_out[]; };
layout(binding = 8) buffer FinalBuf { float final_out[]; };
shared float s_w[S_V];
shared float s_kg[S_V];
void main() {
const uint head_id = gl_WorkGroupID.x;
const uint seq_id = gl_WorkGroupID.y;
const uint col = gl_LocalInvocationID.x;
if (col >= S_V) return;
const uint iq1 = head_id % neq1;
const uint iq3 = seq_id / rq3;
const uint state_size = S_V * S_V;
const uint state_base = (seq_id * H + head_id) * state_size;
float state[S_V];
[[unroll]] for (uint i = 0; i < S_V; i++) {
state[i] = state_in[state_base + col * S_V + i];
}
for (uint c = 0; c < n_chunks; c++) {
const uint chunk_start = c * CHUNK_SIZE;
const uint chunk_len = min(CHUNK_SIZE, n_tokens - chunk_start);
const uint h_base = ((seq_id * n_chunks + c) * H + head_id) * state_size;
[[unroll]] for (uint i = 0; i < S_V; i++) {
h_out[h_base + i * S_V + col] = state[i];
}
const uint wu_base = ((seq_id * n_chunks + c) * H + head_id) * CHUNK_SIZE * S_V;
const uint gcum_base = ((seq_id * n_chunks + c) * H + head_id) * CHUNK_SIZE;
const uint decay_idx = (seq_id * n_chunks + c) * H + head_id;
const float g_total = decay_in[decay_idx];
float delta[S_V];
[[unroll]] for (uint i = 0; i < S_V; i++) {
delta[i] = 0.0;
}
for (uint t = 0; t < chunk_len; t++) {
s_w[col] = w_in[wu_base + t * S_V + col];
const float g_cumsum_t = gcum_in[gcum_base + t];
const float decay_factor = exp(g_total - g_cumsum_t);
const uint t_global = chunk_start + t;
const uint k_off = iq3 * sq3 + t_global * sq2 + iq1 * sq1;
s_kg[col] = k_in[k_off + col] * decay_factor;
barrier();
float ws = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
ws += dot(
vec4(s_w[i], s_w[i+1], s_w[i+2], s_w[i+3]),
vec4(state[i], state[i+1], state[i+2], state[i+3])
);
}
float vnew = u_in[wu_base + t * S_V + col] - ws;
vnew_out[wu_base + t * S_V + col] = vnew;
[[unroll]] for (uint i = 0; i < S_V; i++) {
delta[i] += s_kg[i] * vnew;
}
barrier();
}
float total_decay = exp(g_total);
[[unroll]] for (uint i = 0; i < S_V; i++) {
state[i] = total_decay * state[i] + delta[i];
}
}
// Write final state to dst at s_off
[[unroll]] for (uint i = 0; i < S_V; i++) {
final_out[s_off + state_base + col * S_V + i] = state[i];
}
}