diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 207d625b10..f645124e39 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -10406,7 +10406,7 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, } static constexpr uint32_t GDN_CHUNK_SIZE = 64; -static constexpr uint32_t GDN_CHUNK_THRESHOLD = 2; +static constexpr uint32_t GDN_CHUNK_THRESHOLD = GDN_CHUNK_SIZE; static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { const ggml_tensor * src_q = dst->src[0]; @@ -10494,6 +10494,22 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const size_t h_off = gcum_off + g_size; const size_t total_scratch = h_off + h_size; + if (total_scratch > ctx->device->properties.limits.maxStorageBufferRange) { + // Fall back to autoregressive if scratch exceeds device limits + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); + GGML_ASSERT(pipeline != nullptr); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + const float scale = 1.0f / sqrtf((float)S_v); + const vk_op_gated_delta_net_push_constants pc_ar = { + H, n_tokens, n_seqs, s_off, + sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neq1, rq3, scale + }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf}, + pc_ar, { H, n_seqs, 1u }); + return; + } + if (ctx->prealloc_size_split_k < total_scratch) { ctx->prealloc_size_split_k = total_scratch; ggml_vk_preallocate_buffers(ctx, subctx); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_output_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_output_cm1.comp index 31fba4522a..24a4f3e6be 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_output_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net_chunk_output_cm1.comp @@ -40,7 +40,7 @@ const uint TK = 16; const uint C_TILES = CHUNK_SIZE / TM; const uint D_TILES = S_V / TN; -// Shared memory strides in f16vec4 units, padded for bank conflicts +// Padded strides for bank conflict avoidance const uint QK_STRIDE = S_V / 4 + 2; const uint ATTN_V4_STRIDE = CHUNK_SIZE / 4 + 2;