vulkan: Fix data race/hang in scalar/cm1 flash attention (#17887)
This commit is contained in:
parent
4722671641
commit
3238b1400c
|
|
@ -256,6 +256,9 @@ void main() {
|
||||||
barrier();
|
barrier();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// prevent race on tmpsh
|
||||||
|
barrier();
|
||||||
|
|
||||||
// reduce across threads
|
// reduce across threads
|
||||||
|
|
||||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||||
|
|
|
||||||
|
|
@ -302,6 +302,9 @@ void main() {
|
||||||
barrier();
|
barrier();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// prevent race on tmpsh
|
||||||
|
barrier();
|
||||||
|
|
||||||
// reduce across threads
|
// reduce across threads
|
||||||
|
|
||||||
float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];
|
float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue