unroll some loops
This commit is contained in:
parent
b70cca2ea3
commit
3f99818925
|
|
@ -4,8 +4,9 @@
|
||||||
#include "convert.cuh"
|
#include "convert.cuh"
|
||||||
#include "conv2d-implicit.cuh"
|
#include "conv2d-implicit.cuh"
|
||||||
|
|
||||||
static const int WARPSIZE = 32; // warpSize is not constexpr
|
|
||||||
typedef unsigned int uint;
|
typedef unsigned int uint;
|
||||||
|
constexpr uint WARPSIZE = 32;
|
||||||
|
|
||||||
static __global__ void reduce_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const int nrows) {
|
static __global__ void reduce_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols, const int nrows) {
|
||||||
const int row = blockIdx.x;
|
const int row = blockIdx.x;
|
||||||
|
|
@ -125,6 +126,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
|
|
||||||
// ldg
|
// ldg
|
||||||
const uint weight_sts_addr = innerRowA + innerColA * (BN+PAD) * 4;
|
const uint weight_sts_addr = innerRowA + innerColA * (BN+PAD) * 4;
|
||||||
|
#pragma unroll
|
||||||
for (uint offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) {
|
for (uint offset = 0; offset + rowStrideA <= BN; offset += rowStrideA) {
|
||||||
if(vec_load){
|
if(vec_load){
|
||||||
// if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 < param.c * param.r * param.s){
|
// if (by * BN + innerRowA + offset < param.k && start_k + innerColA * 4 < param.c * param.r * param.s){
|
||||||
|
|
@ -174,6 +176,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
// int curC = ((tx % 2) * 4 % (param.s * param.c)) % param.c; // kernel s offset
|
// int curC = ((tx % 2) * 4 % (param.s * param.c)) % param.c; // kernel s offset
|
||||||
|
|
||||||
const uint input_sts_addr = innerRowA + innerColA * BM * 4;
|
const uint input_sts_addr = innerRowA + innerColA * BM * 4;
|
||||||
|
#pragma unroll
|
||||||
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
|
for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) {
|
||||||
int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z;
|
int n = (ksplit > 0) ? (bx * BM + innerRowA + offset) / PQ : z;
|
||||||
const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ;
|
const unsigned int npq_res = (bx * BM + innerRowA + offset) % PQ;
|
||||||
|
|
@ -278,14 +281,18 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
// lds
|
// lds
|
||||||
// int input_lds_addr = (warp_id % 2) * 64 + mma_tid_x * 4;
|
// int input_lds_addr = (warp_id % 2) * 64 + mma_tid_x * 4;
|
||||||
const uint input_lds_addr = mma_tid_x * WM;
|
const uint input_lds_addr = mma_tid_x * WM;
|
||||||
|
#pragma unroll
|
||||||
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx)
|
for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx)
|
||||||
|
#pragma unroll
|
||||||
for (uint i = 0; i < TM; ++i)
|
for (uint i = 0; i < TM; ++i)
|
||||||
input_frag[0][wSubRowIdx * TM + i] = smeminput[input_lds_addr + wSubRowIdx * WSUBM +
|
input_frag[0][wSubRowIdx * TM + i] = smeminput[input_lds_addr + wSubRowIdx * WSUBM +
|
||||||
threadRowInWarp * TM + i];
|
threadRowInWarp * TM + i];
|
||||||
|
|
||||||
// int weight_lds_addr = (warp_id / 2) * 32 + mma_tid_y * 4;
|
// int weight_lds_addr = (warp_id / 2) * 32 + mma_tid_y * 4;
|
||||||
const uint weight_lds_addr = mma_tid_y * WN;
|
const uint weight_lds_addr = mma_tid_y * WN;
|
||||||
|
#pragma unroll
|
||||||
for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx)
|
for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx)
|
||||||
|
#pragma unroll
|
||||||
for (uint i = 0; i < TN; ++i)
|
for (uint i = 0; i < TN; ++i)
|
||||||
weight_frag[0][wSubColIdx * TN + i] = smemweight[weight_lds_addr + wSubColIdx * WSUBN +
|
weight_frag[0][wSubColIdx * TN + i] = smemweight[weight_lds_addr + wSubColIdx * WSUBN +
|
||||||
threadColInWarp * TN + i];
|
threadColInWarp * TN + i];
|
||||||
|
|
@ -495,7 +502,7 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input,
|
||||||
smeminput[write_flag * BM * BK + input_sts_addr + offset + 2*BM] = tmp.z;
|
smeminput[write_flag * BM * BK + input_sts_addr + offset + 2*BM] = tmp.z;
|
||||||
smeminput[write_flag * BM * BK + input_sts_addr + offset + 3*BM] = tmp.w;
|
smeminput[write_flag * BM * BK + input_sts_addr + offset + 3*BM] = tmp.w;
|
||||||
} else {
|
} else {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < 4; ++i)
|
for (int i = 0; i < 4; ++i)
|
||||||
smeminput[write_flag * BM * BK + input_sts_addr + offset + i*BM] = 0.f;
|
smeminput[write_flag * BM * BK + input_sts_addr + offset + i*BM] = 0.f;
|
||||||
}
|
}
|
||||||
|
|
@ -781,11 +788,11 @@ static void conv2d_implicit_cuda(const float * X_D, const T * K_D, float * Y_D,
|
||||||
}
|
}
|
||||||
|
|
||||||
static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t P, cudaStream_t st) {
|
static void conv2d_implicit_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const param_t P, cudaStream_t st) {
|
||||||
conv2d_implicit_cuda<half, 3>(X_D, K_D, Y_D, P, st);
|
conv2d_implicit_cuda<half, 1>(X_D, K_D, Y_D, P, st);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t P, cudaStream_t st) {
|
static void conv2d_implicit_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const param_t P, cudaStream_t st) {
|
||||||
conv2d_implicit_cuda<float, 3>(X_D, K_D, Y_D, P, st);
|
conv2d_implicit_cuda<float, 1>(X_D, K_D, Y_D, P, st);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
void ggml_cuda_op_conv2d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
|
|
||||||
|
|
@ -262,7 +262,7 @@ struct ggml_cgraph * build_graph_2(const test_model& model) {
|
||||||
// printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]);
|
// printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]);
|
||||||
|
|
||||||
|
|
||||||
struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1, 0);
|
struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1, 1);
|
||||||
// struct ggml_tensor* wino_res = ggml_conv_2d_direct(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1);
|
// struct ggml_tensor* wino_res = ggml_conv_2d_direct(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1);
|
||||||
ggml_set_name(wino_res, "wino_res");
|
ggml_set_name(wino_res, "wino_res");
|
||||||
ggml_build_forward_expand(gf, wino_res);
|
ggml_build_forward_expand(gf, wino_res);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue