trying to reduce integer ops; simply code
This commit is contained in:
parent
c33e4301dc
commit
ea438d8b0e
|
|
@ -805,6 +805,9 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
}
|
||||
}
|
||||
|
||||
const unsigned int A_warp_tile_offset = warp_m * WM * BK;
|
||||
const unsigned int B_warp_tile_offset = warp_n * WN * BK;
|
||||
|
||||
static_assert(BM == 256);
|
||||
static_assert(BN == 256);
|
||||
static_assert(BK == 32);
|
||||
|
|
@ -825,13 +828,11 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
__syncthreads();
|
||||
|
||||
if (block_k != num_block_tiles_k){
|
||||
const half* A_block_gmem = input;
|
||||
const half* B_block_gmem = kernel + (block_n * BN * weightKOffset);
|
||||
tileMemcpyLoadA<BM, BK, NUM_THREADS, 4>(A_block_gmem, A_gmem_cache_reg, block_k * BK, start_k, end_k, inChannelOffset, param);
|
||||
tileMemcpyLoadB<BN, BK, NUM_THREADS, 4>(B_block_gmem, B_gmem_cache_reg, block_k * BK, start_k, end_k, weightKOffset, param);
|
||||
}
|
||||
half* A_warp_tile = A_block_smem + (warp_m * WM * BK);
|
||||
half* B_warp_tile = B_block_smem + (warp_n * WN * BK);
|
||||
half* A_warp_tile = A_block_smem + A_warp_tile_offset;
|
||||
half* B_warp_tile = B_block_smem + B_warp_tile_offset;
|
||||
|
||||
ldmatrix_a<mma_tiles_per_warp_m, mma_tiles_per_warp_k, BK>(A_warp_tile, A_register_);
|
||||
ldmatrix_b<mma_tiles_per_warp_k, mma_tiles_per_warp_n, BK>(B_warp_tile, B_register_);
|
||||
|
|
@ -886,23 +887,25 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
const uint lane_id = threadIdx.x % WARPSIZE;
|
||||
const uint mma_row = lane_id / 4;
|
||||
const uint mma_col = lane_id % 4;
|
||||
const uint output_lds_addr = warp_m * WM * BN/2 + lane_id * BN/2 + warp_n * WN/2;
|
||||
const uint output_sts_addr = warp_m * WM * BN/2 + mma_row * BN/2 + warp_n * WN/2 + mma_col * 2;
|
||||
const uint warp_offset = warp_m * WM * BN/2 + warp_n * WN/2;
|
||||
const uint output_lds_addr = warp_offset + lane_id * BN/2;
|
||||
const uint output_sts_addr = warp_offset + mma_row * BN/2 + mma_col * 2;
|
||||
const uint m_idx = block_n * BN + warp_n * WN;
|
||||
const uint n_idx = block_m * BM + warp_m * WM + lane_id;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2; ++i)
|
||||
{
|
||||
const unsigned int i_offset = i * mma_tiles_per_warp_n/2;
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (unsigned int mma_m = 0; mma_m < mma_tiles_per_warp_m; mma_m++)
|
||||
{
|
||||
for (unsigned int mma_n = i * mma_tiles_per_warp_n/2; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++)
|
||||
const unsigned int mma_m_offset = output_sts_addr + mma_m * MMA_M * BN / 2;
|
||||
for (unsigned int mma_n = i_offset; mma_n < (i+1)*mma_tiles_per_warp_n/2; mma_n++)
|
||||
{
|
||||
uint32_t (®_)[2] = reinterpret_cast<uint32_t(&)[2]>(acc_register_[mma_m][mma_n]);
|
||||
uint idx = output_sts_addr +
|
||||
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N;
|
||||
uint idx = mma_m_offset + (mma_n - i_offset) * MMA_N;
|
||||
idx = idx ^ ((idx & 0b110000000000) >> 9);
|
||||
idx = idx ^ ((idx & 0b1110000000) >> 4);
|
||||
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
|
||||
|
|
@ -913,6 +916,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const unsigned int m_i_wn = m_idx + i * WN / 2;
|
||||
#pragma unroll
|
||||
for (int subk = 0; subk < WN / 4; ++subk){
|
||||
|
|
@ -925,29 +929,15 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
const uint gemm_i = n_idx + j*32;
|
||||
const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||
const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||
uint32_t dst_ptr = *(reinterpret_cast<uint32_t*>(&smemoutput[idx+j*32*BN/2]));
|
||||
uint32_t dst_ptr = *(reinterpret_cast<uint32_t*>(&smemoutput[idx+j*16*BN])); // 32*BN/2 = 16*BN
|
||||
half (&res_)[2] = reinterpret_cast<half(&)[2]>(dst_ptr);
|
||||
if (n < param.n && row < param.k && col < PQ) {
|
||||
if constexpr (ksplit > 0) {
|
||||
const uint outOffset = z * NKPQ +
|
||||
n * KPQ +
|
||||
row * PQ + col;
|
||||
output[outOffset] = ggml_cuda_cast<T>(res_[0]);
|
||||
} else {
|
||||
const uint outOffset = n * KPQ + row * PQ + col;
|
||||
output[outOffset] = ggml_cuda_cast<T>(res_[0]);
|
||||
}
|
||||
const uint outOffset = ((ksplit > 0) ? z * NKPQ : 0) + n * KPQ + row * PQ + col;
|
||||
output[outOffset] = ggml_cuda_cast<T>(res_[0]);
|
||||
}
|
||||
if (n < param.n && row+1 < param.k && col < PQ) {
|
||||
if constexpr (ksplit > 0) {
|
||||
const uint outOffset = z * NKPQ +
|
||||
n * KPQ +
|
||||
(row+1) * PQ + col;
|
||||
output[outOffset] = ggml_cuda_cast<T>(res_[1]);
|
||||
} else {
|
||||
const uint outOffset = n * KPQ + (row+1) * PQ + col;
|
||||
output[outOffset] = ggml_cuda_cast<T>(res_[1]);
|
||||
}
|
||||
const uint outOffset = ((ksplit > 0) ? z * NKPQ : 0) + n * KPQ + (row+1) * PQ + col;
|
||||
output[outOffset] = ggml_cuda_cast<T>(res_[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -714,15 +714,15 @@ int main(void)
|
|||
|
||||
// for(int i = 0; i < ggml_nelements(wino_res); i++) {
|
||||
// for(int i = 0; i < 26*38; i++) {
|
||||
// for(int i = 0; i < conv2d_data.size(); i++) {
|
||||
// float diff = fabs(im2col_data[i] - conv2d_data[i]);
|
||||
// // if(diff > 0.5) {
|
||||
// printf("(%7.3f, %7.3f, %.2f, %d) \n",
|
||||
// im2col_data[i], conv2d_data[i],
|
||||
// diff, i);
|
||||
// // break;
|
||||
// // }
|
||||
// }
|
||||
for(int i = 0; i < conv2d_data.size(); i++) {
|
||||
float diff = fabs(im2col_data[i] - conv2d_data[i]);
|
||||
// if(diff > 0.5) {
|
||||
printf("(%7.3f, %7.3f, %.2f, %d) \n",
|
||||
im2col_data[i], conv2d_data[i],
|
||||
diff, i);
|
||||
// break;
|
||||
// }
|
||||
}
|
||||
|
||||
ggml_free(model.ctx);
|
||||
ggml_backend_buffer_free(model.buffer);
|
||||
|
|
|
|||
Loading…
Reference in New Issue