diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index f646cf73b3..1866517775 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -732,6 +732,236 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +template +__device__ __forceinline__ void ldmatrix_a( + const half* src, + half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] +) +{ + static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 4"); + static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); + + uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast(reg); + unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; + unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); + swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); + uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset); + constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes + + // 0 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][0][0]), "=r"(reg_[0][0][1]), "=r"(reg_[1][0][0]), "=r"(reg_[1][0][1]) + : "r"(src_addr) + ); + + // 0 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][0][0]), "=r"(reg_[2][0][1]), "=r"(reg_[3][0][0]), "=r"(reg_[3][0][1]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 0 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][0][0]), "=r"(reg_[4][0][1]), "=r"(reg_[5][0][0]), "=r"(reg_[5][0][1]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 0 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][0][0]), "=r"(reg_[6][0][1]), "=r"(reg_[7][0][0]), "=r"(reg_[7][0][1]) + : "r"(src_addr + 96 * smem_stride_) + ); + + src_addr ^= 0b10000; + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][1][0]), "=r"(reg_[0][1][1]), "=r"(reg_[1][1][0]), "=r"(reg_[1][1][1]) + : "r"(src_addr) + ); + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][1][0]), "=r"(reg_[2][1][1]), "=r"(reg_[3][1][0]), "=r"(reg_[3][1][1]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][1][0]), "=r"(reg_[4][1][1]), "=r"(reg_[5][1][0]), "=r"(reg_[5][1][1]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1]) + : "r"(src_addr + 96 * smem_stride_) + ); + + src_addr ^= 0b110000; + + // 2 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][2][0]), "=r"(reg_[0][2][1]), "=r"(reg_[1][2][0]), "=r"(reg_[1][2][1]) + : "r"(src_addr) + ); + + // 2 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][2][0]), "=r"(reg_[2][2][1]), "=r"(reg_[3][2][0]), "=r"(reg_[3][2][1]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 2 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][2][0]), "=r"(reg_[4][2][1]), "=r"(reg_[5][2][0]), "=r"(reg_[5][2][1]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 2 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][2][0]), "=r"(reg_[6][2][1]), "=r"(reg_[7][2][0]), "=r"(reg_[7][2][1]) + : "r"(src_addr + 96 * smem_stride_) + ); + src_addr ^= 0b10000; + + // 3 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][3][0]), "=r"(reg_[0][3][1]), "=r"(reg_[1][3][0]), "=r"(reg_[1][3][1]) + : "r"(src_addr) + ); + + // 3 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][3][0]), "=r"(reg_[2][3][1]), "=r"(reg_[3][3][0]), "=r"(reg_[3][3][1]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 3 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][3][0]), "=r"(reg_[4][3][1]), "=r"(reg_[5][3][0]), "=r"(reg_[5][3][1]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 3 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1]) + : "r"(src_addr + 96 * smem_stride_) + ); + +} + +template +__device__ __forceinline__ void ldmatrix_b( + const half* src, + half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] +) +{ + static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); + static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8"); + + uint32_t (®_) [4][8] = reinterpret_cast(reg); + const unsigned int logical_offset = ((threadIdx.x % 8) * smem_stride) + (((threadIdx.x % 32) / 8) * 8); + unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b11100000000) >> 5); + uint32_t src_addr = cvta_to_shared_u32(src + swizzled_offset); + constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3]) + : "r"(src_addr) + ); + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7]) + : "r"(src_addr ^ 0b1000000) + ); + + src_addr += 8 * smem_stride_; + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3]) + : "r"(src_addr) + ); + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7]) + : "r"(src_addr ^ 0b1000000) + ); + + src_addr += 8 * smem_stride_; + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3]) + : "r"(src_addr) + ); + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7]) + : "r"(src_addr ^ 0b1000000) + ); + + src_addr += 8 * smem_stride_; + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3]) + : "r"(src_addr) + ); + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7]) + : "r"(src_addr ^ 0b1000000) + ); + +} + template static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, @@ -742,6 +972,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, constexpr unsigned int MMA_N = 8; const unsigned int K = param.c * param.r * param.s; + const uint PQ = param.Oh * param.Ow; + const uint inChannelOffset = param.c * param.w; + const uint weightKOffset = param.c * param.r * param.s; // for convenience/readability in index calculations const unsigned int A_stride = K; @@ -801,12 +1034,13 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, static_assert(NUM_THREADS == 256); float4 A_gmem_cache_reg[4]; float4 B_gmem_cache_reg[4]; - + // prefetch the first block tile of A,B into shared memory - half* A_block_gmem = input + (block_m * BM * A_stride); - half* B_block_gmem = weight + (block_n * BN); - tileMemcpySwizzleA(A_block_gmem, A_block_smem, K); - tileMemcpySwizzle(B_block_gmem, B_block_smem, N); +// half* A_block_gmem = input + (block_m * BM * A_stride); + half* A_block_gmem = input; + half* B_block_gmem = weight + (block_n * weightKOffset); + tileMemcpySwizzleA(A_block_gmem, A_block_smem, inChannelOffset, param); + tileMemcpySwizzleB(B_block_gmem, B_block_smem, weightKOffset, param); // construct const pointers to warp tiles for use inside the inner loop @@ -864,7 +1098,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, offset_direction = -1 * offset_direction; tileMemcpySwizzleStoreA(A_gmem_cache_reg, A_block_smem); - tileMemcpySwizzleStore(B_gmem_cache_reg, B_block_smem); + tileMemcpySwizzleStoreB (B_gmem_cache_reg, B_block_smem); } } @@ -1026,6 +1260,7 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW }; params.SC_fastdiv = init_fastdiv_values(KW*IC); params.OW_fastdiv = init_fastdiv_values(OW); + params.OHOW_fastdiv = init_fastdiv_values(OW*OH); params.C_fastdiv = init_fastdiv_values(IC); params.RS_fastdiv = init_fastdiv_values(KW*KH); params.S_fastdiv = init_fastdiv_values(KW); diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 0a5fc4ab6a..9c15d72c8f 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -23,31 +23,70 @@ typedef struct{ uint3 C_fastdiv; uint3 RS_fastdiv; uint3 S_fastdiv; + uint3 OHOW_fastdiv; } param_t; #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE // same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel template -__device__ __forceinline__ void tileMemcpySwizzle( +unsigned int NUM_THREADS> +__device__ __forceinline__ void tileMemcpySwizzleB( half* src, half* dst, const unsigned int src_stride ) { - constexpr unsigned int SWIZZLE_MASK = 0b111 << SWIZZLE_BITS; + // constexpr unsigned int SWIZZLE_MASK = 0b111 << SWIZZLE_BITS; + + // // reinterpret input/output as float4 + // float4* src_float4 = reinterpret_cast(src); + // float4* dst_float4 = reinterpret_cast(dst); + // const unsigned int src_stride_vectorized = src_stride / 8; + + // // # of threads is multiple of # of columns in the tile + // constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + // static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + // // flatten out 2d grid of threads into in order of increasing threadIdx.x + // const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + + // // assign each thread a row/column in the tile, calculate how many iterations we need + // // to cover the whole tile + // constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + // constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + // const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + + // #pragma unroll + // for (unsigned int i = 0; i < NUM_ITERS; i++) + // { + // // apply swizzle to the dst index + // const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; + // unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; + // dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK) >> SWIZZLE_BITS); + // if (thread_col * 8 < param.k && start_k + innerColA * 4 < end_k){ + // float4 tmp = reinterpret_cast(&src[thread_row * src_stride_vectorized + thread_col*8)[0]; + // dst_float4[dst_index] = src_float4[src_index]; + // }else{ // read 4 halves + // dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); + // } + // thread_row += ROW_STEP; + // } + + constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; + constexpr unsigned int SWIZZLE_BITS_1 = 4; + constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; + constexpr unsigned int SWIZZLE_BITS_2 = 2; + constexpr unsigned int TILE_COLS = 32; // reinterpret input/output as float4 - float4* src_float4 = reinterpret_cast(src); + // float4* src_float4 = reinterpret_cast(src); float4* dst_float4 = reinterpret_cast(dst); - const unsigned int src_stride_vectorized = src_stride / 8; + // const unsigned int src_stride_vectorized = src_stride / 8; // # of threads is multiple of # of columns in the tile constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); - // flatten out 2d grid of threads into in order of increasing threadIdx.x const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; @@ -57,15 +96,24 @@ __device__ __forceinline__ void tileMemcpySwizzle( constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - + // TODO: next block_k loop + const uint curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset + const uint curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // + #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++) { // apply swizzle to the dst index - const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; + const unsigned int src_index = thread_row * src_stride + thread_col * 8; unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; - dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK) >> SWIZZLE_BITS); - dst_float4[dst_index] = src_float4[src_index]; + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); + if (thread_row < param.k && curR < param.R && curS < param.S && curC < param.c){ + dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; + }else{ // read 4 halves + dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); + } thread_row += ROW_STEP; } } @@ -77,7 +125,9 @@ unsigned int NUM_THREADS> __device__ __forceinline__ void tileMemcpySwizzleA( half* src, half* dst, - const unsigned int src_stride + // const unsigned int src_stride, + const unsigned int inChannelOffset, + param_t param ) { constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; @@ -87,14 +137,13 @@ __device__ __forceinline__ void tileMemcpySwizzleA( constexpr unsigned int TILE_COLS = 32; // reinterpret input/output as float4 - float4* src_float4 = reinterpret_cast(src); + // float4* src_float4 = reinterpret_cast(src); float4* dst_float4 = reinterpret_cast(dst); - const unsigned int src_stride_vectorized = src_stride / 8; + // const unsigned int src_stride_vectorized = src_stride / 8; // # of threads is multiple of # of columns in the tile constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); - // flatten out 2d grid of threads into in order of increasing threadIdx.x const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; @@ -104,16 +153,35 @@ __device__ __forceinline__ void tileMemcpySwizzleA( constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - + + #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++) { + unsigned int gemm_i = blockDim.y * TILE_ROWS + thread_row; + unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); + unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); + int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; + int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; + unsigned int inOffset = n * param.c * param.h * param.w; + // TODO: next block_k loop + const uint curR = fastdiv(thread_col*8, param.SC_fastdiv); // channel offset + const uint curS = fastdiv(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const uint curC = fastmodulo(fastmodulo(thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + int curH = posh_ori + curR * param.d_h; // input h + int curW = posw_ori + curS * param.d_w; // input w // apply swizzle to the dst index - const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - dst_float4[dst_index] = src_float4[src_index]; + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && + curR < param.R && curS < param.S && curC < param.c){ + // const unsigned int src_index = thread_row * src_stride_vectorized + thread_col; + const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + dst_float4[dst_index] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; + } else{ + dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); + } thread_row += ROW_STEP; } }