CUDA: fix unpadded strides in MMA FA kernel (#17891)
This commit is contained in:
parent
9e79b0116e
commit
17f7f4baad
|
|
@ -955,9 +955,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||||
(K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
|
(K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
||||||
|
if constexpr (ncols2 == 1) {
|
||||||
|
constexpr bool oob_check = true;
|
||||||
for (; kb0 < kb0_stop-1; ++kb0) {
|
for (; kb0 < kb0_stop-1; ++kb0) {
|
||||||
constexpr bool last_iter = false;
|
constexpr bool last_iter = false;
|
||||||
constexpr bool oob_check = false;
|
|
||||||
constexpr int k_VKQ_sup = nbatch_fa;
|
constexpr int k_VKQ_sup = nbatch_fa;
|
||||||
flash_attn_ext_f16_iter
|
flash_attn_ext_f16_iter
|
||||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
||||||
|
|
@ -966,21 +968,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||||
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
||||||
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
|
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
|
||||||
}
|
}
|
||||||
// kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
|
||||||
if constexpr (ncols2 == 1) {
|
|
||||||
if (ne11 % nbatch_fa == 0) {
|
|
||||||
constexpr bool last_iter = true;
|
constexpr bool last_iter = true;
|
||||||
constexpr bool oob_check = false;
|
|
||||||
constexpr int k_VKQ_sup = nbatch_fa;
|
|
||||||
flash_attn_ext_f16_iter
|
|
||||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
|
||||||
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
|
||||||
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
|
||||||
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
|
||||||
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
|
|
||||||
} else {
|
|
||||||
constexpr bool last_iter = true;
|
|
||||||
constexpr bool oob_check = true;
|
|
||||||
const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
|
const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
|
||||||
flash_attn_ext_f16_iter
|
flash_attn_ext_f16_iter
|
||||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
||||||
|
|
@ -988,10 +976,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||||
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||||
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
||||||
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
|
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
constexpr bool last_iter = true;
|
|
||||||
constexpr bool oob_check = false;
|
constexpr bool oob_check = false;
|
||||||
|
for (; kb0 < kb0_stop-1; ++kb0) {
|
||||||
|
constexpr bool last_iter = false;
|
||||||
|
constexpr int k_VKQ_sup = nbatch_fa;
|
||||||
|
flash_attn_ext_f16_iter
|
||||||
|
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
||||||
|
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
|
||||||
|
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
|
||||||
|
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
|
||||||
|
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
|
||||||
|
}
|
||||||
|
constexpr bool last_iter = true;
|
||||||
constexpr int k_VKQ_sup = nbatch_fa;
|
constexpr int k_VKQ_sup = nbatch_fa;
|
||||||
flash_attn_ext_f16_iter
|
flash_attn_ext_f16_iter
|
||||||
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
|
||||||
|
|
|
||||||
|
|
@ -36,12 +36,26 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
|
||||||
const ggml_tensor * KQV = dst;
|
const ggml_tensor * KQV = dst;
|
||||||
const ggml_tensor * Q = dst->src[0];
|
const ggml_tensor * Q = dst->src[0];
|
||||||
const ggml_tensor * K = dst->src[1];
|
const ggml_tensor * K = dst->src[1];
|
||||||
|
const ggml_tensor * V = dst->src[2];
|
||||||
const ggml_tensor * mask = dst->src[3];
|
const ggml_tensor * mask = dst->src[3];
|
||||||
|
|
||||||
float max_bias = 0.0f;
|
float max_bias = 0.0f;
|
||||||
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
||||||
|
|
||||||
const bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
// Edge cases like no mask, ALiBi, unpadded K/V, or misaligned addresses for large data transfers
|
||||||
|
// are put into the template specialization without GQA optimizations.
|
||||||
|
bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
|
||||||
|
for (const ggml_tensor * t : {Q, K, V, mask}) {
|
||||||
|
if (t == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
|
||||||
|
if (t->nb[i] % 16 != 0) {
|
||||||
|
use_gqa_opt = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue