WIP: debugging
This commit is contained in:
parent
378bb8368e
commit
dbeb6ced46
|
|
@ -831,6 +831,15 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
half* B_block_smem = &shmem[BM * BK];
|
||||
constexpr int BUFFER_SIZE = BM * BK + BK * BN;
|
||||
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
half* SA1 = A_block_smem;
|
||||
half* SB1 = B_block_smem;
|
||||
half* SA2 = &shmem[BUFFER_SIZE];
|
||||
half* SB2 = SA2 + BM * BK;
|
||||
#else
|
||||
float4 A_gmem_cache_reg[4];
|
||||
float4 B_gmem_cache_reg[4];
|
||||
#endif
|
||||
// declare register storage
|
||||
// ptx instructions expect uint32_t registers, where each uint32_t is 2 halfs packed together
|
||||
uint32_t acc_register[mma_tiles_per_warp_m][mma_tiles_per_warp_n][2];
|
||||
|
|
@ -868,9 +877,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
static_assert(BN == 256);
|
||||
static_assert(BK == 32);
|
||||
static_assert(NUM_THREADS == 256);
|
||||
float4 A_gmem_cache_reg[4];
|
||||
float4 B_gmem_cache_reg[4];
|
||||
|
||||
|
||||
|
||||
prepareIteratorA<BM, BK, A_K_STRID, ROW_STEP>(thread_row, masks_a, element_offset_a, param);
|
||||
|
|
@ -898,7 +904,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
unsigned int curC = tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, 0, 0, masks_a, element_offset_a,
|
||||
thread_row, thread_col, start_k, end_k, param);
|
||||
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, 0, 0, start_k, end_k, thread_row, thread_col, param);
|
||||
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
#endif
|
||||
int offset_direction = 1;
|
||||
unsigned int block_k = 0;
|
||||
unsigned int block_krs = 1;
|
||||
|
|
@ -906,6 +914,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
int s = 0;
|
||||
int r = 0;
|
||||
while (block_k < num_block_tiles_k){
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(0));
|
||||
__syncthreads();
|
||||
|
||||
// moves to the next tile
|
||||
|
|
@ -948,15 +957,29 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
|
||||
// if (block_k != num_block_tiles_k){
|
||||
if (block_krs != num_block_tiles_krs){
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
curC = tileMemcpyAsyncLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, SA2, r, s,
|
||||
masks_a, element_offset_a, thread_row, thread_col, block_k * BK,
|
||||
start_k, end_k, curC, param);
|
||||
tileMemcpyAsyncLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, SB2, r, s, block_k * BK,
|
||||
start_k, end_k, thread_row, thread_col, param);
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
#else
|
||||
curC = tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, r, s,
|
||||
masks_a, element_offset_a, thread_row, thread_col, block_k * BK,
|
||||
start_k, end_k, curC, param);
|
||||
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, r, s, block_k * BK,
|
||||
start_k, end_k, thread_row, thread_col, param);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
half* A_warp_tile = SA1 + A_warp_tile_offset;
|
||||
half* B_warp_tile = SB1 + B_warp_tile_offset;
|
||||
#else
|
||||
half* A_warp_tile = A_block_smem + A_warp_tile_offset;
|
||||
half* B_warp_tile = B_block_smem + B_warp_tile_offset;
|
||||
#endif
|
||||
|
||||
ldmatrix_a<mma_tiles_per_warp_m, mma_tiles_per_warp_k, BK>(A_warp_tile, A_register_);
|
||||
ldmatrix_b<mma_tiles_per_warp_k, mma_tiles_per_warp_n, BK>(B_warp_tile, B_register_);
|
||||
|
|
@ -998,8 +1021,11 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
}
|
||||
|
||||
// if (block_k != num_block_tiles_k)
|
||||
if (block_krs != num_block_tiles_krs)
|
||||
{
|
||||
if (block_krs != num_block_tiles_krs) {
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
half *tmp = SA1; SA1 = SA2; SA2 = tmp;
|
||||
tmp = SB1; SB1 = SB2; SB2 = tmp;
|
||||
#else
|
||||
// switch smem buffers each iteration
|
||||
A_block_smem = A_block_smem + BUFFER_SIZE * offset_direction;
|
||||
B_block_smem = B_block_smem + BUFFER_SIZE * offset_direction;
|
||||
|
|
@ -1007,15 +1033,56 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
|
||||
tileMemcpySwizzleStore<BM, NUM_THREADS, 4>(A_gmem_cache_reg, A_block_smem, thread_row, thread_col);
|
||||
tileMemcpySwizzleStore<BN, NUM_THREADS, 4>(B_gmem_cache_reg, B_block_smem, thread_row, thread_col);
|
||||
#endif
|
||||
}
|
||||
|
||||
block_krs++;
|
||||
|
||||
}
|
||||
// A_block_smem = shmem;
|
||||
// B_block_smem = &shmem[BM * BK];
|
||||
|
||||
// } // iter block_k
|
||||
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(0));
|
||||
__syncthreads();
|
||||
half* A_warp_tile = SA2 + A_warp_tile_offset;
|
||||
half* B_warp_tile = SB2 + B_warp_tile_offset;
|
||||
ldmatrix_a<mma_tiles_per_warp_m, mma_tiles_per_warp_k, BK>(A_warp_tile, A_register_);
|
||||
ldmatrix_b<mma_tiles_per_warp_k, mma_tiles_per_warp_n, BK>(B_warp_tile, B_register_);
|
||||
// outer product between mma tiles
|
||||
#pragma unroll
|
||||
for (unsigned int mma_k = 0; mma_k < mma_tiles_per_warp_k; mma_k++){
|
||||
#pragma unroll
|
||||
for (unsigned int mma_n = 0; mma_n < mma_tiles_per_warp_n; mma_n++){
|
||||
#pragma unroll
|
||||
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++){
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_RUBIN
|
||||
asm volatile (
|
||||
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
|
||||
"{%0, %1}, "
|
||||
"{%2, %3, %4, %5}, "
|
||||
"{%6, %7}, "
|
||||
"{%8, %9};"
|
||||
: "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1])
|
||||
: "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]),"r"(A_register[mma_m][mma_k][2]), "r"(A_register[mma_m][mma_k][3]),
|
||||
"r"(B_register[mma_k][mma_n][0]), "r"(B_register[mma_k][mma_n][1])
|
||||
"r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1])
|
||||
);
|
||||
#else
|
||||
asm volatile (
|
||||
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
||||
"{%0, %1}, "
|
||||
"{%2, %3}, "
|
||||
"{%4}, "
|
||||
"{%5, %6};"
|
||||
: "=r"(acc_register[mma_m][mma_n][0]), "=r"(acc_register[mma_m][mma_n][1])
|
||||
: "r"(A_register[mma_m][mma_k][0]), "r"(A_register[mma_m][mma_k][1]),
|
||||
"r"(B_register[mma_k][mma_n])
|
||||
"r"(acc_register[mma_m][mma_n][0]), "r"(acc_register[mma_m][mma_n][1])
|
||||
);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
// if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){
|
||||
// printf(" %u, %f\n", blockIdx.z, __half2float(acc_register_[0][0][0]));
|
||||
|
|
|
|||
|
|
@ -124,6 +124,28 @@ __device__ void prepareIteratorA(unsigned int thread_row,
|
|||
}
|
||||
}
|
||||
|
||||
template <int preload=16>
|
||||
__device__ void cp_async_zfill(void *ptr, void const *global_ptr, bool pred_guard = true) {
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
|
||||
unsigned int smem_ptr;
|
||||
int src_in_bytes = pred_guard ? preload : 0;
|
||||
|
||||
asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 "
|
||||
"%0, smem_ptr; }\n"
|
||||
: "=r"(smem_ptr)
|
||||
: "l"(ptr));
|
||||
|
||||
asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_ptr),
|
||||
"l"(global_ptr),
|
||||
"n"(preload), "r"(src_in_bytes));
|
||||
#else
|
||||
GGML_UNUSED(ptr);
|
||||
GGML_UNUSED(global_ptr);
|
||||
GGML_UNUSED(pred_guard);
|
||||
#endif
|
||||
}
|
||||
|
||||
// same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel
|
||||
template<unsigned int TILE_ROWS,
|
||||
unsigned int NUM_THREADS>
|
||||
|
|
@ -177,17 +199,10 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
|
|||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
unsigned int smem_ptr;
|
||||
void *ptr = (void *)(dst);
|
||||
int src_in_bytes = thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k ? 16 : 0;
|
||||
asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 "
|
||||
"%0, smem_ptr; }\n"
|
||||
: "=r"(smem_ptr)
|
||||
: "l"(ptr));
|
||||
|
||||
asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_ptr),
|
||||
"l"(&src[src_index]),
|
||||
"n"(16), "r"(src_in_bytes));
|
||||
cp_async_zfill((void *)(&dst_float4[dst_index]), (void const *)(&src[src_index]),
|
||||
thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k);
|
||||
|
||||
#else
|
||||
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curC < end_k){
|
||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||
|
|
@ -272,24 +287,14 @@ __device__ __forceinline__ unsigned int tileMemcpySwizzleA(
|
|||
// }
|
||||
// if (valid && curC < end_k){
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
unsigned int smem_ptr;
|
||||
void *ptr = (void *)(dst);
|
||||
int src_in_bytes = valid ? 16 : 0;
|
||||
asm("{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 "
|
||||
"%0, smem_ptr; }\n"
|
||||
: "=r"(smem_ptr)
|
||||
: "l"(ptr));
|
||||
|
||||
asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_ptr),
|
||||
"l"(&src[element_offset[i]+curC]),
|
||||
"n"(16), "r"(src_in_bytes));
|
||||
cp_async_zfill((void *)(&dst_float4[dst_index]), (void const *)(&src[element_offset[i]+curC]), valid);
|
||||
#else
|
||||
if (valid){
|
||||
// if(element_offset[i] >= 327680 || element_offset[i] < 0)
|
||||
// printf("%d, %d, %d, %d, %d, %d, %d, %d, %d \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y,
|
||||
// i, element_offset[i], curR, curS, curC);
|
||||
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[element_offset[i]+curC])[0];
|
||||
} else{
|
||||
} else {
|
||||
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
}
|
||||
#endif
|
||||
|
|
@ -394,36 +399,6 @@ __device__ __forceinline__ unsigned int tileMemcpyLoadA(
|
|||
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
}
|
||||
}
|
||||
// #pragma unroll
|
||||
// for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
// unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
|
||||
// unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||
// unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||
// int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
|
||||
// int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
|
||||
// // unsigned int inOffset = n * param.c * param.h * param.w;
|
||||
// int curH = posh_ori + curR * param.d_h; // input h
|
||||
// int curW = posw_ori + curS * param.d_w; // input w
|
||||
// bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS));
|
||||
// bool ovl = curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
||||
// curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k;
|
||||
// const int txx = curH * (int) inChannelOffset + curW * (int)param.c + (int)curC;
|
||||
|
||||
// if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 1){
|
||||
// printf(" %u, %u, %u, %u, %u, %lld, %lld, %d, %d, %d\n", i, curR, curS, oldC, curC,
|
||||
// element_offset[i], element_offset[i]+(int64_t)curC, n * (int)chw + txx,
|
||||
// valid?1:0, ovl?1:0);
|
||||
// }
|
||||
|
||||
// if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
|
||||
// curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k){
|
||||
// const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
|
||||
// dst_reg[i] = reinterpret_cast<const float4 *>(&src[n * chw + inOffsetTmp])[0];
|
||||
// } else{
|
||||
// dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
|
||||
// }
|
||||
// thread_row += ROW_STEP;
|
||||
// }
|
||||
return curC;
|
||||
#else
|
||||
GGML_UNUSED(src);
|
||||
|
|
@ -443,6 +418,93 @@ __device__ __forceinline__ unsigned int tileMemcpyLoadA(
|
|||
#endif
|
||||
}
|
||||
|
||||
template<unsigned int TILE_ROWS,
|
||||
unsigned int TILE_COLS,
|
||||
unsigned int NUM_THREADS,
|
||||
unsigned int ELEMENTS_PER_THREAD>
|
||||
__device__ __forceinline__ unsigned int tileMemcpyAsyncLoadA(
|
||||
const half* __restrict__ src,
|
||||
half* __restrict__ dst,
|
||||
const unsigned int curR,
|
||||
const unsigned int curS,
|
||||
unsigned int masks[][2],
|
||||
const int64_t element_offset[],
|
||||
unsigned int thread_row,
|
||||
const unsigned int thread_col,
|
||||
const unsigned int block_k,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
unsigned int oldC,
|
||||
// const unsigned int inChannelOffset,
|
||||
param_t param
|
||||
){
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
||||
// # of threads is multiple of # of columns in the tile
|
||||
constexpr unsigned int SWIZZLE_MASK_1 = 0b10000;
|
||||
constexpr unsigned int SWIZZLE_BITS_1 = 4;
|
||||
constexpr unsigned int SWIZZLE_MASK_2 = 0b1100;
|
||||
constexpr unsigned int SWIZZLE_BITS_2 = 2;
|
||||
|
||||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
||||
|
||||
float4* dst_float4 = reinterpret_cast<float4*>(dst);
|
||||
|
||||
// flatten out 2d grid of threads into in order of increasing threadIdx.x
|
||||
// assign each thread a row/column in the tile, calculate how many iterations we need
|
||||
// to cover the whole tile
|
||||
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
|
||||
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
||||
constexpr unsigned int ITER_STEPS = ROW_STEP * TILE_COLS_VECTORIZED;
|
||||
// unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
|
||||
// const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
|
||||
|
||||
// compile time check that we provided the right amount of registers for storage
|
||||
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
||||
|
||||
// const unsigned int ki = start_k+block_k+thread_col*8;
|
||||
// const unsigned int chw = param.c * param.h * param.w;
|
||||
|
||||
// const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset
|
||||
// const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
// const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
|
||||
const unsigned int curC = start_k+block_k+thread_col*8;
|
||||
if (curC > oldC)
|
||||
clear_mask<NUM_ITERS>(masks, curC >= end_k);
|
||||
|
||||
unsigned int iter_idx = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS));
|
||||
// if(threadIdx.x == 3 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 1){
|
||||
// printf(" %u, %u, %u, %u, %u, %lld, %d\n", i, curR, curS, oldC, curC, element_offset[i], valid?1:0);
|
||||
// }
|
||||
unsigned int dst_index = iter_idx;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
|
||||
cp_async_zfill((void *)(&dst_float4[dst_index]), (void const *)(&src[element_offset[i]+curC]), valid);
|
||||
iter_idx += ITER_STEPS;
|
||||
}
|
||||
return curC;
|
||||
#else
|
||||
GGML_UNUSED(src);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(block_k);
|
||||
GGML_UNUSED(curR);
|
||||
GGML_UNUSED(curS);
|
||||
GGML_UNUSED(start_k);
|
||||
GGML_UNUSED(end_k);
|
||||
GGML_UNUSED(masks);
|
||||
GGML_UNUSED(element_offset);
|
||||
GGML_UNUSED(thread_row);
|
||||
GGML_UNUSED(thread_col);
|
||||
GGML_UNUSED(oldC);
|
||||
GGML_UNUSED(param);
|
||||
NO_DEVICE_CODE;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
template<unsigned int TILE_ROWS,
|
||||
unsigned int TILE_COLS,
|
||||
|
|
@ -463,6 +525,12 @@ __device__ __forceinline__ void tileMemcpyLoadB(
|
|||
){
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
|
||||
|
||||
|
||||
constexpr unsigned int SWIZZLE_MASK_1 = 0b10000;
|
||||
constexpr unsigned int SWIZZLE_BITS_1 = 4;
|
||||
constexpr unsigned int SWIZZLE_MASK_2 = 0b1100;
|
||||
constexpr unsigned int SWIZZLE_BITS_2 = 2;
|
||||
|
||||
// # of threads is multiple of # of columns in the tile
|
||||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
||||
|
|
@ -518,6 +586,84 @@ __device__ __forceinline__ void tileMemcpyLoadB(
|
|||
#endif
|
||||
}
|
||||
|
||||
template<unsigned int TILE_ROWS,
|
||||
unsigned int TILE_COLS,
|
||||
unsigned int NUM_THREADS,
|
||||
unsigned int ELEMENTS_PER_THREAD>
|
||||
__device__ __forceinline__ void tileMemcpyAsyncLoadB(
|
||||
const half *src,
|
||||
half *dst,
|
||||
const unsigned int curR,
|
||||
const unsigned int curS,
|
||||
const unsigned int block_k,
|
||||
const unsigned int start_k,
|
||||
const unsigned int end_k,
|
||||
unsigned int thread_row,
|
||||
const unsigned int thread_col,
|
||||
param_t param
|
||||
){
|
||||
|
||||
#if __CUDA_ARCH__ >= GGML_CUDA_AMPERE
|
||||
|
||||
constexpr unsigned int SWIZZLE_MASK_1 = 0b10000;
|
||||
constexpr unsigned int SWIZZLE_BITS_1 = 4;
|
||||
constexpr unsigned int SWIZZLE_MASK_2 = 0b1100;
|
||||
constexpr unsigned int SWIZZLE_BITS_2 = 2;
|
||||
|
||||
// # of threads is multiple of # of columns in the tile
|
||||
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
|
||||
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
|
||||
|
||||
// flatten out 2d grid of threads into in order of increasing threadIdx.x
|
||||
// const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
|
||||
float4* dst_float4 = reinterpret_cast<float4*>(dst);
|
||||
|
||||
// assign each thread a row/column in the tile, calculate how many iterations we need
|
||||
// to cover the whole tile
|
||||
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
|
||||
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
|
||||
constexpr unsigned int ITER_DST_STEPS = ROW_STEP * TILE_COLS_VECTORIZED;
|
||||
|
||||
// compile time check that we provided the right amount of registers for storage
|
||||
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
|
||||
|
||||
const unsigned int curC = start_k+block_k+thread_col*8;
|
||||
const unsigned int ki = (curR*param.s+curS)*param.c + curC;
|
||||
|
||||
unsigned int iter_src_idx = thread_row * param.weightKOffset + ki;
|
||||
unsigned int iter_dst_idx = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
unsigned int krow_idx = thread_row + blockIdx.x * TILE_ROWS;
|
||||
const int ITER_SRC_STEPS = ROW_STEP * param.weightKOffset;
|
||||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
// const unsigned int src_index = thread_row * param.weightKOffset + ki;
|
||||
const unsigned int src_index = iter_src_idx;
|
||||
unsigned int dst_index = iter_dst_idx;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
|
||||
cp_async_zfill((void *)(&dst_float4[dst_index]), (void const *)(&src[src_index]), krow_idx < param.k && curC < end_k);
|
||||
|
||||
iter_src_idx += ITER_SRC_STEPS;
|
||||
krow_idx += ROW_STEP;
|
||||
iter_dst_idx += ITER_DST_STEPS;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(src);
|
||||
GGML_UNUSED(dst);
|
||||
GGML_UNUSED(block_k);
|
||||
GGML_UNUSED(curR);
|
||||
GGML_UNUSED(curS);
|
||||
GGML_UNUSED(start_k);
|
||||
GGML_UNUSED(end_k);
|
||||
GGML_UNUSED(thread_row);
|
||||
GGML_UNUSED(thread_col);
|
||||
GGML_UNUSED(param);
|
||||
NO_DEVICE_CODE;
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
// same as above but without the swizzle
|
||||
|
||||
|
|
|
|||
|
|
@ -716,15 +716,15 @@ int main(void)
|
|||
|
||||
// for(int i = 0; i < ggml_nelements(wino_res); i++) {
|
||||
// for(int i = 0; i < 26*38; i++) {
|
||||
// for(int i = 0; i < conv2d_data.size(); i++) {
|
||||
// float diff = fabs(im2col_data[i] - conv2d_data[i]);
|
||||
// // if(diff > 0.5) {
|
||||
// printf("(%7.3f, %7.3f, %.2f, %d) \n",
|
||||
// im2col_data[i], conv2d_data[i],
|
||||
// diff, i);
|
||||
// // break;
|
||||
// // }
|
||||
// }
|
||||
for(int i = 0; i < conv2d_data.size(); i++) {
|
||||
float diff = fabs(im2col_data[i] - conv2d_data[i]);
|
||||
// if(diff > 0.5) {
|
||||
printf("(%7.3f, %7.3f, %.2f, %d) \n",
|
||||
im2col_data[i], conv2d_data[i],
|
||||
diff, i);
|
||||
// break;
|
||||
// }
|
||||
}
|
||||
|
||||
ggml_free(model.ctx);
|
||||
ggml_backend_buffer_free(model.buffer);
|
||||
|
|
|
|||
Loading…
Reference in New Issue