use minimal subgroup size on Intel

This commit is contained in:
Ruben Ortlam 2026-02-11 00:41:14 +01:00
parent 9f9b701ff5
commit 3ed9183ac9
1 changed files with 23 additions and 2 deletions

View File

@ -2783,6 +2783,26 @@ static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16;
static constexpr uint32_t scalar_flash_attention_Bc = 64;
static constexpr uint32_t scalar_flash_attention_workgroup_size = 128;
static bool fa_disable_subgroups(const vk_device& device, FaCodePath path) {
return device->vendor_id == VK_VENDOR_ID_INTEL && path == FA_SCALAR;
}
static uint32_t fa_subgroup_size(const vk_device& device, FaCodePath path) {
if (fa_disable_subgroups(device, path)) {
return 0xFFFFFFFF;
}
if (path == FA_VECTOR) {
if (device->vendor_id == VK_VENDOR_ID_AMD && device->subgroup_min_size <= 32 && device->subgroup_max_size >= 32) {
return 32;
} else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) {
return device->subgroup_min_size;
}
}
return device->subgroup_size;
}
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, FaRows rows, bool small_cache) {
GGML_UNUSED(clamp);
@ -3223,17 +3243,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
break;
}
const uint32_t subgroup_size = fa_subgroup_size(device, path);
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
const uint32_t D_lsb = D ^ (D & (D-1));
uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
uint32_t D_split = std::min(std::min(subgroup_size, 8u), D_lsb / 4);
// Nvidia prefers shared memory use to load large tiles of K/V.
// Switch to loading from global memory when it would use too much shared memory.
// AMD prefers loading K directly from global memory
const uint32_t shmem_staging = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256 ? 1 : 0;
const uint32_t subgroup_size = disable_subgroups ? 0xFFFFFFFF : device->subgroup_size;
return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, subgroup_size, shmem_staging, flags};
};