reorder register tile loop
This commit is contained in:
parent
c6255442bb
commit
0ca43582e8
|
|
@ -183,12 +183,20 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
input_frag[(subcrs + 1) & 1][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32];
|
||||
}
|
||||
|
||||
// #pragma unroll
|
||||
// for (int i = 0; i < 8; ++i){
|
||||
// auto weight_frag_i = ggml_cuda_cast<float>(weight_frag[subcrs % 2][i]);
|
||||
// #pragma unroll
|
||||
// for (int j = 0; j < 8; ++j){
|
||||
// output_frag[i][j] += weight_frag_i * input_frag[subcrs % 2][j];
|
||||
// }
|
||||
// }
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; ++i){
|
||||
auto weight_frag_i = ggml_cuda_cast<float>(weight_frag[subcrs % 2][i]);
|
||||
for (int j = 0; j < 8; ++j){
|
||||
// auto weight_frag_i = ggml_cuda_cast<float>(weight_frag[subcrs % 2][i]);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 8; ++j){
|
||||
output_frag[i][j] += weight_frag_i * input_frag[subcrs % 2][j];
|
||||
for (int i = 0; i < 8; ++i){
|
||||
output_frag[j][i] += ggml_cuda_cast<float>(weight_frag[subcrs % 2][i]) * input_frag[subcrs % 2][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -215,7 +223,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
for (int i = 0; i < 8; ++i){
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 8; ++j){
|
||||
output_frag[i][j] += ggml_cuda_cast<float>(weight_frag[1][i]) * input_frag[1][j];
|
||||
output_frag[i][j] += ggml_cuda_cast<float>(weight_frag[1][j]) * input_frag[1][i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -240,15 +248,15 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
|||
#pragma unroll
|
||||
for (int subj = 0; subj < 4; ++subj){
|
||||
// output sts
|
||||
smemoutput[output_sts_addr + subi * 8 * 4 + subj] = output_frag[i * 4 + subi][j * 4 + subj];
|
||||
smemoutput[output_sts_addr + subj * 8 * 4 + subi] = output_frag[i * 4 + subi][j * 4 + subj];
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
#pragma unroll
|
||||
for (int subk = 0; subk < 16; ++subk){
|
||||
int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + i * 16 + subk) * param.Oh * param.Ow + n_idx + j * 32;
|
||||
if ((m_idx + i * 16 + subk) < param.k && (n_idx + j * 32) < param.Oh * param.Ow)
|
||||
int outOffset = z * param.k * param.Oh * param.Ow + (m_idx + j * 16 + subk) * param.Oh * param.Ow + n_idx + i * 32;
|
||||
if ((m_idx + j * 16 + subk) < param.k && (n_idx + i * 32) < param.Oh * param.Ow)
|
||||
output[outOffset] = smemoutput[output_lds_addr + subk * 32];
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue