WIP: output

This commit is contained in:
bssrdf 2025-10-23 21:29:45 -04:00
parent 66f6d16265
commit 2715341c1d
1 changed files with 46 additions and 1 deletions

View File

@ -986,7 +986,7 @@ template<const int BM, const int BN, const int BK, const int WM, const int WN,
const int WK, const int NUM_THREADS>
static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
const half * __restrict__ kernel,
float * __restrict__ output,
half * __restrict__ output,
const param_t param) {
constexpr unsigned int MMA_M = 16;
constexpr unsigned int MMA_N = 8;
@ -1123,6 +1123,51 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
}
}
// reuse smem
half *smemoutput = reinterpret_cast<half *>(shmem);
const uint lane_id = threadIdx.x % WARPSIZE;
const uint mma_row = lane_id / 4;
const uint mma_col = lane_id % 4;
const uint output_lds_addr = warp_id * WSUBM * WSUBN + lane_id;
const uint output_sts_addr = warp_m * WM * BN/2 + mma_row * BN/2 + warp_n * WN/2 + mma_col * 2;
const uint m_idx = by * BN + mma_tid_y * WN;
const uint n_idx = block_m * BM + warp_m * WM;
#pragma unroll
for (int i = 0; i < 2; ++i)
{
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++)
{
for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++)
{
// output sts
uint32_t (&reg_)[2] = reinterpret_cast<uint32_t(&)[2]>(acc_register_[mma_m][mma_n]);
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[output_sts_addr +
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N]);
dst_ptr[0] = reg_[0];
dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[output_sts_addr +
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N + 8 * BN / 2]);
dst_ptr[0] = reg_[1];
}
}
__syncthreads();
#pragma unroll
const uint row = m_idx + j * WSUBN + (lane_id + subk * WARPSIZE) / WSUBM;
const uint gemm_i = n_idx + i * WSUBM + (lane_id + subk * WARPSIZE) % WSUBM;
const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
if (n < param.n && row < param.k && col < param.Oh * param.Ow){
// int outOffset = z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + (n_idx + j * 32);
// if (n < param.n && (m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow)
// param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32];
const uint outOffset = ksplit > 0 ?
z * param.n * param.k * param.Oh * param.Ow + n * param.k * param.Oh * param.Ow +
row * param.Oh * param.Ow + col :
z * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
output[outOffset] = smemoutput[output_lds_addr + subk * WARPSIZE];
}
}
//////////////
// epilogue //
//////////////