113 lines
3.7 KiB
Plaintext
113 lines
3.7 KiB
Plaintext
#version 450
|
|
|
|
#extension GL_EXT_control_flow_attributes : require
|
|
|
|
// Inter-chunk state propagation for chunked gated delta net
|
|
|
|
layout(constant_id = 0) const uint S_V = 128;
|
|
layout(constant_id = 1) const uint CHUNK_SIZE = 64;
|
|
|
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
layout(push_constant) uniform Parameters {
|
|
uint H;
|
|
uint n_tokens;
|
|
uint n_seqs;
|
|
uint sq1, sq2, sq3;
|
|
uint sv1, sv2, sv3;
|
|
uint sb1, sb2, sb3;
|
|
uint neq1, rq3;
|
|
uint n_chunks;
|
|
uint s_off;
|
|
};
|
|
|
|
layout(binding = 0) readonly buffer KBuf { float k_in[]; };
|
|
layout(binding = 1) readonly buffer WBuf { float w_in[]; };
|
|
layout(binding = 2) readonly buffer UBuf { float u_in[]; };
|
|
layout(binding = 3) readonly buffer DecayBuf { float decay_in[]; };
|
|
layout(binding = 4) readonly buffer GCumBuf { float gcum_in[]; };
|
|
layout(binding = 5) readonly buffer StateBuf { float state_in[]; };
|
|
layout(binding = 6) writeonly buffer HBuf { float h_out[]; };
|
|
layout(binding = 7) writeonly buffer VNewBuf { float vnew_out[]; };
|
|
layout(binding = 8) buffer FinalBuf { float final_out[]; };
|
|
|
|
shared float s_w[S_V];
|
|
shared float s_kg[S_V];
|
|
|
|
void main() {
|
|
const uint head_id = gl_WorkGroupID.x;
|
|
const uint seq_id = gl_WorkGroupID.y;
|
|
const uint col = gl_LocalInvocationID.x;
|
|
|
|
if (col >= S_V) return;
|
|
|
|
const uint iq1 = head_id % neq1;
|
|
const uint iq3 = seq_id / rq3;
|
|
|
|
const uint state_size = S_V * S_V;
|
|
const uint state_base = (seq_id * H + head_id) * state_size;
|
|
|
|
float state[S_V];
|
|
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
|
state[i] = state_in[state_base + col * S_V + i];
|
|
}
|
|
|
|
for (uint c = 0; c < n_chunks; c++) {
|
|
const uint chunk_start = c * CHUNK_SIZE;
|
|
const uint chunk_len = min(CHUNK_SIZE, n_tokens - chunk_start);
|
|
|
|
const uint h_base = ((seq_id * n_chunks + c) * H + head_id) * state_size;
|
|
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
|
h_out[h_base + i * S_V + col] = state[i];
|
|
}
|
|
|
|
const uint wu_base = ((seq_id * n_chunks + c) * H + head_id) * CHUNK_SIZE * S_V;
|
|
const uint gcum_base = ((seq_id * n_chunks + c) * H + head_id) * CHUNK_SIZE;
|
|
|
|
const uint decay_idx = (seq_id * n_chunks + c) * H + head_id;
|
|
const float g_total = decay_in[decay_idx];
|
|
|
|
float delta[S_V];
|
|
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
|
delta[i] = 0.0;
|
|
}
|
|
|
|
for (uint t = 0; t < chunk_len; t++) {
|
|
s_w[col] = w_in[wu_base + t * S_V + col];
|
|
|
|
const float g_cumsum_t = gcum_in[gcum_base + t];
|
|
const float decay_factor = exp(g_total - g_cumsum_t);
|
|
const uint t_global = chunk_start + t;
|
|
const uint k_off = iq3 * sq3 + t_global * sq2 + iq1 * sq1;
|
|
s_kg[col] = k_in[k_off + col] * decay_factor;
|
|
barrier();
|
|
|
|
float ws = 0.0;
|
|
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
|
|
ws += dot(
|
|
vec4(s_w[i], s_w[i+1], s_w[i+2], s_w[i+3]),
|
|
vec4(state[i], state[i+1], state[i+2], state[i+3])
|
|
);
|
|
}
|
|
|
|
float vnew = u_in[wu_base + t * S_V + col] - ws;
|
|
vnew_out[wu_base + t * S_V + col] = vnew;
|
|
|
|
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
|
delta[i] += s_kg[i] * vnew;
|
|
}
|
|
barrier();
|
|
}
|
|
|
|
float total_decay = exp(g_total);
|
|
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
|
state[i] = total_decay * state[i] + delta[i];
|
|
}
|
|
}
|
|
|
|
// Write final state to dst at s_off
|
|
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
|
final_out[s_off + state_base + col * S_V + i] = state[i];
|
|
}
|
|
}
|