diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d1c470b1c1..773fe7ddcc 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -831,6 +831,7 @@ struct vk_device_struct { vk_pipeline pipeline_gated_delta_net_chunk_inter; vk_pipeline pipeline_gated_delta_net_chunk_output; vk_pipeline pipeline_gated_delta_net_chunk_output_cm; + vk_pipeline pipeline_gated_delta_net_chunk_fused_cm; vk_pipeline pipeline_ssm_scan_f32_d128; vk_pipeline pipeline_ssm_scan_f32_d256; vk_pipeline pipeline_ssm_conv_f32; @@ -4629,6 +4630,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_chunk_output_cm, "gated_delta_net_chunk_output_cm1_f32_d128", gated_delta_net_chunk_output_cm1_f32_len, gated_delta_net_chunk_output_cm1_f32_data, "main", 6, sizeof(vk_op_gated_delta_net_chunk_push_constants), {1, 1, 1}, {256, 64, 128}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_chunk_fused_cm, "gated_delta_net_chunk_fused_cm_f32_d128", + gated_delta_net_chunk_fused_cm_f32_len, gated_delta_net_chunk_fused_cm_f32_data, "main", + 8, sizeof(vk_op_gated_delta_net_chunk_push_constants), {1, 1, 1}, {256, 64, 128}, 1, true); } if (device->subgroup_arithmetic && device->subgroup_require_full_support) { @@ -10470,45 +10474,12 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s // Chunked parallel path (PP acceleration) const uint32_t n_chunks = (n_tokens + GDN_CHUNK_SIZE - 1) / GDN_CHUNK_SIZE; - vk_pipeline pl_intra = ctx->device->pipeline_gated_delta_net_chunk_intra; - vk_pipeline pl_inter = ctx->device->pipeline_gated_delta_net_chunk_inter; - vk_pipeline pl_output = ctx->device->pipeline_gated_delta_net_chunk_output_cm - ? ctx->device->pipeline_gated_delta_net_chunk_output_cm - : ctx->device->pipeline_gated_delta_net_chunk_output; + vk_pipeline pl_intra = ctx->device->pipeline_gated_delta_net_chunk_intra; + vk_pipeline pl_fused = ctx->device->pipeline_gated_delta_net_chunk_fused_cm; - ggml_pipeline_request_descriptor_sets(ctx, pl_intra, 1); - ggml_pipeline_request_descriptor_sets(ctx, pl_inter, 1); - ggml_pipeline_request_descriptor_sets(ctx, pl_output, 1); - - // Scratch buffer layout within prealloc_split_k - const size_t wu_size = (size_t)n_seqs * n_chunks * H * GDN_CHUNK_SIZE * S_v * sizeof(float); - const size_t d_size = (size_t)n_seqs * n_chunks * H * sizeof(float); - const size_t g_size = (size_t)n_seqs * n_chunks * H * GDN_CHUNK_SIZE * sizeof(float); - const size_t h_size = (size_t)n_seqs * n_chunks * H * S_v * S_v * sizeof(float); - - const size_t w_off = 0; - const size_t u_off = wu_size; - const size_t vn_off = 2 * wu_size; - const size_t dec_off = 3 * wu_size; - const size_t gcum_off = dec_off + d_size; - const size_t h_off = gcum_off + g_size; - const size_t total_scratch = h_off + h_size; - - if (ctx->prealloc_size_split_k < total_scratch) { - ctx->prealloc_size_split_k = total_scratch; - ggml_vk_preallocate_buffers(ctx, subctx); - } - - if (ctx->prealloc_split_k_need_sync) { - ggml_vk_sync_buffers(ctx, subctx); - } - - vk_subbuffer scratch_w = { ctx->prealloc_split_k, w_off, wu_size }; - vk_subbuffer scratch_u = { ctx->prealloc_split_k, u_off, wu_size }; - vk_subbuffer scratch_vnew = { ctx->prealloc_split_k, vn_off, wu_size }; - vk_subbuffer scratch_dec = { ctx->prealloc_split_k, dec_off, d_size }; - vk_subbuffer scratch_gcum = { ctx->prealloc_split_k, gcum_off, g_size }; - vk_subbuffer scratch_h = { ctx->prealloc_split_k, h_off, h_size }; + const size_t wu_size = (size_t)n_seqs * n_chunks * H * GDN_CHUNK_SIZE * S_v * sizeof(float); + const size_t d_size = (size_t)n_seqs * n_chunks * H * sizeof(float); + const size_t g_size = (size_t)n_seqs * n_chunks * H * GDN_CHUNK_SIZE * sizeof(float); const vk_op_gated_delta_net_chunk_push_constants pc = { H, n_tokens, n_seqs, @@ -10519,29 +10490,95 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s n_chunks, s_off }; - // Dispatch 1: Intra-chunk (parallel across chunks) - // Bindings: K, V, G, Beta, W_out, U_out, Decay_out, GCum_out - ggml_vk_dispatch_pipeline(ctx, subctx, pl_intra, - {src_buf[1], src_buf[2], src_buf[3], src_buf[4], - scratch_w, scratch_u, scratch_dec, scratch_gcum}, - pc, { n_chunks * H, n_seqs, 1u }); + if (pl_fused) { + // Fused inter+output path: 2 dispatches, no vnew/h scratch + ggml_pipeline_request_descriptor_sets(ctx, pl_intra, 1); + ggml_pipeline_request_descriptor_sets(ctx, pl_fused, 1); - ggml_vk_sync_buffers(ctx, subctx); + const size_t w_off = 0; + const size_t u_off = wu_size; + const size_t dec_off = 2 * wu_size; + const size_t gcum_off = dec_off + d_size; + const size_t total_scratch = gcum_off + g_size; - // Dispatch 2: Inter-chunk state propagation (sequential across chunks) - // Bindings: K, W, U, Decay, GCum, State, H_out, VNew_out, Final(dst) - ggml_vk_dispatch_pipeline(ctx, subctx, pl_inter, - {src_buf[1], scratch_w, scratch_u, scratch_dec, scratch_gcum, - src_buf[5], scratch_h, scratch_vnew, dst_buf}, - pc, { H, n_seqs, 1u }); + if (ctx->prealloc_size_split_k < total_scratch) { + ctx->prealloc_size_split_k = total_scratch; + ggml_vk_preallocate_buffers(ctx, subctx); + } - ggml_vk_sync_buffers(ctx, subctx); + if (ctx->prealloc_split_k_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } - // Dispatch 3: Output (parallel across chunks) - // Bindings: Q, K, H, VNew, GCum, Dst - ggml_vk_dispatch_pipeline(ctx, subctx, pl_output, - {src_buf[0], src_buf[1], scratch_h, scratch_vnew, scratch_gcum, dst_buf}, - pc, { n_chunks * H, n_seqs, 1u }); + vk_subbuffer scratch_w = { ctx->prealloc_split_k, w_off, wu_size }; + vk_subbuffer scratch_u = { ctx->prealloc_split_k, u_off, wu_size }; + vk_subbuffer scratch_dec = { ctx->prealloc_split_k, dec_off, d_size }; + vk_subbuffer scratch_gcum = { ctx->prealloc_split_k, gcum_off, g_size }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pl_intra, + {src_buf[1], src_buf[2], src_buf[3], src_buf[4], + scratch_w, scratch_u, scratch_dec, scratch_gcum}, + pc, { n_chunks * H, n_seqs, 1u }); + + ggml_vk_sync_buffers(ctx, subctx); + + // Bindings: Q, K, W, U, Decay, GCum, State, Dst + ggml_vk_dispatch_pipeline(ctx, subctx, pl_fused, + {src_buf[0], src_buf[1], scratch_w, scratch_u, + scratch_dec, scratch_gcum, src_buf[5], dst_buf}, + pc, { H, n_seqs, 1u }); + } else { + // Fallback: 3-dispatch path (no coopmat) + vk_pipeline pl_inter = ctx->device->pipeline_gated_delta_net_chunk_inter; + vk_pipeline pl_output = ctx->device->pipeline_gated_delta_net_chunk_output; + + ggml_pipeline_request_descriptor_sets(ctx, pl_intra, 1); + ggml_pipeline_request_descriptor_sets(ctx, pl_inter, 1); + ggml_pipeline_request_descriptor_sets(ctx, pl_output, 1); + + const size_t h_size = (size_t)n_seqs * n_chunks * H * S_v * S_v * sizeof(float); + const size_t w_off = 0; + const size_t u_off = wu_size; + const size_t vn_off = 2 * wu_size; + const size_t dec_off = 3 * wu_size; + const size_t gcum_off = dec_off + d_size; + const size_t h_off = gcum_off + g_size; + const size_t total_scratch = h_off + h_size; + + if (ctx->prealloc_size_split_k < total_scratch) { + ctx->prealloc_size_split_k = total_scratch; + ggml_vk_preallocate_buffers(ctx, subctx); + } + + if (ctx->prealloc_split_k_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + + vk_subbuffer scratch_w = { ctx->prealloc_split_k, w_off, wu_size }; + vk_subbuffer scratch_u = { ctx->prealloc_split_k, u_off, wu_size }; + vk_subbuffer scratch_vnew = { ctx->prealloc_split_k, vn_off, wu_size }; + vk_subbuffer scratch_dec = { ctx->prealloc_split_k, dec_off, d_size }; + vk_subbuffer scratch_gcum = { ctx->prealloc_split_k, gcum_off, g_size }; + vk_subbuffer scratch_h = { ctx->prealloc_split_k, h_off, h_size }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pl_intra, + {src_buf[1], src_buf[2], src_buf[3], src_buf[4], + scratch_w, scratch_u, scratch_dec, scratch_gcum}, + pc, { n_chunks * H, n_seqs, 1u }); + + ggml_vk_sync_buffers(ctx, subctx); + + ggml_vk_dispatch_pipeline(ctx, subctx, pl_inter, + {src_buf[1], scratch_w, scratch_u, scratch_dec, scratch_gcum, + src_buf[5], scratch_h, scratch_vnew, dst_buf}, + pc, { H, n_seqs, 1u }); + + ggml_vk_sync_buffers(ctx, subctx); + + ggml_vk_dispatch_pipeline(ctx, subctx, pl_output, + {src_buf[0], src_buf[1], scratch_h, scratch_vnew, scratch_gcum, dst_buf}, + pc, { n_chunks * H, n_seqs, 1u }); + } ctx->prealloc_split_k_need_sync = true; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_fused_cm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_fused_cm.comp new file mode 100644 index 0000000000..4d676034c2 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_fused_cm.comp @@ -0,0 +1,340 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_shuffle : enable +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_cooperative_matrix : enable + +#include "types.glsl" + +// Fused inter+output kernel for chunked gated delta net +// +// Merges the inter-chunk state propagation and output computation into +// one dispatch, eliminating VNew and H_snapshot scratch buffers plus +// one dispatch barrier. +// +// Step 1: Load K, Q, gcum → shared +// Step 2: A = Q @ K^T (coopmat) +// Step 3: Decay mask → sh_adecay + O_inter = Q @ state → dst (parallel) +// Step 4: vnew = U - W@state → sh_kv (f16), accumulate delta +// Step 5: O_intra = A_decayed @ vnew → dst (coopmat GEMM) +// Step 6: state = exp(decay) * state + delta +// +// Grid: (H, n_seqs, 1) — sequential chunk loop +// Workgroup: 256 threads = 4 subgroups + +layout(constant_id = 0) const uint WG_SIZE = 256; +layout(constant_id = 1) const uint CHUNK_SIZE = 64; +layout(constant_id = 2) const uint S_V = 128; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint H; + uint n_tokens; + uint n_seqs; + uint sq1, sq2, sq3; + uint sv1, sv2, sv3; + uint sb1, sb2, sb3; + uint neq1, rq3; + uint n_chunks; + uint s_off; +}; + +layout(binding = 0) readonly buffer QBuf { float q_in[]; }; +layout(binding = 1) readonly buffer KBuf { float k_in[]; }; +layout(binding = 2) readonly buffer WBuf { float w_in[]; }; +layout(binding = 3) readonly buffer UBuf { float u_in[]; }; +layout(binding = 4) readonly buffer DecBuf { float decay_in[]; }; +layout(binding = 5) readonly buffer GCumBuf { float gcum_in[]; }; +layout(binding = 6) readonly buffer StBuf { float state_in[]; }; +layout(binding = 7) buffer DstBuf { float dst[]; }; + +const uint TM = 16; +const uint TN = 16; +const uint TK = 16; + +const uint C_TILES = CHUNK_SIZE / TM; +const uint D_TILES = S_V / TN; + +const uint QK_STRIDE = S_V / 4 + 2; +const uint ATTN_V4_STRIDE = CHUNK_SIZE / 4 + 2; + +shared f16vec4 sh_q[CHUNK_SIZE * QK_STRIDE]; +shared f16vec4 sh_kv[CHUNK_SIZE * QK_STRIDE]; +shared vec4 sh_attn[CHUNK_SIZE * ATTN_V4_STRIDE]; +shared f16vec4 sh_adecay[CHUNK_SIZE * ATTN_V4_STRIDE]; +shared float sh_gcum[CHUNK_SIZE]; +shared float sh_w[S_V]; +shared float sh_kg[S_V]; + +void main() { + const uint tid = gl_LocalInvocationIndex; + const uint sg_id = gl_SubgroupID; + const uint si = gl_SubgroupInvocationID; + + const uint head_id = gl_WorkGroupID.x; + const uint seq_id = gl_WorkGroupID.y; + const uint col = tid; + + const uint iq1 = head_id % neq1; + const uint iq3 = seq_id / rq3; + const float scale = 1.0 / sqrt(float(S_V)); + + const uint state_size = S_V * S_V; + const uint state_base = (seq_id * H + head_id) * state_size; + const uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; + + // ================================================================ + // Load state into registers (threads 0-127) + // ================================================================ + + float state[S_V]; + if (col < S_V) { + [[unroll]] for (uint i = 0; i < S_V; i++) { + state[i] = state_in[state_base + i * S_V + col]; + } + } + + // ================================================================ + // Chunk loop + // ================================================================ + + for (uint c = 0; c < n_chunks; c++) { + const uint chunk_start = c * CHUNK_SIZE; + const uint chunk_len = min(CHUNK_SIZE, n_tokens - chunk_start); + const uint wu_base = ((seq_id * n_chunks + c) * H + head_id) * CHUNK_SIZE * S_V; + const uint gcum_base = ((seq_id * n_chunks + c) * H + head_id) * CHUNK_SIZE; + const uint decay_idx = (seq_id * n_chunks + c) * H + head_id; + const float g_total = decay_in[decay_idx]; + + // ============================================================ + // Step 1: Load K → sh_kv, Q → sh_q, gcum → sh_gcum + // ============================================================ + + if (tid < CHUNK_SIZE) { + sh_gcum[tid] = (tid < chunk_len) ? gcum_in[gcum_base + tid] : 0.0; + } + + for (uint idx = tid; idx < CHUNK_SIZE * (S_V / 4); idx += WG_SIZE) { + const uint row = idx / (S_V / 4); + const uint col4 = idx % (S_V / 4); + f16vec4 q_val = f16vec4(0.0); + f16vec4 k_val = f16vec4(0.0); + if (row < chunk_len) { + const uint off = iq3 * sq3 + (chunk_start + row) * sq2 + iq1 * sq1 + col4 * 4; + q_val = f16vec4(q_in[off], q_in[off + 1], q_in[off + 2], q_in[off + 3]); + k_val = f16vec4(k_in[off], k_in[off + 1], k_in[off + 2], k_in[off + 3]); + } + sh_q[row * QK_STRIDE + col4] = q_val; + sh_kv[row * QK_STRIDE + col4] = k_val; + } + + barrier(); + + // ============================================================ + // Step 2: Q @ K^T (coopmat, f16→f32) + // ============================================================ + + coopmat A_acc[C_TILES]; + [[unroll]] for (uint tj = 0; tj < C_TILES; tj++) { + A_acc[tj] = coopmat(0.0); + } + + { + coopmat Q_mat; + coopmat KT_mat; + + [[unroll]] for (uint dk = 0; dk < D_TILES; dk++) { + coopMatLoad(Q_mat, sh_q, + sg_id * TM * QK_STRIDE + dk * (TK / 4), + QK_STRIDE, gl_CooperativeMatrixLayoutRowMajor); + + [[unroll]] for (uint tj = 0; tj < C_TILES; tj++) { + coopMatLoad(KT_mat, sh_kv, + tj * TN * QK_STRIDE + dk * (TK / 4), + QK_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); + + A_acc[tj] = coopMatMulAdd(Q_mat, KT_mat, A_acc[tj]); + } + } + } + + [[unroll]] for (uint tj = 0; tj < C_TILES; tj++) { + coopMatStore(A_acc[tj], sh_attn, + sg_id * TM * ATTN_V4_STRIDE + tj * (TN / 4), + ATTN_V4_STRIDE, gl_CooperativeMatrixLayoutRowMajor); + } + + barrier(); + + // ============================================================ + // Step 3: Decay mask + inter-chunk output (parallel, no conflict) + // ============================================================ + + for (uint idx = tid; idx < CHUNK_SIZE * (CHUNK_SIZE / 4); idx += WG_SIZE) { + const uint t = idx / (CHUNK_SIZE / 4); + const uint j4 = idx % (CHUNK_SIZE / 4); + f16vec4 val = f16vec4(0.0); + if (t < chunk_len) { + const float g_t = sh_gcum[t]; + [[unroll]] for (uint k = 0; k < 4; k++) { + const uint j = j4 * 4 + k; + if (j <= t && j < chunk_len) { + val[k] = float16_t(clamp( + exp(g_t - sh_gcum[j]) * sh_attn[t * ATTN_V4_STRIDE + j4][k], + -65504.0, 65504.0)); + } + } + } + sh_adecay[t * ATTN_V4_STRIDE + j4] = val; + } + + if (col < S_V) { + for (uint t = 0; t < chunk_len; t++) { + float o_inter = 0.0; + [[unroll]] for (uint i = 0; i < S_V / 4; i++) { + vec4 q_f32 = vec4(sh_q[t * QK_STRIDE + i]); + o_inter += dot(q_f32, + vec4(state[i*4], state[i*4+1], + state[i*4+2], state[i*4+3])); + } + dst[attn_off + (chunk_start + t) * S_V * H + col] = exp(sh_gcum[t]) * o_inter * scale; + } + } + + barrier(); + + // ============================================================ + // Step 4: vnew = U - W@state → sh_kv, accumulate delta + // Threads 0-127: load w, compute vnew, write sh_kv via shuffle + // Threads 128-255: load k_gated into sh_kg + // ============================================================ + + float delta[S_V]; + if (col < S_V) { + [[unroll]] for (uint i = 0; i < S_V; i++) { + delta[i] = 0.0; + } + } + + for (uint t = 0; t < chunk_len; t++) { + if (col < S_V) { + sh_w[col] = w_in[wu_base + t * S_V + col]; + } + if (tid >= 128 && tid < 256) { + const uint lc = tid - 128; + const float gcum_t = gcum_in[gcum_base + t]; + const float decay_factor = exp(g_total - gcum_t); + const uint k_off = iq3 * sq3 + (chunk_start + t) * sq2 + iq1 * sq1; + sh_kg[lc] = k_in[k_off + lc] * decay_factor; + } + barrier(); + + if (col < S_V) { + float ws = 0.0; + [[unroll]] for (uint i = 0; i < S_V; i += 4) { + ws += dot(vec4(sh_w[i], sh_w[i+1], sh_w[i+2], sh_w[i+3]), + vec4(state[i], state[i+1], state[i+2], state[i+3])); + } + + float vnew = u_in[wu_base + t * S_V + col] - ws; + float vnew_scaled = clamp(vnew * scale, -65504.0, 65504.0); + + float16_t v0 = float16_t(subgroupShuffle(vnew_scaled, si & ~3u)); + float16_t v1 = float16_t(subgroupShuffle(vnew_scaled, (si & ~3u) + 1u)); + float16_t v2 = float16_t(subgroupShuffle(vnew_scaled, (si & ~3u) + 2u)); + float16_t v3 = float16_t(subgroupShuffle(vnew_scaled, (si & ~3u) + 3u)); + if ((si & 3u) == 0u) { + sh_kv[t * QK_STRIDE + (col >> 2)] = f16vec4(v0, v1, v2, v3); + } + + [[unroll]] for (uint i = 0; i < S_V; i++) { + delta[i] += sh_kg[i] * vnew; + } + } + barrier(); + } + + // ============================================================ + // Step 5: O_intra = A_decayed @ vnew → dst (coopmat GEMM) + // ============================================================ + + if (chunk_len == CHUNK_SIZE) { + coopmat A_mat; + coopmat V_mat; + + [[unroll]] for (uint dn = 0; dn < D_TILES; dn++) { + coopmat O_acc; + + coopMatLoad(O_acc, dst, + attn_off + (chunk_start + sg_id * TM) * S_V * H + dn * TN, + S_V * H, gl_CooperativeMatrixLayoutRowMajor); + + [[unroll]] for (uint dk = 0; dk < C_TILES; dk++) { + coopMatLoad(A_mat, sh_adecay, + sg_id * TM * ATTN_V4_STRIDE + dk * (TK / 4), + ATTN_V4_STRIDE, gl_CooperativeMatrixLayoutRowMajor); + + coopMatLoad(V_mat, sh_kv, + dk * TK * QK_STRIDE + dn * (TN / 4), + QK_STRIDE, gl_CooperativeMatrixLayoutRowMajor); + + O_acc = coopMatMulAdd(A_mat, V_mat, O_acc); + } + + coopMatStore(O_acc, dst, + attn_off + (chunk_start + sg_id * TM) * S_V * H + dn * TN, + S_V * H, gl_CooperativeMatrixLayoutRowMajor); + } + } else { + if (col < S_V) { + const uint col4 = col / 4; + const uint comp = col % 4; + + vec4 my_vnew_v4[CHUNK_SIZE / 4]; + [[unroll]] for (uint j4 = 0; j4 < CHUNK_SIZE / 4; j4++) { + my_vnew_v4[j4] = vec4( + float(sh_kv[(j4*4 ) * QK_STRIDE + col4][comp]), + float(sh_kv[(j4*4+1) * QK_STRIDE + col4][comp]), + float(sh_kv[(j4*4+2) * QK_STRIDE + col4][comp]), + float(sh_kv[(j4*4+3) * QK_STRIDE + col4][comp])); + } + + for (uint t = 0; t < chunk_len; t++) { + float o_intra = 0.0; + [[unroll]] for (uint j4 = 0; j4 < CHUNK_SIZE / 4; j4++) { + o_intra += dot(vec4(sh_adecay[t * ATTN_V4_STRIDE + j4]), my_vnew_v4[j4]); + } + dst[attn_off + (chunk_start + t) * S_V * H + col] += o_intra; + } + } + } + + // ============================================================ + // Step 6: State update + // ============================================================ + + if (col < S_V) { + const float total_decay = exp(g_total); + [[unroll]] for (uint i = 0; i < S_V; i++) { + state[i] = total_decay * state[i] + delta[i]; + } + } + } + + // ================================================================ + // Write final state to dst + // ================================================================ + + if (col < S_V) { + [[unroll]] for (uint i = 0; i < S_V; i++) { + dst[s_off + state_base + i * S_V + col] = state[i]; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 2ddd4f2f8c..9fcdcb9545 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -992,6 +992,7 @@ void process_shaders() { string_to_spv("gated_delta_net_chunk_inter_f32", "gated_delta_net_chunk_inter.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("gated_delta_net_chunk_output_f32", "gated_delta_net_chunk_output.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("gated_delta_net_chunk_output_cm1_f32", "gated_delta_net_chunk_output_cm1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("gated_delta_net_chunk_fused_cm_f32", "gated_delta_net_chunk_fused_cm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));