diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index dd45d6c78f..f3f1a07720 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -345,11 +345,12 @@ namespace ggml_cuda_mma { static constexpr __device__ bool supported() { if (I == 16 && J == 8) return true; + if (I == 16 && J == 16) return true; return false; } static __device__ __forceinline__ int get_i(const int l) { - if constexpr (I == 16 && J == 8) { + if constexpr (supported()) { return threadIdx.x % 16; } else { NO_DEVICE_CODE; @@ -358,7 +359,7 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_j(const int l) { - if constexpr (I == 16 && J == 8) { + if constexpr (supported()) { return ne * (threadIdx.x / 16) + l; } else { NO_DEVICE_CODE; @@ -812,6 +813,14 @@ namespace ggml_cuda_mma { #endif // TURING_MMA_AVAILABLE } + template + static __device__ __forceinline__ void load_ldmatrix( + tile<16, 16, T, dl> & t, const T * __restrict__ xs0, const int stride) { +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + load_generic(t, xs0, stride); +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + } + static __device__ __forceinline__ void load_ldmatrix( tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) { ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride); @@ -1012,6 +1021,35 @@ namespace ggml_cuda_mma { #endif // AMD_MFMA_AVAILABLE } + template + static __device__ __forceinline__ void mma( + tile<16, 16, float, dl_d> & D, const tile<16, 16, float, dl_ab> & A, const tile<16, 16, float, dl_ab> & B) { +#ifdef AMD_MFMA_AVAILABLE + using floatx4_t = __attribute__((ext_vector_type(4))) float; + floatx4_t& acc_frag = reinterpret_cast(D.x[0]); +#if defined(CDNA3) +#pragma unroll + for (int i = 0; i < 4; i+= 2) { + using floatx2_t = __attribute__((ext_vector_type(2))) float; + const floatx2_t& a_frag = reinterpret_cast(A.x[i]); + const floatx2_t& b_frag = reinterpret_cast(B.x[i]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0); + } +#elif defined(CDNA2) || defined(CDNA1) +#pragma unroll + for (int i = 0; i < 4; ++i) { + acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0); + } +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(CDNA3) +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE + } + static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B, @@ -1134,6 +1172,25 @@ namespace ggml_cuda_mma { #endif // TURING_MMA_AVAILABLE } + template + static __device__ __forceinline__ void mma( + tile<16, 16, float, dl_d> & D, const tile<16, 16, half2, dl_ab> & A, const tile<16, 16, half2, dl_ab> & B) { +#if defined(AMD_MFMA_AVAILABLE) + using halfx4_t = __attribute__((ext_vector_type(4))) _Float16; + using floatx4_t = __attribute__((ext_vector_type(4))) float; + floatx4_t& acc_frag = reinterpret_cast(D.x[0]); +#pragma unroll + for (int i = 0; i < 4; i += 2) { + const halfx4_t& a_frag = reinterpret_cast(A.x[i]); + const halfx4_t& b_frag = reinterpret_cast(B.x[i]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0); + } +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(AMD_MFMA_AVAILABLE) + } + template static __device__ __forceinline__ void mma( tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) { @@ -1182,6 +1239,38 @@ namespace ggml_cuda_mma { #endif // defined(AMD_WMMA_AVAILABLE) } + template + static __device__ __forceinline__ void mma( + tile<16, 16, float, dl_d> & D, const tile<16, 16, nv_bfloat162, dl_ab> & A, const tile<16, 16, nv_bfloat162, dl_ab> & B) { +#if defined(AMD_MFMA_AVAILABLE) + using floatx4_t = __attribute__((ext_vector_type(4))) float; + floatx4_t& acc_frag = reinterpret_cast(D.x[0]); +#if defined(CDNA3) || defined(CDNA2) + using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16; +#pragma unroll + for (int i = 0; i < 4; i += 2) { + const bf16x4_t& a_frag = reinterpret_cast(A.x[i]); + const bf16x4_t& b_frag = reinterpret_cast(B.x[i]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0); + } +#elif defined(CDNA1) +#pragma unroll + for (int i = 0; i < 4; ++i) { + using bf16x2_t = __attribute__((ext_vector_type(2))) __bf16; + const bf16x2_t& a_frag = reinterpret_cast(A.x[i]); + const bf16x2_t& b_frag = reinterpret_cast(B.x[i]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x8bf16(a_frag, b_frag, acc_frag, 0, 0, 0); + } +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(CDNA3) || defined(CDNA2) +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(AMD_MFMA_AVAILABLE) + } + template static __device__ __forceinline__ void mma( tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) { diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index 7d6f61cb85..8fc184a1d1 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -38,8 +38,8 @@ static __global__ void mul_mat_f( typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; #elif defined(AMD_MFMA_AVAILABLE) if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else { - typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A; - typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B; + typedef tile<16, 16, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile<16, 16, T, DATA_LAYOUT_I_MAJOR> tile_B; typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; #else #ifdef VOLTA_MMA_AVAILABLE @@ -289,8 +289,8 @@ static __global__ void mul_mat_f_ids( typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; #elif defined(AMD_MFMA_AVAILABLE) if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else { - typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A; - typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B; + typedef tile<16, 16, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile<16, 16, T, DATA_LAYOUT_I_MAJOR> tile_B; typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; #else #ifdef VOLTA_MMA_AVAILABLE