diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index d2d775c9b2..9b2331876b 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -720,6 +720,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const uint inChannelOffset = param.c * param.w; const uint weightKOffset = K; + const unsigned int PQ = param.Ow * param.Oh; + const unsigned int KPQ = param.k * PQ; + const unsigned int NKPQ = param.n * KPQ; + // loop bounds, constexpr where possible allows for loop unrolling constexpr unsigned int mma_tiles_per_warp_k = 4; constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M; @@ -845,14 +849,15 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, for (int i = 0; i < 2; ++i) { __syncthreads(); - +#pragma unroll for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) { + const int output_sts_offset = output_sts_addr + mma_m * MMA_M * BN / 2 - i * mma_tiles_per_warp_n/2 * MMA_N; for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++) { uint32_t (®_)[2] = reinterpret_cast(acc_register_[mma_m][mma_n]); - uint idx = output_sts_addr + - mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N; + uint idx = output_sts_offset + mma_n * MMA_N; + // mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N; idx = idx ^ ((idx & 0b1110000000) >> 4); uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); dst_ptr[0] = reg_[0]; @@ -861,24 +866,25 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } } __syncthreads(); - + const unsigned int m_i_wn = m_idx + i * WN / 2; #pragma unroll for (int subk = 0; subk < WN / 2; ++subk){ + const uint row = m_i_wn + subk; +#pragma unroll for (int j = 0; j < 4; ++j){ - const uint row = m_idx + subk + i * WN / 2; const uint gemm_i = n_idx + j*32; const int n = fastdiv(gemm_i, param.OHOW_fastdiv); const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); - if (n < param.n && row < param.k && col < param.Oh * param.Ow) { + if (n < param.n && row < param.k && col < PQ) { uint idx = output_lds_addr + subk + j*32*BN/2; idx = idx ^ ((idx & 0b1110000000) >> 4); if constexpr (ksplit > 0) { - const uint outOffset = z * param.n * param.k * param.Oh * param.Ow + - n * param.k * param.Oh * param.Ow + - row * param.Oh * param.Ow + col; + const uint outOffset = z * NKPQ + + n * KPQ + + row * PQ + col; output[outOffset] = smemoutput[idx]; } else { - const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; + const uint outOffset = n * KPQ + row * PQ + col; output[outOffset] = smemoutput[idx]; } } diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cuh b/ggml/src/ggml-cuda/conv2d-implicit.cuh index 0226e715ce..a49210ddd8 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cuh +++ b/ggml/src/ggml-cuda/conv2d-implicit.cuh @@ -59,18 +59,20 @@ __device__ __forceinline__ void tileMemcpySwizzleB( constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP; unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; - const unsigned int curR = fastdiv(start_k+thread_col*8, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // + + const unsigned int ki = start_k+thread_col*8; + const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ // apply swizzle to the dst index - const unsigned int src_index = thread_row * src_stride + start_k + thread_col * 8; + const unsigned int src_index = thread_row * src_stride + ki; unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col; dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); - if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && start_k+thread_col*8 < end_k){ + if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){ dst_float4[dst_index] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); @@ -122,6 +124,12 @@ __device__ __forceinline__ void tileMemcpySwizzleA( unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED; const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED; + const unsigned int ki = start_k+thread_col*8; + const unsigned int chw = param.c * param.h * param.w; + const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ @@ -130,10 +138,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA( unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; - unsigned int inOffset = n * param.c * param.h * param.w; - const unsigned int curR = fastdiv(start_k+thread_col*8, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // unsigned int inOffset = n * param.c * param.h * param.w; int curH = posh_ori + curR * param.d_h; // input h int curW = posw_ori + curS * param.d_w; // input w // apply swizzle to the dst index @@ -141,9 +146,9 @@ __device__ __forceinline__ void tileMemcpySwizzleA( dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1); dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2); if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - curR < param.r && curS < param.s && curC < param.c && start_k+thread_col*8 < end_k){ + curR < param.r && curS < param.s && curC < param.c && ki < end_k){ const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; - dst_float4[dst_index] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; + dst_float4[dst_index] = reinterpret_cast(&src[n * chw + inOffsetTmp])[0]; } else{ dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f); } @@ -191,6 +196,13 @@ __device__ __forceinline__ void tileMemcpyLoadA( // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); + const unsigned int ki = start_k+block_k+thread_col*8; + const unsigned int chw = param.c * param.h * param.w; + + const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row; @@ -198,16 +210,13 @@ __device__ __forceinline__ void tileMemcpyLoadA( unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv); int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p; int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q; - unsigned int inOffset = n * param.c * param.h * param.w; - const unsigned int curR = fastdiv(start_k+block_k+thread_col*8, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + // unsigned int inOffset = n * param.c * param.h * param.w; int curH = posh_ori + curR * param.d_h; // input h int curW = posw_ori + curS * param.d_w; // input w if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h && - curR < param.r && curS < param.s && curC < param.c && start_k+block_k+thread_col*8 < end_k){ + curR < param.r && curS < param.s && curC < param.c && ki < end_k){ const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC; - dst_reg[i] = reinterpret_cast(&src[inOffset + inOffsetTmp])[0]; + dst_reg[i] = reinterpret_cast(&src[n * chw + inOffsetTmp])[0]; } else{ dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); } @@ -256,14 +265,15 @@ __device__ __forceinline__ void tileMemcpyLoadB( // compile time check that we provided the right amount of registers for storage static_assert(ELEMENTS_PER_THREAD == NUM_ITERS); - const unsigned int curR = fastdiv(start_k+block_k+thread_col*8, param.SC_fastdiv); // channel offset - const unsigned int curS = fastdiv(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset - const unsigned int curC = fastmodulo(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // + const unsigned int ki = start_k+block_k+thread_col*8; + const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset + const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset + const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // #pragma unroll for (unsigned int i = 0; i < NUM_ITERS; i++){ - const unsigned int src_index = thread_row * src_stride + start_k + block_k + thread_col * 8; - if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && start_k+block_k+thread_col*8 < end_k){ + const unsigned int src_index = thread_row * src_stride + ki; + if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){ dst_reg[i] = reinterpret_cast(&src[src_index])[0]; }else{ // read 4 halves dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f); diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 87abd015dc..e5da8ab056 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -384,8 +384,7 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - // for(int i = 26*38; i < 2*26*38; i++) { - // for(int i = 0; i < conv2d_data.size(); i++) { + // // for(int i = 0; i < conv2d_data.size(); i++) { // float diff = fabs(im2col_data[i] - conv2d_data[i]); // // if(diff > 0.5) { // printf("(%7.3f, %7.3f, %.2f, %d) \n",