CUDA: fix overflow in MMA kernel without stream-k (#17939)
This commit is contained in:
parent
7bed317f53
commit
482211438d
|
|
@ -642,8 +642,8 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||||
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
|
||||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||||
|
|
||||||
const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||||
const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||||
|
|
||||||
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
||||||
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
||||||
|
|
@ -679,7 +679,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
||||||
int bidx = bidx0 - 1;
|
int bidx = bidx0 - 1;
|
||||||
int kbc_stop = kbc0;
|
int kbc_stop = kbc0;
|
||||||
while(true) {
|
while(true) {
|
||||||
const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
const int kbc = int64_t(bidx)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||||
if (kbc == kbc_stop) { // Did not have any data.
|
if (kbc == kbc_stop) { // Did not have any data.
|
||||||
bidx--;
|
bidx--;
|
||||||
kbc_stop = kbc;
|
kbc_stop = kbc;
|
||||||
|
|
|
||||||
|
|
@ -1380,8 +1380,8 @@ static __global__ void flash_attn_ext_f16(
|
||||||
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
|
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
|
||||||
|
|
||||||
// kbc == k block continuous, current index in continuous ijk space.
|
// kbc == k block continuous, current index in continuous ijk space.
|
||||||
int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||||
const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||||
|
|
||||||
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
||||||
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue