From 83a3b7d6a98727705ed04f29afdd587ea3c17c37 Mon Sep 17 00:00:00 2001 From: bssrdf Date: Sat, 6 Sep 2025 17:26:19 -0400 Subject: [PATCH] Refactor conv2d_implicit_kernel for improved bitwise operations; add test for implicit convolution --- ggml/src/ggml-cuda/conv2d-implicit.cu | 72 +++-- tests/CMakeLists.txt | 1 + tests/test-conv2d-implicit.cpp | 390 ++++++++++++++++++++++++++ 3 files changed, 434 insertions(+), 29 deletions(-) create mode 100644 tests/test-conv2d-implicit.cpp diff --git a/ggml/src/ggml-cuda/conv2d-implicit.cu b/ggml/src/ggml-cuda/conv2d-implicit.cu index 31205187c1..1e2540f8ca 100644 --- a/ggml/src/ggml-cuda/conv2d-implicit.cu +++ b/ggml/src/ggml-cuda/conv2d-implicit.cu @@ -37,16 +37,16 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, // Warp tile - const int lane_id = threadIdx.x % 32; - const int warp_id = threadIdx.x / 32; - const int mma_tid_x = (lane_id / 2) % 8; - const int mma_tid_y = (lane_id / 16) * 2 + (lane_id % 2); + const int lane_id = threadIdx.x & 31; + const int warp_id = threadIdx.x >> 5; + const int mma_tid_x = (lane_id >> 1) % 8; + const int mma_tid_y = (lane_id >> 4) * 2 + (lane_id & 1); // lds addr - int weight_lds_addr = (warp_id / 2) * 32 + mma_tid_y * 4; - int input_lds_addr = (warp_id % 2) * 64 + mma_tid_x * 4; + int weight_lds_addr = (warp_id >> 1) * 32 + mma_tid_y * 4; + int input_lds_addr = (warp_id & 1) * 64 + mma_tid_x * 4; - int x = bx * 128 + input_lds_addr; - int y = by * 128 + weight_lds_addr; + // int x = bx * 128 + input_lds_addr; + // int y = by * 128 + weight_lds_addr; int z = blockIdx.z; T weight_ldg_reg[4]; @@ -56,20 +56,20 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, int posw_ori[4]; #pragma unroll for (int i = 0; i < 4; ++i){ - posh_ori[i] = ((bx * 128 + tx % 32 + i * 32) / param.Ow) * param.u - param.p; - posw_ori[i] = ((bx * 128 + tx % 32 + i * 32) % param.Ow) * param.v - param.q; + posh_ori[i] = ((bx * 128 + lane_id + i * 32) / param.Ow) * param.u - param.p; + posw_ori[i] = ((bx * 128 + lane_id + i * 32) % param.Ow) * param.v - param.q; } int inOffset = z * param.c * param.h * param.w; - int weiOffset = (by * 128 + tx / 8 * 4) * param.c * param.r * param.s; + int weiOffset = (by * 128 + (tx >> 3) * 4) * param.c * param.r * param.s; int inChannelOffset = param.h * param.w; - int weightChannelOffset = param.r * param.s; + // int weightChannelOffset = param.r * param.s; int weightKOffset = param.c * param.r * param.s; // sts addr - int weight_sts_addr = (tx % 8) * 132 + - (tx / 8) * 4; - int input_sts_addr = (tx / 32) * 128 + (tx % 32); + int weight_sts_addr = (tx & 7) * 132 + + (tx >> 3) * 4; + int input_sts_addr = (warp_id) * 128 + (lane_id); int write_flag = 1; T weight_frag[2][8]; @@ -85,16 +85,16 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, // ldg #pragma unroll for (int i = 0; i < 4; ++i){ - if (tx % 8 < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k){ - weight_ldg_reg[i] = kernel[weiOffset + tx % 8 + i * weightKOffset]; + if (tx % 8 < weightKOffset && by * 128 + (tx >> 3) * 4 + i < param.k){ + weight_ldg_reg[i] = kernel[weiOffset + (tx & 7) + i * weightKOffset]; } else{ weight_ldg_reg[i] = (T)0.f; } } - int curC = (tx / 32) / (param.r * param.s); // channel offset - int curR = ((tx / 32) % (param.r * param.s)) / param.s; // kernel r offset - int curS = ((tx / 32) % (param.r * param.s)) % param.s; // kernel s offset + int curC = (warp_id) / (param.r * param.s); // channel offset + int curR = ((warp_id) % (param.r * param.s)) / param.s; // kernel r offset + int curS = ((warp_id) % (param.r * param.s)) % param.s; // kernel s offset #pragma unroll for (int i = 0; i < 4; ++i){ int curH = posh_ori[i] + curR * param.d_h; // input h @@ -127,21 +127,23 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, input_frag[0][i] = smeminput[input_lds_addr + i]; input_frag[0][i + 4] = smeminput[input_lds_addr + i + 32]; } + + // main loop for (int crs = 0; crs < param.r * param.s * param.c; crs += 8){ // ldg - int weiOffsetTmp = crs + 8 + tx % 8; + int weiOffsetTmp = crs + 8 + (tx & 7); #pragma unroll for (int i = 0; i < 4; ++i){ - if (weiOffsetTmp < weightKOffset && by * 128 + tx / 8 * 4 + i < param.k){ + if (weiOffsetTmp < weightKOffset && by * 128 + (tx >> 3) * 4 + i < param.k){ weight_ldg_reg[i] = kernel[weiOffset + weiOffsetTmp + i * weightKOffset]; } else{ weight_ldg_reg[i] = (T)0.f; } } - curC = (crs + 8 + tx / 32) / (param.r * param.s); // channel offset - curR = ((crs + 8 + tx / 32) % (param.r * param.s)) / param.s; // kernel r offset - curS = ((crs + 8 + tx / 32) % (param.r * param.s)) % param.s; // kernel s offset + curC = (crs + 8 + warp_id) / (param.r * param.s); // channel offset + curR = ((crs + 8 + warp_id) % (param.r * param.s)) / param.s; // kernel r offset + curS = ((crs + 8 + warp_id) % (param.r * param.s)) % param.s; // kernel s offset #pragma unroll for (int i = 0; i < 4; ++i){ @@ -160,13 +162,25 @@ static __global__ void conv2d_implicit_kernel(const float * __restrict__ input, for (int subcrs = 0; subcrs < 8 - 1; ++subcrs){ #pragma unroll for (int i = 0; i < 4; ++i){ - weight_frag[(subcrs + 1) % 2][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i]; - weight_frag[(subcrs + 1) % 2][i + 4] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i + 16]; + weight_frag[(subcrs + 1) & 1][i] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i]; + weight_frag[(subcrs + 1) & 1][i + 4] = smemweight[load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132 + i + 16]; } + // // compute base pointer once + // T* base_ptr = smemweight + load_flag * 132 * 8 + weight_lds_addr + (subcrs + 1) * 132; + + // // first 4 values -> weight_frag[...][0..3] + // float4 v0 = *reinterpret_cast(base_ptr); + + // // next 4 values (offset +16) -> weight_frag[...][4..7] + // float4 v1 = *reinterpret_cast(base_ptr + 16); + + // // unpack into weight_frag + // *reinterpret_cast(&weight_frag[(subcrs + 1) % 2][0]) = v0; + // *reinterpret_cast(&weight_frag[(subcrs + 1) % 2][4]) = v1; #pragma unroll for (int i = 0; i < 4; ++i){ - input_frag[(subcrs + 1) % 2][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i]; - input_frag[(subcrs + 1) % 2][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32]; + input_frag[(subcrs + 1) & 1][i] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i]; + input_frag[(subcrs + 1) & 1][i + 4] = smeminput[load_flag * 128 * 8 + input_lds_addr + (subcrs + 1) * 128 + i + 32]; } #pragma unroll diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9171957756..7ce76f0105 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -198,6 +198,7 @@ if (NOT LLAMA_SANITIZE_ADDRESS) endif() llama_build_and_test(test-gguf.cpp) llama_build_and_test(test-backend-ops.cpp) +llama_build_and_test(test-conv2d-implicit.cpp) llama_build_and_test(test-model-load-cancel.cpp LABEL "model") llama_build_and_test(test-autorelease.cpp LABEL "model") diff --git a/tests/test-conv2d-implicit.cpp b/tests/test-conv2d-implicit.cpp new file mode 100644 index 0000000000..b0efba2f1c --- /dev/null +++ b/tests/test-conv2d-implicit.cpp @@ -0,0 +1,390 @@ +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-cpu.h" +#include "ggml-backend.h" + +#ifdef GGML_USE_CUDA +#include "ggml-cuda.h" +//#include +#endif + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} + +struct test_model { + struct ggml_tensor * a; + struct ggml_tensor * b; + ggml_backend_t backend = NULL; + ggml_backend_buffer_t buffer; + struct ggml_context * ctx; +}; + + + +void load_model(test_model & model, int ic, int oc, int iw, int ih, bool use_gpu = false ) { + // create data + int KW = 3, KH = 3, IC = ic, OC = oc; + int IW = iw, IH = ih, N = 1; + + // printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH); + + // Initialize adata + std::vector adata(KW * KH * IC * OC); + for (int i = 0; i < KW * KH * IC * OC; i++) { + adata[i] = 2.5f; + } + + // Convert adata to fp16 format + // std::vector hadata(KW * KH * IC * OC); + // ggml_fp32_to_fp16_row(adata.data(), hadata.data(), KW * KH * IC * OC); + + // Initialize bdata + std::vector bdata(IW * IH * IC * N); + for (int i = 0; i < IW * IH * IC * N; i++) { + bdata[i] = 1.5f; + } + + size_t buffer_size = 0; + { + buffer_size += KW * KH * IC * OC * ggml_type_size(GGML_TYPE_F32); // tensor a + buffer_size += IW * IH * IC * N * ggml_type_size(GGML_TYPE_F32); // tensor b + buffer_size += 1024; // overhead + } + + // printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor)); + // printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f)); + + int num_tensors = 2; + struct ggml_init_params params { + /*.mem_size =*/ ggml_tensor_overhead() * num_tensors, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + // initialize the backend +#ifdef GGML_USE_CUDA + if (use_gpu) { + // fprintf(stderr, "%s: using CUDA backend\n", __func__); + model.backend = ggml_backend_cuda_init(0); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_METAL + if (use_gpu) { + fprintf(stderr, "%s: using Metal backend\n", __func__); + ggml_backend_metal_log_set_callback(ggml_log_callback_default, nullptr); + model.backend = ggml_backend_metal_init(); + if (!model.backend) { + fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); + } + } +#endif + + if(!model.backend) { + // fallback to CPU backend + model.backend = ggml_backend_cpu_init(); + } + + model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size); + + // create context + model.ctx = ggml_init(params); + + // create tensors + model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, KW, KH, IC, OC); + model.b = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, IW, IH, IC, N); + + // create a allocator + struct ggml_tallocr alloc = ggml_tallocr_new(model.buffer); + + // alloc memory + ggml_tallocr_alloc(&alloc, model.a); + + // load data to buffer + if(ggml_backend_is_cpu(model.backend)) { + memcpy(model.a->data, adata.data(), ggml_nbytes(model.a)); + } else { + ggml_backend_tensor_set(model.a, adata.data(), 0, ggml_nbytes(model.a)); + } + + // alloc memory + ggml_tallocr_alloc(&alloc, model.b); + + if(ggml_backend_is_cpu(model.backend) +#ifdef GGML_USE_METAL + || ggml_backend_is_metal(model.backend) +#endif + ) { + memcpy(model.b->data, bdata.data(), ggml_nbytes(model.b)); + } else { + ggml_backend_tensor_set(model.b, bdata.data(), 0, ggml_nbytes(model.b)); + } +} + +typedef struct ggml_cgraph* (*build_graph_t)(const test_model& model); + +struct ggml_cgraph * build_graph_0(const test_model& model) { + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params0 = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + + // create a temporally context to build the graph + struct ggml_context * ctx0 = ggml_init(params0); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + int s0 = 1; + int s1 = 1; + int p0 = 1; + int p1 = 1; + int d0 = 1; + int d1 = 1; + + + + // recalculate for avoid fragmentation + struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + ggml_set_name(conv2d_res, "conv2d_res"); + ggml_build_forward_expand(gf, conv2d_res); + // int64_t *ne = conv2d_res->ne; + // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + + + // struct ggml_tensor* wino_res = ggml_conv_2d_3x3(ctx0, model.a, model.b); + // ggml_set_name(wino_res, "wino_res"); + // ggml_build_forward_expand(gf, wino_res); + // ne = wino_res->ne; + // printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + ggml_free(ctx0); + return gf; +} + +struct ggml_cgraph * build_graph_1(const test_model& model) { + static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params0 = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph() + }; + + // create a temporally context to build the graph + struct ggml_context * ctx0 = ggml_init(params0); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + int s0 = 1; + int s1 = 1; + int p0 = 1; + int p1 = 1; + int d0 = 1; + int d1 = 1; + + + + // recalculate for avoid fragmentation + // struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + // ggml_set_name(conv2d_res, "conv2d_res"); + // ggml_build_forward_expand(gf, conv2d_res); + // int64_t *ne = conv2d_res->ne; + // printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + + + struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1); + ggml_set_name(wino_res, "wino_res"); + ggml_build_forward_expand(gf, wino_res); + // ne = wino_res->ne; + // printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]); + ggml_free(ctx0); + return gf; +} + + + + +std::vector compute_graph(const test_model & model, ggml_gallocr_t allocr, + build_graph_t build_graph, int iters, double *t) { + struct ggml_cgraph * gf = build_graph(model); + + + // allocate tensors + ggml_gallocr_alloc_graph(allocr, gf); + int n_threads = 1; + + if (ggml_backend_is_cpu(model.backend)) { + ggml_backend_cpu_set_n_threads(model.backend, n_threads); + } + +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(model.backend)) { + ggml_backend_metal_set_n_cb(model.backend, n_threads); + } +#endif + + + + ggml_backend_graph_compute(model.backend, gf); + + ggml_backend_synchronize(model.backend); + + int64_t start_time = ggml_time_us(); + + for(int iter=0; iter data(ggml_nelements(res)); + ggml_backend_tensor_get(res, data.data(), 0, ggml_nbytes(res)); + + *t = time_us/1000; + return data; + +} + + +int main(void) +{ + ggml_time_init(); + std::vector> configs = { + std::make_tuple(64,64,48,64), + std::make_tuple(320,320,104,152), + std::make_tuple(640,640,52,76), + std::make_tuple(640,640,104,152), + std::make_tuple(960,320,104,152), + std::make_tuple(1280,1280,26,38), + std::make_tuple(1280,640,52,76), + std::make_tuple(1920,1280,26,38), + std::make_tuple(2560,1280,26,38), + std::make_tuple(512,512,104,152), + std::make_tuple(512,512,208,304), + std::make_tuple(512,256,416,608), + std::make_tuple(256,128,832,1216), + std::make_tuple(256,256,832,1216), + std::make_tuple(320,256,1024,1920) + }; + + int k = 0; + + for (auto c : configs){ + test_model model; + load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), true); + + ggml_gallocr_t allocr = NULL; + allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); + + //create the worst case graph for memory usage estimation + struct ggml_cgraph * gf = build_graph_0(model); + + // compute the required memory + ggml_gallocr_reserve(allocr, gf); + size_t mem_size0 = ggml_gallocr_get_buffer_size(allocr, 0); + // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); + + + struct ggml_cgraph * gf_res_0 = NULL; + int iterations = 20; + + double run_time0; + std::vector conv2d_data = compute_graph(model, allocr, build_graph_0, iterations, &run_time0); + + ggml_gallocr_free(allocr); + + allocr = NULL; + + allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend)); + + //create the worst case graph for memory usage estimation + gf = build_graph_1(model); + + // compute the required memory + ggml_gallocr_reserve(allocr, gf); + size_t mem_size1 = ggml_gallocr_get_buffer_size(allocr, 0); + // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f); + + + struct ggml_cgraph * gf_res_1 = NULL; + + double run_time1; + std::vector wino_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1); + + if(k==0) { + k = 1; + fprintf(stderr, "| (IC, OC, IW, IH) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n"); + fprintf(stderr, "| --- | --- | --- | --- | --- \n"); + } + + fprintf(stderr, " | (%d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n", + std::get<0>(c), std::get<1>(c), std::get<2>(c), std::get<3>(c), + run_time0, mem_size0/1024.0f/1024.0f, + run_time1, mem_size1/1024.0f/1024.0f); + + + // for(int i = 0; i < ggml_nelements(wino_res); i++) { + // for(int i = 0; i < 3*28; i++) { + // float diff = fabs(conv2d_data[i] - wino_data[i]); + // // if(diff > 1.e-4) { + // printf("(%f, %f, %f, %d) \n", + // conv2d_data[i], + // wino_data[i], diff, i); + // // break; + // // } + // } + + ggml_free(model.ctx); + ggml_backend_buffer_free(model.buffer); + ggml_backend_free(model.backend); + ggml_gallocr_free(allocr); + + } + + + // printf("\nPerforming test:\n"); + + return 0; +}