This commit is contained in:
yulo 2025-12-17 12:06:13 +08:00 committed by GitHub
commit 84565fce61
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 179 additions and 150 deletions

View File

@ -76,15 +76,31 @@ namespace ggml_cuda_mma {
// For the A/C matrices this means I major == row major, J major == column major. // For the A/C matrices this means I major == row major, J major == column major.
// For the B matrix this means I major == column major, J major == row major. // For the B matrix this means I major == column major, J major == row major.
// MIRRORED == Each data value is held exactly once per thread subgroup. // MIRRORED == Each data value is held exactly once per thread subgroup.
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell. DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
DATA_LAYOUT_I_MAJOR_MIRRORED = 10, DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
DATA_LAYOUT_J_MAJOR_MIRRORED = 20, DATA_LAYOUT_I_MAJOR_MIRRORED = 20,
DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
DATA_LAYOUT_I_MAJOR_DUAL = 40, // Matrix A&B for RDNA3.
}; };
// Implemented mma combinations are: // Implemented mma combinations are:
// - (I_MAJOR, I_MAJOR) -> I_MAJOR // - (I_MAJOR, I_MAJOR) -> I_MAJOR
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
constexpr bool is_i_major(const data_layout dl) {
return dl == DATA_LAYOUT_I_MAJOR ||
dl == DATA_LAYOUT_I_MAJOR_MIRRORED ||
dl == DATA_LAYOUT_I_MAJOR_DUAL;
}
constexpr data_layout get_input_data_layout() {
#if defined(RDNA3)
return DATA_LAYOUT_I_MAJOR_DUAL;
#else
return DATA_LAYOUT_I_MAJOR;
#endif // defined(RDNA3)
}
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR> template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
struct tile {}; struct tile {};
@ -115,9 +131,9 @@ namespace ggml_cuda_mma {
} else if constexpr (I == 32 && J == 4) { } else if constexpr (I == 32 && J == 4) {
return threadIdx.x % 32; return threadIdx.x % 32;
} else if constexpr (I == 16 && J == 16) { } else if constexpr (I == 16 && J == 16) {
return 4 * (threadIdx.x / 16) + l; return threadIdx.x % 16;
} else if constexpr (I == 32 && J == 32) { } else if constexpr (I == 32 && J == 32) {
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4); return threadIdx.x % 32;
} else { } else {
NO_DEVICE_CODE; NO_DEVICE_CODE;
return -1; return -1;
@ -132,9 +148,9 @@ namespace ggml_cuda_mma {
} else if constexpr (I == 32 && J == 4) { } else if constexpr (I == 32 && J == 4) {
return 2 * (threadIdx.x / 32) + l; return 2 * (threadIdx.x / 32) + l;
} else if constexpr (I == 16 && J == 16) { } else if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16; return 4 * (threadIdx.x / 16) + l;
} else if constexpr (I == 32 && J == 32) { } else if constexpr (I == 32 && J == 32) {
return threadIdx.x % 32; return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
} else { } else {
NO_DEVICE_CODE; NO_DEVICE_CODE;
return -1; return -1;
@ -171,28 +187,19 @@ namespace ggml_cuda_mma {
} }
} }
#elif defined(AMD_WMMA_AVAILABLE) #elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
static constexpr int ne = I * J / 32; static constexpr int ne = I * J / 32;
#elif defined(RDNA3)
static constexpr int ne = (I == 16 && J == 16) ? I * J / 32 : I * J / 16;
#endif // defined(RDNA4)
T x[ne] = {0}; T x[ne] = {0};
static constexpr __device__ bool supported() { static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true; if (I == 16 && J == 16) return true;
if (I == 16 && J == 8) return true;
if (I == 16 && J == 4) return true;
return false; return false;
} }
static __device__ __forceinline__ int get_i(const int l) { static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 16) { if constexpr (supported()) {
#if defined(RDNA4) return threadIdx.x % 16;
return 8 * (threadIdx.x / 16) + l;
#elif defined(RDNA3)
return 2 * l + (threadIdx.x / 16);
#else
NO_DEVICE_CODE;
return -1;
#endif // defined(RDNA4)
} else { } else {
NO_DEVICE_CODE; NO_DEVICE_CODE;
return -1; return -1;
@ -201,7 +208,17 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ int get_j(const int l) { static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 16) { if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16; // matrix C
#if defined(RDNA3)
return 2 * l + (threadIdx.x / 16);
#else
return ne * (threadIdx.x / 16) + l;
#endif // defined(RDNA3)
} else if constexpr (I == 16 && J == 8) {
// mmq input for RDNA4
return ne * (threadIdx.x / 16) + l;
} else if constexpr (I == 16 && J == 4) {
return ne * (threadIdx.x / 16) + l;
} else { } else {
NO_DEVICE_CODE; NO_DEVICE_CODE;
return -1; return -1;
@ -293,12 +310,7 @@ namespace ggml_cuda_mma {
} }
} }
#elif defined(AMD_WMMA_AVAILABLE) #elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA3)
// RDNA3 has duplicated data as input.
static constexpr int ne = I * J / 32 * 2;
#else
static constexpr int ne = I * J / 32; static constexpr int ne = I * J / 32;
#endif // defined(RDNA3)
half2 x[ne] = {{0.0f, 0.0f}}; half2 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() { static constexpr __device__ bool supported() {
@ -317,14 +329,7 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ int get_j(const int l) { static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) { if constexpr (I == 16 && J == 8) {
#if defined(RDNA4)
return 4 * (threadIdx.x / 16) + l; return 4 * (threadIdx.x / 16) + l;
#elif defined(RDNA3)
return l;
#else
NO_DEVICE_CODE;
return -1;
#endif // defined(RDNA4)
} else { } else {
NO_DEVICE_CODE; NO_DEVICE_CODE;
return -1; return -1;
@ -382,42 +387,19 @@ namespace ggml_cuda_mma {
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
#if defined(AMD_WMMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA3)
// RDNA3 has duplicated data as input.
static constexpr int ne = I * J / 32 * 2;
#else
static constexpr int ne = I * J / 32; static constexpr int ne = I * J / 32;
#endif // defined(RDNA3)
nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() { static constexpr __device__ bool supported() {
if (I == 16 && J == 8) return true; return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
return false;
} }
static __device__ __forceinline__ int get_i(const int l) { static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 8) { return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
} }
static __device__ __forceinline__ int get_j(const int l) { static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) { return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
#if defined(RDNA4)
return 4 * (threadIdx.x / 16) + l;
#elif defined(RDNA3)
return l;
#else
NO_DEVICE_CODE;
return -1;
#endif // defined(RDNA4)
} else {
NO_DEVICE_CODE;
return -1;
}
} }
#else #else
static constexpr int ne = I * J / WARP_SIZE; static constexpr int ne = I * J / WARP_SIZE;
@ -458,6 +440,28 @@ namespace ggml_cuda_mma {
#endif // defined(AMD_WMMA_AVAILABLE) #endif // defined(AMD_WMMA_AVAILABLE)
}; };
template <int I_, int J_, typename T>
struct tile<I_, J_, T, DATA_LAYOUT_J_MAJOR> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR;
static constexpr int ne = tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::ne;
T x[ne] = {0};
static constexpr __device__ bool supported() {
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::supported();
}
static __device__ __forceinline__ int get_i(const int l) {
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_j(l);
}
static __device__ __forceinline__ int get_j(const int l) {
return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_i(l);
}
};
template <int I_, int J_> template <int I_, int J_>
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> { struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
static constexpr int I = I_; static constexpr int I = I_;
@ -524,6 +528,42 @@ namespace ggml_cuda_mma {
} }
}; };
template <int I_, int J_, typename T>
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_DUAL> {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_DUAL;
static constexpr int ne = I * J / 32 * 2;
T x[ne] = {0};
static constexpr __device__ bool supported() {
if (I == 16 && J == 16) return true;
if (I == 16 && J == 8) return true;
if (I == 16 && J == 4) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (supported()) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (supported()) {
return l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
};
#if defined(TURING_MMA_AVAILABLE) #if defined(TURING_MMA_AVAILABLE)
template <int I, int J> template <int I, int J>
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) { static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
@ -569,55 +609,28 @@ namespace ggml_cuda_mma {
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
} }
} else { } else {
int64_t * xi = (int64_t *) t.x; ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
} }
#elif defined(AMD_WMMA_AVAILABLE) #elif defined(AMD_WMMA_AVAILABLE)
if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) { // All wmma layout has contiguous data when i-major.
#if defined(RDNA4) if constexpr (is_i_major(dl)) {
ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); // the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes()
#elif defined(RDNA3) constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes();
ggml_cuda_memcpy_1<sizeof(t.x)/2>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); if constexpr (sizeof(t.x) > aligned_copy_bytes) {
ggml_cuda_memcpy_1<sizeof(t.x)/2>(t.x + t.ne/2, xs0 + t.get_i(0) * stride + t.get_j(t.ne/2)); static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size");
#else constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes;
NO_DEVICE_CODE; #pragma unroll
#endif // defined(RDNA4) for (int i = 0; i < aligned_copy_count; ++i) {
} else if constexpr (std::is_same_v<T, int>) { ggml_cuda_memcpy_1<aligned_copy_bytes>(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i));
if constexpr (I == 16 && J == 4) {
int64_t * xi = (int64_t *) t.x;
#if defined(RDNA4)
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
#elif defined(RDNA3)
static_assert(tile<I,J,T>::ne >= 4, "fragment too small");
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride);
xi[0] = xs[0];
xi[1] = xs[1];
#endif // defined(RDNA4)
} else if constexpr (I == 16 && J == 8) {
int64_t * xi = (int64_t *) t.x;
#if defined(RDNA4)
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
xi[0] = xs[0];
const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
xi[1] = xs1[0];
#elif defined(RDNA3)
static_assert(tile<I,J,T>::ne >= 8, "fragment too small");
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride);
// contiguous four 64-bit chunks per lane for the wider RDNA3 fragment
xi[0] = xs[0];
xi[1] = xs[1];
const int64_t * xs1 = xs + 2;
xi[2] = xs1[0];
xi[3] = xs1[1];
#endif // defined(RDNA4)
} else {
NO_DEVICE_CODE;
} }
} else { } else {
NO_DEVICE_CODE; ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
}
} else {
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
}
} }
#else #else
#pragma unroll #pragma unroll
@ -660,9 +673,9 @@ namespace ggml_cuda_mma {
#endif // TURING_MMA_AVAILABLE #endif // TURING_MMA_AVAILABLE
} }
template <typename T> template <typename T, data_layout dl>
static __device__ __forceinline__ void load_ldmatrix( static __device__ __forceinline__ void load_ldmatrix(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { tile<16, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) {
#if defined(TURING_MMA_AVAILABLE) #if defined(TURING_MMA_AVAILABLE)
int * xi = (int * ) t.x; int * xi = (int * ) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
@ -832,8 +845,9 @@ namespace ggml_cuda_mma {
#endif // TURING_MMA_AVAILABLE #endif // TURING_MMA_AVAILABLE
} }
template <data_layout dl_ab, data_layout dl_d>
static __device__ __forceinline__ void mma( static __device__ __forceinline__ void mma(
tile<16, 8, float> & D, const tile<16, 8, float> & A, const tile<8, 8, float> & B) { tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) {
#ifdef AMPERE_MMA_AVAILABLE #ifdef AMPERE_MMA_AVAILABLE
const int * Axi = (const int *) A.x; const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x; const int * Bxi = (const int *) B.x;
@ -887,8 +901,9 @@ namespace ggml_cuda_mma {
#endif // AMPERE_MMA_AVAILABLE #endif // AMPERE_MMA_AVAILABLE
} }
template <data_layout dl_ab, data_layout dl_d>
static __device__ __forceinline__ void mma( static __device__ __forceinline__ void mma(
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { tile<16, 16, float, dl_d> & D, const tile<16, 8, half2, dl_ab> & A, const tile<16, 8, half2, dl_ab> & B) {
#ifdef TURING_MMA_AVAILABLE #ifdef TURING_MMA_AVAILABLE
const int * Axi = (const int *) A.x; const int * Axi = (const int *) A.x;
const int * Bxi = (const int *) B.x; const int * Bxi = (const int *) B.x;
@ -940,8 +955,9 @@ namespace ggml_cuda_mma {
#endif // TURING_MMA_AVAILABLE #endif // TURING_MMA_AVAILABLE
} }
template <data_layout dl_ab, data_layout dl_d>
static __device__ __forceinline__ void mma( static __device__ __forceinline__ void mma(
tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) { tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) {
#if defined(AMD_WMMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4) #if defined(RDNA4)
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16; using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
@ -967,8 +983,9 @@ namespace ggml_cuda_mma {
#endif // AMPERE_MMA_AVAILABLE #endif // AMPERE_MMA_AVAILABLE
} }
template <data_layout dl_d, data_layout dl_ab>
static __device__ __forceinline__ void mma( static __device__ __forceinline__ void mma(
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) { tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) {
#if defined(AMD_MFMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE)
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
int32x4_t * acc = (int32x4_t *) D.x; int32x4_t * acc = (int32x4_t *) D.x;
@ -1122,8 +1139,9 @@ namespace ggml_cuda_mma {
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
} }
static __device__ __forceinline__ void mma( template <data_layout dl_d, data_layout dl_ab>
tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) { static __device__ __forceinline__ void mma(
tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {
#if defined(AMD_WMMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE)
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x; int32x8_t * acc = (int32x8_t *) D.x;

View File

@ -32,11 +32,13 @@ static __global__ void mul_mat_f(
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) #if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE)
// Special case for tf32, just dummy mma layout as wmma doesn't support it. // Special case for tf32, just dummy mma layout as wmma doesn't support it.
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16; constexpr bool is_tf32 = std::is_same_v<T, float>;
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16; constexpr int tile_B_I = is_tf32 ? 8 : 16;
typedef tile<16, 8, T> tile_A; constexpr int tile_C_J = is_tf32 ? 8 : 16;
typedef tile<tile_B_I, 8, T> tile_B; constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
typedef tile<16, tile_C_J, float> tile_C; typedef tile<16, 8, T, ab_layout> tile_A;
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
#else #else
#ifdef VOLTA_MMA_AVAILABLE #ifdef VOLTA_MMA_AVAILABLE
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else { if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
@ -272,11 +274,13 @@ static __global__ void mul_mat_f_ids(
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) #if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE)
// Special case for tf32, just dummy mma layout as wmma doesn't support it. // Special case for tf32, just dummy mma layout as wmma doesn't support it.
constexpr int tile_B_I = std::is_same_v<T, float> ? 8 : 16; constexpr bool is_tf32 = std::is_same_v<T, float>;
constexpr int tile_C_J = std::is_same_v<T, float> ? 8 : 16; constexpr int tile_B_I = is_tf32 ? 8 : 16;
typedef tile<16, 8, T> tile_A; constexpr int tile_C_J = is_tf32 ? 8 : 16;
typedef tile<tile_B_I, 8, T> tile_B; constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
typedef tile<16, tile_C_J, float> tile_C; typedef tile<16, 8, T, ab_layout> tile_A;
typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
#else #else
#ifdef VOLTA_MMA_AVAILABLE #ifdef VOLTA_MMA_AVAILABLE
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else { if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {

View File

@ -797,9 +797,10 @@ template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
typedef tile<16, 8, int> tile_A; constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 8, int> tile_B; typedef tile<16, 8, int, input_layout> tile_A;
typedef tile<16, 16, int> tile_C; typedef tile<16, 8, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity; constexpr int rows_per_warp = granularity;
@ -966,9 +967,10 @@ template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
typedef tile<16, 8, int> tile_A; constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 8, int> tile_B; typedef tile<16, 8, int, input_layout> tile_A;
typedef tile<16, 16, int> tile_C; typedef tile<16, 8, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity; constexpr int rows_per_warp = granularity;
@ -1130,10 +1132,11 @@ template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE)
typedef tile<16, 8, int> tile_A; constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 8, int> tile_B; typedef tile<16, 8, int, input_layout> tile_A;
typedef tile<16, 16, int> tile_C; typedef tile<16, 8, int, input_layout> tile_B;
typedef tile<64, 2, int> tile_load; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
typedef tile<64, 2, int, input_layout> tile_load;
constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity; constexpr int rows_per_warp = granularity;
@ -1179,9 +1182,10 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
} }
} }
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
typedef tile<16, 4, int> tile_A; constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 4, int> tile_B; typedef tile<16, 4, int, input_layout> tile_A;
typedef tile<16, 16, int> tile_C; typedef tile<16, 4, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity; constexpr int rows_per_warp = granularity;
@ -1435,10 +1439,11 @@ template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE)
typedef tile<16, 8, int> tile_A; constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 8, int> tile_B; typedef tile<16, 8, int, input_layout> tile_A;
typedef tile<16, 16, int> tile_C; typedef tile<16, 8, int, input_layout> tile_B;
typedef tile<64, 2, int> tile_load; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
typedef tile<64, 2, int, input_layout> tile_load;
constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity; constexpr int rows_per_warp = granularity;
@ -1501,10 +1506,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
} }
} }
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 4, int> tile_A; typedef tile<16, 4, int, input_layout> tile_A;
typedef tile<16, 4, int> tile_B; typedef tile<16, 4, int, input_layout> tile_B;
typedef tile<16, 16, int> tile_C; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity; constexpr int rows_per_warp = granularity;
@ -2265,10 +2270,11 @@ template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#if defined(AMD_MFMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE)
typedef tile<16, 8, int> tile_A; constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 8, int> tile_B; typedef tile<16, 8, int, input_layout> tile_A;
typedef tile<16, 16, int> tile_C; typedef tile<16, 8, int, input_layout> tile_B;
typedef tile<64, 2, int> tile_load; typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
typedef tile<64, 2, int, input_layout> tile_load;
constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity; constexpr int rows_per_warp = granularity;
@ -2316,9 +2322,10 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
} }
} }
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
typedef tile<16, 4, int> tile_A; constexpr data_layout input_layout = get_input_data_layout();
typedef tile<16, 4, int> tile_B; typedef tile<16, 4, int, input_layout> tile_A;
typedef tile<16, 16, int> tile_C; typedef tile<16, 4, int, input_layout> tile_B;
typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int granularity = mmq_get_granularity_device(mmq_x);
constexpr int rows_per_warp = granularity; constexpr int rows_per_warp = granularity;
@ -3015,7 +3022,7 @@ static __device__ __forceinline__ void mmq_write_back_mma(
#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
constexpr int tileC_IJ = mmq_get_granularity_device(0); constexpr int tileC_IJ = mmq_get_granularity_device(0);
typedef tile<tileC_IJ, tileC_IJ, int> tile_C; typedef tile<tileC_IJ, tileC_IJ, int, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr int rows_per_warp = granularity; constexpr int rows_per_warp = granularity;
#else #else
typedef tile<16, 8, int> tile_C; typedef tile<16, 8, int> tile_C;