WIP: fixed bugs now results are correct

This commit is contained in:
bssrdf 2025-11-14 11:10:34 -05:00
parent 7d99222a61
commit b015e4b7dc
3 changed files with 102 additions and 49 deletions

View File

@ -870,13 +870,28 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
prepareIteratorA<BM, BK, A_K_STRID, ROW_STEP>(thread_idx, masks_a, element_offset_a, param);
// for(int kk =0; kk < A_K_STRID; kk++){
// if(element_offset_a[kk] >= 327680)
// printf("%d, %d, %d, %d, %d, %lld \n",
// threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, blockIdx.z,
// element_offset_a[kk]);
// }
// if(threadIdx.x == 64 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){
// printf("A[");
// for(int kk =0; kk < A_K_STRID; kk++)
// printf("%f,", element_offset_a[kk]);
// printf("]\n");
// }
// prefetch the first block tile of A,B into shared memory
const half* A_block_gmem = input;
const half* B_block_gmem = kernel + block_n * BN * weightKOffset;
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, 0, 0, masks_a, element_offset_a,
thread_idx, start_k, end_k, inChannelOffset, param);
unsigned int curC = tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, 0, 0, masks_a, element_offset_a,
thread_idx, start_k, end_k, inChannelOffset, param);
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, 0, 0, start_k, end_k, weightKOffset, param);
int offset_direction = 1;
@ -907,6 +922,18 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
if (next_idx == 2) {
++block_k;
}
// if(threadIdx.x == 64 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){
// printf("B %d,%d,%d [", s, r, block_k);
// for(int kk =0; kk < A_K_STRID; kk++){
// if(element_offset_a[kk] >= 327680)
// printf("%d, %d, %d, %d, %d, %lld, %d, %d, %d %d, %lld\n",
// threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, blockIdx.z,
// element_offset_a[kk], r, s, block_k, next_idx, param.inc_next[next_idx]);
// }
// threadIdx.x == 64 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){
// printf("%f,", element_offset_a[kk]);
// printf("]\n");
// if(block_k == num_block_tiles_k)
// break;
@ -916,11 +943,12 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
// if (block_k != num_block_tiles_k){
if (block_krs != num_block_tiles_krs){
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, r, s,
curC = tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, r, s,
masks_a, element_offset_a, thread_idx, block_k * BK,
start_k, end_k, inChannelOffset, param);
start_k, end_k, curC, inChannelOffset, param);
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, weightKOffset, param);
}
half* A_warp_tile = A_block_smem + A_warp_tile_offset;
half* B_warp_tile = B_block_smem + B_warp_tile_offset;
@ -983,6 +1011,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
// } // iter block_k
// if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0){
// printf(" %u, %f\n", blockIdx.z, __half2float(acc_register_[0][0][0]));
// }
// reuse smem
half *smemoutput = shmem;
const uint lane_id = threadIdx.x % WARPSIZE;
@ -1116,15 +1148,6 @@ static void launch_conv2d_implicit_split_kernel(ggml_backend_cuda_context & ctx,
conv2d_implicit_kernel<half, BM, BN, BK,
WM, WN, WK, ksplit, NUM_THREADS><<<gridDim, blockDim, shmem_bytes, st>>>(X_H, K_H, Y_H.get(), P);
int64_t inc[3];
// next S
inc[0] = int64_t(P.c) * P.d_w;
// next R
inc[1] = int64_t(P.w * P.c) * P.d_h - (P.s - 1) * P.c * P.d_w;
// next C
inc[2] = BK - int64_t(P.r - 1) * P.w * P.c * P.d_h - int64_t(P.s - 1) * P.c * P.d_w ;
memcpy(P.inc_next, inc, sizeof(int64_t)*3);
const unsigned int nrows = P.n * P.k * P.Oh * P.Ow;
const unsigned int blockx = (nrows + 511) / 512;
const dim3 block_nums(blockx, 1, 1);
@ -1139,6 +1162,15 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
int id = ggml_cuda_get_device();
int64_t inc[3];
// next S
inc[0] = int64_t(P.c) * P.d_w;
// next R
inc[1] = int64_t(P.w * P.c) * P.d_h - (P.s - 1) * P.c * P.d_w;
// next C
inc[2] = - int64_t(P.r - 1) * P.w * P.c * P.d_h - int64_t(P.s - 1) * P.c * P.d_w ;
memcpy(P.inc_next, inc, sizeof(int64_t)*3);
int64_t ne = P.c * P.h * P.w * P.n;
int64_t ne00 = P.c;
int64_t ne01 = P.h * P.w;
@ -1295,15 +1327,6 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
}
}
int64_t inc[3];
// next S
inc[0] = int64_t(P.c) * P.d_w;
// next R
inc[1] = int64_t(P.w * P.c) * P.d_h - (P.s - 1) * P.c * P.d_w;
// next C
inc[2] = BK_dim - int64_t(P.r - 1) * P.w * P.c * P.d_h - int64_t(P.s - 1) * P.c * P.d_w ;
memcpy(P.inc_next, inc, sizeof(int64_t)*3);
cudaFuncSetAttribute(conv2d_implicit_kernel<float, BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, 0, NumThreads>,
cudaFuncAttributeMaxDynamicSharedMemorySize, 65536); // set shared memory limit to 64KB which is maximum for sm_75
dim3 gridDim(BlocksN, BlocksM);

View File

@ -74,8 +74,8 @@ __device__ void prepareIteratorA(const int thread_idx,
unsigned int npq_res = fastmodulo(gemm_i, param.OHOW_fastdiv);
offset_p[s] = fastdiv(npq_res, param.OW_fastdiv); //* param.u - param.p;
offset_q[s] = fastmodulo(npq_res, param.OW_fastdiv); // * param.v - param.q;
const int h = offset_p[s] * param.u - param.p;
const int w = offset_q[s] * param.v - param.q;
const int h = offset_p[s] * (int)param.u - (int) param.p;
const int w = offset_q[s] * (int)param.v - (int) param.q;
// if(threadIdx.x < 32 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0)
// printf("%d, %d : %d, %d, %d, %d offset (%d, %d, %d), kele %llu Kcont %d\n ", thread_idx, s,
@ -84,7 +84,12 @@ __device__ void prepareIteratorA(const int thread_idx,
// offset_npq, offset_n[s], offset_p[s], offset_q[s], AccessType::kElements,
// ThreadMap::Iterations::kContiguous);
element_offset[s] = offset_n[s] * chw + h * param.c * param.w + w * param.c;
element_offset[s] = offset_n[s] * (int64_t)chw + h * (int64_t)(param.c * param.w) + w * (int64_t)param.c;
// if(element_offset[s] >= 327680)
// printf("(%d, %d, %d, %d, %d), %d, %lld, %d, %d, %d, %d, %d, %u, %u, %u \n",
// threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y, blockIdx.z,
// s, element_offset[s], offset_n[s], offset_p[s], offset_q[s], h, w, chw, param.c * param.w, param.c);
thread_row += ROW_STEP;
}
@ -180,12 +185,12 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
// this is a special case of the above for when TILE_COLS == 32
template<unsigned int TILE_ROWS,
unsigned int NUM_THREADS>
__device__ __forceinline__ void tileMemcpySwizzleA(
__device__ __forceinline__ unsigned int tileMemcpySwizzleA(
const half* src,
half* dst,
const unsigned int curR,
const unsigned int curS,
const unsigned int masks[][2],
unsigned int masks[][2],
const int64_t element_offset[],
const unsigned int thread_idx,
const unsigned int start_k,
@ -218,23 +223,29 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
// const unsigned int ki = start_k+thread_col*8;
const unsigned int chw = param.c * param.h * param.w;
// const unsigned int chw = param.c * param.h * param.w;
// const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset
// const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const unsigned int curC = start_k+thread_col*8;
clear_mask<NUM_ITERS>(masks, curC >= end_k);
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS));
bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS));
// apply swizzle to the dst index
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
if (valid && curC < end_k){
if(element_offset[i] >= 327680 || element_offset[i] < 0)
printf("%d, %d, %d, %d, %d, %d, %d, %d, %d \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y,
i, element_offset[i], curR, curS, curC);
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[element_offset[i]])[0];
// if(threadIdx.x == 3 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 1){
// printf(" %u, %u, %u, %u, %lld, %d\n", i, curR, curS, curC, element_offset[i], valid?1:0);
// }
// if (valid && curC < end_k){
if (valid){
// if(element_offset[i] >= 327680 || element_offset[i] < 0)
// printf("%d, %d, %d, %d, %d, %d, %d, %d, %d \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y,
// i, element_offset[i], curR, curS, curC);
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[element_offset[i]+curC])[0];
} else{
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
}
@ -263,6 +274,7 @@ __device__ __forceinline__ void tileMemcpySwizzleA(
// }
// thread_row += ROW_STEP;
// }
return curC;
#else
GGML_UNUSED(src);
GGML_UNUSED(dst);
@ -276,17 +288,18 @@ template<unsigned int TILE_ROWS,
unsigned int TILE_COLS,
unsigned int NUM_THREADS,
unsigned int ELEMENTS_PER_THREAD>
__device__ __forceinline__ void tileMemcpyLoadA(
__device__ __forceinline__ unsigned int tileMemcpyLoadA(
const half* src,
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
const unsigned int curR,
const unsigned int curS,
const unsigned int masks[][2],
unsigned int masks[][2],
const int64_t element_offset[],
const unsigned int thread_idx,
const unsigned int block_k,
const unsigned int start_k,
const unsigned int end_k,
unsigned int oldC,
const unsigned int inChannelOffset,
param_t param
){
@ -301,7 +314,7 @@ __device__ __forceinline__ void tileMemcpyLoadA(
// to cover the whole tile
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
// unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
// compile time check that we provided the right amount of registers for storage
@ -313,13 +326,18 @@ __device__ __forceinline__ void tileMemcpyLoadA(
// const unsigned int curR = fastdiv(ki, param.SC_fastdiv); // channel offset
// const unsigned int curS = fastdiv(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
// const unsigned int curC = fastmodulo(fastmodulo(ki, param.SC_fastdiv), param.C_fastdiv); // kernel r offset
const unsigned int curC = start_k+block_k+thread_col*8;;
const unsigned int curC = start_k+block_k+thread_col*8;
if (curC > oldC)
clear_mask<NUM_ITERS>(masks, curC >= end_k);
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS));
if (valid && curC < end_k) {
dst_reg[i] = reinterpret_cast<const float4 *>(&src[element_offset[i]])[0];
// if(threadIdx.x == 3 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 1){
// printf(" %u, %u, %u, %u, %u, %lld, %d\n", i, curR, curS, oldC, curC, element_offset[i], valid?1:0);
// }
if (valid) {
dst_reg[i] = reinterpret_cast<const float4 *>(&src[element_offset[i]+curC])[0];
} else{
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
}
@ -334,6 +352,17 @@ __device__ __forceinline__ void tileMemcpyLoadA(
// // unsigned int inOffset = n * param.c * param.h * param.w;
// int curH = posh_ori + curR * param.d_h; // input h
// int curW = posw_ori + curS * param.d_w; // input w
// bool valid = (masks[i][0] & (1u << curR)) && (masks[i][1] & (1u << curS));
// bool ovl = curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
// curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k;
// const int txx = curH * (int) inChannelOffset + curW * (int)param.c + (int)curC;
// if(threadIdx.x == 0 && threadIdx.y == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 1){
// printf(" %u, %u, %u, %u, %u, %lld, %lld, %d, %d, %d\n", i, curR, curS, oldC, curC,
// element_offset[i], element_offset[i]+(int64_t)curC, n * (int)chw + txx,
// valid?1:0, ovl?1:0);
// }
// if (curH >= 0 && curW >= 0 && curW < param.w && curH < param.h &&
// curR < param.r && curS < param.s && curC < param.c && n < param.n && ki < end_k){
// const unsigned int inOffsetTmp = curH * inChannelOffset + curW * param.c + curC;
@ -343,6 +372,7 @@ __device__ __forceinline__ void tileMemcpyLoadA(
// }
// thread_row += ROW_STEP;
// }
return curC;
#else
GGML_UNUSED(src);
GGML_UNUSED(dst_reg);

View File

@ -716,15 +716,15 @@ int main(void)
// for(int i = 0; i < ggml_nelements(wino_res); i++) {
// for(int i = 0; i < 26*38; i++) {
for(int i = 0; i < conv2d_data.size(); i++) {
float diff = fabs(im2col_data[i] - conv2d_data[i]);
// if(diff > 0.5) {
printf("(%7.3f, %7.3f, %.2f, %d) \n",
im2col_data[i], conv2d_data[i],
diff, i);
// break;
// }
}
// for(int i = 0; i < conv2d_data.size(); i++) {
// float diff = fabs(im2col_data[i] - conv2d_data[i]);
// // if(diff > 0.5) {
// printf("(%7.3f, %7.3f, %.2f, %d) \n",
// im2col_data[i], conv2d_data[i],
// diff, i);
// // break;
// // }
// }
ggml_free(model.ctx);
ggml_backend_buffer_free(model.buffer);