fixed bug now split-k is working
This commit is contained in:
parent
6f44f47113
commit
688de6d7d8
|
|
@ -819,9 +819,6 @@ static __global__ void conv2d_implicit_kernel(const half * __restrict__ input,
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
if (block_k != num_block_tiles_k)
|
||||
{
|
||||
// switch smem buffers each iteration
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ __device__ __forceinline__ void tileMemcpySwizzleB(
|
|||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
// apply swizzle to the dst index
|
||||
const unsigned int src_index = thread_row * src_stride + thread_col * 8;
|
||||
const unsigned int src_index = thread_row * src_stride + start_k + thread_col * 8;
|
||||
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
|
||||
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
|
||||
|
|
@ -262,7 +262,7 @@ __device__ __forceinline__ void tileMemcpyLoadB(
|
|||
|
||||
#pragma unroll
|
||||
for (unsigned int i = 0; i < NUM_ITERS; i++){
|
||||
const unsigned int src_index = thread_row * src_stride + block_k + thread_col * 8;
|
||||
const unsigned int src_index = thread_row * src_stride + start_k + block_k + thread_col * 8;
|
||||
if (thread_row < param.k && curR < param.r && curS < param.s && curC < param.c && start_k+block_k+thread_col*8 < end_k){
|
||||
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
|
||||
}else{ // read 4 halves
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, int kw = 3,
|
|||
// create data
|
||||
int KW = kw, KH = kh, IC = ic, OC = oc;
|
||||
int IW = iw, IH = ih, N = 1;
|
||||
srand(time(NULL));
|
||||
// srand(time(NULL));
|
||||
|
||||
// printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH);
|
||||
|
||||
|
|
@ -384,8 +384,8 @@ int main(void)
|
|||
|
||||
// for(int i = 0; i < ggml_nelements(wino_res); i++) {
|
||||
// for(int i = 0; i < 26*38; i++) {
|
||||
// // for(int i = 26*38; i < 2*26*38; i++) {
|
||||
// // for(int i = 0; i < conv2d_data.size(); i++) {
|
||||
// for(int i = 26*38; i < 2*26*38; i++) {
|
||||
// for(int i = 0; i < conv2d_data.size(); i++) {
|
||||
// float diff = fabs(im2col_data[i] - conv2d_data[i]);
|
||||
// // if(diff > 0.5) {
|
||||
// printf("(%7.3f, %7.3f, %.2f, %d) \n",
|
||||
|
|
|
|||
Loading…
Reference in New Issue