change mac loop to match cutlass
This commit is contained in:
parent
9f498d29f1
commit
0939511846
|
|
@ -781,9 +781,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
constexpr unsigned int MMA_M = 16;
|
||||
constexpr unsigned int MMA_N = 8;
|
||||
|
||||
const unsigned int K = param.c * param.r * param.s;
|
||||
const unsigned int K = param.c;
|
||||
const uint inChannelOffset = param.c * param.w;
|
||||
const uint weightKOffset = K;
|
||||
const uint weightKOffset = param.c * param.r * param.s;
|
||||
|
||||
const unsigned int PQ = param.Ow * param.Oh;
|
||||
const unsigned int KPQ = param.k * PQ;
|
||||
|
|
@ -799,18 +799,25 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
constexpr unsigned int mma_tiles_per_warp_n = WN / MMA_N;
|
||||
const unsigned int z = blockIdx.z;
|
||||
|
||||
const unsigned int ks = (ksplit > 0) ? (weightKOffset + ksplit - 1) / ksplit : weightKOffset;
|
||||
const unsigned int ks = (ksplit > 0) ? (K + ksplit - 1) / ksplit : K;
|
||||
const unsigned int start_k = (ksplit > 0) ? z * ks : 0;
|
||||
const unsigned int end_k = min(start_k + ks, weightKOffset);
|
||||
const unsigned int end_k = min(start_k + ks, K);
|
||||
const unsigned int num_block_tiles_k = (ks + (BK-1)) / BK;
|
||||
|
||||
constexpr unsigned int TILE_COLS_VECTORIZED = BK / 8;
|
||||
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
|
||||
constexpr unsigned int A_K_STRID = BM / ROW_STEP;
|
||||
constexpr unsigned int B_K_STRID = BN / ROW_STEP;
|
||||
|
||||
unsigned int masks_a[A_K_STRID][2];
|
||||
unsigned int element_offset_a[A_K_STRID];
|
||||
|
||||
// calculate block/warp indices
|
||||
const unsigned int block_m = blockIdx.y;
|
||||
const unsigned int block_n = blockIdx.x;
|
||||
const unsigned int warp_m = threadIdx.y;
|
||||
const unsigned int warp_n = threadIdx.x / 32;
|
||||
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// double buffering
|
||||
extern __shared__ half shmem[];
|
||||
|
|
@ -858,12 +865,21 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
float4 A_gmem_cache_reg[4];
|
||||
float4 B_gmem_cache_reg[4];
|
||||
|
||||
|
||||
prepareIteratorA<BM, BK, A_K_STRID, ROW_STEP>(thread_idx, masks_a, element_offset_a, param);
|
||||
|
||||
|
||||
// prefetch the first block tile of A,B into shared memory
|
||||
|
||||
const half* A_block_gmem = input;
|
||||
const half* B_block_gmem = kernel + block_n * BN * weightKOffset;
|
||||
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, start_k, end_k, inChannelOffset, param);
|
||||
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, start_k, end_k, weightKOffset, param);
|
||||
int s = 0;
|
||||
int r = 0;
|
||||
while (r < param.r) {
|
||||
// for (int r = 0; r < param.r; ++r) {
|
||||
|
||||
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, r, s, masks_a, element_offset_a, thread_idx, start_k, end_k, inChannelOffset, param);
|
||||
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, r, s, start_k, end_k, weightKOffset, param);
|
||||
|
||||
int offset_direction = 1;
|
||||
|
||||
|
|
@ -871,8 +887,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
__syncthreads();
|
||||
|
||||
if (block_k != num_block_tiles_k){
|
||||
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK, start_k, end_k, inChannelOffset, param);
|
||||
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, block_k * BK, start_k, end_k, weightKOffset, param);
|
||||
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, inChannelOffset, param);
|
||||
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, weightKOffset, param);
|
||||
}
|
||||
half* A_warp_tile = A_block_smem + A_warp_tile_offset;
|
||||
half* B_warp_tile = B_block_smem + B_warp_tile_offset;
|
||||
|
|
@ -926,7 +942,14 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
tileMemcpySwizzleStore<BM, NUM_THREADS, 4>(A_gmem_cache_reg, A_block_smem);
|
||||
tileMemcpySwizzleStore<BN, NUM_THREADS, 4>(B_gmem_cache_reg, B_block_smem);
|
||||
}
|
||||
}
|
||||
} // iter block_k
|
||||
|
||||
s++;
|
||||
if (s == param.s) {
|
||||
s = 0;
|
||||
r++;
|
||||
}
|
||||
} // iter r
|
||||
|
||||
// reuse smem
|
||||
half *smemoutput = shmem;
|
||||
|
|
@ -1166,7 +1189,8 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
|
|||
ks = 16;
|
||||
for (j = 2; j <= ks; j++){
|
||||
const int remainder = (BlocksM * BlocksN * j) % nsm;
|
||||
if ((P.c * P.r * P.s) % (8*j) == 0){
|
||||
// if ((P.c * P.r * P.s) % (8*j) == 0){
|
||||
if ((P.c) % (8*j) == 0){
|
||||
if (remainder == 0) {
|
||||
candidate = j;
|
||||
max_remaining_waves = 0;
|
||||
|
|
|
|||
|
|
@ -26,12 +26,89 @@ typedef struct{
|
|||
} param_t;
|
||||
|
||||
|
||||
/// Clears the predicates
|
||||
|
||||
template<const unsigned int K_STRID>
|
||||
__host__ __device__ void clear_mask(unsigned int masks_[][2], bool clear = true) {
|
||||
|
||||
#pragma unroll
|
||||
for (int s = 0; s < K_STRID; ++s) {
|
||||
masks_[s][0] = clear ? 0 : masks_[s][0];
|
||||
masks_[s][1] = clear ? 0 : masks_[s][1];
|
||||
}
|
||||
}
|
||||
|
||||
template<const unsigned int TILE_ROWS,
|
||||
const unsigned int TILE_COLS,
|
||||
const unsigned int A_K_STRID,
|
||||
const unsigned int ROW_STEP>
|
||||
__device__ void prepareIteratorA(const int thread_idx,
|
||||
unsigned int masks[][2],
|
||||
unsigned int element_offset[],
|
||||
const param_t param){
|
||||
int offset_n[A_K_STRID];
|
||||
int offset_p[A_K_STRID];
|
||||
int offset_q[A_K_STRID];
|
||||
|
||||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
|
||||
const unsigned int chw = param.c * param.h * param.w;
|
||||
|
||||
#pragma unroll
|
||||
for (int s = 0; s < A_K_STRID; ++s) {
|
||||
|
||||
// pointer_[s] = reinterpret_cast<char const *>(ptr);
|
||||
|
||||
// int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided;
|
||||
const unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
|
||||
offset_n[s] = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||
offset_p[s] = fastdiv(npq_res, param.OW_fastdiv); //* param.u - param.p;
|
||||
offset_q[s] = fastmodulo(npq_res, param.OW_fastdiv); // * param.v - param.q;
|
||||
const int h = offset_p[s] * param.u - param.p;
|
||||
const int w = offset_q[s] * param.v - param.q;
|
||||
|
||||
// if(threadIdx.x < 32 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0)
|
||||
// printf("%d, %d : %d, %d, %d, %d offset (%d, %d, %d), kele %llu Kcont %d\n ", thread_idx, s,
|
||||
// // printf("[%s - %d] %d, %d : %d, %d, %d, %d\n ", __FUNCTION__, __LINE__, thread_idx, s,
|
||||
// threadblock_offset.row(), thread_coord.strided(), ThreadMap::Delta::kStrided,
|
||||
// offset_npq, offset_n[s], offset_p[s], offset_q[s], AccessType::kElements,
|
||||
// ThreadMap::Iterations::kContiguous);
|
||||
|
||||
element_offset[s] = offset_n[s] * chw + h * param.c * param.w + w * param.c;
|
||||
thread_row += ROW_STEP;
|
||||
}
|
||||
|
||||
clear_mask<A_K_STRID>(masks);
|
||||
|
||||
for (int r = 0; r < param.r; ++r) {
|
||||
#pragma unroll
|
||||
for (int s_idx = 0; s_idx < A_K_STRID; ++s_idx) {
|
||||
const int h = offset_p[s_idx] * param.u - param.p + r * param.d_h;
|
||||
|
||||
bool pred = (offset_n[s_idx] < param.n && h >= 0 && h < param.h);
|
||||
masks[s_idx][0] |= (pred << r);
|
||||
}
|
||||
}
|
||||
|
||||
for (int s = 0; s < param.s; ++s) {
|
||||
#pragma unroll
|
||||
for (int s_idx = 0; s_idx < A_K_STRID; ++s_idx) {
|
||||
const int w = offset_q[s_idx] * param.v - param.q + s * param.d_w;
|
||||
bool pred = (w >= 0 && w < param.w);
|
||||
masks[s_idx][1] |= (pred << s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel
|
||||
template<unsigned int TILE_ROWS,
|
||||
unsigned int NUM_THREADS>
|
||||
__device__ __forceinline__ void tileMemcpySwizzleB(
|
||||
const half* src,
|
||||
half* dst,
|
||||
const unsigned int curR,
|
||||
const unsigned int curS,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
const unsigned int src_stride,
|
||||
|
|
@ -60,10 +137,12 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
|
|||
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
|
||||
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
|
||||
|
||||
const unsigned int ki = start_k+thread_col*8;
|
||||
const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); //
|
||||
// const unsigned int ki = (curR*param.s+curS)*param.c + start_k+thread_col*8;
|
||||
// const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset
|
||||
// const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
// const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); //
|
||||
const unsigned int curC = start_k+thread_col*8;
|
||||
const unsigned int ki = (curR*param.s+curS)*param.c + curC;
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
|
|
@ -72,7 +151,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
|
|||
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){
|
||||
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k){
|
||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||
}else{ // read 4 halves
|
||||
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
|
|
@ -95,6 +174,11 @@ unsigned int NUM_THREADS>
|
|||
__device__ __forceinline__ void tileMemcpySwizzleA(
|
||||
const half* src,
|
||||
half* dst,
|
||||
const unsigned int curR,
|
||||
const unsigned int curS,
|
||||
unsigned int masks[][2],
|
||||
unsigned int element_offset[],
|
||||
const unsigned int thread_idx,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
const unsigned int inChannelOffset,
|
||||
|
|
@ -115,7 +199,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
|
|||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
||||
// flatten out 2d grid of threads into in order of increasing threadIdx.x
|
||||
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
// const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
|
||||
// assign each thread a row/column in the tile, calculate how many iterations we need
|
||||
// to cover the whole tile
|
||||
|
|
@ -126,11 +210,27 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
|
|||
|
||||
const unsigned int ki = start_k+thread_col*8;
|
||||
const unsigned int chw = param.c * param.h * param.w;
|
||||
const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
|
||||
|
||||
// const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset
|
||||
// const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
// const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = ki;
|
||||
// #pragma unroll
|
||||
// for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
// bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS));
|
||||
// // apply swizzle to the dst index
|
||||
// unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
// dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
// dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
// if (valid && ki < end_k){
|
||||
// if(element_offset[i]+curC >= 327680 || element_offset[i]+curC < 0)
|
||||
// printf("%d, %d, %d, %d, %d, %d, %d, %d, %d \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y,
|
||||
// i, element_offset[i], curR, curS, curC);
|
||||
// dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[element_offset[i]+curC])[0];
|
||||
// } else{
|
||||
// dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
// }
|
||||
// thread_row += ROW_STEP;
|
||||
// }
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
|
||||
|
|
@ -170,7 +270,8 @@ unsigned int ELEMENTS_PER_THREAD>
|
|||
__device__ __forceinline__ void tileMemcpyLoadA(
|
||||
const half* src,
|
||||
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
|
||||
// const unsigned int src_stride,
|
||||
const unsigned int curR,
|
||||
const unsigned int curS,
|
||||
const unsigned int block_k,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
|
|
@ -199,9 +300,10 @@ __device__ __forceinline__ void tileMemcpyLoadA(
|
|||
const unsigned int ki = start_k+block_k+thread_col*8;
|
||||
const unsigned int chw = param.c * param.h * param.w;
|
||||
|
||||
const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
// const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset
|
||||
// const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
// const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = ki;
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
|
|
@ -240,6 +342,8 @@ unsigned int ELEMENTS_PER_THREAD>
|
|||
__device__ __forceinline__ void tileMemcpyLoadB(
|
||||
const half* src,
|
||||
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
|
||||
const unsigned int curR,
|
||||
const unsigned int curS,
|
||||
const unsigned int block_k,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
|
|
@ -265,15 +369,16 @@ __device__ __forceinline__ void tileMemcpyLoadB(
|
|||
// compile time check that we provided the right amount of registers for storage
|
||||
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
||||
|
||||
const unsigned int ki = start_k+block_k+thread_col*8;
|
||||
const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset
|
||||
const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); //
|
||||
// const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset
|
||||
// const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
// const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); //
|
||||
const unsigned int curC = start_k+block_k+thread_col*8;
|
||||
const unsigned int ki = (curR*param.s+curS)*param.c + curC;
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
const unsigned int src_index = thread_row * src_stride + ki;
|
||||
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curC < param.c && ki < end_k){
|
||||
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k){
|
||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||
}else{ // read 4 halves
|
||||
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
|
|
|
|||
|
|
@ -301,7 +301,9 @@ static std::vector<std::tuple<int, int, int, int, int, int>> configs = {
|
|||
// std::make_tuple(960,320,104,152,3,3),
|
||||
// std::make_tuple(1280,1280,26,38,3,3),
|
||||
// std::make_tuple(1920,640,32,32,3,3)
|
||||
std::make_tuple(1280,1280,16,16,3,3),
|
||||
// std::make_tuple(1280,1280,16,16,3,3),
|
||||
// std::make_tuple(32,8,24,24,3,3),
|
||||
std::make_tuple(640,640,64,64,3,3),
|
||||
// std::make_tuple(320,640,32,32,3,3),
|
||||
// std::make_tuple(4,320,96,128,3,3),
|
||||
// std::make_tuple(320,4,96,128,3,3),
|
||||
|
|
@ -671,7 +673,7 @@ int main(void)
|
|||
// fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f);
|
||||
|
||||
|
||||
int iterations = 0;
|
||||
int iterations = 20;
|
||||
|
||||
double run_time0;
|
||||
std::vector<float> im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0);
|
||||
|
|
|
|||
Loading…
Reference in New Issue