fix amd workgroup size issue
This commit is contained in:
parent
f92d7eddab
commit
9b309bbc51
|
|
@ -2759,11 +2759,11 @@ static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
|||
|
||||
static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) {
|
||||
if (hsv >= 192) {
|
||||
return 2;
|
||||
} else if ((hsv | hsk) & 8 || small_cache) {
|
||||
return 4;
|
||||
} else {
|
||||
return 8;
|
||||
} else if ((hsv | hsk) & 8 || small_cache) {
|
||||
return 8;
|
||||
} else {
|
||||
return 16;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2789,13 +2789,7 @@ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint3
|
|||
if (small_rows) {
|
||||
return {scalar_flash_attention_num_small_rows, 64};
|
||||
} else {
|
||||
if ((hsv | hsk) & 8) {
|
||||
// HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
|
||||
// larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
|
||||
return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64};
|
||||
} else {
|
||||
return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32};
|
||||
}
|
||||
return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64};
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -3213,7 +3207,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|||
wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
|
||||
break;
|
||||
default:
|
||||
wg_size = device->subgroup_size * 4;
|
||||
if (device->subgroup_size > 32 && rows_cols[0] < 4) {
|
||||
wg_size = device->subgroup_size * 2;
|
||||
} else {
|
||||
wg_size = device->subgroup_size * 4;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -309,10 +309,10 @@ void main() {
|
|||
tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf;
|
||||
}
|
||||
barrier();
|
||||
rowmaxf = max(max(max(tmpsh[0 * D_split + d_tid],
|
||||
tmpsh[1 * D_split + d_tid]),
|
||||
tmpsh[2 * D_split + d_tid]),
|
||||
tmpsh[3 * D_split + d_tid]);
|
||||
rowmaxf = tmpsh[d_tid];
|
||||
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
|
||||
rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]);
|
||||
}
|
||||
}
|
||||
|
||||
float Moldf = Mf[r];
|
||||
|
|
@ -334,10 +334,10 @@ void main() {
|
|||
tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r];
|
||||
}
|
||||
barrier();
|
||||
Lf[r] = tmpsh[0 * D_split + d_tid] +
|
||||
tmpsh[1 * D_split + d_tid] +
|
||||
tmpsh[2 * D_split + d_tid] +
|
||||
tmpsh[3 * D_split + d_tid];
|
||||
Lf[r] = tmpsh[d_tid];
|
||||
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
|
||||
Lf[r] += tmpsh[s * D_split + d_tid];
|
||||
}
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
|
|
@ -352,10 +352,10 @@ void main() {
|
|||
tmpsh_accv4[gl_SubgroupID * D_split + d_tid] = Of[r][d];
|
||||
}
|
||||
barrier();
|
||||
Of[r][d] = tmpsh_accv4[0 * D_split + d_tid] +
|
||||
tmpsh_accv4[1 * D_split + d_tid] +
|
||||
tmpsh_accv4[2 * D_split + d_tid] +
|
||||
tmpsh_accv4[3 * D_split + d_tid];
|
||||
Of[r][d] = tmpsh_accv4[d_tid];
|
||||
[[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
|
||||
Of[r][d] += tmpsh_accv4[s * D_split + d_tid];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue