diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 1fceeb9a6e..654d2dffe4 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -805,6 +805,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } } + const unsigned int A_warp_tile_offset = warp_m * WM * BK; + const unsigned int B_warp_tile_offset = warp_n * WN * BK; + static_assert(BM == 256); static_assert(BN == 256); static_assert(BK == 32); @@ -825,13 +828,11 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, __syncthreads(); if (block_k != num_block_tiles_k){ - const half* A_block_gmem = input; - const half* B_block_gmem = kernel + (block_n * BN * weightKOffset); tileMemcpyLoadA(A_block_gmem, A_gmem_cache_reg, block_k * BK, start_k, end_k, inChannelOffset, param); tileMemcpyLoadB(B_block_gmem, B_gmem_cache_reg, block_k * BK, start_k, end_k, weightKOffset, param); } - half* A_warp_tile = A_block_smem + (warp_m * WM * BK); - half* B_warp_tile = B_block_smem + (warp_n * WN * BK); + half* A_warp_tile = A_block_smem + A_warp_tile_offset; + half* B_warp_tile = B_block_smem + B_warp_tile_offset; ldmatrix_a(A_warp_tile, A_register_); ldmatrix_b(B_warp_tile, B_register_); @@ -886,23 +887,25 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const uint lane_id = threadIdx.x % WARPSIZE; const uint mma_row = lane_id / 4; const uint mma_col = lane_id % 4; - const uint output_lds_addr = warp_m * WM * BN/2 + lane_id * BN/2 + warp_n * WN/2; - const uint output_sts_addr = warp_m * WM * BN/2 + mma_row * BN/2 + warp_n * WN/2 + mma_col * 2; + const uint warp_offset = warp_m * WM * BN/2 + warp_n * WN/2; + const uint output_lds_addr = warp_offset + lane_id * BN/2; + const uint output_sts_addr = warp_offset + mma_row * BN/2 + mma_col * 2; const uint m_idx = block_n * BN + warp_n * WN; const uint n_idx = block_m * BM + warp_m * WM + lane_id; #pragma unroll for (int i = 0; i < 2; ++i) { + const unsigned int i_offset = i * mma_tiles_per_warp_n/2; __syncthreads(); #pragma unroll for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) { - for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++) + const unsigned int mma_m_offset = output_sts_addr + mma_m * MMA_M * BN / 2; + for (unsigned int mma_n = i_offset; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++) { uint32_t (®_)[2] = reinterpret_cast(acc_register_[mma_m][mma_n]); - uint idx = output_sts_addr + - mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N; + uint idx = mma_m_offset + (mma_n - i_offset) * MMA_N; idx = idx ^ ((idx & 0b110000000000) >> 9); idx = idx ^ ((idx & 0b1110000000) >> 4); uint32_t* dst_ptr = reinterpret_cast(&smemoutput[idx]); @@ -913,6 +916,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } } __syncthreads(); + const unsigned int m_i_wn = m_idx + i * WN / 2; #pragma unroll for (int subk = 0; subk < WN / 4; ++subk){ @@ -925,29 +929,15 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const uint gemm_i = n_idx + j*32; const int n = fastdiv(gemm_i, param.OHOW_fastdiv); const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); - uint32_t dst_ptr = *(reinterpret_cast(&smemoutput[idx+j*32*BN/2])); + uint32_t dst_ptr = *(reinterpret_cast(&smemoutput[idx+j*16*BN])); // 32*BN/2 = 16*BN half (&res_)[2] = reinterpret_cast(dst_ptr); if (n < param.n && row < param.k && col < PQ) { - if constexpr (ksplit > 0) { - const uint outOffset = z * NKPQ + - n * KPQ + - row * PQ + col; - output[outOffset] = ggml_cuda_cast(res_[0]); - } else { - const uint outOffset = n * KPQ + row * PQ + col; - output[outOffset] = ggml_cuda_cast(res_[0]); - } + const uint outOffset = ((ksplit > 0) ? z * NKPQ : 0) + n * KPQ + row * PQ + col; + output[outOffset] = ggml_cuda_cast(res_[0]); } if (n < param.n && row+1 < param.k && col < PQ) { - if constexpr (ksplit > 0) { - const uint outOffset = z * NKPQ + - n * KPQ + - (row+1) * PQ + col; - output[outOffset] = ggml_cuda_cast(res_[1]); - } else { - const uint outOffset = n * KPQ + (row+1) * PQ + col; - output[outOffset] = ggml_cuda_cast(res_[1]); - } + const uint outOffset = ((ksplit > 0) ? z * NKPQ : 0) + n * KPQ + (row+1) * PQ + col; + output[outOffset] = ggml_cuda_cast(res_[1]); } } } diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 8ee0747989..90ef1e5237 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -714,15 +714,15 @@ int main(void) // for(int i = 0; i < ggml_nelements(wino_res); i++) { // for(int i = 0; i < 26*38; i++) { - // for(int i = 0; i < conv2d_data.size(); i++) { - // float diff = fabs(im2col_data[i] - conv2d_data[i]); - // // if(diff > 0.5) { - // printf("(%7.3f, %7.3f, %.2f, %d) \n", - // im2col_data[i], conv2d_data[i], - // diff, i); - // // break; - // // } - // } + for(int i = 0; i < conv2d_data.size(); i++) { + float diff = fabs(im2col_data[i] - conv2d_data[i]); + // if(diff > 0.5) { + printf("(%7.3f, %7.3f, %.2f, %d) \n", + im2col_data[i], conv2d_data[i], + diff, i); + // break; + // } + } ggml_free(model.ctx); ggml_backend_buffer_free(model.buffer);