WIP: debugging
This commit is contained in:
parent
df88b2c917
commit
76885c7697
|
|
@ -677,6 +677,7 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
||||||
uint32_t (®_)[2] = reinterpret_cast<uint32_t(&)[2]>(acc_register_[mma_m][mma_n]);
|
uint32_t (®_)[2] = reinterpret_cast<uint32_t(&)[2]>(acc_register_[mma_m][mma_n]);
|
||||||
uint idx = output_sts_addr +
|
uint idx = output_sts_addr +
|
||||||
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N;
|
mma_m * MMA_M * BN / 2 + (mma_n - i * mma_tiles_per_warp_n/2) * MMA_N;
|
||||||
|
idx = idx ^ ((idx & 0b110000000000) >> 9);
|
||||||
idx = idx ^ ((idx & 0b1110000000) >> 4);
|
idx = idx ^ ((idx & 0b1110000000) >> 4);
|
||||||
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
|
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
|
||||||
dst_ptr[0] = reg_[0];
|
dst_ptr[0] = reg_[0];
|
||||||
|
|
@ -695,19 +696,24 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
||||||
const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
const int n = fastdiv(gemm_i, param.OHOW_fastdiv);
|
||||||
const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
const int col = fastmodulo(gemm_i, param.OHOW_fastdiv);
|
||||||
uint idx = output_lds_addr + subk*2 + j*32*BN/2;
|
uint idx = output_lds_addr + subk*2 + j*32*BN/2;
|
||||||
|
idx = idx ^ ((idx & 0b110000000000) >> 9);
|
||||||
idx = idx ^ ((idx & 0b1110000000) >> 4);
|
idx = idx ^ ((idx & 0b1110000000) >> 4);
|
||||||
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
|
// uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&smemoutput[idx]);
|
||||||
|
uint32_t dst_ptr = *(reinterpret_cast<uint32_t*>(&smemoutput[idx]));
|
||||||
|
half (&res_)[2] = reinterpret_cast<half(&)[2]>(dst_ptr);
|
||||||
if (n < param.n && row < param.k && col < PQ) {
|
if (n < param.n && row < param.k && col < PQ) {
|
||||||
if constexpr (ksplit > 0) {
|
if constexpr (ksplit > 0) {
|
||||||
const uint outOffset = z * NKPQ +
|
const uint outOffset = z * NKPQ +
|
||||||
n * KPQ +
|
n * KPQ +
|
||||||
row * PQ + col;
|
row * PQ + col;
|
||||||
// output[outOffset] = smemoutput[idx];
|
// output[outOffset] = smemoutput[idx];
|
||||||
output[outOffset] = reinterpret_cast<half *>(dst_ptr)[0];
|
// output[outOffset] = reinterpret_cast<half *>(dst_ptr)[0];
|
||||||
|
output[outOffset] = res_[0];
|
||||||
} else {
|
} else {
|
||||||
const uint outOffset = n * KPQ + row * PQ + col;
|
const uint outOffset = n * KPQ + row * PQ + col;
|
||||||
// output[outOffset] = smemoutput[idx];
|
// output[outOffset] = smemoutput[idx];
|
||||||
output[outOffset] = reinterpret_cast<half *>(dst_ptr)[0];
|
// output[outOffset] = reinterpret_cast<half *>(dst_ptr)[0];
|
||||||
|
output[outOffset] = res_[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (n < param.n && row+1 < param.k && col < PQ) {
|
if (n < param.n && row+1 < param.k && col < PQ) {
|
||||||
|
|
@ -716,11 +722,13 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
||||||
n * KPQ +
|
n * KPQ +
|
||||||
(row+1) * PQ + col;
|
(row+1) * PQ + col;
|
||||||
// output[outOffset] = smemoutput[idx];
|
// output[outOffset] = smemoutput[idx];
|
||||||
output[outOffset] = reinterpret_cast<half *>(dst_ptr)[1];
|
// output[outOffset] = reinterpret_cast<half *>(dst_ptr)[1];
|
||||||
|
output[outOffset] = res_[1];
|
||||||
} else {
|
} else {
|
||||||
const uint outOffset = n * KPQ + (row+1) * PQ + col;
|
const uint outOffset = n * KPQ + (row+1) * PQ + col;
|
||||||
// output[outOffset] = smemoutput[idx];
|
// output[outOffset] = smemoutput[idx];
|
||||||
output[outOffset] = reinterpret_cast<half *>(dst_ptr)[1];
|
// output[outOffset] = reinterpret_cast<half *>(dst_ptr)[1];
|
||||||
|
output[outOffset] = res_[1];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue