diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index d04c379bb0..5ec616a978 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -811,7 +811,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, constexpr unsigned int B_K_STRID = BN / ROW_STEP; unsigned int masks_a[A_K_STRID][2]; - unsigned int element_offset_a[A_K_STRID]; + int64_t element_offset_a[A_K_STRID]; // calculate block/warp indices const unsigned int block_m = blockIdx.y; @@ -867,6 +867,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, float4 B_gmem_cache_reg[4]; + prepareIteratorA(thread_idx, masks_a, element_offset_a, param); // prefetch the first block tile of A,B into shared memory @@ -874,7 +875,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const half* A_block_gmem = input; const half* B_block_gmem = kernel + block_n * BN * weightKOffset; - tileMemcpySwizzleA(A_block_gmem, A_block_smem, 0, 0, masks_a, element_offset_a, thread_idx, start_k, end_k, inChannelOffset, param); + tileMemcpySwizzleA(A_block_gmem, A_block_smem, 0, 0, masks_a, element_offset_a, + thread_idx, start_k, end_k, inChannelOffset, param); tileMemcpySwizzleB(B_block_gmem, B_block_smem, 0, 0, start_k, end_k, weightKOffset, param); int offset_direction = 1; @@ -899,6 +901,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, next_idx = 2; } } + + add_byte_offset(element_offset_a, param.inc_next[next_idx]); + if (next_idx == 2) { ++block_k; } @@ -911,7 +916,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, // if (block_k != num_block_tiles_k){ if (block_krs != num_block_tiles_krs){ - tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, inChannelOffset, param); + tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, r, s, + masks_a, element_offset_a, thread_idx, block_k * BK, + start_k, end_k, inChannelOffset, param); tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, weightKOffset, param); } half* A_warp_tile = A_block_smem + A_warp_tile_offset; @@ -1096,7 +1103,7 @@ template<<>>(X_H, K_H, Y_H.get(), P); + int64_t inc[3]; + // next S + inc[0] = int64_t(P.c) * P.d_w; + // next R + inc[1] = int64_t(P.w * P.c) * P.d_h - (P.s - 1) * P.c * P.d_w; + // next C + inc[2] = BK - int64_t(P.r - 1) * P.w * P.c * P.d_h - int64_t(P.s - 1) * P.c * P.d_w ; + memcpy(P.inc_next, inc, sizeof(int64_t)*3); + const unsigned int nrows = P.n * P.k * P.Oh * P.Ow; const unsigned int blockx = (nrows + 511) / 512; const dim3 block_nums(blockx, 1, 1); @@ -1116,7 +1132,7 @@ static void launch_conv2d_implicit_split_kernel(ggml_backend_cuda_context & ctx, reduce_f32<<>>(Y_H.get(), Y_D, nrows, ksplit); } -static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) { +static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, param_t P, cudaStream_t st) { // if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1)) { if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0) { @@ -1279,6 +1295,15 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa } } + int64_t inc[3]; + // next S + inc[0] = int64_t(P.c) * P.d_w; + // next R + inc[1] = int64_t(P.w * P.c) * P.d_h - (P.s - 1) * P.c * P.d_w; + // next C + inc[2] = BK_dim - int64_t(P.r - 1) * P.w * P.c * P.d_h - int64_t(P.s - 1) * P.c * P.d_w ; + memcpy(P.inc_next, inc, sizeof(int64_t)*3); + cudaFuncSetAttribute(conv2d_implicit_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75 dim3 gridDim(BlocksN, BlocksM); @@ -1340,6 +1365,8 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * const uint OC = kernel->ne[3]; // ouptut_chanles const uint B = input->ne[3]; // n_batches + + param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW, init_fastdiv_values(KW*IC), init_fastdiv_values(OW), diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 0f25b38dd6..22b597f7bb 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -23,6 +23,7 @@ typedef struct{ uint3 RS_fastdiv; uint3 S_fastdiv; uint3 OHOW_fastdiv; + int64_t inc_next[3]; } param_t; @@ -38,13 +39,21 @@ __host__ __device__ void clear_mask(unsigned int masks_[][2], bool clear = true) } } +template +__host__ __device__ void add_byte_offset(int64_t element_offset[], const int64_t offset){ +#pragma unroll + for (int s = 0; s < K_STRID; ++s) { + element_offset[s] += offset; + } +} + template __device__ void prepareIteratorA(const int thread_idx, unsigned int masks[][2], - unsigned int element_offset[], + int64_t element_offset[], const param_t param){ int offset_n[A_K_STRID]; int offset_p[A_K_STRID]; @@ -176,8 +185,8 @@ __device__ __forceinline__ void tileMemcpySwizzleA( half* dst, const unsigned int curR, const unsigned int curS, - unsigned int masks[][2], - unsigned int element_offset[], + const unsigned int masks[][2], + const int64_t element_offset[], const unsigned int thread_idx, const unsigned int start_k, const unsigned int end_k, @@ -208,52 +217,52 @@ __device__ __forceinline__ void tileMemcpySwizzleA( unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - const unsigned int ki = start_k+thread_col*8; + // const unsigned int ki = start_k+thread_col*8; const unsigned int chw = param.c * param.h * param.w; // const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset // const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset // const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = ki; - // #pragma unroll - // for (unsigned int i = 0; i < NUM_ITERS; i++){ - // bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); - // // apply swizzle to the dst index - // unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; - // dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); - // dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - // if (valid && ki < end_k){ - // if(element_offset[i]+curC >= 327680 || element_offset[i]+curC < 0) - // printf("%d, %d, %d, %d, %d, %d, %d, %d, %d \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, - // i, element_offset[i], curR, curS, curC); - // dst_float4[dst_index] = reinterpret_cast(&src[element_offset[i]+curC])[0]; - // } else{ - // dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); - // } - // thread_row += ROW_STEP; - // } + const unsigned int curC = start_k+thread_col*8; #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ - unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; - unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); - unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); - int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; - int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; - // unsigned int inOffset = n * param.c * param.h * param.w; - int curH = posh_ori + curR * param.d_h; // input h - int curW = posw_ori + curS * param.d_w; // input w + bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); // apply swizzle to the dst index unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k){ - const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; - dst_float4[dst_index] = reinterpret_cast(&src[n * chw + inOffsetTmp])[0]; + if (valid && curC < end_k){ + if(element_offset[i] >= 327680 || element_offset[i] < 0) + printf("%d, %d, %d, %d, %d, %d, %d, %d, %d \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, + i, element_offset[i], curR, curS, curC); + dst_float4[dst_index] = reinterpret_cast(&src[element_offset[i]])[0]; } else{ dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); } thread_row += ROW_STEP; } + // #pragma unroll + // for (unsigned int i = 0; i < NUM_ITERS; i++){ + // unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; + // unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); + // unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); + // int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; + // int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; + // // unsigned int inOffset = n * param.c * param.h * param.w; + // int curH = posh_ori + curR * param.d_h; // input h + // int curW = posw_ori + curS * param.d_w; // input w + // // apply swizzle to the dst index + // unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; + // dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); + // dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); + // if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && + // curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k){ + // const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + // dst_float4[dst_index] = reinterpret_cast(&src[n * chw + inOffsetTmp])[0]; + // } else{ + // dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); + // } + // thread_row += ROW_STEP; + // } #else GGML_UNUSED(src); GGML_UNUSED(dst); @@ -272,6 +281,9 @@ __device__ __forceinline__ void tileMemcpyLoadA( float4 (&dst_reg)[ELEMENTS_PER_THREAD], const unsigned int curR, const unsigned int curS, + const unsigned int masks[][2], + const int64_t element_offset[], + const unsigned int thread_idx, const unsigned int block_k, const unsigned int start_k, const unsigned int end_k, @@ -285,45 +297,52 @@ __device__ __forceinline__ void tileMemcpyLoadA( static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0); // flatten out 2d grid of threads into in order of increasing threadIdx.x - const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x; - // assign each thread a row/column in the tile, calculate how many iterations we need // to cover the whole tile constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED; constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; - unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; + // unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - const unsigned int ki = start_k+block_k+thread_col*8; - const unsigned int chw = param.c * param.h * param.w; + // const unsigned int ki = start_k+block_k+thread_col*8; + // const unsigned int chw = param.c * param.h * param.w; // const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset // const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset // const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = ki; + const unsigned int curC = start_k+block_k+thread_col*8;; #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ - unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; - unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); - unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); - int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; - int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; - // unsigned int inOffset = n * param.c * param.h * param.w; - int curH = posh_ori + curR * param.d_h; // input h - int curW = posw_ori + curS * param.d_w; // input w - if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k){ - const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; - dst_reg[i] = reinterpret_cast(&src[n * chw + inOffsetTmp])[0]; + bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS)); + if (valid && curC < end_k) { + dst_reg[i] = reinterpret_cast(&src[element_offset[i]])[0]; } else{ dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); } - thread_row += ROW_STEP; } + // #pragma unroll + // for (unsigned int i = 0; i < NUM_ITERS; i++){ + // unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; + // unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv); + // unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); + // int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; + // int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; + // // unsigned int inOffset = n * param.c * param.h * param.w; + // int curH = posh_ori + curR * param.d_h; // input h + // int curW = posw_ori + curS * param.d_w; // input w + // if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && + // curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k){ + // const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; + // dst_reg[i] = reinterpret_cast(&src[n * chw + inOffsetTmp])[0]; + // } else{ + // dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); + // } + // thread_row += ROW_STEP; + // } #else GGML_UNUSED(src); GGML_UNUSED(dst_reg);