This commit is contained in:
bssrdf 2025-12-16 17:40:02 -08:00 committed by GitHub
commit dbd806c060
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 2060 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,402 @@
#pragma once
#include "common.cuh"
typedef struct{
unsigned int n; //batch size
unsigned int c; //number if channels
unsigned int h; //height
unsigned int w; //width
unsigned int d; //depth
unsigned int k; //number of filters
unsigned int r; //filter height
unsigned int s; //filter width
unsigned int t; //filter depth
unsigned int stride0; //stride width
unsigned int stride1; //stride height
unsigned int stride2; //stride depth
unsigned int padding0; //padding width
unsigned int padding1; //padding height
unsigned int padding2; //padding depth
unsigned int dilation0; //dilation width
unsigned int dilation1; //dilation height
unsigned int dilation2; //dilation depth
unsigned int Oh; //output height
unsigned int Ow; //output width
unsigned int Od; //output depth
uint3 SC_fastdiv;
uint3 OW_fastdiv;
uint3 C_fastdiv;
uint3 RS_fastdiv;
uint3 S_fastdiv;
uint3 OHOW_fastdiv;
uint3 PQZ_fastdiv;
uint3 RSC_fastdiv;
uint3 TRS_fastdiv;
} param_t;
template<const int layout>
__device__ __forceinline__ int4 inputIndices(const unsigned int kidx, param_t param) {
const unsigned int cur0 = fastdiv(kidx,
layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv);
const unsigned int cur0_res = fastmodulo(kidx,
layout == 0 ? param.RSC_fastdiv : param.TRS_fastdiv);
const unsigned int cur1 = fastdiv(cur0_res,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv);
const unsigned int cur1_res = fastmodulo(cur0_res,
layout == 0 ? param.SC_fastdiv : param.RS_fastdiv);
const unsigned int cur2 = fastdiv(cur1_res,
layout == 0 ? param.C_fastdiv : param.S_fastdiv);
const unsigned int cur3 = fastmodulo(cur1_res,
layout == 0 ? param.C_fastdiv : param.S_fastdiv);
const unsigned int curC = layout == 0 ? cur3 : cur0;
const unsigned int curT = layout == 0 ? cur0 : cur1;
const unsigned int curR = layout == 0 ? cur1 : cur2;
const unsigned int curS = layout == 0 ? cur2 : cur3;
return make_int4(curC, curT, curR, curS);
}
// same as above, but writes are swizzled to avoid bank conflicts when shared memory is read later in the kernel
template<unsigned int TILE_ROWS,
unsigned int NUM_THREADS>
__device__ __forceinline__ void tileMemcpySwizzleB(
const half* src,
half* dst,
const unsigned int start_k,
const unsigned int end_k,
const unsigned int src_stride,
param_t param
){
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
constexpr unsigned int SWIZZLE_MASK_1 = 0b10000;
constexpr unsigned int SWIZZLE_BITS_1 = 4;
constexpr unsigned int SWIZZLE_MASK_2 = 0b1100;
constexpr unsigned int SWIZZLE_BITS_2 = 2;
constexpr unsigned int TILE_COLS = 32;
float4* dst_float4 = reinterpret_cast<float4*>(dst);
// # of threads is multiple of # of columns in the tile
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
// flatten out 2d grid of threads into in order of increasing threadIdx.x
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
// assign each thread a row/column in the tile, calculate how many iterations we need
// to cover the whole tile
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
const unsigned int kidx = start_k + thread_col*8;
const int4 curIdx = inputIndices<0>(kidx, param);
const int curC = curIdx.x;
const int curT = curIdx.y;
const int curR = curIdx.z;
const int curS = curIdx.w;
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
// apply swizzle to the dst index
const unsigned int src_index = thread_row * src_stride + kidx;
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
// TODO: move some checks outside of loop?
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t && curC < param.c && kidx < end_k){
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[src_index])[0];
}else{ // read 4 halves
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
}
thread_row += ROW_STEP;
}
#else
GGML_UNUSED(src);
GGML_UNUSED(dst);
GGML_UNUSED(src_stride);
GGML_UNUSED(param);
NO_DEVICE_CODE;
#endif
}
// this is a special case of the above for when TILE_COLS == 32
template<unsigned int TILE_ROWS,
unsigned int NUM_THREADS>
__device__ __forceinline__ void tileMemcpySwizzleA(
const half* src,
half* dst,
const unsigned int start_k,
const unsigned int end_k,
const unsigned int inNOffset,
const unsigned int inDepthOffset,
const unsigned int inChannelOffset,
param_t param
)
{
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
constexpr unsigned int SWIZZLE_MASK_1 = 0b10000;
constexpr unsigned int SWIZZLE_BITS_1 = 4;
constexpr unsigned int SWIZZLE_MASK_2 = 0b1100;
constexpr unsigned int SWIZZLE_BITS_2 = 2;
constexpr unsigned int TILE_COLS = 32;
float4* dst_float4 = reinterpret_cast<float4*>(dst);
// # of threads is multiple of # of columns in the tile
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
// flatten out 2d grid of threads into in order of increasing threadIdx.x
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
// assign each thread a row/column in the tile, calculate how many iterations we need
// to cover the whole tile
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
const unsigned int kidx = start_k+thread_col*8;
const int4 curIdx = inputIndices<0>(kidx, param);
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
unsigned int n = fastdiv(gemm_i, param.PQZ_fastdiv);
const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv);
const int posd_ori = fastdiv(npqz_res, param.OHOW_fastdiv) * param.stride2 - param.padding2;
const int ohow_res = fastmodulo(npqz_res, param.OHOW_fastdiv);
const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1;
const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0;
const int curD = posd_ori + curIdx.y * param.dilation2; // input d
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
const int curC = curIdx.x;
// apply swizzle to the dst index
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d &&
n < param.n && curC < param.c && kidx < end_k){
int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC;
dst_float4[dst_index] = reinterpret_cast<const float4 *>(&src[n * inNOffset + inOffsetTmp])[0];
} else{
dst_float4[dst_index] = make_float4(0.f, 0.f, 0.f, 0.f);
}
thread_row += ROW_STEP;
}
#else
GGML_UNUSED(src);
GGML_UNUSED(dst);
GGML_UNUSED(inChannelOffset);
GGML_UNUSED(param);
NO_DEVICE_CODE;
#endif
}
template<unsigned int TILE_ROWS,
unsigned int TILE_COLS,
unsigned int NUM_THREADS,
unsigned int ELEMENTS_PER_THREAD>
__device__ __forceinline__ void tileMemcpyLoadA(
const half* src,
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
// const unsigned int src_stride,
const unsigned int block_k,
const unsigned int start_k,
const unsigned int end_k,
const unsigned int inNOffset,
const unsigned int inDepthOffset,
const unsigned int inChannelOffset,
param_t param
){
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
// # of threads is multiple of # of columns in the tile
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
// flatten out 2d grid of threads into in order of increasing threadIdx.x
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
// assign each thread a row/column in the tile, calculate how many iterations we need
// to cover the whole tile
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
// compile time check that we provided the right amount of registers for storage
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
const unsigned int kidx = start_k + block_k + thread_col*8;
const int4 curIdx = inputIndices<0>(kidx, param);
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
unsigned int gemm_i = blockIdx.y * TILE_ROWS + thread_row;
unsigned int n = fastdiv(gemm_i, param.PQZ_fastdiv);
const unsigned int npqz_res = fastmodulo(gemm_i, param.PQZ_fastdiv);
const int posd_ori = fastdiv(npqz_res, param.OHOW_fastdiv) * param.stride2 - param.padding2;
const int ohow_res = fastmodulo(npqz_res, param.OHOW_fastdiv);
const int posh_ori = fastdiv(ohow_res, param.OW_fastdiv) * param.stride1 - param.padding1;
const int posw_ori = fastmodulo(ohow_res, param.OW_fastdiv) * param.stride0 - param.padding0;
const int curD = posd_ori + curIdx.y * param.dilation2; // input d
const int curH = posh_ori + curIdx.z * param.dilation1; // input h
const int curW = posw_ori + curIdx.w * param.dilation0; // input w
const int curC = curIdx.x;
if (curH >= 0 && curW >= 0 && curD >= 0 && curW < param.w && curH < param.h && curD < param.d
&& n < param.n && curC < param.c && kidx < end_k){
int inOffsetTmp = curD * inDepthOffset + curH * inChannelOffset + curW * param.c + curC;
dst_reg[i] = reinterpret_cast<const float4 *>(&src[n * inNOffset + inOffsetTmp])[0];
} else{
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
}
thread_row += ROW_STEP;
}
#else
GGML_UNUSED(src);
GGML_UNUSED(dst_reg);
GGML_UNUSED(block_k);
GGML_UNUSED(inChannelOffset);
GGML_UNUSED(param);
NO_DEVICE_CODE;
#endif
}
template<unsigned int TILE_ROWS,
unsigned int TILE_COLS,
unsigned int NUM_THREADS,
unsigned int ELEMENTS_PER_THREAD>
__device__ __forceinline__ void tileMemcpyLoadB(
const half* src,
float4 (&dst_reg)[ELEMENTS_PER_THREAD],
const unsigned int block_k,
const unsigned int start_k,
const unsigned int end_k,
const unsigned int src_stride,
param_t param
){
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
// # of threads is multiple of # of columns in the tile
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
// flatten out 2d grid of threads into in order of increasing threadIdx.x
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
// assign each thread a row/column in the tile, calculate how many iterations we need
// to cover the whole tile
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
// compile time check that we provided the right amount of registers for storage
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
const unsigned int kidx = start_k + block_k + thread_col*8;
const int4 curIdx = inputIndices<0>(kidx, param);
const int curC = curIdx.x;
const int curT = curIdx.y;
const int curR = curIdx.z;
const int curS = curIdx.w;
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++){
const unsigned int src_index = thread_row * src_stride + kidx;
// TODO : move some checks outside of the loop
if (thread_row + blockIdx.x * TILE_ROWS < param.k && curR < param.r && curS < param.s && curT < param.t
&& curC < param.c && kidx < end_k){
dst_reg[i] = reinterpret_cast<const float4 *>(&src[src_index])[0];
}else{ // read 4 halves
dst_reg[i] = make_float4(0.f, 0.f, 0.f, 0.f);
}
thread_row += ROW_STEP;
}
#else
GGML_UNUSED(src);
GGML_UNUSED(dst_reg);
GGML_UNUSED(block_k);
GGML_UNUSED(src_stride);
GGML_UNUSED(param);
NO_DEVICE_CODE;
#endif
}
// same as above but without the swizzle
// this is a special case of the above for when TILE_COLS == 32
template<unsigned int TILE_ROWS,
unsigned int NUM_THREADS,
unsigned int ELEMENTS_PER_THREAD>
__device__ __forceinline__ void tileMemcpySwizzleStore(
const float4 (&src_reg)[ELEMENTS_PER_THREAD],
half* dst
)
{
#if __CUDA_ARCH__ >= GGML_CUDA_TURING
constexpr unsigned int SWIZZLE_MASK_1 = 0b10000;
constexpr unsigned int SWIZZLE_BITS_1 = 4;
constexpr unsigned int SWIZZLE_MASK_2 = 0b1100;
constexpr unsigned int SWIZZLE_BITS_2 = 2;
constexpr unsigned int TILE_COLS = 32;
// reinterpret input/output as float4
float4* dst_float4 = reinterpret_cast<float4*>(dst);
// # of threads is multiple of # of columns in the tile
constexpr unsigned int TILE_COLS_VECTORIZED = TILE_COLS / 8;
static_assert(NUM_THREADS % TILE_COLS_VECTORIZED == 0);
// flatten out 2d grid of threads into in order of increasing threadIdx.x
const unsigned int thread_idx = threadIdx.y * blockDim.x + threadIdx.x;
// assign each thread a row/column in the tile, calculate how many iterations we need
// to cover the whole tile
constexpr unsigned int ROW_STEP = NUM_THREADS / TILE_COLS_VECTORIZED;
constexpr unsigned int NUM_ITERS = TILE_ROWS / ROW_STEP;
unsigned int thread_row = thread_idx / TILE_COLS_VECTORIZED;
const unsigned int thread_col = thread_idx % TILE_COLS_VECTORIZED;
// compile time check that we provided the right amount of registers for storage
static_assert(ELEMENTS_PER_THREAD == NUM_ITERS);
#pragma unroll
for (unsigned int i = 0; i < NUM_ITERS; i++)
{
// apply swizzle to the dst index
unsigned int dst_index = thread_row * TILE_COLS_VECTORIZED + thread_col;
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_1) >> SWIZZLE_BITS_1);
dst_index = dst_index ^ ((dst_index & SWIZZLE_MASK_2) >> SWIZZLE_BITS_2);
dst_float4[dst_index] = src_reg[i];
thread_row += ROW_STEP;
}
#else
GGML_UNUSED(src_reg);
GGML_UNUSED(dst);
NO_DEVICE_CODE;
#endif
}
__device__ __forceinline__ uint32_t cvta_to_shared_u32(const void *pointer) {
uint32_t address;
asm("{\n\t"
" .reg .u64 u64addr;\n\t"
" cvta.to.shared.u64 u64addr, %1;\n\t"
" cvt.u32.u64 %0, u64addr;\n\t"
"}"
: "=r"(address)
: "l"(pointer));
return address;
}
#define CUDA_CONV3D_IMPLICT_BLOCK_SIZE 256
void ggml_cuda_op_conv3d_implicit(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -13,6 +13,7 @@
#include "ggml-cuda/concat.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh"
#include "ggml-cuda/conv2d.cuh"
#include "ggml-cuda/conv3d-implicit.cuh"
#include "ggml-cuda/conv2d-dw.cuh"
#include "ggml-cuda/conv2d-transpose.cuh"
#include "ggml-cuda/convert.cuh"
@ -2672,6 +2673,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CONV_2D:
ggml_cuda_op_conv2d(ctx, dst);
break;
case GGML_OP_CONV_3D:
ggml_cuda_op_conv3d_implicit(ctx, dst);
break;
case GGML_OP_CONV_2D_DW:
ggml_cuda_op_conv2d_dw(ctx, dst);
break;
@ -4593,6 +4597,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_IM2COL:
case GGML_OP_IM2COL_3D:
case GGML_OP_CONV_2D:
case GGML_OP_CONV_3D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D:
case GGML_OP_POOL_2D:

View File

@ -218,6 +218,7 @@ if (NOT LLAMA_SANITIZE_ADDRESS AND NOT GGML_SCHED_NO_REALLOC)
endif()
llama_build_and_test(test-gguf.cpp)
llama_build_and_test(test-backend-ops.cpp)
llama_build_and_test(test-conv3d.cpp)
llama_build_and_test(test-model-load-cancel.cpp LABEL "model")
llama_build_and_test(test-autorelease.cpp LABEL "model")

View File

@ -8043,6 +8043,51 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
}
}
// for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
for (ggml_type kernel_type : {GGML_TYPE_F16}) {
for (int N : {1}) {
for (int IC : {48, 320, 640, 1024}) {
for (int OC : {320, 640, 1024, 2048}) {
for (int s0 : {1}) {
for (int p1 : {1}) {
for (int d2 : {1}) {
int64_t IW = 26, IH = 38, ID = 8;
int64_t KW = 3, KH = 3, KD = 3;
int s1 = s0, s2 = s0;
int p0 = p1, p2 = p1;
int d0 = d2, d1 = d2;
test_cases.emplace_back(new test_conv_3d(
N, IC, ID, IH, IW,
OC, KD, KH, KW,
s0, s1, s2, p0, p1, p2, d0, d1, d2,
kernel_type));
IW = 52; IH = 76;
test_cases.emplace_back(new test_conv_3d(
N, IC, ID, IH, IW,
OC, KD, KH, KW,
s0, s1, s2, p0, p1, p2, d0, d1, d2,
kernel_type));
IW = 104; IH = 158;
test_cases.emplace_back(new test_conv_3d(
N, IC, ID, IH, IW,
OC, KD, KH, KW,
s0, s1, s2, p0, p1, p2, d0, d1, d2,
kernel_type));
// IW = 208; IH = 316;
// test_cases.emplace_back(new test_conv_3d(
// N, IC, ID, IH, IW,
// OC, KD, KH, KW,
// s0, s1, s2, p0, p1, p2, d0, d1, d2,
// kernel_type));
}
}
}
}
}
}
}
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1}));
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));

