vulkan: Set k_load_shmem to false when K is too large (#19301)
This commit is contained in:
parent
c342c3b93d
commit
3409ab842d
|
|
@ -3204,9 +3204,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
const uint32_t D_lsb = D ^ (D & (D-1));
|
||||
uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
|
||||
|
||||
// Nvidia prefers shared memory use to load large tiles of K
|
||||
// Nvidia prefers shared memory use to load large tiles of K.
|
||||
// Switch to loading from global memory when it would use too much shared memory.
|
||||
// AMD prefers loading K directly from global memory
|
||||
const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA ? 1 : 0;
|
||||
const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0;
|
||||
|
||||
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem};
|
||||
};
|
||||
|
|
@ -8412,7 +8413,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
|
|||
const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
|
||||
const uint32_t sfsh = Bc * sfshstride * acctype;
|
||||
|
||||
const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA;
|
||||
const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256;
|
||||
const uint32_t kshstride = (k_load_shmem ? hsk_pad : MatBr) / 4 + 2;
|
||||
const uint32_t vsh_stride = MatBc / 4 * row_split;
|
||||
const uint32_t ksh = ((kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)) * f16vec4;
|
||||
|
|
|
|||
Loading…
Reference in New Issue