diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 80a406e2c5..e73ec150df 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -9,8 +9,6 @@ typedef unsigned int uint; -#define GGML_CUDA_CC_RUBIN 10000 - constexpr uint WARPSIZE = 32; #define CUDA_NCHW_2_NHWC_TILE_DIM 32 #define CUDA_NCHW_2_NHWC_BLOCK_NM 8 @@ -344,25 +342,14 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, template __device__ __forceinline__ void ldmatrix_a( const half* src, -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][8] -#else half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] -#endif ){ #if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 8"); -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2"); -#else static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); -#endif -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast(reg); -#else uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast(reg); -#endif + unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); @@ -404,39 +391,7 @@ __device__ __forceinline__ void ldmatrix_a( src_addr ^= 0b10000; // 1 -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][0][2]), "=r"(reg_[0][0][3]), "=r"(reg_[1][0][2]), "=r"(reg_[1][0][3]) - : "r"(src_addr) - ); - // 1 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[2][0][2]), "=r"(reg_[2][0][3]), "=r"(reg_[3][0][2]), "=r"(reg_[3][0][3]) - : "r"(src_addr + 32 * smem_stride_) - ); - - // 1 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[4][0][2]), "=r"(reg_[4][0][3]), "=r"(reg_[5][0][2]), "=r"(reg_[5][0][3]) - : "r"(src_addr + 64 * smem_stride_) - ); - - // 1 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[6][0][2]), "=r"(reg_[6][0][3]), "=r"(reg_[7][0][2]), "=r"(reg_[7][0][3]) - : "r"(src_addr + 96 * smem_stride_) - ); - -#else asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -467,43 +422,10 @@ __device__ __forceinline__ void ldmatrix_a( : "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1]) : "r"(src_addr + 96 * smem_stride_) ); -#endif src_addr ^= 0b110000; // 2 -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][1][0]), "=r"(reg_[0][1][1]), "=r"(reg_[1][1][0]), "=r"(reg_[1][1][1]) - : "r"(src_addr) - ); - - // 2 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[2][1][0]), "=r"(reg_[2][1][1]), "=r"(reg_[3][1][0]), "=r"(reg_[3][1][1]) - : "r"(src_addr + 32 * smem_stride_) - ); - - // 2 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[4][1][0]), "=r"(reg_[4][1][1]), "=r"(reg_[5][1][0]), "=r"(reg_[5][1][1]) - : "r"(src_addr + 64 * smem_stride_) - ); - - // 2 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1]) - : "r"(src_addr + 96 * smem_stride_) - ); -#else asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -534,42 +456,10 @@ __device__ __forceinline__ void ldmatrix_a( : "=r"(reg_[6][2][0]), "=r"(reg_[6][2][1]), "=r"(reg_[7][2][0]), "=r"(reg_[7][2][1]) : "r"(src_addr + 96 * smem_stride_) ); -#endif src_addr ^= 0b10000; // 3 -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][1][2]), "=r"(reg_[0][1][3]), "=r"(reg_[1][1][2]), "=r"(reg_[1][1][3]) - : "r"(src_addr) - ); - // 3 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[2][1][2]), "=r"(reg_[2][1][3]), "=r"(reg_[3][1][2]), "=r"(reg_[3][1][3]) - : "r"(src_addr + 32 * smem_stride_) - ); - - // 3 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[4][1][2]), "=r"(reg_[4][1][3]), "=r"(reg_[5][1][2]), "=r"(reg_[5][1][3]) - : "r"(src_addr + 64 * smem_stride_) - ); - - // 3 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[6][1][2]), "=r"(reg_[6][1][3]), "=r"(reg_[7][1][2]), "=r"(reg_[7][1][3]) - : "r"(src_addr + 96 * smem_stride_) - ); -#else asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -600,7 +490,7 @@ __device__ __forceinline__ void ldmatrix_a( : "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1]) : "r"(src_addr + 96 * smem_stride_) ); -#endif + #else GGML_UNUSED(src); GGML_UNUSED(reg); @@ -611,26 +501,14 @@ __device__ __forceinline__ void ldmatrix_a( template __device__ __forceinline__ void ldmatrix_b( const half* src, -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][4] -#else half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] -#endif ){ #if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING - -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2"); -#else static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); -#endif static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8"); -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - uint32_t (®_) [2][8][2] = reinterpret_cast(reg); -#else uint32_t (®_) [4][8] = reinterpret_cast(reg); -#endif + unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); @@ -638,21 +516,7 @@ __device__ __forceinline__ void ldmatrix_b( constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes // 0 -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][0][0]), "=r"(reg_[0][1][0]), "=r"(reg_[0][2][0]), "=r"(reg_[0][3][0]) - : "r"(src_addr) - ); - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][4][0]), "=r"(reg_[0][5][0]), "=r"(reg_[0][6][0]), "=r"(reg_[0][7][0]) - : "r"(src_addr + 32 * smem_stride_) - ); -#else asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -667,25 +531,10 @@ __device__ __forceinline__ void ldmatrix_b( : "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7]) : "r"(src_addr + 32 * smem_stride_) ); -#endif src_addr ^= 0b10000; -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][0][1]), "=r"(reg_[0][1][1]), "=r"(reg_[0][2][1]), "=r"(reg_[0][3][1]) - : "r"(src_addr) - ); - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][4][1]), "=r"(reg_[0][5][1]), "=r"(reg_[0][6][1]), "=r"(reg_[0][7][1]) - : "r"(src_addr + 32 * smem_stride_) - ); -#else asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -699,25 +548,10 @@ __device__ __forceinline__ void ldmatrix_b( : "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7]) : "r"(src_addr + 32 * smem_stride_) ); -#endif src_addr ^= 0b110000; -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[1][0][0]), "=r"(reg_[1][1][0]), "=r"(reg_[1][2][0]), "=r"(reg_[1][3][0]) - : "r"(src_addr) - ); - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[1][4][0]), "=r"(reg_[1][5][0]), "=r"(reg_[1][6][0]), "=r"(reg_[1][7][0]) - : "r"(src_addr + 32 * smem_stride_) - ); -#else asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -731,25 +565,11 @@ __device__ __forceinline__ void ldmatrix_b( : "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7]) : "r"(src_addr + 32 * smem_stride_) ); -#endif + src_addr ^= 0b10000; -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[1][0][1]), "=r"(reg_[1][1][1]), "=r"(reg_[1][2][1]), "=r"(reg_[1][3][1]) - : "r"(src_addr) - ); - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[1][4][1]), "=r"(reg_[1][5][1]), "=r"(reg_[1][6][1]), "=r"(reg_[1][7][1]) - : "r"(src_addr + 32 * smem_stride_) - ); -#else asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -763,7 +583,6 @@ __device__ __forceinline__ void ldmatrix_b( : "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7]) : "r"(src_addr + 32 * smem_stride_) ); -#endif #else GGML_UNUSED(src); GGML_UNUSED(reg); @@ -783,11 +602,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, constexpr unsigned int MMA_N = 8; // loop bounds, constexpr where possible allows for loop unrolling -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - constexpr unsigned int mma_tiles_per_warp_k = 2; -#else + constexpr unsigned int mma_tiles_per_warp_k = 4; -#endif constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M; constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N; const unsigned int z = blockIdx.z; @@ -835,23 +651,15 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // declare register storage // ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2]; -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]; - uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]; -#else + uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2]; uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n]; -#endif // convenience cast to half for register storage half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4] = reinterpret_cast(acc_register); -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][8] = reinterpret_cast(A_register); - half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][4] = reinterpret_cast(B_register); -#else half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast(A_register); half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] = reinterpret_cast(B_register); -#endif + // accumulators start at 0 for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){ for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){ @@ -968,19 +776,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) { #pragma unroll for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) { -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - asm volatile ( - "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " - "{%0, %1}, " - "{%2, %3, %4, %5}, " - "{%6, %7}, " - "{%8, %9};" - : "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1]) - : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]),"r"(A_register[mma_m][mma_k][2]), "r"(A_register[mma_m][mma_k][3]), - "r"(B_register[mma_k][mma_n][0]), "r"(B_register[mma_k][mma_n][1]) - "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) - ); -#else + asm volatile ( "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " "{%0, %1}, " @@ -992,7 +788,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, "r"(B_register[mma_k][mma_n]) "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) ); -#endif } } } @@ -1030,19 +825,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) { #pragma unroll for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) { -#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN - asm volatile ( - "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " - "{%0, %1}, " - "{%2, %3, %4, %5}, " - "{%6, %7}, " - "{%8, %9};" - : "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1]) - : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]),"r"(A_register[mma_m][mma_k][2]), "r"(A_register[mma_m][mma_k][3]), - "r"(B_register[mma_k][mma_n][0]), "r"(B_register[mma_k][mma_n][1]) - "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) - ); -#else asm volatile ( "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " "{%0, %1}, " @@ -1054,7 +836,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, "r"(B_register[mma_k][mma_n]) "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) ); -#endif } } }