This commit is contained in:
zhang hui 2026-01-17 19:40:57 +08:00
parent 7b43cbc083
commit cd8a31ceb5
3 changed files with 39 additions and 2 deletions

View File

@ -986,6 +986,43 @@ namespace ggml_cuda_mma {
#endif // AMPERE_MMA_AVAILABLE
}
template <data_layout dl_ab, data_layout dl_d>
static __device__ __forceinline__ void mma(
tile<16, 16, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<16, 8, float, dl_ab> & B) {
#ifdef AMD_MFMA_AVAILABLE
using floatx4_t = __attribute__((ext_vector_type(4))) float;
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
#if defined(CDNA3)
#if 0
using floatx2_t = __attribute__((ext_vector_type(2))) float;
const floatx2_t& a_frag = reinterpret_cast<const floatx2_t&>(A.x[0]);
const floatx2_t& b_frag = reinterpret_cast<const floatx2_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0);
#else
#pragma unroll
for (int i = 0; i < 2; ++i) {
const float& a_frag = reinterpret_cast<const float&>(A.x[i]);
const float& b_frag = reinterpret_cast<const float&>(B.x[i]);
acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(a_frag, b_frag, acc_frag, 0, 0, 0);
}
#endif
#elif defined(CDNA2) || defined(CDNA1)
#pragma unroll
for (int i = 0; i < 2; ++i) {
const float& a_frag = reinterpret_cast<const float&>(A.x[i]);
const float& b_frag = reinterpret_cast<const float&>(B.x[i]);
acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(a_frag, b_frag, acc_frag, 0, 0, 0);
}
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // defined(CDNA3)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // AMD_MFMA_AVAILABLE
}
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
const tile<16, 8, int> & A,
const tile<8, 8, int> & B,

View File

@ -196,7 +196,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
switch (type) {
case GGML_TYPE_F32:
return ampere_mma_available(cc);
return ampere_mma_available(cc) || amd_mfma_available(cc);
case GGML_TYPE_F16:
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
case GGML_TYPE_BF16:

View File

@ -37,7 +37,7 @@ static __global__ void mul_mat_f(
typedef tile<16, 8, T, get_input_data_layout()> tile_B;
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
#elif defined(AMD_MFMA_AVAILABLE)
if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;