CUDA: fix unpadded strides in MMA FA kernel (#17891)
This commit is contained in:
parent
9e79b0116e
commit
17f7f4baad
|
|
@ -955,9 +955,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||
(K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
|
||||
}
|
||||
|
||||
// kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
||||
if constexpr (ncols2 == 1) {
|
||||
constexpr bool oob_check = true;
|
||||
for (; kb0 < kb0_stop-1; ++kb0) {
|
||||
constexpr bool last_iter = false;
|
||||
constexpr bool oob_check = false;
|
||||
constexpr int k_VKQ_sup = nbatch_fa;
|
||||
flash_attn_ext_f16_iter
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
|
|
@ -966,21 +968,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
||||
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
|
||||
}
|
||||
// kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
||||
if constexpr (ncols2 == 1) {
|
||||
if (ne11 % nbatch_fa == 0) {
|
||||
constexpr bool last_iter = true;
|
||||
constexpr bool oob_check = false;
|
||||
constexpr int k_VKQ_sup = nbatch_fa;
|
||||
flash_attn_ext_f16_iter
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
||||
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
||||
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
|
||||
} else {
|
||||
constexpr bool last_iter = true;
|
||||
constexpr bool oob_check = true;
|
||||
const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
|
||||
flash_attn_ext_f16_iter
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
|
|
@ -988,10 +976,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|||
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
||||
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
|
||||
}
|
||||
} else {
|
||||
constexpr bool last_iter = true;
|
||||
constexpr bool oob_check = false;
|
||||
for (; kb0 < kb0_stop-1; ++kb0) {
|
||||
constexpr bool last_iter = false;
|
||||
constexpr int k_VKQ_sup = nbatch_fa;
|
||||
flash_attn_ext_f16_iter
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
||||
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
||||
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
|
||||
}
|
||||
constexpr bool last_iter = true;
|
||||
constexpr int k_VKQ_sup = nbatch_fa;
|
||||
flash_attn_ext_f16_iter
|
||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
||||
|
|
|
|||
|
|
@ -36,12 +36,26 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
|
|||
const ggml_tensor * KQV = dst;
|
||||
const ggml_tensor * Q = dst->src[0];
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
|
||||
float max_bias = 0.0f;
|
||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||
|
||||
const bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
// Edge cases like no mask, ALiBi, unpadded K/V, or misaligned addresses for large data transfers
|
||||
// are put into the template specialization without GQA optimizations.
|
||||
bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||
for (const ggml_tensor * t : {Q, K, V, mask}) {
|
||||
if (t == nullptr) {
|
||||
continue;
|
||||
}
|
||||
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
|
||||
if (t->nb[i] % 16 != 0) {
|
||||
use_gqa_opt = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||
|
|
|
|||
Loading…
Reference in New Issue