try to reduce index calculation

This commit is contained in:
bssrdf 2025-11-05 22:02:57 -05:00
parent d9a48580fc
commit 09e3a5f07d
3 changed files with 49 additions and 34 deletions

View File

@ -720,6 +720,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
const uint inChannelOffset = param.c * param.w;
const uint weightKOffset = K;
const unsigned int PQ = param.Ow * param.Oh;
const unsigned int KPQ = param.k * PQ;
const unsigned int NKPQ = param.n * KPQ;
// loop bounds, constexpr where possible allows for loop unrolling
constexpr unsigned int mma_tiles_per_warp_k = 4;
constexpr unsigned int mma_tiles_per_warp_m = WM / MMA_M;
@ -845,14 +849,15 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
for (int i = 0; i < 2; ++i)
{
__syncthreads();
#pragma unroll
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++)
{
const int output_sts_offset = output_sts_addr + mma_m * MMA_M * BN / 2 - i * mma_tiles_per_warp_n/2 * MMA_N;
for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++)
{
uint32_t (&reg_)[2] = reinterpret_cast<uint32_t(&)[2]>(acc_register_[mma_m][mma_n]);
uint idx = output_sts_addr +
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N;
uint idx = output_sts_offset + mma_n * MMA_N;
// mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N;
idx = idx ^ ((idx & 0b1110000000) >> 4);
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
dst_ptr[0] = reg_[0];
@ -861,24 +866,25 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
}
}
__syncthreads();
const unsigned int m_i_wn = m_idx + i * WN / 2;
#pragma unroll
for (int subk = 0; subk < WN / 2; ++subk){
const uint row = m_i_wn + subk;
#pragma unroll
for (int j = 0; j < 4; ++j){
const uint row = m_idx + subk + i * WN / 2;
const uint gemm_i = n_idx + j*32;
const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
if (n < param.n && row < param.k && col < param.Oh * param.Ow) {
if (n < param.n && row < param.k && col < PQ) {
uint idx = output_lds_addr + subk + j*32*BN/2;
idx = idx ^ ((idx & 0b1110000000) >> 4);
if constexpr (ksplit > 0) {
const uint outOffset = z * param.n * param.k * param.Oh * param.Ow +
n * param.k * param.Oh * param.Ow +
row * param.Oh * param.Ow + col;
const uint outOffset = z * NKPQ +
n * KPQ +
row * PQ + col;
output[outOffset] = smemoutput[idx];
} else {
const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
const uint outOffset = n * KPQ + row * PQ + col;
output[outOffset] = smemoutput[idx];
}
}

View File

@ -59,18 +59,20 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
const unsigned int curR = fastdiv(start_k+thread_col*8, param.SC_fastdiv); // channel offset
const unsigned int curS = fastdiv(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const unsigned int curC = fastmodulo(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
const unsigned int ki = start_k+thread_col*8;
const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset
const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); //
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
// apply swizzle to the dst index
const unsigned int src_index = thread_row * src_stride + start_k + thread_col * 8;
const unsigned int src_index = thread_row * src_stride + ki;
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && start_k+thread_col*8 < end_k){
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[src_index])[0];
}else{ // read 4 halves
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
@ -122,6 +124,12 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
const unsigned int ki = start_k+thread_col*8;
const unsigned int chw = param.c * param.h * param.w;
const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset
const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
@ -130,10 +138,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
unsigned int inOffset = n * param.c * param.h * param.w;
const unsigned int curR = fastdiv(start_k+thread_col*8, param.SC_fastdiv); // channel offset
const unsigned int curS = fastdiv(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const unsigned int curC = fastmodulo(fastmodulo(start_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// unsigned int inOffset = n * param.c * param.h * param.w;
int curH = posh_ori + curR * param.d_h; // input h
int curW = posw_ori + curS * param.d_w; // input w
// apply swizzle to the dst index
@ -141,9 +146,9 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
curR < param.r && curS < param.s && curC < param.c && start_k+thread_col*8 < end_k){
curR < param.r && curS < param.s && curC < param.c && ki < end_k){
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[n * chw + inOffsetTmp])[0];
} else{
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
}
@ -191,6 +196,13 @@ __device__ __forceinline__ void tileMemcpyLoadA(
// compile time check that we provided the right amount of registers for storage
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
const unsigned int ki = start_k+block_k+thread_col*8;
const unsigned int chw = param.c * param.h * param.w;
const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset
const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
@ -198,16 +210,13 @@ __device__ __forceinline__ void tileMemcpyLoadA(
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
unsigned int inOffset = n * param.c * param.h * param.w;
const unsigned int curR = fastdiv(start_k+block_k+thread_col*8, param.SC_fastdiv); // channel offset
const unsigned int curS = fastdiv(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const unsigned int curC = fastmodulo(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// unsigned int inOffset = n * param.c * param.h * param.w;
int curH = posh_ori + curR * param.d_h; // input h
int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
curR < param.r && curS < param.s && curC < param.c && start_k+block_k+thread_col*8 < end_k){
curR < param.r && curS < param.s && curC < param.c && ki < end_k){
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
dst_reg[i] = reinterpret_cast<const float4 *>(&src[inOffset + inOffsetTmp])[0];
dst_reg[i] = reinterpret_cast<const float4 *>(&src[n * chw + inOffsetTmp])[0];
} else{
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
}
@ -256,14 +265,15 @@ __device__ __forceinline__ void tileMemcpyLoadB(
// compile time check that we provided the right amount of registers for storage
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
const unsigned int curR = fastdiv(start_k+block_k+thread_col*8, param.SC_fastdiv); // channel offset
const unsigned int curS = fastdiv(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const unsigned int curC = fastmodulo(fastmodulo(start_k+block_k+thread_col*8, param.SC_fastdiv), param.C_fastdiv); //
const unsigned int ki = start_k+block_k+thread_col*8;
const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset
const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); //
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
const unsigned int src_index = thread_row * src_stride + start_k + block_k + thread_col * 8;
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && start_k+block_k+thread_col*8 < end_k){
const unsigned int src_index = thread_row * src_stride + ki;
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
}else{ // read 4 halves
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);

View File

@ -384,8 +384,7 @@ int main(void)
// for(int i = 0; i < ggml_nelements(wino_res); i++) {
// for(int i = 0; i < 26*38; i++) {
// for(int i = 26*38; i < 2*26*38; i++) {
// for(int i = 0; i < conv2d_data.size(); i++) {
// // for(int i = 0; i < conv2d_data.size(); i++) {
// float diff = fabs(im2col_data[i] - conv2d_data[i]);
// // if(diff > 0.5) {
// printf("(%7.3f, %7.3f, %.2f, %d) \n",