speed up compile
This commit is contained in:
parent
fe24848d00
commit
be5ec5a6f0
|
|
@ -1054,6 +1054,13 @@ namespace ggml_cuda_mma {
|
|||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // RDNA4
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
|
||||
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
||||
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
|
||||
const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]);
|
||||
const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]);
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0);
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
|
|
@ -1081,11 +1088,18 @@ namespace ggml_cuda_mma {
|
|||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // RDNA4
|
||||
#endif // defined(RDNA4)
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16;
|
||||
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
||||
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
|
||||
const bf16x4_t& a_frag = reinterpret_cast<const bf16x4_t&>(A.x[0]);
|
||||
const bf16x4_t& b_frag = reinterpret_cast<const bf16x4_t&>(B.x[0]);
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0);
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // AMPERE_MMA_AVAILABLE
|
||||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template <data_layout dl_d, data_layout dl_ab>
|
||||
|
|
|
|||
|
|
@ -2,11 +2,12 @@
|
|||
#include "mmf.cuh"
|
||||
#include "mmid.cuh"
|
||||
|
||||
constexpr int mmf_rows_per_block = 32;
|
||||
constexpr int mmf_rows_per_block_cdna = 64;
|
||||
|
||||
static int get_mmf_rows_per_block(const int cc, const int warp_size) {
|
||||
return warp_size;
|
||||
static int get_mmf_rows_per_block(const int cc) {
|
||||
if (GGML_CUDA_CC_IS_CDNA(cc)) {
|
||||
return MMF_ROWS_PER_BLOCK_CDNA;
|
||||
} else {
|
||||
return MMF_ROWS_PER_BLOCK;
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||
|
|
@ -97,10 +98,9 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
|||
|
||||
const int device = ggml_cuda_get_device();
|
||||
const int cc = ggml_cuda_info().devices[device].cc;
|
||||
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
||||
const int rows_per_block = get_mmf_rows_per_block(cc, warp_size);
|
||||
const int rows_per_block = get_mmf_rows_per_block(cc);
|
||||
|
||||
if (rows_per_block != mmf_rows_per_block && rows_per_block != mmf_rows_per_block_cdna) {
|
||||
if (rows_per_block != MMF_ROWS_PER_BLOCK && rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {
|
||||
GGML_ABORT("unsupported rows_per_block: %i", rows_per_block);
|
||||
}
|
||||
|
||||
|
|
@ -108,13 +108,13 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
|||
case GGML_TYPE_F32: {
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
constexpr int vals_per_T = 1;
|
||||
if (rows_per_block == mmf_rows_per_block) {
|
||||
mul_mat_f_switch_cols_per_block<float, mmf_rows_per_block>(
|
||||
if (rows_per_block == MMF_ROWS_PER_BLOCK) {
|
||||
mul_mat_f_switch_cols_per_block<float, MMF_ROWS_PER_BLOCK>(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
} else {
|
||||
mul_mat_f_switch_cols_per_block<float, mmf_rows_per_block_cdna>(
|
||||
mul_mat_f_switch_cols_per_block<float, MMF_ROWS_PER_BLOCK_CDNA>(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
|
|
@ -123,13 +123,13 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
|||
case GGML_TYPE_F16: {
|
||||
const half2 * src0_d = (const half2 *) src0->data;
|
||||
constexpr int vals_per_T = 2;
|
||||
if (rows_per_block == mmf_rows_per_block) {
|
||||
mul_mat_f_switch_cols_per_block<half2, mmf_rows_per_block>(
|
||||
if (rows_per_block == MMF_ROWS_PER_BLOCK) {
|
||||
mul_mat_f_switch_cols_per_block<half2, MMF_ROWS_PER_BLOCK>(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
} else {
|
||||
mul_mat_f_switch_cols_per_block<half2, mmf_rows_per_block_cdna>(
|
||||
mul_mat_f_switch_cols_per_block<half2, MMF_ROWS_PER_BLOCK_CDNA>(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
|
|
@ -138,13 +138,13 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
|||
case GGML_TYPE_BF16: {
|
||||
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
|
||||
constexpr int vals_per_T = 2;
|
||||
if (rows_per_block == mmf_rows_per_block) {
|
||||
mul_mat_f_switch_cols_per_block<nv_bfloat162, mmf_rows_per_block>(
|
||||
if (rows_per_block == MMF_ROWS_PER_BLOCK) {
|
||||
mul_mat_f_switch_cols_per_block<nv_bfloat162, MMF_ROWS_PER_BLOCK>(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
} else {
|
||||
mul_mat_f_switch_cols_per_block<nv_bfloat162, mmf_rows_per_block_cdna>(
|
||||
mul_mat_f_switch_cols_per_block<nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA>(
|
||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||
|
|
@ -176,7 +176,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
|||
return false;
|
||||
}
|
||||
}
|
||||
if (src0_ne[1] % get_mmf_rows_per_block(cc, warp_size) != 0) {
|
||||
if (src0_ne[1] % get_mmf_rows_per_block(cc) != 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,9 @@
|
|||
|
||||
using namespace ggml_cuda_mma;
|
||||
|
||||
#define MMF_ROWS_PER_BLOCK 32
|
||||
#define MMF_ROWS_PER_BLOCK_CDNA 64
|
||||
|
||||
struct mmf_ids_data {
|
||||
const int32_t * ids_src_compact = nullptr;
|
||||
const int32_t * ids_dst_compact = nullptr;
|
||||
|
|
@ -29,21 +32,18 @@ static __global__ void mul_mat_f(
|
|||
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
|
||||
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
||||
constexpr bool is_tf32 = std::is_same_v<T, float>;
|
||||
constexpr int tile_B_I = is_tf32 ? 8 : 16;
|
||||
constexpr int tile_C_J = is_tf32 ? 8 : 16;
|
||||
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
|
||||
typedef tile<16, 8, T, ab_layout> tile_A;
|
||||
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
|
||||
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T, get_input_data_layout()> tile_A;
|
||||
typedef tile<16, 8, T, get_input_data_layout()> tile_B;
|
||||
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
#else
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
|
||||
if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
|
||||
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
|
||||
#else
|
||||
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<16, 8, float> tile_C;
|
||||
|
|
@ -253,9 +253,7 @@ static __global__ void mul_mat_f(
|
|||
}
|
||||
}
|
||||
}
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
}
|
||||
#endif //VOLTA_MMA_AVAILABLE
|
||||
#else
|
||||
GGML_UNUSED_VARS(x, y, ids, dst,
|
||||
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
|
|
@ -280,21 +278,18 @@ static __global__ void mul_mat_f_ids(
|
|||
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
|
||||
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||
#if defined(AMD_WMMA_AVAILABLE)
|
||||
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
||||
constexpr bool is_tf32 = std::is_same_v<T, float>;
|
||||
constexpr int tile_B_I = is_tf32 ? 8 : 16;
|
||||
constexpr int tile_C_J = is_tf32 ? 8 : 16;
|
||||
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
|
||||
typedef tile<16, 8, T, ab_layout> tile_A;
|
||||
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
|
||||
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T, get_input_data_layout()> tile_A;
|
||||
typedef tile<16, 8, T, get_input_data_layout()> tile_B;
|
||||
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
#else
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
|
||||
if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
|
||||
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
|
||||
#else
|
||||
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T> tile_A;
|
||||
typedef tile<8, 8, T> tile_B;
|
||||
typedef tile<16, 8, float> tile_C;
|
||||
|
|
@ -530,9 +525,7 @@ static __global__ void mul_mat_f_ids(
|
|||
}
|
||||
}
|
||||
}
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
}
|
||||
#endif // VOLTA_MMA_AVAILABLE
|
||||
#else
|
||||
GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
|
||||
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||
|
|
|
|||
Loading…
Reference in New Issue