trying to get rid of remaining bank conflicts; also fixed a bug for split-k condition check

This commit is contained in:
bssrdf 2025-11-07 15:38:36 -05:00
parent 4e9ebe92e0
commit df88b2c917
3 changed files with 150 additions and 124 deletions

View File

@ -672,12 +672,11 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
#pragma unroll
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++)
{
const int output_sts_offset = output_sts_addr + mma_m * MMA_M * BN / 2 - i * mma_tiles_per_warp_n/2 * MMA_N;
for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++)
{
uint32_t (&reg_)[2] = reinterpret_cast<uint32_t(&)[2]>(acc_register_[mma_m][mma_n]);
uint idx = output_sts_offset + mma_n * MMA_N;
// mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N;
uint idx = output_sts_addr +
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N;
idx = idx ^ ((idx & 0b1110000000) >> 4);
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
dst_ptr[0] = reg_[0];
@ -688,24 +687,40 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
__syncthreads();
const unsigned int m_i_wn = m_idx + i * WN / 2;
#pragma unroll
for (int subk = 0; subk < WN / 2; ++subk){
const uint row = m_i_wn + subk;
for (int subk = 0; subk < WN / 4; ++subk){
const uint row = m_i_wn + subk*2;
#pragma unroll
for (int j = 0; j < 4; ++j){
const uint gemm_i = n_idx + j*32;
const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
uint idx = output_lds_addr + subk*2 + j*32*BN/2;
idx = idx ^ ((idx & 0b1110000000) >> 4);
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
if (n < param.n && row < param.k && col < PQ) {
uint idx = output_lds_addr + subk + j*32*BN/2;
idx = idx ^ ((idx & 0b1110000000) >> 4);
if constexpr (ksplit > 0) {
const uint outOffset = z * NKPQ +
n * KPQ +
row * PQ + col;
output[outOffset] = smemoutput[idx];
// output[outOffset] = smemoutput[idx];
output[outOffset] = reinterpret_cast<half *>(dst_ptr)[0];
} else {
const uint outOffset = n * KPQ + row * PQ + col;
output[outOffset] = smemoutput[idx];
// output[outOffset] = smemoutput[idx];
output[outOffset] = reinterpret_cast<half *>(dst_ptr)[0];
}
}
if (n < param.n && row+1 < param.k && col < PQ) {
if constexpr (ksplit > 0) {
const uint outOffset = z * NKPQ +
n * KPQ +
(row+1) * PQ + col;
// output[outOffset] = smemoutput[idx];
output[outOffset] = reinterpret_cast<half *>(dst_ptr)[1];
} else {
const uint outOffset = n * KPQ + (row+1) * PQ + col;
// output[outOffset] = smemoutput[idx];
output[outOffset] = reinterpret_cast<half *>(dst_ptr)[1];
}
}
}
@ -803,6 +818,9 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
constexpr unsigned int WM_dim = BM_dim / WARPS_PER_BLOCK_M;
constexpr unsigned int WN_dim = BN_dim / WARPS_PER_BLOCK_N;
constexpr unsigned int WK_dim = BK_dim / WARPS_PER_BLOCK_K;
static_assert(WN_dim % 4 == 0, "final output requires this to be bank conflicts free");
const unsigned int BlocksM = (P.n * P.Oh * P.Ow + BM_dim - 1) / BM_dim;
const unsigned int BlocksN = (P.k + BN_dim - 1) / BN_dim;
constexpr unsigned int ThreadsM = WARPS_PER_BLOCK_M;
@ -812,7 +830,7 @@ static void conv2d_implicit_cuda_f16(ggml_backend_cuda_context & ctx, const floa
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
const unsigned int ksplit = 8;
if (BlocksM * BlocksN < nsm && P.c > 8 * ksplit) {
if (BlocksM * BlocksN < nsm && P.c >= 8 * ksplit && (P.c * P.r * P.s) % (8*ksplit) == 0) {
ggml_cuda_pool_alloc<half> Y_H(ctx.pool(id), ksplit * P.k * P.Oh * P.Ow * P.n);
cudaFuncSetAttribute(conv2d_implicit_kernel<BM_dim, BN_dim, BK_dim, WM_dim, WN_dim, WK_dim, ksplit, NumThreads>,

View File

@ -5848,6 +5848,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
test_cases.emplace_back(new test_conv_2d( { 24, 24, 32, 1 }, { 3, 3, 32, 8},
GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false));
test_cases.emplace_back(new test_conv_2d( { 24, 24, 96, 1 }, { 3, 3, 96, 8},
GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false));
test_cases.emplace_back(new test_conv_2d( { 24, 24, 128, 1 }, { 3, 3, 128, 8},
GGML_TYPE_F16, 1, 1, 1, 1, 1, 1, false));
// sycl backend will limit task global_range < MAX_INT
// test cases for 2D im2col with large input W and H (occurs in stable-diffusion)
// however these cases need to alloc more memory which may fail in some devices (Intel Arc770, etc.)

View File

@ -305,7 +305,7 @@ int main(void)
// std::make_tuple(640,640,52,76,3,3),
// std::make_tuple(640,640,104,152,3,3),
// std::make_tuple(960,320,104,152,3,3),
// std::make_tuple(1280,1280,26,38,3,3),
std::make_tuple(1280,1280,26,38,3,3),
// std::make_tuple(4,320,96,128,3,3),
// std::make_tuple(320,4,96,128,3,3),
// std::make_tuple(4,320,64,96,3,3),
@ -538,108 +538,108 @@ int main(void)
//1024x1024
std::make_tuple(4,320,128,128,3,3),
std::make_tuple(320,320,128,128,3,3),
std::make_tuple(320,320,128,128,3,3),
std::make_tuple(320,320,128,128,3,3),
std::make_tuple(320,320,128,128,3,3),
std::make_tuple(320,320,128,128,3,3),
std::make_tuple(320,640,64,64,3,3),
std::make_tuple(640,640,64,64,3,3),
std::make_tuple(320,640,64,64,3,3),
std::make_tuple(640,640,64,64,3,3),
std::make_tuple(640,640,64,64,3,3),
std::make_tuple(640,640,64,64,3,3),
std::make_tuple(640,1280,32,32,3,3),
std::make_tuple(1280,1280,32,32,3,3),
std::make_tuple(640,1280,32,32,3,3),
std::make_tuple(1280,1280,32,32,3,3),
std::make_tuple(1280,1280,32,32,3,3),
std::make_tuple(1280,1280,32,32,3,3),
std::make_tuple(1280,1280,32,32,3,3),
std::make_tuple(1280,1280,32,32,3,3),
std::make_tuple(1280,1280,32,32,3,3),
std::make_tuple(2560,1280,32,32,3,3),
std::make_tuple(1280,1280,32,32,3,3),
std::make_tuple(2560,1280,32,32,3,3),
std::make_tuple(2560,1280,32,32,3,3),
std::make_tuple(1280,1280,32,32,3,3),
std::make_tuple(2560,1280,32,32,3,3),
std::make_tuple(1920,1280,32,32,3,3),
std::make_tuple(1280,1280,32,32,3,3),
std::make_tuple(1920,1280,32,32,3,3),
std::make_tuple(1280,1280,64,64,3,3),
std::make_tuple(1920,640,64,64,3,3),
std::make_tuple(640,640,64,64,3,3),
std::make_tuple(1920,640,64,64,3,3),
std::make_tuple(1280,640,64,64,3,3),
std::make_tuple(640,640,64,64,3,3),
std::make_tuple(1280,640,64,64,3,3),
std::make_tuple(960,640,64,64,3,3),
std::make_tuple(640,640,64,64,3,3),
std::make_tuple(960,640,64,64,3,3),
std::make_tuple(640,640,128,128,3,3),
std::make_tuple(960,320,128,128,3,3),
std::make_tuple(320,320,128,128,3,3),
std::make_tuple(960,320,128,128,3,3),
std::make_tuple(640,320,128,128,3,3),
std::make_tuple(320,320,128,128,3,3),
std::make_tuple(640,320,128,128,3,3),
std::make_tuple(640,320,128,128,3,3),
std::make_tuple(320,320,128,128,3,3),
std::make_tuple(640,320,128,128,3,3),
std::make_tuple(320,4,128,128,3,3),
std::make_tuple(4,320,128,128,3,3),
std::make_tuple(320,320,128,128,3,3),
std::make_tuple(320,320,128,128,3,3),
std::make_tuple(320,320,128,128,3,3),
std::make_tuple(320,320,128,128,3,3),
std::make_tuple(320,320,128,128,3,3),
std::make_tuple(320,640,64,64,3,3),
std::make_tuple(640,640,64,64,3,3),
std::make_tuple(320,640,64,64,3,3),
std::make_tuple(640,640,64,64,3,3),
std::make_tuple(640,640,64,64,3,3),
std::make_tuple(640,640,64,64,3,3),
std::make_tuple(640,1280,32,32,3,3),
std::make_tuple(1280,1280,32,32,3,3),
std::make_tuple(640,1280,32,32,3,3),
std::make_tuple(1280,1280,32,32,3,3),
std::make_tuple(1280,1280,32,32,3,3),
std::make_tuple(1280,1280,32,32,3,3),
std::make_tuple(1280,1280,32,32,3,3),
std::make_tuple(1280,1280,32,32,3,3),
std::make_tuple(1280,1280,32,32,3,3),
std::make_tuple(2560,1280,32,32,3,3),
std::make_tuple(1280,1280,32,32,3,3),
std::make_tuple(2560,1280,32,32,3,3),
std::make_tuple(2560,1280,32,32,3,3),
std::make_tuple(1280,1280,32,32,3,3),
std::make_tuple(2560,1280,32,32,3,3),
std::make_tuple(1920,1280,32,32,3,3),
std::make_tuple(1280,1280,32,32,3,3),
std::make_tuple(1920,1280,32,32,3,3),
std::make_tuple(1280,1280,64,64,3,3),
std::make_tuple(1920,640,64,64,3,3),
std::make_tuple(640,640,64,64,3,3),
std::make_tuple(1920,640,64,64,3,3),
std::make_tuple(1280,640,64,64,3,3),
std::make_tuple(640,640,64,64,3,3),
std::make_tuple(1280,640,64,64,3,3),
std::make_tuple(960,640,64,64,3,3),
std::make_tuple(640,640,64,64,3,3),
std::make_tuple(960,640,64,64,3,3),
std::make_tuple(640,640,128,128,3,3),
std::make_tuple(960,320,128,128,3,3),
std::make_tuple(320,320,128,128,3,3),
std::make_tuple(960,320,128,128,3,3),
std::make_tuple(640,320,128,128,3,3),
std::make_tuple(320,320,128,128,3,3),
std::make_tuple(640,320,128,128,3,3),
std::make_tuple(640,320,128,128,3,3),
std::make_tuple(320,320,128,128,3,3),
std::make_tuple(640,320,128,128,3,3),
std::make_tuple(320,4,128,128,3,3),
// std::make_tuple(4,320,128,128,3,3),
// std::make_tuple(320,320,128,128,3,3),
// std::make_tuple(320,320,128,128,3,3),
// std::make_tuple(320,320,128,128,3,3),
// std::make_tuple(320,320,128,128,3,3),
// std::make_tuple(320,320,128,128,3,3),
// std::make_tuple(320,640,64,64,3,3),
// std::make_tuple(640,640,64,64,3,3),
// std::make_tuple(320,640,64,64,3,3),
// std::make_tuple(640,640,64,64,3,3),
// std::make_tuple(640,640,64,64,3,3),
// std::make_tuple(640,640,64,64,3,3),
// std::make_tuple(640,1280,32,32,3,3),
// std::make_tuple(1280,1280,32,32,3,3),
// std::make_tuple(640,1280,32,32,3,3),
// std::make_tuple(1280,1280,32,32,3,3),
// std::make_tuple(1280,1280,32,32,3,3),
// std::make_tuple(1280,1280,32,32,3,3),
// std::make_tuple(1280,1280,32,32,3,3),
// std::make_tuple(1280,1280,32,32,3,3),
// std::make_tuple(1280,1280,32,32,3,3),
// std::make_tuple(2560,1280,32,32,3,3),
// std::make_tuple(1280,1280,32,32,3,3),
// std::make_tuple(2560,1280,32,32,3,3),
// std::make_tuple(2560,1280,32,32,3,3),
// std::make_tuple(1280,1280,32,32,3,3),
// std::make_tuple(2560,1280,32,32,3,3),
// std::make_tuple(1920,1280,32,32,3,3),
// std::make_tuple(1280,1280,32,32,3,3),
// std::make_tuple(1920,1280,32,32,3,3),
// std::make_tuple(1280,1280,64,64,3,3),
// std::make_tuple(1920,640,64,64,3,3),
// std::make_tuple(640,640,64,64,3,3),
// std::make_tuple(1920,640,64,64,3,3),
// std::make_tuple(1280,640,64,64,3,3),
// std::make_tuple(640,640,64,64,3,3),
// std::make_tuple(1280,640,64,64,3,3),
// std::make_tuple(960,640,64,64,3,3),
// std::make_tuple(640,640,64,64,3,3),
// std::make_tuple(960,640,64,64,3,3),
// std::make_tuple(640,640,128,128,3,3),
// std::make_tuple(960,320,128,128,3,3),
// std::make_tuple(320,320,128,128,3,3),
// std::make_tuple(960,320,128,128,3,3),
// std::make_tuple(640,320,128,128,3,3),
// std::make_tuple(320,320,128,128,3,3),
// std::make_tuple(640,320,128,128,3,3),
// std::make_tuple(640,320,128,128,3,3),
// std::make_tuple(320,320,128,128,3,3),
// std::make_tuple(640,320,128,128,3,3),
// std::make_tuple(320,4,128,128,3,3),
// std::make_tuple(4,320,128,128,3,3),
// std::make_tuple(320,320,128,128,3,3),
// std::make_tuple(320,320,128,128,3,3),
// std::make_tuple(320,320,128,128,3,3),
// std::make_tuple(320,320,128,128,3,3),
// std::make_tuple(320,320,128,128,3,3),
// std::make_tuple(320,640,64,64,3,3),
// std::make_tuple(640,640,64,64,3,3),
// std::make_tuple(320,640,64,64,3,3),
// std::make_tuple(640,640,64,64,3,3),
// std::make_tuple(640,640,64,64,3,3),
// std::make_tuple(640,640,64,64,3,3),
// std::make_tuple(640,1280,32,32,3,3),
// std::make_tuple(1280,1280,32,32,3,3),
// std::make_tuple(640,1280,32,32,3,3),
// std::make_tuple(1280,1280,32,32,3,3),
// std::make_tuple(1280,1280,32,32,3,3),
// std::make_tuple(1280,1280,32,32,3,3),
// std::make_tuple(1280,1280,32,32,3,3),
// std::make_tuple(1280,1280,32,32,3,3),
// std::make_tuple(1280,1280,32,32,3,3),
// std::make_tuple(2560,1280,32,32,3,3),
// std::make_tuple(1280,1280,32,32,3,3),
// std::make_tuple(2560,1280,32,32,3,3),
// std::make_tuple(2560,1280,32,32,3,3),
// std::make_tuple(1280,1280,32,32,3,3),
// std::make_tuple(2560,1280,32,32,3,3),
// std::make_tuple(1920,1280,32,32,3,3),
// std::make_tuple(1280,1280,32,32,3,3),
// std::make_tuple(1920,1280,32,32,3,3),
// std::make_tuple(1280,1280,64,64,3,3),
// std::make_tuple(1920,640,64,64,3,3),
// std::make_tuple(640,640,64,64,3,3),
// std::make_tuple(1920,640,64,64,3,3),
// std::make_tuple(1280,640,64,64,3,3),
// std::make_tuple(640,640,64,64,3,3),
// std::make_tuple(1280,640,64,64,3,3),
// std::make_tuple(960,640,64,64,3,3),
// std::make_tuple(640,640,64,64,3,3),
// std::make_tuple(960,640,64,64,3,3),
// std::make_tuple(640,640,128,128,3,3),
// std::make_tuple(960,320,128,128,3,3),
// std::make_tuple(320,320,128,128,3,3),
// std::make_tuple(960,320,128,128,3,3),
// std::make_tuple(640,320,128,128,3,3),
// std::make_tuple(320,320,128,128,3,3),
// std::make_tuple(640,320,128,128,3,3),
// std::make_tuple(640,320,128,128,3,3),
// std::make_tuple(320,320,128,128,3,3),
// std::make_tuple(640,320,128,128,3,3),
// std::make_tuple(320,4,128,128,3,3),
};
@ -663,7 +663,7 @@ int main(void)
// fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f);
int iterations = 20;
int iterations = 0;
double run_time0;
std::vector<float> im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0);
@ -705,16 +705,16 @@ int main(void)
// for(int i = 0; i < ggml_nelements(wino_res); i++) {
// for(int i = 0; i < 26*38; i++) {
// // for(int i = 0; i < conv2d_data.size(); i++) {
// float diff = fabs(im2col_data[i] - conv2d_data[i]);
// // if(diff > 0.5) {
// printf("(%7.3f, %7.3f, %.2f, %d) \n",
// im2col_data[i], conv2d_data[i],
// diff, i);
// // break;
// // }
// }
for(int i = 0; i < 26*38; i++) {
// for(int i = 0; i < conv2d_data.size(); i++) {
float diff = fabs(im2col_data[i] - conv2d_data[i]);
// if(diff > 0.5) {
printf("(%7.3f, %7.3f, %.2f, %d) \n",
im2col_data[i], conv2d_data[i],
diff, i);
// break;
// }
}
ggml_free(model.ctx);
ggml_backend_buffer_free(model.buffer);