This commit is contained in:
bssrdf 2025-10-21 17:15:26 -04:00
parent f931ad883f
commit 1b69ed44c6
1 changed files with 12 additions and 12 deletions

View File

@ -805,8 +805,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
// prefetch the first block tile of A,B into shared memory
half* A_block_gmem = input + (block_m * BM * A_stride);
half* B_block_gmem = weight + (block_n * BN);
tileMemcpySwizzleA<BM_dim, NUM_THREADS>(A_block_gmem, A_block_smem, K);
tileMemcpySwizzle<BK_dim, BN_dim, NUM_THREADS, SWIZZLE_BITS_B>(B_block_gmem, B_block_smem, N);
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, K);
tileMemcpySwizzle<BK, BN, NUM_THREADS, SWIZZLE_BITS_B>(B_block_gmem, B_block_smem, N);
// construct const pointers to warp tiles for use inside the inner loop
@ -819,16 +819,16 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
if (block_k != num_block_tiles_k)
{
half* A_block_gmem = A + (block_m * BM_dim * A_stride) + (block_k * BK_dim);
half* B_block_gmem = B + (block_k * BK_dim * B_stride) + (block_n * BN_dim);
tileMemcpyLoad<BM_dim, BK_dim, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, K);
tileMemcpyLoad<BK_dim, BN_dim, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, N);
half* A_block_gmem = A + (block_m * BM * A_stride) + (block_k * BK);
half* B_block_gmem = B + (block_k * BK * B_stride) + (block_n * BN);
tileMemcpyLoad<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, K);
tileMemcpyLoad<BK, BN, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, N);
}
half* A_warp_tile = A_block_smem + (warp_m * WM_dim * BK_dim);
half* B_warp_tile = B_block_smem + (warp_n * WN_dim);
half* A_warp_tile = A_block_smem + (warp_m * WM * BK);
half* B_warp_tile = B_block_smem + (warp_n * WN);
ldmatrix_a<mma_tiles_per_warp_m, mma_tiles_per_warp_k, BK_dim>(A_warp_tile, A_register_);
ldmatrix_b<mma_tiles_per_warp_k, mma_tiles_per_warp_n, BN_dim>(B_warp_tile, B_register_);
ldmatrix_a<mma_tiles_per_warp_m, mma_tiles_per_warp_k, BK>(A_warp_tile, A_register_);
ldmatrix_b<mma_tiles_per_warp_k, mma_tiles_per_warp_n, BN>(B_warp_tile, B_register_);
// outer product between mma tiles
#pragma unroll
@ -863,8 +863,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
B_block_smem = B_block_smem + BUFFER_SIZE * offset_direction;
offset_direction = -1 * offset_direction;
tileMemcpySwizzleStoreA<BM_dim, NUM_THREADS, 4>(A_gmem_cache_reg, A_block_smem);
tileMemcpySwizzleStore<BK_dim, BN_dim, NUM_THREADS, SWIZZLE_BITS_B, 4>(B_gmem_cache_reg, B_block_smem);
tileMemcpySwizzleStoreA<BM, NUM_THREADS, 4>(A_gmem_cache_reg, A_block_smem);
tileMemcpySwizzleStore<BK, BN, NUM_THREADS, SWIZZLE_BITS_B, 4>(B_gmem_cache_reg, B_block_smem);
}
}