reorder register tile loop
This commit is contained in:
parent
c6255442bb
commit
0ca43582e8
|
|
@ -183,12 +183,20 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
input_frag[(subcrs + 1) & 1][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32];
|
input_frag[(subcrs + 1) & 1][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// #pragma unroll
|
||||||
|
// for (int i = 0; i < 8; ++i){
|
||||||
|
// auto weight_frag_i = ggml_cuda_cast<float>(weight_frag[subcrs % 2][i]);
|
||||||
|
// #pragma unroll
|
||||||
|
// for (int j = 0; j < 8; ++j){
|
||||||
|
// output_frag[i][j] += weight_frag_i * input_frag[subcrs % 2][j];
|
||||||
|
// }
|
||||||
|
// }
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 8; ++i){
|
for (int j = 0; j < 8; ++j){
|
||||||
auto weight_frag_i = ggml_cuda_cast<float>(weight_frag[subcrs % 2][i]);
|
// auto weight_frag_i = ggml_cuda_cast<float>(weight_frag[subcrs % 2][i]);
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 8; ++j){
|
for (int i = 0; i < 8; ++i){
|
||||||
output_frag[i][j] += weight_frag_i * input_frag[subcrs % 2][j];
|
output_frag[j][i] += ggml_cuda_cast<float>(weight_frag[subcrs % 2][i]) * input_frag[subcrs % 2][j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -215,7 +223,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
for (int i = 0; i < 8; ++i){
|
for (int i = 0; i < 8; ++i){
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < 8; ++j){
|
for (int j = 0; j < 8; ++j){
|
||||||
output_frag[i][j] += ggml_cuda_cast<float>(weight_frag[1][i]) * input_frag[1][j];
|
output_frag[i][j] += ggml_cuda_cast<float>(weight_frag[1][j]) * input_frag[1][i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -240,15 +248,15 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int subj = 0; subj < 4; ++subj){
|
for (int subj = 0; subj < 4; ++subj){
|
||||||
// output sts
|
// output sts
|
||||||
smemoutput[output_sts_addr + subi * 8 * 4 + subj] = output_frag[i * 4 + subi][j * 4 + subj];
|
smemoutput[output_sts_addr + subj * 8 * 4 + subi] = output_frag[i * 4 + subi][j * 4 + subj];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int subk = 0; subk < 16; ++subk){
|
for (int subk = 0; subk < 16; ++subk){
|
||||||
int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + n_idx + j * 32;
|
int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + j * 16 + subk) * param.Oh * param.Ow + n_idx + i * 32;
|
||||||
if ((m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow)
|
if ((m_idx + j * 16 + subk) < param.k && (n_idx + i * 32) < param.Oh * param.Ow)
|
||||||
output[outOffset] = smemoutput[output_lds_addr + subk * 32];
|
output[outOffset] = smemoutput[output_lds_addr + subk * 32];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue