From 9f498d29f1a3652bdbe426dd4802cc616b653aa2 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Wed, 12 Nov 2025 11:55:15 -0500 Subject: [PATCH] only enable m16n8k16 on ampere or above --- ggml/src/ggml-cuda/conv2d-implicit.cu | 388 ++++++++++++++------------ tests/test-conv2d.cpp | 18 +- 2 files changed, 226 insertions(+), 180 deletions(-) diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 654d2dffe4..529a0b50fd 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -343,15 +343,25 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, template __device__ __forceinline__ void ldmatrix_a( const half* src, +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][8] +#else + half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] +#endif ){ #if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 8"); - // static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2"); +#else + static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); +#endif - // uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast(reg); +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast(reg); +#else + uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast(reg); +#endif unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); @@ -392,38 +402,8 @@ __device__ __forceinline__ void ldmatrix_a( src_addr ^= 0b10000; - // // 1 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[0][1][0]), "=r"(reg_[0][1][1]), "=r"(reg_[1][1][0]), "=r"(reg_[1][1][1]) - // : "r"(src_addr) - // ); - - // // 1 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[2][1][0]), "=r"(reg_[2][1][1]), "=r"(reg_[3][1][0]), "=r"(reg_[3][1][1]) - // : "r"(src_addr + 32 * smem_stride_) - // ); - - // // 1 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[4][1][0]), "=r"(reg_[4][1][1]), "=r"(reg_[5][1][0]), "=r"(reg_[5][1][1]) - // : "r"(src_addr + 64 * smem_stride_) - // ); - - // // 1 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1]) - // : "r"(src_addr + 96 * smem_stride_) - // ); - + // 1 +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -455,41 +435,43 @@ __device__ __forceinline__ void ldmatrix_a( : "r"(src_addr + 96 * smem_stride_) ); +#else + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][1][0]), "=r"(reg_[0][1][1]), "=r"(reg_[1][1][0]), "=r"(reg_[1][1][1]) + : "r"(src_addr) + ); + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][1][0]), "=r"(reg_[2][1][1]), "=r"(reg_[3][1][0]), "=r"(reg_[3][1][1]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][1][0]), "=r"(reg_[4][1][1]), "=r"(reg_[5][1][0]), "=r"(reg_[5][1][1]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1]) + : "r"(src_addr + 96 * smem_stride_) + ); +#endif + src_addr ^= 0b110000; - // // 2 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[0][2][0]), "=r"(reg_[0][2][1]), "=r"(reg_[1][2][0]), "=r"(reg_[1][2][1]) - // : "r"(src_addr) - // ); - - // // 2 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[2][2][0]), "=r"(reg_[2][2][1]), "=r"(reg_[3][2][0]), "=r"(reg_[3][2][1]) - // : "r"(src_addr + 32 * smem_stride_) - // ); - - // // 2 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[4][2][0]), "=r"(reg_[4][2][1]), "=r"(reg_[5][2][0]), "=r"(reg_[5][2][1]) - // : "r"(src_addr + 64 * smem_stride_) - // ); - - // // 2 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[6][2][0]), "=r"(reg_[6][2][1]), "=r"(reg_[7][2][0]), "=r"(reg_[7][2][1]) - // : "r"(src_addr + 96 * smem_stride_) - // ); - // 2 +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -520,42 +502,42 @@ __device__ __forceinline__ void ldmatrix_a( : "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1]) : "r"(src_addr + 96 * smem_stride_) ); +#else + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][2][0]), "=r"(reg_[0][2][1]), "=r"(reg_[1][2][0]), "=r"(reg_[1][2][1]) + : "r"(src_addr) + ); + // 2 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][2][0]), "=r"(reg_[2][2][1]), "=r"(reg_[3][2][0]), "=r"(reg_[3][2][1]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 2 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][2][0]), "=r"(reg_[4][2][1]), "=r"(reg_[5][2][0]), "=r"(reg_[5][2][1]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 2 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][2][0]), "=r"(reg_[6][2][1]), "=r"(reg_[7][2][0]), "=r"(reg_[7][2][1]) + : "r"(src_addr + 96 * smem_stride_) + ); +#endif src_addr ^= 0b10000; - // // 3 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[0][3][0]), "=r"(reg_[0][3][1]), "=r"(reg_[1][3][0]), "=r"(reg_[1][3][1]) - // : "r"(src_addr) - // ); - - // // 3 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[2][3][0]), "=r"(reg_[2][3][1]), "=r"(reg_[3][3][0]), "=r"(reg_[3][3][1]) - // : "r"(src_addr + 32 * smem_stride_) - // ); - - // // 3 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[4][3][0]), "=r"(reg_[4][3][1]), "=r"(reg_[5][3][0]), "=r"(reg_[5][3][1]) - // : "r"(src_addr + 64 * smem_stride_) - // ); - - // // 3 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1]) - // : "r"(src_addr + 96 * smem_stride_) - // ); - // 3 +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -586,6 +568,38 @@ __device__ __forceinline__ void ldmatrix_a( : "=r"(reg_[6][1][2]), "=r"(reg_[6][1][3]), "=r"(reg_[7][1][2]), "=r"(reg_[7][1][3]) : "r"(src_addr + 96 * smem_stride_) ); +#else + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][3][0]), "=r"(reg_[0][3][1]), "=r"(reg_[1][3][0]), "=r"(reg_[1][3][1]) + : "r"(src_addr) + ); + + // 3 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][3][0]), "=r"(reg_[2][3][1]), "=r"(reg_[3][3][0]), "=r"(reg_[3][3][1]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 3 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][3][0]), "=r"(reg_[4][3][1]), "=r"(reg_[5][3][0]), "=r"(reg_[5][3][1]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 3 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1]) + : "r"(src_addr + 96 * smem_stride_) + ); +#endif #else GGML_UNUSED(src); GGML_UNUSED(reg); @@ -596,17 +610,26 @@ __device__ __forceinline__ void ldmatrix_a( template __device__ __forceinline__ void ldmatrix_b( const half* src, - // half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][4] +#else + half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] +#endif ){ #if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING - // static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2"); +#else + static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); +#endif static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8"); - // uint32_t (®_) [4][8] = reinterpret_cast(reg); +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE uint32_t (®_) [2][8][2] = reinterpret_cast(reg); +#else + uint32_t (®_) [4][8] = reinterpret_cast(reg); +#endif unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); @@ -614,21 +637,7 @@ __device__ __forceinline__ void ldmatrix_b( constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes // 0 - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3]) - // : "r"(src_addr) - // ); - - - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7]) - // : "r"(src_addr + 32 * smem_stride_) - // ); - +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -636,30 +645,32 @@ __device__ __forceinline__ void ldmatrix_b( : "r"(src_addr) ); - asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" : "=r"(reg_[0][4][0]), "=r"(reg_[0][5][0]), "=r"(reg_[0][6][0]), "=r"(reg_[0][7][0]) : "r"(src_addr + 32 * smem_stride_) ); +#else + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3]) + : "r"(src_addr) + ); + + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7]) + : "r"(src_addr + 32 * smem_stride_) + ); +#endif src_addr ^= 0b10000; - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3]) - // : "r"(src_addr) - // ); - - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7]) - // : "r"(src_addr + 32 * smem_stride_) - // ); - +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -673,23 +684,25 @@ __device__ __forceinline__ void ldmatrix_b( : "=r"(reg_[0][4][1]), "=r"(reg_[0][5][1]), "=r"(reg_[0][6][1]), "=r"(reg_[0][7][1]) : "r"(src_addr + 32 * smem_stride_) ); +#else + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3]) + : "r"(src_addr) + ); + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7]) + : "r"(src_addr + 32 * smem_stride_) + ); +#endif src_addr ^= 0b110000; - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3]) - // : "r"(src_addr) - // ); - - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7]) - // : "r"(src_addr + 32 * smem_stride_) - // ); - +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -703,23 +716,25 @@ __device__ __forceinline__ void ldmatrix_b( : "=r"(reg_[1][4][0]), "=r"(reg_[1][5][0]), "=r"(reg_[1][6][0]), "=r"(reg_[1][7][0]) : "r"(src_addr + 32 * smem_stride_) ); +#else + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3]) + : "r"(src_addr) + ); + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7]) + : "r"(src_addr + 32 * smem_stride_) + ); +#endif src_addr ^= 0b10000; - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3]) - // : "r"(src_addr) - // ); - - // asm volatile ( - // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - // "{%0, %1, %2, %3}, [%4];" - // : "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7]) - // : "r"(src_addr + 32 * smem_stride_) - // ); - +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -733,6 +748,21 @@ __device__ __forceinline__ void ldmatrix_b( : "=r"(reg_[1][4][1]), "=r"(reg_[1][5][1]), "=r"(reg_[1][6][1]), "=r"(reg_[1][7][1]) : "r"(src_addr + 32 * smem_stride_) ); +#else + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3]) + : "r"(src_addr) + ); + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7]) + : "r"(src_addr + 32 * smem_stride_) + ); +#endif #else GGML_UNUSED(src); GGML_UNUSED(reg); @@ -760,7 +790,11 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const unsigned int NKPQ = param.n * KPQ; // loop bounds, constexpr where possible allows for loop unrolling +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE constexpr unsigned int mma_tiles_per_warp_k = 2; +#else + constexpr unsigned int mma_tiles_per_warp_k = 4; +#endif constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M; constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N; const unsigned int z = blockIdx.z; @@ -787,14 +821,23 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // declare register storage // ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2]; +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]; uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]; +#else + uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2]; + uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n]; +#endif // convenience cast to half for register storage half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4] = reinterpret_cast(acc_register); +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][8] = reinterpret_cast(A_register); half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][4] = reinterpret_cast(B_register); - +#else + half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast(A_register); + half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] = reinterpret_cast(B_register); +#endif // accumulators start at 0 for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){ for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){ @@ -844,17 +887,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){ #pragma unroll for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){ - // asm volatile ( - // "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " - // "{%0, %1}, " - // "{%2, %3}, " - // "{%4}, " - // "{%5, %6};" - // : "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1]) - // : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]), - // "r"(B_register[mma_k][mma_n]) - // "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) - // ); +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE asm volatile ( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " "{%0, %1}, " @@ -866,6 +899,19 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, "r"(B_register[mma_k][mma_n][0]), "r"(B_register[mma_k][mma_n][1]) "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) ); +#else + asm volatile ( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0, %1}, " + "{%2, %3}, " + "{%4}, " + "{%5, %6};" + : "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1]) + : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]), + "r"(B_register[mma_k][mma_n]) + "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) + ); +#endif } } } diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 90ef1e5237..8ee0747989 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -714,15 +714,15 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - for(int i = 0; i < conv2d_data.size(); i++) { - float diff = fabs(im2col_data[i] - conv2d_data[i]); - // if(diff > 0.5) { - printf("(%7.3f, %7.3f, %.2f, %d) \n", - im2col_data[i], conv2d_data[i], - diff, i); - // break; - // } - } + // for(int i = 0; i < conv2d_data.size(); i++) { + // float diff = fabs(im2col_data[i] - conv2d_data[i]); + // // if(diff > 0.5) { + // printf("(%7.3f, %7.3f, %.2f, %d) \n", + // im2col_data[i], conv2d_data[i], + // diff, i); + // // break; + // // } + // } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer);