diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 529a0b50fd..bd67ac2b86 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -781,9 +781,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, constexpr unsigned int MMA_M = 16; constexpr unsigned int MMA_N = 8; - const unsigned int K = param.c * param.r * param.s; + const unsigned int K = param.c; const uint inChannelOffset = param.c * param.w; - const uint weightKOffset = K; + const uint weightKOffset = param.c * param.r * param.s; const unsigned int PQ = param.Ow * param.Oh; const unsigned int KPQ = param.k * PQ; @@ -799,18 +799,25 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N; const unsigned int z = blockIdx.z; - const unsigned int ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset; + const unsigned int ks = (ksplit > 0) ? (K + ksplit - 1) / ksplit : K; const unsigned int start_k = (ksplit > 0) ? z * ks : 0; - const unsigned int end_k = min(start_k + ks, weightKOffset); + const unsigned int end_k = min(start_k + ks, K); const unsigned int num_block_tiles_k = (ks + (BK-1)) / BK; + constexpr unsigned int TILE_COLS_VECTORIZED = BK / 8; + constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; + constexpr unsigned int A_K_STRID = BM / ROW_STEP; + constexpr unsigned int B_K_STRID = BN / ROW_STEP; + unsigned int masks_a[A_K_STRID][2]; + unsigned int element_offset_a[A_K_STRID]; // calculate block/warp indices const unsigned int block_m = blockIdx.y; const unsigned int block_n = blockIdx.x; const unsigned int warp_m = threadIdx.y; const unsigned int warp_n = threadIdx.x / 32; + const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; // double buffering extern __shared__ half shmem[]; @@ -858,12 +865,21 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, float4 A_gmem_cache_reg[4]; float4 B_gmem_cache_reg[4]; + + prepareIteratorA(thread_idx, masks_a, element_offset_a, param); + + // prefetch the first block tile of A,B into shared memory const half* A_block_gmem = input; const half* B_block_gmem = kernel + block_n * BN * weightKOffset; - tileMemcpySwizzleA(A_block_gmem, A_block_smem, start_k, end_k, inChannelOffset, param); - tileMemcpySwizzleB(B_block_gmem, B_block_smem, start_k, end_k, weightKOffset, param); + int s = 0; + int r = 0; + while (r < param.r) { + // for (int r = 0; r < param.r; ++r) { + + tileMemcpySwizzleA(A_block_gmem, A_block_smem, r, s, masks_a, element_offset_a, thread_idx, start_k, end_k, inChannelOffset, param); + tileMemcpySwizzleB(B_block_gmem, B_block_smem, r, s, start_k, end_k, weightKOffset, param); int offset_direction = 1; @@ -871,8 +887,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, __syncthreads(); if (block_k != num_block_tiles_k){ - tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, block_k * BK, start_k, end_k, inChannelOffset, param); - tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, block_k * BK, start_k, end_k, weightKOffset, param); + tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, inChannelOffset, param); + tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, weightKOffset, param); } half* A_warp_tile = A_block_smem + A_warp_tile_offset; half* B_warp_tile = B_block_smem + B_warp_tile_offset; @@ -926,7 +942,14 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, tileMemcpySwizzleStore(A_gmem_cache_reg, A_block_smem); tileMemcpySwizzleStore(B_gmem_cache_reg, B_block_smem); } - } + } // iter block_k + + s++; + if (s == param.s) { + s = 0; + r++; + } + } // iter r // reuse smem half *smemoutput = shmem; @@ -1166,7 +1189,8 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa ks = 16; for (j = 2; j <= ks; j++){ const int remainder = (BlocksM * BlocksN * j) % nsm; - if ((P.c * P.r * P.s) % (8*j) == 0){ + // if ((P.c * P.r * P.s) % (8*j) == 0){ + if ((P.c) % (8*j) == 0){ if (remainder == 0) { candidate = j; max_remaining_waves = 0; diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 981a183fd9..0f25b38dd6 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -26,12 +26,89 @@ typedef struct{ } param_t; +/// Clears the predicates + +template +__host__ __device__ void clear_mask(unsigned int masks_[][2], bool clear = true) { + +#pragma unroll + for (int s = 0; s < K_STRID; ++s) { + masks_[s][0] = clear ? 0 : masks_[s][0]; + masks_[s][1] = clear ? 0 : masks_[s][1]; + } +} + +template +__device__ void prepareIteratorA(const int thread_idx, + unsigned int masks[][2], + unsigned int element_offset[], + const param_t param){ + int offset_n[A_K_STRID]; + int offset_p[A_K_STRID]; + int offset_q[A_K_STRID]; + + constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; + unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + const unsigned int chw = param.c * param.h * param.w; + +#pragma unroll + for (int s = 0; s < A_K_STRID; ++s) { + + // pointer_[s] = reinterpret_cast(ptr); + + // int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + const unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; + offset_n[s] = fastdiv(gemm_i, param.OHOW_fastdiv); + unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); + offset_p[s] = fastdiv(npq_res, param.OW_fastdiv); //* param.u - param.p; + offset_q[s] = fastmodulo(npq_res, param.OW_fastdiv); // * param.v - param.q; + const int h = offset_p[s] * param.u - param.p; + const int w = offset_q[s] * param.v - param.q; + + // if(threadIdx.x < 32 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0) + // printf("%d, %d : %d, %d, %d, %d offset (%d, %d, %d), kele %llu Kcont %d\n ", thread_idx, s, + // // printf("[%s - %d] %d, %d : %d, %d, %d, %d\n ", __FUNCTION__, __LINE__, thread_idx, s, + // threadblock_offset.row(), thread_coord.strided(), ThreadMap::Delta::kStrided, + // offset_npq, offset_n[s], offset_p[s], offset_q[s], AccessType::kElements, + // ThreadMap::Iterations::kContiguous); + + element_offset[s] = offset_n[s] * chw + h * param.c * param.w + w * param.c; + thread_row += ROW_STEP; + } + + clear_mask(masks); + + for (int r = 0; r < param.r; ++r) { +#pragma unroll + for (int s_idx = 0; s_idx < A_K_STRID; ++s_idx) { + const int h = offset_p[s_idx] * param.u - param.p + r * param.d_h; + + bool pred = (offset_n[s_idx] < param.n && h >= 0 && h < param.h); + masks[s_idx][0] |= (pred << r); + } + } + + for (int s = 0; s < param.s; ++s) { +#pragma unroll + for (int s_idx = 0; s_idx < A_K_STRID; ++s_idx) { + const int w = offset_q[s_idx] * param.v - param.q + s * param.d_w; + bool pred = (w >= 0 && w < param.w); + masks[s_idx][1] |= (pred << s); + } + } +} + // same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel template __device__ __forceinline__ void tileMemcpySwizzleB( const half* src, half* dst, + const unsigned int curR, + const unsigned int curS, const unsigned int start_k, const unsigned int end_k, const unsigned int src_stride, @@ -60,10 +137,12 @@ __device__ __forceinline__ void tileMemcpySwizzleB( unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - const unsigned int ki = start_k+thread_col*8; - const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // + // const unsigned int ki = (curR*param.s+curS)*param.c + start_k+thread_col*8; + // const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset + // const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // + const unsigned int curC = start_k+thread_col*8; + const unsigned int ki = (curR*param.s+curS)*param.c + curC; #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ @@ -72,7 +151,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB( unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){ + if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k){ dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); @@ -95,6 +174,11 @@ unsigned int NUM_THREADS> __device__ __forceinline__ void tileMemcpySwizzleA( const half* src, half* dst, + const unsigned int curR, + const unsigned int curS, + unsigned int masks[][2], + unsigned int element_offset[], + const unsigned int thread_idx, const unsigned int start_k, const unsigned int end_k, const unsigned int inChannelOffset, @@ -115,7 +199,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA( constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8; static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); // flatten out 2d grid of threads into in order of increasing threadIdx.x - const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; + // const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; // assign each thread a row/column in the tile, calculate how many iterations we need // to cover the whole tile @@ -126,11 +210,27 @@ __device__ __forceinline__ void tileMemcpySwizzleA( const unsigned int ki = start_k+thread_col*8; const unsigned int chw = param.c * param.h * param.w; - const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - - + // const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset + // const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = ki; + // #pragma unroll + // for (unsigned int i = 0; i < NUM_ITERS; i++){ + // bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); + // // apply swizzle to the dst index + // unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; + // dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); + // dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); + // if (valid && ki < end_k){ + // if(element_offset[i]+curC >= 327680 || element_offset[i]+curC < 0) + // printf("%d, %d, %d, %d, %d, %d, %d, %d, %d \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, + // i, element_offset[i], curR, curS, curC); + // dst_float4[dst_index] = reinterpret_cast(&src[element_offset[i]+curC])[0]; + // } else{ + // dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); + // } + // thread_row += ROW_STEP; + // } #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; @@ -170,7 +270,8 @@ unsigned int ELEMENTS_PER_THREAD> __device__ __forceinline__ void tileMemcpyLoadA( const half* src, float4 (&dst_reg)[ELEMENTS_PER_THREAD], - // const unsigned int src_stride, + const unsigned int curR, + const unsigned int curS, const unsigned int block_k, const unsigned int start_k, const unsigned int end_k, @@ -199,9 +300,10 @@ __device__ __forceinline__ void tileMemcpyLoadA( const unsigned int ki = start_k+block_k+thread_col*8; const unsigned int chw = param.c * param.h * param.w; - const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset + // const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = ki; #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ @@ -240,6 +342,8 @@ unsigned int ELEMENTS_PER_THREAD> __device__ __forceinline__ void tileMemcpyLoadB( const half* src, float4 (&dst_reg)[ELEMENTS_PER_THREAD], + const unsigned int curR, + const unsigned int curS, const unsigned int block_k, const unsigned int start_k, const unsigned int end_k, @@ -265,15 +369,16 @@ __device__ __forceinline__ void tileMemcpyLoadB( // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - const unsigned int ki = start_k+block_k+thread_col*8; - const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // + // const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset + // const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // + const unsigned int curC = start_k+block_k+thread_col*8; + const unsigned int ki = (curR*param.s+curS)*param.c + curC; #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ const unsigned int src_index = thread_row * src_stride + ki; - if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){ + if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k){ dst_reg[i] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 8ee0747989..daac5c9605 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -301,7 +301,9 @@ static std::vector> configs = { // std::make_tuple(960,320,104,152,3,3), // std::make_tuple(1280,1280,26,38,3,3), // std::make_tuple(1920,640,32,32,3,3) - std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(1280,1280,16,16,3,3), + // std::make_tuple(32,8,24,24,3,3), + std::make_tuple(640,640,64,64,3,3), // std::make_tuple(320,640,32,32,3,3), // std::make_tuple(4,320,96,128,3,3), // std::make_tuple(320,4,96,128,3,3), @@ -671,7 +673,7 @@ int main(void) // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); - int iterations = 0; + int iterations = 20; double run_time0; std::vector im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0);