extend tile k
This commit is contained in:
parent
250ae9aee8
commit
4d2ef3d483
|
|
@ -345,11 +345,12 @@ namespace ggml_cuda_mma {
|
|||
|
||||
static constexpr __device__ bool supported() {
|
||||
if (I == 16 && J == 8) return true;
|
||||
if (I == 16 && J == 16) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
if constexpr (supported()) {
|
||||
return threadIdx.x % 16;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
|
|
@ -358,7 +359,7 @@ namespace ggml_cuda_mma {
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
if constexpr (I == 16 && J == 8) {
|
||||
if constexpr (supported()) {
|
||||
return ne * (threadIdx.x / 16) + l;
|
||||
} else {
|
||||
NO_DEVICE_CODE;
|
||||
|
|
@ -812,6 +813,14 @@ namespace ggml_cuda_mma {
|
|||
#endif // TURING_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template <typename T, data_layout dl>
|
||||
static __device__ __forceinline__ void load_ldmatrix(
|
||||
tile<16, 16, T, dl> & t, const T * __restrict__ xs0, const int stride) {
|
||||
#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
load_generic(t, xs0, stride);
|
||||
#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void load_ldmatrix(
|
||||
tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
|
||||
ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
|
||||
|
|
@ -1012,6 +1021,35 @@ namespace ggml_cuda_mma {
|
|||
#endif // AMD_MFMA_AVAILABLE
|
||||
}
|
||||
|
||||
template <data_layout dl_ab, data_layout dl_d>
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, float, dl_d> & D, const tile<16, 16, float, dl_ab> & A, const tile<16, 16, float, dl_ab> & B) {
|
||||
#ifdef AMD_MFMA_AVAILABLE
|
||||
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
||||
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
|
||||
#if defined(CDNA3)
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i+= 2) {
|
||||
using floatx2_t = __attribute__((ext_vector_type(2))) float;
|
||||
const floatx2_t& a_frag = reinterpret_cast<const floatx2_t&>(A.x[i]);
|
||||
const floatx2_t& b_frag = reinterpret_cast<const floatx2_t&>(B.x[i]);
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0);
|
||||
}
|
||||
#elif defined(CDNA2) || defined(CDNA1)
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(CDNA3)
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // AMD_MFMA_AVAILABLE
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
|
||||
const tile<16, 8, int> & A,
|
||||
const tile<8, 8, int> & B,
|
||||
|
|
@ -1134,6 +1172,25 @@ namespace ggml_cuda_mma {
|
|||
#endif // TURING_MMA_AVAILABLE
|
||||
}
|
||||
|
||||
template <data_layout dl_ab, data_layout dl_d>
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, float, dl_d> & D, const tile<16, 16, half2, dl_ab> & A, const tile<16, 16, half2, dl_ab> & B) {
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
|
||||
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
||||
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i += 2) {
|
||||
const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[i]);
|
||||
const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[i]);
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template <data_layout dl_ab, data_layout dl_d>
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) {
|
||||
|
|
@ -1182,6 +1239,38 @@ namespace ggml_cuda_mma {
|
|||
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template <data_layout dl_ab, data_layout dl_d>
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, float, dl_d> & D, const tile<16, 16, nv_bfloat162, dl_ab> & A, const tile<16, 16, nv_bfloat162, dl_ab> & B) {
|
||||
#if defined(AMD_MFMA_AVAILABLE)
|
||||
using floatx4_t = __attribute__((ext_vector_type(4))) float;
|
||||
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
|
||||
#if defined(CDNA3) || defined(CDNA2)
|
||||
using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i += 2) {
|
||||
const bf16x4_t& a_frag = reinterpret_cast<const bf16x4_t&>(A.x[i]);
|
||||
const bf16x4_t& b_frag = reinterpret_cast<const bf16x4_t&>(B.x[i]);
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0);
|
||||
}
|
||||
#elif defined(CDNA1)
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
using bf16x2_t = __attribute__((ext_vector_type(2))) __bf16;
|
||||
const bf16x2_t& a_frag = reinterpret_cast<const bf16x2_t&>(A.x[i]);
|
||||
const bf16x2_t& b_frag = reinterpret_cast<const bf16x2_t&>(B.x[i]);
|
||||
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8bf16(a_frag, b_frag, acc_frag, 0, 0, 0);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(CDNA3) || defined(CDNA2)
|
||||
#else
|
||||
GGML_UNUSED_VARS(D, A, B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // defined(AMD_MFMA_AVAILABLE)
|
||||
}
|
||||
|
||||
template <data_layout dl_d, data_layout dl_ab>
|
||||
static __device__ __forceinline__ void mma(
|
||||
tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) {
|
||||
|
|
|
|||
|
|
@ -38,8 +38,8 @@ static __global__ void mul_mat_f(
|
|||
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
|
||||
typedef tile<16, 16, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile<16, 16, T, DATA_LAYOUT_I_MAJOR> tile_B;
|
||||
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
#else
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
|
|
@ -289,8 +289,8 @@ static __global__ void mul_mat_f_ids(
|
|||
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
#elif defined(AMD_MFMA_AVAILABLE)
|
||||
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
|
||||
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
|
||||
typedef tile<16, 16, T, DATA_LAYOUT_I_MAJOR> tile_A;
|
||||
typedef tile<16, 16, T, DATA_LAYOUT_I_MAJOR> tile_B;
|
||||
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
|
||||
#else
|
||||
#ifdef VOLTA_MMA_AVAILABLE
|
||||
|
|
|
|||
Loading…
Reference in New Issue