diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 773fe7ddcc..d1c470b1c1 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -831,7 +831,6 @@ struct vk_device_struct { vk_pipeline pipeline_gated_delta_net_chunk_inter; vk_pipeline pipeline_gated_delta_net_chunk_output; vk_pipeline pipeline_gated_delta_net_chunk_output_cm; - vk_pipeline pipeline_gated_delta_net_chunk_fused_cm; vk_pipeline pipeline_ssm_scan_f32_d128; vk_pipeline pipeline_ssm_scan_f32_d256; vk_pipeline pipeline_ssm_conv_f32; @@ -4630,9 +4629,6 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_chunk_output_cm, "gated_delta_net_chunk_output_cm1_f32_d128", gated_delta_net_chunk_output_cm1_f32_len, gated_delta_net_chunk_output_cm1_f32_data, "main", 6, sizeof(vk_op_gated_delta_net_chunk_push_constants), {1, 1, 1}, {256, 64, 128}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_chunk_fused_cm, "gated_delta_net_chunk_fused_cm_f32_d128", - gated_delta_net_chunk_fused_cm_f32_len, gated_delta_net_chunk_fused_cm_f32_data, "main", - 8, sizeof(vk_op_gated_delta_net_chunk_push_constants), {1, 1, 1}, {256, 64, 128}, 1, true); } if (device->subgroup_arithmetic && device->subgroup_require_full_support) { @@ -10474,12 +10470,45 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s // Chunked parallel path (PP acceleration) const uint32_t n_chunks = (n_tokens + GDN_CHUNK_SIZE - 1) / GDN_CHUNK_SIZE; - vk_pipeline pl_intra = ctx->device->pipeline_gated_delta_net_chunk_intra; - vk_pipeline pl_fused = ctx->device->pipeline_gated_delta_net_chunk_fused_cm; + vk_pipeline pl_intra = ctx->device->pipeline_gated_delta_net_chunk_intra; + vk_pipeline pl_inter = ctx->device->pipeline_gated_delta_net_chunk_inter; + vk_pipeline pl_output = ctx->device->pipeline_gated_delta_net_chunk_output_cm + ? ctx->device->pipeline_gated_delta_net_chunk_output_cm + : ctx->device->pipeline_gated_delta_net_chunk_output; - const size_t wu_size = (size_t)n_seqs * n_chunks * H * GDN_CHUNK_SIZE * S_v * sizeof(float); - const size_t d_size = (size_t)n_seqs * n_chunks * H * sizeof(float); - const size_t g_size = (size_t)n_seqs * n_chunks * H * GDN_CHUNK_SIZE * sizeof(float); + ggml_pipeline_request_descriptor_sets(ctx, pl_intra, 1); + ggml_pipeline_request_descriptor_sets(ctx, pl_inter, 1); + ggml_pipeline_request_descriptor_sets(ctx, pl_output, 1); + + // Scratch buffer layout within prealloc_split_k + const size_t wu_size = (size_t)n_seqs * n_chunks * H * GDN_CHUNK_SIZE * S_v * sizeof(float); + const size_t d_size = (size_t)n_seqs * n_chunks * H * sizeof(float); + const size_t g_size = (size_t)n_seqs * n_chunks * H * GDN_CHUNK_SIZE * sizeof(float); + const size_t h_size = (size_t)n_seqs * n_chunks * H * S_v * S_v * sizeof(float); + + const size_t w_off = 0; + const size_t u_off = wu_size; + const size_t vn_off = 2 * wu_size; + const size_t dec_off = 3 * wu_size; + const size_t gcum_off = dec_off + d_size; + const size_t h_off = gcum_off + g_size; + const size_t total_scratch = h_off + h_size; + + if (ctx->prealloc_size_split_k < total_scratch) { + ctx->prealloc_size_split_k = total_scratch; + ggml_vk_preallocate_buffers(ctx, subctx); + } + + if (ctx->prealloc_split_k_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + + vk_subbuffer scratch_w = { ctx->prealloc_split_k, w_off, wu_size }; + vk_subbuffer scratch_u = { ctx->prealloc_split_k, u_off, wu_size }; + vk_subbuffer scratch_vnew = { ctx->prealloc_split_k, vn_off, wu_size }; + vk_subbuffer scratch_dec = { ctx->prealloc_split_k, dec_off, d_size }; + vk_subbuffer scratch_gcum = { ctx->prealloc_split_k, gcum_off, g_size }; + vk_subbuffer scratch_h = { ctx->prealloc_split_k, h_off, h_size }; const vk_op_gated_delta_net_chunk_push_constants pc = { H, n_tokens, n_seqs, @@ -10490,95 +10519,29 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s n_chunks, s_off }; - if (pl_fused) { - // Fused inter+output path: 2 dispatches, no vnew/h scratch - ggml_pipeline_request_descriptor_sets(ctx, pl_intra, 1); - ggml_pipeline_request_descriptor_sets(ctx, pl_fused, 1); + // Dispatch 1: Intra-chunk (parallel across chunks) + // Bindings: K, V, G, Beta, W_out, U_out, Decay_out, GCum_out + ggml_vk_dispatch_pipeline(ctx, subctx, pl_intra, + {src_buf[1], src_buf[2], src_buf[3], src_buf[4], + scratch_w, scratch_u, scratch_dec, scratch_gcum}, + pc, { n_chunks * H, n_seqs, 1u }); - const size_t w_off = 0; - const size_t u_off = wu_size; - const size_t dec_off = 2 * wu_size; - const size_t gcum_off = dec_off + d_size; - const size_t total_scratch = gcum_off + g_size; + ggml_vk_sync_buffers(ctx, subctx); - if (ctx->prealloc_size_split_k < total_scratch) { - ctx->prealloc_size_split_k = total_scratch; - ggml_vk_preallocate_buffers(ctx, subctx); - } + // Dispatch 2: Inter-chunk state propagation (sequential across chunks) + // Bindings: K, W, U, Decay, GCum, State, H_out, VNew_out, Final(dst) + ggml_vk_dispatch_pipeline(ctx, subctx, pl_inter, + {src_buf[1], scratch_w, scratch_u, scratch_dec, scratch_gcum, + src_buf[5], scratch_h, scratch_vnew, dst_buf}, + pc, { H, n_seqs, 1u }); - if (ctx->prealloc_split_k_need_sync) { - ggml_vk_sync_buffers(ctx, subctx); - } + ggml_vk_sync_buffers(ctx, subctx); - vk_subbuffer scratch_w = { ctx->prealloc_split_k, w_off, wu_size }; - vk_subbuffer scratch_u = { ctx->prealloc_split_k, u_off, wu_size }; - vk_subbuffer scratch_dec = { ctx->prealloc_split_k, dec_off, d_size }; - vk_subbuffer scratch_gcum = { ctx->prealloc_split_k, gcum_off, g_size }; - - ggml_vk_dispatch_pipeline(ctx, subctx, pl_intra, - {src_buf[1], src_buf[2], src_buf[3], src_buf[4], - scratch_w, scratch_u, scratch_dec, scratch_gcum}, - pc, { n_chunks * H, n_seqs, 1u }); - - ggml_vk_sync_buffers(ctx, subctx); - - // Bindings: Q, K, W, U, Decay, GCum, State, Dst - ggml_vk_dispatch_pipeline(ctx, subctx, pl_fused, - {src_buf[0], src_buf[1], scratch_w, scratch_u, - scratch_dec, scratch_gcum, src_buf[5], dst_buf}, - pc, { H, n_seqs, 1u }); - } else { - // Fallback: 3-dispatch path (no coopmat) - vk_pipeline pl_inter = ctx->device->pipeline_gated_delta_net_chunk_inter; - vk_pipeline pl_output = ctx->device->pipeline_gated_delta_net_chunk_output; - - ggml_pipeline_request_descriptor_sets(ctx, pl_intra, 1); - ggml_pipeline_request_descriptor_sets(ctx, pl_inter, 1); - ggml_pipeline_request_descriptor_sets(ctx, pl_output, 1); - - const size_t h_size = (size_t)n_seqs * n_chunks * H * S_v * S_v * sizeof(float); - const size_t w_off = 0; - const size_t u_off = wu_size; - const size_t vn_off = 2 * wu_size; - const size_t dec_off = 3 * wu_size; - const size_t gcum_off = dec_off + d_size; - const size_t h_off = gcum_off + g_size; - const size_t total_scratch = h_off + h_size; - - if (ctx->prealloc_size_split_k < total_scratch) { - ctx->prealloc_size_split_k = total_scratch; - ggml_vk_preallocate_buffers(ctx, subctx); - } - - if (ctx->prealloc_split_k_need_sync) { - ggml_vk_sync_buffers(ctx, subctx); - } - - vk_subbuffer scratch_w = { ctx->prealloc_split_k, w_off, wu_size }; - vk_subbuffer scratch_u = { ctx->prealloc_split_k, u_off, wu_size }; - vk_subbuffer scratch_vnew = { ctx->prealloc_split_k, vn_off, wu_size }; - vk_subbuffer scratch_dec = { ctx->prealloc_split_k, dec_off, d_size }; - vk_subbuffer scratch_gcum = { ctx->prealloc_split_k, gcum_off, g_size }; - vk_subbuffer scratch_h = { ctx->prealloc_split_k, h_off, h_size }; - - ggml_vk_dispatch_pipeline(ctx, subctx, pl_intra, - {src_buf[1], src_buf[2], src_buf[3], src_buf[4], - scratch_w, scratch_u, scratch_dec, scratch_gcum}, - pc, { n_chunks * H, n_seqs, 1u }); - - ggml_vk_sync_buffers(ctx, subctx); - - ggml_vk_dispatch_pipeline(ctx, subctx, pl_inter, - {src_buf[1], scratch_w, scratch_u, scratch_dec, scratch_gcum, - src_buf[5], scratch_h, scratch_vnew, dst_buf}, - pc, { H, n_seqs, 1u }); - - ggml_vk_sync_buffers(ctx, subctx); - - ggml_vk_dispatch_pipeline(ctx, subctx, pl_output, - {src_buf[0], src_buf[1], scratch_h, scratch_vnew, scratch_gcum, dst_buf}, - pc, { n_chunks * H, n_seqs, 1u }); - } + // Dispatch 3: Output (parallel across chunks) + // Bindings: Q, K, H, VNew, GCum, Dst + ggml_vk_dispatch_pipeline(ctx, subctx, pl_output, + {src_buf[0], src_buf[1], scratch_h, scratch_vnew, scratch_gcum, dst_buf}, + pc, { n_chunks * H, n_seqs, 1u }); ctx->prealloc_split_k_need_sync = true; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_fused_cm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_fused_cm.comp deleted file mode 100644 index 4d676034c2..0000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_fused_cm.comp +++ /dev/null @@ -1,340 +0,0 @@ -#version 450 - -#extension GL_EXT_control_flow_attributes : enable -#extension GL_EXT_shader_16bit_storage : require -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require -#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#extension GL_KHR_shader_subgroup_basic : enable -#extension GL_KHR_shader_subgroup_arithmetic : enable -#extension GL_KHR_shader_subgroup_shuffle : enable -#extension GL_KHR_memory_scope_semantics : enable -#extension GL_KHR_cooperative_matrix : enable - -#include "types.glsl" - -// Fused inter+output kernel for chunked gated delta net -// -// Merges the inter-chunk state propagation and output computation into -// one dispatch, eliminating VNew and H_snapshot scratch buffers plus -// one dispatch barrier. -// -// Step 1: Load K, Q, gcum → shared -// Step 2: A = Q @ K^T (coopmat) -// Step 3: Decay mask → sh_adecay + O_inter = Q @ state → dst (parallel) -// Step 4: vnew = U - W@state → sh_kv (f16), accumulate delta -// Step 5: O_intra = A_decayed @ vnew → dst (coopmat GEMM) -// Step 6: state = exp(decay) * state + delta -// -// Grid: (H, n_seqs, 1) — sequential chunk loop -// Workgroup: 256 threads = 4 subgroups - -layout(constant_id = 0) const uint WG_SIZE = 256; -layout(constant_id = 1) const uint CHUNK_SIZE = 64; -layout(constant_id = 2) const uint S_V = 128; - -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout(push_constant) uniform Parameters { - uint H; - uint n_tokens; - uint n_seqs; - uint sq1, sq2, sq3; - uint sv1, sv2, sv3; - uint sb1, sb2, sb3; - uint neq1, rq3; - uint n_chunks; - uint s_off; -}; - -layout(binding = 0) readonly buffer QBuf { float q_in[]; }; -layout(binding = 1) readonly buffer KBuf { float k_in[]; }; -layout(binding = 2) readonly buffer WBuf { float w_in[]; }; -layout(binding = 3) readonly buffer UBuf { float u_in[]; }; -layout(binding = 4) readonly buffer DecBuf { float decay_in[]; }; -layout(binding = 5) readonly buffer GCumBuf { float gcum_in[]; }; -layout(binding = 6) readonly buffer StBuf { float state_in[]; }; -layout(binding = 7) buffer DstBuf { float dst[]; }; - -const uint TM = 16; -const uint TN = 16; -const uint TK = 16; - -const uint C_TILES = CHUNK_SIZE / TM; -const uint D_TILES = S_V / TN; - -const uint QK_STRIDE = S_V / 4 + 2; -const uint ATTN_V4_STRIDE = CHUNK_SIZE / 4 + 2; - -shared f16vec4 sh_q[CHUNK_SIZE * QK_STRIDE]; -shared f16vec4 sh_kv[CHUNK_SIZE * QK_STRIDE]; -shared vec4 sh_attn[CHUNK_SIZE * ATTN_V4_STRIDE]; -shared f16vec4 sh_adecay[CHUNK_SIZE * ATTN_V4_STRIDE]; -shared float sh_gcum[CHUNK_SIZE]; -shared float sh_w[S_V]; -shared float sh_kg[S_V]; - -void main() { - const uint tid = gl_LocalInvocationIndex; - const uint sg_id = gl_SubgroupID; - const uint si = gl_SubgroupInvocationID; - - const uint head_id = gl_WorkGroupID.x; - const uint seq_id = gl_WorkGroupID.y; - const uint col = tid; - - const uint iq1 = head_id % neq1; - const uint iq3 = seq_id / rq3; - const float scale = 1.0 / sqrt(float(S_V)); - - const uint state_size = S_V * S_V; - const uint state_base = (seq_id * H + head_id) * state_size; - const uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; - - // ================================================================ - // Load state into registers (threads 0-127) - // ================================================================ - - float state[S_V]; - if (col < S_V) { - [[unroll]] for (uint i = 0; i < S_V; i++) { - state[i] = state_in[state_base + i * S_V + col]; - } - } - - // ================================================================ - // Chunk loop - // ================================================================ - - for (uint c = 0; c < n_chunks; c++) { - const uint chunk_start = c * CHUNK_SIZE; - const uint chunk_len = min(CHUNK_SIZE, n_tokens - chunk_start); - const uint wu_base = ((seq_id * n_chunks + c) * H + head_id) * CHUNK_SIZE * S_V; - const uint gcum_base = ((seq_id * n_chunks + c) * H + head_id) * CHUNK_SIZE; - const uint decay_idx = (seq_id * n_chunks + c) * H + head_id; - const float g_total = decay_in[decay_idx]; - - // ============================================================ - // Step 1: Load K → sh_kv, Q → sh_q, gcum → sh_gcum - // ============================================================ - - if (tid < CHUNK_SIZE) { - sh_gcum[tid] = (tid < chunk_len) ? gcum_in[gcum_base + tid] : 0.0; - } - - for (uint idx = tid; idx < CHUNK_SIZE * (S_V / 4); idx += WG_SIZE) { - const uint row = idx / (S_V / 4); - const uint col4 = idx % (S_V / 4); - f16vec4 q_val = f16vec4(0.0); - f16vec4 k_val = f16vec4(0.0); - if (row < chunk_len) { - const uint off = iq3 * sq3 + (chunk_start + row) * sq2 + iq1 * sq1 + col4 * 4; - q_val = f16vec4(q_in[off], q_in[off + 1], q_in[off + 2], q_in[off + 3]); - k_val = f16vec4(k_in[off], k_in[off + 1], k_in[off + 2], k_in[off + 3]); - } - sh_q[row * QK_STRIDE + col4] = q_val; - sh_kv[row * QK_STRIDE + col4] = k_val; - } - - barrier(); - - // ============================================================ - // Step 2: Q @ K^T (coopmat, f16→f32) - // ============================================================ - - coopmat A_acc[C_TILES]; - [[unroll]] for (uint tj = 0; tj < C_TILES; tj++) { - A_acc[tj] = coopmat(0.0); - } - - { - coopmat Q_mat; - coopmat KT_mat; - - [[unroll]] for (uint dk = 0; dk < D_TILES; dk++) { - coopMatLoad(Q_mat, sh_q, - sg_id * TM * QK_STRIDE + dk * (TK / 4), - QK_STRIDE, gl_CooperativeMatrixLayoutRowMajor); - - [[unroll]] for (uint tj = 0; tj < C_TILES; tj++) { - coopMatLoad(KT_mat, sh_kv, - tj * TN * QK_STRIDE + dk * (TK / 4), - QK_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); - - A_acc[tj] = coopMatMulAdd(Q_mat, KT_mat, A_acc[tj]); - } - } - } - - [[unroll]] for (uint tj = 0; tj < C_TILES; tj++) { - coopMatStore(A_acc[tj], sh_attn, - sg_id * TM * ATTN_V4_STRIDE + tj * (TN / 4), - ATTN_V4_STRIDE, gl_CooperativeMatrixLayoutRowMajor); - } - - barrier(); - - // ============================================================ - // Step 3: Decay mask + inter-chunk output (parallel, no conflict) - // ============================================================ - - for (uint idx = tid; idx < CHUNK_SIZE * (CHUNK_SIZE / 4); idx += WG_SIZE) { - const uint t = idx / (CHUNK_SIZE / 4); - const uint j4 = idx % (CHUNK_SIZE / 4); - f16vec4 val = f16vec4(0.0); - if (t < chunk_len) { - const float g_t = sh_gcum[t]; - [[unroll]] for (uint k = 0; k < 4; k++) { - const uint j = j4 * 4 + k; - if (j <= t && j < chunk_len) { - val[k] = float16_t(clamp( - exp(g_t - sh_gcum[j]) * sh_attn[t * ATTN_V4_STRIDE + j4][k], - -65504.0, 65504.0)); - } - } - } - sh_adecay[t * ATTN_V4_STRIDE + j4] = val; - } - - if (col < S_V) { - for (uint t = 0; t < chunk_len; t++) { - float o_inter = 0.0; - [[unroll]] for (uint i = 0; i < S_V / 4; i++) { - vec4 q_f32 = vec4(sh_q[t * QK_STRIDE + i]); - o_inter += dot(q_f32, - vec4(state[i*4], state[i*4+1], - state[i*4+2], state[i*4+3])); - } - dst[attn_off + (chunk_start + t) * S_V * H + col] = exp(sh_gcum[t]) * o_inter * scale; - } - } - - barrier(); - - // ============================================================ - // Step 4: vnew = U - W@state → sh_kv, accumulate delta - // Threads 0-127: load w, compute vnew, write sh_kv via shuffle - // Threads 128-255: load k_gated into sh_kg - // ============================================================ - - float delta[S_V]; - if (col < S_V) { - [[unroll]] for (uint i = 0; i < S_V; i++) { - delta[i] = 0.0; - } - } - - for (uint t = 0; t < chunk_len; t++) { - if (col < S_V) { - sh_w[col] = w_in[wu_base + t * S_V + col]; - } - if (tid >= 128 && tid < 256) { - const uint lc = tid - 128; - const float gcum_t = gcum_in[gcum_base + t]; - const float decay_factor = exp(g_total - gcum_t); - const uint k_off = iq3 * sq3 + (chunk_start + t) * sq2 + iq1 * sq1; - sh_kg[lc] = k_in[k_off + lc] * decay_factor; - } - barrier(); - - if (col < S_V) { - float ws = 0.0; - [[unroll]] for (uint i = 0; i < S_V; i += 4) { - ws += dot(vec4(sh_w[i], sh_w[i+1], sh_w[i+2], sh_w[i+3]), - vec4(state[i], state[i+1], state[i+2], state[i+3])); - } - - float vnew = u_in[wu_base + t * S_V + col] - ws; - float vnew_scaled = clamp(vnew * scale, -65504.0, 65504.0); - - float16_t v0 = float16_t(subgroupShuffle(vnew_scaled, si & ~3u)); - float16_t v1 = float16_t(subgroupShuffle(vnew_scaled, (si & ~3u) + 1u)); - float16_t v2 = float16_t(subgroupShuffle(vnew_scaled, (si & ~3u) + 2u)); - float16_t v3 = float16_t(subgroupShuffle(vnew_scaled, (si & ~3u) + 3u)); - if ((si & 3u) == 0u) { - sh_kv[t * QK_STRIDE + (col >> 2)] = f16vec4(v0, v1, v2, v3); - } - - [[unroll]] for (uint i = 0; i < S_V; i++) { - delta[i] += sh_kg[i] * vnew; - } - } - barrier(); - } - - // ============================================================ - // Step 5: O_intra = A_decayed @ vnew → dst (coopmat GEMM) - // ============================================================ - - if (chunk_len == CHUNK_SIZE) { - coopmat A_mat; - coopmat V_mat; - - [[unroll]] for (uint dn = 0; dn < D_TILES; dn++) { - coopmat O_acc; - - coopMatLoad(O_acc, dst, - attn_off + (chunk_start + sg_id * TM) * S_V * H + dn * TN, - S_V * H, gl_CooperativeMatrixLayoutRowMajor); - - [[unroll]] for (uint dk = 0; dk < C_TILES; dk++) { - coopMatLoad(A_mat, sh_adecay, - sg_id * TM * ATTN_V4_STRIDE + dk * (TK / 4), - ATTN_V4_STRIDE, gl_CooperativeMatrixLayoutRowMajor); - - coopMatLoad(V_mat, sh_kv, - dk * TK * QK_STRIDE + dn * (TN / 4), - QK_STRIDE, gl_CooperativeMatrixLayoutRowMajor); - - O_acc = coopMatMulAdd(A_mat, V_mat, O_acc); - } - - coopMatStore(O_acc, dst, - attn_off + (chunk_start + sg_id * TM) * S_V * H + dn * TN, - S_V * H, gl_CooperativeMatrixLayoutRowMajor); - } - } else { - if (col < S_V) { - const uint col4 = col / 4; - const uint comp = col % 4; - - vec4 my_vnew_v4[CHUNK_SIZE / 4]; - [[unroll]] for (uint j4 = 0; j4 < CHUNK_SIZE / 4; j4++) { - my_vnew_v4[j4] = vec4( - float(sh_kv[(j4*4 ) * QK_STRIDE + col4][comp]), - float(sh_kv[(j4*4+1) * QK_STRIDE + col4][comp]), - float(sh_kv[(j4*4+2) * QK_STRIDE + col4][comp]), - float(sh_kv[(j4*4+3) * QK_STRIDE + col4][comp])); - } - - for (uint t = 0; t < chunk_len; t++) { - float o_intra = 0.0; - [[unroll]] for (uint j4 = 0; j4 < CHUNK_SIZE / 4; j4++) { - o_intra += dot(vec4(sh_adecay[t * ATTN_V4_STRIDE + j4]), my_vnew_v4[j4]); - } - dst[attn_off + (chunk_start + t) * S_V * H + col] += o_intra; - } - } - } - - // ============================================================ - // Step 6: State update - // ============================================================ - - if (col < S_V) { - const float total_decay = exp(g_total); - [[unroll]] for (uint i = 0; i < S_V; i++) { - state[i] = total_decay * state[i] + delta[i]; - } - } - } - - // ================================================================ - // Write final state to dst - // ================================================================ - - if (col < S_V) { - [[unroll]] for (uint i = 0; i < S_V; i++) { - dst[s_off + state_base + i * S_V + col] = state[i]; - } - } -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 9fcdcb9545..2ddd4f2f8c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -992,7 +992,6 @@ void process_shaders() { string_to_spv("gated_delta_net_chunk_inter_f32", "gated_delta_net_chunk_inter.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("gated_delta_net_chunk_output_f32", "gated_delta_net_chunk_output.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("gated_delta_net_chunk_output_cm1_f32", "gated_delta_net_chunk_output_cm1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); - string_to_spv("gated_delta_net_chunk_fused_cm_f32", "gated_delta_net_chunk_fused_cm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));