diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 794d90bdd1..3268dadfe8 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -78,27 +78,25 @@ namespace ggml_cuda_mma { // MIRRORED == Each data value is held exactly once per thread subgroup. DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA. DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3. - DATA_LAYOUT_I_MAJOR_MIRRORED = 20, + DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3. DATA_LAYOUT_J_MAJOR_MIRRORED = 30, - DATA_LAYOUT_I_MAJOR_DUAL = 40, // Matrix A&B for RDNA3. }; // Implemented mma combinations are: // - (I_MAJOR, I_MAJOR) -> I_MAJOR // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR - constexpr bool is_i_major(const data_layout dl) { + static constexpr bool is_i_major(const data_layout dl) { return dl == DATA_LAYOUT_I_MAJOR || - dl == DATA_LAYOUT_I_MAJOR_MIRRORED || - dl == DATA_LAYOUT_I_MAJOR_DUAL; + dl == DATA_LAYOUT_I_MAJOR_MIRRORED; } - constexpr data_layout get_input_data_layout() { -#if defined(RDNA3) - return DATA_LAYOUT_I_MAJOR_DUAL; + static constexpr __device__ data_layout get_input_data_layout() { +#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + return DATA_LAYOUT_I_MAJOR_MIRRORED; #else return DATA_LAYOUT_I_MAJOR; -#endif // defined(RDNA3) +#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA } template @@ -462,11 +460,65 @@ namespace ggml_cuda_mma { } }; + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED; + + // RDNA3 + static constexpr int ne = I * J / 32 * 2; + + T x[ne] = {0}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 16) return true; + if (I == 16 && J == 8) return true; + if (I == 16 && J == 4) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int /*l*/) { + if constexpr (supported()) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (supported()) { + return l; + } else { + NO_DEVICE_CODE; + return -1; + } + } + }; + template struct tile { static constexpr int I = I_; static constexpr int J = J_; static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED; +#if defined(RDNA3) + static constexpr int ne = tile::ne; + + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + return tile::supported(); + } + + static __device__ __forceinline__ int get_i(const int l) { + return tile::get_i(l); + } + + static __device__ __forceinline__ int get_j(const int l) { + return tile::get_j(l); + } +#else // Volta static constexpr int ne = I * J / (WARP_SIZE/4); half2 x[ne] = {{0.0f, 0.0f}}; @@ -493,6 +545,29 @@ namespace ggml_cuda_mma { return -1; } } +#endif // defined(RDNA3) + }; + + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED; + static constexpr int ne = tile::ne; + + nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + return tile::supported(); + } + + static __device__ __forceinline__ int get_i(const int l) { + return tile::get_i(l); + } + + static __device__ __forceinline__ int get_j(const int l) { + return tile::get_j(l); + } }; template @@ -528,42 +603,6 @@ namespace ggml_cuda_mma { } }; - template - struct tile { - static constexpr int I = I_; - static constexpr int J = J_; - static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_DUAL; - - static constexpr int ne = I * J / 32 * 2; - - T x[ne] = {0}; - - static constexpr __device__ bool supported() { - if (I == 16 && J == 16) return true; - if (I == 16 && J == 8) return true; - if (I == 16 && J == 4) return true; - return false; - } - - static __device__ __forceinline__ int get_i(const int l) { - if constexpr (supported()) { - return threadIdx.x % 16; - } else { - NO_DEVICE_CODE; - return -1; - } - } - - static __device__ __forceinline__ int get_j(const int l) { - if constexpr (supported()) { - return l; - } else { - NO_DEVICE_CODE; - return -1; - } - } - }; - #if defined(TURING_MMA_AVAILABLE) template static __device__ __forceinline__ tile get_half2(const tile & tile_float) {