reduced bank conflicts for output
This commit is contained in:
parent
75dde410a8
commit
4b1920e9e7
|
|
@ -1219,8 +1219,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
||||||
{
|
{
|
||||||
// output sts
|
// output sts
|
||||||
uint32_t (®_)[2] = reinterpret_cast<uint32_t(&)[2]>(acc_register_[mma_m][mma_n]);
|
uint32_t (®_)[2] = reinterpret_cast<uint32_t(&)[2]>(acc_register_[mma_m][mma_n]);
|
||||||
const uint idx = output_sts_addr +
|
uint idx = output_sts_addr +
|
||||||
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N;
|
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N;
|
||||||
|
idx = idx ^ ((idx & 0b1110000000) >> 4);
|
||||||
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
|
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
|
||||||
dst_ptr[0] = reg_[0];
|
dst_ptr[0] = reg_[0];
|
||||||
dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx + 8 * BN / 2]);
|
dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx + 8 * BN / 2]);
|
||||||
|
|
@ -1255,7 +1256,10 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
||||||
// if (n < param.n && (m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow)
|
// if (n < param.n && (m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow)
|
||||||
// param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32];
|
// param.interm[outOffset] = smemoutput[output_lds_addr + subk * 32];
|
||||||
const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
|
const uint outOffset = n * param.k * param.Oh * param.Ow + row * param.Oh * param.Ow + col;
|
||||||
output[outOffset] = smemoutput[output_lds_addr + subk + j*32*BN/2];
|
uint idx = output_lds_addr + subk + j*32*BN/2;
|
||||||
|
idx = idx ^ ((idx & 0b1110000000) >> 4);
|
||||||
|
// output[outOffset] = smemoutput[output_lds_addr + subk + j*32*BN/2];
|
||||||
|
output[outOffset] = smemoutput[idx];
|
||||||
// if(outOffset == 32){
|
// if(outOffset == 32){
|
||||||
// printf("(%u, %u, %u, %u), output[%d,%d,%d]=%f \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y,
|
// printf("(%u, %u, %u, %u), output[%d,%d,%d]=%f \n", threadIdx.x, threadIdx.y, blockIdx.x, blockIdx.y,
|
||||||
// n, row, col, __half2float(output[outOffset]));
|
// n, row, col, __half2float(output[outOffset]));
|
||||||
|
|
|
||||||
|
|
@ -357,6 +357,7 @@ int main(void)
|
||||||
// std::make_tuple(1280,1280,26,38,1,1),
|
// std::make_tuple(1280,1280,26,38,1,1),
|
||||||
// std::make_tuple(256,128,768,1024,3,3),
|
// std::make_tuple(256,128,768,1024,3,3),
|
||||||
// std::make_tuple(256,128,768,1024,1,1),
|
// std::make_tuple(256,128,768,1024,1,1),
|
||||||
|
// std::make_tuple(512,256,384,512,1,1),
|
||||||
// std::make_tuple(1280,640,52,76,3,3),
|
// std::make_tuple(1280,640,52,76,3,3),
|
||||||
// std::make_tuple(1920,1280,26,38,3,3),
|
// std::make_tuple(1920,1280,26,38,3,3),
|
||||||
// std::make_tuple(2560,1280,26,38,3,3),
|
// std::make_tuple(2560,1280,26,38,3,3),
|
||||||
|
|
@ -388,7 +389,7 @@ int main(void)
|
||||||
|
|
||||||
|
|
||||||
struct ggml_cgraph * gf_res_0 = NULL;
|
struct ggml_cgraph * gf_res_0 = NULL;
|
||||||
int iterations = 20;
|
int iterations = 0;
|
||||||
|
|
||||||
double run_time0;
|
double run_time0;
|
||||||
std::vector<float> im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0);
|
std::vector<float> im2col_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0);
|
||||||
|
|
@ -451,17 +452,17 @@ int main(void)
|
||||||
|
|
||||||
// for(int i = 0; i < ggml_nelements(wino_res); i++) {
|
// for(int i = 0; i < ggml_nelements(wino_res); i++) {
|
||||||
// for(int i = 0; i < 26*38; i++) {
|
// for(int i = 0; i < 26*38; i++) {
|
||||||
for(int i = 0; i < conv2d_data.size(); i++) {
|
// for(int i = 0; i < conv2d_data.size(); i++) {
|
||||||
// float diff = fabs(conv2d_data[i] - wino_data[i]);
|
// // float diff = fabs(conv2d_data[i] - wino_data[i]);
|
||||||
float diff = fabs(im2col_data[i] - wino_data[i]);
|
// float diff = fabs(im2col_data[i] - wino_data[i]);
|
||||||
float diff1 = fabs(im2col_data[i] - conv2d_data[i]);
|
// float diff1 = fabs(im2col_data[i] - conv2d_data[i]);
|
||||||
if(diff > 0.5) {
|
// // if(diff > 0.5) {
|
||||||
printf("(%7.3f, %7.3f, %7.3f, %.2f, %.2f, %d) \n",
|
// printf("(%7.3f, %7.3f, %7.3f, %.2f, %.2f, %d) \n",
|
||||||
im2col_data[i], conv2d_data[i],
|
// im2col_data[i], conv2d_data[i],
|
||||||
wino_data[i], diff, diff1, i);
|
// wino_data[i], diff, diff1, i);
|
||||||
// break;
|
// // break;
|
||||||
}
|
// // }
|
||||||
}
|
// }
|
||||||
|
|
||||||
ggml_free(model.ctx);
|
ggml_free(model.ctx);
|
||||||
ggml_backend_buffer_free(model.buffer);
|
ggml_backend_buffer_free(model.buffer);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue