diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 874d40e80b..1fceeb9a6e 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -343,13 +343,15 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, template __device__ __forceinline__ void ldmatrix_a( const half* src, - half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] + half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][8] ){ #if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING - static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 4"); - static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); + static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 8"); + // static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); + static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2"); - uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast(reg); + // uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast(reg); + uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast(reg); unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); @@ -390,7 +392,104 @@ __device__ __forceinline__ void ldmatrix_a( src_addr ^= 0b10000; + // // 1 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[0][1][0]), "=r"(reg_[0][1][1]), "=r"(reg_[1][1][0]), "=r"(reg_[1][1][1]) + // : "r"(src_addr) + // ); + + // // 1 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[2][1][0]), "=r"(reg_[2][1][1]), "=r"(reg_[3][1][0]), "=r"(reg_[3][1][1]) + // : "r"(src_addr + 32 * smem_stride_) + // ); + + // // 1 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[4][1][0]), "=r"(reg_[4][1][1]), "=r"(reg_[5][1][0]), "=r"(reg_[5][1][1]) + // : "r"(src_addr + 64 * smem_stride_) + // ); + + // // 1 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1]) + // : "r"(src_addr + 96 * smem_stride_) + // ); + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][0][2]), "=r"(reg_[0][0][3]), "=r"(reg_[1][0][2]), "=r"(reg_[1][0][3]) + : "r"(src_addr) + ); + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][0][2]), "=r"(reg_[2][0][3]), "=r"(reg_[3][0][2]), "=r"(reg_[3][0][3]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][0][2]), "=r"(reg_[4][0][3]), "=r"(reg_[5][0][2]), "=r"(reg_[5][0][3]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][0][2]), "=r"(reg_[6][0][3]), "=r"(reg_[7][0][2]), "=r"(reg_[7][0][3]) + : "r"(src_addr + 96 * smem_stride_) + ); + + src_addr ^= 0b110000; + + // // 2 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[0][2][0]), "=r"(reg_[0][2][1]), "=r"(reg_[1][2][0]), "=r"(reg_[1][2][1]) + // : "r"(src_addr) + // ); + + // // 2 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[2][2][0]), "=r"(reg_[2][2][1]), "=r"(reg_[3][2][0]), "=r"(reg_[3][2][1]) + // : "r"(src_addr + 32 * smem_stride_) + // ); + + // // 2 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[4][2][0]), "=r"(reg_[4][2][1]), "=r"(reg_[5][2][0]), "=r"(reg_[5][2][1]) + // : "r"(src_addr + 64 * smem_stride_) + // ); + + // // 2 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[6][2][0]), "=r"(reg_[6][2][1]), "=r"(reg_[7][2][0]), "=r"(reg_[7][2][1]) + // : "r"(src_addr + 96 * smem_stride_) + // ); + + // 2 asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -398,7 +497,7 @@ __device__ __forceinline__ void ldmatrix_a( : "r"(src_addr) ); - // 1 + // 2 asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -406,7 +505,7 @@ __device__ __forceinline__ void ldmatrix_a( : "r"(src_addr + 32 * smem_stride_) ); - // 1 + // 2 asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -414,7 +513,7 @@ __device__ __forceinline__ void ldmatrix_a( : "r"(src_addr + 64 * smem_stride_) ); - // 1 + // 2 asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" @@ -422,46 +521,45 @@ __device__ __forceinline__ void ldmatrix_a( : "r"(src_addr + 96 * smem_stride_) ); - src_addr ^= 0b110000; - - // 2 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][2][0]), "=r"(reg_[0][2][1]), "=r"(reg_[1][2][0]), "=r"(reg_[1][2][1]) - : "r"(src_addr) - ); - - // 2 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[2][2][0]), "=r"(reg_[2][2][1]), "=r"(reg_[3][2][0]), "=r"(reg_[3][2][1]) - : "r"(src_addr + 32 * smem_stride_) - ); - - // 2 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[4][2][0]), "=r"(reg_[4][2][1]), "=r"(reg_[5][2][0]), "=r"(reg_[5][2][1]) - : "r"(src_addr + 64 * smem_stride_) - ); - - // 2 - asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[6][2][0]), "=r"(reg_[6][2][1]), "=r"(reg_[7][2][0]), "=r"(reg_[7][2][1]) - : "r"(src_addr + 96 * smem_stride_) - ); src_addr ^= 0b10000; + // // 3 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[0][3][0]), "=r"(reg_[0][3][1]), "=r"(reg_[1][3][0]), "=r"(reg_[1][3][1]) + // : "r"(src_addr) + // ); + + // // 3 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[2][3][0]), "=r"(reg_[2][3][1]), "=r"(reg_[3][3][0]), "=r"(reg_[3][3][1]) + // : "r"(src_addr + 32 * smem_stride_) + // ); + + // // 3 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[4][3][0]), "=r"(reg_[4][3][1]), "=r"(reg_[5][3][0]), "=r"(reg_[5][3][1]) + // : "r"(src_addr + 64 * smem_stride_) + // ); + + // // 3 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1]) + // : "r"(src_addr + 96 * smem_stride_) + // ); + // 3 asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][3][0]), "=r"(reg_[0][3][1]), "=r"(reg_[1][3][0]), "=r"(reg_[1][3][1]) + : "=r"(reg_[0][1][2]), "=r"(reg_[0][1][3]), "=r"(reg_[1][1][2]), "=r"(reg_[1][1][3]) : "r"(src_addr) ); @@ -469,7 +567,7 @@ __device__ __forceinline__ void ldmatrix_a( asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[2][3][0]), "=r"(reg_[2][3][1]), "=r"(reg_[3][3][0]), "=r"(reg_[3][3][1]) + : "=r"(reg_[2][1][2]), "=r"(reg_[2][1][3]), "=r"(reg_[3][1][2]), "=r"(reg_[3][1][3]) : "r"(src_addr + 32 * smem_stride_) ); @@ -477,7 +575,7 @@ __device__ __forceinline__ void ldmatrix_a( asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[4][3][0]), "=r"(reg_[4][3][1]), "=r"(reg_[5][3][0]), "=r"(reg_[5][3][1]) + : "=r"(reg_[4][1][2]), "=r"(reg_[4][1][3]), "=r"(reg_[5][1][2]), "=r"(reg_[5][1][3]) : "r"(src_addr + 64 * smem_stride_) ); @@ -485,7 +583,7 @@ __device__ __forceinline__ void ldmatrix_a( asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1]) + : "=r"(reg_[6][1][2]), "=r"(reg_[6][1][3]), "=r"(reg_[7][1][2]), "=r"(reg_[7][1][3]) : "r"(src_addr + 96 * smem_stride_) ); #else @@ -498,14 +596,17 @@ __device__ __forceinline__ void ldmatrix_a( template __device__ __forceinline__ void ldmatrix_b( const half* src, - half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] + // half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] + half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][4] ){ #if __CUDA_ARCH__ >= GGML_CUDA_CC_TURING - static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); + // static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); + static_assert(mma_tiles_per_warp_k == 2, "mma_tiles_per_warp_k must be 2"); static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8"); - uint32_t (®_) [4][8] = reinterpret_cast(reg); + // uint32_t (®_) [4][8] = reinterpret_cast(reg); + uint32_t (®_) [2][8][2] = reinterpret_cast(reg); unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); @@ -513,10 +614,25 @@ __device__ __forceinline__ void ldmatrix_b( constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes // 0 + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3]) + // : "r"(src_addr) + // ); + + + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7]) + // : "r"(src_addr + 32 * smem_stride_) + // ); + asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3]) + : "=r"(reg_[0][0][0]), "=r"(reg_[0][1][0]), "=r"(reg_[0][2][0]), "=r"(reg_[0][3][0]) : "r"(src_addr) ); @@ -524,55 +640,97 @@ __device__ __forceinline__ void ldmatrix_b( asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7]) + : "=r"(reg_[0][4][0]), "=r"(reg_[0][5][0]), "=r"(reg_[0][6][0]), "=r"(reg_[0][7][0]) : "r"(src_addr + 32 * smem_stride_) ); src_addr ^= 0b10000; + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3]) + // : "r"(src_addr) + // ); + + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7]) + // : "r"(src_addr + 32 * smem_stride_) + // ); + asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3]) - : "r"(src_addr) - ); + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][0][1]), "=r"(reg_[0][1][1]), "=r"(reg_[0][2][1]), "=r"(reg_[0][3][1]) + : "r"(src_addr) + ); asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7]) + : "=r"(reg_[0][4][1]), "=r"(reg_[0][5][1]), "=r"(reg_[0][6][1]), "=r"(reg_[0][7][1]) : "r"(src_addr + 32 * smem_stride_) ); src_addr ^= 0b110000; + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3]) + // : "r"(src_addr) + // ); + + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7]) + // : "r"(src_addr + 32 * smem_stride_) + // ); + asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3]) - : "r"(src_addr) - ); + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[1][0][0]), "=r"(reg_[1][1][0]), "=r"(reg_[1][2][0]), "=r"(reg_[1][3][0]) + : "r"(src_addr) + ); asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7]) + : "=r"(reg_[1][4][0]), "=r"(reg_[1][5][0]), "=r"(reg_[1][6][0]), "=r"(reg_[1][7][0]) : "r"(src_addr + 32 * smem_stride_) ); src_addr ^= 0b10000; + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3]) + // : "r"(src_addr) + // ); + + // asm volatile ( + // "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + // "{%0, %1, %2, %3}, [%4];" + // : "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7]) + // : "r"(src_addr + 32 * smem_stride_) + // ); + asm volatile ( - "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " - "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3]) - : "r"(src_addr) - ); + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[1][0][1]), "=r"(reg_[1][1][1]), "=r"(reg_[1][2][1]), "=r"(reg_[1][3][1]) + : "r"(src_addr) + ); asm volatile ( "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " "{%0, %1, %2, %3}, [%4];" - : "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7]) + : "=r"(reg_[1][4][1]), "=r"(reg_[1][5][1]), "=r"(reg_[1][6][1]), "=r"(reg_[1][7][1]) : "r"(src_addr + 32 * smem_stride_) ); #else @@ -602,7 +760,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const unsigned int NKPQ = param.n * KPQ; // loop bounds, constexpr where possible allows for loop unrolling - constexpr unsigned int mma_tiles_per_warp_k = 4; + constexpr unsigned int mma_tiles_per_warp_k = 2; constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M; constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N; const unsigned int z = blockIdx.z; @@ -629,13 +787,13 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // declare register storage // ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2]; - uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2]; - uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n]; + uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4]; + uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2]; // convenience cast to half for register storage half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4] = reinterpret_cast(acc_register); - half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast(A_register); - half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] = reinterpret_cast(B_register); + half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][8] = reinterpret_cast(A_register); + half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][4] = reinterpret_cast(B_register); // accumulators start at 0 for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){ @@ -685,15 +843,26 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){ #pragma unroll for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){ + // asm volatile ( + // "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + // "{%0, %1}, " + // "{%2, %3}, " + // "{%4}, " + // "{%5, %6};" + // : "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1]) + // : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]), + // "r"(B_register[mma_k][mma_n]) + // "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) + // ); asm volatile ( - "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " "{%0, %1}, " - "{%2, %3}, " - "{%4}, " - "{%5, %6};" + "{%2, %3, %4, %5}, " + "{%6, %7}, " + "{%8, %9};" : "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1]) - : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]), - "r"(B_register[mma_k][mma_n]) + : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]),"r"(A_register[mma_m][mma_k][2]), "r"(A_register[mma_m][mma_k][3]), + "r"(B_register[mma_k][mma_n][0]), "r"(B_register[mma_k][mma_n][1]) "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) ); } diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 23a3aab366..8ee0747989 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -301,7 +301,7 @@ static std::vector> configs = { // std::make_tuple(960,320,104,152,3,3), // std::make_tuple(1280,1280,26,38,3,3), // std::make_tuple(1920,640,32,32,3,3) - // std::make_tuple(1280,1280,16,16,3,3), + std::make_tuple(1280,1280,16,16,3,3), // std::make_tuple(320,640,32,32,3,3), // std::make_tuple(4,320,96,128,3,3), // std::make_tuple(320,4,96,128,3,3), @@ -317,7 +317,7 @@ static std::vector> configs = { // std::make_tuple(1920,1280,26,38,3,3), // std::make_tuple(2560,1280,26,38,3,3), // std::make_tuple(320,1280,26,38,3,3), - std::make_tuple(512,512,104,152,3,3), + // std::make_tuple(512,512,104,152,3,3), // std::make_tuple(512,512,208,304,3,3), // std::make_tuple(512,256,416,608,3,3), // std::make_tuple(256,128,832,1216,3,3), @@ -653,7 +653,7 @@ int main(void) int k = 0; - // for (auto c : configs_sdxl_768){ + // for (auto c : configs_sdxl_1024){ for (auto c : configs){ test_model model; load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c),