diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu new file mode 100644 index 0000000000..37144970d3 --- /dev/null +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -0,0 +1,1243 @@ +// #include +#include +#include "ggml.h" +#include "common.cuh" +#include "convert.cuh" +#include "cp-async.cuh" +#include "conv2d-implicit.cuh" + + +typedef unsigned int uint; + +constexpr uint WARPSIZE = 32; +#define CUDA_NCHW_2_NHWC_TILE_DIM 32 +#define CUDA_NCHW_2_NHWC_BLOCK_NM 8 +#define CUDA_NCHW_2_NHWC_BLOCK_ROWS 8 +#define CUDA_NCHW_2_NHWC_BLOCK_C 64 + + +//currently not use; in future for split-k kernels +template +static __global__ void reduce_f32(const src_T * __restrict__ x, dst_T * __restrict__ dst, const int ncols, const int nrows) { + const int row = blockIdx.x; + const int col = threadIdx.x; + + float sum = 0.0f; + if (row * blockDim.x + col < ncols) { + for (int i = 0; i < nrows; ++i){ + sum += ggml_cuda_cast(x[i * ncols + row * blockDim.x + col]); + } + dst[row * blockDim.x + col] = ggml_cuda_cast(sum); + } +} + +constexpr uint32_t filter_swizzle_mask(uint32_t n, uint32_t m) { + if (n <= 1) return 1; + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + int count = 0; + while ((m >>= 1) != 0) + ++count; + return n << count; +} + +template +static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01){ + + const int64_t nmat = ne / (ne00 * ne01); + const int64_t n = ne00 * ne01; + + int x = blockIdx.x * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.x; + int y = blockIdx.y * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.y; + int tx = blockIdx.y * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.x; // transpose block offset + int ty = blockIdx.x * CUDA_NCHW_2_NHWC_TILE_DIM + threadIdx.y; + + __shared__ src_T tile[CUDA_NCHW_2_NHWC_TILE_DIM][CUDA_NCHW_2_NHWC_TILE_DIM]; +#pragma unroll + for(int i = 0; i < CUDA_NCHW_2_NHWC_BLOCK_NM; ++i){ + + const unsigned int imat = blockIdx.z * CUDA_NCHW_2_NHWC_BLOCK_NM + i; + if(imat >= nmat) + break; +#pragma unroll + for (int j = 0; j < CUDA_NCHW_2_NHWC_TILE_DIM; j += CUDA_NCHW_2_NHWC_BLOCK_ROWS){ + if(x < ne01 && y + j < ne00){ + const int row = threadIdx.y+j; + const int col = threadIdx.x ^ row; + tile[row][col] = src[imat*n + (y+j)*ne01 + x]; + } + } + __syncthreads(); +#pragma unroll + for (int j = 0; j < CUDA_NCHW_2_NHWC_TILE_DIM; j += CUDA_NCHW_2_NHWC_BLOCK_ROWS){ + if(ty + j < ne01 && tx < ne00){ + const int col = (threadIdx.y+j) ^ threadIdx.x; + dst[imat*n + (ty+j)*ne00 + tx] = ggml_cuda_cast(tile[threadIdx.x][col]); + } + } + } +} + +template +static __global__ void NCHW2NHWC(const src_T *src, dst_T * dst, const int ne, const int ne00, const int ne01, param_t P){ + + const int64_t n = ne00 * ne01; + + const unsigned int tx = threadIdx.x; + const unsigned int bx = blockIdx.x; + const unsigned int by = blockIdx.y; + + const unsigned int blk = (bx+1) * blk_c <= ne00 ? blk_c : ne00 - bx * blk_c; + + __shared__ src_T tile[rs*blk_c]; + + +#pragma unroll + for (unsigned int j = 0; j < rs; j++){ + const int i = j * blk + tx; + const unsigned int row = fastmodulo(i, P.RS_fastdiv); + const unsigned int col = fastdiv(i, P.RS_fastdiv); + const unsigned int src_index = by*n + bx * blk_c * rs + j * blk + tx; + unsigned int idx = row * blk_c + col; + idx = idx ^ ((idx & mask) >> 4); + if (src_index < ne && tx < blk) { + tile[idx] = src[src_index]; + } + } + __syncthreads(); +#pragma unroll + for (unsigned int j = 0; j < rs; j++){ + const unsigned int dst_index = by*n + j*ne00 + bx*blk_c + tx; + if(dst_index < ne && tx < blk){ + unsigned int idx = j*blk_c + tx; + idx = idx ^ ((idx & mask) >> 4); + dst[dst_index] = ggml_cuda_cast(tile[idx]); + } + } +} + + + +template +static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, + const T * __restrict__ kernel, + float * __restrict__ output, + const param_t param) { + + __shared__ char smem[sizeof(float) * (TM*TN*NUM_THREADS) <= sizeof(float) * 2 * (BM+PAD) * BK + sizeof(T)*2*BK * (BN+PAD) ? + sizeof(float)*2*(BM+PAD)*BK + sizeof(T)*2*BK*(BN+PAD) : sizeof(float) * (TM*TN*NUM_THREADS)]; + T *smemweight = reinterpret_cast(smem); + float *smeminput = reinterpret_cast(smem + 2 * BK * (BN+PAD) * sizeof(T)); + + const uint tx = threadIdx.x; + const uint bx = blockIdx.x; + const uint by = blockIdx.y; + + const uint PQ = param.Oh * param.Ow; + const uint CHW = param.c * param.h * param.w; + + // Warp tile + const uint lane_id = tx % WARPSIZE; + const uint warp_id = tx / WARPSIZE; + const int mma_tid_x = warp_id / (BN / WN); + const int mma_tid_y = warp_id % (BN / WN); + + // size of the warp subtile + constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER); + constexpr uint WSUBM = WM / WMITER; // 64/2=32 + constexpr uint WSUBN = WN / WNITER; // 32/2=16 + + // Placement of the thread in the warp subtile + const uint threadColInWarp = lane_id % (WSUBN / TN); // i%(16/4) + const uint threadRowInWarp = lane_id / (WSUBN / TN); // i/4 + + int z = blockIdx.z; + + int inChannelOffset = layout == 0 ? param.c * param.w : param.h * param.w; + int weightKOffset = param.c * param.r * param.s; + + const uint ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset; + const uint start_k = (ksplit > 0)? z * ks: 0; + const uint end_k = min(start_k + ks, weightKOffset); + + int write_flag = 1; + T weight_frag[2][WNITER * TN]; + float input_frag[2][WMITER * TM] = {0.f}; + float output_frag[WMITER * TM * WNITER * TN] = {0.f}; + + // calculating the indices that this thread will load into SMEM + // we'll load 128bit / 32bit = 4 elements per thread at each step + const uint innerRowA = tx / (BK / 4); + const uint innerColA = tx % (BK / 4); + constexpr uint rowStrideA = (NUM_THREADS * 4) / BK; + +// ldg + loadFilter + (kernel, smemweight, by, innerRowA, innerColA, weightKOffset, + start_k, end_k, param); + + loadInput + (input, smeminput, bx, innerRowA, innerColA, + start_k, end_k, PQ, CHW, inChannelOffset, param); + + __syncthreads(); + + // lds + const uint input_lds_addr = mma_tid_x * WM; +#pragma unroll + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) +#pragma unroll + for (uint i = 0; i < TM; ++i) + input_frag[0][wSubRowIdx * TM + i] = smeminput[input_lds_addr + wSubRowIdx * WSUBM + + threadRowInWarp * TM + i]; + + const uint weight_lds_addr = mma_tid_y * WN; +#pragma unroll + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) +#pragma unroll + for (uint i = 0; i < TN; ++i) + weight_frag[0][wSubColIdx * TN + i] = smemweight[weight_lds_addr + wSubColIdx * WSUBN + + threadColInWarp * TN + i]; + + for (int crs = start_k; crs < end_k; crs += BK) { + + int load_flag = write_flag ^ 1; +#pragma unroll + for (int subcrs = 0; subcrs < BK - 1; ++subcrs) + { + +#pragma unroll + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) +#pragma unroll + for (uint i = 0; i < TN; ++i) + weight_frag[(subcrs + 1) % 2][wSubColIdx * TN + i] = smemweight[load_flag * (BN+PAD) * BK + + (subcrs + 1) * (BN+PAD) + weight_lds_addr + wSubColIdx * WSUBN + threadColInWarp * TN + i]; +#pragma unroll + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) +#pragma unroll + for (uint i = 0; i < TM; ++i) + input_frag[(subcrs + 1) % 2][wSubRowIdx * TM + i] = smeminput[load_flag * (BM+PAD) * BK + + (subcrs + 1) * (BM+PAD) + input_lds_addr + wSubRowIdx * WSUBM + threadRowInWarp * TM + i]; + + // execute warptile matmul +#pragma unroll + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { +#pragma unroll + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + // calculate per-thread results +#pragma unroll + for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { +#pragma unroll + for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { + output_frag[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + + (wSubColIdx * TN) + resIdxN] += + input_frag[subcrs % 2][wSubRowIdx * TM + resIdxM] * + ggml_cuda_cast(weight_frag[subcrs % 2][wSubColIdx * TN + resIdxN]); + } + } + } + } + } + // ldg + + loadFilter + (kernel, &smemweight[write_flag * (BN+PAD) * BK], by, innerRowA, innerColA, weightKOffset, + crs+BK, end_k, param); + + loadInput + (input, &smeminput[write_flag * (BM+PAD) * BK], bx, innerRowA, innerColA, + crs + BK, end_k, PQ, CHW, inChannelOffset, param); + + __syncthreads(); + + write_flag ^= 1; + +#pragma unroll + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) +#pragma unroll + for (uint i = 0; i < TM; ++i) + input_frag[0][wSubRowIdx * TM + i] = smeminput[(load_flag ^ 1) * (BM+PAD) * BK + + input_lds_addr + wSubRowIdx * WSUBM + threadRowInWarp * TM + i]; +#pragma unroll + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) +#pragma unroll + for (uint i = 0; i < TN; ++i) + weight_frag[0][wSubColIdx * TN + i] = smemweight[(load_flag ^ 1) * (BN+PAD) * BK + + weight_lds_addr + wSubColIdx * WSUBN + threadColInWarp * TN + i]; +#pragma unroll + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { +#pragma unroll + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + // calculate per-thread results +#pragma unroll + for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { +#pragma unroll + for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { + output_frag[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + + (wSubColIdx * TN) + resIdxN] += + input_frag[1][wSubRowIdx * TM + resIdxM] * + ggml_cuda_cast(weight_frag[1][wSubColIdx * TN + resIdxN]); + } + } + } + } + } + + // reuse smem + float *smemoutput = reinterpret_cast(smem); + + const uint output_lds_addr = warp_id * WSUBM * WSUBN + lane_id; + const uint output_sts_addr = mma_tid_x * BN / WN * TM * TN * WARPSIZE + mma_tid_y * TM * TN * WARPSIZE + + threadColInWarp * TN * WSUBM + threadRowInWarp * TM; + const uint m_idx = by * BN + mma_tid_y * WN; + const uint n_idx = bx * BM + mma_tid_x * WM; + +#pragma unroll + for (int i = 0; i < WMITER; ++i) + { +#pragma unroll + for (int j = 0; j < WNITER; ++j) + { + __syncthreads(); + +#pragma unroll + for (int subi = 0; subi < TM; ++subi) + { +#pragma unroll + for (int subj = 0; subj < TN; ++subj) + { + // output sts + smemoutput[output_sts_addr + subj * WSUBM + subi] = + output_frag[(i * TM + subi) * (WNITER * TN) + j * TN + subj]; + } + } + __syncthreads(); +#pragma unroll + for (int subk = 0; subk < TM * TN; ++subk){ + const uint row = m_idx + j * WSUBN + (lane_id + subk * WARPSIZE) / WSUBM; + const uint gemm_i = n_idx + i * WSUBM + (lane_id + subk * WARPSIZE) % WSUBM; + const int n = (ksplit > 0) ? gemm_i / PQ : z; + const int col = (ksplit > 0) ? gemm_i % PQ : gemm_i; + if (n < param.n && row < param.k && col < param.Oh * param.Ow){ + const uint outOffset = ksplit > 0 ? + z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + + row * param.Oh * param.Ow + col : + z * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; + output[outOffset] = smemoutput[output_lds_addr + subk * WARPSIZE]; + } + } + } + } +} + + + +template +__device__ __forceinline__ void ldmatrix_a( + const half* src, + half (®)[mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] +){ +#ifdef CP_ASYNC_AVAILABLE + static_assert(mma_tiles_per_warp_m == 8, "mma_tiles_per_warp_m must be 8"); + static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); + + uint32_t (®_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][2] = reinterpret_cast(reg); + + unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; + unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); + swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); + uint32_t src_addr = ggml_cuda_cvta_generic_to_shared(src + swizzled_offset); + constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes + + // 0 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][0][0]), "=r"(reg_[0][0][1]), "=r"(reg_[1][0][0]), "=r"(reg_[1][0][1]) + : "r"(src_addr) + ); + + // 0 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][0][0]), "=r"(reg_[2][0][1]), "=r"(reg_[3][0][0]), "=r"(reg_[3][0][1]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 0 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][0][0]), "=r"(reg_[4][0][1]), "=r"(reg_[5][0][0]), "=r"(reg_[5][0][1]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 0 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][0][0]), "=r"(reg_[6][0][1]), "=r"(reg_[7][0][0]), "=r"(reg_[7][0][1]) + : "r"(src_addr + 96 * smem_stride_) + ); + + src_addr ^= 0b10000; + + // 1 + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][1][0]), "=r"(reg_[0][1][1]), "=r"(reg_[1][1][0]), "=r"(reg_[1][1][1]) + : "r"(src_addr) + ); + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][1][0]), "=r"(reg_[2][1][1]), "=r"(reg_[3][1][0]), "=r"(reg_[3][1][1]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][1][0]), "=r"(reg_[4][1][1]), "=r"(reg_[5][1][0]), "=r"(reg_[5][1][1]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 1 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][1][0]), "=r"(reg_[6][1][1]), "=r"(reg_[7][1][0]), "=r"(reg_[7][1][1]) + : "r"(src_addr + 96 * smem_stride_) + ); + + src_addr ^= 0b110000; + + // 2 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][2][0]), "=r"(reg_[0][2][1]), "=r"(reg_[1][2][0]), "=r"(reg_[1][2][1]) + : "r"(src_addr) + ); + + // 2 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][2][0]), "=r"(reg_[2][2][1]), "=r"(reg_[3][2][0]), "=r"(reg_[3][2][1]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 2 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][2][0]), "=r"(reg_[4][2][1]), "=r"(reg_[5][2][0]), "=r"(reg_[5][2][1]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 2 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][2][0]), "=r"(reg_[6][2][1]), "=r"(reg_[7][2][0]), "=r"(reg_[7][2][1]) + : "r"(src_addr + 96 * smem_stride_) + ); + src_addr ^= 0b10000; + + // 3 + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][3][0]), "=r"(reg_[0][3][1]), "=r"(reg_[1][3][0]), "=r"(reg_[1][3][1]) + : "r"(src_addr) + ); + + // 3 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][3][0]), "=r"(reg_[2][3][1]), "=r"(reg_[3][3][0]), "=r"(reg_[3][3][1]) + : "r"(src_addr + 32 * smem_stride_) + ); + + // 3 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[4][3][0]), "=r"(reg_[4][3][1]), "=r"(reg_[5][3][0]), "=r"(reg_[5][3][1]) + : "r"(src_addr + 64 * smem_stride_) + ); + + // 3 + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[6][3][0]), "=r"(reg_[6][3][1]), "=r"(reg_[7][3][0]), "=r"(reg_[7][3][1]) + : "r"(src_addr + 96 * smem_stride_) + ); + +#else + GGML_UNUSED(src); + GGML_UNUSED(reg); + NO_DEVICE_CODE; +#endif +} + +template +__device__ __forceinline__ void ldmatrix_b( + const half* src, + half (®)[mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] +){ +#ifdef CP_ASYNC_AVAILABLE + static_assert(mma_tiles_per_warp_k == 4, "mma_tiles_per_warp_k must be 4"); + static_assert(mma_tiles_per_warp_n == 8, "mma_tiles_per_warp_n must be 8"); + + uint32_t (®_) [4][8] = reinterpret_cast(reg); + + unsigned int logical_offset = (threadIdx.x % 32) * smem_stride; + unsigned int swizzled_offset = logical_offset ^ ((logical_offset & 0b10000000) >> 4); + swizzled_offset = swizzled_offset ^ ((swizzled_offset & 0b1100000) >> 2); + uint32_t src_addr = ggml_cuda_cvta_generic_to_shared(src + swizzled_offset); + constexpr unsigned int smem_stride_ = smem_stride * sizeof(half); // convert stride to bytes + + // 0 + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][0]), "=r"(reg_[0][1]), "=r"(reg_[0][2]), "=r"(reg_[0][3]) + : "r"(src_addr) + ); + + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[0][4]), "=r"(reg_[0][5]), "=r"(reg_[0][6]), "=r"(reg_[0][7]) + : "r"(src_addr + 32 * smem_stride_) + ); + + src_addr ^= 0b10000; + + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[1][0]), "=r"(reg_[1][1]), "=r"(reg_[1][2]), "=r"(reg_[1][3]) + : "r"(src_addr) + ); + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[1][4]), "=r"(reg_[1][5]), "=r"(reg_[1][6]), "=r"(reg_[1][7]) + : "r"(src_addr + 32 * smem_stride_) + ); + + src_addr ^= 0b110000; + + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][0]), "=r"(reg_[2][1]), "=r"(reg_[2][2]), "=r"(reg_[2][3]) + : "r"(src_addr) + ); + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[2][4]), "=r"(reg_[2][5]), "=r"(reg_[2][6]), "=r"(reg_[2][7]) + : "r"(src_addr + 32 * smem_stride_) + ); + + + src_addr ^= 0b10000; + + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[3][0]), "=r"(reg_[3][1]), "=r"(reg_[3][2]), "=r"(reg_[3][3]) + : "r"(src_addr) + ); + + asm volatile ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 " + "{%0, %1, %2, %3}, [%4];" + : "=r"(reg_[3][4]), "=r"(reg_[3][5]), "=r"(reg_[3][6]), "=r"(reg_[3][7]) + : "r"(src_addr + 32 * smem_stride_) + ); +#else + GGML_UNUSED(src); + GGML_UNUSED(reg); + NO_DEVICE_CODE; +#endif +} + +template +static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, + const half * __restrict__ kernel, + T * __restrict__ output, + const param_t param) { +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING + + constexpr unsigned int MMA_M = 16; + constexpr unsigned int MMA_N = 8; + + // loop bounds, constexpr where possible allows for loop unrolling + + constexpr unsigned int mma_tiles_per_warp_k = 4; + constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M; + constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N; + const unsigned int z = blockIdx.z; + + const unsigned int ks = (ksplit > 0) ? (param.c + ksplit - 1) / ksplit : param.c; + const unsigned int start_k = (ksplit > 0) ? z * ks : 0; + const unsigned int end_k = min(start_k + ks, param.c); + const unsigned int num_block_tiles_k = (ks + (BK-1)) / BK; + const unsigned int num_block_tiles_krs = num_block_tiles_k * param.r * param.s; + + constexpr unsigned int TILE_COLS_VECTORIZED = BK / 8; + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int A_K_STRID = BM / ROW_STEP; + // constexpr unsigned int B_K_STRID = BN / ROW_STEP; + + unsigned int masks_a[A_K_STRID][2]; + int64_t element_offset_a[A_K_STRID]; + int64_t element_offset_b; + + // calculate block/warp indices + const unsigned int block_m = blockIdx.y; + const unsigned int block_n = blockIdx.x; + const unsigned int warp_m = threadIdx.y; + const unsigned int warp_n = threadIdx.x / 32; + const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + + // double buffering + extern __shared__ half shmem[]; + half* A_block_smem = shmem; + half* B_block_smem = &shmem[BM * BK]; + constexpr int BUFFER_SIZE = BM * BK + BK * BN; + +#ifdef CP_ASYNC_AVAILABLE + half* SA1 = A_block_smem; + half* SB1 = B_block_smem; + half* SA2 = &shmem[BUFFER_SIZE]; + half* SB2 = SA2 + BM * BK; +#else + float4 A_gmem_cache_reg[4]; + float4 B_gmem_cache_reg[4]; + int offset_direction = 1; +#endif + // declare register storage + // ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together + uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2]; + + uint32_t A_register[mma_tiles_per_warp_m][mma_tiles_per_warp_k][2]; + uint32_t B_register[mma_tiles_per_warp_k][mma_tiles_per_warp_n]; + + // convenience cast to half for register storage + half (&acc_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_n][4] = reinterpret_cast(acc_register); + half (&A_register_) [mma_tiles_per_warp_m][mma_tiles_per_warp_k][4] = reinterpret_cast(A_register); + half (&B_register_) [mma_tiles_per_warp_k][mma_tiles_per_warp_n][2] = reinterpret_cast(B_register); + + // accumulators start at 0 + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){ + for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){ + acc_register_[mma_m][mma_n][0] = 0; + acc_register_[mma_m][mma_n][1] = 0; + acc_register_[mma_m][mma_n][2] = 0; + acc_register_[mma_m][mma_n][3] = 0; + } + } + + const unsigned int A_warp_tile_offset = warp_m * WM * BK; + const unsigned int B_warp_tile_offset = warp_n * WN * BK; + + static_assert(BM == 256); + static_assert(BN == 256); + static_assert(BK == 32); + static_assert(NUM_THREADS == 256); + + + prepareIteratorA(thread_row, masks_a, element_offset_a, param); + +#ifdef CP_ASYNC_AVAILABLE + unsigned int iter_src_idx = thread_row * param.weightKOffset; + unsigned int iter_dst_idx = thread_row * TILE_COLS_VECTORIZED + thread_col; + unsigned int krow_idx = thread_row + blockIdx.x * BN; + const int ITER_SRC_STEPS = ROW_STEP * param.weightKOffset; +#endif + + + // prefetch the first block tile of A,B into shared memory + + const half* A_block_gmem = input; + const half* B_block_gmem = kernel + block_n * BN * param.weightKOffset; + + unsigned int curC = tileMemcpySwizzleA(A_block_gmem, A_block_smem, 0, 0, masks_a, element_offset_a, + thread_row, thread_col, start_k, end_k, param); + element_offset_b = curC; + tileMemcpySwizzleB(B_block_gmem, B_block_smem, 0, 0, curC, element_offset_b, start_k, end_k, thread_row, thread_col, param); + +#ifdef CP_ASYNC_AVAILABLE + asm volatile("cp.async.commit_group;\n" ::); +#endif + + unsigned int block_k = 0; + unsigned int block_krs = 1; + int s = 0; + int r = 0; + +#ifdef CP_ASYNC_AVAILABLE + while (block_krs < num_block_tiles_krs) { + + asm volatile("cp.async.wait_group %0;\n" ::"n"(0)); +#else + while (block_k < num_block_tiles_k) { +#endif + __syncthreads(); + + // moves to the next channel block tile + int next_idx = 0; + ++s; + if (s == param.s) { + s = 0; + ++r; + if (r < param.r) { + next_idx = 1; + } else { + r = 0; + next_idx = 2; + } + } + + add_byte_offset(element_offset_a, param.inc_next[next_idx]); + + if (next_idx == 2) { + ++block_k; + } + + if (block_krs != num_block_tiles_krs) { +#ifdef CP_ASYNC_AVAILABLE + curC = tileMemcpyAsyncLoadA(A_block_gmem, SA2, r, s, + masks_a, element_offset_a, thread_row, thread_col, + iter_dst_idx, block_k * BK, + start_k, end_k, curC, param); + element_offset_b = (r*param.s+s)*param.c + curC; + tileMemcpyAsyncLoadB(B_block_gmem, SB2, r, s, curC, element_offset_b, block_k * BK, + start_k, end_k, thread_row, thread_col, + iter_src_idx, iter_dst_idx, krow_idx, ITER_SRC_STEPS,param); + asm volatile("cp.async.commit_group;\n" ::); +#else + curC = tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, r, s, + masks_a, element_offset_a, thread_row, thread_col, block_k * BK, + start_k, end_k, curC, param); + element_offset_b = (r*param.s+s)*param.c + curC; + tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, r, s, curC, element_offset_b, block_k * BK, + start_k, end_k, thread_row, thread_col, param); +#endif + } + +#ifdef CP_ASYNC_AVAILABLE + half* A_warp_tile = SA1 + A_warp_tile_offset; + half* B_warp_tile = SB1 + B_warp_tile_offset; +#else + half* A_warp_tile = A_block_smem + A_warp_tile_offset; + half* B_warp_tile = B_block_smem + B_warp_tile_offset; +#endif + + ldmatrix_a(A_warp_tile, A_register_); + ldmatrix_b(B_warp_tile, B_register_); + + // outer product between mma tiles +#pragma unroll + for (unsigned int mma_k = 0; mma_k < mma_tiles_per_warp_k; mma_k++) { +#pragma unroll + for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) { +#pragma unroll + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) { + + asm volatile ( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0, %1}, " + "{%2, %3}, " + "{%4}, " + "{%5, %6};" + : "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1]) + : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]), + "r"(B_register[mma_k][mma_n]) + "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) + ); + } + } + } + + + if (block_krs != num_block_tiles_krs) { +#ifdef CP_ASYNC_AVAILABLE + half *tmp = SA1; SA1 = SA2; SA2 = tmp; + tmp = SB1; SB1 = SB2; SB2 = tmp; +#else + // switch smem buffers each iteration + A_block_smem = A_block_smem + BUFFER_SIZE * offset_direction; + B_block_smem = B_block_smem + BUFFER_SIZE * offset_direction; + offset_direction = -1 * offset_direction; + + tileMemcpySwizzleStore(A_gmem_cache_reg, A_block_smem, thread_row, thread_col); + tileMemcpySwizzleStore(B_gmem_cache_reg, B_block_smem, thread_row, thread_col); +#endif + } + block_krs++; + } + + +#ifdef CP_ASYNC_AVAILABLE + asm volatile("cp.async.wait_group %0;\n" ::"n"(0)); + __syncthreads(); + half* A_warp_tile = SA1 + A_warp_tile_offset; + half* B_warp_tile = SB1 + B_warp_tile_offset; + ldmatrix_a(A_warp_tile, A_register_); + ldmatrix_b(B_warp_tile, B_register_); + // outer product between mma tiles +#pragma unroll + for (unsigned int mma_k = 0; mma_k < mma_tiles_per_warp_k; mma_k++) { +#pragma unroll + for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++) { +#pragma unroll + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) { + asm volatile ( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0, %1}, " + "{%2, %3}, " + "{%4}, " + "{%5, %6};" + : "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1]) + : "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]), + "r"(B_register[mma_k][mma_n]) + "r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1]) + ); + } + } + } +#endif + + + // reuse smem + half *smemoutput = shmem; + const uint lane_id = threadIdx.x % WARPSIZE; + const uint mma_row = lane_id / 4; + const uint mma_col = lane_id % 4; + const uint warp_offset = warp_m * WM * BN/2 + warp_n * WN/2; + const uint output_lds_addr = warp_offset + lane_id * BN/2; + const uint output_sts_addr = warp_offset + mma_row * BN/2 + mma_col * 2; + const uint m_idx = block_n * BN + warp_n * WN; + const uint n_idx = block_m * BM + warp_m * WM + lane_id; + +#pragma unroll + for (int i = 0; i < 2; ++i) { + const unsigned int i_offset = i * mma_tiles_per_warp_n/2; + __syncthreads(); +#pragma unroll + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) { + const unsigned int mma_m_offset = output_sts_addr + mma_m * MMA_M * BN / 2; + for (unsigned int mma_n = i_offset; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++) { + uint32_t (®_)[2] = reinterpret_cast(acc_register_[mma_m][mma_n]); + uint idx = mma_m_offset + (mma_n - i_offset) * MMA_N; + idx = idx ^ ((idx & 0b110000000000) >> 9); + idx = idx ^ ((idx & 0b1110000000) >> 4); + uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); + dst_ptr[0] = reg_[0]; + idx = (idx + 8 * BN / 2 ) ^ 0b010; + dst_ptr = reinterpret_cast(&smemoutput[idx]); + dst_ptr[0] = reg_[1]; + } + } + __syncthreads(); + + const unsigned int m_i_wn = m_idx + i * WN / 2; +#pragma unroll + for (int subk = 0; subk < WN / 4; ++subk) { + const uint row = m_i_wn + subk*2; + uint idx = output_lds_addr + subk*2; + idx = idx ^ ((idx & 0b110000000000) >> 9); + idx = idx ^ ((idx & 0b1110000000) >> 4); +#pragma unroll + for (int j = 0; j < 4; ++j) { + const uint gemm_i = n_idx + j*32; + const int n = fastdiv(gemm_i, param.OHOW_fastdiv); + const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); + uint32_t dst_ptr = *(reinterpret_cast(&smemoutput[idx+j*16*BN])); // 32*BN/2 = 16*BN + half (&res_)[2] = reinterpret_cast(dst_ptr); + if (n < param.n && row < param.k && col < param.PQ) { + const uint outOffset = ((ksplit > 0) ? z * param.NKPQ : 0) + n * param.KPQ + row * param.PQ + col; + output[outOffset] = ggml_cuda_cast(res_[0]); + } + if (n < param.n && row+1 < param.k && col < param.PQ) { + const uint outOffset = ((ksplit > 0) ? z * param.NKPQ : 0) + n * param.KPQ + (row+1) * param.PQ + col; + output[outOffset] = ggml_cuda_cast(res_[1]); + } + } + } + } +#else + GGML_UNUSED(input); + GGML_UNUSED(kernel); + GGML_UNUSED(output); + GGML_UNUSED(param); + NO_DEVICE_CODE; +#endif +} + + +#define NUM_VARIANTS 4 + +/* + conv_shapes[][0]: ne_input=[384,512,256,1],ne_kernel=[3,3,256,256] + conv_shapes[][1]: ne_input=[96,128,512,1],ne_kernel=[3,3,512,512] + conv_shapes[][2]: ne_input=[192,256,512,1git diff],ne_kernel=[3,3,512,512] +*/ +constexpr static int conv_shapes[][NUM_VARIANTS] = { + { 128, 128, 128, 256 }, // BM + { 256, 128, 256, 128 }, // BN + { 8, 8, 8, 8 }, // BK + { 128, 64, 32, 128 }, // WM + { 32, 32 , 256, 32 }, // WN + { 2, 2, 1, 1 }, // WNITER + { 8, 4, 4, 4 }, // TM + { 8, 4, 8, 8 }, // TN + { 256, 256, 128, 256} // NUM_THREADS +}; + +template +static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D, const param_t P, cudaStream_t st) { + + const uint BM = conv_shapes[0][CONV_SHAPE]; + const uint BN = conv_shapes[1][CONV_SHAPE]; + const uint BK = conv_shapes[2][CONV_SHAPE]; + const uint WM = conv_shapes[3][CONV_SHAPE]; + const uint WN = conv_shapes[4][CONV_SHAPE]; + const uint WNITER = conv_shapes[5][CONV_SHAPE]; + const uint TM = conv_shapes[6][CONV_SHAPE]; + const uint TN = conv_shapes[7][CONV_SHAPE]; + const uint NUM_THREADS = conv_shapes[8][CONV_SHAPE]; + int blockx = ((P.Oh * P.Ow + BM - 1) / BM); // blockx number + int blocky = (P.k + BN-1) / BN; // blocky number + int blockz = P.n; // blockz number + int thready = 1; // thready number per block + int threadz = 1; // threadz number per block + dim3 thblock(NUM_THREADS, thready, threadz); + dim3 grid(blockx, blocky, blockz); + + conv2d_implicit_kernel<<>>(X_D, K_D, Y_D, P); +} + +template +static void launch_conv2d_implicit_split_kernel(ggml_backend_cuda_context & ctx, const half *X_H, const half *K_H, float *Y_D, + const unsigned int BlocksM, const unsigned int BlocksN, + const unsigned int shmem_bytes, + param_t P, cudaStream_t st){ + + int id = ggml_cuda_get_device(); + + ggml_cuda_pool_alloc Y_H(ctx.pool(id), ksplit * P.k * P.Oh * P.Ow * P.n); + cudaFuncSetAttribute(conv2d_implicit_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 + dim3 gridDim(BlocksN, BlocksM, ksplit); + dim3 blockDim(ThreadsN, ThreadsM); + + conv2d_implicit_kernel<<>>(X_H, K_H, Y_H.get(), P); + + const unsigned int nrows = P.n * P.k * P.Oh * P.Ow; + const unsigned int blockx = (nrows + 511) / 512; + const dim3 block_nums(blockx, 1, 1); + const dim3 block_dims(512, 1, 1); + reduce_f32<<>>(Y_H.get(), Y_D, nrows, ksplit); +} + +static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, param_t P, cudaStream_t st) { + + if (GGML_CUDA_CC_IS_NVIDIA(cc) && ampere_mma_available(cc) && P.c % 8 == 0 && (P.r <= 32 && P.s <= 32)) { + + int id = ggml_cuda_get_device(); + + int64_t inc[3]; + // next S + inc[0] = int64_t(P.c) * P.d_w; + // next R + inc[1] = int64_t(P.w * P.c) * P.d_h - (P.s - 1) * P.c * P.d_w; + // next C + inc[2] = - int64_t(P.r - 1) * P.w * P.c * P.d_h - int64_t(P.s - 1) * P.c * P.d_w ; + memcpy(P.inc_next, inc, sizeof(int64_t)*3); + + int64_t ne = P.c * P.h * P.w * P.n; + int64_t ne00 = P.c; + int64_t ne01 = P.h * P.w; + ggml_cuda_pool_alloc input_f16(ctx.pool(id), ne); + + dim3 dimGrid( (ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; + dim3 dimBlock(CUDA_NCHW_2_NHWC_TILE_DIM,CUDA_NCHW_2_NHWC_BLOCK_ROWS, 1); + NCHW2NHWC<<>>(X_D, input_f16.get(), ne, ne00, ne01); + + ne = P.c * P.r * P.s * P.k; + ne01 = P.r * P.s; + ggml_cuda_pool_alloc kernel_f16(ctx.pool(id)); + if (ne01 > 1){ + kernel_f16.alloc(ne); + + dim3 dimGrid1((ne00 + CUDA_NCHW_2_NHWC_BLOCK_C - 1) / CUDA_NCHW_2_NHWC_BLOCK_C, + ne/(ne00*ne01), + 1) ; + if (ne01 == 25) { + constexpr unsigned int mask = filter_swizzle_mask(25, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01, P); + } else if (ne01 == 16) { + constexpr unsigned int mask = filter_swizzle_mask(16, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01, P); + } else if (ne01 == 9) { + constexpr unsigned int mask = filter_swizzle_mask(9, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01, P); + } else if (ne01 == 8) { + constexpr unsigned int mask = filter_swizzle_mask(8, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01, P); + } else if (ne01 == 7) { + constexpr unsigned int mask = filter_swizzle_mask(7, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01, P); + } else if (ne01 == 6) { + constexpr unsigned int mask = filter_swizzle_mask(6, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01, P); + } else if (ne01 == 5) { + constexpr unsigned int mask = filter_swizzle_mask(5, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01, P); + } else if (ne01 == 4) { + constexpr unsigned int mask = filter_swizzle_mask(4, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01, P); + } else if (ne01 == 3) { + constexpr unsigned int mask = filter_swizzle_mask(3, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01, P); + } else if (ne01 == 2) { + constexpr unsigned int mask = filter_swizzle_mask(2, CUDA_NCHW_2_NHWC_BLOCK_C); + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01, P); + } else { + dim3 dimGrid2((ne01 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne00 + CUDA_NCHW_2_NHWC_TILE_DIM - 1) / CUDA_NCHW_2_NHWC_TILE_DIM, + (ne/(ne00*ne01) + CUDA_NCHW_2_NHWC_BLOCK_NM - 1) / CUDA_NCHW_2_NHWC_BLOCK_NM) ; + NCHW2NHWC<<>>(K_D, kernel_f16.get(), ne, ne00, ne01); + } + } + + const half *X_H = input_f16.get(); + const half *K_H = ne01 == 1 ? K_D : kernel_f16.get(); + + constexpr unsigned int BM_dim = 256; + constexpr unsigned int BN_dim = 256; + constexpr unsigned int BK_dim = 32; + + constexpr unsigned int WARPS_PER_BLOCK_M = 2; + constexpr unsigned int WARPS_PER_BLOCK_N = 4; + constexpr unsigned int WARPS_PER_BLOCK_K = 4; + + constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M; + constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N; + constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K; + + static_assert(WN_dim % 4 == 0, "final output requires this to be bank conflicts free"); + + const unsigned int BlocksM = (P.n * P.Oh * P.Ow + BM_dim - 1) / BM_dim; + const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim; + constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M; + constexpr unsigned int ThreadsN = WARPSIZE * WARPS_PER_BLOCK_N; + constexpr unsigned int NumThreads = ThreadsM * ThreadsN; + const unsigned int shmem_bytes = (BM_dim * BK_dim + BK_dim * BN_dim) * 2 * sizeof(half); + + const unsigned int nsm = (unsigned int) (ggml_cuda_info().devices[ggml_cuda_get_device()].nsm); + // if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) { + if (BlocksM * BlocksN < 2*nsm){ + int j, max_remaining_waves = -1, candidate = -1; + int ks = min(20, nsm / (BlocksM * BlocksN)); + if (ks < 2 && (BlocksM * BlocksN) % nsm < nsm*4/5) + ks = 20; + for (j = 2; j <= ks; j++){ + const int remainder = (BlocksM * BlocksN * j) % nsm; + // if ((P.c * P.r * P.s) % (8*j) == 0){ + if ((P.c) % (8*j) == 0){ + if (remainder == 0) { + candidate = j; + max_remaining_waves = 0; + break; + } else if (remainder > max_remaining_waves) { + max_remaining_waves = remainder; + candidate = j; + } + } + } + if(candidate != -1){ + j = candidate; + if (j == 2) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 3) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 4) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 5) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 6) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 7) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 8) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 9) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 10) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 11) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 12) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 13) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 14) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 15) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 16) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 17) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 18) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 19) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } else if (j == 20) { + launch_conv2d_implicit_split_kernel(ctx, X_H, K_H, Y_D, BlocksM, BlocksN, shmem_bytes, P, st); + } + + return; + } + } + + cudaFuncSetAttribute(conv2d_implicit_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 + dim3 gridDim(BlocksN, BlocksM); + dim3 blockDim(ThreadsN, ThreadsM); + + conv2d_implicit_kernel + <<>>(X_H, K_H, Y_D, P); + } else{ + conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); + } + +} + +static void conv2d_implicit_cuda_f32(ggml_backend_cuda_context & ctx, const float * X_D, const float * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { + conv2d_implicit_cuda(X_D, K_D, Y_D, P, st); + GGML_UNUSED(ctx); + GGML_UNUSED(cc); +} + +void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * kernel = dst->src[0]; + const ggml_tensor * input = dst->src[1]; + float * K_D = (float *) kernel->data; + const float * X_D = (const float *) input->data; + float * Y_D = (float *) dst->data; + + GGML_ASSERT(ggml_is_contiguous(kernel)); + GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32); + + + cudaStream_t st = ctx.stream(); + const int cc = ggml_cuda_info().devices[ctx.device].cc; + + const int32_t * p = (const int32_t *) dst->op_params; + const uint ST_X = p[0]; // stride_x + const uint ST_Y = p[1]; // stride_y + const uint PD_X = p[2]; // padding_x + const uint PD_Y = p[3]; // padding_y + const uint DL_X = p[4]; // dilation_x + const uint DL_Y = p[5]; // dilation_y + + GGML_ASSERT(p[6] == false); + + const uint IW = input->ne[0]; // input_w + const uint IH = input->ne[1]; // input_h + const uint OW = dst->ne[0]; // output_w + const uint OH = dst->ne[1]; // output_h + const uint KW = kernel->ne[0]; // kernel_w + const uint KH = kernel->ne[1]; // kernel_h + const uint IC = input->ne[2]; // input_channels + + const uint OC = kernel->ne[3]; // ouptut_chanles + const uint B = input->ne[3]; // n_batches + + + int64_t pp[3] = {0}; + + param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW, + init_fastdiv_values(KW*IC), + init_fastdiv_values(OW), + init_fastdiv_values(IC), + init_fastdiv_values(KW*KH), + init_fastdiv_values(KW), + init_fastdiv_values(OW*OH), + pp[0], pp[1], pp[2], + IC*IW, + IC*KW*KH, + OW*OH, + OC*OW*OH, + B*OC*OW*OH, + IC*IW*IH}; + + if (kernel->type == GGML_TYPE_F16) { + conv2d_implicit_cuda_f16(ctx, X_D, (half *) K_D, Y_D, cc, params, st); + } else { + conv2d_implicit_cuda_f32(ctx, X_D, K_D, Y_D, cc, params, st); + } +} diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh new file mode 100644 index 0000000000..aeaa158d72 --- /dev/null +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -0,0 +1,730 @@ +#pragma once +#include "common.cuh" + +constexpr unsigned int SWIZZLE_MASK_1 = 0b10000; +constexpr unsigned int SWIZZLE_BITS_1 = 4; +constexpr unsigned int SWIZZLE_MASK_2 = 0b1100; +constexpr unsigned int SWIZZLE_BITS_2 = 2; + +typedef struct{ + unsigned int n; //batch size + unsigned int c; //number if channels + unsigned int h; //height + unsigned int w; //width + unsigned int k; //number of filters + unsigned int r; //filter height + unsigned int s; //filter width + unsigned int u; //stride height + unsigned int v; //stride width + unsigned int p; //padding height + unsigned int q; //padding width + unsigned int d_h; //dilation height + unsigned int d_w; //dilation width + unsigned int Oh; //output height + unsigned int Ow; //output width + uint3 SC_fastdiv; + uint3 OW_fastdiv; + uint3 C_fastdiv; + uint3 RS_fastdiv; + uint3 S_fastdiv; + uint3 OHOW_fastdiv; + int64_t inc_next[3]; + unsigned int inChannelOffset; + unsigned int weightKOffset; + unsigned int PQ; + unsigned int KPQ; + unsigned int NKPQ; + unsigned int CHW; +} param_t; + + +/// Clears the predicates + +template +__device__ void clear_mask(unsigned int masks_[][2], bool clear = true) { + +#pragma unroll + for (int s = 0; s < K_STRID; ++s) { + masks_[s][0] = clear ? 0 : masks_[s][0]; + masks_[s][1] = clear ? 0 : masks_[s][1]; + } +} + +template +__device__ void add_byte_offset(int64_t element_offset[], const int64_t offset) { +#pragma unroll + for (int s = 0; s < K_STRID; ++s) { + element_offset[s] += offset; + } +} + +template +__device__ void prepareIteratorA(unsigned int thread_row, + unsigned int masks[][2], + int64_t element_offset[], + const param_t param) { + int offset_n[A_K_STRID]; + int offset_p[A_K_STRID]; + int offset_q[A_K_STRID]; + +#pragma unroll + for (int s = 0; s < A_K_STRID; ++s) { + + const unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; + offset_n[s] = fastdiv(gemm_i, param.OHOW_fastdiv); + unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); + offset_p[s] = fastdiv(npq_res, param.OW_fastdiv); //* param.u - param.p; + offset_q[s] = fastmodulo(npq_res, param.OW_fastdiv); // * param.v - param.q; + const int h = offset_p[s] * (int)param.u - (int) param.p; + const int w = offset_q[s] * (int)param.v - (int) param.q; + + element_offset[s] = offset_n[s] * (int64_t)param.CHW + h * (int64_t)(param.inChannelOffset) + w * (int64_t)param.c; + + thread_row += ROW_STEP; + } + + clear_mask(masks); + + for (int r = 0; r < param.r; ++r) { +#pragma unroll + for (int s_idx = 0; s_idx < A_K_STRID; ++s_idx) { + const int h = offset_p[s_idx] * param.u - param.p + r * param.d_h; + + bool pred = (offset_n[s_idx] < param.n && h >= 0 && h < param.h); + masks[s_idx][0] |= (pred << r); + } + } + + for (int s = 0; s < param.s; ++s) { +#pragma unroll + for (int s_idx = 0; s_idx < A_K_STRID; ++s_idx) { + const int w = offset_q[s_idx] * param.v - param.q + s * param.d_w; + bool pred = (w >= 0 && w < param.w); + masks[s_idx][1] |= (pred << s); + } + } +} + +template +__device__ void cp_async_zfill(void *ptr, void const *global_ptr, bool pred_guard = true) { +#ifdef CP_ASYNC_AVAILABLE + unsigned int smem_ptr; + int src_in_bytes = pred_guard ? preload : 0; + + asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 " + "%0, smem_ptr; }\n" + : "=r"(smem_ptr) + : "l"(ptr)); + + asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_ptr), + "l"(global_ptr), + "n"(preload), "r"(src_in_bytes)); +#else + GGML_UNUSED(ptr); + GGML_UNUSED(global_ptr); + GGML_UNUSED(pred_guard); +#endif +} + +// same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel +template +__device__ __forceinline__ void tileMemcpySwizzleB( + const half* __restrict__ src, + half* __restrict__ dst, + const unsigned int curR, + const unsigned int curS, + const unsigned int curC, + const int64_t ki, + const unsigned int start_k, + const unsigned int end_k, + unsigned int thread_row, + const unsigned int thread_col, + param_t param +) { +#if __CUDA_ARCH__ >= GGML_CUDA_TURING + + constexpr unsigned int TILE_COLS = 32; + + float4* dst_float4 = reinterpret_cast(dst); + + // # of threads is multiple of # of columns in the tile + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++) { + // apply swizzle to the dst index + const unsigned int src_index = thread_row * param.weightKOffset + ki; + unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); +#ifdef CP_ASYNC_AVAILABLE + cp_async_zfill((void *)(&dst_float4[dst_index]), (void const *)(&src[src_index]), + thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k); + +#else + if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k) { + dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; + } else { // read 4 halves + dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); + } +#endif + thread_row += ROW_STEP; + } +#else + GGML_UNUSED(src); + GGML_UNUSED(dst); + GGML_UNUSED(curR); + GGML_UNUSED(curS); + GGML_UNUSED(ki); + GGML_UNUSED(start_k); + GGML_UNUSED(end_k); + GGML_UNUSED(thread_row); + GGML_UNUSED(thread_col); + GGML_UNUSED(param); + NO_DEVICE_CODE; +#endif +} + + +// this is a special case of the above for when TILE_COLS == 32 +template +__device__ __forceinline__ unsigned int tileMemcpySwizzleA( + const half* __restrict__ src, + half* __restrict__ dst, + const unsigned int curR, + const unsigned int curS, + unsigned int masks[][2], + const int64_t element_offset[], + unsigned int thread_row, + const unsigned int thread_col, + const unsigned int start_k, + const unsigned int end_k, + param_t param +) { +#if __CUDA_ARCH__ >= GGML_CUDA_TURING + + constexpr unsigned int TILE_COLS = 32; + + float4* dst_float4 = reinterpret_cast(dst); + + // # of threads is multiple of # of columns in the tile + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + + const unsigned int curC = start_k+thread_col*8; + clear_mask(masks, curC >= end_k); + + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++) { + bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); + // apply swizzle to the dst index + unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); +#ifdef CP_ASYNC_AVAILABLE + cp_async_zfill((void *)(&dst_float4[dst_index]), (void const *)(&src[element_offset[i]+curC]), valid); +#else + if (valid) { + dst_float4[dst_index] = reinterpret_cast(&src[element_offset[i]+curC])[0]; + } else { + dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); + } +#endif + thread_row += ROW_STEP; + } + return curC; +#else + GGML_UNUSED(src); + GGML_UNUSED(dst); + GGML_UNUSED(curR); + GGML_UNUSED(curS); + GGML_UNUSED(start_k); + GGML_UNUSED(end_k); + GGML_UNUSED(masks); + GGML_UNUSED(element_offset); + GGML_UNUSED(thread_row); + GGML_UNUSED(thread_col); + GGML_UNUSED(param); + NO_DEVICE_CODE; +#endif +} + +template +__device__ __forceinline__ unsigned int tileMemcpyLoadA( + const half* __restrict__ src, + float4 (&dst_reg)[ELEMENTS_PER_THREAD], + const unsigned int curR, + const unsigned int curS, + unsigned int masks[][2], + const int64_t element_offset[], + unsigned int thread_row, + const unsigned int thread_col, + const unsigned int block_k, + const unsigned int start_k, + const unsigned int end_k, + unsigned int oldC, + param_t param +) { +#if __CUDA_ARCH__ >= GGML_CUDA_TURING + + // # of threads is multiple of # of columns in the tile + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + + // compile time check that we provided the right amount of registers for storage + static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + + const unsigned int curC = start_k+block_k+thread_col*8; + if (curC > oldC) + clear_mask(masks, curC >= end_k); + + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++) { + bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); + if (valid) { + dst_reg[i] = reinterpret_cast(&src[element_offset[i]+curC])[0]; + } else{ + dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); + } + } + return curC; +#else + GGML_UNUSED(src); + GGML_UNUSED(dst_reg); + GGML_UNUSED(block_k); + GGML_UNUSED(curR); + GGML_UNUSED(curS); + GGML_UNUSED(start_k); + GGML_UNUSED(end_k); + GGML_UNUSED(masks); + GGML_UNUSED(element_offset); + GGML_UNUSED(thread_row); + GGML_UNUSED(thread_col); + GGML_UNUSED(oldC); + GGML_UNUSED(param); + NO_DEVICE_CODE; +#endif +} + +template +__device__ __forceinline__ unsigned int tileMemcpyAsyncLoadA( + const half* __restrict__ src, + half* __restrict__ dst, + const unsigned int curR, + const unsigned int curS, + unsigned int masks[][2], + const int64_t element_offset[], + unsigned int thread_row, + const unsigned int thread_col, + unsigned int iter_idx, + const unsigned int block_k, + const unsigned int start_k, + const unsigned int end_k, + unsigned int oldC, + param_t param +) { +#ifdef CP_ASYNC_AVAILABLE + + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + float4* dst_float4 = reinterpret_cast(dst); + + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + constexpr unsigned int ITER_STEPS = ROW_STEP * TILE_COLS_VECTORIZED; + + // compile time check that we provided the right amount of registers for storage + static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + + const unsigned int curC = start_k+block_k+thread_col*8; + if (curC > oldC) + clear_mask(masks, curC >= end_k); + + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++) { + bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); + unsigned int dst_index = iter_idx; + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); + + cp_async_zfill((void *)(&dst_float4[dst_index]), (void const *)(&src[element_offset[i]+curC]), valid); + iter_idx += ITER_STEPS; + } + return curC; +#else + GGML_UNUSED(src); + GGML_UNUSED(dst); + GGML_UNUSED(block_k); + GGML_UNUSED(curR); + GGML_UNUSED(curS); + GGML_UNUSED(start_k); + GGML_UNUSED(end_k); + GGML_UNUSED(masks); + GGML_UNUSED(element_offset); + GGML_UNUSED(thread_row); + GGML_UNUSED(thread_col); + GGML_UNUSED(iter_idx); + GGML_UNUSED(oldC); + GGML_UNUSED(param); + NO_DEVICE_CODE; +#endif +} + + +template +__device__ __forceinline__ void tileMemcpyLoadB( + const half* __restrict__ src, + float4 (&dst_reg)[ELEMENTS_PER_THREAD], + const unsigned int curR, + const unsigned int curS, + const unsigned int curC, + const int64_t ki, + const unsigned int block_k, + const unsigned int start_k, + const unsigned int end_k, + unsigned int thread_row, + const unsigned int thread_col, + param_t param +) { +#if __CUDA_ARCH__ >= GGML_CUDA_TURING + + // # of threads is multiple of # of columns in the tile + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + + // compile time check that we provided the right amount of registers for storage + static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + + unsigned int iter_idx = thread_row * param.weightKOffset + ki; + unsigned int krow_idx = thread_row + blockIdx.x * TILE_ROWS; + const int ITER_STEPS = ROW_STEP * param.weightKOffset; + + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++) { + const unsigned int src_index = iter_idx; + if (krow_idx < param.k && curC < end_k) { + dst_reg[i] = reinterpret_cast(&src[src_index])[0]; + } else { // read 4 halves + dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); + } + krow_idx += ROW_STEP; + iter_idx += ITER_STEPS; + } +#else + GGML_UNUSED(src); + GGML_UNUSED(dst_reg); + GGML_UNUSED(block_k); + GGML_UNUSED(curR); + GGML_UNUSED(curS); + GGML_UNUSED(ki); + GGML_UNUSED(start_k); + GGML_UNUSED(end_k); + GGML_UNUSED(thread_row); + GGML_UNUSED(thread_col); + GGML_UNUSED(param); + NO_DEVICE_CODE; +#endif +} + +template +__device__ __forceinline__ void tileMemcpyAsyncLoadB( + const half *src, + half *dst, + const unsigned int curR, + const unsigned int curS, + const unsigned int curC, + const int64_t ki, + const unsigned int block_k, + const unsigned int start_k, + const unsigned int end_k, + unsigned int thread_row, + const unsigned int thread_col, + unsigned int iter_src_idx, + unsigned int iter_dst_idx, + unsigned int krow_idx, + const int ITER_SRC_STEPS, + param_t param +) { + +#ifdef CP_ASYNC_AVAILABLE + + // # of threads is multiple of # of columns in the tile + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + float4* dst_float4 = reinterpret_cast(dst); + + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + constexpr unsigned int ITER_DST_STEPS = ROW_STEP * TILE_COLS_VECTORIZED; + + // compile time check that we provided the right amount of registers for storage + static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + + iter_src_idx += ki; + + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++) { + const unsigned int src_index = iter_src_idx; + unsigned int dst_index = iter_dst_idx; + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); + + cp_async_zfill((void *)(&dst_float4[dst_index]), (void const *)(&src[src_index]), krow_idx < param.k && curC < end_k); + + iter_src_idx += ITER_SRC_STEPS; + krow_idx += ROW_STEP; + iter_dst_idx += ITER_DST_STEPS; + } +#else + GGML_UNUSED(src); + GGML_UNUSED(dst); + GGML_UNUSED(block_k); + GGML_UNUSED(curR); + GGML_UNUSED(curS); + GGML_UNUSED(ki); + GGML_UNUSED(start_k); + GGML_UNUSED(end_k); + GGML_UNUSED(thread_row); + GGML_UNUSED(thread_col); + GGML_UNUSED(iter_src_idx); + GGML_UNUSED(iter_dst_idx); + GGML_UNUSED(krow_idx); + GGML_UNUSED(ITER_SRC_STEPS); + GGML_UNUSED(param); + NO_DEVICE_CODE; +#endif +} + + +// same as above but without the swizzle + +// this is a special case of the above for when TILE_COLS == 32 +template +__device__ __forceinline__ void tileMemcpySwizzleStore( + const float4 (&src_reg)[ELEMENTS_PER_THREAD], + half* __restrict__ dst, + unsigned int thread_row, + const unsigned int thread_col +) { +#if __CUDA_ARCH__ >= GGML_CUDA_TURING + + constexpr unsigned int TILE_COLS = 32; + + // reinterpret input/output as float4 + float4* dst_float4 = reinterpret_cast(dst); + + // # of threads is multiple of # of columns in the tile + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); + + // assign each thread a row/column in the tile, calculate how many iterations we need + // to cover the whole tile + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; + constexpr unsigned int ITER_STEPS = ROW_STEP * TILE_COLS_VECTORIZED; + + // compile time check that we provided the right amount of registers for storage + static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + + unsigned int iter_idx = thread_row * TILE_COLS_VECTORIZED + thread_col; + #pragma unroll + for (unsigned int i = 0; i < NUM_ITERS; i++) { + // apply swizzle to the dst index + unsigned int dst_index = iter_idx; + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); + dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); + dst_float4[dst_index] = src_reg[i]; + iter_idx += ITER_STEPS; + } +#else + GGML_UNUSED(src_reg); + GGML_UNUSED(dst); + GGML_UNUSED(thread_row); + GGML_UNUSED(thread_col); + NO_DEVICE_CODE; +#endif +} + +template +__device__ __forceinline__ void loadFilter(const T * __restrict__ kernel, + T * __restrict__ smemweight, + const unsigned int by, + const unsigned int innerRowA, + const unsigned int innerColA, + const unsigned int weightKOffset, + const unsigned int start_k, + const unsigned int end_k, + const param_t param){ + + const unsigned int weight_sts_addr = innerRowA + innerColA * (BN+PAD) * 4; + const unsigned int kidx = start_k + innerColA * 4; +#pragma unroll + for (int offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) { + const unsigned int nidx = by * BN + innerRowA + offset; + if (vec_load) { + if (nidx < param.k && kidx < end_k) { + if constexpr (std::is_same_v){ + float4 tmp = reinterpret_cast(&kernel[nidx * weightKOffset + kidx])[0]; + smemweight[weight_sts_addr + offset + 0] = tmp.x; + smemweight[weight_sts_addr + offset + (BN+PAD)] = tmp.y; + smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = tmp.z; + smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = tmp.w; + } else { // read 4 halves + float2 tmp = reinterpret_cast(&kernel[nidx * weightKOffset + kidx])[0]; + const half *val = reinterpret_cast(&tmp); + smemweight[weight_sts_addr + offset + 0] = val[0]; + smemweight[weight_sts_addr + offset + (BN+PAD)] = val[1]; + smemweight[weight_sts_addr + offset + 2*(BN+PAD)] = val[2]; + smemweight[weight_sts_addr + offset + 3*(BN+PAD)] = val[3]; + } + } else { +#pragma unroll + for (int i = 0; i < 4; ++i) { + smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; + } + } + } else { +#pragma unroll + for (int i = 0; i < 4; ++i) { + if (nidx < param.k && kidx + i < end_k) { + smemweight[weight_sts_addr + offset + i*(BN+PAD)] = kernel[nidx * weightKOffset + kidx + i]; + } else { + smemweight[weight_sts_addr + offset + i*(BN+PAD)] = (T)0.f; + } + } + } + } +} + + +template +__device__ __forceinline__ void loadInput(const float * __restrict__ input, + float * __restrict__ smeminput, + const unsigned int bx, + const unsigned int innerRowA, + const unsigned int innerColA, + const unsigned int start_k, + const unsigned int end_k, + const unsigned int PQ, + const unsigned int CHW, + const unsigned int inChannelOffset, + const param_t param) { + const unsigned int input_sts_addr = innerRowA + innerColA * (BM+PAD) * 4; + const unsigned int kidx = start_k + innerColA * 4; +#pragma unroll + for (unsigned int offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { + const unsigned int midx = bx * BM + innerRowA + offset; + int n = (ksplit > 0) ? midx / PQ : blockIdx.z; + const unsigned int npq_res = midx % PQ; + const int posh_ori = fastdiv((ksplit > 0) ? npq_res: midx, param.OW_fastdiv) * param.u - param.p; + const int posw_ori = fastmodulo((ksplit > 0) ? npq_res: midx, param.OW_fastdiv) * param.v - param.q; + const unsigned int inOffset = n * CHW; + if (vec_load) { + const unsigned int cur0 = fastdiv(kidx, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + const unsigned int cur1 = fastdiv(fastmodulo(kidx, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const unsigned int cur2 = fastmodulo(fastmodulo(kidx, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const unsigned int curC = layout == 0 ? cur2 : cur0; + const unsigned int curR = layout == 0 ? cur0 : cur1; + const unsigned int curS = layout == 0 ? cur1 : cur2; + const int curH = posh_ori + curR * param.d_h; // input h + const int curW = posw_ori + curS * param.d_w; // input w + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && kidx < end_k) { + int inOffsetTmp = layout == 0 ? + curH * inChannelOffset + curW * param.c + curC: + curC * inChannelOffset + curH * param.w + curW; + float4 tmp = reinterpret_cast(&input[inOffset + inOffsetTmp])[0]; + smeminput[input_sts_addr + offset + 0] = tmp.x; + smeminput[input_sts_addr + offset + BM+PAD] = tmp.y; + smeminput[input_sts_addr + offset + 2*(BM+PAD)] = tmp.z; + smeminput[input_sts_addr + offset + 3*(BM+PAD)] = tmp.w; + } else { +#pragma unroll + for (int i = 0; i < 4; ++i) + smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f; + } + } else { +#pragma unroll + for (int i = 0; i < 4; ++i) { + const unsigned int cur0 = fastdiv(kidx + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv); // channel offset + const unsigned int cur1 = fastdiv(fastmodulo(kidx + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const unsigned int cur2 = fastmodulo(fastmodulo(kidx + i, + layout == 0 ? param.SC_fastdiv : param.RS_fastdiv), + layout == 0 ? param.C_fastdiv : param.S_fastdiv); // kernel r offset + const unsigned int curC = layout == 0 ? cur2 : cur0; + const unsigned int curR = layout == 0 ? cur0 : cur1; + const unsigned int curS = layout == 0 ? cur1 : cur2; + const int curH = posh_ori + curR * param.d_h; // input h + const int curW = posw_ori + curS * param.d_w; // input w + if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && kidx + i < end_k) { + int inOffsetTmp = layout == 0 ? + curH * inChannelOffset + curW * param.c + curC: + curC * inChannelOffset + curH * param.w + curW; + smeminput[input_sts_addr + offset + i*(BM+PAD)] = input[inOffset + inOffsetTmp]; + } else { + smeminput[input_sts_addr + offset + i*(BM+PAD)] = 0.f; + } + } + } + } +} + + +#define CUDA_CONV2D_IMPLICT_BLOCK_SIZE 256 +void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/cp-async.cuh b/ggml/src/ggml-cuda/cp-async.cuh index 63d0c482ff..91011234b2 100644 --- a/ggml/src/ggml-cuda/cp-async.cuh +++ b/ggml/src/ggml-cuda/cp-async.cuh @@ -3,7 +3,7 @@ #include "common.cuh" -static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) { +static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(const void * generic_ptr) { #ifdef CP_ASYNC_AVAILABLE return __cvta_generic_to_shared(generic_ptr); #else diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 08383edb40..796278a15e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -13,6 +13,7 @@ #include "ggml-cuda/concat.cuh" #include "ggml-cuda/conv-transpose-1d.cuh" #include "ggml-cuda/conv2d.cuh" +#include "ggml-cuda/conv2d-implicit.cuh" #include "ggml-cuda/conv2d-dw.cuh" #include "ggml-cuda/conv2d-transpose.cuh" #include "ggml-cuda/convert.cuh" @@ -2671,7 +2672,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_im2col_3d(ctx, dst); break; case GGML_OP_CONV_2D: - ggml_cuda_op_conv2d(ctx, dst); + ggml_cuda_op_conv2d_implicit(ctx, dst); break; case GGML_OP_CONV_2D_DW: ggml_cuda_op_conv2d_dw(ctx, dst); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 411467e968..8ce1a13c68 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -40,6 +40,7 @@ #include #include #include +#include #ifdef __EMSCRIPTEN__ # define N_THREADS 1 @@ -7277,6 +7278,25 @@ static std::vector> make_test_cases_eval() { } } + test_cases.emplace_back(new test_conv_2d( { 16, 16, 8, 1}, { 3, 3, 8, 12}, + GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false)); + test_cases.emplace_back(new test_conv_2d( { 16, 16, 16, 1}, { 3, 3, 16, 6}, + GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false)); + test_cases.emplace_back(new test_conv_2d( { 16, 16, 24, 1}, { 3, 3, 24, 6}, + GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false)); + test_cases.emplace_back(new test_conv_2d( { 16, 16, 8, 3}, { 3, 3, 8, 6}, + GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false)); + test_cases.emplace_back(new test_conv_2d( { 24, 24, 32, 1 }, { 3, 3, 32, 8}, + GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false)); + test_cases.emplace_back(new test_conv_2d( { 24, 24, 96, 1 }, { 3, 3, 96, 8}, + GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false)); + test_cases.emplace_back(new test_conv_2d( { 24, 24, 128, 1 }, { 3, 3, 128, 8}, + GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false)); + test_cases.emplace_back(new test_conv_2d( { 24, 24, 128, 3 }, { 3, 3, 128, 8}, + GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false)); + + + // sycl backend will limit task global_range < MAX_INT // test cases for 2D im2col with large input W and H (occurs in stable-diffusion) // however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.) @@ -8363,6 +8383,71 @@ static std::vector> make_test_cases_perf() { } } + // Stable-diffusion layers + std::map idx_sd{ + { "iw", 0 }, + { "ih", 1 }, + { "kw", 2 }, + { "kh", 3 }, + { "Cout", 4 }, + { "Cin", 5 }, + { "B", 6 }, + }; + + // Input image size + uint32_t w = 768; + uint32_t h = 1024; + + // Number of filters (base) + uint32_t Cout_b = 128; + uint32_t Cin_b = 128; + + std::vector> cases_sd = { + { w / 8, h / 8, 3, 3, Cout_b * 4, Cin_b * 4, 1 }, // x10 (called 10 times) + { w / 4, h / 4, 3, 3, Cout_b * 4, Cin_b * 4, 1 }, // x7 + { w / 2, h / 2, 3, 3, Cout_b * 2, Cin_b * 2, 1 }, // x5 + { w, h, 3, 3, Cout_b, Cin_b, 1 }, // x5 + { w / 8, h / 8, 1, 1, Cout_b * 4, Cin_b * 4, 1 }, // x4 + { w / 8, h / 8, 1, 1, 4, 4, 1 }, + { w / 8, h / 8, 3, 3, Cout_b * 4, 4, 1 }, + + { w / 2, h / 2, 3, 3, Cout_b * 4, Cin_b * 4, 1 }, + { w / 2, h / 2, 3, 3, Cout_b * 2, Cin_b * 4, 1 }, + { w / 2, h / 2, 1, 1, Cout_b * 2, Cin_b * 4, 1 }, + + { w, h, 3, 3, Cout_b, Cin_b * 2, 1 }, + { w, h, 1, 1, Cout_b, Cin_b * 2, 1 }, + { w, h, 3, 3, Cout_b * 2, Cin_b * 2, 1 }, + + { w, h, 3, 3, 3, Cin_b, 1 }, + }; + + for (auto act_case : cases_sd) { + GGML_ASSERT(act_case[idx_sd["kw"]] == 3 || act_case[idx_sd["kw"]] == 1); + GGML_ASSERT(act_case[idx_sd["kh"]] == 3 || act_case[idx_sd["kh"]] == 1); + + uint32_t p0 = act_case[idx_sd["kw"]] == 3 ? 1 : 0; + uint32_t p1 = act_case[idx_sd["kh"]] == 3 ? 1 : 0; + + test_cases.emplace_back(new test_conv_2d( + { act_case[idx_sd["iw"]], act_case[idx_sd["ih"]], act_case[idx_sd["Cin"]], act_case[idx_sd["B"]] }, + { act_case[idx_sd["kw"]], act_case[idx_sd["kh"]], act_case[idx_sd["Cin"]], act_case[idx_sd["Cout"]] }, + GGML_TYPE_F16, 1, 1, p0, p1, 1, 1, false)); + } + + for (auto act_case : cases_sd) { + GGML_ASSERT(act_case[idx_sd["kw"]] == 3 || act_case[idx_sd["kw"]] == 1); + GGML_ASSERT(act_case[idx_sd["kh"]] == 3 || act_case[idx_sd["kh"]] == 1); + + uint32_t p0 = act_case[idx_sd["kw"]] == 3 ? 1 : 0; + uint32_t p1 = act_case[idx_sd["kh"]] == 3 ? 1 : 0; + + test_cases.emplace_back(new test_conv_2d( + { act_case[idx_sd["iw"]], act_case[idx_sd["ih"]], act_case[idx_sd["Cin"]], act_case[idx_sd["B"]] }, + { act_case[idx_sd["kw"]], act_case[idx_sd["kh"]], act_case[idx_sd["Cin"]], act_case[idx_sd["Cout"]] }, + GGML_TYPE_F32, 1, 1, p0, p1, 1, 1, false)); + } + test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1})); test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));