remove i_major_dual (#18157)
Co-authored-by: zhang hui <you@example.com>
This commit is contained in:
parent
9ce64aed7d
commit
54189c0d39
|
|
@ -78,27 +78,25 @@ namespace ggml_cuda_mma {
|
||||||
// MIRRORED == Each data value is held exactly once per thread subgroup.
|
// MIRRORED == Each data value is held exactly once per thread subgroup.
|
||||||
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
|
DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
|
||||||
DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
|
DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
|
||||||
DATA_LAYOUT_I_MAJOR_MIRRORED = 20,
|
DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3.
|
||||||
DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
|
DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
|
||||||
DATA_LAYOUT_I_MAJOR_DUAL = 40, // Matrix A&B for RDNA3.
|
|
||||||
};
|
};
|
||||||
// Implemented mma combinations are:
|
// Implemented mma combinations are:
|
||||||
// - (I_MAJOR, I_MAJOR) -> I_MAJOR
|
// - (I_MAJOR, I_MAJOR) -> I_MAJOR
|
||||||
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
|
// - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
|
||||||
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
|
// - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
|
||||||
|
|
||||||
constexpr bool is_i_major(const data_layout dl) {
|
static constexpr bool is_i_major(const data_layout dl) {
|
||||||
return dl == DATA_LAYOUT_I_MAJOR ||
|
return dl == DATA_LAYOUT_I_MAJOR ||
|
||||||
dl == DATA_LAYOUT_I_MAJOR_MIRRORED ||
|
dl == DATA_LAYOUT_I_MAJOR_MIRRORED;
|
||||||
dl == DATA_LAYOUT_I_MAJOR_DUAL;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr data_layout get_input_data_layout() {
|
static constexpr __device__ data_layout get_input_data_layout() {
|
||||||
#if defined(RDNA3)
|
#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||||
return DATA_LAYOUT_I_MAJOR_DUAL;
|
return DATA_LAYOUT_I_MAJOR_MIRRORED;
|
||||||
#else
|
#else
|
||||||
return DATA_LAYOUT_I_MAJOR;
|
return DATA_LAYOUT_I_MAJOR;
|
||||||
#endif // defined(RDNA3)
|
#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
|
template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
|
||||||
|
|
@ -462,11 +460,65 @@ namespace ggml_cuda_mma {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <int I_, int J_, typename T>
|
||||||
|
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_MIRRORED> {
|
||||||
|
static constexpr int I = I_;
|
||||||
|
static constexpr int J = J_;
|
||||||
|
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
|
||||||
|
|
||||||
|
// RDNA3
|
||||||
|
static constexpr int ne = I * J / 32 * 2;
|
||||||
|
|
||||||
|
T x[ne] = {0};
|
||||||
|
|
||||||
|
static constexpr __device__ bool supported() {
|
||||||
|
if (I == 16 && J == 16) return true;
|
||||||
|
if (I == 16 && J == 8) return true;
|
||||||
|
if (I == 16 && J == 4) return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_i(const int /*l*/) {
|
||||||
|
if constexpr (supported()) {
|
||||||
|
return threadIdx.x % 16;
|
||||||
|
} else {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
|
if constexpr (supported()) {
|
||||||
|
return l;
|
||||||
|
} else {
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <int I_, int J_>
|
template <int I_, int J_>
|
||||||
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
|
struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
|
||||||
static constexpr int I = I_;
|
static constexpr int I = I_;
|
||||||
static constexpr int J = J_;
|
static constexpr int J = J_;
|
||||||
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
|
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
|
||||||
|
#if defined(RDNA3)
|
||||||
|
static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
|
||||||
|
|
||||||
|
half2 x[ne] = {{0.0f, 0.0f}};
|
||||||
|
|
||||||
|
static constexpr __device__ bool supported() {
|
||||||
|
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_i(const int l) {
|
||||||
|
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
|
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
|
||||||
|
}
|
||||||
|
#else // Volta
|
||||||
static constexpr int ne = I * J / (WARP_SIZE/4);
|
static constexpr int ne = I * J / (WARP_SIZE/4);
|
||||||
|
|
||||||
half2 x[ne] = {{0.0f, 0.0f}};
|
half2 x[ne] = {{0.0f, 0.0f}};
|
||||||
|
|
@ -493,6 +545,29 @@ namespace ggml_cuda_mma {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif // defined(RDNA3)
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int I_, int J_>
|
||||||
|
struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR_MIRRORED> {
|
||||||
|
static constexpr int I = I_;
|
||||||
|
static constexpr int J = J_;
|
||||||
|
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
|
||||||
|
static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
|
||||||
|
|
||||||
|
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
|
||||||
|
|
||||||
|
static constexpr __device__ bool supported() {
|
||||||
|
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_i(const int l) {
|
||||||
|
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ int get_j(const int l) {
|
||||||
|
return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int I_, int J_>
|
template <int I_, int J_>
|
||||||
|
|
@ -528,42 +603,6 @@ namespace ggml_cuda_mma {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <int I_, int J_, typename T>
|
|
||||||
struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_DUAL> {
|
|
||||||
static constexpr int I = I_;
|
|
||||||
static constexpr int J = J_;
|
|
||||||
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_DUAL;
|
|
||||||
|
|
||||||
static constexpr int ne = I * J / 32 * 2;
|
|
||||||
|
|
||||||
T x[ne] = {0};
|
|
||||||
|
|
||||||
static constexpr __device__ bool supported() {
|
|
||||||
if (I == 16 && J == 16) return true;
|
|
||||||
if (I == 16 && J == 8) return true;
|
|
||||||
if (I == 16 && J == 4) return true;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_i(const int l) {
|
|
||||||
if constexpr (supported()) {
|
|
||||||
return threadIdx.x % 16;
|
|
||||||
} else {
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static __device__ __forceinline__ int get_j(const int l) {
|
|
||||||
if constexpr (supported()) {
|
|
||||||
return l;
|
|
||||||
} else {
|
|
||||||
NO_DEVICE_CODE;
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
#if defined(TURING_MMA_AVAILABLE)
|
#if defined(TURING_MMA_AVAILABLE)
|
||||||
template <int I, int J>
|
template <int I, int J>
|
||||||
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue