Revert "vulkan: fused inter+output kernel for chunked GDN"
This reverts commit 08c355c01f3a298ef943216d4c55367a1c967286.
This commit is contained in:
parent
b0323615c9
commit
efbde13283
|
|
@ -831,7 +831,6 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_gated_delta_net_chunk_inter;
|
vk_pipeline pipeline_gated_delta_net_chunk_inter;
|
||||||
vk_pipeline pipeline_gated_delta_net_chunk_output;
|
vk_pipeline pipeline_gated_delta_net_chunk_output;
|
||||||
vk_pipeline pipeline_gated_delta_net_chunk_output_cm;
|
vk_pipeline pipeline_gated_delta_net_chunk_output_cm;
|
||||||
vk_pipeline pipeline_gated_delta_net_chunk_fused_cm;
|
|
||||||
vk_pipeline pipeline_ssm_scan_f32_d128;
|
vk_pipeline pipeline_ssm_scan_f32_d128;
|
||||||
vk_pipeline pipeline_ssm_scan_f32_d256;
|
vk_pipeline pipeline_ssm_scan_f32_d256;
|
||||||
vk_pipeline pipeline_ssm_conv_f32;
|
vk_pipeline pipeline_ssm_conv_f32;
|
||||||
|
|
@ -4630,9 +4629,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_chunk_output_cm, "gated_delta_net_chunk_output_cm1_f32_d128",
|
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_chunk_output_cm, "gated_delta_net_chunk_output_cm1_f32_d128",
|
||||||
gated_delta_net_chunk_output_cm1_f32_len, gated_delta_net_chunk_output_cm1_f32_data, "main",
|
gated_delta_net_chunk_output_cm1_f32_len, gated_delta_net_chunk_output_cm1_f32_data, "main",
|
||||||
6, sizeof(vk_op_gated_delta_net_chunk_push_constants), {1, 1, 1}, {256, 64, 128}, 1, true);
|
6, sizeof(vk_op_gated_delta_net_chunk_push_constants), {1, 1, 1}, {256, 64, 128}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_chunk_fused_cm, "gated_delta_net_chunk_fused_cm_f32_d128",
|
|
||||||
gated_delta_net_chunk_fused_cm_f32_len, gated_delta_net_chunk_fused_cm_f32_data, "main",
|
|
||||||
8, sizeof(vk_op_gated_delta_net_chunk_push_constants), {1, 1, 1}, {256, 64, 128}, 1, true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
|
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
|
||||||
|
|
@ -10474,12 +10470,45 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
|
||||||
// Chunked parallel path (PP acceleration)
|
// Chunked parallel path (PP acceleration)
|
||||||
const uint32_t n_chunks = (n_tokens + GDN_CHUNK_SIZE - 1) / GDN_CHUNK_SIZE;
|
const uint32_t n_chunks = (n_tokens + GDN_CHUNK_SIZE - 1) / GDN_CHUNK_SIZE;
|
||||||
|
|
||||||
vk_pipeline pl_intra = ctx->device->pipeline_gated_delta_net_chunk_intra;
|
vk_pipeline pl_intra = ctx->device->pipeline_gated_delta_net_chunk_intra;
|
||||||
vk_pipeline pl_fused = ctx->device->pipeline_gated_delta_net_chunk_fused_cm;
|
vk_pipeline pl_inter = ctx->device->pipeline_gated_delta_net_chunk_inter;
|
||||||
|
vk_pipeline pl_output = ctx->device->pipeline_gated_delta_net_chunk_output_cm
|
||||||
|
? ctx->device->pipeline_gated_delta_net_chunk_output_cm
|
||||||
|
: ctx->device->pipeline_gated_delta_net_chunk_output;
|
||||||
|
|
||||||
const size_t wu_size = (size_t)n_seqs * n_chunks * H * GDN_CHUNK_SIZE * S_v * sizeof(float);
|
ggml_pipeline_request_descriptor_sets(ctx, pl_intra, 1);
|
||||||
const size_t d_size = (size_t)n_seqs * n_chunks * H * sizeof(float);
|
ggml_pipeline_request_descriptor_sets(ctx, pl_inter, 1);
|
||||||
const size_t g_size = (size_t)n_seqs * n_chunks * H * GDN_CHUNK_SIZE * sizeof(float);
|
ggml_pipeline_request_descriptor_sets(ctx, pl_output, 1);
|
||||||
|
|
||||||
|
// Scratch buffer layout within prealloc_split_k
|
||||||
|
const size_t wu_size = (size_t)n_seqs * n_chunks * H * GDN_CHUNK_SIZE * S_v * sizeof(float);
|
||||||
|
const size_t d_size = (size_t)n_seqs * n_chunks * H * sizeof(float);
|
||||||
|
const size_t g_size = (size_t)n_seqs * n_chunks * H * GDN_CHUNK_SIZE * sizeof(float);
|
||||||
|
const size_t h_size = (size_t)n_seqs * n_chunks * H * S_v * S_v * sizeof(float);
|
||||||
|
|
||||||
|
const size_t w_off = 0;
|
||||||
|
const size_t u_off = wu_size;
|
||||||
|
const size_t vn_off = 2 * wu_size;
|
||||||
|
const size_t dec_off = 3 * wu_size;
|
||||||
|
const size_t gcum_off = dec_off + d_size;
|
||||||
|
const size_t h_off = gcum_off + g_size;
|
||||||
|
const size_t total_scratch = h_off + h_size;
|
||||||
|
|
||||||
|
if (ctx->prealloc_size_split_k < total_scratch) {
|
||||||
|
ctx->prealloc_size_split_k = total_scratch;
|
||||||
|
ggml_vk_preallocate_buffers(ctx, subctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ctx->prealloc_split_k_need_sync) {
|
||||||
|
ggml_vk_sync_buffers(ctx, subctx);
|
||||||
|
}
|
||||||
|
|
||||||
|
vk_subbuffer scratch_w = { ctx->prealloc_split_k, w_off, wu_size };
|
||||||
|
vk_subbuffer scratch_u = { ctx->prealloc_split_k, u_off, wu_size };
|
||||||
|
vk_subbuffer scratch_vnew = { ctx->prealloc_split_k, vn_off, wu_size };
|
||||||
|
vk_subbuffer scratch_dec = { ctx->prealloc_split_k, dec_off, d_size };
|
||||||
|
vk_subbuffer scratch_gcum = { ctx->prealloc_split_k, gcum_off, g_size };
|
||||||
|
vk_subbuffer scratch_h = { ctx->prealloc_split_k, h_off, h_size };
|
||||||
|
|
||||||
const vk_op_gated_delta_net_chunk_push_constants pc = {
|
const vk_op_gated_delta_net_chunk_push_constants pc = {
|
||||||
H, n_tokens, n_seqs,
|
H, n_tokens, n_seqs,
|
||||||
|
|
@ -10490,95 +10519,29 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
|
||||||
n_chunks, s_off
|
n_chunks, s_off
|
||||||
};
|
};
|
||||||
|
|
||||||
if (pl_fused) {
|
// Dispatch 1: Intra-chunk (parallel across chunks)
|
||||||
// Fused inter+output path: 2 dispatches, no vnew/h scratch
|
// Bindings: K, V, G, Beta, W_out, U_out, Decay_out, GCum_out
|
||||||
ggml_pipeline_request_descriptor_sets(ctx, pl_intra, 1);
|
ggml_vk_dispatch_pipeline(ctx, subctx, pl_intra,
|
||||||
ggml_pipeline_request_descriptor_sets(ctx, pl_fused, 1);
|
{src_buf[1], src_buf[2], src_buf[3], src_buf[4],
|
||||||
|
scratch_w, scratch_u, scratch_dec, scratch_gcum},
|
||||||
|
pc, { n_chunks * H, n_seqs, 1u });
|
||||||
|
|
||||||
const size_t w_off = 0;
|
ggml_vk_sync_buffers(ctx, subctx);
|
||||||
const size_t u_off = wu_size;
|
|
||||||
const size_t dec_off = 2 * wu_size;
|
|
||||||
const size_t gcum_off = dec_off + d_size;
|
|
||||||
const size_t total_scratch = gcum_off + g_size;
|
|
||||||
|
|
||||||
if (ctx->prealloc_size_split_k < total_scratch) {
|
// Dispatch 2: Inter-chunk state propagation (sequential across chunks)
|
||||||
ctx->prealloc_size_split_k = total_scratch;
|
// Bindings: K, W, U, Decay, GCum, State, H_out, VNew_out, Final(dst)
|
||||||
ggml_vk_preallocate_buffers(ctx, subctx);
|
ggml_vk_dispatch_pipeline(ctx, subctx, pl_inter,
|
||||||
}
|
{src_buf[1], scratch_w, scratch_u, scratch_dec, scratch_gcum,
|
||||||
|
src_buf[5], scratch_h, scratch_vnew, dst_buf},
|
||||||
|
pc, { H, n_seqs, 1u });
|
||||||
|
|
||||||
if (ctx->prealloc_split_k_need_sync) {
|
ggml_vk_sync_buffers(ctx, subctx);
|
||||||
ggml_vk_sync_buffers(ctx, subctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
vk_subbuffer scratch_w = { ctx->prealloc_split_k, w_off, wu_size };
|
// Dispatch 3: Output (parallel across chunks)
|
||||||
vk_subbuffer scratch_u = { ctx->prealloc_split_k, u_off, wu_size };
|
// Bindings: Q, K, H, VNew, GCum, Dst
|
||||||
vk_subbuffer scratch_dec = { ctx->prealloc_split_k, dec_off, d_size };
|
ggml_vk_dispatch_pipeline(ctx, subctx, pl_output,
|
||||||
vk_subbuffer scratch_gcum = { ctx->prealloc_split_k, gcum_off, g_size };
|
{src_buf[0], src_buf[1], scratch_h, scratch_vnew, scratch_gcum, dst_buf},
|
||||||
|
pc, { n_chunks * H, n_seqs, 1u });
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pl_intra,
|
|
||||||
{src_buf[1], src_buf[2], src_buf[3], src_buf[4],
|
|
||||||
scratch_w, scratch_u, scratch_dec, scratch_gcum},
|
|
||||||
pc, { n_chunks * H, n_seqs, 1u });
|
|
||||||
|
|
||||||
ggml_vk_sync_buffers(ctx, subctx);
|
|
||||||
|
|
||||||
// Bindings: Q, K, W, U, Decay, GCum, State, Dst
|
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pl_fused,
|
|
||||||
{src_buf[0], src_buf[1], scratch_w, scratch_u,
|
|
||||||
scratch_dec, scratch_gcum, src_buf[5], dst_buf},
|
|
||||||
pc, { H, n_seqs, 1u });
|
|
||||||
} else {
|
|
||||||
// Fallback: 3-dispatch path (no coopmat)
|
|
||||||
vk_pipeline pl_inter = ctx->device->pipeline_gated_delta_net_chunk_inter;
|
|
||||||
vk_pipeline pl_output = ctx->device->pipeline_gated_delta_net_chunk_output;
|
|
||||||
|
|
||||||
ggml_pipeline_request_descriptor_sets(ctx, pl_intra, 1);
|
|
||||||
ggml_pipeline_request_descriptor_sets(ctx, pl_inter, 1);
|
|
||||||
ggml_pipeline_request_descriptor_sets(ctx, pl_output, 1);
|
|
||||||
|
|
||||||
const size_t h_size = (size_t)n_seqs * n_chunks * H * S_v * S_v * sizeof(float);
|
|
||||||
const size_t w_off = 0;
|
|
||||||
const size_t u_off = wu_size;
|
|
||||||
const size_t vn_off = 2 * wu_size;
|
|
||||||
const size_t dec_off = 3 * wu_size;
|
|
||||||
const size_t gcum_off = dec_off + d_size;
|
|
||||||
const size_t h_off = gcum_off + g_size;
|
|
||||||
const size_t total_scratch = h_off + h_size;
|
|
||||||
|
|
||||||
if (ctx->prealloc_size_split_k < total_scratch) {
|
|
||||||
ctx->prealloc_size_split_k = total_scratch;
|
|
||||||
ggml_vk_preallocate_buffers(ctx, subctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ctx->prealloc_split_k_need_sync) {
|
|
||||||
ggml_vk_sync_buffers(ctx, subctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
vk_subbuffer scratch_w = { ctx->prealloc_split_k, w_off, wu_size };
|
|
||||||
vk_subbuffer scratch_u = { ctx->prealloc_split_k, u_off, wu_size };
|
|
||||||
vk_subbuffer scratch_vnew = { ctx->prealloc_split_k, vn_off, wu_size };
|
|
||||||
vk_subbuffer scratch_dec = { ctx->prealloc_split_k, dec_off, d_size };
|
|
||||||
vk_subbuffer scratch_gcum = { ctx->prealloc_split_k, gcum_off, g_size };
|
|
||||||
vk_subbuffer scratch_h = { ctx->prealloc_split_k, h_off, h_size };
|
|
||||||
|
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pl_intra,
|
|
||||||
{src_buf[1], src_buf[2], src_buf[3], src_buf[4],
|
|
||||||
scratch_w, scratch_u, scratch_dec, scratch_gcum},
|
|
||||||
pc, { n_chunks * H, n_seqs, 1u });
|
|
||||||
|
|
||||||
ggml_vk_sync_buffers(ctx, subctx);
|
|
||||||
|
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pl_inter,
|
|
||||||
{src_buf[1], scratch_w, scratch_u, scratch_dec, scratch_gcum,
|
|
||||||
src_buf[5], scratch_h, scratch_vnew, dst_buf},
|
|
||||||
pc, { H, n_seqs, 1u });
|
|
||||||
|
|
||||||
ggml_vk_sync_buffers(ctx, subctx);
|
|
||||||
|
|
||||||
ggml_vk_dispatch_pipeline(ctx, subctx, pl_output,
|
|
||||||
{src_buf[0], src_buf[1], scratch_h, scratch_vnew, scratch_gcum, dst_buf},
|
|
||||||
pc, { n_chunks * H, n_seqs, 1u });
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx->prealloc_split_k_need_sync = true;
|
ctx->prealloc_split_k_need_sync = true;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,340 +0,0 @@
|
||||||
#version 450
|
|
||||||
|
|
||||||
#extension GL_EXT_control_flow_attributes : enable
|
|
||||||
#extension GL_EXT_shader_16bit_storage : require
|
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
|
||||||
#extension GL_KHR_shader_subgroup_basic : enable
|
|
||||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
|
||||||
#extension GL_KHR_shader_subgroup_shuffle : enable
|
|
||||||
#extension GL_KHR_memory_scope_semantics : enable
|
|
||||||
#extension GL_KHR_cooperative_matrix : enable
|
|
||||||
|
|
||||||
#include "types.glsl"
|
|
||||||
|
|
||||||
// Fused inter+output kernel for chunked gated delta net
|
|
||||||
//
|
|
||||||
// Merges the inter-chunk state propagation and output computation into
|
|
||||||
// one dispatch, eliminating VNew and H_snapshot scratch buffers plus
|
|
||||||
// one dispatch barrier.
|
|
||||||
//
|
|
||||||
// Step 1: Load K, Q, gcum → shared
|
|
||||||
// Step 2: A = Q @ K^T (coopmat)
|
|
||||||
// Step 3: Decay mask → sh_adecay + O_inter = Q @ state → dst (parallel)
|
|
||||||
// Step 4: vnew = U - W@state → sh_kv (f16), accumulate delta
|
|
||||||
// Step 5: O_intra = A_decayed @ vnew → dst (coopmat GEMM)
|
|
||||||
// Step 6: state = exp(decay) * state + delta
|
|
||||||
//
|
|
||||||
// Grid: (H, n_seqs, 1) — sequential chunk loop
|
|
||||||
// Workgroup: 256 threads = 4 subgroups
|
|
||||||
|
|
||||||
layout(constant_id = 0) const uint WG_SIZE = 256;
|
|
||||||
layout(constant_id = 1) const uint CHUNK_SIZE = 64;
|
|
||||||
layout(constant_id = 2) const uint S_V = 128;
|
|
||||||
|
|
||||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
||||||
|
|
||||||
layout(push_constant) uniform Parameters {
|
|
||||||
uint H;
|
|
||||||
uint n_tokens;
|
|
||||||
uint n_seqs;
|
|
||||||
uint sq1, sq2, sq3;
|
|
||||||
uint sv1, sv2, sv3;
|
|
||||||
uint sb1, sb2, sb3;
|
|
||||||
uint neq1, rq3;
|
|
||||||
uint n_chunks;
|
|
||||||
uint s_off;
|
|
||||||
};
|
|
||||||
|
|
||||||
layout(binding = 0) readonly buffer QBuf { float q_in[]; };
|
|
||||||
layout(binding = 1) readonly buffer KBuf { float k_in[]; };
|
|
||||||
layout(binding = 2) readonly buffer WBuf { float w_in[]; };
|
|
||||||
layout(binding = 3) readonly buffer UBuf { float u_in[]; };
|
|
||||||
layout(binding = 4) readonly buffer DecBuf { float decay_in[]; };
|
|
||||||
layout(binding = 5) readonly buffer GCumBuf { float gcum_in[]; };
|
|
||||||
layout(binding = 6) readonly buffer StBuf { float state_in[]; };
|
|
||||||
layout(binding = 7) buffer DstBuf { float dst[]; };
|
|
||||||
|
|
||||||
const uint TM = 16;
|
|
||||||
const uint TN = 16;
|
|
||||||
const uint TK = 16;
|
|
||||||
|
|
||||||
const uint C_TILES = CHUNK_SIZE / TM;
|
|
||||||
const uint D_TILES = S_V / TN;
|
|
||||||
|
|
||||||
const uint QK_STRIDE = S_V / 4 + 2;
|
|
||||||
const uint ATTN_V4_STRIDE = CHUNK_SIZE / 4 + 2;
|
|
||||||
|
|
||||||
shared f16vec4 sh_q[CHUNK_SIZE * QK_STRIDE];
|
|
||||||
shared f16vec4 sh_kv[CHUNK_SIZE * QK_STRIDE];
|
|
||||||
shared vec4 sh_attn[CHUNK_SIZE * ATTN_V4_STRIDE];
|
|
||||||
shared f16vec4 sh_adecay[CHUNK_SIZE * ATTN_V4_STRIDE];
|
|
||||||
shared float sh_gcum[CHUNK_SIZE];
|
|
||||||
shared float sh_w[S_V];
|
|
||||||
shared float sh_kg[S_V];
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
const uint tid = gl_LocalInvocationIndex;
|
|
||||||
const uint sg_id = gl_SubgroupID;
|
|
||||||
const uint si = gl_SubgroupInvocationID;
|
|
||||||
|
|
||||||
const uint head_id = gl_WorkGroupID.x;
|
|
||||||
const uint seq_id = gl_WorkGroupID.y;
|
|
||||||
const uint col = tid;
|
|
||||||
|
|
||||||
const uint iq1 = head_id % neq1;
|
|
||||||
const uint iq3 = seq_id / rq3;
|
|
||||||
const float scale = 1.0 / sqrt(float(S_V));
|
|
||||||
|
|
||||||
const uint state_size = S_V * S_V;
|
|
||||||
const uint state_base = (seq_id * H + head_id) * state_size;
|
|
||||||
const uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;
|
|
||||||
|
|
||||||
// ================================================================
|
|
||||||
// Load state into registers (threads 0-127)
|
|
||||||
// ================================================================
|
|
||||||
|
|
||||||
float state[S_V];
|
|
||||||
if (col < S_V) {
|
|
||||||
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
|
||||||
state[i] = state_in[state_base + i * S_V + col];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ================================================================
|
|
||||||
// Chunk loop
|
|
||||||
// ================================================================
|
|
||||||
|
|
||||||
for (uint c = 0; c < n_chunks; c++) {
|
|
||||||
const uint chunk_start = c * CHUNK_SIZE;
|
|
||||||
const uint chunk_len = min(CHUNK_SIZE, n_tokens - chunk_start);
|
|
||||||
const uint wu_base = ((seq_id * n_chunks + c) * H + head_id) * CHUNK_SIZE * S_V;
|
|
||||||
const uint gcum_base = ((seq_id * n_chunks + c) * H + head_id) * CHUNK_SIZE;
|
|
||||||
const uint decay_idx = (seq_id * n_chunks + c) * H + head_id;
|
|
||||||
const float g_total = decay_in[decay_idx];
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// Step 1: Load K → sh_kv, Q → sh_q, gcum → sh_gcum
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
if (tid < CHUNK_SIZE) {
|
|
||||||
sh_gcum[tid] = (tid < chunk_len) ? gcum_in[gcum_base + tid] : 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (uint idx = tid; idx < CHUNK_SIZE * (S_V / 4); idx += WG_SIZE) {
|
|
||||||
const uint row = idx / (S_V / 4);
|
|
||||||
const uint col4 = idx % (S_V / 4);
|
|
||||||
f16vec4 q_val = f16vec4(0.0);
|
|
||||||
f16vec4 k_val = f16vec4(0.0);
|
|
||||||
if (row < chunk_len) {
|
|
||||||
const uint off = iq3 * sq3 + (chunk_start + row) * sq2 + iq1 * sq1 + col4 * 4;
|
|
||||||
q_val = f16vec4(q_in[off], q_in[off + 1], q_in[off + 2], q_in[off + 3]);
|
|
||||||
k_val = f16vec4(k_in[off], k_in[off + 1], k_in[off + 2], k_in[off + 3]);
|
|
||||||
}
|
|
||||||
sh_q[row * QK_STRIDE + col4] = q_val;
|
|
||||||
sh_kv[row * QK_STRIDE + col4] = k_val;
|
|
||||||
}
|
|
||||||
|
|
||||||
barrier();
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// Step 2: Q @ K^T (coopmat, f16→f32)
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
coopmat<float, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> A_acc[C_TILES];
|
|
||||||
[[unroll]] for (uint tj = 0; tj < C_TILES; tj++) {
|
|
||||||
A_acc[tj] = coopmat<float, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0);
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> Q_mat;
|
|
||||||
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> KT_mat;
|
|
||||||
|
|
||||||
[[unroll]] for (uint dk = 0; dk < D_TILES; dk++) {
|
|
||||||
coopMatLoad(Q_mat, sh_q,
|
|
||||||
sg_id * TM * QK_STRIDE + dk * (TK / 4),
|
|
||||||
QK_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
|
|
||||||
|
|
||||||
[[unroll]] for (uint tj = 0; tj < C_TILES; tj++) {
|
|
||||||
coopMatLoad(KT_mat, sh_kv,
|
|
||||||
tj * TN * QK_STRIDE + dk * (TK / 4),
|
|
||||||
QK_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
|
|
||||||
|
|
||||||
A_acc[tj] = coopMatMulAdd(Q_mat, KT_mat, A_acc[tj]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
[[unroll]] for (uint tj = 0; tj < C_TILES; tj++) {
|
|
||||||
coopMatStore(A_acc[tj], sh_attn,
|
|
||||||
sg_id * TM * ATTN_V4_STRIDE + tj * (TN / 4),
|
|
||||||
ATTN_V4_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
|
|
||||||
}
|
|
||||||
|
|
||||||
barrier();
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// Step 3: Decay mask + inter-chunk output (parallel, no conflict)
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
for (uint idx = tid; idx < CHUNK_SIZE * (CHUNK_SIZE / 4); idx += WG_SIZE) {
|
|
||||||
const uint t = idx / (CHUNK_SIZE / 4);
|
|
||||||
const uint j4 = idx % (CHUNK_SIZE / 4);
|
|
||||||
f16vec4 val = f16vec4(0.0);
|
|
||||||
if (t < chunk_len) {
|
|
||||||
const float g_t = sh_gcum[t];
|
|
||||||
[[unroll]] for (uint k = 0; k < 4; k++) {
|
|
||||||
const uint j = j4 * 4 + k;
|
|
||||||
if (j <= t && j < chunk_len) {
|
|
||||||
val[k] = float16_t(clamp(
|
|
||||||
exp(g_t - sh_gcum[j]) * sh_attn[t * ATTN_V4_STRIDE + j4][k],
|
|
||||||
-65504.0, 65504.0));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sh_adecay[t * ATTN_V4_STRIDE + j4] = val;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (col < S_V) {
|
|
||||||
for (uint t = 0; t < chunk_len; t++) {
|
|
||||||
float o_inter = 0.0;
|
|
||||||
[[unroll]] for (uint i = 0; i < S_V / 4; i++) {
|
|
||||||
vec4 q_f32 = vec4(sh_q[t * QK_STRIDE + i]);
|
|
||||||
o_inter += dot(q_f32,
|
|
||||||
vec4(state[i*4], state[i*4+1],
|
|
||||||
state[i*4+2], state[i*4+3]));
|
|
||||||
}
|
|
||||||
dst[attn_off + (chunk_start + t) * S_V * H + col] = exp(sh_gcum[t]) * o_inter * scale;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
barrier();
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// Step 4: vnew = U - W@state → sh_kv, accumulate delta
|
|
||||||
// Threads 0-127: load w, compute vnew, write sh_kv via shuffle
|
|
||||||
// Threads 128-255: load k_gated into sh_kg
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
float delta[S_V];
|
|
||||||
if (col < S_V) {
|
|
||||||
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
|
||||||
delta[i] = 0.0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (uint t = 0; t < chunk_len; t++) {
|
|
||||||
if (col < S_V) {
|
|
||||||
sh_w[col] = w_in[wu_base + t * S_V + col];
|
|
||||||
}
|
|
||||||
if (tid >= 128 && tid < 256) {
|
|
||||||
const uint lc = tid - 128;
|
|
||||||
const float gcum_t = gcum_in[gcum_base + t];
|
|
||||||
const float decay_factor = exp(g_total - gcum_t);
|
|
||||||
const uint k_off = iq3 * sq3 + (chunk_start + t) * sq2 + iq1 * sq1;
|
|
||||||
sh_kg[lc] = k_in[k_off + lc] * decay_factor;
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
|
|
||||||
if (col < S_V) {
|
|
||||||
float ws = 0.0;
|
|
||||||
[[unroll]] for (uint i = 0; i < S_V; i += 4) {
|
|
||||||
ws += dot(vec4(sh_w[i], sh_w[i+1], sh_w[i+2], sh_w[i+3]),
|
|
||||||
vec4(state[i], state[i+1], state[i+2], state[i+3]));
|
|
||||||
}
|
|
||||||
|
|
||||||
float vnew = u_in[wu_base + t * S_V + col] - ws;
|
|
||||||
float vnew_scaled = clamp(vnew * scale, -65504.0, 65504.0);
|
|
||||||
|
|
||||||
float16_t v0 = float16_t(subgroupShuffle(vnew_scaled, si & ~3u));
|
|
||||||
float16_t v1 = float16_t(subgroupShuffle(vnew_scaled, (si & ~3u) + 1u));
|
|
||||||
float16_t v2 = float16_t(subgroupShuffle(vnew_scaled, (si & ~3u) + 2u));
|
|
||||||
float16_t v3 = float16_t(subgroupShuffle(vnew_scaled, (si & ~3u) + 3u));
|
|
||||||
if ((si & 3u) == 0u) {
|
|
||||||
sh_kv[t * QK_STRIDE + (col >> 2)] = f16vec4(v0, v1, v2, v3);
|
|
||||||
}
|
|
||||||
|
|
||||||
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
|
||||||
delta[i] += sh_kg[i] * vnew;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// Step 5: O_intra = A_decayed @ vnew → dst (coopmat GEMM)
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
if (chunk_len == CHUNK_SIZE) {
|
|
||||||
coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> A_mat;
|
|
||||||
coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> V_mat;
|
|
||||||
|
|
||||||
[[unroll]] for (uint dn = 0; dn < D_TILES; dn++) {
|
|
||||||
coopmat<float, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> O_acc;
|
|
||||||
|
|
||||||
coopMatLoad(O_acc, dst,
|
|
||||||
attn_off + (chunk_start + sg_id * TM) * S_V * H + dn * TN,
|
|
||||||
S_V * H, gl_CooperativeMatrixLayoutRowMajor);
|
|
||||||
|
|
||||||
[[unroll]] for (uint dk = 0; dk < C_TILES; dk++) {
|
|
||||||
coopMatLoad(A_mat, sh_adecay,
|
|
||||||
sg_id * TM * ATTN_V4_STRIDE + dk * (TK / 4),
|
|
||||||
ATTN_V4_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
|
|
||||||
|
|
||||||
coopMatLoad(V_mat, sh_kv,
|
|
||||||
dk * TK * QK_STRIDE + dn * (TN / 4),
|
|
||||||
QK_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
|
|
||||||
|
|
||||||
O_acc = coopMatMulAdd(A_mat, V_mat, O_acc);
|
|
||||||
}
|
|
||||||
|
|
||||||
coopMatStore(O_acc, dst,
|
|
||||||
attn_off + (chunk_start + sg_id * TM) * S_V * H + dn * TN,
|
|
||||||
S_V * H, gl_CooperativeMatrixLayoutRowMajor);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (col < S_V) {
|
|
||||||
const uint col4 = col / 4;
|
|
||||||
const uint comp = col % 4;
|
|
||||||
|
|
||||||
vec4 my_vnew_v4[CHUNK_SIZE / 4];
|
|
||||||
[[unroll]] for (uint j4 = 0; j4 < CHUNK_SIZE / 4; j4++) {
|
|
||||||
my_vnew_v4[j4] = vec4(
|
|
||||||
float(sh_kv[(j4*4 ) * QK_STRIDE + col4][comp]),
|
|
||||||
float(sh_kv[(j4*4+1) * QK_STRIDE + col4][comp]),
|
|
||||||
float(sh_kv[(j4*4+2) * QK_STRIDE + col4][comp]),
|
|
||||||
float(sh_kv[(j4*4+3) * QK_STRIDE + col4][comp]));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (uint t = 0; t < chunk_len; t++) {
|
|
||||||
float o_intra = 0.0;
|
|
||||||
[[unroll]] for (uint j4 = 0; j4 < CHUNK_SIZE / 4; j4++) {
|
|
||||||
o_intra += dot(vec4(sh_adecay[t * ATTN_V4_STRIDE + j4]), my_vnew_v4[j4]);
|
|
||||||
}
|
|
||||||
dst[attn_off + (chunk_start + t) * S_V * H + col] += o_intra;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// Step 6: State update
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
if (col < S_V) {
|
|
||||||
const float total_decay = exp(g_total);
|
|
||||||
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
|
||||||
state[i] = total_decay * state[i] + delta[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ================================================================
|
|
||||||
// Write final state to dst
|
|
||||||
// ================================================================
|
|
||||||
|
|
||||||
if (col < S_V) {
|
|
||||||
[[unroll]] for (uint i = 0; i < S_V; i++) {
|
|
||||||
dst[s_off + state_base + i * S_V + col] = state[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -992,7 +992,6 @@ void process_shaders() {
|
||||||
string_to_spv("gated_delta_net_chunk_inter_f32", "gated_delta_net_chunk_inter.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
string_to_spv("gated_delta_net_chunk_inter_f32", "gated_delta_net_chunk_inter.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||||
string_to_spv("gated_delta_net_chunk_output_f32", "gated_delta_net_chunk_output.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
string_to_spv("gated_delta_net_chunk_output_f32", "gated_delta_net_chunk_output.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||||
string_to_spv("gated_delta_net_chunk_output_cm1_f32", "gated_delta_net_chunk_output_cm1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
string_to_spv("gated_delta_net_chunk_output_cm1_f32", "gated_delta_net_chunk_output_cm1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||||
string_to_spv("gated_delta_net_chunk_fused_cm_f32", "gated_delta_net_chunk_fused_cm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
|
||||||
|
|
||||||
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||||
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue