diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index f765660272..8eadb69dbb 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3218,7 +3218,7 @@ static void ggml_vk_load_shaders(vk_device& device) { // Nvidia prefers shared memory use to load large tiles of K. // Switch to loading from global memory when it would use too much shared memory. // AMD prefers loading K directly from global memory - const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0; + const uint32_t k_load_shmem = 1; // device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0; return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags}; }; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index e641debe3c..e4ca125eb4 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -50,8 +50,12 @@ shared ACC_TYPEV4 tmpsh_accv4[tmpsh_size]; const uint32_t masksh_stride = Br + 1; shared FLOAT_TYPE masksh[Bc * masksh_stride]; -const uint qfstride = HSK / 4 + 1; -shared FLOAT_TYPEV4 Qf[Br * qfstride]; +const uint32_t qf_stride = HSK / 4 + 1; +shared FLOAT_TYPEV4 Qf[Br * qf_stride]; + +const uint32_t D = HSK > HSV ? HSK : HSV; +const uint32_t kvsh_stride = D / 4 + 1; +shared FLOAT_TYPEV4 kvsh[K_LOAD_SHMEM != 0 ? Bc * kvsh_stride : 1]; void main() { #ifdef NEEDS_INIT_IQ_SHMEM @@ -75,7 +79,7 @@ void main() { uint32_t r = (idx + tid) / (HSK / 4); if (r < Br && d < HSK / 4 && i * Br + r < N) { - Qf[r * qfstride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); + Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); } } barrier(); @@ -184,10 +188,33 @@ void main() { } } + if (K_LOAD_SHMEM != 0) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t c = (idx + tid) / (HSK / 4); + if (c < Bc && d < HSK / 4) { + FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0); + if (!KV_bounds_check || j * Bc + c < KV) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = FLOAT_TYPEV4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); +#else + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); +#endif + } + + kvsh[c * kvsh_stride + d] = K_Tf; + } + } + barrier(); + } + [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { FLOAT_TYPEV4 Q_cache[rows_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Q_cache[r] = Qf[tile_row(r) * qfstride + d * D_split + d_tid]; + Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid]; } [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { @@ -195,14 +222,19 @@ void main() { continue; } + FLOAT_TYPEV4 K_Tf; + if (K_LOAD_SHMEM != 0) { + K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; + } else { #if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = FLOAT_TYPEV4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); #else - FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); #endif + } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf)); }