diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index d7ef4b5d95..57cd116d73 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -7,6 +7,9 @@ typedef unsigned int uint; + +#define GGML_CUDA_CC_RUBIN 10000 + constexpr uint WARPSIZE = 32; #define CUDA_NCHW_2_NHWC_TILE_DIM 32 #define CUDA_NCHW_2_NHWC_BLOCK_NM 8 @@ -343,7 +346,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, template __device__ __forceinline__ void ldmatrix_a( const half* src, -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][8] #else half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] @@ -351,13 +354,13 @@ __device__ __forceinline__ void ldmatrix_a( ){ #if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 8"); -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2"); #else static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); #endif -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast(reg); #else uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast(reg); @@ -403,7 +406,7 @@ __device__ __forceinline__ void ldmatrix_a( src_addr ^= 0b10000; // 1 -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -471,7 +474,7 @@ __device__ __forceinline__ void ldmatrix_a( src_addr ^= 0b110000; // 2 -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -537,7 +540,7 @@ __device__ __forceinline__ void ldmatrix_a( src_addr ^= 0b10000; // 3 -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -610,7 +613,7 @@ __device__ __forceinline__ void ldmatrix_a( template __device__ __forceinline__ void ldmatrix_b( const half* src, -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][4] #else half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] @@ -618,14 +621,14 @@ __device__ __forceinline__ void ldmatrix_b( ){ #if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2"); #else static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); #endif static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8"); -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN uint32_t (®_) [2][8][2] = reinterpret_cast(reg); #else uint32_t (®_) [4][8] = reinterpret_cast(reg); @@ -637,7 +640,7 @@ __device__ __forceinline__ void ldmatrix_b( constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes // 0 -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -670,7 +673,7 @@ __device__ __forceinline__ void ldmatrix_b( src_addr ^= 0b10000; -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -702,7 +705,7 @@ __device__ __forceinline__ void ldmatrix_b( src_addr ^= 0b110000; -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -734,7 +737,7 @@ __device__ __forceinline__ void ldmatrix_b( src_addr ^= 0b10000; -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -790,7 +793,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // const unsigned int NKPQ = param.n * KPQ; // loop bounds, constexpr where possible allows for loop unrolling -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN constexpr unsigned int mma_tiles_per_warp_k = 2; #else constexpr unsigned int mma_tiles_per_warp_k = 4; @@ -829,7 +832,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // declare register storage // ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2]; -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]; uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]; #else @@ -839,7 +842,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // convenience cast to half for register storage half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4] = reinterpret_cast(acc_register); -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][8] = reinterpret_cast(A_register); half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][4] = reinterpret_cast(B_register); #else @@ -962,7 +965,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){ #pragma unroll for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){ -#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN asm volatile ( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " "{%0, %1}, "