diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 42085d1002..30c08b6ffb 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -1054,6 +1054,13 @@ namespace ggml_cuda_mma { GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // RDNA4 +#elif defined(AMD_MFMA_AVAILABLE) + using halfx4_t = __attribute__((ext_vector_type(4))) _Float16; + using floatx4_t = __attribute__((ext_vector_type(4))) float; + floatx4_t& acc_frag = reinterpret_cast(D.x[0]); + const halfx4_t& a_frag = reinterpret_cast(A.x[0]); + const halfx4_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0); #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -1081,11 +1088,18 @@ namespace ggml_cuda_mma { #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // RDNA4 +#endif // defined(RDNA4) +#elif defined(AMD_MFMA_AVAILABLE) + using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16; + using floatx4_t = __attribute__((ext_vector_type(4))) float; + floatx4_t& acc_frag = reinterpret_cast(D.x[0]); + const bf16x4_t& a_frag = reinterpret_cast(A.x[0]); + const bf16x4_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0); #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // AMPERE_MMA_AVAILABLE +#endif // defined(AMD_WMMA_AVAILABLE) } template diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 35d48f614c..96fd3c5847 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -2,11 +2,12 @@ #include "mmf.cuh" #include "mmid.cuh" -constexpr int mmf_rows_per_block = 32; -constexpr int mmf_rows_per_block_cdna = 64; - -static int get_mmf_rows_per_block(const int cc, const int warp_size) { - return warp_size; +static int get_mmf_rows_per_block(const int cc) { + if (GGML_CUDA_CC_IS_CDNA(cc)) { + return MMF_ROWS_PER_BLOCK_CDNA; + } else { + return MMF_ROWS_PER_BLOCK; + } } void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { @@ -97,10 +98,9 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr const int device = ggml_cuda_get_device(); const int cc = ggml_cuda_info().devices[device].cc; - const int warp_size = ggml_cuda_info().devices[device].warp_size; - const int rows_per_block = get_mmf_rows_per_block(cc, warp_size); + const int rows_per_block = get_mmf_rows_per_block(cc); - if (rows_per_block != mmf_rows_per_block && rows_per_block != mmf_rows_per_block_cdna) { + if (rows_per_block != MMF_ROWS_PER_BLOCK && rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) { GGML_ABORT("unsupported rows_per_block: %i", rows_per_block); } @@ -108,13 +108,13 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; constexpr int vals_per_T = 1; - if (rows_per_block == mmf_rows_per_block) { - mul_mat_f_switch_cols_per_block( + if (rows_per_block == MMF_ROWS_PER_BLOCK) { + mul_mat_f_switch_cols_per_block( src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } else { - mul_mat_f_switch_cols_per_block( + mul_mat_f_switch_cols_per_block( src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); @@ -123,13 +123,13 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr case GGML_TYPE_F16: { const half2 * src0_d = (const half2 *) src0->data; constexpr int vals_per_T = 2; - if (rows_per_block == mmf_rows_per_block) { - mul_mat_f_switch_cols_per_block( + if (rows_per_block == MMF_ROWS_PER_BLOCK) { + mul_mat_f_switch_cols_per_block( src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } else { - mul_mat_f_switch_cols_per_block( + mul_mat_f_switch_cols_per_block( src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); @@ -138,13 +138,13 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr case GGML_TYPE_BF16: { const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; constexpr int vals_per_T = 2; - if (rows_per_block == mmf_rows_per_block) { - mul_mat_f_switch_cols_per_block( + if (rows_per_block == MMF_ROWS_PER_BLOCK) { + mul_mat_f_switch_cols_per_block( src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } else { - mul_mat_f_switch_cols_per_block( + mul_mat_f_switch_cols_per_block( src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); @@ -176,7 +176,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const return false; } } - if (src0_ne[1] % get_mmf_rows_per_block(cc, warp_size) != 0) { + if (src0_ne[1] % get_mmf_rows_per_block(cc) != 0) { return false; } diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index e6e9d79aa1..d4a0bee322 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -6,6 +6,9 @@ using namespace ggml_cuda_mma; +#define MMF_ROWS_PER_BLOCK 32 +#define MMF_ROWS_PER_BLOCK_CDNA 64 + struct mmf_ids_data { const int32_t * ids_src_compact = nullptr; const int32_t * ids_dst_compact = nullptr; @@ -29,21 +32,18 @@ static __global__ void mul_mat_f( // TODO: handle this in a consistent and simpler way after AMD MFMA support has been added #if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE) - // Special case for tf32, just dummy mma layout as wmma doesn't support it. - constexpr bool is_tf32 = std::is_same_v; - constexpr int tile_B_I = is_tf32 ? 8 : 16; - constexpr int tile_C_J = is_tf32 ? 8 : 16; - constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout(); - typedef tile<16, 8, T, ab_layout> tile_A; - typedef tile tile_B; - typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C; + if constexpr (!(std::is_same_v || std::is_same_v) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { + typedef tile<16, 8, T, get_input_data_layout()> tile_A; + typedef tile<16, 8, T, get_input_data_layout()> tile_B; + typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; #else #ifdef VOLTA_MMA_AVAILABLE - if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else { + if constexpr (!std::is_same_v || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; #else + if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { typedef tile<16, 8, T> tile_A; typedef tile<8, 8, T> tile_B; typedef tile<16, 8, float> tile_C; @@ -253,9 +253,7 @@ static __global__ void mul_mat_f( } } } -#ifdef VOLTA_MMA_AVAILABLE } -#endif //VOLTA_MMA_AVAILABLE #else GGML_UNUSED_VARS(x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, @@ -280,21 +278,18 @@ static __global__ void mul_mat_f_ids( // TODO: handle this in a consistent and simpler way after AMD MFMA support has been added #if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE) - // Special case for tf32, just dummy mma layout as wmma doesn't support it. - constexpr bool is_tf32 = std::is_same_v; - constexpr int tile_B_I = is_tf32 ? 8 : 16; - constexpr int tile_C_J = is_tf32 ? 8 : 16; - constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout(); - typedef tile<16, 8, T, ab_layout> tile_A; - typedef tile tile_B; - typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C; + if constexpr (!(std::is_same_v || std::is_same_v) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { + typedef tile<16, 8, T, get_input_data_layout()> tile_A; + typedef tile<16, 8, T, get_input_data_layout()> tile_B; + typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; #else #ifdef VOLTA_MMA_AVAILABLE - if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else { + if constexpr (!std::is_same_v || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; #else + if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { typedef tile<16, 8, T> tile_A; typedef tile<8, 8, T> tile_B; typedef tile<16, 8, float> tile_C; @@ -530,9 +525,7 @@ static __global__ void mul_mat_f_ids( } } } -#ifdef VOLTA_MMA_AVAILABLE } -#endif // VOLTA_MMA_AVAILABLE #else GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,