CUDA: add gqa_ratio 4 for GLM 4.7 flash (#18953)
This commit is contained in:
parent
5516b9c16a
commit
b70d251076
|
|
@ -432,7 +432,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||||
constexpr int ncols = ncols1 * ncols2;
|
constexpr int ncols = ncols1 * ncols2;
|
||||||
constexpr int cols_per_warp = T_B_KQ::I;
|
constexpr int cols_per_warp = T_B_KQ::I;
|
||||||
constexpr int cols_per_thread = get_cols_per_thread();
|
constexpr int cols_per_thread = get_cols_per_thread();
|
||||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
|
||||||
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
|
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
|
||||||
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
|
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
|
||||||
|
|
@ -510,7 +510,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
|
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
|
||||||
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
|
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
|
||||||
|
|
@ -522,14 +521,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
||||||
T_A_KQ K_A;
|
T_A_KQ K_A;
|
||||||
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
|
||||||
|
|
||||||
// Wide version of KQ_C is column-major
|
if constexpr (cols_per_warp == 8) {
|
||||||
|
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||||
|
} else {
|
||||||
|
// Wide version of KQ_C is column-major
|
||||||
#if defined(AMD_WMMA_AVAILABLE)
|
#if defined(AMD_WMMA_AVAILABLE)
|
||||||
// RDNA matrix C is column-major.
|
// RDNA matrix C is column-major.
|
||||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
|
||||||
#else
|
#else
|
||||||
// swap A and B for CUDA.
|
// swap A and B for CUDA.
|
||||||
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
|
||||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -953,7 +956,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||||
|
|
||||||
constexpr int cols_per_warp = T_B_KQ::I;
|
constexpr int cols_per_warp = T_B_KQ::I;
|
||||||
constexpr int cols_per_thread = get_cols_per_thread();
|
constexpr int cols_per_thread = get_cols_per_thread();
|
||||||
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
|
||||||
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
|
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
|
||||||
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
|
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
|
||||||
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
|
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
|
||||||
|
|
@ -1484,6 +1487,13 @@ static __global__ void flash_attn_ext_f16(
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
#ifdef VOLTA_MMA_AVAILABLE
|
||||||
|
if (ncols1*ncols2 < 32) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
#endif // VOLTA_MMA_AVAILABLE
|
||||||
|
|
||||||
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
|
||||||
if (ncols1*ncols2 > 32) {
|
if (ncols1*ncols2 > 32) {
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
|
|
@ -1728,3 +1738,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
|
||||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
|
||||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
|
||||||
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
|
||||||
|
|
||||||
|
// For GLM 4.7 Flash
|
||||||
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||||
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||||
|
extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||||
|
|
|
||||||
|
|
@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
|
||||||
|
|
||||||
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||||
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
|
|
@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
|
||||||
|
|
||||||
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
|
||||||
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
|
|
@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
|
||||||
|
|
||||||
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||||
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
|
||||||
|
|
||||||
|
|
@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
|
||||||
|
|
||||||
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
|
||||||
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
|
||||||
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
|
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
|
||||||
|
|
||||||
|
|
@ -1187,6 +1195,10 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
|
||||||
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
|
launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (use_gqa_opt && gqa_ratio % 4 == 0) {
|
||||||
|
launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (DV <= 256) {
|
if constexpr (DV <= 256) {
|
||||||
|
|
|
||||||
|
|
@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||||
|
|
||||||
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
||||||
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
||||||
GGML_ASSERT(gqa_ratio % 16 == 0);
|
GGML_ASSERT(gqa_ratio % 4 == 0);
|
||||||
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
if (gqa_ratio % 16 == 0) {
|
||||||
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
|
||||||
|
} else {
|
||||||
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
|
|
@ -262,7 +266,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
||||||
if (V->ne[0] != 512) {
|
if (V->ne[0] != 512) {
|
||||||
return BEST_FATTN_KERNEL_NONE;
|
return BEST_FATTN_KERNEL_NONE;
|
||||||
}
|
}
|
||||||
if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
|
if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
|
||||||
return BEST_FATTN_KERNEL_NONE;
|
return BEST_FATTN_KERNEL_NONE;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
|
||||||
|
|
@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
|
||||||
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
|
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
|
||||||
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
|
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
|
||||||
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
|
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
|
||||||
|
DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
|
||||||
|
|
|
||||||
|
|
@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
|
||||||
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
|
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
|
||||||
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
|
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
|
||||||
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
|
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
|
||||||
|
DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
|
||||||
|
|
|
||||||
|
|
@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
|
||||||
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
|
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
|
||||||
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
|
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
|
||||||
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
|
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
|
||||||
|
DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
|
||||||
|
|
|
||||||
|
|
@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
|
||||||
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
|
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
|
||||||
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
|
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
|
||||||
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
|
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
|
||||||
|
DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,7 @@ for ncols in [8, 16, 32, 64]:
|
||||||
continue
|
continue
|
||||||
if head_size_kq != 576 and ncols2 == 16:
|
if head_size_kq != 576 and ncols2 == 16:
|
||||||
continue
|
continue
|
||||||
if head_size_kq == 576 and ncols2 != 16:
|
if head_size_kq == 576 and ncols2 not in (4, 16):
|
||||||
continue
|
continue
|
||||||
head_size_v = head_size_kq if head_size_kq != 576 else 512
|
head_size_v = head_size_kq if head_size_kq != 576 else 512
|
||||||
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
|
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue