diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 383840db7f..d1c470b1c1 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -10406,7 +10406,7 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, } static constexpr uint32_t GDN_CHUNK_SIZE = 64; -static constexpr uint32_t GDN_CHUNK_THRESHOLD = UINT32_MAX; // Disabled +static constexpr uint32_t GDN_CHUNK_THRESHOLD = 2; static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { const ggml_tensor * src_q = dst->src[0]; @@ -10472,7 +10472,9 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s vk_pipeline pl_intra = ctx->device->pipeline_gated_delta_net_chunk_intra; vk_pipeline pl_inter = ctx->device->pipeline_gated_delta_net_chunk_inter; - vk_pipeline pl_output = ctx->device->pipeline_gated_delta_net_chunk_output; + vk_pipeline pl_output = ctx->device->pipeline_gated_delta_net_chunk_output_cm + ? ctx->device->pipeline_gated_delta_net_chunk_output_cm + : ctx->device->pipeline_gated_delta_net_chunk_output; ggml_pipeline_request_descriptor_sets(ctx, pl_intra, 1); ggml_pipeline_request_descriptor_sets(ctx, pl_inter, 1);