diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 323173b90e..cd0d6c3e2a 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -986,6 +986,43 @@ namespace ggml_cuda_mma { #endif // AMPERE_MMA_AVAILABLE } + template + static __device__ __forceinline__ void mma( + tile<16, 16, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<16, 8, float, dl_ab> & B) { +#ifdef AMD_MFMA_AVAILABLE + using floatx4_t = __attribute__((ext_vector_type(4))) float; + floatx4_t& acc_frag = reinterpret_cast(D.x[0]); +#if defined(CDNA3) +#if 0 + using floatx2_t = __attribute__((ext_vector_type(2))) float; + const floatx2_t& a_frag = reinterpret_cast(A.x[0]); + const floatx2_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0); +#else +#pragma unroll + for (int i = 0; i < 2; ++i) { + const float& a_frag = reinterpret_cast(A.x[i]); + const float& b_frag = reinterpret_cast(B.x[i]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(a_frag, b_frag, acc_frag, 0, 0, 0); + } +#endif +#elif defined(CDNA2) || defined(CDNA1) +#pragma unroll + for (int i = 0; i < 2; ++i) { + const float& a_frag = reinterpret_cast(A.x[i]); + const float& b_frag = reinterpret_cast(B.x[i]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(a_frag, b_frag, acc_frag, 0, 0, 0); + } +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(CDNA3) +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE + } + static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B, diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 4e03d0db08..0ea7ffa350 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -196,7 +196,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const switch (type) { case GGML_TYPE_F32: - return ampere_mma_available(cc); + return ampere_mma_available(cc) || amd_mfma_available(cc); case GGML_TYPE_F16: return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc); case GGML_TYPE_BF16: diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index 43fc4f6b4b..db22502383 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -37,7 +37,7 @@ static __global__ void mul_mat_f( typedef tile<16, 8, T, get_input_data_layout()> tile_B; typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; #elif defined(AMD_MFMA_AVAILABLE) - if constexpr (!(std::is_same_v || std::is_same_v) || rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else { + if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else { typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A; typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B; typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;