WIP: move rs loop into block-k-loop following cutlass

This commit is contained in:
bssrdf 2025-11-13 18:44:32 -05:00
parent 8bfb7ed2f2
commit 63c53fe1f1
2 changed files with 55 additions and 32 deletions

View File

@ -803,6 +803,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
const unsigned int start_k = (ksplit > 0) ? z * ks : 0;
const unsigned int end_k = min(start_k + ks, K);
const unsigned int num_block_tiles_k = (ks + (BK-1)) / BK;
const unsigned int num_block_tiles_krs = num_block_tiles_k * param.r * param.s;
constexpr unsigned int TILE_COLS_VECTORIZED = BK / 8;
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
@ -867,26 +868,49 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
prepareIteratorA<BM, BK, A_K_STRID, ROW_STEP>(thread_idx, masks_a, element_offset_a, param);
// prefetch the first block tile of A,B into shared memory
const half* A_block_gmem = input;
const half* B_block_gmem = kernel + block_n * BN * weightKOffset;
int s = 0;
int r = 0;
while (r < param.r) {
// for (int r = 0; r < param.r; ++r) {
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, r, s, masks_a, element_offset_a, thread_idx, start_k, end_k, inChannelOffset, param);
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, r, s, start_k, end_k, weightKOffset, param);
tileMemcpySwizzleA<BM, NUM_THREADS>(A_block_gmem, A_block_smem, 0, 0, masks_a, element_offset_a, thread_idx, start_k, end_k, inChannelOffset, param);
tileMemcpySwizzleB<BN, NUM_THREADS>(B_block_gmem, B_block_smem, 0, 0, start_k, end_k, weightKOffset, param);
int offset_direction = 1;
for (unsigned int block_k = 1; block_k <= num_block_tiles_k; block_k++){
unsigned int block_k = 0;
unsigned int block_krs = 1;
// for (unsigned int block_k = 1; block_k <= num_block_tiles_k; block_k++){
int s = 0;
int r = 0;
while (block_k < num_block_tiles_k){
__syncthreads();
if (block_k != num_block_tiles_k){
// moves to the next tile
int next_idx = 0;
++s;
if (s == param.s) {
s = 0;
++r;
if (r < param.r) {
next_idx = 1;
} else {
r = 0;
next_idx = 2;
}
}
if (next_idx == 2) {
++block_k;
}
// if(block_k == num_block_tiles_k)
// break;
// if(thread_idx == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){
// printf(" s = %d, r = %d, block_k = %d, next_idx = %d , %d %d \n", s, r, block_k, next_idx, block_krs, num_block_tiles_k);
// }
// if (block_k != num_block_tiles_k){
if (block_krs != num_block_tiles_krs){
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, inChannelOffset, param);
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, r, s, block_k * BK, start_k, end_k, weightKOffset, param);
}
@ -932,7 +956,8 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
}
}
if (block_k != num_block_tiles_k)
// if (block_k != num_block_tiles_k)
if (block_krs != num_block_tiles_krs)
{
// switch smem buffers each iteration
A_block_smem = A_block_smem + BUFFER_SIZE * offset_direction;
@ -942,16 +967,14 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
tileMemcpySwizzleStore<BM, NUM_THREADS, 4>(A_gmem_cache_reg, A_block_smem);
tileMemcpySwizzleStore<BN, NUM_THREADS, 4>(B_gmem_cache_reg, B_block_smem);
}
} // iter block_k
s++;
if (s == param.s) {
s = 0;
r++;
}
A_block_smem = shmem;
B_block_smem = &shmem[BM * BK];
} // iter r
block_krs++;
}
// A_block_smem = shmem;
// B_block_smem = &shmem[BM * BK];
// } // iter block_k
// reuse smem
half *smemoutput = shmem;

View File

@ -301,9 +301,9 @@ static std::vector<std::tuple<int, int, int, int, int, int>> configs = {
// std::make_tuple(960,320,104,152,3,3),
// std::make_tuple(1280,1280,26,38,3,3),
// std::make_tuple(1920,640,32,32,3,3)
// std::make_tuple(1280,1280,16,16,3,3),
std::make_tuple(1280,1280,16,16,3,3),
// std::make_tuple(32,8,24,24,3,3),
std::make_tuple(640,640,64,64,3,3),
// std::make_tuple(640,640,64,64,3,3),
// std::make_tuple(320,640,32,32,3,3),
// std::make_tuple(4,320,96,128,3,3),
// std::make_tuple(320,4,96,128,3,3),
@ -673,7 +673,7 @@ int main(void)
// fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f);
int iterations = 20;
int iterations = 0;
double run_time0;
std::vector<float> im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0);
@ -716,15 +716,15 @@ int main(void)
// for(int i = 0; i < ggml_nelements(wino_res); i++) {
// for(int i = 0; i < 26*38; i++) {
// for(int i = 0; i < conv2d_data.size(); i++) {
// float diff = fabs(im2col_data[i] - conv2d_data[i]);
// // if(diff > 0.5) {
// printf("(%7.3f, %7.3f, %.2f, %d) \n",
// im2col_data[i], conv2d_data[i],
// diff, i);
// // break;
// // }
// }
for(int i = 0; i < conv2d_data.size(); i++) {
float diff = fabs(im2col_data[i] - conv2d_data[i]);
// if(diff > 0.5) {
printf("(%7.3f, %7.3f, %.2f, %d) \n",
im2col_data[i], conv2d_data[i],
diff, i);
// break;
// }
}
ggml_free(model.ctx);
ggml_backend_buffer_free(model.buffer);