refactor mmf rows_per_block
This commit is contained in:
parent
6e36299b47
commit
fe24848d00
|
|
@ -2,6 +2,12 @@
|
||||||
#include "mmf.cuh"
|
#include "mmf.cuh"
|
||||||
#include "mmid.cuh"
|
#include "mmid.cuh"
|
||||||
|
|
||||||
|
constexpr int mmf_rows_per_block = 32;
|
||||||
|
constexpr int mmf_rows_per_block_cdna = 64;
|
||||||
|
|
||||||
|
static int get_mmf_rows_per_block(const int cc, const int warp_size) {
|
||||||
|
return warp_size;
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
||||||
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
||||||
|
|
@ -89,30 +95,60 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
|
||||||
ids_info_ptr = &ids_info;
|
ids_info_ptr = &ids_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const int device = ggml_cuda_get_device();
|
||||||
|
const int cc = ggml_cuda_info().devices[device].cc;
|
||||||
|
const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
||||||
|
const int rows_per_block = get_mmf_rows_per_block(cc, warp_size);
|
||||||
|
|
||||||
|
if (rows_per_block != mmf_rows_per_block && rows_per_block != mmf_rows_per_block_cdna) {
|
||||||
|
GGML_ABORT("unsupported rows_per_block: %i", rows_per_block);
|
||||||
|
}
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32: {
|
case GGML_TYPE_F32: {
|
||||||
const float * src0_d = (const float *) src0->data;
|
const float * src0_d = (const float *) src0->data;
|
||||||
constexpr int vals_per_T = 1;
|
constexpr int vals_per_T = 1;
|
||||||
mul_mat_f_switch_cols_per_block(
|
if (rows_per_block == mmf_rows_per_block) {
|
||||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
mul_mat_f_switch_cols_per_block<float, mmf_rows_per_block>(
|
||||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||||
|
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||||
|
} else {
|
||||||
|
mul_mat_f_switch_cols_per_block<float, mmf_rows_per_block_cdna>(
|
||||||
|
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||||
|
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||||
|
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_F16: {
|
case GGML_TYPE_F16: {
|
||||||
const half2 * src0_d = (const half2 *) src0->data;
|
const half2 * src0_d = (const half2 *) src0->data;
|
||||||
constexpr int vals_per_T = 2;
|
constexpr int vals_per_T = 2;
|
||||||
mul_mat_f_switch_cols_per_block(
|
if (rows_per_block == mmf_rows_per_block) {
|
||||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
mul_mat_f_switch_cols_per_block<half2, mmf_rows_per_block>(
|
||||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||||
|
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||||
|
} else {
|
||||||
|
mul_mat_f_switch_cols_per_block<half2, mmf_rows_per_block_cdna>(
|
||||||
|
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||||
|
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||||
|
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_BF16: {
|
case GGML_TYPE_BF16: {
|
||||||
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
|
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
|
||||||
constexpr int vals_per_T = 2;
|
constexpr int vals_per_T = 2;
|
||||||
mul_mat_f_switch_cols_per_block(
|
if (rows_per_block == mmf_rows_per_block) {
|
||||||
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
mul_mat_f_switch_cols_per_block<nv_bfloat162, mmf_rows_per_block>(
|
||||||
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||||
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||||
|
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||||
|
} else {
|
||||||
|
mul_mat_f_switch_cols_per_block<nv_bfloat162, mmf_rows_per_block_cdna>(
|
||||||
|
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
|
||||||
|
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
|
||||||
|
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
|
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
|
||||||
|
|
@ -140,7 +176,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
|
if (src0_ne[1] % get_mmf_rows_per_block(cc, warp_size) != 0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,6 @@
|
||||||
|
|
||||||
using namespace ggml_cuda_mma;
|
using namespace ggml_cuda_mma;
|
||||||
|
|
||||||
#define MMF_ROWS_PER_BLOCK 32
|
|
||||||
|
|
||||||
struct mmf_ids_data {
|
struct mmf_ids_data {
|
||||||
const int32_t * ids_src_compact = nullptr;
|
const int32_t * ids_src_compact = nullptr;
|
||||||
const int32_t * ids_dst_compact = nullptr;
|
const int32_t * ids_dst_compact = nullptr;
|
||||||
|
|
@ -228,21 +226,30 @@ static __global__ void mul_mat_f(
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
float sum = 0.0f;
|
float sum[rows_per_block/warp_size] = {0.0f};
|
||||||
static_assert(rows_per_block == warp_size, "need loop/check");
|
static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
|
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
|
||||||
const int i = i0 + threadIdx.x;
|
#pragma unroll
|
||||||
|
for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
|
||||||
|
const int i = i0 + i1*warp_size + threadIdx.x;
|
||||||
|
|
||||||
sum += buf_iw[j*kiw + i];
|
sum[i1] += buf_iw[j*kiw + i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (!has_ids) {
|
if constexpr (!has_ids) {
|
||||||
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
|
||||||
|
dst[j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
const int slot = (j < cols_per_block) ? slot_map[j] : -1;
|
const int slot = (j < cols_per_block) ? slot_map[j] : -1;
|
||||||
if (slot >= 0 && (col_base + j) < ncols_dst_total) {
|
if (slot >= 0 && (col_base + j) < ncols_dst_total) {
|
||||||
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
|
||||||
|
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -497,13 +504,16 @@ static __global__ void mul_mat_f_ids(
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
float sum = 0.0f;
|
float sum[rows_per_block/warp_size] = {0.0f};
|
||||||
static_assert(rows_per_block == warp_size, "need loop/check");
|
static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
|
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
|
||||||
const int i = i0 + threadIdx.x;
|
#pragma unroll
|
||||||
|
for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
|
||||||
|
const int i = i0 + i1*warp_size + threadIdx.x;
|
||||||
|
|
||||||
sum += buf_iw[j*kiw + i];
|
sum[i1] += buf_iw[j * kiw + i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const int global_j = col_base + j;
|
const int global_j = col_base + j;
|
||||||
|
|
@ -513,7 +523,10 @@ static __global__ void mul_mat_f_ids(
|
||||||
const int token = (int) qrm.x;
|
const int token = (int) qrm.x;
|
||||||
if (token < ncols_dst_total) {
|
if (token < ncols_dst_total) {
|
||||||
const int slot = (int) qrm.y;
|
const int slot = (int) qrm.y;
|
||||||
dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum;
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
|
||||||
|
dst[slot * stride_channel_dst + token * stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -529,7 +542,7 @@ static __global__ void mul_mat_f_ids(
|
||||||
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T, int cols_per_block, int nwarps>
|
template<typename T, int rows_per_block, int cols_per_block, int nwarps>
|
||||||
static inline void mul_mat_f_switch_ids(
|
static inline void mul_mat_f_switch_ids(
|
||||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||||
const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,
|
const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,
|
||||||
|
|
@ -553,7 +566,7 @@ static inline void mul_mat_f_switch_ids(
|
||||||
const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
|
const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
|
||||||
const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst);
|
const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst);
|
||||||
|
|
||||||
mul_mat_f_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
mul_mat_f_ids<T, rows_per_block, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||||
(x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
|
(x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
|
||||||
ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
|
|
@ -564,19 +577,19 @@ static inline void mul_mat_f_switch_ids(
|
||||||
dim3 block_nums_ids = block_nums;
|
dim3 block_nums_ids = block_nums;
|
||||||
block_nums_ids.y *= col_tiles;
|
block_nums_ids.y *= col_tiles;
|
||||||
|
|
||||||
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
mul_mat_f<T, rows_per_block, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
|
||||||
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
} else {
|
} else {
|
||||||
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
|
mul_mat_f<T, rows_per_block, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
|
||||||
(x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
(x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int cols_per_block>
|
template <typename T, int rows_per_block, int cols_per_block>
|
||||||
void mul_mat_f_cuda(
|
void mul_mat_f_cuda(
|
||||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||||
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
|
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
|
||||||
|
|
@ -614,7 +627,6 @@ void mul_mat_f_cuda(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
|
|
||||||
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
|
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
|
||||||
const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
|
const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
|
||||||
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
|
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
|
||||||
|
|
@ -628,56 +640,56 @@ void mul_mat_f_cuda(
|
||||||
|
|
||||||
switch (nwarps_best) {
|
switch (nwarps_best) {
|
||||||
case 1: {
|
case 1: {
|
||||||
mul_mat_f_switch_ids<T, cols_per_block, 1>(
|
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 1>(
|
||||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||||
ids_data);
|
ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 2: {
|
case 2: {
|
||||||
mul_mat_f_switch_ids<T, cols_per_block, 2>(
|
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 2>(
|
||||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||||
ids_data);
|
ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 3: {
|
case 3: {
|
||||||
mul_mat_f_switch_ids<T, cols_per_block, 3>(
|
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 3>(
|
||||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||||
ids_data);
|
ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 4: {
|
case 4: {
|
||||||
mul_mat_f_switch_ids<T, cols_per_block, 4>(
|
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 4>(
|
||||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||||
ids_data);
|
ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 5: {
|
case 5: {
|
||||||
mul_mat_f_switch_ids<T, cols_per_block, 5>(
|
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 5>(
|
||||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||||
ids_data);
|
ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 6: {
|
case 6: {
|
||||||
mul_mat_f_switch_ids<T, cols_per_block, 6>(
|
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 6>(
|
||||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||||
ids_data);
|
ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 7: {
|
case 7: {
|
||||||
mul_mat_f_switch_ids<T, cols_per_block, 7>(
|
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 7>(
|
||||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||||
ids_data);
|
ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 8: {
|
case 8: {
|
||||||
mul_mat_f_switch_ids<T, cols_per_block, 8>(
|
mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 8>(
|
||||||
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
|
||||||
|
|
@ -691,7 +703,7 @@ void mul_mat_f_cuda(
|
||||||
GGML_UNUSED_VARS(nchannels_y);
|
GGML_UNUSED_VARS(nchannels_y);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, int rows_per_block>
|
||||||
static void mul_mat_f_switch_cols_per_block(
|
static void mul_mat_f_switch_cols_per_block(
|
||||||
const T * x, const float * y, const int32_t * ids, float * dst,
|
const T * x, const float * y, const int32_t * ids, float * dst,
|
||||||
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
|
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
|
||||||
|
|
@ -708,82 +720,82 @@ static void mul_mat_f_switch_cols_per_block(
|
||||||
|
|
||||||
switch (ncols_case) {
|
switch (ncols_case) {
|
||||||
case 1: {
|
case 1: {
|
||||||
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, rows_per_block, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 2: {
|
case 2: {
|
||||||
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, rows_per_block, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 3: {
|
case 3: {
|
||||||
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, rows_per_block, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 4: {
|
case 4: {
|
||||||
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, rows_per_block, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 5: {
|
case 5: {
|
||||||
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, rows_per_block, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 6: {
|
case 6: {
|
||||||
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, rows_per_block, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 7: {
|
case 7: {
|
||||||
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, rows_per_block, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 8: {
|
case 8: {
|
||||||
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, rows_per_block, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 9: {
|
case 9: {
|
||||||
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, rows_per_block, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 10: {
|
case 10: {
|
||||||
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, rows_per_block, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 11: {
|
case 11: {
|
||||||
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, rows_per_block, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 12: {
|
case 12: {
|
||||||
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, rows_per_block, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 13: {
|
case 13: {
|
||||||
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, rows_per_block, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 14: {
|
case 14: {
|
||||||
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, rows_per_block, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 15: {
|
case 15: {
|
||||||
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, rows_per_block, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
case 16: {
|
case 16: {
|
||||||
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
mul_mat_f_cuda<T, rows_per_block, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
||||||
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
||||||
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
|
||||||
} break;
|
} break;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue