diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp index 8ebed83b20..2ed8260c76 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_inter.comp @@ -74,6 +74,12 @@ void main() { for (uint t = 0; t < chunk_len; t++) { s_w[col] = w_in[wu_base + t * S_V + col]; + + const float g_cumsum_t = gcum_in[gcum_base + t]; + const float decay_factor = exp(g_total - g_cumsum_t); + const uint t_global = chunk_start + t; + const uint k_off = iq3 * sq3 + t_global * sq2 + iq1 * sq1; + s_kg[col] = k_in[k_off + col] * decay_factor; barrier(); float ws = 0.0; @@ -87,15 +93,6 @@ void main() { float vnew = u_in[wu_base + t * S_V + col] - ws; vnew_out[wu_base + t * S_V + col] = vnew; - // K_gated[t] = k[t] * exp(g_total - g_cumsum[t]) - float g_cumsum_t = gcum_in[gcum_base + t]; - float decay_factor = exp(g_total - g_cumsum_t); - - const uint t_global = chunk_start + t; - const uint k_off = iq3 * sq3 + t_global * sq2 + iq1 * sq1; - s_kg[col] = k_in[k_off + col] * decay_factor; - barrier(); - [[unroll]] for (uint i = 0; i < S_V; i++) { delta[i] += s_kg[i] * vnew; }