WIP
This commit is contained in:
parent
f931ad883f
commit
1b69ed44c6
|
|
@ -805,8 +805,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
// prefetch the first block tile of A,B into shared memory
|
||||
half* A_block_gmem = input + (block_m * BM * A_stride);
|
||||
half* B_block_gmem = weight + (block_n * BN);
|
||||
tileMemcpySwizzleA<BM_dim, NUM_THREADS>(A_block_gmem, A_block_smem, K);
|
||||
tileMemcpySwizzle<BK_dim, BN_dim, NUM_THREADS, SWIZZLE_BITS_B>(B_block_gmem, B_block_smem, N);
|
||||
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, K);
|
||||
tileMemcpySwizzle<BK, BN, NUM_THREADS, SWIZZLE_BITS_B>(B_block_gmem, B_block_smem, N);
|
||||
|
||||
// construct const pointers to warp tiles for use inside the inner loop
|
||||
|
||||
|
|
@ -819,16 +819,16 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
|
||||
if (block_k != num_block_tiles_k)
|
||||
{
|
||||
half* A_block_gmem = A + (block_m * BM_dim * A_stride) + (block_k * BK_dim);
|
||||
half* B_block_gmem = B + (block_k * BK_dim * B_stride) + (block_n * BN_dim);
|
||||
tileMemcpyLoad<BM_dim, BK_dim, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, K);
|
||||
tileMemcpyLoad<BK_dim, BN_dim, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, N);
|
||||
half* A_block_gmem = A + (block_m * BM * A_stride) + (block_k * BK);
|
||||
half* B_block_gmem = B + (block_k * BK * B_stride) + (block_n * BN);
|
||||
tileMemcpyLoad<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, K);
|
||||
tileMemcpyLoad<BK, BN, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, N);
|
||||
}
|
||||
half* A_warp_tile = A_block_smem + (warp_m * WM_dim * BK_dim);
|
||||
half* B_warp_tile = B_block_smem + (warp_n * WN_dim);
|
||||
half* A_warp_tile = A_block_smem + (warp_m * WM * BK);
|
||||
half* B_warp_tile = B_block_smem + (warp_n * WN);
|
||||
|
||||
ldmatrix_a<mma_tiles_per_warp_m, mma_tiles_per_warp_k, BK_dim>(A_warp_tile, A_register_);
|
||||
ldmatrix_b<mma_tiles_per_warp_k, mma_tiles_per_warp_n, BN_dim>(B_warp_tile, B_register_);
|
||||
ldmatrix_a<mma_tiles_per_warp_m, mma_tiles_per_warp_k, BK>(A_warp_tile, A_register_);
|
||||
ldmatrix_b<mma_tiles_per_warp_k, mma_tiles_per_warp_n, BN>(B_warp_tile, B_register_);
|
||||
|
||||
// outer product between mma tiles
|
||||
#pragma unroll
|
||||
|
|
@ -863,8 +863,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
B_block_smem = B_block_smem + BUFFER_SIZE * offset_direction;
|
||||
offset_direction = -1 * offset_direction;
|
||||
|
||||
tileMemcpySwizzleStoreA<BM_dim, NUM_THREADS, 4>(A_gmem_cache_reg, A_block_smem);
|
||||
tileMemcpySwizzleStore<BK_dim, BN_dim, NUM_THREADS, SWIZZLE_BITS_B, 4>(B_gmem_cache_reg, B_block_smem);
|
||||
tileMemcpySwizzleStoreA<BM, NUM_THREADS, 4>(A_gmem_cache_reg, A_block_smem);
|
||||
tileMemcpySwizzleStore<BK, BN, NUM_THREADS, SWIZZLE_BITS_B, 4>(B_gmem_cache_reg, B_block_smem);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue