fix amd workgroup size issue

This commit is contained in:
Ruben Ortlam 2026-02-05 17:17:04 +01:00
parent f92d7eddab
commit 9b309bbc51
2 changed files with 22 additions and 24 deletions

View File

@ -2759,11 +2759,11 @@ static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) {
if (hsv >= 192) {
return 2;
} else if ((hsv | hsk) & 8 || small_cache) {
return 4;
} else {
return 8;
} else if ((hsv | hsk) & 8 || small_cache) {
return 8;
} else {
return 16;
}
}
@ -2789,13 +2789,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
if (small_rows) {
return {scalar_flash_attention_num_small_rows, 64};
} else {
if ((hsv | hsk) & 8) {
// HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
// larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64};
} else {
return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32};
}
return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64};
}
}
@ -3213,7 +3207,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
break;
default:
wg_size = device->subgroup_size * 4;
if (device->subgroup_size > 32 && rows_cols[0] < 4) {
wg_size = device->subgroup_size * 2;
} else {
wg_size = device->subgroup_size * 4;
}
break;
}

View File

@ -309,10 +309,10 @@ void main() {
tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf;
}
barrier();
rowmaxf = max(max(max(tmpsh[0 * D_split + d_tid],
tmpsh[1 * D_split + d_tid]),
tmpsh[2 * D_split + d_tid]),
tmpsh[3 * D_split + d_tid]);
rowmaxf = tmpsh[d_tid];
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]);
}
}
float Moldf = Mf[r];
@ -334,10 +334,10 @@ void main() {
tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r];
}
barrier();
Lf[r] = tmpsh[0 * D_split + d_tid] +
tmpsh[1 * D_split + d_tid] +
tmpsh[2 * D_split + d_tid] +
tmpsh[3 * D_split + d_tid];
Lf[r] = tmpsh[d_tid];
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
Lf[r] += tmpsh[s * D_split + d_tid];
}
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
@ -352,10 +352,10 @@ void main() {
tmpsh_accv4[gl_SubgroupID * D_split + d_tid] = Of[r][d];
}
barrier();
Of[r][d] = tmpsh_accv4[0 * D_split + d_tid] +
tmpsh_accv4[1 * D_split + d_tid] +
tmpsh_accv4[2 * D_split + d_tid] +
tmpsh_accv4[3 * D_split + d_tid];
Of[r][d] = tmpsh_accv4[d_tid];
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
Of[r][d] += tmpsh_accv4[s * D_split + d_tid];
}
}
}
}