WIP: output
This commit is contained in:
parent
66f6d16265
commit
2715341c1d
|
|
@ -986,7 +986,7 @@ template<const int BM, const int BN, const int BK, const int WM, const int WN,
|
|||
const int WK, const int NUM_THREADS>
|
||||
static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
||||
const half * __restrict__ kernel,
|
||||
float * __restrict__ output,
|
||||
half * __restrict__ output,
|
||||
const param_t param) {
|
||||
constexpr unsigned int MMA_M = 16;
|
||||
constexpr unsigned int MMA_N = 8;
|
||||
|
|
@ -1123,6 +1123,51 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
}
|
||||
}
|
||||
|
||||
// reuse smem
|
||||
half *smemoutput = reinterpret_cast<half *>(shmem);
|
||||
const uint lane_id = threadIdx.x % WARPSIZE;
|
||||
const uint mma_row = lane_id / 4;
|
||||
const uint mma_col = lane_id % 4;
|
||||
const uint output_lds_addr = warp_id * WSUBM * WSUBN + lane_id;
|
||||
const uint output_sts_addr = warp_m * WM * BN/2 + mma_row * BN/2 + warp_n * WN/2 + mma_col * 2;
|
||||
const uint m_idx = by * BN + mma_tid_y * WN;
|
||||
const uint n_idx = block_m * BM + warp_m * WM;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++i)
|
||||
{
|
||||
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++)
|
||||
{
|
||||
for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++)
|
||||
{
|
||||
// output sts
|
||||
uint32_t (®_)[2] = reinterpret_cast<uint32_t(&)[2]>(acc_register_[mma_m][mma_n]);
|
||||
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[output_sts_addr +
|
||||
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N]);
|
||||
dst_ptr[0] = reg_[0];
|
||||
dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[output_sts_addr +
|
||||
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N + 8 * BN / 2]);
|
||||
dst_ptr[0] = reg_[1];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
const uint row = m_idx + j * WSUBN + (lane_id + subk * WARPSIZE) / WSUBM;
|
||||
const uint gemm_i = n_idx + i * WSUBM + (lane_id + subk * WARPSIZE) % WSUBM;
|
||||
const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||
const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||
if (n < param.n && row < param.k && col < param.Oh * param.Ow){
|
||||
// int outOffset = z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + (n_idx + j * 32);
|
||||
// if (n < param.n && (m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow)
|
||||
// param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32];
|
||||
const uint outOffset = ksplit > 0 ?
|
||||
z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow +
|
||||
row * param.Oh * param.Ow + col :
|
||||
z * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
|
||||
output[outOffset] = smemoutput[output_lds_addr + subk * WARPSIZE];
|
||||
}
|
||||
}
|
||||
|
||||
//////////////
|
||||
// epilogue //
|
||||
//////////////
|
||||
|
|
|
|||
Loading…
Reference in New Issue