443
tests/test-conv3d.cpp Normal file
View File

@ -0,0 +1,443 @@
#include "ggml.h"
#include "ggml-alloc.h"
#include "ggml-cpu.h"
#include "ggml-backend.h"
#ifdef GGML_USE_CUDA
#include "ggml-cuda.h"
//#include <cuda_runtime.h>
#endif
#ifdef GGML_USE_METAL
#include "ggml-metal.h"
#endif
#include <cassert>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <map>
#include <string>
#include <vector>
static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
(void) level;
(void) user_data;
fputs(text, stderr);
fflush(stderr);
}
struct test_model {
struct ggml_tensor * a;
struct ggml_tensor * b;
ggml_backend_t backend = NULL;
ggml_backend_buffer_t buffer;
struct ggml_context * ctx;
};
void load_model(test_model & model, int ic, int oc, int iw, int ih, int id,
int kw, int kh, int kd,
bool use_fp16, bool use_gpu);
struct ggml_cgraph * build_graph_0(const test_model& model, const int64_t ic, const int64_t n, const int64_t oc);
struct ggml_cgraph * build_graph_1(const test_model& model, const int64_t ic, const int64_t n, const int64_t oc);
typedef struct ggml_cgraph* (*build_graph_t)(const test_model& model,
const int64_t i0, const int64_t i1, const int64_t i2);
std::vector<float> compute_graph(const test_model & model, ggml_gallocr_t allocr,
build_graph_t build_graph, int iters,
const int64_t ic, const int64_t n, const int64_t oc, double *t);
void load_model(test_model & model, int ic, int oc, int iw, int ih, int id,
int kw = 3, int kh = 3, int kd = 3,
bool use_fp16 = true, bool use_gpu = false ) {
// create data
int KW = kw, KH = kh, KD = kd;
int IC = ic, OC = oc;
int IW = iw, IH = ih, ID = id, N = 1;
srand(time(NULL));
// printf(" input: IC = %d, OC = %d, IW = %d, IH = %d \n ", IC, OC, IW, IH);
// Initialize adata
std::vector<float> adata(KW * KH * KD * IC * OC);
for (int i = 0; i < KW * KH * KD * IC * OC; i++) {
// adata[i] = 2.f;
// adata[i] = (float)(i%KW)-1.f;
// adata[i] = (rand() % 255) / 255.0;
float r = -1.f + static_cast <float> (rand()) /( static_cast <float> (RAND_MAX/(1.f-(-1.f))));
adata[i] = r;
}
// Convert adata to fp16 format
std::vector<ggml_fp16_t> hadata(KW * KH * KD * IC * OC);
ggml_fp32_to_fp16_row(adata.data(), hadata.data(), KW * KH * KD * IC * OC);
// Initialize bdata
std::vector<float> bdata(IW * IH * ID * IC * N);
for (int i = 0; i < IW * IH * ID * IC * N; i++) {
// bdata[i] = (float)(i%IW)/10.f;
// bdata[i] = 1.5f;
// bdata[i] = (rand() % 255) / 255.0;
float r = -1.f + static_cast <float> (rand()) /( static_cast <float> (RAND_MAX/(1.f-(-1.f))));
bdata[i] = r;
}
size_t buffer_size = 0;
{ if(use_fp16)
buffer_size += KW * KH * KD * IC * OC * ggml_type_size(GGML_TYPE_F16); // tensor a
else
buffer_size += KW * KH * KD * IC * OC * ggml_type_size(GGML_TYPE_F32); // tensor a
buffer_size += IW * IH * ID * IC * N * ggml_type_size(GGML_TYPE_F32); // tensor b
buffer_size += 1024; // overhead
}
// printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor));
// printf("%s: backend buffer size = %0.2f MB\n", __func__, (buffer_size/ 1024.f/ 1024.f));
int num_tensors = 2;
struct ggml_init_params params {
/*.mem_size =*/ ggml_tensor_overhead() * num_tensors,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
// initialize the backend
#ifdef GGML_USE_CUDA
if (use_gpu) {
// fprintf(stderr, "%s: using CUDA backend\n", __func__);
model.backend = ggml_backend_cuda_init(0);
if (!model.backend) {
fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
}
}
#else
GGML_UNUSED(use_gpu);
#endif
#ifdef GGML_USE_METAL
if (use_gpu) {
fprintf(stderr, "%s: using Metal backend\n", __func__);
model.backend = ggml_backend_metal_init();
if (!model.backend) {
fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__);
}
}
#else
GGML_UNUSED(use_gpu);
#endif
if(!model.backend) {
// fallback to CPU backend
model.backend = ggml_backend_cpu_init();
}
model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size);
// create context
model.ctx = ggml_init(params);
// create tensors
if(use_fp16)
model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F16, KW, KH, KD, IC*OC);
else
model.a = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, KW, KH, KD, IC*OC);
model.b = ggml_new_tensor_4d(model.ctx, GGML_TYPE_F32, IW, IH, ID, IC*N);
// create a allocator
struct ggml_tallocr alloc = ggml_tallocr_new(model.buffer);
// alloc memory
ggml_tallocr_alloc(&alloc, model.a);
// load data to buffer
if(ggml_backend_is_cpu(model.backend)) {
if(use_fp16)
memcpy(model.a->data, hadata.data(), ggml_nbytes(model.a));
else
memcpy(model.a->data, adata.data(), ggml_nbytes(model.a));
} else {
if(use_fp16)
ggml_backend_tensor_set(model.a, hadata.data(), 0, ggml_nbytes(model.a));
else
ggml_backend_tensor_set(model.a, adata.data(), 0, ggml_nbytes(model.a));
}
// alloc memory
ggml_tallocr_alloc(&alloc, model.b);
if(ggml_backend_is_cpu(model.backend)
#ifdef GGML_USE_METAL
|| ggml_backend_is_metal(model.backend)
#endif
) {
memcpy(model.b->data, bdata.data(), ggml_nbytes(model.b));
} else {
ggml_backend_tensor_set(model.b, bdata.data(), 0, ggml_nbytes(model.b));
}
}
struct ggml_cgraph * build_graph_0(const test_model& model, const int64_t ic, const int64_t n, const int64_t oc) {
GGML_UNUSED(n);
GGML_UNUSED(oc);
static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
static std::vector<uint8_t> buf(buf_size);
struct ggml_init_params params0 = {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ buf.data(),
/*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
};
// create a temporally context to build the graph
struct ggml_context * ctx0 = ggml_init(params0);
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
// int s0 = 2;
// int s1 = 1;
// int s2 = 1;
// int p0 = 2;
// int p1 = 0;
// int p2 = 1;
// int d0 = 1;
// int d1 = 1;
// int d2 = 2;
int s0 = 1;
int s1 = 1;
int s2 = 1;
int p0 = 1;
int p1 = 1;
int p2 = 1;
int d0 = 1;
int d1 = 1;
int d2 = 1;
// recalculate for avoid fragmentation
struct ggml_tensor* conv2d_res = ggml_conv_3d(ctx0, model.a, model.b, ic, s0, s1, s2, p0, p1, p2, d0, d1, d2);
ggml_set_name(conv2d_res, "conv2d_res");
ggml_build_forward_expand(gf, conv2d_res);
// int64_t *ne = conv2d_res->ne;
// printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]);
// struct ggml_tensor* wino_res = ggml_conv_2d_3x3(ctx0, model.a, model.b);
// ggml_set_name(wino_res, "wino_res");
// ggml_build_forward_expand(gf, wino_res);
// ne = wino_res->ne;
// printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]);
ggml_free(ctx0);
return gf;
}
struct ggml_cgraph * build_graph_1(const test_model& model, const int64_t ic, const int64_t n, const int64_t oc) {
static size_t buf_size = ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead();
static std::vector<uint8_t> buf(buf_size);
struct ggml_init_params params0 = {
/*.mem_size =*/ buf_size,
/*.mem_buffer =*/ buf.data(),
/*.no_alloc =*/ true, // the tensors will be allocated later by ggml_gallocr_alloc_graph()
};
// create a temporally context to build the graph
struct ggml_context * ctx0 = ggml_init(params0);
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
int s0 = 1;
int s1 = 1;
int s2 = 1;
int p0 = 1;
int p1 = 1;
int p2 = 1;
int d0 = 1;
int d1 = 1;
int d2 = 1;
// int s0 = 2;
// int s1 = 1;
// int s2 = 1;
// int p0 = 2;
// int p1 = 0;
// int p2 = 1;
// int d0 = 1;
// int d1 = 1;
// int d2 = 2;
// recalculate for avoid fragmentation
// struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1);
// ggml_set_name(conv2d_res, "conv2d_res");
// ggml_build_forward_expand(gf, conv2d_res);
// int64_t *ne = conv2d_res->ne;
// printf("conv2d: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]);
// struct ggml_tensor* wino_res = ggml_conv_2d_implicitgemm(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1);
struct ggml_tensor* wino_res = ggml_conv_3d_direct(ctx0, model.a, model.b,
s0, s1, s2, p0, p1, p2, d0, d1, d2,
ic, n, oc);
ggml_set_name(wino_res, "wino_res");
ggml_build_forward_expand(gf, wino_res);
// int64_t *ne = wino_res->ne;
// printf("wino: (%zu, %zu, %zu, %zu) \n", ne[0], ne[1], ne[2], ne[3]);
ggml_free(ctx0);
return gf;
}
std::vector<float> compute_graph(const test_model & model, ggml_gallocr_t allocr,
build_graph_t build_graph, int iters,
const int64_t ic, const int64_t n, const int64_t oc, double *t) {
struct ggml_cgraph * gf = build_graph(model, ic, n, oc);
// allocate tensors
ggml_gallocr_alloc_graph(allocr, gf);
int n_threads = 1;
if (ggml_backend_is_cpu(model.backend)) {
ggml_backend_cpu_set_n_threads(model.backend, n_threads);
}
ggml_backend_graph_compute(model.backend, gf);
ggml_backend_synchronize(model.backend);
int64_t start_time = ggml_time_us();
for(int iter=0; iter<iters; iter++){
ggml_backend_graph_compute(model.backend, gf);
ggml_backend_synchronize(model.backend);
}
// ggml_backend_synchronize(model.backend);
int64_t end_time = ggml_time_us();
double time_us = end_time - start_time;
time_us = time_us/iters;
//ggml_graph_print(gf);
struct ggml_tensor *res = NULL;
for(int i = 0; i < ggml_graph_n_nodes(gf); ++i) {
if(strcmp(ggml_get_name(ggml_graph_node(gf, i)), "wino_res") == 0) {
res = ggml_graph_node(gf, i);
} else if(strcmp(ggml_get_name(ggml_graph_node(gf, i)), "conv2d_res") == 0) {
res = ggml_graph_node(gf, i);
}
}
std::vector<float> data(ggml_nelements(res));
ggml_backend_tensor_get(res, data.data(), 0, ggml_nbytes(res));
*t = time_us/1000;
return data;
}
int main(void)
{
ggml_time_init();
std::vector<std::tuple<int, int, int, int, int, int, int, int>> configs = {
// std::make_tuple(1,2,16,32,4,3,3,3),
// std::make_tuple(320,1280,26,38,8,3,3,3),
// std::make_tuple(1280,1280,26,38,8,3,3,3),
// std::make_tuple(320,1280,52,76,8,3,3,3),
// std::make_tuple(1280,1280,52,76,8,3,3,3),
// std::make_tuple(320,1280,104,152,8,3,3,3),
// std::make_tuple(1280,1280,104,152,8,3,3,3),
// std::make_tuple(320,1280,208,304,4,3,3,3),
// std::make_tuple(1024,2048,30,52,3,3,3,3),
// std::make_tuple(1024,2048,52,76,4,3,3,3),
// std::make_tuple(1024,2048,52,76,6,3,3,3),
// std::make_tuple(48,3072,64,64,9,2,2,1),
// std::make_tuple(48,3072,64,64,17,2,2,1),
// std::make_tuple(48,3072,64,64,33,2,2,1),
std::make_tuple(320,320,104,158,8,3,3,3),
};
int k = 0;
for (auto c : configs){
test_model model;
load_model(model, std::get<0>(c), std::get<1>(c), std::get<2>(c),
std::get<3>(c), std::get<4>(c), std::get<5>(c), std::get<6>(c), std::get<7>(c), true, true);
ggml_gallocr_t allocr = NULL;
allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend));
//create the worst case graph for memory usage estimation
struct ggml_cgraph * gf = build_graph_0(model, std::get<0>(c), 0, 0);
// compute the required memory
ggml_gallocr_reserve(allocr, gf);
size_t mem_size0 = ggml_gallocr_get_buffer_size(allocr, 0);
// fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f);
int iterations = 20;
double run_time0;
std::vector<float> im2col_data = compute_graph(model, allocr, build_graph_0, iterations,
std::get<0>(c), 1, std::get<1>(c), &run_time0);
ggml_gallocr_free(allocr);
allocr = NULL;
allocr = ggml_gallocr_new(ggml_backend_get_default_buffer_type(model.backend));
//create the worst case graph for memory usage estimation
gf = build_graph_1(model, std::get<0>(c), 1, std::get<1>(c));
// compute the required memory
ggml_gallocr_reserve(allocr, gf);
size_t mem_size1 = ggml_gallocr_get_buffer_size(allocr, 0);
double run_time1;
std::vector<float> conv2d_data = compute_graph(model, allocr, build_graph_1, iterations,
std::get<0>(c), 1, std::get<1>(c), &run_time1);
if(k==0) {
k = 1;
fprintf(stderr, "| (IC, OC, IW, IH, ID, KW, KH, KD) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n");
fprintf(stderr, "| --- | --- | --- | --- | --- \n");
}
fprintf(stderr, " | (%d, %d, %d, %d, %d, %d, %d, %d) | %.2f ms | %.2f MB | %.2f ms | %.2f MB\n",
std::get<0>(c), std::get<1>(c), std::get<2>(c),
std::get<3>(c), std::get<4>(c), std::get<5>(c),
std::get<6>(c), std::get<7>(c),
run_time0, mem_size0/1024.0f/1024.0f,
run_time1, mem_size1/1024.0f/1024.0f);
// for(int i = 0; i < conv2d_data.size(); i++) {
// float diff = fabs(im2col_data[i] - conv2d_data[i]);
// // if(diff > 0.5) {
// printf("(%7.3f, %7.3f, %f, %d) \n",
// im2col_data[i], conv2d_data[i],
// diff, i);
// // break;
// // }
// }
ggml_free(model.ctx);
ggml_backend_buffer_free(model.buffer);
ggml_backend_free(model.backend);
ggml_gallocr_free(allocr);
}
return 0;
}