WIP: debugging

This commit is contained in:
bssrdf 2025-11-13 22:08:41 -05:00
parent 63c53fe1f1
commit 7d99222a61
2 changed files with 104 additions and 58 deletions

View File

@ -811,7 +811,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
constexpr unsigned int B_K_STRID = BN / ROW_STEP;
unsigned int masks_a[A_K_STRID][2];
unsigned int element_offset_a[A_K_STRID];
int64_t element_offset_a[A_K_STRID];
// calculate block/warp indices
const unsigned int block_m = blockIdx.y;
@ -867,6 +867,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
float4 B_gmem_cache_reg[4];
prepareIteratorA<BM, BK, A_K_STRID, ROW_STEP>(thread_idx, masks_a, element_offset_a, param);
// prefetch the first block tile of A,B into shared memory
@ -874,7 +875,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
const half* A_block_gmem = input;
const half* B_block_gmem = kernel + block_n * BN * weightKOffset;
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, 0, 0, masks_a, element_offset_a, thread_idx, start_k, end_k, inChannelOffset, param);
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, 0, 0, masks_a, element_offset_a,
thread_idx, start_k, end_k, inChannelOffset, param);
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, 0, 0, start_k, end_k, weightKOffset, param);
int offset_direction = 1;
@ -899,6 +901,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
next_idx = 2;
}
}
add_byte_offset<A_K_STRID>(element_offset_a, param.inc_next[next_idx]);
if (next_idx == 2) {
++block_k;
}
@ -911,7 +916,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
// if (block_k != num_block_tiles_k){
if (block_krs != num_block_tiles_krs){
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, inChannelOffset, param);
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, r, s,
masks_a, element_offset_a, thread_idx, block_k * BK,
start_k, end_k, inChannelOffset, param);
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, weightKOffset, param);
}
half* A_warp_tile = A_block_smem + A_warp_tile_offset;
@ -1096,7 +1103,7 @@ template<const int BM, const int BN, const int BK,
static void launch_conv2d_implicit_split_kernel(ggml_backend_cuda_context & ctx, const half *X_H, const half *K_H, float *Y_D,
const unsigned int BlocksM, const unsigned int BlocksN,
const unsigned int shmem_bytes,
const param_t P, cudaStream_t st){
param_t P, cudaStream_t st){
int id = ggml_cuda_get_device();
@ -1109,6 +1116,15 @@ static void launch_conv2d_implicit_split_kernel(ggml_backend_cuda_context & ctx,
conv2d_implicit_kernel<half, BM, BN, BK,
WM, WN, WK, ksplit, NUM_THREADS><<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
int64_t inc[3];
// next S
inc[0] = int64_t(P.c) * P.d_w;
// next R
inc[1] = int64_t(P.w * P.c) * P.d_h - (P.s - 1) * P.c * P.d_w;
// next C
inc[2] = BK - int64_t(P.r - 1) * P.w * P.c * P.d_h - int64_t(P.s - 1) * P.c * P.d_w ;
memcpy(P.inc_next, inc, sizeof(int64_t)*3);
const unsigned int nrows = P.n * P.k * P.Oh * P.Ow;
const unsigned int blockx = (nrows + 511) / 512;
const dim3 block_nums(blockx, 1, 1);
@ -1116,7 +1132,7 @@ static void launch_conv2d_implicit_split_kernel(ggml_backend_cuda_context & ctx,
reduce_f32<half, float><<<block_nums, block_dims, 0, st>>>(Y_H.get(), Y_D, nrows, ksplit);
}
static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, const param_t P, cudaStream_t st) {
static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const float * X_D, const half * K_D, float * Y_D, int cc, param_t P, cudaStream_t st) {
// if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0 && (P.r > 1 || P.s > 1)) {
if (GGML_CUDA_CC_IS_NVIDIA(cc) && turing_mma_available(cc) && P.c % 8 == 0) {
@ -1279,6 +1295,15 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
}
}
int64_t inc[3];
// next S
inc[0] = int64_t(P.c) * P.d_w;
// next R
inc[1] = int64_t(P.w * P.c) * P.d_h - (P.s - 1) * P.c * P.d_w;
// next C
inc[2] = BK_dim - int64_t(P.r - 1) * P.w * P.c * P.d_h - int64_t(P.s - 1) * P.c * P.d_w ;
memcpy(P.inc_next, inc, sizeof(int64_t)*3);
cudaFuncSetAttribute(conv2d_implicit_kernel<float, BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 0, NumThreads>,
cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
dim3 gridDim(BlocksN, BlocksM);
@ -1340,6 +1365,8 @@ void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor *
const uint OC = kernel->ne[3]; // ouptut_chanles
const uint B = input->ne[3]; // n_batches
param_t params = { B, IC, IH, IW, OC, KH, KW, ST_Y, ST_X, PD_Y, PD_X, DL_Y, DL_X, OH, OW,
init_fastdiv_values(KW*IC),
init_fastdiv_values(OW),

View File

@ -23,6 +23,7 @@ typedef struct{
uint3 RS_fastdiv;
uint3 S_fastdiv;
uint3 OHOW_fastdiv;
int64_t inc_next[3];
} param_t;
@ -38,13 +39,21 @@ __host__ __device__ void clear_mask(unsigned int masks_[][2], bool clear = true)
}
}
template<const unsigned int K_STRID>
__host__ __device__ void add_byte_offset(int64_t element_offset[], const int64_t offset){
#pragma unroll
for (int s = 0; s < K_STRID; ++s) {
element_offset[s] += offset;
}
}
template<const unsigned int TILE_ROWS,
const unsigned int TILE_COLS,
const unsigned int A_K_STRID,
const unsigned int ROW_STEP>
__device__ void prepareIteratorA(const int thread_idx,
unsigned int masks[][2],
unsigned int element_offset[],
int64_t element_offset[],
const param_t param){
int offset_n[A_K_STRID];
int offset_p[A_K_STRID];
@ -176,8 +185,8 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
half* dst,
const unsigned int curR,
const unsigned int curS,
unsigned int masks[][2],
unsigned int element_offset[],
const unsigned int masks[][2],
const int64_t element_offset[],
const unsigned int thread_idx,
const unsigned int start_k,
const unsigned int end_k,
@ -208,52 +217,52 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
const unsigned int ki = start_k+thread_col*8;
// const unsigned int ki = start_k+thread_col*8;
const unsigned int chw = param.c * param.h * param.w;
// const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset
// const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const unsigned int curC = ki;
// #pragma unroll
// for (unsigned int i = 0; i < NUM_ITERS; i++){
// bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS));
// // apply swizzle to the dst index
// unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
// dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
// dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
// if (valid && ki < end_k){
// if(element_offset[i]+curC >= 327680 || element_offset[i]+curC < 0)
// printf("%d, %d, %d, %d, %d, %d, %d, %d, %d \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y,
// i, element_offset[i], curR, curS, curC);
// dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[element_offset[i]+curC])[0];
// } else{
// dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
// }
// thread_row += ROW_STEP;
// }
const unsigned int curC = start_k+thread_col*8;
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
// unsigned int inOffset = n * param.c * param.h * param.w;
int curH = posh_ori + curR * param.d_h; // input h
int curW = posw_ori + curS * param.d_w; // input w
bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS));
// apply swizzle to the dst index
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k){
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[n * chw + inOffsetTmp])[0];
if (valid && curC < end_k){
if(element_offset[i] >= 327680 || element_offset[i] < 0)
printf("%d, %d, %d, %d, %d, %d, %d, %d, %d \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y,
i, element_offset[i], curR, curS, curC);
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[element_offset[i]])[0];
} else{
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
}
thread_row += ROW_STEP;
}
// #pragma unroll
// for (unsigned int i = 0; i < NUM_ITERS; i++){
// unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
// unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
// unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
// int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
// int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
// // unsigned int inOffset = n * param.c * param.h * param.w;
// int curH = posh_ori + curR * param.d_h; // input h
// int curW = posw_ori + curS * param.d_w; // input w
// // apply swizzle to the dst index
// unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
// dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
// dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
// if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
// curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k){
// const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
// dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[n * chw + inOffsetTmp])[0];
// } else{
// dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
// }
// thread_row += ROW_STEP;
// }
#else
GGML_UNUSED(src);
GGML_UNUSED(dst);
@ -272,6 +281,9 @@ __device__ __forceinline__ void tileMemcpyLoadA(
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
const unsigned int curR,
const unsigned int curS,
const unsigned int masks[][2],
const int64_t element_offset[],
const unsigned int thread_idx,
const unsigned int block_k,
const unsigned int start_k,
const unsigned int end_k,
@ -285,45 +297,52 @@ __device__ __forceinline__ void tileMemcpyLoadA(
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
// flatten out 2d grid of threads into in order of increasing threadIdx.x
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
// assign each thread a row/column in the tile, calculate how many iterations we need
// to cover the whole tile
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
// unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
// compile time check that we provided the right amount of registers for storage
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
const unsigned int ki = start_k+block_k+thread_col*8;
const unsigned int chw = param.c * param.h * param.w;
// const unsigned int ki = start_k+block_k+thread_col*8;
// const unsigned int chw = param.c * param.h * param.w;
// const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset
// const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const unsigned int curC = ki;
const unsigned int curC = start_k+block_k+thread_col*8;;
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
// unsigned int inOffset = n * param.c * param.h * param.w;
int curH = posh_ori + curR * param.d_h; // input h
int curW = posw_ori + curS * param.d_w; // input w
if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k){
const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
dst_reg[i] = reinterpret_cast<const float4 *>(&src[n * chw + inOffsetTmp])[0];
bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS));
if (valid && curC < end_k) {
dst_reg[i] = reinterpret_cast<const float4 *>(&src[element_offset[i]])[0];
} else{
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
}
thread_row += ROW_STEP;
}
// #pragma unroll
// for (unsigned int i = 0; i < NUM_ITERS; i++){
// unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
// unsigned int n = fastdiv(gemm_i, param.OHOW_fastdiv);
// unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
// int posh_ori = fastdiv(npq_res, param.OW_fastdiv) * param.u - param.p;
// int posw_ori = fastmodulo(npq_res, param.OW_fastdiv) * param.v - param.q;
// // unsigned int inOffset = n * param.c * param.h * param.w;
// int curH = posh_ori + curR * param.d_h; // input h
// int curW = posw_ori + curS * param.d_w; // input w
// if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
// curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k){
// const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
// dst_reg[i] = reinterpret_cast<const float4 *>(&src[n * chw + inOffsetTmp])[0];
// } else{
// dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
// }
// thread_row += ROW_STEP;
// }
#else
GGML_UNUSED(src);
GGML_UNUSED(dst_reg);