|
|
|
|
@ -98,6 +98,19 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
|
|
|
|
|
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) {
|
|
|
|
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true);
|
|
|
|
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
|
|
|
|
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
|
|
|
|
|
|
|
|
|
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
|
|
|
|
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
|
|
|
|
|
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
|
|
|
|
|
|
|
|
|
|
// TODO tune specifically for RDNA
|
|
|
|
|
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
|
|
|
|
|
if (ampere_mma_available(cc)) {
|
|
|
|
|
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
|
|
|
|
|
@ -105,6 +118,9 @@ static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, c
|
|
|
|
|
if (turing_mma_available(cc)) {
|
|
|
|
|
return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
|
|
|
|
|
}
|
|
|
|
|
if (amd_wmma_available(cc)) {
|
|
|
|
|
return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
|
|
|
|
|
}
|
|
|
|
|
GGML_ASSERT(volta_mma_available(cc));
|
|
|
|
|
return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
|
|
|
|
|
}
|
|
|
|
|
@ -116,6 +132,8 @@ static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(cons
|
|
|
|
|
return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
|
|
|
|
|
#elif defined(VOLTA_MMA_AVAILABLE)
|
|
|
|
|
return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
|
|
|
|
|
#elif defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
|
|
|
|
|
#else
|
|
|
|
|
GGML_UNUSED_VARS(DKQ, DV, ncols);
|
|
|
|
|
return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
|
|
|
|
|
@ -186,6 +204,23 @@ static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ,
|
|
|
|
|
return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static constexpr __device__ int get_cols_per_thread() {
|
|
|
|
|
#if defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
return 1; // RDNA has a single column.
|
|
|
|
|
#else
|
|
|
|
|
return 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
|
|
|
|
#endif // defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static __host__ int get_cols_per_warp(const int cc) {
|
|
|
|
|
if (turing_mma_available(cc) || amd_wmma_available(cc)) {
|
|
|
|
|
return 16;
|
|
|
|
|
} else {
|
|
|
|
|
// Volta
|
|
|
|
|
return 32;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ------------------------------------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) {
|
|
|
|
|
@ -393,10 +428,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
|
|
const int jt,
|
|
|
|
|
const int kb0,
|
|
|
|
|
const int k_VKQ_sup) {
|
|
|
|
|
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
|
|
|
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
|
|
|
|
constexpr int ncols = ncols1 * ncols2;
|
|
|
|
|
constexpr int cols_per_warp = T_B_KQ::I;
|
|
|
|
|
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
|
|
|
|
constexpr int cols_per_thread = get_cols_per_thread();
|
|
|
|
|
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
|
|
|
|
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
|
|
|
|
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
|
|
|
|
|
@ -413,6 +448,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
|
|
const int k_VKQ_0 = kb0 * nbatch_fa;
|
|
|
|
|
#if defined(TURING_MMA_AVAILABLE)
|
|
|
|
|
T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
|
|
|
|
|
#elif defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
|
|
|
|
|
#else // Volta
|
|
|
|
|
T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
|
|
|
|
|
#endif // defined(TURING_MMA_AVAILABLE)
|
|
|
|
|
@ -461,8 +498,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
|
|
if constexpr (cols_per_warp == 8) {
|
|
|
|
|
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
|
|
|
|
|
} else {
|
|
|
|
|
// Wide version of KQ_C is column-major => swap A and B.
|
|
|
|
|
// Wide version of KQ_C is column-major
|
|
|
|
|
#if defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
// RDNA matrix C is column-major.
|
|
|
|
|
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
|
|
|
|
|
#else
|
|
|
|
|
// swap A and B for CUDA.
|
|
|
|
|
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
|
|
|
|
|
#endif // defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@ -479,8 +522,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
|
|
T_A_KQ K_A;
|
|
|
|
|
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
|
|
|
|
|
|
|
|
|
// Wide version of KQ_C is column-major => swap A and B.
|
|
|
|
|
// Wide version of KQ_C is column-major
|
|
|
|
|
#if defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
// RDNA matrix C is column-major.
|
|
|
|
|
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
|
|
|
|
#else
|
|
|
|
|
// swap A and B for CUDA.
|
|
|
|
|
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
|
|
|
|
#endif // defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@ -532,7 +581,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
|
|
|
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
|
|
|
|
|
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
|
|
|
|
|
#if defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
constexpr int KQ_idx = 0;
|
|
|
|
|
#else
|
|
|
|
|
// Turing + Volta:
|
|
|
|
|
const int KQ_idx = l % 2;
|
|
|
|
|
#endif // defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@ -552,8 +607,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
|
|
|
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
|
|
|
|
|
KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[l % 2]);
|
|
|
|
|
KQ_rowsum_add[l % 2] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
|
|
|
|
|
#if defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
constexpr int KQ_idx = 0;
|
|
|
|
|
#else
|
|
|
|
|
// Turing + Volta:
|
|
|
|
|
const int KQ_idx = l % 2;
|
|
|
|
|
#endif // defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]);
|
|
|
|
|
KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
|
|
|
|
|
} else {
|
|
|
|
|
KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f;
|
|
|
|
|
}
|
|
|
|
|
@ -584,8 +645,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
|
|
|
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
|
|
|
|
|
#if defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
constexpr int KQ_idx = 0;
|
|
|
|
|
#else
|
|
|
|
|
// Turing + Volta:
|
|
|
|
|
KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
|
|
|
|
|
const int KQ_idx = (l/2) % 2;
|
|
|
|
|
#endif // defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@ -596,7 +662,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
|
|
// Values per KQ column are spread across 4 threads:
|
|
|
|
|
constexpr int offset_first = 2;
|
|
|
|
|
constexpr int offset_last = 1;
|
|
|
|
|
#else
|
|
|
|
|
#elif defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
// Values per KQ column are spread across 2 threads:
|
|
|
|
|
constexpr int offset_first = 16;
|
|
|
|
|
constexpr int offset_last = 16;
|
|
|
|
|
#else // Volta
|
|
|
|
|
// Values per KQ column are spread across 2 threads:
|
|
|
|
|
constexpr int offset_first = 2;
|
|
|
|
|
constexpr int offset_last = 2;
|
|
|
|
|
@ -612,10 +682,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
|
|
for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int l = 0; l < T_C_KQ::ne; ++l) {
|
|
|
|
|
// Turing + Volta:
|
|
|
|
|
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
|
|
|
|
|
KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[(l/2) % 2]);
|
|
|
|
|
KQ_rowsum_add[(l/2) % 2] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
|
|
|
|
|
#if defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
constexpr int KQ_idx = 0;
|
|
|
|
|
#else
|
|
|
|
|
// Turing + Volta:
|
|
|
|
|
const int KQ_idx = (l/2) % 2;
|
|
|
|
|
#endif // defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]);
|
|
|
|
|
KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
|
|
|
|
|
} else {
|
|
|
|
|
KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f;
|
|
|
|
|
}
|
|
|
|
|
@ -639,7 +714,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
|
|
|
|
|
|
|
#if defined(TURING_MMA_AVAILABLE)
|
|
|
|
|
if constexpr (cols_per_warp == 8) {
|
|
|
|
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
|
|
|
|
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
|
|
|
|
|
#pragma unroll
|
|
|
|
|
@ -660,6 +735,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#elif defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
const half2 KQ_max_scale_h2 = make_half2(
|
|
|
|
|
KQ_max_scale[0], KQ_max_scale[0]);
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
|
|
|
|
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#else // Volta
|
|
|
|
|
const half2 KQ_max_scale_h2 = make_half2(
|
|
|
|
|
KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]);
|
|
|
|
|
@ -707,6 +792,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
|
|
// Therefore, iterate over V in reverse and re-use the data if possible.
|
|
|
|
|
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
|
|
|
|
|
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
|
|
|
|
|
#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
|
|
|
|
|
T_A_VKQ A_identity;
|
|
|
|
|
make_identity_mat(A_identity);
|
|
|
|
|
#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
|
|
|
|
|
|
|
|
|
|
// Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
|
|
|
|
|
#pragma unroll
|
|
|
|
|
@ -727,7 +816,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
|
|
}
|
|
|
|
|
const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
|
|
|
|
|
|
|
|
|
|
#if defined(TURING_MMA_AVAILABLE)
|
|
|
|
|
#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
|
|
|
|
|
@ -737,12 +826,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
|
|
const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J;
|
|
|
|
|
|
|
|
|
|
T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
|
|
|
|
|
#if defined(LDMATRIX_TRANS_AVAILABLE)
|
|
|
|
|
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
|
|
|
|
#else
|
|
|
|
|
// TODO: Try to transpose tile_V when loading gmem to smem.
|
|
|
|
|
// Use mma to transpose T_A_VKQ for RDNA.
|
|
|
|
|
T_A_VKQ A_trans;
|
|
|
|
|
load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
|
|
|
|
|
mma(A, A_trans, A_identity);
|
|
|
|
|
#endif // defined(TURING_MMA_AVAILABLE)
|
|
|
|
|
if constexpr (T_B_KQ::I == 8) {
|
|
|
|
|
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
|
|
|
|
|
} else {
|
|
|
|
|
// Wide version of VKQ_C is column-major => swap A and B.
|
|
|
|
|
// Wide version of VKQ_C is column-major.
|
|
|
|
|
#if defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
// RDNA matrix C is column-major.
|
|
|
|
|
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
|
|
|
|
|
#else
|
|
|
|
|
// swap A and B for CUDA.
|
|
|
|
|
mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
|
|
|
|
|
#endif // defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@ -761,7 +864,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
|
|
mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif // defined(TURING_MMA_AVAILABLE)
|
|
|
|
|
#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
|
|
|
|
|
if constexpr (nstages <= 1) {
|
|
|
|
|
__syncthreads(); // Only needed if tile_K == tile_V.
|
|
|
|
|
@ -774,7 +877,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
|
|
|
tile_Q, tile_K, tile_V, tile_mask,
|
|
|
|
|
Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
|
|
|
|
NO_DEVICE_CODE;
|
|
|
|
|
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
|
|
|
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#if defined(TURING_MMA_AVAILABLE)
|
|
|
|
|
@ -794,6 +897,15 @@ template<> struct mma_tile_sizes<8> {
|
|
|
|
|
using T_B_VKQ = tile< 8, 8, half2>; // column-major
|
|
|
|
|
using T_C_VKQ = tile<16, 4, half2>; // row-major
|
|
|
|
|
};
|
|
|
|
|
#elif defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
template<int ncols> struct mma_tile_sizes {
|
|
|
|
|
using T_A_KQ = tile<16, 8, half2>; // row-major
|
|
|
|
|
using T_B_KQ = tile<16, 8, half2>; // column-major
|
|
|
|
|
using T_C_KQ = tile<16, 16, float>; // column-major
|
|
|
|
|
using T_A_VKQ = tile<16, 8, half2>; // row-major
|
|
|
|
|
using T_B_VKQ = tile<16, 8, half2>; // column-major
|
|
|
|
|
using T_C_VKQ = tile<16, 8, half2>; // column-major
|
|
|
|
|
};
|
|
|
|
|
#else // Volta
|
|
|
|
|
template<int ncols> struct mma_tile_sizes {
|
|
|
|
|
using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
|
|
|
|
|
@ -828,7 +940,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
|
|
const int jt,
|
|
|
|
|
const int kb0_start,
|
|
|
|
|
const int kb0_stop) {
|
|
|
|
|
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
|
|
|
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
|
|
|
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
|
|
|
|
|
|
|
|
|
constexpr int ncols = ncols1 * ncols2;
|
|
|
|
|
@ -840,7 +952,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
|
|
using T_C_VKQ = typename mma_tile_sizes<ncols>::T_C_VKQ;
|
|
|
|
|
|
|
|
|
|
constexpr int cols_per_warp = T_B_KQ::I;
|
|
|
|
|
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
|
|
|
|
|
constexpr int cols_per_thread = get_cols_per_thread();
|
|
|
|
|
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
|
|
|
|
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
|
|
|
|
|
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
|
|
|
|
|
@ -871,6 +983,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
|
|
T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
|
|
|
|
|
#if defined(TURING_MMA_AVAILABLE)
|
|
|
|
|
T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
|
|
|
|
|
#elif defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
|
|
|
|
|
#else // Volta
|
|
|
|
|
T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
|
|
|
|
|
#endif // defined(TURING_MMA_AVAILABLE)
|
|
|
|
|
@ -1010,6 +1124,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
|
|
// The partial sums are spread across 8/4 threads.
|
|
|
|
|
constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
|
|
|
|
|
constexpr int offset_last = cols_per_warp == 8 ? 4 : 1;
|
|
|
|
|
#elif defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
// The partial sums are spread across 2 threads.
|
|
|
|
|
constexpr int offset_first = 16;
|
|
|
|
|
constexpr int offset_last = 16;
|
|
|
|
|
#else // Volta
|
|
|
|
|
// The partial sums are spread across 2 threads.
|
|
|
|
|
constexpr int offset_first = 2;
|
|
|
|
|
@ -1047,7 +1165,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
|
|
|
|
|
|
|
#if defined(TURING_MMA_AVAILABLE)
|
|
|
|
|
if constexpr (cols_per_warp == 8) {
|
|
|
|
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
|
|
|
|
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
|
|
|
|
|
#pragma unroll
|
|
|
|
|
@ -1068,6 +1186,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#elif defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
|
|
|
|
|
#pragma unroll
|
|
|
|
|
for (int l = 0; l < T_C_VKQ::ne; ++l) {
|
|
|
|
|
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#else // Volta
|
|
|
|
|
const int col = (threadIdx.x / 2) % 2;
|
|
|
|
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
|
|
|
|
|
@ -1119,6 +1246,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
|
|
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
|
|
|
|
|
const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
|
|
|
|
|
const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
|
|
|
|
|
#elif defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0);
|
|
|
|
|
const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]);
|
|
|
|
|
const bool thread_should_write = threadIdx.x / 16 < cols_per_thread;
|
|
|
|
|
#else // Volta
|
|
|
|
|
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2);
|
|
|
|
|
const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]);
|
|
|
|
|
@ -1319,7 +1450,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
|
|
|
stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
|
|
|
|
|
jt, kb0_start, kb0_stop);
|
|
|
|
|
NO_DEVICE_CODE;
|
|
|
|
|
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
|
|
|
|
|
#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool mla>
|
|
|
|
|
@ -1346,7 +1477,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
|
|
const int32_t nb21, const int32_t nb22, const int64_t nb23,
|
|
|
|
|
const int32_t ne31, const int32_t ne32, const int32_t ne33,
|
|
|
|
|
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
|
|
|
|
|
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
|
|
|
|
|
#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
|
|
|
|
|
|
|
|
|
|
// Skip unused kernel variants for faster compilation:
|
|
|
|
|
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
|
|
|
|
|
@ -1360,6 +1491,13 @@ static __global__ void flash_attn_ext_f16(
|
|
|
|
|
}
|
|
|
|
|
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
|
|
|
|
|
|
|
|
|
#if defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) {
|
|
|
|
|
NO_DEVICE_CODE;
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
#endif // defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
|
|
|
|
|
static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
|
|
|
|
|
|
|
|
|
|
constexpr int ncols = ncols1 * ncols2;
|
|
|
|
|
@ -1473,7 +1611,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
|
|
ne31, ne32, ne33,
|
|
|
|
|
nb31, nb32, nb33);
|
|
|
|
|
NO_DEVICE_CODE;
|
|
|
|
|
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
|
|
|
|
|
#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <int DKQ, int DV, int ncols1, int ncols2>
|
|
|
|
|
@ -1492,7 +1630,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
|
|
|
const bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols, cc);
|
|
|
|
|
const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc);
|
|
|
|
|
|
|
|
|
|
const int cols_per_warp = std::min(ncols, turing_mma_available(cc) ? 16 : 32);
|
|
|
|
|
const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
|
|
|
|
|
const int nwarps = nthreads / WARP_SIZE;
|
|
|
|
|
|
|
|
|
|
constexpr bool mla = DKQ == 576;
|
|
|
|
|
@ -1512,29 +1650,34 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
|
|
|
float logit_softcap;
|
|
|
|
|
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
|
|
|
|
|
|
|
|
|
#if defined(GGML_USE_HIP)
|
|
|
|
|
using fattn_kernel_ptr_t = const void*;
|
|
|
|
|
#else
|
|
|
|
|
using fattn_kernel_ptr_t = fattn_kernel_t;
|
|
|
|
|
#endif // defined(GGML_USE_HIP)
|
|
|
|
|
fattn_kernel_t fattn_kernel;
|
|
|
|
|
if (logit_softcap == 0.0f) {
|
|
|
|
|
constexpr bool use_logit_softcap = false;
|
|
|
|
|
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
|
|
|
|
|
|
|
|
|
|
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
|
|
|
|
#if !defined(GGML_USE_MUSA)
|
|
|
|
|
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
|
|
|
if (!shared_memory_limit_raised[id]) {
|
|
|
|
|
CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
|
|
|
|
CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
|
|
|
|
shared_memory_limit_raised[id] = true;
|
|
|
|
|
}
|
|
|
|
|
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
|
|
|
|
#endif // !defined(GGML_USE_MUSA)
|
|
|
|
|
} else {
|
|
|
|
|
constexpr bool use_logit_softcap = true;
|
|
|
|
|
fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
|
|
|
|
|
|
|
|
|
|
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
|
|
|
|
#if !defined(GGML_USE_MUSA)
|
|
|
|
|
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
|
|
|
|
if (!shared_memory_limit_raised[id]) {
|
|
|
|
|
CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
|
|
|
|
CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
|
|
|
|
|
shared_memory_limit_raised[id] = true;
|
|
|
|
|
}
|
|
|
|
|
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
|
|
|
|
#endif // !defined(GGML_USE_MUSA)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
launch_fattn<DV, ncols1, ncols2>
|
|
|
|
|
|