diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index a11d306c6c..06bb4c53f1 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -986,7 +986,7 @@ template static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, const half * __restrict__ kernel, - float * __restrict__ output, + half * __restrict__ output, const param_t param) { constexpr unsigned int MMA_M = 16; constexpr unsigned int MMA_N = 8; @@ -1123,6 +1123,51 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input, } } + // reuse smem + half *smemoutput = reinterpret_cast(shmem); + const uint lane_id = threadIdx.x % WARPSIZE; + const uint mma_row = lane_id / 4; + const uint mma_col = lane_id % 4; + const uint output_lds_addr = warp_id * WSUBM * WSUBN + lane_id; + const uint output_sts_addr = warp_m * WM * BN/2 + mma_row * BN/2 + warp_n * WN/2 + mma_col * 2; + const uint m_idx = by * BN + mma_tid_y * WN; + const uint n_idx = block_m * BM + warp_m * WM; + +#pragma unroll + for (int i = 0; i < 2; ++i) + { + for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++) + { + for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++) + { + // output sts + uint32_t (®_)[2] = reinterpret_cast(acc_register_[mma_m][mma_n]); + uint32_t* dst_ptr = reinterpret_cast(&smemoutput[output_sts_addr + + mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N]); + dst_ptr[0] = reg_[0]; + dst_ptr = reinterpret_cast(&smemoutput[output_sts_addr + + mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N + 8 * BN / 2]); + dst_ptr[0] = reg_[1]; + } + } + __syncthreads(); +#pragma unroll + const uint row = m_idx + j * WSUBN + (lane_id + subk * WARPSIZE) / WSUBM; + const uint gemm_i = n_idx + i * WSUBM + (lane_id + subk * WARPSIZE) % WSUBM; + const int n = fastdiv(gemm_i, param.OHOW_fastdiv); + const int col = fastmodulo(gemm_i, param.OHOW_fastdiv); + if (n < param.n && row < param.k && col < param.Oh * param.Ow){ + // int outOffset = z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + (n_idx + j * 32); + // if (n < param.n && (m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow) + // param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32]; + const uint outOffset = ksplit > 0 ? + z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + + row * param.Oh * param.Ow + col : + z * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col; + output[outOffset] = smemoutput[output_lds_addr + subk * WARPSIZE]; + } + } + ////////////// // epilogue // //////////////