try to reduce index calculation
This commit is contained in:
parent
d9a48580fc
commit
09e3a5f07d
|
|
@ -720,6 +720,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
const uint inChannelOffset = param.c * param.w;
|
||||
const uint weightKOffset = K;
|
||||
|
||||
const unsigned int PQ = param.Ow * param.Oh;
|
||||
const unsigned int KPQ = param.k * PQ;
|
||||
const unsigned int NKPQ = param.n * KPQ;
|
||||
|
||||
// loop bounds, constexpr where possible allows for loop unrolling
|
||||
constexpr unsigned int mma_tiles_per_warp_k = 4;
|
||||
constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M;
|
||||
|
|
@ -845,14 +849,15 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
for (int i = 0; i < 2; ++i)
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++)
|
||||
{
|
||||
const int output_sts_offset = output_sts_addr + mma_m * MMA_M * BN / 2 - i * mma_tiles_per_warp_n/2 * MMA_N;
|
||||
for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++)
|
||||
{
|
||||
uint32_t (®_)[2] = reinterpret_cast<uint32_t(&)[2]>(acc_register_[mma_m][mma_n]);
|
||||
uint idx = output_sts_addr +
|
||||
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N;
|
||||
uint idx = output_sts_offset + mma_n * MMA_N;
|
||||
// mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N;
|
||||
idx = idx ^ ((idx & 0b1110000000) >> 4);
|
||||
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
|
||||
dst_ptr[0] = reg_[0];
|
||||
|
|
@ -861,24 +866,25 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const unsigned int m_i_wn = m_idx + i * WN / 2;
|
||||
#pragma unroll
|
||||
for (int subk = 0; subk < WN / 2; ++subk){
|
||||
const uint row = m_i_wn + subk;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; ++j){
|
||||
const uint row = m_idx + subk + i * WN / 2;
|
||||
const uint gemm_i = n_idx + j*32;
|
||||
const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||
const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||
if (n < param.n && row < param.k && col < param.Oh * param.Ow) {
|
||||
if (n < param.n && row < param.k && col < PQ) {
|
||||
uint idx = output_lds_addr + subk + j*32*BN/2;
|
||||
idx = idx ^ ((idx & 0b1110000000) >> 4);
|
||||
if constexpr (ksplit > 0) {
|
||||
const uint outOffset = z * param.n * param.k * param.Oh * param.Ow +
|
||||
n * param.k * param.Oh * param.Ow +
|
||||
row * param.Oh * param.Ow + col;
|
||||
const uint outOffset = z * NKPQ +
|
||||
n * KPQ +
|
||||
row * PQ + col;
|
||||
output[outOffset] = smemoutput[idx];
|
||||
} else {
|
||||
const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
|
||||
const uint outOffset = n * KPQ + row * PQ + col;
|
||||
output[outOffset] = smemoutput[idx];
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -59,18 +59,20 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
|
|||
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
||||
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
|
||||
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
|
||||
const unsigned int curR = fastdiv(start_k+thread_col*8, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
|
||||
|
||||
const unsigned int ki = start_k+thread_col*8;
|
||||
const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); //
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
// apply swizzle to the dst index
|
||||
const unsigned int src_index = thread_row * src_stride + start_k + thread_col * 8;
|
||||
const unsigned int src_index = thread_row * src_stride + ki;
|
||||
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && start_k+thread_col*8 < end_k){
|
||||
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){
|
||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||
}else{ // read 4 halves
|
||||
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
|
|
@ -122,6 +124,12 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
|
|||
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
|
||||
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
|
||||
|
||||
const unsigned int ki = start_k+thread_col*8;
|
||||
const unsigned int chw = param.c * param.h * param.w;
|
||||
const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
|
|
@ -130,10 +138,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
|
|||
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
|
||||
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
|
||||
unsigned int inOffset = n * param.c * param.h * param.w;
|
||||
const unsigned int curR = fastdiv(start_k+thread_col*8, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
// unsigned int inOffset = n * param.c * param.h * param.w;
|
||||
int curH = posh_ori + curR * param.d_h; // input h
|
||||
int curW = posw_ori + curS * param.d_w; // input w
|
||||
// apply swizzle to the dst index
|
||||
|
|
@ -141,9 +146,9 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
|
|||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
||||
curR < param.r && curS < param.s && curC < param.c && start_k+thread_col*8 < end_k){
|
||||
curR < param.r && curS < param.s && curC < param.c && ki < end_k){
|
||||
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
|
||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[n * chw + inOffsetTmp])[0];
|
||||
} else{
|
||||
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
}
|
||||
|
|
@ -191,6 +196,13 @@ __device__ __forceinline__ void tileMemcpyLoadA(
|
|||
// compile time check that we provided the right amount of registers for storage
|
||||
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
||||
|
||||
const unsigned int ki = start_k+block_k+thread_col*8;
|
||||
const unsigned int chw = param.c * param.h * param.w;
|
||||
|
||||
const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
|
||||
|
|
@ -198,16 +210,13 @@ __device__ __forceinline__ void tileMemcpyLoadA(
|
|||
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
|
||||
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
|
||||
unsigned int inOffset = n * param.c * param.h * param.w;
|
||||
const unsigned int curR = fastdiv(start_k+block_k+thread_col*8, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
// unsigned int inOffset = n * param.c * param.h * param.w;
|
||||
int curH = posh_ori + curR * param.d_h; // input h
|
||||
int curW = posw_ori + curS * param.d_w; // input w
|
||||
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
||||
curR < param.r && curS < param.s && curC < param.c && start_k+block_k+thread_col*8 < end_k){
|
||||
curR < param.r && curS < param.s && curC < param.c && ki < end_k){
|
||||
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
|
||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[n * chw + inOffsetTmp])[0];
|
||||
} else{
|
||||
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
}
|
||||
|
|
@ -256,14 +265,15 @@ __device__ __forceinline__ void tileMemcpyLoadB(
|
|||
// compile time check that we provided the right amount of registers for storage
|
||||
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
||||
|
||||
const unsigned int curR = fastdiv(start_k+block_k+thread_col*8, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
|
||||
const unsigned int ki = start_k+block_k+thread_col*8;
|
||||
const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); //
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
const unsigned int src_index = thread_row * src_stride + start_k + block_k + thread_col * 8;
|
||||
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && start_k+block_k+thread_col*8 < end_k){
|
||||
const unsigned int src_index = thread_row * src_stride + ki;
|
||||
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){
|
||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||
}else{ // read 4 halves
|
||||
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
|
|
|
|||
|
|
@ -384,8 +384,7 @@ int main(void)
|
|||
|
||||
// for(int i = 0; i < ggml_nelements(wino_res); i++) {
|
||||
// for(int i = 0; i < 26*38; i++) {
|
||||
// for(int i = 26*38; i < 2*26*38; i++) {
|
||||
// for(int i = 0; i < conv2d_data.size(); i++) {
|
||||
// // for(int i = 0; i < conv2d_data.size(); i++) {
|
||||
// float diff = fabs(im2col_data[i] - conv2d_data[i]);
|
||||
// // if(diff > 0.5) {
|
||||
// printf("(%7.3f, %7.3f, %.2f, %d) \n",
|
||||
|
|
|
|||
Loading…
Reference in New Issue