HIP: RDNA4 tensor core support for MMF (#17077)
* mmf for rdna4 * align the padding for rdna4 * forbit mul_mat_f for rdna4 * fix as comment * remove device kernels * add constexpr for early return * update based on review comment * change based on the review comment * pass compile error * keep code consistency --------- Co-authored-by: zhang hui <you@example.com>
This commit is contained in:
parent
8e9ddba610
commit
028f93ef98
|
|
@ -224,6 +224,10 @@ static const char * cu_get_error_str(CUresult err) {
|
||||||
#define AMD_MFMA_AVAILABLE
|
#define AMD_MFMA_AVAILABLE
|
||||||
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
|
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
|
||||||
|
|
||||||
|
#if defined(GGML_USE_HIP) && defined(RDNA4)
|
||||||
|
#define AMD_WMMA_AVAILABLE
|
||||||
|
#endif // defined(GGML_USE_HIP) && defined(RDNA4)
|
||||||
|
|
||||||
// The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
|
// The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
|
||||||
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||||
#define VOLTA_MMA_AVAILABLE
|
#define VOLTA_MMA_AVAILABLE
|
||||||
|
|
@ -283,6 +287,10 @@ static bool amd_mfma_available(const int cc) {
|
||||||
#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
|
#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool amd_wmma_available(const int cc) {
|
||||||
|
return GGML_CUDA_CC_IS_RDNA4(cc);
|
||||||
|
}
|
||||||
|
|
||||||
static bool volta_mma_available(const int cc) {
|
static bool volta_mma_available(const int cc) {
|
||||||
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
|
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,15 @@ template<typename dst_t, typename src_t>
|
||||||
return __float2bfloat16(float(x));
|
return __float2bfloat16(float(x));
|
||||||
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
|
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
|
||||||
return __bfloat162float(x);
|
return __bfloat162float(x);
|
||||||
|
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
|
||||||
|
return __float22half2_rn(x);
|
||||||
|
} else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
|
||||||
|
// bypass compile error on cuda 12.0.1
|
||||||
|
#ifdef GGML_USE_HIP
|
||||||
|
return __float22bfloat162_rn(x);
|
||||||
|
#else
|
||||||
|
return {x.x, x.y};
|
||||||
|
#endif // GGML_USE_HIP
|
||||||
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
|
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
|
||||||
return int32_t(x);
|
return int32_t(x);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -74,6 +74,33 @@ namespace ggml_cuda_mma {
|
||||||
static constexpr int J = J_;
|
static constexpr int J = J_;
|
||||||
|
|
||||||
#if defined(GGML_USE_HIP)
|
#if defined(GGML_USE_HIP)
|
||||||
|
#if defined(RDNA4)
|
||||||
|
static constexpr int ne = I * J / 32;
|
||||||
|
T x[ne] = {0};
|
||||||
|
|
||||||
|
static constexpr __device__ bool supported() {
|
||||||
|
if (I == 16 && J == 16) return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_i(const int l) {
|
||||||
|
if constexpr (I == 16 && J == 16) {
|
||||||
|
return 8 * (threadIdx.x / 16) + l;
|
||||||
|
} else {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
|
if constexpr (I == 16 && J == 16) {
|
||||||
|
return threadIdx.x % 16;
|
||||||
|
} else {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
static constexpr int ne = I * J / 64;
|
static constexpr int ne = I * J / 64;
|
||||||
T x[ne] = {0};
|
T x[ne] = {0};
|
||||||
|
|
||||||
|
|
@ -119,6 +146,7 @@ namespace ggml_cuda_mma {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif // defined(RDNA4)
|
||||||
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||||
static constexpr int ne = I * J / 32;
|
static constexpr int ne = I * J / 32;
|
||||||
T x[ne] = {0};
|
T x[ne] = {0};
|
||||||
|
|
@ -236,6 +264,32 @@ namespace ggml_cuda_mma {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#elif defined(AMD_WMMA_AVAILABLE)
|
||||||
|
static constexpr int ne = I * J / 32;
|
||||||
|
half2 x[ne] = {{0.0f, 0.0f}};
|
||||||
|
|
||||||
|
static constexpr __device__ bool supported() {
|
||||||
|
if (I == 16 && J == 8) return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_i(const int l) {
|
||||||
|
if constexpr (I == 16 && J == 8) {
|
||||||
|
return threadIdx.x % 16;
|
||||||
|
} else {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
|
if constexpr (I == 16 && J == 8) {
|
||||||
|
return 4 * (threadIdx.x / 16) + l;
|
||||||
|
} else {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
static constexpr int ne = I * J / WARP_SIZE;
|
static constexpr int ne = I * J / WARP_SIZE;
|
||||||
half2 x[ne] = {{0.0f, 0.0f}};
|
half2 x[ne] = {{0.0f, 0.0f}};
|
||||||
|
|
@ -285,6 +339,34 @@ namespace ggml_cuda_mma {
|
||||||
struct tile<I_, J_, nv_bfloat162> {
|
struct tile<I_, J_, nv_bfloat162> {
|
||||||
static constexpr int I = I_;
|
static constexpr int I = I_;
|
||||||
static constexpr int J = J_;
|
static constexpr int J = J_;
|
||||||
|
|
||||||
|
#if defined(AMD_WMMA_AVAILABLE)
|
||||||
|
static constexpr int ne = I * J / 32;
|
||||||
|
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||||
|
|
||||||
|
static constexpr __device__ bool supported() {
|
||||||
|
if (I == 16 && J == 8) return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_i(const int l) {
|
||||||
|
if constexpr (I == 16 && J == 8) {
|
||||||
|
return threadIdx.x % 16;
|
||||||
|
} else {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
|
if constexpr (I == 16 && J == 8) {
|
||||||
|
return 4 * (threadIdx.x / 16) + l;
|
||||||
|
} else {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
static constexpr int ne = I * J / WARP_SIZE;
|
static constexpr int ne = I * J / WARP_SIZE;
|
||||||
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||||
|
|
||||||
|
|
@ -320,6 +402,7 @@ namespace ggml_cuda_mma {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int I, int J>
|
template <int I, int J>
|
||||||
|
|
@ -353,6 +436,8 @@ namespace ggml_cuda_mma {
|
||||||
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
|
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
|
||||||
xi[0] = xs[0];
|
xi[0] = xs[0];
|
||||||
}
|
}
|
||||||
|
#elif defined(AMD_WMMA_AVAILABLE)
|
||||||
|
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
|
||||||
#else
|
#else
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int l = 0; l < t.ne; ++l) {
|
for (int l = 0; l < t.ne; ++l) {
|
||||||
|
|
@ -639,12 +724,34 @@ namespace ggml_cuda_mma {
|
||||||
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
||||||
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
|
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
|
||||||
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||||
|
#elif defined(AMD_WMMA_AVAILABLE)
|
||||||
|
using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
|
||||||
|
using floatx8_t = __attribute__((ext_vector_type(8))) float;
|
||||||
|
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
|
||||||
|
const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
|
||||||
|
const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
|
||||||
|
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED_VARS(D, A, B);
|
GGML_UNUSED_VARS(D, A, B);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // TURING_MMA_AVAILABLE
|
#endif // TURING_MMA_AVAILABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ void mma(
|
||||||
|
tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) {
|
||||||
|
#if defined(AMD_WMMA_AVAILABLE)
|
||||||
|
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
|
||||||
|
using floatx8_t = __attribute__((ext_vector_type(8))) float;
|
||||||
|
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
|
||||||
|
const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
|
||||||
|
const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
|
||||||
|
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
|
||||||
|
#else
|
||||||
|
GGML_UNUSED_VARS(D, A, B);
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif // AMPERE_MMA_AVAILABLE
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ void mma(
|
static __device__ __forceinline__ void mma(
|
||||||
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
|
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
|
||||||
#if defined(AMD_MFMA_AVAILABLE)
|
#if defined(AMD_MFMA_AVAILABLE)
|
||||||
|
|
|
||||||
|
|
@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (src1_ncols > 16) {
|
if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -160,9 +160,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
return ampere_mma_available(cc);
|
return ampere_mma_available(cc);
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
return volta_mma_available(cc) || turing_mma_available(cc);
|
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc);
|
||||||
case GGML_TYPE_BF16:
|
case GGML_TYPE_BF16:
|
||||||
return ampere_mma_available(cc);
|
return ampere_mma_available(cc) || amd_wmma_available(cc);
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
#include "mma.cuh"
|
#include "mma.cuh"
|
||||||
#include "common.cuh"
|
#include "common.cuh"
|
||||||
|
#include "convert.cuh"
|
||||||
|
|
||||||
using namespace ggml_cuda_mma;
|
using namespace ggml_cuda_mma;
|
||||||
|
|
||||||
|
|
@ -27,20 +28,35 @@ static __global__ void mul_mat_f(
|
||||||
const int stride_col_id, const int stride_row_id,
|
const int stride_col_id, const int stride_row_id,
|
||||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
|
||||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
|
||||||
|
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||||
|
#if defined(AMD_WMMA_AVAILABLE)
|
||||||
|
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
||||||
|
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
|
||||||
|
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
|
||||||
|
typedef tile<16, 8, T> tile_A;
|
||||||
|
typedef tile<tile_B_I, 8, T> tile_B;
|
||||||
|
typedef tile<16, tile_C_J, float> tile_C;
|
||||||
|
|
||||||
|
constexpr bool a_supported = tile_A::supported();
|
||||||
|
constexpr bool b_supported = tile_B::supported();
|
||||||
|
constexpr bool c_supported = tile_C::supported();
|
||||||
|
constexpr bool supported = a_supported && b_supported && c_supported;
|
||||||
|
#else
|
||||||
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
||||||
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
||||||
|
constexpr bool supported = I_16_supported || I_32_supported;
|
||||||
if (!I_16_supported && !I_32_supported) {
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
|
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
|
||||||
|
|
||||||
typedef tile<I_preferred, 8, T> tile_A;
|
typedef tile<I_preferred, 8, T> tile_A;
|
||||||
typedef tile<8, 8, T> tile_B;
|
typedef tile<8, 8, T> tile_B;
|
||||||
typedef tile<I_preferred, 8, float> tile_C;
|
typedef tile<I_preferred, 8, float> tile_C;
|
||||||
|
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||||
|
if constexpr (!supported) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||||
constexpr int tile_k_padded = warp_size + 4;
|
constexpr int tile_k_padded = warp_size + 4;
|
||||||
|
|
@ -161,11 +177,11 @@ static __global__ void mul_mat_f(
|
||||||
|
|
||||||
if constexpr (!has_ids) {
|
if constexpr (!has_ids) {
|
||||||
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
|
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
|
||||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
|
||||||
} else {
|
} else {
|
||||||
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
|
const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
|
||||||
float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
|
float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
|
||||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -239,7 +255,7 @@ static __global__ void mul_mat_f(
|
||||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||||
}
|
}
|
||||||
|
|
||||||
//This kernel is for larger batch sizes of mul_mat_id
|
//This kernel is for larger batch sizes of mul_mat_id
|
||||||
|
|
@ -253,20 +269,35 @@ static __global__ void mul_mat_f_ids(
|
||||||
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
||||||
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
|
||||||
const uint3 sis1_fd, const uint3 nch_fd) {
|
const uint3 sis1_fd, const uint3 nch_fd) {
|
||||||
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
|
||||||
|
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||||
|
#if defined(AMD_WMMA_AVAILABLE)
|
||||||
|
// Special case for tf32, just dummy mma layout as wmma doesn't support it.
|
||||||
|
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16;
|
||||||
|
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16;
|
||||||
|
typedef tile<16, 8, T> tile_A;
|
||||||
|
typedef tile<tile_B_I, 8, T> tile_B;
|
||||||
|
typedef tile<16, tile_C_J, float> tile_C;
|
||||||
|
|
||||||
|
constexpr bool a_supported = tile_A::supported();
|
||||||
|
constexpr bool b_supported = tile_B::supported();
|
||||||
|
constexpr bool c_supported = tile_C::supported();
|
||||||
|
constexpr bool supported = a_supported && b_supported && c_supported;
|
||||||
|
#else
|
||||||
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
constexpr bool I_16_supported = tile<16, 8, T>::supported() && tile<16, 8, float>::supported();
|
||||||
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
constexpr bool I_32_supported = tile<32, 8, T>::supported() && tile<32, 8, float>::supported();
|
||||||
|
constexpr bool supported = I_16_supported || I_32_supported;
|
||||||
|
|
||||||
if (!I_16_supported && !I_32_supported) {
|
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work but 16 is ~1% faster.
|
||||||
NO_DEVICE_CODE;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr int I_preferred = I_16_supported ? 16 : 32; // For Turing MMA both work butr 16 is ~1% faster.
|
|
||||||
|
|
||||||
typedef tile<I_preferred, 8, T> tile_A;
|
typedef tile<I_preferred, 8, T> tile_A;
|
||||||
typedef tile<8, 8, T> tile_B;
|
typedef tile<8, 8, T> tile_B;
|
||||||
typedef tile<I_preferred, 8, float> tile_C;
|
typedef tile<I_preferred, 8, float> tile_C;
|
||||||
|
#endif // defined(AMD_WMMA_AVAILABLE)
|
||||||
|
if constexpr (!supported) {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
||||||
constexpr int tile_k_padded = warp_size + 4;
|
constexpr int tile_k_padded = warp_size + 4;
|
||||||
|
|
@ -408,7 +439,7 @@ static __global__ void mul_mat_f_ids(
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
for (int j0 = 0; j0 < tile_B::I; ++j0) {
|
||||||
const float2 tmp = vals_buf[curr_buf][j0];
|
const float2 tmp = vals_buf[curr_buf][j0];
|
||||||
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
|
tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (itB + 1 < ntB) {
|
if (itB + 1 < ntB) {
|
||||||
|
|
@ -492,7 +523,7 @@ static __global__ void mul_mat_f_ids(
|
||||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T, int cols_per_block, int nwarps>
|
template<typename T, int cols_per_block, int nwarps>
|
||||||
|
|
@ -554,7 +585,8 @@ void mul_mat_f_cuda(
|
||||||
cudaStream_t stream, const mmf_ids_data * ids_data) {
|
cudaStream_t stream, const mmf_ids_data * ids_data) {
|
||||||
typedef tile<16, 8, T> tile_A_16;
|
typedef tile<16, 8, T> tile_A_16;
|
||||||
typedef tile<32, 8, T> tile_A_32;
|
typedef tile<32, 8, T> tile_A_32;
|
||||||
typedef tile< 8, 8, T> tile_B;
|
typedef tile<16, 8, T> tile_B_16;
|
||||||
|
typedef tile< 8, 8, T> tile_B_8;
|
||||||
|
|
||||||
GGML_ASSERT(ncols_x % 2 == 0);
|
GGML_ASSERT(ncols_x % 2 == 0);
|
||||||
GGML_ASSERT(stride_row % 2 == 0);
|
GGML_ASSERT(stride_row % 2 == 0);
|
||||||
|
|
@ -581,7 +613,8 @@ void mul_mat_f_cuda(
|
||||||
|
|
||||||
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
|
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
|
||||||
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
|
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
|
||||||
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
|
const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
|
||||||
|
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
|
||||||
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
|
||||||
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
|
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
|
||||||
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
|
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue