f32 mmf
This commit is contained in:
parent
7b43cbc083
commit
cd8a31ceb5
|
|
@ -986,6 +986,43 @@ namespace ggml_cuda_mma {
|
|||
#endif // AMPERE_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template <data_layout dl_ab, data_layout dl_d>
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<16, 8, float, dl_ab> & B) {
|
||||
#ifdef AMD_MFMA_AVAILABLE
|
||||
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
||||
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
|
||||
#if defined(CDNA3)
|
||||
#if 0
|
||||
using floatx2_t = __attribute__((ext_vector_type(2))) float;
|
||||
const floatx2_t& a_frag = reinterpret_cast<const floatx2_t&>(A.x[0]);
|
||||
const floatx2_t& b_frag = reinterpret_cast<const floatx2_t&>(B.x[0]);
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0);
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
const float& a_frag = reinterpret_cast<const float&>(A.x[i]);
|
||||
const float& b_frag = reinterpret_cast<const float&>(B.x[i]);
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(a_frag, b_frag, acc_frag, 0, 0, 0);
|
||||
}
|
||||
#endif
|
||||
#elif defined(CDNA2) || defined(CDNA1)
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
const float& a_frag = reinterpret_cast<const float&>(A.x[i]);
|
||||
const float& b_frag = reinterpret_cast<const float&>(B.x[i]);
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(a_frag, b_frag, acc_frag, 0, 0, 0);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(CDNA3)
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // AMD_MFMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
|
||||
const tile<16, 8, int> & A,
|
||||
const tile<8, 8, int> & B,
|
||||
|
|
|
|||
|
|
@ -196,7 +196,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
|||
|
||||
switch (type) {
|
||||
case GGML_TYPE_F32:
|
||||
return ampere_mma_available(cc);
|
||||
return ampere_mma_available(cc) || amd_mfma_available(cc);
|
||||
case GGML_TYPE_F16:
|
||||
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
|
||||
case GGML_TYPE_BF16:
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ static __global__ void mul_mat_f(
|
|||
typedef tile<16, 8, T, get_input_data_layout()> tile_B;
|
||||
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
|
||||
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
|
||||
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
|
|
|
|||
Loading…
Reference in New Issue