vulkan: clean up chunked GDN shaders for PR review
Remove verbose algorithm comments, section dividers, stale inline constant annotations, and unused extensions. Match llama.cpp codebase style (minimal comments, no section decorators). No functional changes. 16/16 tests pass.
This commit is contained in:
parent
d2fabedf09
commit
e22c2b2c85
|
|
@ -10519,8 +10519,6 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
|
|||
n_chunks, s_off
|
||||
};
|
||||
|
||||
// Dispatch 1: Intra-chunk (parallel across chunks)
|
||||
// Bindings: K, V, G, Beta, W_out, U_out, Decay_out, GCum_out
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pl_intra,
|
||||
{src_buf[1], src_buf[2], src_buf[3], src_buf[4],
|
||||
scratch_w, scratch_u, scratch_dec, scratch_gcum},
|
||||
|
|
@ -10528,8 +10526,6 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
|
|||
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
|
||||
// Dispatch 2: Inter-chunk state propagation (sequential across chunks)
|
||||
// Bindings: K, W, U, Decay, GCum, State, H_out, VNew_out, Final(dst)
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pl_inter,
|
||||
{src_buf[1], scratch_w, scratch_u, scratch_dec, scratch_gcum,
|
||||
src_buf[5], scratch_h, scratch_vnew, dst_buf},
|
||||
|
|
@ -10537,8 +10533,6 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
|
|||
|
||||
ggml_vk_sync_buffers(ctx, subctx);
|
||||
|
||||
// Dispatch 3: Output (parallel across chunks)
|
||||
// Bindings: Q, K, H, VNew, GCum, Dst
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pl_output,
|
||||
{src_buf[0], src_buf[1], scratch_h, scratch_vnew, scratch_gcum, dst_buf},
|
||||
pc, { n_chunks * H, n_seqs, 1u });
|
||||
|
|
|
|||
|
|
@ -3,17 +3,6 @@
|
|||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
// Inter-chunk state propagation for chunked gated delta net
|
||||
//
|
||||
// Sequential across chunks, parallel across state columns.
|
||||
// For each chunk c:
|
||||
// 1. Store state snapshot h[c] for output kernel
|
||||
// 2. v_corrected = U - W @ S (C x d)
|
||||
// 3. S_next = exp(g_total) * S + K_gated^T @ v_corrected (d x d)
|
||||
//
|
||||
// where K_gated[t] = k[t] * exp(g_total - g_cumsum[t])
|
||||
//
|
||||
// Grid: (H, n_seqs, 1)
|
||||
// Workgroup: S_V threads (one per state column)
|
||||
|
||||
layout(constant_id = 0) const uint S_V = 128;
|
||||
layout(constant_id = 1) const uint CHUNK_SIZE = 64;
|
||||
|
|
|
|||
|
|
@ -2,21 +2,7 @@
|
|||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
|
||||
// Intra-chunk kernel for chunked gated delta net (non-KDA scalar gate)
|
||||
//
|
||||
// For each chunk of C=64 tokens, computes W and U using the WY representation.
|
||||
// Uses a single A matrix with gates, compensates W by multiplying k by exp(g).
|
||||
//
|
||||
// Algorithm (FLA-equivalent):
|
||||
// 1. g_cumsum[j] = cumsum(g[0..j]) within chunk (log-space)
|
||||
// 2. A[i][j] = -beta[i] * dot(k[i], k[j]) * exp(g_cumsum[i] - g_cumsum[j]) for i > j
|
||||
// 3. T = (I + A)^{-1} via forward substitution (row-by-row parallel)
|
||||
// 4. U[i] = sum_j T[i][j] * beta[j] * v[j]
|
||||
// 5. W[i] = sum_j T[i][j] * beta[j] * exp(g_cumsum[j]) * k[j]
|
||||
// 6. Output g_cumsum[C-1] as total chunk decay
|
||||
//
|
||||
// Grid: (n_chunks * H, n_seqs, 1)
|
||||
// Workgroup: CHUNK_SIZE threads (one per token in chunk)
|
||||
// Intra-chunk WY decomposition for chunked gated delta net
|
||||
|
||||
layout(constant_id = 0) const uint S_V = 128;
|
||||
layout(constant_id = 1) const uint CHUNK_SIZE = 64;
|
||||
|
|
@ -78,7 +64,6 @@ void main() {
|
|||
}
|
||||
barrier();
|
||||
|
||||
// Step 1: Prefix sum of log-gates
|
||||
if (tid == 0) {
|
||||
for (uint i = 1; i < chunk_len; i++) {
|
||||
s_decay[i] += s_decay[i - 1];
|
||||
|
|
@ -86,13 +71,11 @@ void main() {
|
|||
}
|
||||
barrier();
|
||||
|
||||
// Output per-token g_cumsum for inter-chunk and output kernels
|
||||
if (valid) {
|
||||
const uint gcum_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * CHUNK_SIZE;
|
||||
gcum_out[gcum_base + tid] = s_decay[tid];
|
||||
}
|
||||
|
||||
// Load my k vector into registers
|
||||
float my_k[S_V];
|
||||
if (valid) {
|
||||
const uint k_off = iq3 * sq3 + global_t * sq2 + iq1 * sq1;
|
||||
|
|
@ -105,17 +88,12 @@ void main() {
|
|||
}
|
||||
}
|
||||
|
||||
// Step 2: Build A matrix using shared memory broadcast for k[j]
|
||||
// A[i][j] = -beta[i] * dot(k[i], k[j]) * exp(g_cumsum[i] - g_cumsum[j]) for i > j
|
||||
// Initialize to zero
|
||||
for (uint j = 0; j < CHUNK_SIZE; j++) {
|
||||
A(tid, j) = 0.0;
|
||||
}
|
||||
barrier();
|
||||
|
||||
// For each column j, broadcast k[j] via shared memory, all threads compute their row
|
||||
for (uint j = 0; j < chunk_len; j++) {
|
||||
// Broadcast k[j] — need multiple passes when S_V > CHUNK_SIZE
|
||||
{
|
||||
const uint j_global = chunk_start + j;
|
||||
const uint kj_off = iq3 * sq3 + j_global * sq2 + iq1 * sq1;
|
||||
|
|
@ -140,12 +118,8 @@ void main() {
|
|||
barrier();
|
||||
}
|
||||
|
||||
// Step 3: Forward substitution T = (I + A)^{-1}
|
||||
// Process row by row. For row i, all threads j < i compute in parallel:
|
||||
// T[i][j] += sum_m T[i][m] * T[m][j] for m in [j..i-1]
|
||||
// The A matrix is modified in-place to become T.
|
||||
// Forward substitution: T = (I + A)^{-1}, in-place
|
||||
for (uint i = 1; i < chunk_len; i++) {
|
||||
// Each thread with tid < i computes T[i][tid]
|
||||
if (tid < i) {
|
||||
float sum = 0.0;
|
||||
for (uint m = tid; m < i; m++) {
|
||||
|
|
@ -156,19 +130,12 @@ void main() {
|
|||
barrier();
|
||||
}
|
||||
|
||||
// Add identity
|
||||
if (valid) {
|
||||
A(tid, tid) = 1.0;
|
||||
}
|
||||
barrier();
|
||||
|
||||
// Step 4+5: Compute W and U via shared memory broadcast + register accumulation
|
||||
// U[tid][d] = sum_j T[tid][j] * beta[j] * v[j][d]
|
||||
// W[tid][d] = sum_j T[tid][j] * beta[j] * exp(g_cumsum[j]) * k[j][d]
|
||||
//
|
||||
// For each j, broadcast k[j]*exp(g[j]) and v[j] via shared memory.
|
||||
// Accumulate in d-tiles of 32 to limit register pressure.
|
||||
|
||||
// W and U via tiled broadcast accumulation
|
||||
const uint out_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * CHUNK_SIZE * S_V;
|
||||
const uint TILE_D = 32;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,27 +1,13 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
#extension GL_KHR_cooperative_matrix : enable
|
||||
|
||||
#include "types.glsl"
|
||||
|
||||
// Coopmat output kernel for chunked gated delta net
|
||||
//
|
||||
// Phase 1: A = Q @ K^T (coopmat, f16→f32)
|
||||
// Phase 2: Decay mask → sh_adecay (f16) + vnew → sh_kv (f16, pre-scaled)
|
||||
// Pass 1: O_inter = Q @ S → dst (scalar, 128 threads)
|
||||
// Pass 2: O_intra = A_decayed @ vnew → dst (coopmat GEMM, full chunks)
|
||||
// Partial last chunk: scalar fallback. 3 barriers total.
|
||||
//
|
||||
// Grid: (n_chunks * H, n_seqs, 1)
|
||||
// Workgroup: WG_SIZE threads = 4 subgroups
|
||||
|
||||
layout(constant_id = 0) const uint WG_SIZE = 256;
|
||||
layout(constant_id = 1) const uint CHUNK_SIZE = 64;
|
||||
layout(constant_id = 2) const uint S_V = 128;
|
||||
|
|
@ -51,16 +37,16 @@ const uint TM = 16;
|
|||
const uint TN = 16;
|
||||
const uint TK = 16;
|
||||
|
||||
const uint C_TILES = CHUNK_SIZE / TM; // 4
|
||||
const uint D_TILES = S_V / TN; // 8
|
||||
const uint C_TILES = CHUNK_SIZE / TM;
|
||||
const uint D_TILES = S_V / TN;
|
||||
|
||||
// Shared memory strides in f16vec4 units, padded for bank conflicts
|
||||
const uint QK_STRIDE = S_V / 4 + 2; // 34
|
||||
const uint ATTN_V4_STRIDE = CHUNK_SIZE / 4 + 2; // 18
|
||||
const uint QK_STRIDE = S_V / 4 + 2;
|
||||
const uint ATTN_V4_STRIDE = CHUNK_SIZE / 4 + 2;
|
||||
|
||||
shared f16vec4 sh_q[CHUNK_SIZE * QK_STRIDE]; // Q (f16) for coopmat Phase 1
|
||||
shared f16vec4 sh_kv[CHUNK_SIZE * QK_STRIDE]; // K (f16) for coopmat Phase 1
|
||||
shared vec4 sh_attn[CHUNK_SIZE * ATTN_V4_STRIDE]; // attention matrix (f32)
|
||||
shared f16vec4 sh_q[CHUNK_SIZE * QK_STRIDE];
|
||||
shared f16vec4 sh_kv[CHUNK_SIZE * QK_STRIDE];
|
||||
shared vec4 sh_attn[CHUNK_SIZE * ATTN_V4_STRIDE];
|
||||
shared f16vec4 sh_adecay[CHUNK_SIZE * ATTN_V4_STRIDE];
|
||||
shared float sh_gcum[CHUNK_SIZE];
|
||||
|
||||
|
|
@ -80,10 +66,6 @@ void main() {
|
|||
const uint chunk_len = min(CHUNK_SIZE, n_tokens - chunk_start);
|
||||
const float scale = 1.0 / sqrt(float(S_V));
|
||||
|
||||
// ================================================================
|
||||
// Load Q, K, gcum to shared memory
|
||||
// ================================================================
|
||||
|
||||
if (tid < CHUNK_SIZE) {
|
||||
const uint gcum_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * CHUNK_SIZE;
|
||||
sh_gcum[tid] = (tid < chunk_len) ? gcum_in[gcum_base + tid] : 0.0;
|
||||
|
|
@ -105,10 +87,7 @@ void main() {
|
|||
|
||||
barrier();
|
||||
|
||||
// ================================================================
|
||||
// Phase 1: A = Q @ K^T [C×D] × [D×C] → [C×C] (coopmat)
|
||||
// ================================================================
|
||||
|
||||
// A = Q @ K^T (coopmat)
|
||||
coopmat<float, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> A_acc[C_TILES];
|
||||
[[unroll]] for (uint tj = 0; tj < C_TILES; tj++) {
|
||||
A_acc[tj] = coopmat<float, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0);
|
||||
|
|
@ -133,7 +112,6 @@ void main() {
|
|||
}
|
||||
}
|
||||
|
||||
// Store A to sh_attn as f32
|
||||
[[unroll]] for (uint tj = 0; tj < C_TILES; tj++) {
|
||||
coopMatStore(A_acc[tj], sh_attn,
|
||||
sg_id * TM * ATTN_V4_STRIDE + tj * (TN / 4),
|
||||
|
|
@ -142,16 +120,10 @@ void main() {
|
|||
|
||||
barrier();
|
||||
|
||||
// ================================================================
|
||||
// Phase 2: Decay mask + vnew load (all 256 threads)
|
||||
// Pass 1: Inter-chunk Q@S → dst (128 active threads)
|
||||
// No shared-memory write conflicts — runs without intermediate barrier.
|
||||
// ================================================================
|
||||
|
||||
const uint wu_base = ((seq_id * n_chunks + chunk_id) * H + head_id) * CHUNK_SIZE * S_V;
|
||||
const uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
|
||||
|
||||
// Phase 2a: A_decayed = causal_decay_mask(A) → sh_adecay (f16)
|
||||
// Causal decay mask (f16)
|
||||
for (uint idx = tid; idx < CHUNK_SIZE * (CHUNK_SIZE / 4); idx += WG_SIZE) {
|
||||
const uint t = idx / (CHUNK_SIZE / 4);
|
||||
const uint j4 = idx % (CHUNK_SIZE / 4);
|
||||
|
|
@ -170,7 +142,7 @@ void main() {
|
|||
sh_adecay[t * ATTN_V4_STRIDE + j4] = val;
|
||||
}
|
||||
|
||||
// Phase 2b: vnew → sh_kv (f16, pre-scaled by 1/√S_V)
|
||||
// vnew to f16, pre-scaled
|
||||
for (uint idx = tid; idx < CHUNK_SIZE * (S_V / 4); idx += WG_SIZE) {
|
||||
const uint row = idx / (S_V / 4);
|
||||
const uint col4 = idx % (S_V / 4);
|
||||
|
|
@ -186,7 +158,7 @@ void main() {
|
|||
sh_kv[row * QK_STRIDE + col4] = val;
|
||||
}
|
||||
|
||||
// Pass 1: Inter-chunk (128 active threads, write directly to dst)
|
||||
// O_inter = Q @ state
|
||||
{
|
||||
const uint col = tid;
|
||||
const bool col_active = (col < S_V);
|
||||
|
|
@ -216,12 +188,7 @@ void main() {
|
|||
|
||||
barrier();
|
||||
|
||||
// ================================================================
|
||||
// Pass 2: Intra-chunk A_decayed[C×C] @ vnew[C×S_V] → [C×S_V]
|
||||
// Full chunks: coopmat GEMM (accumulates onto inter results in dst).
|
||||
// Partial (last) chunk: scalar fallback.
|
||||
// ================================================================
|
||||
|
||||
// O_intra = A_decayed @ vnew (coopmat GEMM, scalar fallback for partial chunks)
|
||||
if (chunk_len == CHUNK_SIZE) {
|
||||
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> A_mat;
|
||||
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> V_mat;
|
||||
|
|
|
|||
Loading…
Reference in New Issue