vulkan: harden chunked GDN dispatch and fix minor issues

- Raise GDN_CHUNK_THRESHOLD from 2 to CHUNK_SIZE (64). Chunked path
  only activates when there's at least one full chunk. Below that,
  autoregressive is faster and the 3-dispatch overhead isn't justified.
- Add maxStorageBufferRange guard on scratch allocation. Falls back to
  autoregressive if the scratch buffers would exceed device limits.
- Fix inaccurate shared memory stride comment in cm1 output kernel.

16/16 tests pass.
This commit is contained in:
Progeny Alpha 2026-03-15 00:38:44 -04:00
parent ab79f14b42
commit 088cb0cbe8
2 changed files with 18 additions and 2 deletions

View File

@ -10406,7 +10406,7 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx,
}
static constexpr uint32_t GDN_CHUNK_SIZE = 64;
static constexpr uint32_t GDN_CHUNK_THRESHOLD = 2;
static constexpr uint32_t GDN_CHUNK_THRESHOLD = GDN_CHUNK_SIZE;
static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
const ggml_tensor * src_q = dst->src[0];
@ -10494,6 +10494,22 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
const size_t h_off = gcum_off + g_size;
const size_t total_scratch = h_off + h_size;
if (total_scratch > ctx->device->properties.limits.maxStorageBufferRange) {
// Fall back to autoregressive if scratch exceeds device limits
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
GGML_ASSERT(pipeline != nullptr);
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
const float scale = 1.0f / sqrtf((float)S_v);
const vk_op_gated_delta_net_push_constants pc_ar = {
H, n_tokens, n_seqs, s_off,
sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neq1, rq3, scale
};
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
pc_ar, { H, n_seqs, 1u });
return;
}
if (ctx->prealloc_size_split_k < total_scratch) {
ctx->prealloc_size_split_k = total_scratch;
ggml_vk_preallocate_buffers(ctx, subctx);

View File

@ -40,7 +40,7 @@ const uint TK = 16;
const uint C_TILES = CHUNK_SIZE / TM;
const uint D_TILES = S_V / TN;
// Shared memory strides in f16vec4 units, padded for bank conflicts
// Padded strides for bank conflict avoidance
const uint QK_STRIDE = S_V / 4 + 2;
const uint ATTN_V4_STRIDE = CHUNK_SIZE / 4 + 2;